EM bugfix. Convenience methods in Factor, Permute, Properties, EM.
authorCharles Vaske <cvaske@hgwdev.cse.ucsc.edu>
Sat, 15 Aug 2009 18:27:51 +0000 (11:27 -0700)
committerCharles Vaske <cvaske@hgwdev.cse.ucsc.edu>
Sat, 15 Aug 2009 18:27:51 +0000 (11:27 -0700)
 -- Bugfix: was using abs() instead of fabs() in determining EM termination,
    which caused a loss of precision.
 -- New constructor in Permute for canonical variable ordering
 -- New constructor in Factor that reorders variables to canonical ordering
 -- New accessor in Properties for getting all keys
 -- New function for getting inferred parameters

include/dai/emalg.h
include/dai/evidence.h
include/dai/factor.h
include/dai/index.h
include/dai/properties.h
src/emalg.cpp
tests/testem/testem.out

index 021cf7e..622a25b 100644 (file)
@@ -180,7 +180,7 @@ class SharedParameters {
         /** \param varorders  all the factor orientations for this parameter
          *  \param estimation a pointer to the parameter estimation method
          */ 
         /** \param varorders  all the factor orientations for this parameter
          *  \param estimation a pointer to the parameter estimation method
          */ 
-        SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation );
+        SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool deleteParameterEstimationInDestructor=0);
 
         /// Constructor for making an object from a stream and a factor graph
         SharedParameters( std::istream &is, const FactorGraph &fg_varlookup );
 
         /// Constructor for making an object from a stream and a factor graph
         SharedParameters( std::istream &is, const FactorGraph &fg_varlookup );
@@ -196,6 +196,8 @@ class SharedParameters {
 
         /// Estimate and set the shared parameters
         void setParameters( FactorGraph &fg );
 
         /// Estimate and set the shared parameters
         void setParameters( FactorGraph &fg );
+
+       void collectParameters( const FactorGraph& fg, std::vector< Real >& outVals, std::vector< Var >& outVarOrder );
 };
 
 
 };
 
 
@@ -219,6 +221,16 @@ class MaximizationStep {
 
         /// Using all of the currently added expectations, make new factors with maximized parameters and set them in the FactorGraph.
         void maximize( FactorGraph &fg );
 
         /// Using all of the currently added expectations, make new factors with maximized parameters and set them in the FactorGraph.
         void maximize( FactorGraph &fg );
+
+       /// @name iterator interface
+       //@{
+       typedef std::vector< SharedParameters >::iterator iterator;
+       typedef std::vector< SharedParameters >::const_iterator const_iterator;
+       iterator begin() { return _params.begin(); }
+       const_iterator begin() const { return _params.begin(); }
+       iterator end() { return _params.end(); }
+       const_iterator end() const { return _params.end(); }
+       //@}
 };
 
 
 };
 
 
@@ -298,9 +310,15 @@ class EMAlg {
          */
         bool hasSatisfiedTermConditions() const;
 
          */
         bool hasSatisfiedTermConditions() const;
 
+       /// Return the last calculated log likelihood
+       Real getLogZ() const { return _lastLogZ.back(); }
+
         /// Returns number of iterations done so far
         size_t getCurrentIters() const { return _iters; }
 
         /// Returns number of iterations done so far
         size_t getCurrentIters() const { return _iters; }
 
+       /// Get the iteration method used
+       const InfAlg& eStep() const { return _estep; }
+
         /// Perform an iteration over all maximization steps
         Real iterate();
 
         /// Perform an iteration over all maximization steps
         Real iterate();
 
@@ -309,6 +327,16 @@ class EMAlg {
 
         /// Iterate until termination conditions are satisfied
         void run();
 
         /// Iterate until termination conditions are satisfied
         void run();
+
+       /// @name iterator interface
+       //@{ !!!
+       typedef std::vector< MaximizationStep >::iterator s_iterator;
+       typedef std::vector< MaximizationStep >::const_iterator const_s_iterator;
+       s_iterator s_begin() { return _msteps.begin(); }
+       const_s_iterator s_begin() const { return _msteps.begin(); }
+       s_iterator s_end() { return _msteps.end(); }
+       const_s_iterator s_end() const { return _msteps.end(); }
+       //@}
 };
 
 
 };
 
 
