Improved documentation of bipgraph.h and added example_bipgraph.cpp
[libdai.git] / src / bp.cpp
index fb7220a..8e9f7f4 100644 (file)
@@ -1,6 +1,7 @@
-/*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
-    Radboud University Nijmegen, The Netherlands
-    
+/*  Copyright (C) 2006-2008  Joris Mooij  [joris dot mooij at tuebingen dot mpg dot de]
+    Radboud University Nijmegen, The Netherlands /
+    Max Planck Institute for Biological Cybernetics, Germany
+
     This file is part of libDAI.
 
     libDAI is free software; you can redistribute it and/or modify
@@ -25,7 +26,6 @@
 #include <set>
 #include <algorithm>
 #include <dai/bp.h>
-#include <dai/diffs.h>
 #include <dai/util.h>
 #include <dai/properties.h>
 
@@ -39,6 +39,9 @@ using namespace std;
 const char *BP::Name = "BP";
 
 
+#define DAI_BP_FAST 1
+
+
 void BP::setProperties( const PropertySet &opts ) {
     assert( opts.hasKey("tol") );
     assert( opts.hasKey("maxiter") );
@@ -58,6 +61,10 @@ void BP::setProperties( const PropertySet &opts ) {
         props.damping = opts.getStringAs<double>("damping");
     else
         props.damping = 0.0;
+    if( opts.hasKey("inference") )
+        props.inference = opts.getStringAs<Properties::InfType>("inference");
+    else
+        props.inference = Properties::InfType::SUMPROD;
 }
 
 
@@ -69,6 +76,7 @@ PropertySet BP::getProperties() const {
     opts.Set( "logdomain", props.logdomain );
     opts.Set( "updates", props.updates );
     opts.Set( "damping", props.damping );
+    opts.Set( "inference", props.inference );
     return opts;
 }
 
@@ -81,7 +89,8 @@ string BP::printProperties() const {
     s << "verbose=" << props.verbose << ",";
     s << "logdomain=" << props.logdomain << ",";
     s << "updates=" << props.updates << ",";
-    s << "damping=" << props.damping << "]";
+    s << "damping=" << props.damping << ",";
+    s << "inference=" << props.inference << "]";
     return s.str();
 }
 
@@ -98,9 +107,11 @@ void BP::construct() {
             newEP.message = Prob( var(i).states() );
             newEP.newMessage = Prob( var(i).states() );
 
-            newEP.index.reserve( factor(I).states() );
-            for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
-                newEP.index.push_back( k );
+            if( DAI_BP_FAST ) {
+                newEP.index.reserve( factor(I).states() );
+                for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
+                    newEP.index.push_back( k );
+            }
 
             newEP.residual = 0.0;
             _edges[i].push_back( newEP );
@@ -138,7 +149,7 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
     // calculate updated message I->i
     size_t I = nbV(i,_I);
 
-    if( 0 == 1 ) {
+    if( !DAI_BP_FAST ) {
         /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
         Factor prod( factor( I ) );
         foreach( const Neighbor &j, nbF(I) )
@@ -189,8 +200,13 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
         Prob marg( var(i).states(), 0.0 );
         // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
         const ind_t ind = index(i,_I);
-        for( size_t r = 0; r < prod.size(); ++r )
-            marg[ind[r]] += prod[r];
+        if( props.inference == Properties::InfType::SUMPROD ) 
+            for( size_t r = 0; r < prod.size(); ++r )
+                marg[ind[r]] += prod[r];
+        else
+            for( size_t r = 0; r < prod.size(); ++r )
+                if( prod[r] > marg[ind[r]] ) 
+                    marg[ind[r]] = prod[r];
         marg.normalize();
 
         // Store result
@@ -359,7 +375,7 @@ Factor BP::belief( const VarSet &ns ) const {
 
 
 Factor BP::beliefF (size_t I) const {
-    if( 0 == 1 ) {
+    if( !DAI_BP_FAST ) {
         /*  UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
 
         Factor prod( factor(I) );
@@ -419,7 +435,7 @@ Real BP::logZ() const {
     for(size_t i = 0; i < nrVars(); ++i )
         sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
     for( size_t I = 0; I < nrFactors(); ++I )
-        sum -= KL_dist( beliefF(I), factor(I) );
+        sum -= dist( beliefF(I), factor(I), Prob::DISTKL );
     return sum;
 }