Full command line inference and tests
[libdai.git] / include / dai / emalg.h
1 /*
2 Copyright 2009 Charles Vaske <cvaske@soe.ucsc.edu>
3 University of California Santa Cruz
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, either version 3 of the License, or
8 (at your option) any later version.
9
10 This program is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with this program. If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 #ifndef __defined_libdai_emalg_h
20 #define __defined_libdai_emalg_h
21
22 #include<vector>
23 #include<map>
24
25 #include <dai/factor.h>
26 #include <dai/daialg.h>
27 #include <dai/evidence.h>
28 #include <dai/index.h>
29 #include <dai/properties.h>
30
31
32 /// \file
33 /** \brief Defines classes related to Expectation Maximization:
34 * EMAlg, ParameterEstimate, and FactorOrientations
35 */
36
37 namespace dai {
38
39 ///Interface for a parameter estimation method.
40 /** This parameter estimation interface is based on sufficient statistics.
41 * Implementations are responsible for collecting data from a probability
42 * vector passed to it from a SharedParameters container object.
43 *
44 * Implementations of this interface should register a factory function
45 * via the static ParameterEstimation::registerMethod function.
46 */
47 class ParameterEstimation {
48 public:
49 /// A pointer to a factory function.
50 typedef ParameterEstimation* (*ParamEstFactory)(const PropertySet&);
51
52 /// General factory method for construction of ParameterEstimation subclasses.
53 static ParameterEstimation* construct(const std::string& method,
54 const PropertySet& p);
55 /// Register a subclass with ParameterEstimation::construct.
56 static void registerMethod(const std::string method,
57 const ParamEstFactory f) {
58 if (_registry == NULL) {
59 loadDefaultRegistry();
60 }
61 (*_registry)[method] = f;
62 }
63 /// Virtual destructor for deleting pointers to derived classes.
64 virtual ~ParameterEstimation() {}
65 /// Estimate the factor using the accumulated sufficient statistics and reset.
66 virtual Prob estimate() = 0;
67 /// Accumulate the sufficient statistics for p.
68 virtual void addSufficientStatistics(Prob& p) = 0;
69 /// Returns the size of the Prob that is passed to addSufficientStatistics.
70 virtual size_t probSize() const = 0;
71 /// A virtual copy constructor.
72 virtual ParameterEstimation* clone() const= 0;
73 private:
74 static std::map< std::string, ParamEstFactory >* _registry;
75 static void loadDefaultRegistry();
76 };
77
78 /// Estimates the parameters of a conditional probability, using pseudocounts.
79 class CondProbEstimation : private ParameterEstimation {
80 private:
81 size_t _target_dim;
82 Prob _stats;
83 Prob _initial_stats;
84 public:
85 /** For a conditional probability \f$ Pr( X | Y ) \f$,
86 * \param target_dimension should equal \f$ | X | \f$
87 * \param pseudocounts has length \f$ |X| \cdot |Y| \f$
88 */
89 CondProbEstimation(size_t target_dimension, Prob pseudocounts);
90
91 /// Virtual constructor, using a PropertySet.
92 /** Some keys in the PropertySet are required:
93 * - target_dimension, which should be equal to \f$ | X | \f$
94 * - total_dimension, which sholud be equal to \f$ |X| \cdot |Y| \f$
95 *
96 * An optional key is:
97 * - pseudo_count which specifies the initial counts (defaults to 1)
98 */
99 static ParameterEstimation* factory(const PropertySet& p);
100 /// Virtual destructor
101 virtual ~CondProbEstimation() {}
102 /// Returns an estimate of the conditional probability distribution.
103 /** The format of the resulting Prob keeps all the values for
104 * \f$ P(X | Y=a) \f$ sequential in teh array.
105 */
106 virtual Prob estimate();
107 /// Accumulate sufficient statistics from the expectations in p.
108 virtual void addSufficientStatistics(Prob& p);
109 /// Returns the required size for arguments to addSufficientStatistics
110 virtual size_t probSize() const { return _stats.size(); }
111 /// Virtual copy constructor.
112 virtual ParameterEstimation* clone() const {
113 return new CondProbEstimation(_target_dim, _initial_stats);
114 }
115 };
116
117 /** A single factor or set of factors whose parameters should be
118 * estimated. Each factor's values are reordered to match a
119 * canonical variable ordering. This canonical variable ordering
120 * will likely not be the order of variables required to make two
121 * factors parameters isomorphic. Therefore, this ordering of the
122 * variables must be specified for ever factor to ensure that
123 * parameters can be shared between different factors during EM.
124 */
125 class SharedParameters {
126 public:
127 /// Convenience label for an index into a FactorGraph to a factor.
128 typedef size_t FactorIndex;
129 /// Convenience label for a grouping of factor orientations.
130 typedef std::map< FactorIndex, std::vector< Var > > FactorOrientations;
131 private:
132 std::map< FactorIndex, VarSet > _varsets;
133 std::map< FactorIndex, Permute > _perms;
134 FactorOrientations _varorders;
135 ParameterEstimation* _estimation;
136 bool _deleteEstimation;
137
138 static Permute calculatePermutation(const std::vector< Var >& varorder,
139 const std::vector< size_t >& dims,
140 VarSet& outVS);
141 void setPermsAndVarSetsFromVarOrders();
142 public:
143 /// Copy constructor
144 SharedParameters(const SharedParameters& sp);
145 /// Constructor useful in programmatic settings
146 /** \param varorders all the factor orientations for this parameter
147 \param estimation a pointer to the parameter estimation method
148 */
149 SharedParameters(const FactorOrientations& varorders,
150 ParameterEstimation* estimation);
151
152 /// Constructor for making an object from a stream
153 SharedParameters(std::istream& is, const FactorGraph& fg_varlookup);
154
155 /// Destructor
156 ~SharedParameters() { if (_deleteEstimation) delete _estimation; }
157
158 /// Collect the necessary statistics from expected values
159 void collectSufficientStatistics(InfAlg& alg);
160
161 /// Estimate and set the shared parameters
162 void setParameters(FactorGraph& fg);
163 };
164
165 /** A maximization step groups together several parameter estimation
166 * tasks into a single unit.
167 */
168 class MaximizationStep {
169 private:
170 std::vector< SharedParameters > _params;
171 public:
172 MaximizationStep() : _params() {}
173
174 /// Construct an step object taht contains all these estimation probelms
175 MaximizationStep(std::vector< SharedParameters >& maximizations) :
176 _params(maximizations) {}
177
178 /// Construct a step from an input stream
179 MaximizationStep(std::istream& is, const FactorGraph& fg_varlookup);
180
181 /** Collect the beliefs from this InfAlg as expectations for
182 * the next Maximization step.
183 */
184 void addExpectations(InfAlg& alg);
185
186 /** Using all of the currently added expectations, make new factors
187 * with maximized parameters and set them in the FactorGraph.
188 */
189 void maximize(FactorGraph& fg);
190 };
191
192 /// EMAlg performs Expectation Maximization to learn factor parameters.
193 /** This requires specifying:
194 * - Evidence (instances of observations from the graphical model),
195 * - InfAlg for performing the E-step, which includes the factor graph
196 * - a vector of MaximizationSteps steps to be performed
197 *
198 * This implementation can peform incremental EM by using multiple
199 * MaximizationSteps. An expectation step is performed between execution
200 * of each MaximizationStep. A call to iterate() will cycle through all
201 * MaximizationSteps.
202 */
203 class EMAlg {
204 private:
205 /// All the data samples used during learning
206 const Evidence& _evidence;
207
208 /// How to do the expectation step
209 InfAlg& _estep;
210
211 /// The maximization steps to take
212 std::vector<MaximizationStep> _msteps;
213 size_t _iters;
214 std::vector<Real> _lastLogZ;
215
216 size_t _max_iters;
217 Real _log_z_tol;
218
219 public:
220 /// Key for setting maximum iterations @see setTermConditions
221 static const std::string MAX_ITERS_KEY;//("max_iters");
222 /// Default maximum iterations
223 static const size_t MAX_ITERS_DEFAULT;
224 /// Key for setting likelihood termination condition @see setTermConditions
225 static const std::string LOG_Z_TOL_KEY;
226 /// Default log_z_tol
227 static const Real LOG_Z_TOL_DEFAULT;
228
229 /// Construct an EMAlg from all these objects
230 EMAlg(const Evidence& evidence, InfAlg& estep,
231 std::vector<MaximizationStep>& msteps,
232 PropertySet* termconditions=NULL)
233 : _evidence(evidence),
234 _estep(estep),
235 _msteps(msteps),
236 _iters(0),
237 _lastLogZ(),
238 _max_iters(MAX_ITERS_DEFAULT),
239 _log_z_tol(LOG_Z_TOL_DEFAULT) {
240 setTermConditions(termconditions);
241 }
242
243 /// Construct an EMAlg from an input stream
244 EMAlg(const Evidence& evidence, InfAlg& estep, std::istream& mstep_file);
245
246 /// Change the coditions for termination
247 /** There are two possible parameters in the PropertySety
248 * - max_iters maximum number of iterations (default 30)
249 * - log_z_tol proportion of increase in logZ (default 0.01)
250 *
251 * \see hasSatisifiedTermConditions()
252 */
253 void setTermConditions(const PropertySet* p);
254
255 /// Determine if the termination conditions have been met.
256 /** There are two sufficient termination conditions:
257 * -# the maximum number of iterations has been performed
258 * -# the ratio of logZ increase over previous logZ is less than the
259 * tolerance. I.e.
260 \f$ \frac{\log(Z_{current}) - \log(Z_{previous})}
261 {| \log(Z_{previous}) | } < tol \f$.
262 */
263 bool hasSatisfiedTermConditions() const;
264
265 size_t getCurrentIters() const { return _iters; }
266
267 /// Perform an iteration over all maximization steps
268 Real iterate();
269
270 /// Performs an iteration over a single MaximizationStep
271 Real iterate(MaximizationStep& mstep);
272
273 /// Iterate until termination conditions satisfied
274 void run();
275
276 };
277
278 } // namespace dai
279
280 #endif