index 275c3bf..1738732 100644 (file)
@@ -69,11 +69,15 @@ class Evidence {
         /// Default constructor
         Evidence() : _samples() {}
       
         /// Default constructor
         Evidence() : _samples() {}
       
-        /// Read in tabular data from a stream. 
+        /// Constructor with existing samples
+         Evidence(std::vector<Observation>& samples) : _samples(samples) {}
+
+/// Read in tabular data from a stream. 
         /** Each line contains one sample, and the first line is a header line with names.
          */
         void addEvidenceTabFile( std::istream& is, std::map<std::string, Var> &varMap );
 
         /** Each line contains one sample, and the first line is a header line with names.
          */
         void addEvidenceTabFile( std::istream& is, std::map<std::string, Var> &varMap );
 
+    
         /// Read in tabular data from a stream. 
         /** Each line contains one sample, and the first line is a header line with 
          *  variable labels which should correspond with a subset of the variables in fg.
         /// Read in tabular data from a stream. 
         /** Each line contains one sample, and the first line is a header line with 
          *  variable labels which should correspond with a subset of the variables in fg.
index 42df4ce..1427002 100644 (file)
@@ -113,6 +113,12 @@ template <typename T> class TFactor {
             assert( _vs.nrStates() == _p.size() );
 #endif
         }
             assert( _vs.nrStates() == _p.size() );
 #endif
         }
+        TFactor( const std::vector< Var >& vars, const std::vector< T >& p ) : _vs(vars.begin(), vars.end(), vars.size()), _p(p.size()) {
+            Permute permindex(vars);
+            for (size_t li = 0; li < p.size(); ++li) {
+                _p[permindex.convert_linear_index(li)] = p[li];
+            }
+        }
         
         /// Constructs TFactor depending on the variable n, with uniform distribution
         TFactor( const Var& n ) : _vs(n), _p(n.states()) {}
         
         /// Constructs TFactor depending on the variable n, with uniform distribution
         TFactor( const Var& n ) : _vs(n), _p(n.states()) {}
index 46b2244..94b9325 100644 (file)
@@ -65,7 +65,7 @@ class IndexFor {
     private:
         /// The current linear index corresponding to the state of indexVars
         long                _index;
     private:
         /// The current linear index corresponding to the state of indexVars
         long                _index;
-
+       
         /// For each variable in forVars, the amount of change in _index
         std::vector<long>   _sum;
 
         /// For each variable in forVars, the amount of change in _index
         std::vector<long>   _sum;
 
@@ -74,7 +74,7 @@ class IndexFor {
         
         /// For each variable in forVars, its number of possible values
         std::vector<size_t> _dims;
         
         /// For each variable in forVars, its number of possible values
         std::vector<size_t> _dims;
-
+       
     public:
         /// Default constructor
         IndexFor() { 
     public:
         /// Default constructor
         IndexFor() { 
@@ -224,13 +224,26 @@ class Permute {
         Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
             assert( _dims.size() == _sigma.size() );
         }
         Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
             assert( _dims.size() == _sigma.size() );
         }
-
+  
+       Permute(const std::vector< Var >& vars) : _dims(vars.size()), _sigma(vars.size()) {
+           VarSet vs(vars.begin(), vars.end(), vars.size());
+           for (size_t i = 0; i < vars.size(); ++i) {
+               _dims[i] = vars[i].states();
+           }
+           VarSet::iterator set_iter = vs.begin();
+           for (size_t i = 0; i < vs.size(); ++i, ++set_iter) {
+               std::vector< Var >::const_iterator j;
+               j = find(vars.begin(), vars.end(), *set_iter);
+               _sigma[i] = j - vars.begin();
+           }
+       }
+       
         /// Calculates a permuted linear index.
         /** Converts the linear index li to a vector index
          *  corresponding with the dimensions in _dims, permutes it according to sigma, 
          *  and converts it back to a linear index  according to the permuted dimensions.
          */
         /// Calculates a permuted linear index.
         /** Converts the linear index li to a vector index
          *  corresponding with the dimensions in _dims, permutes it according to sigma, 
          *  and converts it back to a linear index  according to the permuted dimensions.
          */
-        size_t convert_linear_index( size_t li ) {
+        size_t convert_linear_index( size_t li ) const {
             size_t N = _dims.size();
 
             // calculate vector index corresponding to linear index
             size_t N = _dims.size();
 
             // calculate vector index corresponding to linear index
index 6ad5f2a..9539d5c 100644 (file)
@@ -33,6 +33,7 @@
 #include <sstream>
 #include <boost/any.hpp>
 #include <map>
 #include <sstream>
 #include <boost/any.hpp>
 #include <map>
+#include <vector>
 #include <cassert>
 #include <typeinfo>
 #include <dai/exceptions.h>
 #include <cassert>
 #include <typeinfo>
 #include <dai/exceptions.h>
@@ -151,7 +152,17 @@ class PropertySet : private std::map<PropertyKey, PropertyValue> {
         }
 
         /// Shorthand for (temporarily) adding properties, e.g. PropertySet p()("method","BP")("verbose",1)("tol",1e-9)
         }
 
         /// Shorthand for (temporarily) adding properties, e.g. PropertySet p()("method","BP")("verbose",1)("tol",1e-9)
-        PropertySet operator()(const PropertyKey &key, const PropertyValue &val) const { PropertySet copy = *this; return copy.Set(key,val); }
+               PropertySet operator()(const PropertyKey &key, const PropertyValue &val) const { PropertySet copy = *this; return copy.Set(key,val); }
+               
+               std::vector< PropertyKey > keys() const {
+                       std::vector< PropertyKey > result;
+                       result.reserve(size());
+                       PropertySet::const_iterator i = begin();
+                       for ( ; i != end(); ++i) {
+                               result.push_back(i->first);
+                       }
+                       return result;
+               }
 
         /// Check if a property with the given key exists
         bool hasKey(const PropertyKey &key) const { PropertySet::const_iterator x = find(key); return (x != this->end()); }
 
         /// Check if a property with the given key exists
         bool hasKey(const PropertyKey &key) const { PropertySet::const_iterator x = find(key); return (x != this->end()); }
index 8e93700..e0fcab3 100644 (file)
@@ -196,8 +196,8 @@ SharedParameters::SharedParameters( const SharedParameters &sp )
 }
 
 
 }
 
 
