Generalized VarSet to "template<typename T> small_set<T>"
[libdai.git] / src / bp.cpp
index 94680d6..ac431a2 100644 (file)
@@ -42,15 +42,26 @@ const char *BP::Name = "BP";
 void BP::setProperties( const PropertySet &opts ) {
     assert( opts.hasKey("tol") );
     assert( opts.hasKey("maxiter") );
-    assert( opts.hasKey("verbose") );
     assert( opts.hasKey("logdomain") );
     assert( opts.hasKey("updates") );
     
     props.tol = opts.getStringAs<double>("tol");
     props.maxiter = opts.getStringAs<size_t>("maxiter");
-    props.verbose = opts.getStringAs<size_t>("verbose");
     props.logdomain = opts.getStringAs<bool>("logdomain");
     props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+
+    if( opts.hasKey("verbose") )
+        props.verbose = opts.getStringAs<size_t>("verbose");
+    else
+        props.verbose = 0;
+    if( opts.hasKey("damping") )
+        props.damping = opts.getStringAs<double>("damping");
+    else
+        props.damping = 0.0;
+    if( opts.hasKey("inference") )
+        props.inference = opts.getStringAs<Properties::InfType>("inference");
+    else
+        props.inference = Properties::InfType::SUMPROD;
 }
 
 
@@ -61,17 +72,33 @@ PropertySet BP::getProperties() const {
     opts.Set( "verbose", props.verbose );
     opts.Set( "logdomain", props.logdomain );
     opts.Set( "updates", props.updates );
+    opts.Set( "damping", props.damping );
+    opts.Set( "inference", props.inference );
     return opts;
 }
 
 
