Fixed regression FBP and bugs in TRWBP
[libdai.git] / src / trwbp.cpp
index 407d847..bda5bc7 100644 (file)
@@ -32,12 +32,15 @@ string TRWBP::identify() const {
 Real TRWBP::logZ() const {
     Real sum = 0.0;
     for( size_t I = 0; I < nrFactors(); I++ ) {
-        sum += (beliefF(I) * factor(I).log(true)).sum();  // TRWBP
-        if( factor(I).vars().size() == 2 )
-            sum -= edgeWeight(I) * MutualInfo( beliefF(I) );  // TRWBP
+        sum += (beliefF(I) * factor(I).log(true)).sum();  // TRWBP/FBP
+        sum += Weight(I) * beliefF(I).entropy();  // TRWBP/FBP
+    }
+    for( size_t i = 0; i < nrVars(); ++i ) {
+        Real c_i = 0.0;
+        foreach( const Neighbor &I, nbV(i) )
+            c_i += Weight(I);
+        sum += (1.0 - c_i) * beliefV(i).entropy();  // TRWBP/FBP
     }
-    for( size_t i = 0; i < nrVars(); ++i )
-        sum += beliefV(i).entropy();  // TRWBP
     return sum;
 }
 
@@ -47,13 +50,12 @@ void TRWBP::calcNewMessage( size_t i, size_t _I ) {
     // calculate updated message I->i
     size_t I = nbV(i,_I);
     const Var &v_i = var(i);
-    const VarSet &v_I = factor(I).vars();
-    Real c_I = edgeWeight(I); // TRWBP: c_I (\mu_I in the paper)
+    Real c_I = Weight(I); // TRWBP: c_I (\mu_I in the paper)
 
     Prob marg;
-    if( v_I.size() == 1 ) { // optimization
+    if( factor(I).vars().size() == 1 ) { // optimization
         marg = factor(I).p();
-    } else
+    } else {
         Factor Fprod( factor(I) );
         Prob &prod = Fprod.p();
         if( props.logdomain ) {
@@ -71,7 +73,7 @@ void TRWBP::calcNewMessage( size_t i, size_t _I ) {
                 // prod_j will be the product of messages coming into j
                 Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
                 foreach( const Neighbor &J, nbV(j) ) {
-                    Real c_J = edgeWeight(J);
+                    Real c_J = Weight(J);
                     if( J != I ) { // for all J in nb(j) \ I
                         if( props.logdomain )
                             prod_j += message( j, J.iter ) * c_J;
@@ -94,6 +96,7 @@ void TRWBP::calcNewMessage( size_t i, size_t _I ) {
                         Fprod *= Factor( v_j, prod_j );
                 } else {
                     /* OPTIMIZED VERSION */
+                    size_t _I = j.dual;
                     // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
                     const ind_t &ind = index(j, _I);
                     for( size_t r = 0; r < prod.size(); ++r )
@@ -130,17 +133,17 @@ void TRWBP::calcNewMessage( size_t i, size_t _I ) {
                         marg[ind[r]] = prod[r];
             marg.normalize();
         }
+    }
 
-        // Store result
-        if( props.logdomain )
-            newMessage(i,_I) = marg.log();
-        else
-            newMessage(i,_I) = marg;
+    // Store result
+    if( props.logdomain )
+        newMessage(i,_I) = marg.log();
+    else
+        newMessage(i,_I) = marg;
 
-        // Update the residual if necessary
-        if( props.updates == Properties::UpdateType::SEQMAX )
-            updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
-    }
+    // Update the residual if necessary
+    if( props.updates == Properties::UpdateType::SEQMAX )
+        updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
 }
 
 
@@ -148,7 +151,7 @@ void TRWBP::calcNewMessage( size_t i, size_t _I ) {
 void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
     p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
     foreach( const Neighbor &I, nbV(i) ) {
-        Real c_I = edgeWeight(I);
+        Real c_I = Weight(I);
         if( props.logdomain )
             p += newMessage( i, I.iter ) * c_I;
         else
@@ -159,8 +162,7 @@ void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
 
 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
 void TRWBP::calcBeliefF( size_t I, Prob &p ) const {
-    Real c_I = edgeWeight(I); // TRWBP: c_I
-    const VarSet &v_I = factor(I).vars();
+    Real c_I = Weight(I); // TRWBP: c_I
 
     Factor Fprod( factor(I) );
     Prob &prod = Fprod.p();
@@ -179,7 +181,7 @@ void TRWBP::calcBeliefF( size_t I, Prob &p ) const {
         // prod_j will be the product of messages coming into j
         Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
         foreach( const Neighbor &J, nbV(j) ) {
-            Real c_J = edgeWeight(J);
+            Real c_J = Weight(J);
             if( J != I ) { // for all J in nb(j) \ I
                 if( props.logdomain )
                     prod_j += newMessage( j, J.iter ) * c_J;
@@ -221,7 +223,7 @@ void TRWBP::calcBeliefF( size_t I, Prob &p ) const {
 
 void TRWBP::construct() {
     BP::construct();
-    _edge_weight.resize( nrFactors(), 1.0 );
+    _weight.resize( nrFactors(), 1.0 );
 }