Fixed regression FBP and bugs in TRWBP
[libdai.git] / src / bp.cpp
index 5d4a79a..dc7dd91 100644 (file)
@@ -145,70 +145,74 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
     // calculate updated message I->i
     size_t I = nbV(i,_I);
 
-    Factor Fprod( factor(I) );
-    Prob &prod = Fprod.p();
-    if( props.logdomain )
-        prod.takeLog();
+    Prob marg;
+    if( factor(I).vars().size() == 1 ) // optimization
+        marg = factor(I).p();
+    else {
+        Factor Fprod( factor(I) );
+        Prob &prod = Fprod.p();
+        if( props.logdomain )
+            prod.takeLog();
+
+        // Calculate product of incoming messages and factor I
+        foreach( const Neighbor &j, nbF(I) )
+            if( j != i ) { // for all j in I \ i
+                // prod_j will be the product of messages coming into j
+                Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
+                foreach( const Neighbor &J, nbV(j) )
+                    if( J != I ) { // for all J in nb(j) \ I
+                        if( props.logdomain )
+                            prod_j += message( j, J.iter );
+                        else
+                            prod_j *= message( j, J.iter );
+                    }
 
-    // Calculate product of incoming messages and factor I
-    foreach( const Neighbor &j, nbF(I) )
-        if( j != i ) { // for all j in I \ i
-            // prod_j will be the product of messages coming into j
-            Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
-            foreach( const Neighbor &J, nbV(j) )
-                if( J != I ) { // for all J in nb(j) \ I
+                // multiply prod with prod_j
+                if( !DAI_BP_FAST ) {
+                    /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
                     if( props.logdomain )
-                        prod_j += message( j, J.iter );
+                        Fprod += Factor( var(j), prod_j );
                     else
-                        prod_j *= message( j, J.iter );
+                        Fprod *= Factor( var(j), prod_j );
+                } else {
+                    /* OPTIMIZED VERSION */
+                    size_t _I = j.dual;
+                    // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+                    const ind_t &ind = index(j, _I);
+                    for( size_t r = 0; r < prod.size(); ++r )
+                        if( props.logdomain )
+                            prod[r] += prod_j[ind[r]];
+                        else
+                            prod[r] *= prod_j[ind[r]];
                 }
-
-            // multiply prod with prod_j
-            if( !DAI_BP_FAST ) {
-                /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
-                if( props.logdomain )
-                    Fprod += Factor( var(j), prod_j );
-                else
-                    Fprod *= Factor( var(j), prod_j );
-            } else {
-                /* OPTIMIZED VERSION */
-                size_t _I = j.dual;
-                // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
-                const ind_t &ind = index(j, _I);
-                for( size_t r = 0; r < prod.size(); ++r )
-                    if( props.logdomain )
-                        prod[r] += prod_j[ind[r]];
-                    else
-                        prod[r] *= prod_j[ind[r]];
             }
-        }
 
-    if( props.logdomain ) {
-        prod -= prod.max();
-        prod.takeExp();
-    }
+        if( props.logdomain ) {
+            prod -= prod.max();
+            prod.takeExp();
+        }
 
-    // Marginalize onto i
-    Prob marg;
-    if( !DAI_BP_FAST ) {
-        /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
-        if( props.inference == Properties::InfType::SUMPROD )
-            marg = Fprod.marginal( var(i) ).p();
-        else
-            marg = Fprod.maxMarginal( var(i) ).p();
-    } else {
-        /* OPTIMIZED VERSION */
-        marg = Prob( var(i).states(), 0.0 );
-        // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
-        const ind_t ind = index(i,_I);
-        if( props.inference == Properties::InfType::SUMPROD )
-            for( size_t r = 0; r < prod.size(); ++r )
-                marg[ind[r]] += prod[r];
-        else
-            for( size_t r = 0; r < prod.size(); ++r )
-                if( prod[r] > marg[ind[r]] )
-                    marg[ind[r]] = prod[r];
-        marg.normalize();
+        // Marginalize onto i
+        if( !DAI_BP_FAST ) {
+            /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+            if( props.inference == Properties::InfType::SUMPROD )
+                marg = Fprod.marginal( var(i) ).p();
+            else
+                marg = Fprod.maxMarginal( var(i) ).p();
+        } else {
+            /* OPTIMIZED VERSION */
+            marg = Prob( var(i).states(), 0.0 );
+            // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
+            const ind_t ind = index(i,_I);
+            if( props.inference == Properties::InfType::SUMPROD )
+                for( size_t r = 0; r < prod.size(); ++r )
+                    marg[ind[r]] += prod[r];
+            else
+                for( size_t r = 0; r < prod.size(); ++r )
+                    if( prod[r] > marg[ind[r]] )
+                        marg[ind[r]] = prod[r];
+            marg.normalize();
+        }
     }
 
     // Store result