New git HEAD version
[libdai.git] / src / trwbp.cpp
index 407d847..bb294a4 100644 (file)
@@ -1,13 +1,15 @@
 /*  This file is part of libDAI - http://www.libdai.org/
  *
- *  libDAI is licensed under the terms of the GNU General Public License version
- *  2, or (at your option) any later version. libDAI is distributed without any
- *  warranty. See the file COPYING for more details.
+ *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
  *
- *  Copyright (C) 2010  Joris Mooij  [joris dot mooij at libdai dot org]
+ *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
  */
 
 
+#include <dai/dai_config.h>
+#ifdef DAI_WITH_TRWBP
+
+
 #include <dai/trwbp.h>
 
 
@@ -20,135 +22,115 @@ namespace dai {
 using namespace std;
 
 
-const char *TRWBP::Name = "TRWBP";
+void TRWBP::setProperties( const PropertySet &opts ) {
+    BP::setProperties( opts );
 
+    if( opts.hasKey("nrtrees") )
+        nrtrees = opts.getStringAs<size_t>("nrtrees");
+    else
+        nrtrees = 0;
+}
 
-string TRWBP::identify() const {
-    return string(Name) + printProperties();
+
+PropertySet TRWBP::getProperties() const {
+    PropertySet opts = BP::getProperties();
+    opts.set( "nrtrees", nrtrees );
+    return opts;
 }
 
 
-/* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
+string TRWBP::printProperties() const {
+    stringstream s( stringstream::out );
+    string sbp = BP::printProperties();
+    s << sbp.substr( 0, sbp.size() - 1 );
+    s << ",";
+    s << "nrtrees=" << nrtrees << "]";
+    return s.str();
+}
+
+
+// This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
 Real TRWBP::logZ() const {
     Real sum = 0.0;
     for( size_t I = 0; I < nrFactors(); I++ ) {
-        sum += (beliefF(I) * factor(I).log(true)).sum();  // TRWBP
-        if( factor(I).vars().size() == 2 )
-            sum -= edgeWeight(I) * MutualInfo( beliefF(I) );  // TRWBP
+        sum += (beliefF(I) * factor(I).log(true)).sum();  // TRWBP/FBP
+        sum += Weight(I) * beliefF(I).entropy();  // TRWBP/FBP
+    }
+    for( size_t i = 0; i < nrVars(); ++i ) {
+        Real c_i = 0.0;
+        bforeach( const Neighbor &I, nbV(i) )
+            c_i += Weight(I);
+        if( c_i != 1.0 )
+            sum += (1.0 - c_i) * beliefV(i).entropy();  // TRWBP/FBP
     }
-    for( size_t i = 0; i < nrVars(); ++i )
-        sum += beliefV(i).entropy();  // TRWBP
     return sum;
 }
 
 
-/* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
-void TRWBP::calcNewMessage( size_t i, size_t _I ) {
-    // calculate updated message I->i
-    size_t I = nbV(i,_I);
-    const Var &v_i = var(i);
-    const VarSet &v_I = factor(I).vars();
-    Real c_I = edgeWeight(I); // TRWBP: c_I (\mu_I in the paper)
+// This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
+Prob TRWBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
+    Real c_I = Weight(I); // TRWBP: c_I
 
-    Prob marg;
-    if( v_I.size() == 1 ) { // optimization
-        marg = factor(I).p();
+    Factor Fprod( factor(I) );
+    Prob &prod = Fprod.p();
+    if( props.logdomain ) {
+        prod.takeLog();
+        prod /= c_I; // TRWBP
     } else
-        Factor Fprod( factor(I) );
-        Prob &prod = Fprod.p();
-        if( props.logdomain ) {
-            prod.takeLog();
-            prod /= c_I;         // TRWBP
-        } else
-            prod ^= (1.0 / c_I); // TRWBP
-    
-        // Calculate product of incoming messages and factor I
-        foreach( const Neighbor &j, nbF(I) )
-            if( j != i ) { // for all j in I \ i
-                const Var &v_j = var(j);
-
-                // TRWBP: corresponds to messages n_jI
-                // prod_j will be the product of messages coming into j
-                Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
-                foreach( const Neighbor &J, nbV(j) ) {
-                    Real c_J = edgeWeight(J);
-                    if( J != I ) { // for all J in nb(j) \ I
-                        if( props.logdomain )
-                            prod_j += message( j, J.iter ) * c_J;
-                        else
-                            prod_j *= message( j, J.iter ) ^ c_J;
-                    } else { // TRWBP: multiply by m_Ij^(c_I-1)
-                        if( props.logdomain )
-                            prod_j += message( j, J.iter ) * (c_J - 1.0);
-                        else
-                            prod_j *= message( j, J.iter ) ^ (c_J - 1.0);
-                    }
-                }
+        prod ^= (1.0 / c_I); // TRWBP
 
-                // multiply prod with prod_j
-                if( !DAI_TRWBP_FAST ) {
-                    /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+    // Calculate product of incoming messages and factor I
+    bforeach( const Neighbor &j, nbF(I) )
+        if( !(without_i && (j == i)) ) {
+            const Var &v_j = var(j);
+            // prod_j will be the product of messages coming into j
+            // TRWBP: corresponds to messages n_jI
+            Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
+            bforeach( const Neighbor &J, nbV(j) ) {
+                Real c_J = Weight(J);  // TRWBP
+                if( J != I ) { // for all J in nb(j) \ I
                     if( props.logdomain )
-                        Fprod += Factor( v_j, prod_j );
+                        prod_j += message( j, J.iter ) * c_J;
                     else
-                        Fprod *= Factor( v_j, prod_j );
-                } else {
-                    /* OPTIMIZED VERSION */
-                    // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
-                    const ind_t &ind = index(j, _I);
-                    for( size_t r = 0; r < prod.size(); ++r )
-                        if( props.logdomain )
-                            prod[r] += prod_j[ind[r]];
-                        else
-                            prod[r] *= prod_j[ind[r]];
+                        prod_j *= message( j, J.iter ) ^ c_J;
+                } else if( c_J != 1.0 ) { // TRWBP: multiply by m_Ij^(c_I-1)
+                    if( props.logdomain )
+                        prod_j += message( j, J.iter ) * (c_J - 1.0);
+                    else
+                        prod_j *= message( j, J.iter ) ^ (c_J - 1.0);
                 }
             }
 
