Added EM code by Charlie Vaske (and cleaned up the style a bit)
[libdai.git] / include / dai / emalg.h
1 /* Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
2 University of California Santa Cruz
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_emalg_h
23 #define __defined_libdai_emalg_h
24
25
26 #include <vector>
27 #include <map>
28
29 #include <dai/factor.h>
30 #include <dai/daialg.h>
31 #include <dai/evidence.h>
32 #include <dai/index.h>
33 #include <dai/properties.h>
34
35
36 /// \file
37 /** \brief Defines classes related to Expectation Maximization:
38 * EMAlg, ParameterEstimate, and FactorOrientations
39 */
40
41
42 namespace dai {
43
44
45 /// Interface for a parameter estimation method.
46 /** This parameter estimation interface is based on sufficient statistics.
47 * Implementations are responsible for collecting data from a probability
48 * vector passed to it from a SharedParameters container object.
49 *
50 * Implementations of this interface should register a factory function
51 * via the static ParameterEstimation::registerMethod function.
52 */
53 class ParameterEstimation {
54 public:
55 /// A pointer to a factory function.
56 typedef ParameterEstimation* (*ParamEstFactory)(const PropertySet&);
57
58 /// General factory method for construction of ParameterEstimation subclasses.
59 static ParameterEstimation* construct( const std::string& method, const PropertySet& p );
60 /// Register a subclass with ParameterEstimation::construct.
61 static void registerMethod( const std::string method, const ParamEstFactory f ) {
62 if( _registry == NULL )
63 loadDefaultRegistry();
64 (*_registry)[method] = f;
65 }
66 /// Virtual destructor for deleting pointers to derived classes.
67 virtual ~ParameterEstimation() {}
68 /// Estimate the factor using the accumulated sufficient statistics and reset.
69 virtual Prob estimate() = 0;
70 /// Accumulate the sufficient statistics for p.
71 virtual void addSufficientStatistics( Prob& p ) = 0;
72 /// Returns the size of the Prob that is passed to addSufficientStatistics.
73 virtual size_t probSize() const = 0;
74 /// A virtual copy constructor.
75 virtual ParameterEstimation* clone() const = 0;
76
77 private:
78 static std::map<std::string, ParamEstFactory>* _registry;
79 static void loadDefaultRegistry();
80 };
81
82
83 /// Estimates the parameters of a conditional probability, using pseudocounts.
84 class CondProbEstimation : private ParameterEstimation {
85 private:
86 size_t _target_dim;
87 Prob _stats;
88 Prob _initial_stats;
89 public:
90 /** For a conditional probability \f$ Pr( X | Y ) \f$,
91 * \param target_dimension should equal \f$ | X | \f$
92 * \param pseudocounts has length \f$ |X| \cdot |Y| \f$
93 */
94 CondProbEstimation( size_t target_dimension, Prob pseudocounts );
95
96 /// Virtual constructor, using a PropertySet.
97 /** Some keys in the PropertySet are required:
98 * - target_dimension, which should be equal to \f$ | X | \f$
99 * - total_dimension, which sholud be equal to \f$ |X| \cdot |Y| \f$
100 *
101 * An optional key is:
102 * - pseudo_count which specifies the initial counts (defaults to 1)
103 */
104 static ParameterEstimation* factory( const PropertySet& p );
105 /// Virtual destructor
106 virtual ~CondProbEstimation() {}
107 /// Returns an estimate of the conditional probability distribution.
108 /** The format of the resulting Prob keeps all the values for
109 * \f$ P(X | Y=a) \f$ sequential in teh array.
110 */
111 virtual Prob estimate();
112 /// Accumulate sufficient statistics from the expectations in p.
113 virtual void addSufficientStatistics( Prob& p );
114 /// Returns the required size for arguments to addSufficientStatistics
115 virtual size_t probSize() const {
116 return _stats.size();
117 }
118 /// Virtual copy constructor.
119 virtual ParameterEstimation* clone() const {
120 return new CondProbEstimation( _target_dim, _initial_stats );
121 }
122 };
123
124
125 /** A single factor or set of factors whose parameters should be
126 * estimated. Each factor's values are reordered to match a
127 * canonical variable ordering. This canonical variable ordering
128 * will likely not be the order of variables required to make two
129 * factors parameters isomorphic. Therefore, this ordering of the
130 * variables must be specified for ever factor to ensure that
131 * parameters can be shared between different factors during EM.
132 */
133 class SharedParameters {
134 public:
135 /// Convenience label for an index into a FactorGraph to a factor.
136 typedef size_t FactorIndex;
137 /// Convenience label for a grouping of factor orientations.
138 typedef std::map< FactorIndex, std::vector< Var > > FactorOrientations;
139 private:
140 std::map<FactorIndex, VarSet> _varsets;
141 std::map<FactorIndex, Permute> _perms;
142 FactorOrientations _varorders;
143 ParameterEstimation* _estimation;
144 bool _deleteEstimation;
145
146 static Permute calculatePermutation( const std::vector<Var>& varorder, const std::vector<size_t>& dims, VarSet& outVS );
147 void setPermsAndVarSetsFromVarOrders();
148
149 public:
150 /// Copy constructor
151 SharedParameters( const SharedParameters& sp );
152 /// Constructor useful in programmatic settings
153 /** \param varorders all the factor orientations for this parameter
154 * \param estimation a pointer to the parameter estimation method
155 */
156 SharedParameters( const FactorOrientations& varorders, ParameterEstimation* estimation );
157
158 /// Constructor for making an object from a stream
159 SharedParameters( std::istream& is, const FactorGraph& fg_varlookup );
160
161 /// Destructor
162 ~SharedParameters() {
163 if( _deleteEstimation )
164 delete _estimation;
165 }
166
167 /// Collect the necessary statistics from expected values
168 void collectSufficientStatistics( InfAlg& alg );
169
170 /// Estimate and set the shared parameters
171 void setParameters( FactorGraph& fg );
172 };
173
174
175 /** A maximization step groups together several parameter estimation
176 * tasks into a single unit.
177 */
178 class MaximizationStep {
179 private:
180 std::vector<SharedParameters> _params;
181 public:
182 MaximizationStep() : _params() {}
183
184 /// Construct an step object taht contains all these estimation probelms
185 MaximizationStep( std::vector<SharedParameters>& maximizations ) : _params(maximizations) {}
186
187 /// Construct a step from an input stream
188 MaximizationStep( std::istream& is, const FactorGraph& fg_varlookup );
189
190 /** Collect the beliefs from this InfAlg as expectations for
191 * the next Maximization step.
192 */
193 void addExpectations( InfAlg& alg );
194
195 /** Using all of the currently added expectations, make new factors
196 * with maximized parameters and set them in the FactorGraph.
197 */
198 void maximize( FactorGraph& fg );
199 };
200
201
202 /// EMAlg performs Expectation Maximization to learn factor parameters.
203 /** This requires specifying:
204 * - Evidence (instances of observations from the graphical model),
205 * - InfAlg for performing the E-step, which includes the factor graph
206 * - a vector of MaximizationSteps steps to be performed
207 *
208 * This implementation can peform incremental EM by using multiple
209 * MaximizationSteps. An expectation step is performed between execution
210 * of each MaximizationStep. A call to iterate() will cycle through all
211 * MaximizationSteps.
212 */
213 class EMAlg {
214 private:
215 /// All the data samples used during learning
216 const Evidence& _evidence;
217
218 /// How to do the expectation step
219 InfAlg& _estep;
220
221 /// The maximization steps to take
222 std::vector<MaximizationStep> _msteps;
223 size_t _iters;
224 std::vector<Real> _lastLogZ;
225
226 size_t _max_iters;
227 Real _log_z_tol;
228
229 public:
230 /// Key for setting maximum iterations @see setTermConditions
231 static const std::string MAX_ITERS_KEY; //("max_iters");
232 /// Default maximum iterations
233 static const size_t MAX_ITERS_DEFAULT;
234 /// Key for setting likelihood termination condition @see setTermConditions
235 static const std::string LOG_Z_TOL_KEY;
236 /// Default log_z_tol
237 static const Real LOG_Z_TOL_DEFAULT;
238
239 /// Construct an EMAlg from all these objects
240 EMAlg( const Evidence& evidence, InfAlg& estep, std::vector<MaximizationStep>& msteps, PropertySet* termconditions = NULL) :
241 _evidence(evidence), _estep(estep), _msteps(msteps), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT), _log_z_tol(LOG_Z_TOL_DEFAULT) {
242 setTermConditions( termconditions );
243 }
244
245 /// Construct an EMAlg from an input stream
246 EMAlg( const Evidence& evidence, InfAlg& estep, std::istream& mstep_file );
247
248 /// Change the coditions for termination
249 /** There are two possible parameters in the PropertySety
250 * - max_iters maximum number of iterations (default 30)
251 * - log_z_tol proportion of increase in logZ (default 0.01)
252 *
253 * \see hasSatisifiedTermConditions()
254 */
255 void setTermConditions( const PropertySet* p );
256
257 /// Determine if the termination conditions have been met.
258 /** There are two sufficient termination conditions:
259 * -# the maximum number of iterations has been performed
260 * -# the ratio of logZ increase over previous logZ is less than the
261 * tolerance. I.e.
262 \f$ \frac{\log(Z_{current}) - \log(Z_{previous})}
263 {| \log(Z_{previous}) | } < tol \f$.
264 */
265 bool hasSatisfiedTermConditions() const;
266
267 size_t getCurrentIters() const {
268 return _iters;
269 }
270
271 /// Perform an iteration over all maximization steps
272 Real iterate();
273
274 /// Performs an iteration over a single MaximizationStep
275 Real iterate( MaximizationStep& mstep );
276
277 /// Iterate until termination conditions satisfied
278 void run();
279 };
280
281
282 } // end of namespace dai
283
284
285 #endif