Added max-product functionality to JTree
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Tue, 15 Sep 2009 10:47:21 +0000 (12:47 +0200)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Tue, 15 Sep 2009 10:47:21 +0000 (12:47 +0200)
README
examples/example.cpp
include/dai/jtree.h
src/jtree.cpp
tests/aliases.conf
tests/maxprodbug.fg [new file with mode: 0644]

diff --git a/README b/README
index 91670dd..3c5908a 100644 (file)
--- a/README
+++ b/README
@@ -92,6 +92,10 @@ Currently, libDAI supports the following (approximate) inference methods:
     * Gibbs sampler;
     * Conditioned BP [EaG09].
 
+These inference methods can be used to calculate partition sums, marginals
+over subsets of variables, and MAP states (the joint state of variables that
+has maximum probability).
+
 In addition, libDAI supports parameter learning of conditional probability
 tables by Expectation Maximization.
 
@@ -147,7 +151,8 @@ You need:
 - GNU make
 - doxygen
 - graphviz
-- recent boost C++ libraries (at least version 1.34, or 1.37 for cygwin)
+- recent boost C++ libraries (at least version 1.34, or 1.37 for cygwin;
+  version 1.37 shipped with Ubuntu 9.04 is known not to work)
 
 On Debian/Ubuntu, you can easily install all these packages with a single command:
 "apt-get install g++ make doxygen graphviz libboost-dev libboost-graph-dev libboost-program-options-dev"