-        if( props.logdomain ) {
-            prod -= prod.max();
-            prod.takeExp();
-        }
+            // multiply prod with prod_j
+            if( !DAI_TRWBP_FAST ) {
+                // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
+                if( props.logdomain )
+                    Fprod += Factor( v_j, prod_j );
+                else
+                    Fprod *= Factor( v_j, prod_j );
+            } else {
+                // OPTIMIZED VERSION
+                size_t _I = j.dual;
+                // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+                const ind_t &ind = index(j, _I);
 
-        // Marginalize onto i
-        if( !DAI_TRWBP_FAST ) {
-            /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
-            if( props.inference == Properties::InfType::SUMPROD )
-                marg = Fprod.marginal( v_i ).p();
-            else
-                marg = Fprod.maxMarginal( v_i ).p();
-        } else {
-            /* OPTIMIZED VERSION */
-            marg = Prob( v_i.states(), 0.0 );
-            // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
-            const ind_t ind = index(i,_I);
-            if( props.inference == Properties::InfType::SUMPROD )
                 for( size_t r = 0; r < prod.size(); ++r )
-                    marg[ind[r]] += prod[r];
-            else
-                for( size_t r = 0; r < prod.size(); ++r )
-                    if( prod[r] > marg[ind[r]] )
-                        marg[ind[r]] = prod[r];
-            marg.normalize();
+                    if( props.logdomain )
+                        prod.set( r, prod[r] + prod_j[ind[r]] );
+                    else
+                        prod.set( r, prod[r] * prod_j[ind[r]] );
+            }
         }
-
-        // Store result
-        if( props.logdomain )
-            newMessage(i,_I) = marg.log();
-        else
-            newMessage(i,_I) = marg;
-
-        // Update the residual if necessary
-        if( props.updates == Properties::UpdateType::SEQMAX )
-            updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
-    }
+    
+    return prod;
 }
 
 