-SharedParameters::SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation )
-  : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _deleteEstimation(false
+SharedParameters::SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool deleteParameterEstimationInDestructor )
+  : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _deleteEstimation(deleteParameterEstimationInDestructor
 {
     // Calculate the necessary permutations
     setPermsAndVarSetsFromVarOrders();
 {
     // Calculate the necessary permutations
     setPermsAndVarSetsFromVarOrders();
@@ -233,6 +233,26 @@ void SharedParameters::setParameters( FactorGraph &fg ) {
 }
 
 
 }
 
 
+void SharedParameters::collectParameters( const FactorGraph& fg, std::vector< Real >& outVals, std::vector< Var >& outVarOrder ) {
+    FactorOrientations::iterator it = _varorders.begin();
+    if (it == _varorders.end()) {
+       return;
+    }
+    FactorIndex i = it->first;
+    std::vector< Var >::iterator var_it = _varorders[i].begin();
+    std::vector< Var >::iterator var_stop = _varorders[i].end();
+    for ( ; var_it != var_stop; ++var_it) {
+       outVarOrder.push_back(*var_it);
+    }
+    const Factor& f = fg.factor(i);
+    assert(f.vars() == _varsets[i]);
+    const Permute& perm = _perms[i];
+    for (size_t val_index = 0; val_index < f.states(); ++val_index) {
+       outVals.push_back(f[perm.convert_linear_index(val_index)]);
+    }
+}
+
+
 MaximizationStep::MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup ) : _params() {
     size_t num_params = -1;
     is >> num_params;
 MaximizationStep::MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup ) : _params() {
     size_t num_params = -1;
     is >> num_params;
@@ -298,32 +318,36 @@ bool EMAlg::hasSatisfiedTermConditions() const {
             std::cerr << "Error: in EM log-likehood decreased from " << previous << " to " << current << std::endl;
             return true;
         }
             std::cerr << "Error: in EM log-likehood decreased from " << previous << " to " << current << std::endl;
             return true;
         }
-        return diff / abs(previous) <= _log_z_tol;
+        return (diff / fabs(previous)) <= _log_z_tol;
     }
 }
 
 
 Real EMAlg::iterate( MaximizationStep &mstep ) {
     Real logZ = 0;
     }
 }
 
 
 Real EMAlg::iterate( MaximizationStep &mstep ) {
     Real logZ = 0;
-
+       Real likelihood = 0;
+       
+       _estep.run();
+       logZ = _estep.logZ();
+       
     // Expectation calculation
     for( Evidence::const_iterator e = _evidence.begin(); e != _evidence.end(); ++e ) {
         InfAlg* clamped = _estep.clone();
         e->applyEvidence( *clamped );
         clamped->init();
         clamped->run();
     // Expectation calculation
     for( Evidence::const_iterator e = _evidence.begin(); e != _evidence.end(); ++e ) {
         InfAlg* clamped = _estep.clone();
         e->applyEvidence( *clamped );
         clamped->init();
         clamped->run();
-      
-        logZ += clamped->logZ();
-
+               
+               likelihood += clamped->logZ() - logZ;
+               
         mstep.addExpectations( *clamped );
         mstep.addExpectations( *clamped );
-
+               
         delete clamped;
     }
     
     // Maximization of parameters
     mstep.maximize( _estep.fg() );
         delete clamped;
     }
     
     // Maximization of parameters
     mstep.maximize( _estep.fg() );
-
-    return logZ;
+       
+       return likelihood;
 }
 
 
 }
 
 
