Merge branch 'master' of git.tuebingen.mpg.de:libdai
[libdai.git] / src / mf.cpp
index a143608..42ede4a 100644 (file)
@@ -23,9 +23,6 @@ namespace dai {
 using namespace std;
 
 
-const char *MF::Name = "MF";
-
-
 void MF::setProperties( const PropertySet &opts ) {
     DAI_ASSERT( opts.hasKey("tol") );
     DAI_ASSERT( opts.hasKey("maxiter") );
@@ -40,6 +37,14 @@ void MF::setProperties( const PropertySet &opts ) {
         props.damping = opts.getStringAs<Real>("damping");
     else
         props.damping = 0.0;
+    if( opts.hasKey("init") )
+        props.init = opts.getStringAs<Properties::InitType>("init");
+    else
+        props.init = Properties::InitType::UNIFORM;
+    if( opts.hasKey("updates") )
+        props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+    else
+        props.updates = Properties::UpdateType::NAIVE;
 }
 
 
@@ -49,6 +54,8 @@ PropertySet MF::getProperties() const {
     opts.set( "maxiter", props.maxiter );
     opts.set( "verbose", props.verbose );
     opts.set( "damping", props.damping );
+    opts.set( "init", props.init );
+    opts.set( "updates", props.updates );
     return opts;
 }
 
@@ -59,6 +66,8 @@ string MF::printProperties() const {
     s << "tol=" << props.tol << ",";
     s << "maxiter=" << props.maxiter << ",";
     s << "verbose=" << props.verbose << ",";
+    s << "init=" << props.init << ",";
+    s << "updates=" << props.updates << ",";
     s << "damping=" << props.damping << "]";
     return s.str();
 }
@@ -73,29 +82,31 @@ void MF::construct() {
 }
 
 
-string MF::identify() const {
-    return string(Name) + printProperties();
-}
-
-
 void MF::init() {
-    for( vector<Factor>::iterator qi = _beliefs.begin(); qi != _beliefs.end(); qi++ )
-        qi->fill(1.0);
+    if( props.init == Properties::InitType::UNIFORM )
+        for( size_t i = 0; i < nrVars(); i++ )
+            _beliefs[i].fill( 1.0 );
+    else
+        for( size_t i = 0; i < nrVars(); i++ )
+            _beliefs[i].randomize();
 }
 
 
 Factor MF::calcNewBelief( size_t i ) {
     Factor result;
     foreach( const Neighbor &I, nbV(i) ) {
-        Factor henk;
+        Factor belief_I_minus_i;
         foreach( const Neighbor &j, nbF(I) ) // for all j in I \ i
             if( j != i )
-                henk *= _beliefs[j];
-        Factor piet = factor(I).log(true);
-        piet *= henk;
-        piet = piet.marginal(var(i), false);
-        piet = piet.exp();
-        result *= piet;
+                belief_I_minus_i *= _beliefs[j];
+        Factor f_I = factor(I);
+        if( props.updates == Properties::UpdateType::NAIVE )
+            f_I.takeLog();
+        Factor msg_I_i = (belief_I_minus_i * f_I).marginal( var(i), false );
+        if( props.updates == Properties::UpdateType::NAIVE )
+            result *= msg_I_i.exp();
+        else
+            result *= msg_I_i;
     }
     result.normalize();
     return result;
@@ -124,19 +135,19 @@ Real MF::run() {
             Factor nb = calcNewBelief( i );
 
             if( nb.hasNaNs() ) {
-                cerr << Name << "::run():  ERROR: new belief of variable " << var(i) << " has NaNs!" << endl;
+                cerr << name() << "::run():  ERROR: new belief of variable " << var(i) << " has NaNs!" << endl;
                 return 1.0;
             }
 
             if( props.damping != 0.0 )
                 nb = (nb^(1.0 - props.damping)) * (_beliefs[i]^props.damping);
 
-            maxDiff = std::max( maxDiff, dist( nb, _beliefs[i], Prob::DISTLINF ) );
+            maxDiff = std::max( maxDiff, dist( nb, _beliefs[i], DISTLINF ) );
             _beliefs[i] = nb;
         }
 
         if( props.verbose >= 3 )
-            cerr << Name << "::run:  maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
+            cerr << name() << "::run:  maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
     }
 
     if( maxDiff > _maxdiff )
@@ -146,10 +157,10 @@ Real MF::run() {
         if( maxDiff > props.tol ) {
             if( props.verbose == 1 )
                 cerr << endl;
-            cerr << Name << "::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
+            cerr << name() << "::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
         } else {
             if( props.verbose >= 3 )
-                cerr << Name << "::run:  ";
+                cerr << name() << "::run:  ";
             cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
         }
     }
@@ -204,10 +215,13 @@ Real MF::logZ() const {
 
 
 void MF::init( const VarSet &ns ) {
-    for( size_t i = 0; i < nrVars(); i++ ) {
-        if( ns.contains(var(i) ) )
-            _beliefs[i].fill( 1.0 );
-    }
+    for( size_t i = 0; i < nrVars(); i++ )
+        if( ns.contains(var(i) ) ) {
+            if( props.init == Properties::InitType::UNIFORM )
+                _beliefs[i].fill( 1.0 );
+            else
+                _beliefs[i].randomize();
+        }
 }