-/* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
+// This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
 void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
     p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
-    foreach( const Neighbor &I, nbV(i) ) {
-        Real c_I = edgeWeight(I);
+    bforeach( const Neighbor &I, nbV(i) ) {
+        Real c_I = Weight(I);
         if( props.logdomain )
             p += newMessage( i, I.iter ) * c_I;
         else
@@ -157,72 +139,67 @@ void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
 }
 
 
-/* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
-void TRWBP::calcBeliefF( size_t I, Prob &p ) const {
-    Real c_I = edgeWeight(I); // TRWBP: c_I
-    const VarSet &v_I = factor(I).vars();
+void TRWBP::construct() {
+    BP::construct();
+    _weight.resize( nrFactors(), 1.0 );
+    sampleWeights( nrtrees );
+    if( props.verbose >= 2 )
+        cerr << "Weights: " << _weight << endl;
+}
 
-    Factor Fprod( factor(I) );
-    Prob &prod = Fprod.p();
 
-    if( props.logdomain ) {
-        prod.takeLog();
-        prod /= c_I; // TRWBP
-    } else
-        prod ^= (1.0 / c_I); // TRWBP
+void TRWBP::addTreeToWeights( const RootedTree &tree ) {
+    for( RootedTree::const_iterator e = tree.begin(); e != tree.end(); e++ ) {
+        VarSet ij( var(e->first), var(e->second) );
+        size_t I = findFactor( ij );
+        _weight[I] += 1.0;
+    }
+}
 
-    // Calculate product of incoming messages and factor I
-    foreach( const Neighbor &j, nbF(I) ) {
-        const Var &v_j = var(j);
-
-        // TRWBP: corresponds to messages n_jI
-        // prod_j will be the product of messages coming into j
-        Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
-        foreach( const Neighbor &J, nbV(j) ) {
-            Real c_J = edgeWeight(J);
-            if( J != I ) { // for all J in nb(j) \ I
-                if( props.logdomain )
-                    prod_j += newMessage( j, J.iter ) * c_J;
-                else
-                    prod_j *= newMessage( j, J.iter ) ^ c_J;
-            } else { // TRWBP: multiply by m_Ij^(c_I-1)
-                if( props.logdomain )
-                    prod_j += newMessage( j, J.iter ) * (c_J - 1.0);
-                else
-                    prod_j *= newMessage( j, J.iter ) ^ (c_J - 1.0);
-            }
-        }
 
-        // multiply prod with prod_j
-        if( !DAI_TRWBP_FAST ) {
-            /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
-            if( props.logdomain )
-                Fprod += Factor( v_j, prod_j );
-            else
-                Fprod *= Factor( v_j, prod_j );
-        } else {
-            /* OPTIMIZED VERSION */
-            size_t _I = j.dual;
-            // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
-            const ind_t &ind = index(j, _I);
-
-            for( size_t r = 0; r < prod.size(); ++r ) {
-                if( props.logdomain )
-                    prod[r] += prod_j[ind[r]];
-                else
-                    prod[r] *= prod_j[ind[r]];
-            }
-        }
+void TRWBP::sampleWeights( size_t nrTrees ) {
+    if( !nrTrees )
+        return;
+
+    // initialize weights to zero
+    fill( _weight.begin(), _weight.end(), 0.0 );
+
+    // construct Markov adjacency graph, with edges weighted with
+    // random weights drawn from the uniform distribution on the interval [0,1]
+    WeightedGraph<Real> wg;
+    for( size_t i = 0; i < nrVars(); ++i ) {
+        const Var &v_i = var(i);
+        VarSet di = delta(i);
+        for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
+            if( v_i < *j )
+                wg[UEdge(i,findVar(*j))] = rnd_uniform();
     }
-    
-    p = prod;
-}
 
+    // now repeatedly change the random weights, find the minimal spanning tree, and add it to the weights
+    for( size_t nr = 0; nr < nrTrees; nr++ ) {
+        // find minimal spanning tree
+        RootedTree randTree = MinSpanningTree( wg, /*true*/false ); // WORKAROUND FOR BUG IN BOOST GRAPH LIBRARY VERSION 1.54
+        // add it to the weights
+        addTreeToWeights( randTree );
+        // resample weights of the graph
+        for( WeightedGraph<Real>::iterator e = wg.begin(); e != wg.end(); e++ )
+            e->second = rnd_uniform();
+    }
 
-void TRWBP::construct() {
-    BP::construct();
-    _edge_weight.resize( nrFactors(), 1.0 );
+    // normalize the weights and set the single-variable weights to 1.0
+    for( size_t I = 0; I < nrFactors(); I++ ) {
+        size_t sizeI = factor(I).vars().size();
+        if( sizeI == 1 )
+            _weight[I] = 1.0;
+        else if( sizeI == 2 )
+            _weight[I] /= nrTrees;
+        else
+            DAI_THROW(NOT_IMPLEMENTED);
+    }
 }
 
 
 } // end of namespace dai
+
+
+#endif