Several changes by Giuseppe Passino
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Mon, 2 Mar 2009 19:35:36 +0000 (20:35 +0100)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Mon, 2 Mar 2009 19:35:36 +0000 (20:35 +0100)
- [Giuseppe Passino] Added BP::findMaximum(), which constructs a
  global state with maximum probability after running max-product
- [Giuseppe Passino] Added iterator interfaces to TProb, TFactor,
  and FactorGraph
- [Giuseppe Passino] Added prefix iterator to State
- [Joris Mooij] Added debug assertions to Var which check for
  inconsistent dimensions of variables with the same labels

example.cpp
include/dai/bp.h
include/dai/factor.h
include/dai/factorgraph.h
include/dai/index.h
include/dai/prob.h
include/dai/var.h
src/bp.cpp

index 47c83f3..1f9bd71 100644 (file)
@@ -55,11 +55,16 @@ int main( int argc, char *argv[] ) {
         bp.init();
         bp.run();
 
-        cout << "Exact single node marginals:" << endl;
+        BP mp(fg, opts("updates",string("SEQFIX"))("logdomain",false)("inference",string("MAXPROD")));
+        mp.init();
+        mp.run();
+        vector<size_t> mpstate = mp.findMaximum();
+
+        cout << "Exact variable marginals:" << endl;
         for( size_t i = 0; i < fg.nrVars(); i++ )
             cout << jt.belief(fg.var(i)) << endl;
 
-        cout << "Approximate (loopy belief propagation) single node marginals:" << endl;
+        cout << "Approximate (loopy belief propagation) variable marginals:" << endl;
         for( size_t i = 0; i < fg.nrVars(); i++ )
             cout << bp.belief(fg.var(i)) << endl;
 
@@ -73,6 +78,18 @@ int main( int argc, char *argv[] ) {
 
         cout << "Exact log partition sum: " << jt.logZ() << endl;
         cout << "Approximate (loopy belief propagation) log partition sum: " << bp.logZ() << endl;
+
+        cout << "Max-product variable marginals:" << endl;
+        for( size_t i = 0; i < fg.nrVars(); i++ )
+            cout << mp.belief(fg.var(i)) << endl;
+
+        cout << "Max-product factor marginals:" << endl;
+        for( size_t I = 0; I < fg.nrFactors(); I++ )
+            cout << mp.belief(fg.factor(I).vars()) << "=" << mp.beliefF(I) << endl;
+
+        cout << "Max-product state:" << endl;
+        for( size_t i = 0; i < mpstate.size(); i++ )
+            cout << fg.var(i) << ": " << mpstate[i] << endl;
     }
 
     return 0;
index c762d7e..d8c9a00 100644 (file)
@@ -137,6 +137,11 @@ class BP : public DAIAlgFG {
         Factor beliefF( size_t I ) const;
         //@}
 
+        /// Calculates the joint state of all variables that has maximum probability
+        /** Assumes that run() has been called and that props.inference == MAXPROD
+         */
+        std::vector<std::size_t> findMaximum() const;
+
     private:
         const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
         Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
@@ -158,6 +163,10 @@ class BP : public DAIAlgFG {
             }
         }
         void findMaxResidual( size_t &i, size_t &_I );
+        /// Calculates unnormalized belief of variable
+        void calcBeliefV( size_t i, Prob &p ) const;
+        /// Calculates unnormalized belief of factor
+        void calcBeliefF( size_t I, Prob &p ) const;
 
         void construct();
         /// Set Props according to the PropertySet opts, where the values can be stored as std::strings or as the type of the corresponding Props member
index d65937a..ad56cd5 100644 (file)
@@ -68,6 +68,12 @@ template <typename T> class TFactor {
         TProb<T>    _p;
 
     public:
+        /// Iterator over factor entries
+               typedef typename TProb<T>::iterator iterator;
+
+        /// Const iterator over factor entries
+               typedef typename TProb<T>::const_iterator const_iterator;
+
         /// Construct Factor with empty VarSet
         TFactor ( Real p = 1.0 ) : _vs(), _p(1,p) {}
 
@@ -118,6 +124,15 @@ template <typename T> class TFactor {
 
         /// Returns a reference to the i'th probability value
         T& operator[] (size_t i) { return _p[i]; }
+        
+        /// Returns iterator pointing to first entry
+        iterator begin() { return _p.begin(); }
+        /// Returns const iterator pointing to first entry
+               const_iterator begin() const { return _p.begin(); }
+               /// Returns iterator pointing beyond last entry
+               iterator end() { return _p.end(); }
+               /// Returns const iterator pointing beyond last entry
+               const_iterator end() const { return _p.end(); }
 
         /// Sets all probability entries to p
         TFactor<T> & fill (T p) { _p.fill( p ); return(*this); }
index 11d8215..2199418 100644 (file)
@@ -73,6 +73,13 @@ class FactorGraph {
 
         /// Shorthand for BipartiteGraph::Edge
         typedef BipartiteGraph::Edge      Edge;
+        
+        /// Iterator over factors
+        typedef std::vector<Factor>::iterator iterator;
+        
+        /// Const iterator over factors
+        typedef std::vector<Factor>::const_iterator const_iterator;
+        
 
     private:
         std::vector<Var>         _vars;
@@ -127,6 +134,14 @@ class FactorGraph {
         const Factor & factor(size_t I) const { return _factors[I]; }
         /// Returns const reference to all factors
         const std::vector<Factor> & factors() const { return _factors; }
+        /// Returns iterator pointing to first factor
+        iterator begin() { return _factors.begin(); }
+        /// Returns const iterator pointing to first factor
+        const_iterator begin() const { return _factors.begin(); }
+        /// Returns iterator pointing beyond last factor
+        iterator end() { return _factors.end(); }
+        /// Returns const iterator pointing beyond last factor
+        const_iterator end() const { return _factors.end(); }
 
         /// Returns number of variables
         size_t nrVars() const { return vars().size(); }
index 8665a6a..1137005 100644 (file)
@@ -355,8 +355,8 @@ class State {
             return vs_state;
         }
 
-        /// Postfix increment operator
-        void operator++( int ) {
+        /// Prefix increment operator
+        void operator++( ) {
             if( valid() ) {
                 state++;
                 states_type::iterator entry = states.begin();
@@ -370,6 +370,11 @@ class State {
                     state = -1;
             }
         }
+        
+        /// Postfix increment operator
+        void operator++( int ) {
+               operator++();
+        }
 
         /// Returns true if the current state is valid
         bool valid() const {
index dbd434b..3200c22 100644 (file)
@@ -61,6 +61,11 @@ template <typename T> class TProb {
         std::vector<T> _p;
 
     public:
+        /// Iterator over entries
+       typedef typename std::vector<T>::iterator iterator;
+        /// Const iterator over entries
+       typedef typename std::vector<T>::const_iterator const_iterator;
+
         /// Enumerates different ways of normalizing a probability measure.
         /** 
          *  - NORMPROB means that the sum of all entries should be 1;
@@ -105,6 +110,18 @@ template <typename T> class TProb {
         
         /// Returns a reference to the i'th probability entry
         T& operator[]( size_t i ) { return _p[i]; }
+        
+        /// Returns iterator pointing to first entry
+        iterator begin() { return _p.begin(); }
+
+        /// Returns const iterator pointing to first entry
+        const_iterator begin() const { return _p.begin(); }
+
+        /// Returns iterator pointing beyond last entry
+        iterator end() { return _p.end(); }
+
+        /// Returns const iterator pointing beyond last entry
+        const_iterator end() const { return _p.end(); }
 
         /// Sets all elements to x
         TProb<T> & fill(T x) { 
index fbb6e72..b4a04ee 100644 (file)
@@ -32,6 +32,7 @@
 
 
 #include <iostream>
+#include <cassert>
 
 
 namespace dai {
@@ -70,13 +71,37 @@ class Var {
         /// Larger-than operator (only compares labels)
         bool operator > ( const Var& n ) const { return( _label >  n._label ); }
         /// Smaller-than-or-equal-to operator (only compares labels)
-        bool operator <= ( const Var& n ) const { return( _label <= n._label ); }
+        bool operator <= ( const Var& n ) const { 
+#ifdef DAI_DEBUG
+            if( _label == n._label )
+                assert( _states == n._states );
+#endif
+            return( _label <= n._label ); 
+        }
         /// Larger-than-or-equal-to operator (only compares labels)
-        bool operator >= ( const Var& n ) const { return( _label >= n._label ); }
+        bool operator >= ( const Var& n ) const { 
+#ifdef DAI_DEBUG
+            if( _label == n._label )
+                assert( _states == n._states );
+#endif
+            return( _label >= n._label ); 
+        }
         /// Not-equal-to operator (only compares labels)
-        bool operator != ( const Var& n ) const { return( _label != n._label ); }
+        bool operator != ( const Var& n ) const { 
+#ifdef DAI_DEBUG
+            if( _label == n._label )
+                assert( _states == n._states );
+#endif
+            return( _label != n._label ); 
+        }
         /// Equal-to operator (only compares labels)
-        bool operator == ( const Var& n ) const { return( _label == n._label ); }
+        bool operator == ( const Var& n ) const { 
+#ifdef DAI_DEBUG
+            if( _label == n._label )
+                assert( _states == n._states );
+#endif            
+            return( _label == n._label ); 
+        }
 
         /// Writes a Var to an output stream
         friend std::ostream& operator << ( std::ostream& os, const Var& n ) {
index 7c8d81a..6417c78 100644 (file)
@@ -1,6 +1,7 @@
-/*  Copyright (C) 2006-2008  Joris Mooij  [joris dot mooij at tuebingen dot mpg dot de]
+/*  Copyright (C) 2006-2009  Joris Mooij  [joris dot mooij at tuebingen dot mpg dot de]
     Radboud University Nijmegen, The Netherlands /
     Max Planck Institute for Biological Cybernetics, Germany
+    Giuseppe Passino
 
     This file is part of libDAI.
 
@@ -25,6 +26,7 @@
 #include <map>
 #include <set>
 #include <algorithm>
+#include <stack>
 #include <dai/bp.h>
 #include <dai/util.h>
 #include <dai/properties.h>
@@ -323,24 +325,63 @@ double BP::run() {
 }
 
 
-Factor BP::beliefV( size_t i ) const {
-    Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 ); 
+void BP::calcBeliefV( size_t i, Prob &p ) const {
+    p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 ); 
     foreach( const Neighbor &I, nbV(i) )
         if( props.logdomain )
-            prod += newMessage( i, I.iter );
+            p += newMessage( i, I.iter );
         else
-            prod *= newMessage( i, I.iter );
+            p *= newMessage( i, I.iter );
+}
+
+
+void BP::calcBeliefF( size_t I, Prob &p ) const {
+    p = factor(I).p();
+    if( props.logdomain )
+        p.takeLog();
+
+    foreach( const Neighbor &j, nbF(I) ) {
+        size_t _I = j.dual;
+        // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+        const ind_t & ind = index(j, _I);
+
+        // prod_j will be the product of messages coming into j
+        Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
+        foreach( const Neighbor &J, nbV(j) ) {
+            if( J != I ) { // for all J in nb(j) \ I 
+                if( props.logdomain )
+                    prod_j += newMessage( j, J.iter );
+                else
+                    prod_j *= newMessage( j, J.iter );
+            }
+        }
+
+        // multiply p with prod_j
+        for( size_t r = 0; r < p.size(); ++r ) {
+            if( props.logdomain )
+                p[r] += prod_j[ind[r]];
+            else
+                p[r] *= prod_j[ind[r]];
+        }
+    }
+}
+
+
+Factor BP::beliefV( size_t i ) const {
+    Prob p;
+    calcBeliefV( i, p );
+
     if( props.logdomain ) {
-        prod -= prod.maxVal();
-        prod.takeExp();
+        p -= p.maxVal();
+        p.takeExp();
     }
 
-    prod.normalize();
-    return( Factor( var(i), prod ) );
+    p.normalize();
+    return( Factor( var(i), p ) );
 }
 
 
-Factor BP::belief (const Var &n) const {
+Factor BP::belief( const Var &n ) const {
     return( beliefV( findVar( n ) ) );
 }
 
@@ -369,7 +410,7 @@ Factor BP::belief( const VarSet &ns ) const {
 }
 
 
-Factor BP::beliefF (size_t I) const {
+Factor BP::beliefF( size_t I ) const {
     if( 0 == 1 ) {
         /*  UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
 
@@ -383,42 +424,17 @@ Factor BP::beliefF (size_t I) const {
         return prod.normalized();
     } else {
         /* OPTIMIZED VERSION */
-        Prob prod( factor(I).p() );
-        if( props.logdomain )
-            prod.takeLog();
-
-        foreach( const Neighbor &j, nbF(I) ) {
-            size_t _I = j.dual;
-            // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
-            const ind_t & ind = index(j, _I);
-
-            // prod_j will be the product of messages coming into j
-            Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
-            foreach( const Neighbor &J, nbV(j) ) {
-                if( J != I ) { // for all J in nb(j) \ I 
-                    if( props.logdomain )
-                        prod_j += newMessage( j, J.iter );
-                    else
-                        prod_j *= newMessage( j, J.iter );
-                }
-            }
 
-            // multiply prod with prod_j
-            for( size_t r = 0; r < prod.size(); ++r ) {
-                if( props.logdomain )
-                    prod[r] += prod_j[ind[r]];
-                else
-                    prod[r] *= prod_j[ind[r]];
-            }
-        }
+        Prob prod;
+        calcBeliefF( I, prod );
 
         if( props.logdomain ) {
             prod -= prod.maxVal();
             prod.takeExp();
         }
+        prod.normalize();
 
         Factor result( factor(I).vars(), prod );
-        result.normalize();
 
         return( result );
     }
@@ -449,4 +465,85 @@ void BP::init( const VarSet &ns ) {
 }
 
 
+std::vector<size_t> BP::findMaximum() const {
+    std::vector<size_t> maximum( nrVars() );
+    std::vector<bool> visitedVars( nrVars(), false );
+    std::vector<bool> visitedFactors( nrFactors(), false );
+    std::stack<size_t> scheduledFactors;
+    for( size_t i = 0; i < nrVars(); ++i ) {
+        if( visitedVars[i] )
+            continue;
+        visitedVars[i] = true;
+        
+        // Maximise with respect to variable i
+        Prob prod;
+        calcBeliefV( i, prod );
+        maximum[i] = std::max_element( prod.begin(), prod.end() ) - prod.begin();
+        
+        foreach( const Neighbor &I, nbV(i) )
+            if( !visitedFactors[I] ) 
+                scheduledFactors.push(I);
+
+        while( !scheduledFactors.empty() ){
+            size_t I = scheduledFactors.top();
+            scheduledFactors.pop();
+            if( visitedFactors[I] )
+                continue;
+            visitedFactors[I] = true;
+            
+            // Evaluate if some neighboring variables still need to be fixed; if not, we're done
+            bool allDetermined = true;
+            foreach( const Neighbor &j, nbF(I) ) 
+                if( !visitedVars[j.node] ) {
+                    allDetermined = false;
+                    break;
+                }
+            if( allDetermined )
+                continue;
+            
+            // Calculate product of incoming messages on factor I
+            Prob prod2;
+            calcBeliefF( I, prod2 );
+
+            // The allowed configuration is restrained according to the variables assigned so far:
+            // pick the argmax amongst the allowed states
+            Real maxProb = std::numeric_limits<Real>::min();
+            State maxState( factor(I).vars() );
+            for( State s( factor(I).vars() ); s.valid(); ++s ){
+                // First, calculate whether this state is consistent with variables that
+                // have been assigned already
+                bool allowedState = true;
+                foreach( const Neighbor &j, nbF(I) )
+                    if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
+                        allowedState = false;
+                        break;
+                    }
+                // If it is consistent, check if its probability is larger than what we have seen so far
+                if( allowedState && prod2[s] > maxProb ) {
+                    maxState = s;
+                    maxProb = prod2[s];
+                }
+            }
+            
+            // Decode the argmax
+            foreach( const Neighbor &j, nbF(I) ) {
+                if( visitedVars[j.node] ) {
+                    // We have already visited j earlier - hopefully our state is consistent
+                    if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
+                        std::cerr << "BP::findMaximum - warning: maximum not consistent due to loops." << std::endl;
+                } else {
+                    // We found a consistent state for variable j
+                    visitedVars[j.node] = true;
+                    maximum[j.node] = maxState( var(j.node) );
+                    foreach( const Neighbor &J, nbV(j) )
+                        if( !visitedFactors[J] ) 
+                            scheduledFactors.push(J);
+                }
+            }
+        }
+    }
+    return maximum;
+}
+
+
 } // end of namespace dai