New git HEAD version
[libdai.git] / src / mf.cpp
index 227b11a..4c1d932 100644 (file)
@@ -1,14 +1,15 @@
 /*  This file is part of libDAI - http://www.libdai.org/
  *
- *  libDAI is licensed under the terms of the GNU General Public License version
- *  2, or (at your option) any later version. libDAI is distributed without any
- *  warranty. See the file COPYING for more details.
+ *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
  *
- *  Copyright (C) 2006-2010  Joris Mooij  [joris dot mooij at libdai dot org]
- *  Copyright (C) 2006-2007  Radboud University Nijmegen, The Netherlands
+ *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
  */
 
 
+#include <dai/dai_config.h>
+#ifdef DAI_WITH_MF
+
+
 #include <iostream>
 #include <sstream>
 #include <map>
@@ -23,9 +24,6 @@ namespace dai {
 using namespace std;
 
 
-const char *MF::Name = "MF";
-
-
 void MF::setProperties( const PropertySet &opts ) {
     DAI_ASSERT( opts.hasKey("tol") );
     DAI_ASSERT( opts.hasKey("maxiter") );
@@ -40,15 +38,25 @@ void MF::setProperties( const PropertySet &opts ) {
         props.damping = opts.getStringAs<Real>("damping");
     else
         props.damping = 0.0;
+    if( opts.hasKey("init") )
+        props.init = opts.getStringAs<Properties::InitType>("init");
+    else
+        props.init = Properties::InitType::UNIFORM;
+    if( opts.hasKey("updates") )
+        props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+    else
+        props.updates = Properties::UpdateType::NAIVE;
 }
 
 
 PropertySet MF::getProperties() const {
     PropertySet opts;
-    opts.Set( "tol", props.tol );
-    opts.Set( "maxiter", props.maxiter );
-    opts.Set( "verbose", props.verbose );
-    opts.Set( "damping", props.damping );
+    opts.set( "tol", props.tol );
+    opts.set( "maxiter", props.maxiter );
+    opts.set( "verbose", props.verbose );
+    opts.set( "damping", props.damping );
+    opts.set( "init", props.init );
+    opts.set( "updates", props.updates );
     return opts;
 }
 
@@ -59,6 +67,8 @@ string MF::printProperties() const {
     s << "tol=" << props.tol << ",";
     s << "maxiter=" << props.maxiter << ",";
     s << "verbose=" << props.verbose << ",";
+    s << "init=" << props.init << ",";
+    s << "updates=" << props.updates << ",";
     s << "damping=" << props.damping << "]";
     return s.str();
 }
@@ -73,29 +83,31 @@ void MF::construct() {
 }
 
 
-string MF::identify() const {
-    return string(Name) + printProperties();
-}
-
-
 void MF::init() {
-    for( vector<Factor>::iterator qi = _beliefs.begin(); qi != _beliefs.end(); qi++ )
-        qi->fill(1.0);
+    if( props.init == Properties::InitType::UNIFORM )
+        for( size_t i = 0; i < nrVars(); i++ )
+            _beliefs[i].fill( 1.0 );
+    else
+        for( size_t i = 0; i < nrVars(); i++ )
+            _beliefs[i].randomize();
 }
 
 
 Factor MF::calcNewBelief( size_t i ) {
     Factor result;
-    foreach( const Neighbor &I, nbV(i) ) {
-        Factor henk;
-        foreach( const Neighbor &j, nbF(I) ) // for all j in I \ i
+    bforeach( const Neighbor &I, nbV(i) ) {
+        Factor belief_I_minus_i;
+        bforeach( const Neighbor &j, nbF(I) ) // for all j in I \ i
             if( j != i )
-                henk *= _beliefs[j];
-        Factor piet = factor(I).log(true);
-        piet *= henk;
-        piet = piet.marginal(var(i), false);
-        piet = piet.exp();
-        result *= piet;
+                belief_I_minus_i *= _beliefs[j];
+        Factor f_I = factor(I);
+        if( props.updates == Properties::UpdateType::NAIVE )
+            f_I.takeLog(true);
+        Factor msg_I_i = (belief_I_minus_i * f_I).marginal( var(i), false );
+        if( props.updates == Properties::UpdateType::NAIVE )
+            result *= msg_I_i.exp();
+        else
+            result *= msg_I_i;
     }
     result.normalize();
     return result;
@@ -117,26 +129,26 @@ Real MF::run() {
     // been reached or until the maximum belief difference is smaller than tolerance
     Real maxDiff = INFINITY;
     for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol; _iters++ ) {
-        random_shuffle( update_seq.begin(), update_seq.end() );
+        random_shuffle( update_seq.begin(), update_seq.end(), rnd );
 
         maxDiff = -INFINITY;
-        foreach( const size_t &i, update_seq ) {
+        bforeach( const size_t &i, update_seq ) {
             Factor nb = calcNewBelief( i );
 
             if( nb.hasNaNs() ) {
-                cerr << Name << "::run():  ERROR: new belief of variable " << var(i) << " has NaNs!" << endl;
+                cerr << name() << "::run():  ERROR: new belief of variable " << var(i) << " has NaNs!" << endl;
                 return 1.0;
             }
 
             if( props.damping != 0.0 )
                 nb = (nb^(1.0 - props.damping)) * (_beliefs[i]^props.damping);
 
-            maxDiff = std::max( maxDiff, dist( nb, _beliefs[i], Prob::DISTLINF ) );
+            maxDiff = std::max( maxDiff, dist( nb, _beliefs[i], DISTLINF ) );
             _beliefs[i] = nb;
         }
 
         if( props.verbose >= 3 )
-            cerr << Name << "::run:  maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
+            cerr << name() << "::run:  maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
     }
 
     if( maxDiff > _maxdiff )
@@ -146,10 +158,10 @@ Real MF::run() {
         if( maxDiff > props.tol ) {
             if( props.verbose == 1 )
                 cerr << endl;
-            cerr << Name << "::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
+            cerr << name() << "::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
         } else {
             if( props.verbose >= 3 )
-                cerr << Name << "::run:  ";
+                cerr << name() << "::run:  ";
             cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
         }
     }
@@ -190,7 +202,7 @@ Real MF::logZ() const {
         s -= beliefV(i).entropy();
     for( size_t I = 0; I < nrFactors(); I++ ) {
         Factor henk;
-        foreach( const Neighbor &j, nbF(I) )  // for all j in I
+        bforeach( const Neighbor &j, nbF(I) )  // for all j in I
             henk *= _beliefs[j];
         henk.normalize();
         Factor piet;
@@ -204,11 +216,17 @@ Real MF::logZ() const {
 
 
 void MF::init( const VarSet &ns ) {
-    for( size_t i = 0; i < nrVars(); i++ ) {
-        if( ns.contains(var(i) ) )
-            _beliefs[i].fill( 1.0 );
-    }
+    for( size_t i = 0; i < nrVars(); i++ )
+        if( ns.contains(var(i) ) ) {
+            if( props.init == Properties::InitType::UNIFORM )
+                _beliefs[i].fill( 1.0 );
+            else
+                _beliefs[i].randomize();
+        }
 }
 
 
 } // end of namespace dai
+
+
+#endif