Generalized VarSet to "template<typename T> small_set<T>"
[libdai.git] / src / bp.cpp
index fb7220a..ac431a2 100644 (file)
@@ -58,6 +58,10 @@ void BP::setProperties( const PropertySet &opts ) {
         props.damping = opts.getStringAs<double>("damping");
     else
         props.damping = 0.0;
+    if( opts.hasKey("inference") )
+        props.inference = opts.getStringAs<Properties::InfType>("inference");
+    else
+        props.inference = Properties::InfType::SUMPROD;
 }
 
 
@@ -69,6 +73,7 @@ PropertySet BP::getProperties() const {
     opts.Set( "logdomain", props.logdomain );
     opts.Set( "updates", props.updates );
     opts.Set( "damping", props.damping );
+    opts.Set( "inference", props.inference );
     return opts;
 }
 
@@ -81,7 +86,8 @@ string BP::printProperties() const {
     s << "verbose=" << props.verbose << ",";
     s << "logdomain=" << props.logdomain << ",";
     s << "updates=" << props.updates << ",";
-    s << "damping=" << props.damping << "]";
+    s << "damping=" << props.damping << ",";
+    s << "inference=" << props.inference << "]";
     return s.str();
 }
 
@@ -189,8 +195,13 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
         Prob marg( var(i).states(), 0.0 );
         // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
         const ind_t ind = index(i,_I);
-        for( size_t r = 0; r < prod.size(); ++r )
-            marg[ind[r]] += prod[r];
+        if( props.inference == Properties::InfType::SUMPROD ) 
+            for( size_t r = 0; r < prod.size(); ++r )
+                marg[ind[r]] += prod[r];
+        else
+            for( size_t r = 0; r < prod.size(); ++r )
+                if( prod[r] > marg[ind[r]] ) 
+                    marg[ind[r]] = prod[r];
         marg.normalize();
 
         // Store result