-void BP::create() {
+string BP::printProperties() const {
+    stringstream s( stringstream::out );
+    s << "[";
+    s << "tol=" << props.tol << ",";
+    s << "maxiter=" << props.maxiter << ",";
+    s << "verbose=" << props.verbose << ",";
+    s << "logdomain=" << props.logdomain << ",";
+    s << "updates=" << props.updates << ",";
+    s << "damping=" << props.damping << ",";
+    s << "inference=" << props.inference << "]";
+    return s.str();
+}
+
+
+void BP::construct() {
     // create edge properties
-    edges.clear();
-    edges.reserve( nrVars() );
+    _edges.clear();
+    _edges.reserve( nrVars() );
     for( size_t i = 0; i < nrVars(); ++i ) {
-        edges.push_back( vector<EdgeProp>() );
-        edges[i].reserve( nbV(i).size() ); 
+        _edges.push_back( vector<EdgeProp>() );
+        _edges[i].reserve( nbV(i).size() ); 
         foreach( const Neighbor &I, nbV(i) ) {
             EdgeProp newEP;
             newEP.message = Prob( var(i).states() );
@@ -82,22 +109,18 @@ void BP::create() {
                 newEP.index.push_back( k );
 
             newEP.residual = 0.0;
-            edges[i].push_back( newEP );
+            _edges[i].push_back( newEP );
         }
     }
 }
 
 
 void BP::init() {
+    double c = props.logdomain ? 0.0 : 1.0;
     for( size_t i = 0; i < nrVars(); ++i ) {
         foreach( const Neighbor &I, nbV(i) ) {
-            if( props.logdomain ) {
-                message( i, I.iter ).fill( 0.0 );
-                newMessage( i, I.iter ).fill( 0.0 );
-            } else {
-                message( i, I.iter ).fill( 1.0 );
-                newMessage( i, I.iter ).fill( 1.0 );
-            }
+            message( i, I.iter ).fill( c );
+            newMessage( i, I.iter ).fill( c );
         }
     }
 }
@@ -121,64 +144,72 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
     // calculate updated message I->i
     size_t I = nbV(i,_I);
 
-/*  UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
-
-    Factor prod( factor( I ) );
-    for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ )
-        if( *j != i ) {     // for all j in I \ i
-            for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); J++ ) 
-                if( *J != I ) {     // for all J in nb(j) \ I 
-                    prod *= Factor( *j, message(*j,*J) );
-    Factor marg = prod.marginal(var(i));
-*/
-    
-    Prob prod( factor(I).p() );
-    if( props.logdomain ) 
-        prod.takeLog();
-
-    // Calculate product of incoming messages and factor I
-    foreach( const Neighbor &j, nbF(I) ) {
-        if( j != i ) {     // for all j in I \ i
-            size_t _I = j.dual;
-            // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
-            const ind_t & ind = index(j, _I);
+    if( 0 == 1 ) {
+        /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+        Factor prod( factor( I ) );
+        foreach( const Neighbor &j, nbF(I) )
+            if( j != i ) {     // for all j in I \ i
+                foreach( const Neighbor &J, nbV(j) )
+                    if( J != I ) {     // for all J in nb(j) \ I 
+                        prod *= Factor( var(j), message(j, J.iter) );
+                    }
+            }
+        newMessage(i,_I) = prod.marginal( var(i) ).p();
+    } else {
+        /* OPTIMIZED VERSION */
+        Prob prod( factor(I).p() );
+        if( props.logdomain ) 
+            prod.takeLog();
+
+        // Calculate product of incoming messages and factor I
+        foreach( const Neighbor &j, nbF(I) ) {
+            if( j != i ) {     // for all j in I \ i
+                size_t _I = j.dual;
+                // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+                const ind_t &ind = index(j, _I);
+
+                // prod_j will be the product of messages coming into j
+                Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
+                foreach( const Neighbor &J, nbV(j) )
+                    if( J != I ) { // for all J in nb(j) \ I 
+                        if( props.logdomain )
+                            prod_j += message( j, J.iter );
+                        else
+                            prod_j *= message( j, J.iter );
+                    }
 
-            // prod_j will be the product of messages coming into j
-            Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
-            foreach( const Neighbor &J, nbV(j) )
-                if( J != I ) { // for all J in nb(j) \ I 
+                // multiply prod with prod_j
+                for( size_t r = 0; r < prod.size(); ++r )
                     if( props.logdomain )
-                        prod_j += message( j, J.iter );
+                        prod[r] += prod_j[ind[r]];
                     else
-                        prod_j *= message( j, J.iter );
-                }
+                        prod[r] *= prod_j[ind[r]];
+            }
+        }
+        if( props.logdomain ) {
+            prod -= prod.maxVal();
+            prod.takeExp();
+        }
 
-            // multiply prod with prod_j
+        // Marginalize onto i
+        Prob marg( var(i).states(), 0.0 );
+        // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
+        const ind_t ind = index(i,_I);
+        if( props.inference == Properties::InfType::SUMPROD ) 
             for( size_t r = 0; r < prod.size(); ++r )
-                if( props.logdomain )
-                    prod[r] += prod_j[ind[r]];
-                else
-                    prod[r] *= prod_j[ind[r]];
-        }
-    }
-    if( props.logdomain ) {
-        prod -= prod.maxVal();
-        prod.takeExp();
-    }
+                marg[ind[r]] += prod[r];
+        else
+            for( size_t r = 0; r < prod.size(); ++r )
+                if( prod[r] > marg[ind[r]] ) 
+                    marg[ind[r]] = prod[r];
+        marg.normalize();
 
-    // Marginalize onto i
-    Prob marg( var(i).states(), 0.0 );
-    // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
-    const ind_t ind = index(i,_I);
-    for( size_t r = 0; r < prod.size(); ++r )
-        marg[ind[r]] += prod[r];
-    marg.normalize( _normtype );
-    
-    // Store result
-    if( props.logdomain )
-        newMessage(i,_I) = marg.log();
-    else
-        newMessage(i,_I) = marg;
+        // Store result
+        if( props.logdomain )
+            newMessage(i,_I) = marg.log();
+        else
+            newMessage(i,_I) = marg;
+    }
 }
 
 
@@ -200,7 +231,6 @@ double BP::run() {
     for( size_t i = 0; i < nrVars(); ++i )
         old_beliefs.push_back( beliefV(i) );
 
-    size_t iter = 0;
     size_t nredges = nrEdges();
 
     if( props.updates == Properties::UpdateType::SEQMAX ) {
@@ -220,16 +250,14 @@ double BP::run() {
 
     // do several passes over the network until maximum number of iterations has
     // been reached or until the maximum belief difference is smaller than tolerance
-    for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; ++iter ) {
+    for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; ++_iters ) {
         if( props.updates == Properties::UpdateType::SEQMAX ) {
             // Residuals-BP by Koller et al.
             for( size_t t = 0; t < nredges; ++t ) {
                 // update the message with the largest residual
-
                 size_t i, _I;
                 findMaxResidual( i, _I );
-                message( i, _I ) = newMessage( i, _I );
-                residual( i, _I ) = 0.0;
+                updateMessage( i, _I );
 
                 // I->i has been updated, which means that residuals for all
                 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
@@ -253,7 +281,7 @@ double BP::run() {
 
             for( size_t i = 0; i < nrVars(); ++i )
                 foreach( const Neighbor &I, nbV(i) )
-                    message( i, I.iter ) = newMessage( i, I.iter );
+                    updateMessage( i, I.iter );
         } else {
             // Sequential updates
             if( props.updates == Properties::UpdateType::SEQRND )
@@ -261,7 +289,7 @@ double BP::run() {
             
             foreach( const Edge &e, update_seq ) {
                 calcNewMessage( e.first, e.second );
-                message( e.first, e.second ) = newMessage( e.first, e.second );
+                updateMessage( e.first, e.second );
             }
         }
 
@@ -273,21 +301,21 @@ double BP::run() {
         }
 
         if( props.verbose >= 3 )
-            cout << "BP::run:  maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
+            cout << Name << "::run:  maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
     }
 
-    if( diffs.maxDiff() > maxdiff )
-        maxdiff = diffs.maxDiff();
+    if( diffs.maxDiff() > _maxdiff )
+        _maxdiff = diffs.maxDiff();
 
     if( props.verbose >= 1 ) {
         if( diffs.maxDiff() > props.tol ) {
             if( props.verbose == 1 )
                 cout << endl;
-                cout << "BP::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
+                cout << Name << "::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
         } else {
             if( props.verbose >= 3 )
-                cout << "BP::run:  ";
-                cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
+                cout << Name << "::run:  ";
+                cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
         }
     }
 
@@ -307,7 +335,7 @@ Factor BP::beliefV( size_t i ) const {
         prod.takeExp();
     }
 
-    prod.normalize( Prob::NORMPROB );
+    prod.normalize();
     return( Factor( var(i), prod ) );
 }
 
@@ -342,61 +370,65 @@ Factor BP::belief( const VarSet &ns ) const {
 
 
 Factor BP::beliefF (size_t I) const {
-    Prob prod( factor(I).p() );
-    if( props.logdomain )
-        prod.takeLog();
-
-    foreach( const Neighbor &j, nbF(I) ) {
-        size_t _I = j.dual;
-        // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
-        const ind_t & ind = index(j, _I);
-
-        // prod_j will be the product of messages coming into j
-        Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
-        foreach( const Neighbor &J, nbV(j) ) {
-            if( J != I ) { // for all J in nb(j) \ I 
+    if( 0 == 1 ) {
+        /*  UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+
+        Factor prod( factor(I) );
+        foreach( const Neighbor &j, nbF(I) ) {
+            foreach( const Neighbor &J, nbV(j) ) {
+                if( J != I )  // for all J in nb(j) \ I
+                    prod *= Factor( var(j), newMessage(j, J.iter) );
+            }
+        }
+        return prod.normalized();
+    } else {
+        /* OPTIMIZED VERSION */
+        Prob prod( factor(I).p() );
+        if( props.logdomain )
+            prod.takeLog();
+
+        foreach( const Neighbor &j, nbF(I) ) {
+            size_t _I = j.dual;
+            // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+            const ind_t & ind = index(j, _I);
+
+            // prod_j will be the product of messages coming into j
+            Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
+            foreach( const Neighbor &J, nbV(j) ) {
+                if( J != I ) { // for all J in nb(j) \ I 
+                    if( props.logdomain )
+                        prod_j += newMessage( j, J.iter );
+                    else
+                        prod_j *= newMessage( j, J.iter );
+                }
+            }
+
+            // multiply prod with prod_j
+            for( size_t r = 0; r < prod.size(); ++r ) {
                 if( props.logdomain )
-                    prod_j += newMessage( j, J.iter );
+                    prod[r] += prod_j[ind[r]];
                 else
-                    prod_j *= newMessage( j, J.iter );
+                    prod[r] *= prod_j[ind[r]];
             }
         }
 
-        // multiply prod with prod_j
-        for( size_t r = 0; r < prod.size(); ++r ) {
-            if( props.logdomain )
-                prod[r] += prod_j[ind[r]];
-            else
-                prod[r] *= prod_j[ind[r]];
+        if( props.logdomain ) {
+            prod -= prod.maxVal();
+            prod.takeExp();
         }
-    }
-
-    if( props.logdomain ) {
-        prod -= prod.maxVal();
-        prod.takeExp();
-    }
-
-    Factor result( factor(I).vars(), prod );
-    result.normalize( Prob::NORMPROB );
 
-    return( result );
+        Factor result( factor(I).vars(), prod );
+        result.normalize();
 
-/*  UNOPTIMIZED VERSION
-    Factor prod( factor(I) );
-    for( _nb_cit i = nb2(I).begin(); i != nb2(I).end(); i++ ) {
-        for( _nb_cit J = nb1(*i).begin(); J != nb1(*i).end(); J++ )
-            if( *J != I )
-                prod *= Factor( var(*i), newMessage(*i,*J)) );
+        return( result );
     }
-    return prod.normalize( Prob::NORMPROB );*/
 }
 
 
-Complex BP::logZ() const {
-    Complex sum = 0.0;
+Real BP::logZ() const {
+    Real sum = 0.0;
     for(size_t i = 0; i < nrVars(); ++i )
-        sum += Complex(1.0 - nbV(i).size()) * beliefV(i).entropy();
+        sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
     for( size_t I = 0; I < nrFactors(); ++I )
         sum -= KL_dist( beliefF(I), factor(I) );
     return sum;
@@ -404,9 +436,7 @@ Complex BP::logZ() const {
 
 
 string BP::identify() const { 
-    stringstream result (stringstream::out);
-    result << Name << getProperties();
-    return result.str();
+    return string(Name) + printProperties();
 }