New git HEAD version
[libdai.git] / src / trwbp.cpp
index 34c2943..bb294a4 100644 (file)
@@ -1,13 +1,15 @@
 /*  This file is part of libDAI - http://www.libdai.org/
  *
- *  libDAI is licensed under the terms of the GNU General Public License version
- *  2, or (at your option) any later version. libDAI is distributed without any
- *  warranty. See the file COPYING for more details.
+ *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
  *
- *  Copyright (C) 2010  Joris Mooij  [joris dot mooij at libdai dot org]
+ *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
  */
 
 
+#include <dai/dai_config.h>
+#ifdef DAI_WITH_TRWBP
+
+
 #include <dai/trwbp.h>
 
 
@@ -20,9 +22,6 @@ namespace dai {
 using namespace std;
 
 
-const char *TRWBP::Name = "TRWBP";
-
-
 void TRWBP::setProperties( const PropertySet &opts ) {
     BP::setProperties( opts );
 
@@ -50,11 +49,6 @@ string TRWBP::printProperties() const {
 }
 
 
-string TRWBP::identify() const {
-    return string(Name) + printProperties();
-}
-
-
 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
 Real TRWBP::logZ() const {
     Real sum = 0.0;
@@ -64,9 +58,10 @@ Real TRWBP::logZ() const {
     }
     for( size_t i = 0; i < nrVars(); ++i ) {
         Real c_i = 0.0;
-        foreach( const Neighbor &I, nbV(i) )
+        bforeach( const Neighbor &I, nbV(i) )
             c_i += Weight(I);
-        sum += (1.0 - c_i) * beliefV(i).entropy();  // TRWBP/FBP
+        if( c_i != 1.0 )
+            sum += (1.0 - c_i) * beliefV(i).entropy();  // TRWBP/FBP
     }
     return sum;
 }
@@ -85,20 +80,20 @@ Prob TRWBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) con
         prod ^= (1.0 / c_I); // TRWBP
 
     // Calculate product of incoming messages and factor I
-    foreach( const Neighbor &j, nbF(I) )
+    bforeach( const Neighbor &j, nbF(I) )
         if( !(without_i && (j == i)) ) {
             const Var &v_j = var(j);
             // prod_j will be the product of messages coming into j
             // TRWBP: corresponds to messages n_jI
             Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
-            foreach( const Neighbor &J, nbV(j) ) {
+            bforeach( const Neighbor &J, nbV(j) ) {
                 Real c_J = Weight(J);  // TRWBP
                 if( J != I ) { // for all J in nb(j) \ I
                     if( props.logdomain )
                         prod_j += message( j, J.iter ) * c_J;
                     else
                         prod_j *= message( j, J.iter ) ^ c_J;
-                } else { // TRWBP: multiply by m_Ij^(c_I-1)
+                } else if( c_J != 1.0 ) { // TRWBP: multiply by m_Ij^(c_I-1)
                     if( props.logdomain )
                         prod_j += message( j, J.iter ) * (c_J - 1.0);
                     else
@@ -119,12 +114,11 @@ Prob TRWBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) con
                 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
                 const ind_t &ind = index(j, _I);
 
-                for( size_t r = 0; r < prod.size(); ++r ) {
+                for( size_t r = 0; r < prod.size(); ++r )
                     if( props.logdomain )
-                        prod[r] += prod_j[ind[r]];
+                        prod.set( r, prod[r] + prod_j[ind[r]] );
                     else
-                        prod[r] *= prod_j[ind[r]];
-                }
+                        prod.set( r, prod[r] * prod_j[ind[r]] );
             }
         }
     
@@ -135,7 +129,7 @@ Prob TRWBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) con
 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
 void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
     p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
-    foreach( const Neighbor &I, nbV(i) ) {
+    bforeach( const Neighbor &I, nbV(i) ) {
         Real c_I = Weight(I);
         if( props.logdomain )
             p += newMessage( i, I.iter ) * c_I;
@@ -184,7 +178,7 @@ void TRWBP::sampleWeights( size_t nrTrees ) {
     // now repeatedly change the random weights, find the minimal spanning tree, and add it to the weights
     for( size_t nr = 0; nr < nrTrees; nr++ ) {
         // find minimal spanning tree
-        RootedTree randTree = MinSpanningTree( wg, true );
+        RootedTree randTree = MinSpanningTree( wg, /*true*/false ); // WORKAROUND FOR BUG IN BOOST GRAPH LIBRARY VERSION 1.54
         // add it to the weights
         addTreeToWeights( randTree );
         // resample weights of the graph
@@ -206,3 +200,6 @@ void TRWBP::sampleWeights( size_t nrTrees ) {
 
 
 } // end of namespace dai
+
+
+#endif