index 5e777cb..54861d3 100644 (file)
@@ -19,9 +19,9 @@ Sample #16 has 2 observations.
 Sample #17 has 2 observations.
 Sample #18 has 2 observations.
 Sample #19 has 2 observations.
 Sample #17 has 2 observations.
 Sample #18 has 2 observations.
 Sample #19 has 2 observations.
-Iteration 1 likelihood: -13.8629
-Iteration 2 likelihood: -9.56675
-Iteration 3 likelihood: -9.56675
+Iteration 1 likelihood: -27.7259
+Iteration 2 likelihood: -23.4297
+Iteration 3 likelihood: -23.4297
 
 Inferred Factor Graph:
 ######################
 
 Inferred Factor Graph:
 ######################
@@ -56,9 +56,9 @@ Sample #16 has 2 observations.
 Sample #17 has 2 observations.
 Sample #18 has 2 observations.
 Sample #19 has 2 observations.
 Sample #17 has 2 observations.
 Sample #18 has 2 observations.
 Sample #19 has 2 observations.
-Iteration 1 likelihood: 0
-Iteration 2 likelihood: 3.97035
-Iteration 3 likelihood: 3.97035
+Iteration 1 likelihood: -27.7259
+Iteration 2 likelihood: -23.7555
+Iteration 3 likelihood: -23.7555
 
 Inferred Factor Graph:
 ######################
 
 Inferred Factor Graph:
 ######################
@@ -82,11 +82,9 @@ Sample #1 has 4 observations.
 Sample #2 has 6 observations.
 Sample #3 has 6 observations.
 Sample #4 has 5 observations.
 Sample #2 has 6 observations.
 Sample #3 has 6 observations.
 Sample #4 has 5 observations.
-Iteration 1 likelihood: 11.1646
-Iteration 2 likelihood: 1.53723
-Iteration 3 likelihood: 1.64691
-Iteration 4 likelihood: 1.67497
-Iteration 5 likelihood: 1.68191
+Iteration 1 likelihood: -16.4893
+Iteration 2 likelihood: -15.4406
+Iteration 3 likelihood: -15.3299
 
 Inferred Factor Graph:
 ######################
 
 Inferred Factor Graph:
 ######################
@@ -96,14 +94,14 @@ Inferred Factor Graph:
 2 6 7 
 2 2 2 
 8
 2 6 7 
 2 2 2 
 8
-0   0.398343360806
-1   0.351464146565
-2   0.601656639194
-3   0.648535853435
-4   0.804844687957
-5    0.67374245803
-6   0.195155312043
-7    0.32625754197
+0   0.396387112396
+1   0.356649908835
+2   0.603612887604
+3   0.643350091165
+4   0.805779039527
+5   0.669278114097
+6   0.194220960473
+7   0.330721885903
 
 3
 0 1 6 
 
 3
 0 1 6 
@@ -122,23 +120,23 @@ Inferred Factor Graph:
 1 2 4 
 2 2 2 
 8
 1 2 4 
 2 2 2 
 8
-0   0.398343360806
-1   0.601656639194
-2   0.351464146565
-3   0.648535853435
-4   0.804844687957
-5   0.195155312043
-6    0.67374245803
-7    0.32625754197
+0   0.396387112396
+1   0.603612887604
+2   0.356649908835
+3   0.643350091165
+4   0.805779039527
+5   0.194220960473
+6   0.669278114097
+7   0.330721885903
 Number of samples: 5
 Sample #0 has 5 observations.
 Sample #1 has 4 observations.
 Sample #2 has 6 observations.
 Sample #3 has 6 observations.
 Sample #4 has 5 observations.
 Number of samples: 5
 Sample #0 has 5 observations.
 Sample #1 has 4 observations.
 Sample #2 has 6 observations.
 Sample #3 has 6 observations.
 Sample #4 has 5 observations.
-Iteration 1 likelihood: 11.1646
-Iteration 2 likelihood: -7.29331
-Iteration 3 likelihood: -7.261
+Iteration 1 likelihood: -16.4893
+Iteration 2 likelihood: -17.6905
+Iteration 3 likelihood: -17.6582
 
 Inferred Factor Graph:
 ######################
 
 Inferred Factor Graph:
 ######################