Merge branch 'vaskeEmFix' of git://disco.cse.ucsc.edu/libDAI into mergeVaskeEmFix
[libdai.git] / src / emalg.cpp
index dff6913..62102ef 100644 (file)
@@ -193,8 +193,8 @@ SharedParameters::SharedParameters( const SharedParameters &sp )
 }
 
 
-SharedParameters::SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation )
-  : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _deleteEstimation(false
+SharedParameters::SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool deleteParameterEstimationInDestructor )
+  : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _deleteEstimation(deleteParameterEstimationInDestructor
 {
     // Calculate the necessary permutations
     setPermsAndVarSetsFromVarOrders();
@@ -230,6 +230,26 @@ void SharedParameters::setParameters( FactorGraph &fg ) {
 }
 
 
+void SharedParameters::collectParameters( const FactorGraph& fg, std::vector< Real >& outVals, std::vector< Var >& outVarOrder ) {
+    FactorOrientations::iterator it = _varorders.begin();
+    if (it == _varorders.end()) {
+       return;
+    }
+    FactorIndex i = it->first;
+    std::vector< Var >::iterator var_it = _varorders[i].begin();
+    std::vector< Var >::iterator var_stop = _varorders[i].end();
+    for ( ; var_it != var_stop; ++var_it) {
+       outVarOrder.push_back(*var_it);
+    }
+    const Factor& f = fg.factor(i);
+    assert(f.vars() == _varsets[i]);
+    const Permute& perm = _perms[i];
+    for (size_t val_index = 0; val_index < f.states(); ++val_index) {
+       outVals.push_back(f[perm.convert_linear_index(val_index)]);
+    }
+}
+
+
 MaximizationStep::MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup ) : _params() {
     size_t num_params = -1;
     is >> num_params;
@@ -295,32 +315,36 @@ bool EMAlg::hasSatisfiedTermConditions() const {
             std::cerr << "Error: in EM log-likehood decreased from " << previous << " to " << current << std::endl;
             return true;
         }
-        return diff / abs(previous) <= _log_z_tol;
+        return (diff / fabs(previous)) <= _log_z_tol;
     }
 }
 
 
 Real EMAlg::iterate( MaximizationStep &mstep ) {
     Real logZ = 0;
-
+       Real likelihood = 0;
+       
+       _estep.run();
+       logZ = _estep.logZ();
+       
     // Expectation calculation
     for( Evidence::const_iterator e = _evidence.begin(); e != _evidence.end(); ++e ) {
         InfAlg* clamped = _estep.clone();
         e->applyEvidence( *clamped );
         clamped->init();
         clamped->run();
-      
-        logZ += clamped->logZ();
-
+               
+               likelihood += clamped->logZ() - logZ;
+               
         mstep.addExpectations( *clamped );
-
+               
         delete clamped;
     }
     
     // Maximization of parameters
     mstep.maximize( _estep.fg() );
-
-    return logZ;
+       
+       return likelihood;
 }