Merge branch 'master' of git@git.tuebingen.mpg.de:libdai
[libdai.git] / src / bp.cpp
index 0d65992..0352565 100644 (file)
@@ -39,26 +39,33 @@ using namespace std;
 const char *BP::Name = "BP";
 
 
-bool BP::checkProperties() {
-    if( !HasProperty("updates") )
-        return false;
-    if( !HasProperty("tol") )
-        return false;
-    if (!HasProperty("maxiter") )
-        return false;
-    if (!HasProperty("verbose") )
-        return false;
+void BP::setProperties( const PropertySet &opts ) {
+    assert( opts.hasKey("tol") );
+    assert( opts.hasKey("maxiter") );
+    assert( opts.hasKey("verbose") );
+    assert( opts.hasKey("logdomain") );
+    assert( opts.hasKey("updates") );
     
-    ConvertPropertyTo<double>("tol");
-    ConvertPropertyTo<size_t>("maxiter");
-    ConvertPropertyTo<size_t>("verbose");
-    ConvertPropertyTo<UpdateType>("updates");
+    props.tol = opts.getStringAs<double>("tol");
+    props.maxiter = opts.getStringAs<size_t>("maxiter");
+    props.verbose = opts.getStringAs<size_t>("verbose");
+    props.logdomain = opts.getStringAs<bool>("logdomain");
+    props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+}
+
 
-    return true;
+PropertySet BP::getProperties() const {
+    PropertySet opts;
+    opts.Set( "tol", props.tol );
+    opts.Set( "maxiter", props.maxiter );
+    opts.Set( "verbose", props.verbose );
+    opts.Set( "logdomain", props.logdomain );
+    opts.Set( "updates", props.updates );
+    return opts;
 }
 
 
-void BP::Regenerate() {
+void BP::create() {
     // create edge properties
     edges.clear();
     edges.reserve( nrVars() );
@@ -70,8 +77,8 @@ void BP::Regenerate() {
             newEP.message = Prob( var(i).states() );
             newEP.newMessage = Prob( var(i).states() );
 
-            newEP.index.reserve( factor(I).stateSpace() );
-            for( Index k( var(i), factor(I).vars() ); k >= 0; ++k )
+            newEP.index.reserve( factor(I).states() );
+            for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
                 newEP.index.push_back( k );
 
             newEP.residual = 0.0;
@@ -82,11 +89,15 @@ void BP::Regenerate() {
 
 
 void BP::init() {
-    assert( checkProperties() );
     for( size_t i = 0; i < nrVars(); ++i ) {
         foreach( const Neighbor &I, nbV(i) ) {
-            message( i, I.iter ).fill( 1.0 );
-            newMessage( i, I.iter ).fill( 1.0 );
+            if( props.logdomain ) {
+                message( i, I.iter ).fill( 0.0 );
+                newMessage( i, I.iter ).fill( 0.0 );
+            } else {
+                message( i, I.iter ).fill( 1.0 );
+                newMessage( i, I.iter ).fill( 1.0 );
+            }
         }
     }
 }
@@ -122,52 +133,66 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
 */
     
     Prob prod( factor(I).p() );
+    if( props.logdomain ) 
+        prod.takeLog();
 
     // Calculate product of incoming messages and factor I
     foreach( const Neighbor &j, nbF(I) ) {
         if( j != i ) {     // for all j in I \ i
             size_t _I = j.dual;
-            // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+            // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
             const ind_t & ind = index(j, _I);
 
             // prod_j will be the product of messages coming into j
-            Prob prod_j( var(j).states() ); 
-            foreach( const Neighbor &J, nbV(j) ) {
-                if( J != I )   // for all J in nb(j) \ I 
-                    prod_j *= message( j, J.iter );
-            }
+            Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
+            foreach( const Neighbor &J, nbV(j) )
+                if( J != I ) { // for all J in nb(j) \ I 
+                    if( props.logdomain )
+                        prod_j += message( j, J.iter );
+                    else
+                        prod_j *= message( j, J.iter );
+                }
 
             // multiply prod with prod_j
             for( size_t r = 0; r < prod.size(); ++r )
-                prod[r] *= prod_j[ind[r]];
+                if( props.logdomain )
+                    prod[r] += prod_j[ind[r]];
+                else
+                    prod[r] *= prod_j[ind[r]];
         }
     }
+    if( props.logdomain ) {
+        prod -= prod.maxVal();
+        prod.takeExp();
+    }
 
     // Marginalize onto i
     Prob marg( var(i).states(), 0.0 );
-    // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
+    // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
     const ind_t ind = index(i,_I);
     for( size_t r = 0; r < prod.size(); ++r )
         marg[ind[r]] += prod[r];
-    marg.normalize( _normtype );
+    marg.normalize( Prob::NORMPROB );
     
     // Store result
-    newMessage(i,_I) = marg;
+    if( props.logdomain )
+        newMessage(i,_I) = marg.log();
+    else
+        newMessage(i,_I) = marg;
 }
 
 
 // BP::run does not check for NANs for performance reasons
 // Somehow NaNs do not often occur in BP...
 double BP::run() {
-    if( Verbose() >= 1 )
+    if( props.verbose >= 1 )
         cout << "Starting " << identify() << "...";
-    if( Verbose() >= 3)
+    if( props.verbose >= 3)
        cout << endl; 
 
-    clock_t tic = toc();
+    double tic = toc();
     Diffs diffs(nrVars(), 1.0);
     
-    typedef pair<size_t,size_t> Edge;
     vector<Edge> update_seq;
 
     vector<Factor> old_beliefs;
@@ -178,7 +203,7 @@ double BP::run() {
     size_t iter = 0;
     size_t nredges = nrEdges();
 
-    if( Updates() == UpdateType::SEQMAX ) {
+    if( props.updates == Properties::UpdateType::SEQMAX ) {
         // do the first pass
         for( size_t i = 0; i < nrVars(); ++i )
             foreach( const Neighbor &I, nbV(i) ) {
@@ -195,8 +220,8 @@ double BP::run() {
 
     // do several passes over the network until maximum number of iterations has
     // been reached or until the maximum belief difference is smaller than tolerance
-    for( iter=0; iter < MaxIter() && diffs.max() > Tol(); ++iter ) {
-        if( Updates() == UpdateType::SEQMAX ) {
+    for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; ++iter ) {
+        if( props.updates == Properties::UpdateType::SEQMAX ) {
             // Residuals-BP by Koller et al.
             for( size_t t = 0; t < nredges; ++t ) {
                 // update the message with the largest residual
@@ -220,7 +245,7 @@ double BP::run() {
                     }
                 }
             }
-        } else if( Updates() == UpdateType::PARALL ) {
+        } else if( props.updates == Properties::UpdateType::PARALL ) {
             // Parallel updates 
             for( size_t i = 0; i < nrVars(); ++i )
                 foreach( const Neighbor &I, nbV(i) )
@@ -231,7 +256,7 @@ double BP::run() {
                     message( i, I.iter ) = newMessage( i, I.iter );
         } else {
             // Sequential updates
-            if( Updates() == UpdateType::SEQRND )
+            if( props.updates == Properties::UpdateType::SEQRND )
                 random_shuffle( update_seq.begin(), update_seq.end() );
             
             foreach( const Edge &e, update_seq ) {
@@ -247,32 +272,40 @@ double BP::run() {
             old_beliefs[i] = nb;
         }
 
-        if( Verbose() >= 3 )
-            cout << "BP::run:  maxdiff " << diffs.max() << " after " << iter+1 << " passes" << endl;
+        if( props.verbose >= 3 )
+            cout << "BP::run:  maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
     }
 
-    updateMaxDiff( diffs.max() );
+    if( diffs.maxDiff() > maxdiff )
+        maxdiff = diffs.maxDiff();
 
-    if( Verbose() >= 1 ) {
-        if( diffs.max() > Tol() ) {
-            if( Verbose() == 1 )
+    if( props.verbose >= 1 ) {
+        if( diffs.maxDiff() > props.tol ) {
+            if( props.verbose == 1 )
                 cout << endl;
-                cout << "BP::run:  WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
+                cout << "BP::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
         } else {
-            if( Verbose() >= 3 )
+            if( props.verbose >= 3 )
                 cout << "BP::run:  ";
                 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
         }
     }
 
-    return diffs.max();
+    return diffs.maxDiff();
 }
 
 
 Factor BP::beliefV( size_t i ) const {
-    Prob prod( var(i).states() ); 
+    Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 ); 
     foreach( const Neighbor &I, nbV(i) )
-        prod *= newMessage( i, I.iter );
+        if( props.logdomain )
+            prod += newMessage( i, I.iter );
+        else
+            prod *= newMessage( i, I.iter );
+    if( props.logdomain ) {
+        prod -= prod.maxVal();
+        prod.takeExp();
+    }
 
     prod.normalize( Prob::NORMPROB );
     return( Factor( var(i), prod ) );
@@ -310,22 +343,37 @@ Factor BP::belief( const VarSet &ns ) const {
 
 Factor BP::beliefF (size_t I) const {
     Prob prod( factor(I).p() );
+    if( props.logdomain )
+        prod.takeLog();
 
     foreach( const Neighbor &j, nbF(I) ) {
         size_t _I = j.dual;
-        // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+        // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
         const ind_t & ind = index(j, _I);
 
         // prod_j will be the product of messages coming into j
-        Prob prod_j( var(j).states() ); 
+        Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
         foreach( const Neighbor &J, nbV(j) ) {
-            if( J != I )   // for all J in nb(j) \ I 
-                prod_j *= newMessage( j, J.iter );
+            if( J != I ) { // for all J in nb(j) \ I 
+                if( props.logdomain )
+                    prod_j += newMessage( j, J.iter );
+                else
+                    prod_j *= newMessage( j, J.iter );
+            }
         }
 
         // multiply prod with prod_j
-        for( size_t r = 0; r < prod.size(); ++r )
-            prod[r] *= prod_j[ind[r]];
+        for( size_t r = 0; r < prod.size(); ++r ) {
+            if( props.logdomain )
+                prod[r] += prod_j[ind[r]];
+            else
+                prod[r] *= prod_j[ind[r]];
+        }
+    }
+
+    if( props.logdomain ) {
+        prod -= prod.maxVal();
+        prod.takeExp();
     }
 
     Factor result( factor(I).vars(), prod );
@@ -345,10 +393,10 @@ Factor BP::beliefF (size_t I) const {
 }
 
 
-Complex BP::logZ() const {
-    Complex sum = 0.0;
+Real BP::logZ() const {
+    Real sum = 0.0;
     for(size_t i = 0; i < nrVars(); ++i )
-        sum += Complex(1.0 - nbV(i).size()) * beliefV(i).entropy();
+        sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
     for( size_t I = 0; I < nrFactors(); ++I )
         sum -= KL_dist( beliefF(I), factor(I) );
     return sum;
@@ -357,7 +405,7 @@ Complex BP::logZ() const {
 
 string BP::identify() const { 
     stringstream result (stringstream::out);
-    result << Name << GetProperties();
+    result << Name << getProperties();
     return result.str();
 }
 
@@ -366,7 +414,7 @@ void BP::init( const VarSet &ns ) {
     for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
         size_t ni = findVar( *n );
         foreach( const Neighbor &I, nbV( ni ) )
-            message( ni, I.iter ).fill( 1.0 );
+            message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
     }
 }