New git HEAD version
[libdai.git] / src / trwbp.cpp
index 0674e8f..bb294a4 100644 (file)
@@ -1,13 +1,15 @@
 /*  This file is part of libDAI - http://www.libdai.org/
  *
- *  libDAI is licensed under the terms of the GNU General Public License version
- *  2, or (at your option) any later version. libDAI is distributed without any
- *  warranty. See the file COPYING for more details.
+ *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
  *
- *  Copyright (C) 2010  Joris Mooij  [joris dot mooij at libdai dot org]
+ *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
  */
 
 
+#include <dai/dai_config.h>
+#ifdef DAI_WITH_TRWBP
+
+
 #include <dai/trwbp.h>
 
 
@@ -20,11 +22,30 @@ namespace dai {
 using namespace std;
 
 
-const char *TRWBP::Name = "TRWBP";
+void TRWBP::setProperties( const PropertySet &opts ) {
+    BP::setProperties( opts );
+
+    if( opts.hasKey("nrtrees") )
+        nrtrees = opts.getStringAs<size_t>("nrtrees");
+    else
+        nrtrees = 0;
+}
+
+
+PropertySet TRWBP::getProperties() const {
+    PropertySet opts = BP::getProperties();
+    opts.set( "nrtrees", nrtrees );
+    return opts;
+}
 
 
-string TRWBP::identify() const {
-    return string(Name) + printProperties();
+string TRWBP::printProperties() const {
+    stringstream s( stringstream::out );
+    string sbp = BP::printProperties();
+    s << sbp.substr( 0, sbp.size() - 1 );
+    s << ",";
+    s << "nrtrees=" << nrtrees << "]";
+    return s.str();
 }
 
 
@@ -37,9 +58,10 @@ Real TRWBP::logZ() const {
     }
     for( size_t i = 0; i < nrVars(); ++i ) {
         Real c_i = 0.0;
-        foreach( const Neighbor &I, nbV(i) )
+        bforeach( const Neighbor &I, nbV(i) )
             c_i += Weight(I);
-        sum += (1.0 - c_i) * beliefV(i).entropy();  // TRWBP/FBP
+        if( c_i != 1.0 )
+            sum += (1.0 - c_i) * beliefV(i).entropy();  // TRWBP/FBP
     }
     return sum;
 }
@@ -58,20 +80,20 @@ Prob TRWBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) con
         prod ^= (1.0 / c_I); // TRWBP
 
     // Calculate product of incoming messages and factor I
-    foreach( const Neighbor &j, nbF(I) )
+    bforeach( const Neighbor &j, nbF(I) )
         if( !(without_i && (j == i)) ) {
             const Var &v_j = var(j);
             // prod_j will be the product of messages coming into j
             // TRWBP: corresponds to messages n_jI
             Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
-            foreach( const Neighbor &J, nbV(j) ) {
+            bforeach( const Neighbor &J, nbV(j) ) {
                 Real c_J = Weight(J);  // TRWBP
                 if( J != I ) { // for all J in nb(j) \ I
                     if( props.logdomain )
                         prod_j += message( j, J.iter ) * c_J;
                     else
                         prod_j *= message( j, J.iter ) ^ c_J;
-                } else { // TRWBP: multiply by m_Ij^(c_I-1)
+                } else if( c_J != 1.0 ) { // TRWBP: multiply by m_Ij^(c_I-1)
                     if( props.logdomain )
                         prod_j += message( j, J.iter ) * (c_J - 1.0);
                     else
@@ -92,12 +114,11 @@ Prob TRWBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) con
                 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
                 const ind_t &ind = index(j, _I);
 
-                for( size_t r = 0; r < prod.size(); ++r ) {
+                for( size_t r = 0; r < prod.size(); ++r )
                     if( props.logdomain )
-                        prod[r] += prod_j[ind[r]];
+                        prod.set( r, prod[r] + prod_j[ind[r]] );
                     else
-                        prod[r] *= prod_j[ind[r]];
-                }
+                        prod.set( r, prod[r] * prod_j[ind[r]] );
             }
         }
     
@@ -108,7 +129,7 @@ Prob TRWBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) con
 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
 void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
     p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
-    foreach( const Neighbor &I, nbV(i) ) {
+    bforeach( const Neighbor &I, nbV(i) ) {
         Real c_I = Weight(I);
         if( props.logdomain )
             p += newMessage( i, I.iter ) * c_I;
@@ -121,7 +142,64 @@ void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
 void TRWBP::construct() {
     BP::construct();
     _weight.resize( nrFactors(), 1.0 );
+    sampleWeights( nrtrees );
+    if( props.verbose >= 2 )
+        cerr << "Weights: " << _weight << endl;
+}
+
+
+void TRWBP::addTreeToWeights( const RootedTree &tree ) {
+    for( RootedTree::const_iterator e = tree.begin(); e != tree.end(); e++ ) {
+        VarSet ij( var(e->first), var(e->second) );
+        size_t I = findFactor( ij );
+        _weight[I] += 1.0;
+    }
+}
+
+
+void TRWBP::sampleWeights( size_t nrTrees ) {
+    if( !nrTrees )
+        return;
+
+    // initialize weights to zero
+    fill( _weight.begin(), _weight.end(), 0.0 );
+
+    // construct Markov adjacency graph, with edges weighted with
+    // random weights drawn from the uniform distribution on the interval [0,1]
+    WeightedGraph<Real> wg;
+    for( size_t i = 0; i < nrVars(); ++i ) {
+        const Var &v_i = var(i);
+        VarSet di = delta(i);
+        for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
+            if( v_i < *j )
+                wg[UEdge(i,findVar(*j))] = rnd_uniform();
+    }
+
+    // now repeatedly change the random weights, find the minimal spanning tree, and add it to the weights
+    for( size_t nr = 0; nr < nrTrees; nr++ ) {
+        // find minimal spanning tree
+        RootedTree randTree = MinSpanningTree( wg, /*true*/false ); // WORKAROUND FOR BUG IN BOOST GRAPH LIBRARY VERSION 1.54
+        // add it to the weights
+        addTreeToWeights( randTree );
+        // resample weights of the graph
+        for( WeightedGraph<Real>::iterator e = wg.begin(); e != wg.end(); e++ )
+            e->second = rnd_uniform();
+    }
+
+    // normalize the weights and set the single-variable weights to 1.0
+    for( size_t I = 0; I < nrFactors(); I++ ) {
+        size_t sizeI = factor(I).vars().size();
+        if( sizeI == 1 )
+            _weight[I] = 1.0;
+        else if( sizeI == 2 )
+            _weight[I] /= nrTrees;
+        else
+            DAI_THROW(NOT_IMPLEMENTED);
+    }
 }
 
 
 } // end of namespace dai
+
+
+#endif