Improved BP (added 'maxtime' property)
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Tue, 3 Aug 2010 13:41:26 +0000 (15:41 +0200)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Tue, 3 Aug 2010 13:41:26 +0000 (15:41 +0200)
ChangeLog
include/dai/bp.h
src/bp.cpp

index 5e7b073..11b4cc1 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,5 +1,6 @@
 git HEAD
 --------
+* Improved BP (added 'maxtime' property)
 * Added fromString<>( const std::string& x )
 * Added SmallSet::erase( const T& t )
 * Added DECMAP algorithm.
index dc133af..c319f88 100644 (file)
@@ -114,6 +114,9 @@ class BP : public DAIAlgFG {
             /// Maximum number of iterations
             size_t maxiter;
 
+            /// Maximum time (in seconds)
+            double maxtime;
+
             /// Tolerance for convergence test
             Real tol;
 
@@ -136,6 +139,15 @@ class BP : public DAIAlgFG {
         /// Specifies whether the history of message updates should be recorded
         bool recordSentMessages;
 
+        /// Stores variable beliefs of previous iteration
+        std::vector<Factor> oldBeliefsV;
+
+        /// Stores factor beliefs of previous iteration
+        std::vector<Factor> oldBeliefsF;
+
+        /// Stores the update schedule
+        std::vector<Edge> updateSeq;
+
     public:
     /// \name Constructors/destructors
     //@{
index ab18715..db2ab72 100644 (file)
@@ -34,15 +34,21 @@ const char *BP::Name = "BP";
 
 void BP::setProperties( const PropertySet &opts ) {
     DAI_ASSERT( opts.hasKey("tol") );
-    DAI_ASSERT( opts.hasKey("maxiter") );
     DAI_ASSERT( opts.hasKey("logdomain") );
     DAI_ASSERT( opts.hasKey("updates") );
 
     props.tol = opts.getStringAs<Real>("tol");
-    props.maxiter = opts.getStringAs<size_t>("maxiter");
     props.logdomain = opts.getStringAs<bool>("logdomain");
     props.updates = opts.getStringAs<Properties::UpdateType>("updates");
 
+    if( opts.hasKey("maxiter") )
+        props.maxiter = opts.getStringAs<size_t>("maxiter");
+    else
+        props.maxiter = 10000;
+    if( opts.hasKey("maxtime") )
+        props.maxtime = opts.getStringAs<Real>("maxtime");
+    else
+        props.maxtime = INFINITY;
     if( opts.hasKey("verbose") )
         props.verbose = opts.getStringAs<size_t>("verbose");
     else
@@ -62,6 +68,7 @@ PropertySet BP::getProperties() const {
     PropertySet opts;
     opts.set( "tol", props.tol );
     opts.set( "maxiter", props.maxiter );
+    opts.set( "maxtime", props.maxtime );
     opts.set( "verbose", props.verbose );
     opts.set( "logdomain", props.logdomain );
     opts.set( "updates", props.updates );
@@ -76,6 +83,7 @@ string BP::printProperties() const {
     s << "[";
     s << "tol=" << props.tol << ",";
     s << "maxiter=" << props.maxiter << ",";
+    s << "maxtime=" << props.maxtime << ",";
     s << "verbose=" << props.verbose << ",";
     s << "logdomain=" << props.logdomain << ",";
     s << "updates=" << props.updates << ",";
@@ -116,6 +124,23 @@ void BP::construct() {
                 _edge2lut[i].push_back( _lut.insert( make_pair( newEP.residual, make_pair( i, _edges[i].size() - 1 ))) );
         }
     }
+
+    // create old beliefs
+    oldBeliefsV.clear();
+    oldBeliefsV.reserve( nrVars() );
+    for( size_t i = 0; i < nrVars(); ++i )
+        oldBeliefsV.push_back( Factor( var(i) ) );
+    oldBeliefsF.clear();
+    oldBeliefsF.reserve( nrFactors() );
+    for( size_t I = 0; I < nrFactors(); ++I )
+        oldBeliefsF.push_back( Factor( factor(I).vars() ) );
+    
+    // create update sequence
+    updateSeq.clear();
+    updateSeq.reserve( nrEdges() );
+    for( size_t I = 0; I < nrFactors(); I++ )
+        foreach( const Neighbor &i, nbF(I) )
+            updateSeq.push_back( Edge( i, i.dual ) );
 }
 
 
@@ -129,6 +154,7 @@ void BP::init() {
                 updateResidual( i, I.iter, 0.0 );
         }
     }
+    _iters = 0;
 }
 
 
@@ -245,36 +271,20 @@ Real BP::run() {
         cerr << endl;
 
     double tic = toc();
-    Real maxDiff = INFINITY;
-
-    vector<Factor> oldBeliefsV, oldBeliefsF;
-    oldBeliefsV.reserve( nrVars() );
-    for( size_t i = 0; i < nrVars(); ++i )
-        oldBeliefsV.push_back( beliefV(i) );
-    oldBeliefsF.reserve( nrFactors() );
-    for( size_t I = 0; I < nrFactors(); ++I )
-        oldBeliefsF.push_back( beliefF(I) );
-
-    size_t nredges = nrEdges();
-    vector<Edge> update_seq;
-    if( props.updates == Properties::UpdateType::SEQMAX ) {
-        // do the first pass
-        for( size_t i = 0; i < nrVars(); ++i )
-            foreach( const Neighbor &I, nbV(i) )
-                calcNewMessage( i, I.iter );
-    } else {
-        update_seq.reserve( nredges );
-        for( size_t I = 0; I < nrFactors(); I++ )
-            foreach( const Neighbor &i, nbF(I) )
-                update_seq.push_back( Edge( i, i.dual ) );
-    }
 
     // do several passes over the network until maximum number of iterations has
     // been reached or until the maximum belief difference is smaller than tolerance
-    for( _iters=0; _iters < props.maxiter && maxDiff > props.tol; ++_iters ) {
+    Real maxDiff = INFINITY;
+    for( ; _iters < props.maxiter && maxDiff > props.tol && (toc() - tic) < props.maxtime; _iters++ ) {
         if( props.updates == Properties::UpdateType::SEQMAX ) {
-            // Residuals-BP by Koller et al.
-            for( size_t t = 0; t < nredges; ++t ) {
+            if( _iters == 0 ) {
+                // do the first pass
+                for( size_t i = 0; i < nrVars(); ++i )
+                  foreach( const Neighbor &I, nbV(i) )
+                      calcNewMessage( i, I.iter );
+            }
+            // Maximum-Residual BP [\ref EMK06]
+            for( size_t t = 0; t < updateSeq.size(); ++t ) {
                 // update the message with the largest residual
                 size_t i, _I;
                 findMaxResidual( i, _I );
@@ -304,9 +314,9 @@ Real BP::run() {
         } else {
             // Sequential updates
             if( props.updates == Properties::UpdateType::SEQRND )
-                random_shuffle( update_seq.begin(), update_seq.end() );
+                random_shuffle( updateSeq.begin(), updateSeq.end() );
 
-            foreach( const Edge &e, update_seq ) {
+            foreach( const Edge &e, updateSeq ) {
                 calcNewMessage( e.first, e.second );
                 updateMessage( e.first, e.second );
             }
@@ -336,7 +346,7 @@ Real BP::run() {
         if( maxDiff > props.tol ) {
             if( props.verbose == 1 )
                 cerr << endl;
-                cerr << Name << "::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
+                cerr << Name << "::run:  WARNING: not converged after " << _iters << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
         } else {
             if( props.verbose >= 3 )
                 cerr << Name << "::run:  ";
@@ -439,6 +449,7 @@ void BP::init( const VarSet &ns ) {
                 updateResidual( ni, I.iter, 0.0 );
         }
     }
+    _iters = 0;
 }