6f973cc971e890df6d6f14d3dd631d175ddb8794
[libdai.git] / include / dai / emalg.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
8 * Copyright (C) 2009 University of California, Santa Cruz
9 */
10
11
12 #ifndef __defined_libdai_emalg_h
13 #define __defined_libdai_emalg_h
14
15
16 #include <vector>
17 #include <map>
18
19 #include <dai/factor.h>
20 #include <dai/daialg.h>
21 #include <dai/evidence.h>
22 #include <dai/index.h>
23 #include <dai/properties.h>
24
25
26 /// \file
27 /// \brief Defines classes related to Expectation Maximization: EMAlg, ParameterEstimation, CondProbEstimation and SharedParameters
28
29
30 namespace dai {
31
32
33 /// Base class for parameter estimation methods.
34 /** This class defines the general interface of parameter estimation methods.
35 *
36 * Implementations of this interface (see e.g. CondProbEstimation) should
37 * register a factory function (virtual constructor) via the static
38 * registerMethod() function.
39 * This factory function should return a pointer to a newly constructed
40 * object, whose type is a subclass of ParameterEstimation, and gets as
41 * input a PropertySet of parameters. After a subclass has been registered,
42 * instances of it can be constructed using the construct() method.
43 *
44 * Implementations are responsible for collecting data from a probability
45 * vector passed to it from a SharedParameters container object.
46 *
47 * The default registry only contains CondProbEstimation, named
48 * "ConditionalProbEstimation".
49 *
50 * \author Charles Vaske
51 */
52 class ParameterEstimation {
53 public:
54 /// Type of pointer to factory function.
55 typedef ParameterEstimation* (*ParamEstFactory)( const PropertySet& );
56
57 /// Virtual destructor for deleting pointers to derived classes.
58 virtual ~ParameterEstimation() {}
59
60 /// Virtual copy constructor.
61 virtual ParameterEstimation* clone() const = 0;
62
63 /// General factory method that constructs the desired ParameterEstimation subclass
64 /** \param method Name of the subclass that should be constructed;
65 * \param p Parameters passed to constructor of subclass.
66 * \note \a method should either be in the default registry or should be registered first using registerMethod().
67 */
68 static ParameterEstimation* construct( const std::string &method, const PropertySet &p );
69
70 /// Register a subclass so that it can be used with construct().
71 static void registerMethod( const std::string &method, const ParamEstFactory &f ) {
72 if( _registry == NULL )
73 loadDefaultRegistry();
74 (*_registry)[method] = f;
75 }
76
77 /// Estimate the factor using the accumulated sufficient statistics and reset.
78 virtual Prob estimate() = 0;
79
80 /// Accumulate the sufficient statistics for \a p.
81 virtual void addSufficientStatistics( const Prob &p ) = 0;
82
83 /// Returns the size of the Prob that should be passed to addSufficientStatistics.
84 virtual size_t probSize() const = 0;
85
86 private:
87 /// A static registry containing all methods registered so far.
88 static std::map<std::string, ParamEstFactory> *_registry;
89
90 /// Registers default ParameterEstimation subclasses (currently, only CondProbEstimation).
91 static void loadDefaultRegistry();
92 };
93
94
95 /// Estimates the parameters of a conditional probability table, using pseudocounts.
96 /** \author Charles Vaske
97 */
98 class CondProbEstimation : private ParameterEstimation {
99 private:
100 /// Number of states of the variable of interest
101 size_t _target_dim;
102 /// Current pseudocounts
103 Prob _stats;
104 /// Initial pseudocounts
105 Prob _initial_stats;
106
107 public:
108 /// Constructor
109 /** For a conditional probability \f$ P( X | Y ) \f$,
110 * \param target_dimension should equal \f$ | X | \f$
111 * \param pseudocounts are the initial pseudocounts, of length \f$ |X| \cdot |Y| \f$
112 */
113 CondProbEstimation( size_t target_dimension, const Prob &pseudocounts );
114
115 /// Virtual constructor, using a PropertySet.
116 /** Some keys in the PropertySet are required.
117 * For a conditional probability \f$ P( X | Y ) \f$,
118 * - \a target_dimension should be equal to \f$ | X | \f$
119 * - \a total_dimension should be equal to \f$ |X| \cdot |Y| \f$
120 *
121 * An optional key is:
122 * - \a pseudo_count which specifies the initial counts (defaults to 1)
123 */
124 static ParameterEstimation* factory( const PropertySet &p );
125
126 /// Virtual copy constructor
127 virtual ParameterEstimation* clone() const { return new CondProbEstimation( _target_dim, _initial_stats ); }
128
129 /// Virtual destructor
130 virtual ~CondProbEstimation() {}
131
132 /// Returns an estimate of the conditional probability distribution.
133 /** The format of the resulting Prob keeps all the values for
134 * \f$ P(X | Y=y) \f$ in sequential order in the array.
135 */
136 virtual Prob estimate();
137
138 /// Accumulate sufficient statistics from the expectations in \a p
139 virtual void addSufficientStatistics( const Prob &p );
140
141 /// Returns the required size for arguments to addSufficientStatistics().
142 virtual size_t probSize() const { return _stats.size(); }
143 };
144
145
146 /// Represents a single factor or set of factors whose parameters should be estimated.
147 /** To ensure that parameters can be shared between different factors during
148 * EM learning, each factor's values are reordered to match a desired variable
149 * ordering. The ordering of the variables in a factor may therefore differ
150 * from the canonical ordering used in libDAI. The SharedParameters
151 * class combines one or more factors (together with the specified orderings
152 * of the variables) with a ParameterEstimation object, taking care of the
153 * necessary permutations of the factor entries / parameters.
154 *
155 * \author Charles Vaske
156 */
157 class SharedParameters {
158 public:
159 /// Convenience label for an index of a factor in a FactorGraph.
160 typedef size_t FactorIndex;
161 /// Convenience label for a grouping of factor orientations.
162 typedef std::map<FactorIndex, std::vector<Var> > FactorOrientations;
163
164 private:
165 /// Maps factor indices to the corresponding VarSets
166 std::map<FactorIndex, VarSet> _varsets;
167 /// Maps factor indices to the corresponding Permute objects that permute the desired ordering into the canonical ordering
168 std::map<FactorIndex, Permute> _perms;
169 /// Maps factor indices to the corresponding desired variable orderings
170 FactorOrientations _varorders;
171 /// Parameter estimation method to be used
172 ParameterEstimation *_estimation;
173 /// Indicates whether \c *this gets ownership of _estimation
174 bool _ownEstimation;
175
176 /// Calculates the permutation that permutes the variables in varorder into the canonical ordering
177 /** \param varorder Given ordering of variables
178 * \param outVS Contains variables in \varorder represented as a VarSet
179 * \return Permute object for permuting variables in varorder into the canonical libDAI ordering
180 */
181 static Permute calculatePermutation( const std::vector<Var> &varorder, VarSet &outVS );
182
183 /// Initializes _varsets and _perms from _varorders and checks whether their state spaces correspond with _estimation.probSize()
184 void setPermsAndVarSetsFromVarOrders();
185
186 public:
187 /// Constructor
188 /** \param varorders all the factor orientations for this parameter
189 * \param estimation a pointer to the parameter estimation method
190 * \param ownPE whether the constructed object gets ownership of \a estimation
191 */
192 SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool ownPE=false );
193
194 /// Construct a SharedParameters object from a stream \a is and a factor graph \a fg
195 /** \see \ref fileformats-sharedparameters
196 */
197 SharedParameters( std::istream &is, const FactorGraph &fg );
198
199 /// Copy constructor
200 SharedParameters( const SharedParameters &sp ) : _varsets(sp._varsets), _perms(sp._perms), _varorders(sp._varorders), _estimation(sp._estimation), _ownEstimation(sp._ownEstimation) {
201 // If sp owns its _estimation object, we should clone it instead of copying the pointer
202 if( _ownEstimation )
203 _estimation = _estimation->clone();
204 }
205
206 /// Destructor
207 ~SharedParameters() {
208 // If we own the _estimation object, we should delete it now
209 if( _ownEstimation )
210 delete _estimation;
211 }
212
213 /// Collect the necessary statistics from expected values
214 void collectSufficientStatistics( InfAlg &alg );
215
216 /// Estimate and set the shared parameters
217 void setParameters( FactorGraph &fg );
218
219 /// Returns the parameters
220 void collectParameters( const FactorGraph &fg, std::vector<Real> &outVals, std::vector<Var> &outVarOrder );
221 };
222
223
224 /// A MaximizationStep groups together several parameter estimation tasks into a single unit.
225 /** \author Charles Vaske
226 */
227 class MaximizationStep {
228 private:
229 std::vector<SharedParameters> _params;
230
231 public:
232 /// Default constructor
233 MaximizationStep() : _params() {}
234
235 /// Constructor from a vector of SharedParameters objects
236 MaximizationStep( std::vector<SharedParameters> &maximizations ) : _params(maximizations) {}
237
238 /// Constructor from an input stream and a corresponding factor graph
239 MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup );
240
241 /// Collect the beliefs from this InfAlg as expectations for the next Maximization step.
242 void addExpectations( InfAlg &alg );
243
244 /// Using all of the currently added expectations, make new factors with maximized parameters and set them in the FactorGraph.
245 void maximize( FactorGraph &fg );
246
247 /// \name Iterator interface
248 //@{
249 typedef std::vector<SharedParameters>::iterator iterator;
250 typedef std::vector<SharedParameters>::const_iterator const_iterator;
251 iterator begin() { return _params.begin(); }
252 const_iterator begin() const { return _params.begin(); }
253 iterator end() { return _params.end(); }
254 const_iterator end() const { return _params.end(); }
255 //@}
256 };
257
258
259 /// EMAlg performs Expectation Maximization to learn factor parameters.
260 /** This requires specifying:
261 * - Evidence (instances of observations from the graphical model),
262 * - InfAlg for performing the E-step, which includes the factor graph,
263 * - a vector of MaximizationSteps steps to be performed.
264 *
265 * This implementation can perform incremental EM by using multiple
266 * MaximizationSteps. An expectation step is performed between execution
267 * of each MaximizationStep. A call to iterate() will cycle through all
268 * MaximizationSteps.
269 *
270 * Having multiple and separate maximization steps allows for maximizing some
271 * parameters, performing another E step, and then maximizing separate
272 * parameters, which may result in faster convergence in some cases.
273 *
274 * \author Charles Vaske
275 */
276 class EMAlg {
277 private:
278 /// All the data samples used during learning
279 const Evidence &_evidence;
280
281 /// How to do the expectation step
282 InfAlg &_estep;
283
284 /// The maximization steps to take
285 std::vector<MaximizationStep> _msteps;
286
287 /// Number of iterations done
288 size_t _iters;
289
290 /// History of likelihoods
291 std::vector<Real> _lastLogZ;
292
293 /// Maximum number of iterations
294 size_t _max_iters;
295
296 /// Convergence tolerance
297 Real _log_z_tol;
298
299 public:
300 /// Key for setting maximum iterations @see setTermConditions
301 static const std::string MAX_ITERS_KEY;
302 /// Default maximum iterations @see setTermConditions
303 static const size_t MAX_ITERS_DEFAULT;
304 /// Key for setting likelihood termination condition @see setTermConditions
305 static const std::string LOG_Z_TOL_KEY;
306 /// Default likelihood tolerance @see setTermConditions
307 static const Real LOG_Z_TOL_DEFAULT;
308
309 /// Construct an EMAlg from all these objects
310 EMAlg( const Evidence &evidence, InfAlg &estep, std::vector<MaximizationStep> &msteps, const PropertySet &termconditions )
311 : _evidence(evidence), _estep(estep), _msteps(msteps), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT), _log_z_tol(LOG_Z_TOL_DEFAULT)
312 {
313 setTermConditions( termconditions );
314 }
315
316 /// Construct an EMAlg from an Evidence object, an InfAlg object, and an input stream
317 EMAlg( const Evidence &evidence, InfAlg &estep, std::istream &mstep_file );
318
319 /// Change the conditions for termination
320 /** There are two possible parameters in the PropertySet
321 * - max_iters maximum number of iterations
322 * - log_z_tol proportion of increase in logZ
323 *
324 * \see hasSatisifiedTermConditions()
325 */
326 void setTermConditions( const PropertySet &p );
327
328 /// Determine if the termination conditions have been met.
329 /** There are two sufficient termination conditions:
330 * -# the maximum number of iterations has been performed
331 * -# the ratio of logZ increase over previous logZ is less than the
332 * tolerance, i.e.,
333 * \f$ \frac{\log(Z_t) - \log(Z_{t-1})}{| \log(Z_{t-1}) | } < \mathrm{tol} \f$.
334 */
335 bool hasSatisfiedTermConditions() const;
336
337 /// Return the last calculated log likelihood
338 Real getLogZ() const { return _lastLogZ.back(); }
339
340 /// Returns number of iterations done so far
341 size_t getCurrentIters() const { return _iters; }
342
343 /// Get the iteration method used
344 const InfAlg& eStep() const { return _estep; }
345
346 /// Perform an iteration over all maximization steps
347 Real iterate();
348
349 /// Perform an iteration over a single MaximizationStep
350 Real iterate( MaximizationStep &mstep );
351
352 /// Iterate until termination conditions are satisfied
353 void run();
354
355 /// \name Iterator interface
356 //@{
357 typedef std::vector<MaximizationStep>::iterator s_iterator;
358 typedef std::vector<MaximizationStep>::const_iterator const_s_iterator;
359 s_iterator s_begin() { return _msteps.begin(); }
360 const_s_iterator s_begin() const { return _msteps.begin(); }
361 s_iterator s_end() { return _msteps.end(); }
362 const_s_iterator s_end() const { return _msteps.end(); }
363 //@}
364 };
365
366
367 } // end of namespace dai
368
369
370 #endif