Added initialization of TRWBP weights by sampling spanning trees
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Wed, 17 Mar 2010 12:29:07 +0000 (13:29 +0100)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Wed, 17 Mar 2010 12:29:07 +0000 (13:29 +0100)
ChangeLog
include/dai/doc.h
include/dai/trwbp.h
src/trwbp.cpp
tests/aliases.conf
utils/createfg.cpp

index 3d8135d..dbadbac 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,6 +1,7 @@
 git HEAD
 --------
 
+* Added initialization of TRWBP weights by sampling spanning trees
 * Cleaned up MR code:
   - rewrote response propagation implementation with help of BBP
   - now uses GraphAL to represent the Markov graph
index 5fe1a6f..cfa6382 100644 (file)
@@ -11,6 +11,8 @@
 /** \file
  *  \brief Contains additional doxygen documentation
  *
+ *  \todo Replace all Name members by virtual functions (or add virtual functions returning the Name)
+ *
  *  \idea Adapt (part of the) guidelines in http://www.boost.org/development/requirements.html#Design_and_Programming
  *
  *  \idea Use "gcc -MM" to generate dependencies for targets: http://make.paulandlesley.org/autodep.html
index 572f463..9f85815 100644 (file)
@@ -45,7 +45,6 @@ namespace dai {
  *    \f[ c_i := \sum_{I \in N_i} c_I \f]
  *
  *  \note TRWBP is actually equivalent to FBP
- *  \todo Add nice way to set weights
  *  \todo Merge code of FBP and TRWBP
  */
 class TRWBP : public BP {
@@ -59,6 +58,13 @@ class TRWBP : public BP {
         std::vector<Real> _weight;
 
     public:
+        /// Size of sample of trees used to set the weights
+        /** \todo See if there is a way to wrap TRWBP::nrtrees in a props struct
+         *  together with the other properties currently in TRWBP::props
+         *  (without copying al lot of BP code literally)
+         */
+        size_t nrtrees;
+
         /// Name of this inference algorithm
         static const char *Name;
 
@@ -68,8 +74,10 @@ class TRWBP : public BP {
         /// Default constructor
         TRWBP() : BP(), _weight() {}
 
-        /// Construct from FactorGraph \a fg and PropertySet \a opts
-        /** \param opts Parameters @see BP::Properties
+        /// Construct from FactorGraph \a fg and PropertySet \a opts.
+        /** There is an additional property "nrtrees" which allows to specify the
+         *  number of random spanning trees used to set the scale parameters.
+         *  \param opts Parameters @see BP::Properties. 
          */
         TRWBP( const FactorGraph &fg, const PropertySet &opts ) : BP(fg, opts), _weight() {
             setProperties( opts );
@@ -82,6 +90,9 @@ class TRWBP : public BP {
         virtual TRWBP* clone() const { return new TRWBP(*this); }
         virtual std::string identify() const;
         virtual Real logZ() const;
+        virtual void setProperties( const PropertySet &opts );
+        virtual PropertySet getProperties() const;
+        virtual std::string printProperties() const;
     //@}
 
     /// \name TRWBP accessors/mutators for scale parameters
@@ -100,6 +111,12 @@ class TRWBP : public BP {
          */
         void setWeights( const std::vector<Real> &c ) { _weight = c; }
 
+        /// Increases weights corresponding to pairwise factors in \a tree with 1
+        void addTreeToWeights( const RootedTree &tree );
+
+        /// Samples weights from a sample of \a nrTrees random spanning trees
+        void sampleWeights( size_t nrTrees );
+
     protected:
         /// Calculate the product of factor \a I and the incoming messages
         /** If \a without_i == \c true, the message coming from variable \a i is omitted from the product
index 0674e8f..3184376 100644 (file)
@@ -23,6 +23,33 @@ using namespace std;
 const char *TRWBP::Name = "TRWBP";
 
 
+void TRWBP::setProperties( const PropertySet &opts ) {
+    BP::setProperties( opts );
+
+    if( opts.hasKey("nrtrees") )
+        nrtrees = opts.getStringAs<size_t>("nrtrees");
+    else
+        nrtrees = 0;
+}
+
+
+PropertySet TRWBP::getProperties() const {
+    PropertySet opts = BP::getProperties();
+    opts.Set( "nrtrees", nrtrees );
+    return opts;
+}
+
+
+string TRWBP::printProperties() const {
+    stringstream s( stringstream::out );
+    string sbp = BP::printProperties();
+    s << sbp.substr( 0, sbp.size() - 1 );
+    s << ",";
+    s << "nrtrees=" << nrtrees << "]";
+    return s.str();
+}
+
+
 string TRWBP::identify() const {
     return string(Name) + printProperties();
 }
@@ -121,6 +148,60 @@ void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
 void TRWBP::construct() {
     BP::construct();
     _weight.resize( nrFactors(), 1.0 );
+    sampleWeights( nrtrees );
+    if( props.verbose >= 2 )
+        cerr << "Weights: " << _weight << endl;
+}
+
+
+void TRWBP::addTreeToWeights( const RootedTree &tree ) {
+    for( RootedTree::const_iterator e = tree.begin(); e != tree.end(); e++ ) {
+        VarSet ij( var(e->n1), var(e->n2) );
+        size_t I = findFactor( ij );
+        _weight[I] += 1.0;
+    }
+}
+
+
+void TRWBP::sampleWeights( size_t nrTrees ) {
+    if( !nrTrees )
+        return;
+
+    // initialize weights to zero
+    fill( _weight.begin(), _weight.end(), 0.0 );
+
+    // construct Markov adjacency graph, with edges weighted with
+    // random weights drawn from the uniform distribution on the interval [0,1]
+    WeightedGraph<Real> wg;
+    for( size_t i = 0; i < nrVars(); ++i ) {
+        const Var &v_i = var(i);
+        VarSet di = delta(i);
+        for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
+            if( v_i < *j )
+                wg[UEdge(i,findVar(*j))] = rnd_uniform();
+    }
+
+    // now repeatedly change the random weights, find the minimal spanning tree, and add it to the weights
+    for( size_t nr = 0; nr < nrTrees; nr++ ) {
+        // find minimal spanning tree
+        RootedTree randTree = MinSpanningTreePrims( wg );
+        // add it to the weights
+        addTreeToWeights( randTree );
+        // resample weights of the graph
+        for( WeightedGraph<Real>::iterator e = wg.begin(); e != wg.end(); e++ )
+            e->second = rnd_uniform();
+    }
+
+    // normalize the weights and set the single-variable weights to 1.0
+    for( size_t I = 0; I < nrFactors(); I++ ) {
+        size_t sizeI = factor(I).vars().size();
+        if( sizeI == 1 )
+            _weight[I] = 1.0;
+        else if( sizeI == 2 )
+            _weight[I] /= nrTrees;
+        else
+            DAI_THROW(NOT_IMPLEMENTED);
+    }
 }
 
 
index 0890a29..75c4d80 100644 (file)
@@ -41,7 +41,7 @@ FBP:                            FBP[updates=SEQFIX,tol=1e-9,maxiter=10000,logdom
 
 # --- TRWBP -------------------
 
-TRWBP:                          TRWBP[updates=SEQFIX,tol=1e-9,maxiter=10000,logdomain=0]
+TRWBP:                          TRWBP[updates=SEQFIX,tol=1e-9,maxiter=10000,logdomain=0,nrtrees=0]
 
 # --- JTREE -------------------
 
index 1685526..6135349 100644 (file)
@@ -484,7 +484,7 @@ int main( int argc, char *argv[] ) {
             } else
                 NEED_ARG("N", "number of variables");
 
-            if( states > 2 || ft == FactorType::POTTS ) {
+            if( ft != FactorType::ISING ) {
                 NEED_ARG("beta", "stddev of log-factor entries");
             } else {
                 NEED_ARG("mean_w", "mean of pairwise interactions");