Merge branch 'master' of git@git.tuebingen.mpg.de:libdai
[libdai.git] / src / bp.cpp
index 1045ded..0352565 100644 (file)
@@ -39,26 +39,29 @@ using namespace std;
 const char *BP::Name = "BP";
 
 
-bool BP::checkProperties() {
-    if( !HasProperty("updates") )
-        return false;
-    if( !HasProperty("tol") )
-        return false;
-    if (!HasProperty("maxiter") )
-        return false;
-    if (!HasProperty("verbose") )
-        return false;
-    if (!HasProperty("logdomain") )
-        return false;
+void BP::setProperties( const PropertySet &opts ) {
+    assert( opts.hasKey("tol") );
+    assert( opts.hasKey("maxiter") );
+    assert( opts.hasKey("verbose") );
+    assert( opts.hasKey("logdomain") );
+    assert( opts.hasKey("updates") );
     
-    ConvertPropertyTo<double>("tol");
-    ConvertPropertyTo<size_t>("maxiter");
-    ConvertPropertyTo<size_t>("verbose");
-    ConvertPropertyTo<UpdateType>("updates");
-    ConvertPropertyTo<bool>("logdomain");
-    logDomain = GetPropertyAs<bool>("logdomain");
-
-    return true;
+    props.tol = opts.getStringAs<double>("tol");
+    props.maxiter = opts.getStringAs<size_t>("maxiter");
+    props.verbose = opts.getStringAs<size_t>("verbose");
+    props.logdomain = opts.getStringAs<bool>("logdomain");
+    props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+}
+
+
+PropertySet BP::getProperties() const {
+    PropertySet opts;
+    opts.Set( "tol", props.tol );
+    opts.Set( "maxiter", props.maxiter );
+    opts.Set( "verbose", props.verbose );
+    opts.Set( "logdomain", props.logdomain );
+    opts.Set( "updates", props.updates );
+    return opts;
 }
 
 
@@ -86,10 +89,9 @@ void BP::create() {
 
 
 void BP::init() {
-    assert( checkProperties() );
     for( size_t i = 0; i < nrVars(); ++i ) {
         foreach( const Neighbor &I, nbV(i) ) {
-            if( logDomain ) {
+            if( props.logdomain ) {
                 message( i, I.iter ).fill( 0.0 );
                 newMessage( i, I.iter ).fill( 0.0 );
             } else {
@@ -131,7 +133,7 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
 */
     
     Prob prod( factor(I).p() );
-    if( logDomain ) 
+    if( props.logdomain ) 
         prod.takeLog();
 
     // Calculate product of incoming messages and factor I
@@ -142,10 +144,10 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
             const ind_t & ind = index(j, _I);
 
             // prod_j will be the product of messages coming into j
-            Prob prod_j( var(j).states(), logDomain ? 0.0 : 1.0 ); 
+            Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
             foreach( const Neighbor &J, nbV(j) )
                 if( J != I ) { // for all J in nb(j) \ I 
-                    if( logDomain )
+                    if( props.logdomain )
                         prod_j += message( j, J.iter );
                     else
                         prod_j *= message( j, J.iter );
@@ -153,13 +155,13 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
 
             // multiply prod with prod_j
             for( size_t r = 0; r < prod.size(); ++r )
-                if( logDomain )
+                if( props.logdomain )
                     prod[r] += prod_j[ind[r]];
                 else
                     prod[r] *= prod_j[ind[r]];
         }
     }
-    if( logDomain ) {
+    if( props.logdomain ) {
         prod -= prod.maxVal();
         prod.takeExp();
     }
@@ -170,10 +172,10 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
     const ind_t ind = index(i,_I);
     for( size_t r = 0; r < prod.size(); ++r )
         marg[ind[r]] += prod[r];
-    marg.normalize( _normtype );
+    marg.normalize( Prob::NORMPROB );
     
     // Store result
-    if( logDomain )
+    if( props.logdomain )
         newMessage(i,_I) = marg.log();
     else
         newMessage(i,_I) = marg;
@@ -183,15 +185,14 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
 // BP::run does not check for NANs for performance reasons
 // Somehow NaNs do not often occur in BP...
 double BP::run() {
-    if( Verbose() >= 1 )
+    if( props.verbose >= 1 )
         cout << "Starting " << identify() << "...";
-    if( Verbose() >= 3)
+    if( props.verbose >= 3)
        cout << endl; 
 
     double tic = toc();
     Diffs diffs(nrVars(), 1.0);
     
-    typedef pair<size_t,size_t> Edge;
     vector<Edge> update_seq;
 
     vector<Factor> old_beliefs;
@@ -202,7 +203,7 @@ double BP::run() {
     size_t iter = 0;
     size_t nredges = nrEdges();
 
-    if( Updates() == UpdateType::SEQMAX ) {
+    if( props.updates == Properties::UpdateType::SEQMAX ) {
         // do the first pass
         for( size_t i = 0; i < nrVars(); ++i )
             foreach( const Neighbor &I, nbV(i) ) {
@@ -219,8 +220,8 @@ double BP::run() {
 
     // do several passes over the network until maximum number of iterations has
     // been reached or until the maximum belief difference is smaller than tolerance
-    for( iter=0; iter < MaxIter() && diffs.maxDiff() > Tol(); ++iter ) {
-        if( Updates() == UpdateType::SEQMAX ) {
+    for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; ++iter ) {
+        if( props.updates == Properties::UpdateType::SEQMAX ) {
             // Residuals-BP by Koller et al.
             for( size_t t = 0; t < nredges; ++t ) {
                 // update the message with the largest residual
@@ -244,7 +245,7 @@ double BP::run() {
                     }
                 }
             }
-        } else if( Updates() == UpdateType::PARALL ) {
+        } else if( props.updates == Properties::UpdateType::PARALL ) {
             // Parallel updates 
             for( size_t i = 0; i < nrVars(); ++i )
                 foreach( const Neighbor &I, nbV(i) )
@@ -255,7 +256,7 @@ double BP::run() {
                     message( i, I.iter ) = newMessage( i, I.iter );
         } else {
             // Sequential updates
-            if( Updates() == UpdateType::SEQRND )
+            if( props.updates == Properties::UpdateType::SEQRND )
                 random_shuffle( update_seq.begin(), update_seq.end() );
             
             foreach( const Edge &e, update_seq ) {
@@ -271,19 +272,20 @@ double BP::run() {
             old_beliefs[i] = nb;
         }
 
-        if( Verbose() >= 3 )
+        if( props.verbose >= 3 )
             cout << "BP::run:  maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
     }
 
-    updateMaxDiff( diffs.maxDiff() );
+    if( diffs.maxDiff() > maxdiff )
+        maxdiff = diffs.maxDiff();
 
-    if( Verbose() >= 1 ) {
-        if( diffs.maxDiff() > Tol() ) {
-            if( Verbose() == 1 )
+    if( props.verbose >= 1 ) {
+        if( diffs.maxDiff() > props.tol ) {
+            if( props.verbose == 1 )
                 cout << endl;
-                cout << "BP::run:  WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
+                cout << "BP::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
         } else {
-            if( Verbose() >= 3 )
+            if( props.verbose >= 3 )
                 cout << "BP::run:  ";
                 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
         }
@@ -294,13 +296,13 @@ double BP::run() {
 
 
 Factor BP::beliefV( size_t i ) const {
-    Prob prod( var(i).states(), logDomain ? 0.0 : 1.0 ); 
+    Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 ); 
     foreach( const Neighbor &I, nbV(i) )
-        if( logDomain )
+        if( props.logdomain )
             prod += newMessage( i, I.iter );
         else
             prod *= newMessage( i, I.iter );
-    if( logDomain ) {
+    if( props.logdomain ) {
         prod -= prod.maxVal();
         prod.takeExp();
     }
@@ -341,7 +343,7 @@ Factor BP::belief( const VarSet &ns ) const {
 
 Factor BP::beliefF (size_t I) const {
     Prob prod( factor(I).p() );
-    if( logDomain )
+    if( props.logdomain )
         prod.takeLog();
 
     foreach( const Neighbor &j, nbF(I) ) {
@@ -350,10 +352,10 @@ Factor BP::beliefF (size_t I) const {
         const ind_t & ind = index(j, _I);
 
         // prod_j will be the product of messages coming into j
-        Prob prod_j( var(j).states(), logDomain ? 0.0 : 1.0 ); 
+        Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
         foreach( const Neighbor &J, nbV(j) ) {
             if( J != I ) { // for all J in nb(j) \ I 
-                if( logDomain )
+                if( props.logdomain )
                     prod_j += newMessage( j, J.iter );
                 else
                     prod_j *= newMessage( j, J.iter );
@@ -362,14 +364,14 @@ Factor BP::beliefF (size_t I) const {
 
         // multiply prod with prod_j
         for( size_t r = 0; r < prod.size(); ++r ) {
-            if( logDomain )
+            if( props.logdomain )
                 prod[r] += prod_j[ind[r]];
             else
                 prod[r] *= prod_j[ind[r]];
         }
     }
 
-    if( logDomain ) {
+    if( props.logdomain ) {
         prod -= prod.maxVal();
         prod.takeExp();
     }
@@ -391,10 +393,10 @@ Factor BP::beliefF (size_t I) const {
 }
 
 
-Complex BP::logZ() const {
-    Complex sum = 0.0;
+Real BP::logZ() const {
+    Real sum = 0.0;
     for(size_t i = 0; i < nrVars(); ++i )
-        sum += Complex(1.0 - nbV(i).size()) * beliefV(i).entropy();
+        sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
     for( size_t I = 0; I < nrFactors(); ++I )
         sum -= KL_dist( beliefF(I), factor(I) );
     return sum;
@@ -403,7 +405,7 @@ Complex BP::logZ() const {
 
 string BP::identify() const { 
     stringstream result (stringstream::out);
-    result << Name << GetProperties();
+    result << Name << getProperties();
     return result.str();
 }
 
@@ -412,7 +414,7 @@ void BP::init( const VarSet &ns ) {
     for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
         size_t ni = findVar( *n );
         foreach( const Neighbor &I, nbV( ni ) )
-            message( ni, I.iter ).fill( logDomain ? 0.0 : 1.0 );
+            message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
     }
 }