Various EM code improvements by Charles Vaske and Andy Nguyen
[libdai.git] / include / dai / emalg.h
index 543cc67..9f99c63 100644 (file)
@@ -73,15 +73,22 @@ class ParameterEstimation {
             (*_registry)[method] = f;
         }
 
-        /// Estimate the factor using the accumulated sufficient statistics and reset.
-        virtual Prob estimate() = 0;
+        /// Estimate the factor using the provided expectations.
+        virtual Prob estimate(const Prob &p) { return parametersToFactor(parameters(p)); }
 
-        /// Accumulate the sufficient statistics for \a p.
-        virtual void addSufficientStatistics( const Prob &p ) = 0;
+        /// Convert a set of estimated parameters to a factor
+        virtual Prob parametersToFactor(const Prob &params) = 0;
 
-        /// Returns the size of the Prob that should be passed to addSufficientStatistics.
+        /// Return parameters for the estimated factor, in a format specific to the parameter estimation
+        virtual Prob parameters(const Prob &p) = 0;
+
+        /// Returns the size of the Prob that should be passed to estimate and parameters
         virtual size_t probSize() const = 0;
 
+        // Returns the name of the class of this parameter estimation
+        virtual const std::string& name() const = 0;
+
+        virtual const PropertySet& properties() const = 0;
     private:
         /// A static registry containing all methods registered so far.
         static std::map<std::string, ParamEstFactory> *_registry;
@@ -98,11 +105,14 @@ class CondProbEstimation : private ParameterEstimation {
     private:
         /// Number of states of the variable of interest
         size_t _target_dim;
-        /// Current pseudocounts
-        Prob _stats;
         /// Initial pseudocounts
         Prob _initial_stats;
 
+        static std::string _name; // = "CondProbEstimation";
+
+        /// PropertySet that allows reconstruction of this estimator
+        PropertySet _props;
+
     public:
         /// Constructor
         /** For a conditional probability \f$ P( X | Y ) \f$,
@@ -132,13 +142,18 @@ class CondProbEstimation : private ParameterEstimation {
         /** The format of the resulting Prob keeps all the values for
          *  \f$ P(X | Y=y) \f$ in sequential order in the array.
          */
-        virtual Prob estimate();
-
-        /// Accumulate sufficient statistics from the expectations in \a p
-        virtual void addSufficientStatistics( const Prob &p );
-
-        /// Returns the required size for arguments to addSufficientStatistics().
-        virtual size_t probSize() const { return _stats.size(); }
+        virtual Prob parameters(const Prob &p);
+
+        // For a discrete conditional probability distribution,
+        // the parameters are equivalent to the resulting factor
+        virtual Prob parametersToFactor(const Prob &p);
+
+        /// Returns the required size for arguments to estimate().
+        virtual size_t probSize() const { return _initial_stats.size(); }
+        
+        virtual const std::string& name() const { return _name; }
+        
+        virtual const PropertySet& properties() const { return _props; }
 };
 
 
@@ -171,6 +186,8 @@ class SharedParameters {
         ParameterEstimation *_estimation;
         /// Indicates whether \c *this gets ownership of _estimation
         bool _ownEstimation;
+        /// The accumulated expectations
+        Prob* _expectations;
 
         /// Calculates the permutation that permutes the canonical ordering into the desired ordering
         /** \param varOrder Desired ordering of variables
@@ -197,10 +214,11 @@ class SharedParameters {
         SharedParameters( std::istream &is, const FactorGraph &fg );
 
         /// Copy constructor
-        SharedParameters( const SharedParameters &sp ) : _varsets(sp._varsets), _perms(sp._perms), _varorders(sp._varorders), _estimation(sp._estimation), _ownEstimation(sp._ownEstimation) {
+        SharedParameters( const SharedParameters &sp ) : _varsets(sp._varsets), _perms(sp._perms), _varorders(sp._varorders), _estimation(sp._estimation), _ownEstimation(sp._ownEstimation), _expectations(NULL) {
             // If sp owns its _estimation object, we should clone it instead of copying the pointer
             if( _ownEstimation )
                 _estimation = _estimation->clone();
+            _expectations = new Prob(*sp._expectations);
         }
 
         /// Destructor
@@ -208,24 +226,39 @@ class SharedParameters {
             // If we own the _estimation object, we should delete it now
             if( _ownEstimation )
                 delete _estimation;
+            if( _expectations != NULL) 
+                delete _expectations;
         }
 
-        /// Collect the sufficient statistics from expected values (beliefs) according to \a alg
+        /// Collect the expected values (beliefs) according to \a alg
         /** For each of the relevant factors (that shares the parameters we are interested in),
          *  the corresponding belief according to \a alg is obtained and its entries are permuted
          *  such that their ordering corresponds with the shared parameters that we are estimating.
-         *  Then, the parameter estimation subclass method addSufficientStatistics() is called with
-         *  this vector of expected values of the parameters as input.
          */
-        void collectSufficientStatistics( InfAlg &alg );
+        void collectExpectations( InfAlg &alg );
+
+        /// Return the current accumulated expectations
+        const Prob& currentExpectations() const { return *_expectations; }
+
+        ParameterEstimation& getPEst() const { return *_estimation; }
 
         /// Estimate and set the shared parameters
-        /** Based on the sufficient statistics collected so far, the shared parameters are estimated
+        /** Based on the expectation statistics collected so far, the shared parameters are estimated
          *  using the parameter estimation subclass method estimate(). Then, each of the relevant
          *  factors in \a fg (that shares the parameters we are interested in) is set according 
          *  to those parameters (permuting the parameters accordingly).
          */
         void setParameters( FactorGraph &fg );
+
+        /// Return a reference to the vector of factor orientations
+        /** This is necessary for determing which variables were used
+         *  to estimate parameters, and analysis of expectations
+         *  after an Estimation step has been performed.
+         */
+        const FactorOrientations& getFactorOrientations() const { return _varorders; }
+
+        /// Reset the current expectations
+        void clear( ) { _expectations->fill(0); }
 };
 
 
@@ -254,6 +287,9 @@ class MaximizationStep {
 
         /// Using all of the currently added expectations, make new factors with maximized parameters and set them in the FactorGraph.
         void maximize( FactorGraph &fg );
+        
+        /// Clear the step, to be called at the begining of each step
+        void clear( );
 
     /// \name Iterator interface
     //@{