index 0705bde..3884fe6 100644 (file)
@@ -54,11 +54,21 @@ int main( int argc, char *argv[] ) {
         // using the parameters specified by opts and an additional property
         // that specifies the type of updates the JTree algorithm should perform
         JTree jt( fg, opts("updates",string("HUGIN")) );
-        // Initialize junction tree algoritm
+        // Initialize junction tree algorithm
         jt.init();
         // Run junction tree algorithm
         jt.run();
 
+        // Construct another JTree (junction tree) object that is used to calculate
+        // the joint configuration of variables that has maximum probability (MAP state)
+        JTree jtmap( fg, opts("updates",string("HUGIN"))("inference",string("MAXPROD")) );
+        // Initialize junction tree algorithm
+        jtmap.init();
+        // Run junction tree algorithm
+        jtmap.run();
+        // Calculate joint state of all variables that has maximum probability
+        vector<size_t> jtmapstate = jtmap.findMaximum();
+
         // Construct a BP (belief propagation) object from the FactorGraph fg
         // using the parameters specified by opts and two additional properties,
         // specifying the type of updates the BP algorithm should perform and
@@ -111,18 +121,33 @@ int main( int argc, char *argv[] ) {
         // Report log partition sum of fg, approximated by the belief propagation algorithm
         cout << "Approximate (loopy belief propagation) log partition sum: " << bp.logZ() << endl;
 
+        // Report exact MAP variable marginals
+        cout << "Exact MAP variable marginals:" << endl;
+        for( size_t i = 0; i < fg.nrVars(); i++ )
+            cout << jtmap.belief(fg.var(i)) << endl;
+
         // Report max-product variable marginals
-        cout << "Max-product variable marginals:" << endl;
+        cout << "Approximate (max-product) MAP variable marginals:" << endl;
         for( size_t i = 0; i < fg.nrVars(); i++ )
             cout << mp.belief(fg.var(i)) << endl;
 
+        // Report exact MAP factor marginals
+        cout << "Exact MAP factor marginals:" << endl;
+        for( size_t I = 0; I < fg.nrFactors(); I++ )
+            cout << jtmap.belief(fg.factor(I).vars()) << "=" << jtmap.beliefF(I) << endl;
+
         // Report max-product factor marginals
-        cout << "Max-product factor marginals:" << endl;
+        cout << "Approximate (max-product) MAP factor marginals:" << endl;
         for( size_t I = 0; I < fg.nrFactors(); I++ )
             cout << mp.belief(fg.factor(I).vars()) << "=" << mp.beliefF(I) << endl;
 
-        // Report max-product joint state
-        cout << "Max-product state:" << endl;
+        // Report exact MAP joint state
+        cout << "Exact MAP state:" << endl;
+        for( size_t i = 0; i < jtmapstate.size(); i++ )
+            cout << fg.var(i) << ": " << jtmapstate[i] << endl;
+
+        // Report max-product MAP joint state
+        cout << "Approximate (max-product) MAP state:" << endl;
         for( size_t i = 0; i < mpstate.size(); i++ )
             cout << fg.var(i) << ": " << mpstate[i] << endl;
     }
index f46e194..295dc2c 100644 (file)
@@ -65,11 +65,17 @@ class JTree : public DAIAlgRG {
             /// Enumeration of possible JTree updates
             DAI_ENUM(UpdateType,HUGIN,SHSH)
 
+            /// Enumeration of inference variants
+            DAI_ENUM(InfType,SUMPROD,MAXPROD);
+
             /// Verbosity
             size_t verbose;
 
             /// Type of updates
             UpdateType updates;
+
+            /// Type of inference: sum-product or max-product?
+            InfType inference;
         } props;
 
         /// Name of this inference algorithm
@@ -119,6 +125,11 @@ class JTree : public DAIAlgRG {
 
         /// Calculates the marginal of a set of variables
         Factor calcMarginal( const VarSet& ns );
+
+        /// Calculates the joint state of all variables that has maximum probability
+        /** Assumes that run() has been called and that props.inference == MAXPROD
+         */
+        std::vector<std::size_t> findMaximum() const;
         //@}
 
     private:
index 13d2ad2..0eecfae 100644 (file)
@@ -21,6 +21,7 @@
 
 
 #include <iostream>
+#include <stack>
 #include <dai/jtree.h>
 
 
@@ -39,6 +40,10 @@ void JTree::setProperties( const PropertySet &opts ) {
 
     props.verbose = opts.getStringAs<size_t>("verbose");
     props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+    if( opts.hasKey("inference") )
+        props.inference = opts.getStringAs<Properties::InfType>("inference");
+    else
+        props.inference = Properties::InfType::SUMPROD;
 }
 
 
@@ -46,6 +51,7 @@ PropertySet JTree::getProperties() const {
     PropertySet opts;
     opts.Set( "verbose", props.verbose );
     opts.Set( "updates", props.updates );
+    opts.Set( "inference", props.inference );
     return opts;
 }
 
@@ -54,7 +60,8 @@ string JTree::printProperties() const {
     stringstream s( stringstream::out );
     s << "[";
     s << "verbose=" << props.verbose << ",";
-    s << "updates=" << props.updates << "]";
+    s << "updates=" << props.updates << ",";
+    s << "inference=" << props.inference << "]";
     return s.str();
 }
 
@@ -219,7 +226,12 @@ void JTree::runHUGIN() {
     for( size_t i = RTree.size(); (i--) != 0; ) {
 //      Make outer region RTree[i].n1 consistent with outer region RTree[i].n2
 //      IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
-        Factor new_Qb = Qa[RTree[i].n2].marginal( IR( i ), false );
+        Factor new_Qb;
+        if( props.inference == Properties::InfType::SUMPROD )
+            new_Qb = Qa[RTree[i].n2].marginal( IR( i ), false );
+        else
+            new_Qb = Qa[RTree[i].n2].maxMarginal( IR( i ), false );
+
         _logZ += log(new_Qb.normalize());
         Qa[RTree[i].n1] *= new_Qb / Qb[i];
         Qb[i] = new_Qb;
@@ -233,7 +245,12 @@ void JTree::runHUGIN() {
     for( size_t i = 0; i < RTree.size(); i++ ) {
 //      Make outer region RTree[i].n2 consistent with outer region RTree[i].n1
 //      IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
-        Factor new_Qb = Qa[RTree[i].n1].marginal( IR( i ) );
+        Factor new_Qb;
+        if( props.inference == Properties::InfType::SUMPROD )
+            new_Qb = Qa[RTree[i].n1].marginal( IR( i ) );
+        else
+            new_Qb = Qa[RTree[i].n1].maxMarginal( IR( i ) );
+
         Qa[RTree[i].n2] *= new_Qb / Qb[i];
         Qb[i] = new_Qb;
     }
@@ -256,11 +273,14 @@ void JTree::runShaferShenoy() {
         size_t j = nbIR(e)[0].node; // = RTree[e].n1
         size_t _e = nbIR(e)[0].dual;
 
-        Factor piet = OR(i);
+        Factor msg = OR(i);
         foreach( const Neighbor &k, nbOR(i) )
             if( k != e )
-                piet *= message( i, k.iter );
-        message( j, _e ) = piet.marginal( IR(e), false );
+                msg *= message( i, k.iter );
+        if( props.inference == Properties::InfType::SUMPROD )
+            message( j, _e ) = msg.marginal( IR(e), false );
+        else
+            message( j, _e ) = msg.maxMarginal( IR(e), false );
         _logZ += log( message(j,_e).normalize() );
     }
 
@@ -270,11 +290,14 @@ void JTree::runShaferShenoy() {
         size_t j = nbIR(e)[1].node; // = RTree[e].n2
         size_t _e = nbIR(e)[1].dual;
 
-        Factor piet = OR(i);
+        Factor msg = OR(i);
         foreach( const Neighbor &k, nbOR(i) )
             if( k != e )
-                piet *= message( i, k.iter );
-        message( j, _e ) = piet.marginal( IR(e) );
+                msg *= message( i, k.iter );
+        if( props.inference == Properties::InfType::SUMPROD )
+            message( j, _e ) = msg.marginal( IR(e) );
+        else
+            message( j, _e ) = msg.maxMarginal( IR(e) );
     }
 
     // Calculate beliefs
@@ -293,8 +316,12 @@ void JTree::runShaferShenoy() {
     }
 
     // Only for logZ (and for belief)...
-    for( size_t beta = 0; beta < nrIRs(); beta++ )
-        Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
+    for( size_t beta = 0; beta < nrIRs(); beta++ ) {
+        if( props.inference == Properties::InfType::SUMPROD )
+            Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
+        else
+            Qb[beta] = Qa[nbIR(beta)[0].node].maxMarginal( IR(beta) );
+    }
 }
 
 
@@ -536,4 +563,83 @@ std::pair<size_t,size_t> treewidth( const FactorGraph & fg ) {
 }
 
 
+std::vector<size_t> JTree::findMaximum() const {
+    vector<size_t> maximum( nrVars() );
+    vector<bool> visitedVars( nrVars(), false );
+    vector<bool> visitedFactors( nrFactors(), false );
+    stack<size_t> scheduledFactors;
+    for( size_t i = 0; i < nrVars(); ++i ) {
+        if( visitedVars[i] )
+            continue;
+        visitedVars[i] = true;
+
+        // Maximise with respect to variable i
+        Prob prod = beliefV(i).p();
+        maximum[i] = max_element( prod.begin(), prod.end() ) - prod.begin();
+
+        foreach( const Neighbor &I, nbV(i) )
+            if( !visitedFactors[I] )
+                scheduledFactors.push(I);
+
+        while( !scheduledFactors.empty() ){
+            size_t I = scheduledFactors.top();
+            scheduledFactors.pop();
+            if( visitedFactors[I] )
+                continue;
+            visitedFactors[I] = true;
+
+            // Evaluate if some neighboring variables still need to be fixed; if not, we're done
+            bool allDetermined = true;
+            foreach( const Neighbor &j, nbF(I) )
+                if( !visitedVars[j.node] ) {
+                    allDetermined = false;
+                    break;
+                }
+            if( allDetermined )
+                continue;
+
+            // Calculate product of incoming messages on factor I
+            Prob prod2 = beliefF(I).p();
+
+            // The allowed configuration is restrained according to the variables assigned so far:
+            // pick the argmax amongst the allowed states
+            Real maxProb = numeric_limits<Real>::min();
+            State maxState( factor(I).vars() );
+            for( State s( factor(I).vars() ); s.valid(); ++s ){
+                // First, calculate whether this state is consistent with variables that
+                // have been assigned already
+                bool allowedState = true;
+                foreach( const Neighbor &j, nbF(I) )
+                    if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
+                        allowedState = false;
+                        break;
+                    }
+                // If it is consistent, check if its probability is larger than what we have seen so far
+                if( allowedState && prod2[s] > maxProb ) {
+                    maxState = s;
+                    maxProb = prod2[s];
+                }
+            }
+
+            // Decode the argmax
+            foreach( const Neighbor &j, nbF(I) ) {
+                if( visitedVars[j.node] ) {
+                    // We have already visited j earlier - hopefully our state is consistent
+                    if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
+                        cerr << "JTree::findMaximum - warning: maximum not consistent due to loops." << endl;
+                } else {
+                    // We found a consistent state for variable j
+                    visitedVars[j.node] = true;
+                    maximum[j.node] = maxState( var(j.node) );
+                    foreach( const Neighbor &J, nbV(j) )
+                        if( !visitedFactors[J] )
+                            scheduledFactors.push(J);
+                }
+            }
+        }
+    }
+    return maximum;
+}
+
+
 } // end of namespace dai
index 5f5b5d5..81eee6a 100644 (file)
@@ -22,6 +22,8 @@ MP_PARALL_LOG:                  BP[updates=PARALL,tol=1e-9,maxiter=10000,logdoma
 
 JTREE_HUGIN:                    JTREE[updates=HUGIN,verbose=0]
 JTREE_SHSH:                     JTREE[updates=SHSH,verbose=0]
+JTREE_HUGIN_MAP:               JTREE[updates=HUGIN,verbose=0,inference=MAXPROD]
+JTREE_SHSH_MAP:                        JTREE[updates=SHSH,verbose=0,inference=MAXPROD]
 
 # --- MF ----------------------
 
diff --git a/tests/maxprodbug.fg b/tests/maxprodbug.fg
new file mode 100644 (file)
index 0000000..8246970
--- /dev/null
@@ -0,0 +1,28 @@
+3
+
+2
+1 2
+2 2
+4
+0 1
+1 0
+2 0
+3 1
+
+2
+1 3
+2 2
+4
+0 1
+1 1
+2 1
+3 1
+
+2
+2 3
+2 2
+4
+0 0
+1 1
+2 1
+3 0