1 /* Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
2 University of California Santa Cruz
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
22 #ifndef __defined_libdai_emalg_h
23 #define __defined_libdai_emalg_h
29 #include <dai/factor.h>
30 #include <dai/daialg.h>
31 #include <dai/evidence.h>
32 #include <dai/index.h>
33 #include <dai/properties.h>
37 /// \brief Defines classes related to Expectation Maximization: EMAlg, ParameterEstimation, CondProbEstimation and SharedParameters
38 /// \todo Describe EM file format
44 /// Interface for a parameter estimation method.
45 /** This parameter estimation interface is based on sufficient statistics.
46 * Implementations are responsible for collecting data from a probability
47 * vector passed to it from a SharedParameters container object.
49 * Implementations of this interface should register a factory function
50 * via the static ParameterEstimation::registerMethod function.
51 * The default registry only contains CondProbEstimation, named
52 * "ConditionalProbEstimation".
54 class ParameterEstimation
{
56 /// A pointer to a factory function.
57 typedef ParameterEstimation
* (*ParamEstFactory
)( const PropertySet
& );
59 /// Virtual destructor for deleting pointers to derived classes.
60 virtual ~ParameterEstimation() {}
61 /// Virtual copy constructor.
62 virtual ParameterEstimation
* clone() const = 0;
64 /// General factory method for construction of ParameterEstimation subclasses.
65 static ParameterEstimation
* construct( const std::string
&method
, const PropertySet
&p
);
67 /// Register a subclass so that it can be used with ParameterEstimation::construct.
68 static void registerMethod( const std::string
&method
, const ParamEstFactory
&f
) {
69 if( _registry
== NULL
)
70 loadDefaultRegistry();
71 (*_registry
)[method
] = f
;
74 /// Estimate the factor using the accumulated sufficient statistics and reset.
75 virtual Prob
estimate() = 0;
77 /// Accumulate the sufficient statistics for p.
78 virtual void addSufficientStatistics( const Prob
&p
) = 0;
80 /// Returns the size of the Prob that should be passed to addSufficientStatistics.
81 virtual size_t probSize() const = 0;
84 /// A static registry containing all methods registered so far.
85 static std::map
<std::string
, ParamEstFactory
> *_registry
;
87 /// Registers default ParameterEstimation subclasses (currently, only CondProbEstimation).
88 static void loadDefaultRegistry();
92 /// Estimates the parameters of a conditional probability table, using pseudocounts.
93 class CondProbEstimation
: private ParameterEstimation
{
95 /// Number of states of the variable of interest
97 /// Current pseudocounts
99 /// Initial pseudocounts
104 /** For a conditional probability \f$ P( X | Y ) \f$,
105 * \param target_dimension should equal \f$ | X | \f$
106 * \param pseudocounts has length \f$ |X| \cdot |Y| \f$
108 CondProbEstimation( size_t target_dimension
, const Prob
&pseudocounts
);
110 /// Virtual constructor, using a PropertySet.
111 /** Some keys in the PropertySet are required.
112 * For a conditional probability \f$ P( X | Y ) \f$,
113 * - target_dimension should be equal to \f$ | X | \f$
114 * - total_dimension should be equal to \f$ |X| \cdot |Y| \f$
116 * An optional key is:
117 * - pseudo_count, which specifies the initial counts (defaults to 1)
119 static ParameterEstimation
* factory( const PropertySet
&p
);
121 /// Virtual copy constructor
122 virtual ParameterEstimation
* clone() const { return new CondProbEstimation( _target_dim
, _initial_stats
); }
124 /// Virtual destructor
125 virtual ~CondProbEstimation() {}
127 /// Returns an estimate of the conditional probability distribution.
128 /** The format of the resulting Prob keeps all the values for
129 * \f$ P(X | Y=y) \f$ in sequential order in the array.
131 virtual Prob
estimate();
133 /// Accumulate sufficient statistics from the expectations in p.
134 virtual void addSufficientStatistics( const Prob
&p
);
136 /// Returns the required size for arguments to addSufficientStatistics.
137 virtual size_t probSize() const { return _stats
.size(); }
141 /// A single factor or set of factors whose parameters should be estimated.
142 /** To ensure that parameters can be shared between different factors during
143 * EM learning, each factor's values are reordered to match a desired variable
144 * ordering. The ordering of the variables in a factor may therefore differ
145 * from the canonical ordering used in libDAI. The SharedParameters
146 * class couples one or more factors (together with the specified orderings
147 * of the variables) with a ParameterEstimation object, taking care of the
148 * necessary permutations of the factor entries / parameters.
150 class SharedParameters
{
152 /// Convenience label for an index into a factor in a FactorGraph.
153 typedef size_t FactorIndex
;
154 /// Convenience label for a grouping of factor orientations.
155 typedef std::map
<FactorIndex
, std::vector
<Var
> > FactorOrientations
;
158 /// Maps factor indices to the corresponding VarSets
159 std::map
<FactorIndex
, VarSet
> _varsets
;
160 /// Maps factor indices to the corresponding Permute objects that permute the desired ordering into the canonical ordering
161 std::map
<FactorIndex
, Permute
> _perms
;
162 /// Maps factor indices to the corresponding desired variable orderings
163 FactorOrientations _varorders
;
164 /// Parameter estimation method to be used
165 ParameterEstimation
*_estimation
;
166 /// Indicates whether the object pointed to by _estimation should be deleted upon destruction
167 bool _deleteEstimation
;
169 /// Calculates the permutation that permutes the variables in varorder into the canonical ordering
170 static Permute
calculatePermutation( const std::vector
<Var
> &varorder
, VarSet
&outVS
);
172 /// Initializes _varsets and _perms from _varorders
173 void setPermsAndVarSetsFromVarOrders();
177 SharedParameters( const SharedParameters
&sp
);
180 /** \param varorders all the factor orientations for this parameter
181 * \param estimation a pointer to the parameter estimation method
182 * \param deletePE whether the parameter estimation object should be deleted in the destructor
184 SharedParameters( const FactorOrientations
&varorders
, ParameterEstimation
*estimation
, bool deletePE
=false );
186 /// Constructor for making an object from a stream and a factor graph
187 SharedParameters( std::istream
&is
, const FactorGraph
&fg_varlookup
);
190 ~SharedParameters() {
191 if( _deleteEstimation
)
195 /// Collect the necessary statistics from expected values
196 void collectSufficientStatistics( InfAlg
&alg
);
198 /// Estimate and set the shared parameters
199 void setParameters( FactorGraph
&fg
);
201 /// Returns the parameters
202 void collectParameters( const FactorGraph
&fg
, std::vector
<Real
> &outVals
, std::vector
<Var
> &outVarOrder
);
206 /// A MaximizationStep groups together several parameter estimation tasks into a single unit.
207 class MaximizationStep
{
209 std::vector
<SharedParameters
> _params
;
212 /// Default constructor
213 MaximizationStep() : _params() {}
215 /// Constructor from a vector of SharedParameters objects
216 MaximizationStep( std::vector
<SharedParameters
> &maximizations
) : _params(maximizations
) {}
218 /// Constructor from an input stream and a corresponding factor graph
219 MaximizationStep( std::istream
&is
, const FactorGraph
&fg_varlookup
);
221 /// Collect the beliefs from this InfAlg as expectations for the next Maximization step.
222 void addExpectations( InfAlg
&alg
);
224 /// Using all of the currently added expectations, make new factors with maximized parameters and set them in the FactorGraph.
225 void maximize( FactorGraph
&fg
);
227 /// @name Iterator interface
229 typedef std::vector
<SharedParameters
>::iterator iterator
;
230 typedef std::vector
<SharedParameters
>::const_iterator const_iterator
;
231 iterator
begin() { return _params
.begin(); }
232 const_iterator
begin() const { return _params
.begin(); }
233 iterator
end() { return _params
.end(); }
234 const_iterator
end() const { return _params
.end(); }
239 /// EMAlg performs Expectation Maximization to learn factor parameters.
240 /** This requires specifying:
241 * - Evidence (instances of observations from the graphical model),
242 * - InfAlg for performing the E-step, which includes the factor graph,
243 * - a vector of MaximizationSteps steps to be performed.
245 * This implementation can perform incremental EM by using multiple
246 * MaximizationSteps. An expectation step is performed between execution
247 * of each MaximizationStep. A call to iterate() will cycle through all
250 * Having multiple and separate maximization steps allows for maximizing some
251 * parameters, performing another E step, and then maximizing separate
252 * parameters, which may result in faster convergence in some cases.
256 /// All the data samples used during learning
257 const Evidence
&_evidence
;
259 /// How to do the expectation step
262 /// The maximization steps to take
263 std::vector
<MaximizationStep
> _msteps
;
265 /// Number of iterations done
268 /// History of likelihoods
269 std::vector
<Real
> _lastLogZ
;
271 /// Maximum number of iterations
274 /// Convergence tolerance
278 /// Key for setting maximum iterations @see setTermConditions
279 static const std::string MAX_ITERS_KEY
;
280 /// Default maximum iterations @see setTermConditions
281 static const size_t MAX_ITERS_DEFAULT
;
282 /// Key for setting likelihood termination condition @see setTermConditions
283 static const std::string LOG_Z_TOL_KEY
;
284 /// Default likelihood tolerance @see setTermConditions
285 static const Real LOG_Z_TOL_DEFAULT
;
287 /// Construct an EMAlg from all these objects
288 EMAlg( const Evidence
&evidence
, InfAlg
&estep
, std::vector
<MaximizationStep
> &msteps
, const PropertySet
&termconditions
)
289 : _evidence(evidence
), _estep(estep
), _msteps(msteps
), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT
), _log_z_tol(LOG_Z_TOL_DEFAULT
)
291 setTermConditions( termconditions
);
294 /// Construct an EMAlg from an Evidence object, an InfAlg object, and an input stream
295 EMAlg( const Evidence
&evidence
, InfAlg
&estep
, std::istream
&mstep_file
);
297 /// Change the conditions for termination
298 /** There are two possible parameters in the PropertySet
299 * - max_iters maximum number of iterations
300 * - log_z_tol proportion of increase in logZ
302 * \see hasSatisifiedTermConditions()
304 void setTermConditions( const PropertySet
&p
);
306 /// Determine if the termination conditions have been met.
307 /** There are two sufficient termination conditions:
308 * -# the maximum number of iterations has been performed
309 * -# the ratio of logZ increase over previous logZ is less than the
311 * \f$ \frac{\log(Z_t) - \log(Z_{t-1})}{| \log(Z_{t-1}) | } < \mathrm{tol} \f$.
313 bool hasSatisfiedTermConditions() const;
315 /// Return the last calculated log likelihood
316 Real
getLogZ() const { return _lastLogZ
.back(); }
318 /// Returns number of iterations done so far
319 size_t getCurrentIters() const { return _iters
; }
321 /// Get the iteration method used
322 const InfAlg
& eStep() const { return _estep
; }
324 /// Perform an iteration over all maximization steps
327 /// Perform an iteration over a single MaximizationStep
328 Real
iterate( MaximizationStep
&mstep
);
330 /// Iterate until termination conditions are satisfied
333 /// @name Iterator interface
335 typedef std::vector
<MaximizationStep
>::iterator s_iterator
;
336 typedef std::vector
<MaximizationStep
>::const_iterator const_s_iterator
;
337 s_iterator
s_begin() { return _msteps
.begin(); }
338 const_s_iterator
s_begin() const { return _msteps
.begin(); }
339 s_iterator
s_end() { return _msteps
.end(); }
340 const_s_iterator
s_end() const { return _msteps
.end(); }
345 } // end of namespace dai