Fixed tabs and trailing whitespaces
[libdai.git] / src / emalg.cpp
index dff6913..5acc9cf 100644 (file)
@@ -56,8 +56,8 @@ ParameterEstimation* CondProbEstimation::factory( const PropertySet &p ) {
 }
 
 
-CondProbEstimation::CondProbEstimation( size_t target_dimension, const Prob &pseudocounts ) 
-  : _target_dim(target_dimension), _stats(pseudocounts), _initial_stats(pseudocounts) 
+CondProbEstimation::CondProbEstimation( size_t target_dimension, const Prob &pseudocounts )
+  : _target_dim(target_dimension), _stats(pseudocounts), _initial_stats(pseudocounts)
 {
     assert( !(_stats.size() % _target_dim) );
 }
@@ -100,13 +100,13 @@ Permute SharedParameters::calculatePermutation( const std::vector<Var> &varorder
         labels.push_back( varorder[i].label() );
         outVS |= varorder[i];
     }
-  
+
     // Construct the sigma array for the permutation object
     std::vector<size_t> sigma;
     sigma.reserve( dims.size() );
     for( VarSet::iterator set_iterator = outVS.begin(); sigma.size() < dims.size(); ++set_iterator )
         sigma.push_back( find(labels.begin(), labels.end(), set_iterator->label()) - labels.begin() );
-  
+
     return Permute( dims, sigma );
 }
 
@@ -127,7 +127,7 @@ void SharedParameters::setPermsAndVarSetsFromVarOrders() {
 
 
 SharedParameters::SharedParameters( std::istream &is, const FactorGraph &fg_varlookup )
-  : _varsets(), _perms(), _varorders(), _estimation(NULL), _deleteEstimation(true) 
+  : _varsets(), _perms(), _varorders(), _estimation(NULL), _deleteEstimation(true)
 {
     // Read the desired parameter estimation method from the stream
     std::string est_method;
@@ -170,7 +170,7 @@ SharedParameters::SharedParameters( std::istream &is, const FactorGraph &fg_varl
             labelparse >> label;
             VarSet::const_iterator vsi = vs.begin();
             for( ; vsi != vs.end(); ++vsi )
-                if( vsi->label() == label ) 
+                if( vsi->label() == label )
                     break;
             if( vsi == vs.end() )
                 DAI_THROW(INVALID_EMALG_FILE);
@@ -184,7 +184,7 @@ SharedParameters::SharedParameters( std::istream &is, const FactorGraph &fg_varl
 }
 
 
-SharedParameters::SharedParameters( const SharedParameters &sp ) 
+SharedParameters::SharedParameters( const SharedParameters &sp )
   : _varsets(sp._varsets), _perms(sp._perms), _varorders(sp._varorders), _estimation(sp._estimation), _deleteEstimation(sp._deleteEstimation)
 {
     // If sp owns its _estimation object, we should clone it instead
@@ -193,8 +193,8 @@ SharedParameters::SharedParameters( const SharedParameters &sp )
 }
 
 
-SharedParameters::SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation )
-  : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _deleteEstimation(false) 
+SharedParameters::SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool deletePE )
+  : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _deleteEstimation(deletePE)
 {
     // Calculate the necessary permutations
     setPermsAndVarSetsFromVarOrders();
@@ -205,7 +205,7 @@ void SharedParameters::collectSufficientStatistics( InfAlg &alg ) {
     for( std::map< FactorIndex, Permute >::iterator i = _perms.begin(); i != _perms.end(); ++i ) {
         Permute &perm = i->second;
         VarSet &vs = _varsets[i->first];
-        
+
         Factor b = alg.belief(vs);
         Prob p( b.states(), 0.0 );
         for( size_t entry = 0; entry < b.states(); ++entry )
@@ -220,7 +220,7 @@ void SharedParameters::setParameters( FactorGraph &fg ) {
     for( std::map<FactorIndex, Permute>::iterator i = _perms.begin(); i != _perms.end(); ++i ) {
         Permute &perm = i->second;
         VarSet &vs = _varsets[i->first];
-        
+
         Factor f( vs, 0.0 );
         for( size_t entry = 0; entry < f.states(); ++entry )
             f[perm.convert_linear_index(entry)] = p[entry];
@@ -230,6 +230,22 @@ void SharedParameters::setParameters( FactorGraph &fg ) {
 }
 
 
+void SharedParameters::collectParameters( const FactorGraph &fg, std::vector<Real> &outVals, std::vector<Var> &outVarOrder ) {
+    FactorOrientations::iterator it = _varorders.begin();
+    if( it == _varorders.end() )
+        return;
+    FactorIndex I = it->first;
+    for( std::vector<Var>::const_iterator var_it = _varorders[I].begin(); var_it != _varorders[I].end(); ++var_it )
+        outVarOrder.push_back( *var_it );
+
+    const Factor &f = fg.factor(I);
+    assert( f.vars() == _varsets[I] );
+    const Permute &perm = _perms[I];
+    for( size_t val_index = 0; val_index < f.states(); ++val_index )
+        outVals.push_back( f[perm.convert_linear_index(val_index)] );
+}
+
+
 MaximizationStep::MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup ) : _params() {
     size_t num_params = -1;
     is >> num_params;
@@ -266,7 +282,7 @@ EMAlg::EMAlg( const Evidence &evidence, InfAlg &estep, std::istream &msteps_file
     _msteps.reserve(num_msteps);
     for( size_t i = 0; i < num_msteps; ++i )
         _msteps.push_back( MaximizationStep( msteps_file, estep.fg() ) );
-}      
+}
 
 
 void EMAlg::setTermConditions( const PropertySet &p ) {
@@ -295,13 +311,17 @@ bool EMAlg::hasSatisfiedTermConditions() const {
             std::cerr << "Error: in EM log-likehood decreased from " << previous << " to " << current << std::endl;
             return true;
         }
-        return diff / abs(previous) <= _log_z_tol;
+        return (diff / fabs(previous)) <= _log_z_tol;
     }
 }
 
 
 Real EMAlg::iterate( MaximizationStep &mstep ) {
     Real logZ = 0;
+    Real likelihood = 0;
+
+    _estep.run();
+    logZ = _estep.logZ();
 
     // Expectation calculation
     for( Evidence::const_iterator e = _evidence.begin(); e != _evidence.end(); ++e ) {
@@ -309,18 +329,18 @@ Real EMAlg::iterate( MaximizationStep &mstep ) {
         e->applyEvidence( *clamped );
         clamped->init();
         clamped->run();
-      
-        logZ += clamped->logZ();
+
+        likelihood += clamped->logZ() - logZ;
 
         mstep.addExpectations( *clamped );
 
         delete clamped;
     }
-    
+
     // Maximization of parameters
     mstep.maximize( _estep.fg() );
 
-    return logZ;
+    return likelihood;
 }