Improved documentation of bipgraph.h and added example_bipgraph.cpp
[libdai.git] / src / bp.cpp
index 8162bf7..8e9f7f4 100644 (file)
@@ -1,6 +1,7 @@
-/*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
-    Radboud University Nijmegen, The Netherlands
-    
+/*  Copyright (C) 2006-2008  Joris Mooij  [joris dot mooij at tuebingen dot mpg dot de]
+    Radboud University Nijmegen, The Netherlands /
+    Max Planck Institute for Biological Cybernetics, Germany
+
     This file is part of libDAI.
 
     libDAI is free software; you can redistribute it and/or modify
@@ -25,7 +26,6 @@
 #include <set>
 #include <algorithm>
 #include <dai/bp.h>
-#include <dai/diffs.h>
 #include <dai/util.h>
 #include <dai/properties.h>
 
@@ -39,6 +39,9 @@ using namespace std;
 const char *BP::Name = "BP";
 
 
+#define DAI_BP_FAST 1
+
+
 void BP::setProperties( const PropertySet &opts ) {
     assert( opts.hasKey("tol") );
     assert( opts.hasKey("maxiter") );
@@ -58,6 +61,10 @@ void BP::setProperties( const PropertySet &opts ) {
         props.damping = opts.getStringAs<double>("damping");
     else
         props.damping = 0.0;
+    if( opts.hasKey("inference") )
+        props.inference = opts.getStringAs<Properties::InfType>("inference");
+    else
+        props.inference = Properties::InfType::SUMPROD;
 }
 
 
@@ -69,6 +76,7 @@ PropertySet BP::getProperties() const {
     opts.Set( "logdomain", props.logdomain );
     opts.Set( "updates", props.updates );
     opts.Set( "damping", props.damping );
+    opts.Set( "inference", props.inference );
     return opts;
 }
 
@@ -81,7 +89,8 @@ string BP::printProperties() const {
     s << "verbose=" << props.verbose << ",";
     s << "logdomain=" << props.logdomain << ",";
     s << "updates=" << props.updates << ",";
-    s << "damping=" << props.damping << "]";
+    s << "damping=" << props.damping << ",";
+    s << "inference=" << props.inference << "]";
     return s.str();
 }
 
@@ -98,9 +107,11 @@ void BP::construct() {
             newEP.message = Prob( var(i).states() );
             newEP.newMessage = Prob( var(i).states() );
 
-            newEP.index.reserve( factor(I).states() );
-            for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
-                newEP.index.push_back( k );
+            if( DAI_BP_FAST ) {
+                newEP.index.reserve( factor(I).states() );
+                for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
+                    newEP.index.push_back( k );
+            }
 
             newEP.residual = 0.0;
             _edges[i].push_back( newEP );
@@ -138,7 +149,7 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
     // calculate updated message I->i
     size_t I = nbV(i,_I);
 
-    if( 0 == 1 ) {
+    if( !DAI_BP_FAST ) {
         /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
         Factor prod( factor( I ) );
         foreach( const Neighbor &j, nbF(I) )
@@ -189,8 +200,13 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
         Prob marg( var(i).states(), 0.0 );
         // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
         const ind_t ind = index(i,_I);
-        for( size_t r = 0; r < prod.size(); ++r )
-            marg[ind[r]] += prod[r];
+        if( props.inference == Properties::InfType::SUMPROD ) 
+            for( size_t r = 0; r < prod.size(); ++r )
+                marg[ind[r]] += prod[r];
+        else
+            for( size_t r = 0; r < prod.size(); ++r )
+                if( prod[r] > marg[ind[r]] ) 
+                    marg[ind[r]] = prod[r];
         marg.normalize();
 
         // Store result
@@ -247,7 +263,6 @@ double BP::run() {
                 size_t i, _I;
                 findMaxResidual( i, _I );
                 updateMessage( i, _I );
-                residual( i, _I ) = 0.0;
 
                 // I->i has been updated, which means that residuals for all
                 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
@@ -291,7 +306,7 @@ double BP::run() {
         }
 
         if( props.verbose >= 3 )
-            cout << "BP::run:  maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
+            cout << Name << "::run:  maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
     }
 
     if( diffs.maxDiff() > _maxdiff )
@@ -301,11 +316,11 @@ double BP::run() {
         if( diffs.maxDiff() > props.tol ) {
             if( props.verbose == 1 )
                 cout << endl;
-                cout << "BP::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
+                cout << Name << "::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
         } else {
             if( props.verbose >= 3 )
-                cout << "BP::run:  ";
-                cout << "converged in " << _iters << " passes (" << toc() - tic << " clocks)." << endl;
+                cout << Name << "::run:  ";
+                cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
         }
     }
 
@@ -360,7 +375,7 @@ Factor BP::belief( const VarSet &ns ) const {
 
 
 Factor BP::beliefF (size_t I) const {
-    if( 0 == 1 ) {
+    if( !DAI_BP_FAST ) {
         /*  UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
 
         Factor prod( factor(I) );
@@ -420,7 +435,7 @@ Real BP::logZ() const {
     for(size_t i = 0; i < nrVars(); ++i )
         sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
     for( size_t I = 0; I < nrFactors(); ++I )
-        sum -= KL_dist( beliefF(I), factor(I) );
+        sum -= dist( beliefF(I), factor(I), Prob::DISTKL );
     return sum;
 }