bp.init();
bp.run();
- cout << "Exact single node marginals:" << endl;
+ BP mp(fg, opts("updates",string("SEQFIX"))("logdomain",false)("inference",string("MAXPROD")));
+ mp.init();
+ mp.run();
+ vector<size_t> mpstate = mp.findMaximum();
+
+ cout << "Exact variable marginals:" << endl;
for( size_t i = 0; i < fg.nrVars(); i++ )
cout << jt.belief(fg.var(i)) << endl;
- cout << "Approximate (loopy belief propagation) single node marginals:" << endl;
+ cout << "Approximate (loopy belief propagation) variable marginals:" << endl;
for( size_t i = 0; i < fg.nrVars(); i++ )
cout << bp.belief(fg.var(i)) << endl;
cout << "Exact log partition sum: " << jt.logZ() << endl;
cout << "Approximate (loopy belief propagation) log partition sum: " << bp.logZ() << endl;
+
+ cout << "Max-product variable marginals:" << endl;
+ for( size_t i = 0; i < fg.nrVars(); i++ )
+ cout << mp.belief(fg.var(i)) << endl;
+
+ cout << "Max-product factor marginals:" << endl;
+ for( size_t I = 0; I < fg.nrFactors(); I++ )
+ cout << mp.belief(fg.factor(I).vars()) << "=" << mp.beliefF(I) << endl;
+
+ cout << "Max-product state:" << endl;
+ for( size_t i = 0; i < mpstate.size(); i++ )
+ cout << fg.var(i) << ": " << mpstate[i] << endl;
}
return 0;
Factor beliefF( size_t I ) const;
//@}
+ /// Calculates the joint state of all variables that has maximum probability
+ /** Assumes that run() has been called and that props.inference == MAXPROD
+ */
+ std::vector<std::size_t> findMaximum() const;
+
private:
const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
}
}
void findMaxResidual( size_t &i, size_t &_I );
+ /// Calculates unnormalized belief of variable
+ void calcBeliefV( size_t i, Prob &p ) const;
+ /// Calculates unnormalized belief of factor
+ void calcBeliefF( size_t I, Prob &p ) const;
void construct();
/// Set Props according to the PropertySet opts, where the values can be stored as std::strings or as the type of the corresponding Props member
TProb<T> _p;
public:
+ /// Iterator over factor entries
+ typedef typename TProb<T>::iterator iterator;
+
+ /// Const iterator over factor entries
+ typedef typename TProb<T>::const_iterator const_iterator;
+
/// Construct Factor with empty VarSet
TFactor ( Real p = 1.0 ) : _vs(), _p(1,p) {}
/// Returns a reference to the i'th probability value
T& operator[] (size_t i) { return _p[i]; }
+
+ /// Returns iterator pointing to first entry
+ iterator begin() { return _p.begin(); }
+ /// Returns const iterator pointing to first entry
+ const_iterator begin() const { return _p.begin(); }
+ /// Returns iterator pointing beyond last entry
+ iterator end() { return _p.end(); }
+ /// Returns const iterator pointing beyond last entry
+ const_iterator end() const { return _p.end(); }
/// Sets all probability entries to p
TFactor<T> & fill (T p) { _p.fill( p ); return(*this); }
/// Shorthand for BipartiteGraph::Edge
typedef BipartiteGraph::Edge Edge;
+
+ /// Iterator over factors
+ typedef std::vector<Factor>::iterator iterator;
+
+ /// Const iterator over factors
+ typedef std::vector<Factor>::const_iterator const_iterator;
+
private:
std::vector<Var> _vars;
const Factor & factor(size_t I) const { return _factors[I]; }
/// Returns const reference to all factors
const std::vector<Factor> & factors() const { return _factors; }
+ /// Returns iterator pointing to first factor
+ iterator begin() { return _factors.begin(); }
+ /// Returns const iterator pointing to first factor
+ const_iterator begin() const { return _factors.begin(); }
+ /// Returns iterator pointing beyond last factor
+ iterator end() { return _factors.end(); }
+ /// Returns const iterator pointing beyond last factor
+ const_iterator end() const { return _factors.end(); }
/// Returns number of variables
size_t nrVars() const { return vars().size(); }
return vs_state;
}
- /// Postfix increment operator
- void operator++( int ) {
+ /// Prefix increment operator
+ void operator++( ) {
if( valid() ) {
state++;
states_type::iterator entry = states.begin();
state = -1;
}
}
+
+ /// Postfix increment operator
+ void operator++( int ) {
+ operator++();
+ }
/// Returns true if the current state is valid
bool valid() const {
std::vector<T> _p;
public:
+ /// Iterator over entries
+ typedef typename std::vector<T>::iterator iterator;
+ /// Const iterator over entries
+ typedef typename std::vector<T>::const_iterator const_iterator;
+
/// Enumerates different ways of normalizing a probability measure.
/**
* - NORMPROB means that the sum of all entries should be 1;
/// Returns a reference to the i'th probability entry
T& operator[]( size_t i ) { return _p[i]; }
+
+ /// Returns iterator pointing to first entry
+ iterator begin() { return _p.begin(); }
+
+ /// Returns const iterator pointing to first entry
+ const_iterator begin() const { return _p.begin(); }
+
+ /// Returns iterator pointing beyond last entry
+ iterator end() { return _p.end(); }
+
+ /// Returns const iterator pointing beyond last entry
+ const_iterator end() const { return _p.end(); }
/// Sets all elements to x
TProb<T> & fill(T x) {
#include <iostream>
+#include <cassert>
namespace dai {
/// Larger-than operator (only compares labels)
bool operator > ( const Var& n ) const { return( _label > n._label ); }
/// Smaller-than-or-equal-to operator (only compares labels)
- bool operator <= ( const Var& n ) const { return( _label <= n._label ); }
+ bool operator <= ( const Var& n ) const {
+#ifdef DAI_DEBUG
+ if( _label == n._label )
+ assert( _states == n._states );
+#endif
+ return( _label <= n._label );
+ }
/// Larger-than-or-equal-to operator (only compares labels)
- bool operator >= ( const Var& n ) const { return( _label >= n._label ); }
+ bool operator >= ( const Var& n ) const {
+#ifdef DAI_DEBUG
+ if( _label == n._label )
+ assert( _states == n._states );
+#endif
+ return( _label >= n._label );
+ }
/// Not-equal-to operator (only compares labels)
- bool operator != ( const Var& n ) const { return( _label != n._label ); }
+ bool operator != ( const Var& n ) const {
+#ifdef DAI_DEBUG
+ if( _label == n._label )
+ assert( _states == n._states );
+#endif
+ return( _label != n._label );
+ }
/// Equal-to operator (only compares labels)
- bool operator == ( const Var& n ) const { return( _label == n._label ); }
+ bool operator == ( const Var& n ) const {
+#ifdef DAI_DEBUG
+ if( _label == n._label )
+ assert( _states == n._states );
+#endif
+ return( _label == n._label );
+ }
/// Writes a Var to an output stream
friend std::ostream& operator << ( std::ostream& os, const Var& n ) {
-/* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
+/* Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
Radboud University Nijmegen, The Netherlands /
Max Planck Institute for Biological Cybernetics, Germany
+ Giuseppe Passino
This file is part of libDAI.
#include <map>
#include <set>
#include <algorithm>
+#include <stack>
#include <dai/bp.h>
#include <dai/util.h>
#include <dai/properties.h>
}
-Factor BP::beliefV( size_t i ) const {
- Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 );
+void BP::calcBeliefV( size_t i, Prob &p ) const {
+ p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
foreach( const Neighbor &I, nbV(i) )
if( props.logdomain )
- prod += newMessage( i, I.iter );
+ p += newMessage( i, I.iter );
else
- prod *= newMessage( i, I.iter );
+ p *= newMessage( i, I.iter );
+}
+
+
+void BP::calcBeliefF( size_t I, Prob &p ) const {
+ p = factor(I).p();
+ if( props.logdomain )
+ p.takeLog();
+
+ foreach( const Neighbor &j, nbF(I) ) {
+ size_t _I = j.dual;
+ // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+ const ind_t & ind = index(j, _I);
+
+ // prod_j will be the product of messages coming into j
+ Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
+ foreach( const Neighbor &J, nbV(j) ) {
+ if( J != I ) { // for all J in nb(j) \ I
+ if( props.logdomain )
+ prod_j += newMessage( j, J.iter );
+ else
+ prod_j *= newMessage( j, J.iter );
+ }
+ }
+
+ // multiply p with prod_j
+ for( size_t r = 0; r < p.size(); ++r ) {
+ if( props.logdomain )
+ p[r] += prod_j[ind[r]];
+ else
+ p[r] *= prod_j[ind[r]];
+ }
+ }
+}
+
+
+Factor BP::beliefV( size_t i ) const {
+ Prob p;
+ calcBeliefV( i, p );
+
if( props.logdomain ) {
- prod -= prod.maxVal();
- prod.takeExp();
+ p -= p.maxVal();
+ p.takeExp();
}
- prod.normalize();
- return( Factor( var(i), prod ) );
+ p.normalize();
+ return( Factor( var(i), p ) );
}
-Factor BP::belief (const Var &n) const {
+Factor BP::belief( const Var &n ) const {
return( beliefV( findVar( n ) ) );
}
}
-Factor BP::beliefF (size_t I) const {
+Factor BP::beliefF( size_t I ) const {
if( 0 == 1 ) {
/* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
return prod.normalized();
} else {
/* OPTIMIZED VERSION */
- Prob prod( factor(I).p() );
- if( props.logdomain )
- prod.takeLog();
-
- foreach( const Neighbor &j, nbF(I) ) {
- size_t _I = j.dual;
- // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
- const ind_t & ind = index(j, _I);
-
- // prod_j will be the product of messages coming into j
- Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
- foreach( const Neighbor &J, nbV(j) ) {
- if( J != I ) { // for all J in nb(j) \ I
- if( props.logdomain )
- prod_j += newMessage( j, J.iter );
- else
- prod_j *= newMessage( j, J.iter );
- }
- }
- // multiply prod with prod_j
- for( size_t r = 0; r < prod.size(); ++r ) {
- if( props.logdomain )
- prod[r] += prod_j[ind[r]];
- else
- prod[r] *= prod_j[ind[r]];
- }
- }
+ Prob prod;
+ calcBeliefF( I, prod );
if( props.logdomain ) {
prod -= prod.maxVal();
prod.takeExp();
}
+ prod.normalize();
Factor result( factor(I).vars(), prod );
- result.normalize();
return( result );
}
}
+std::vector<size_t> BP::findMaximum() const {
+ std::vector<size_t> maximum( nrVars() );
+ std::vector<bool> visitedVars( nrVars(), false );
+ std::vector<bool> visitedFactors( nrFactors(), false );
+ std::stack<size_t> scheduledFactors;
+ for( size_t i = 0; i < nrVars(); ++i ) {
+ if( visitedVars[i] )
+ continue;
+ visitedVars[i] = true;
+
+ // Maximise with respect to variable i
+ Prob prod;
+ calcBeliefV( i, prod );
+ maximum[i] = std::max_element( prod.begin(), prod.end() ) - prod.begin();
+
+ foreach( const Neighbor &I, nbV(i) )
+ if( !visitedFactors[I] )
+ scheduledFactors.push(I);
+
+ while( !scheduledFactors.empty() ){
+ size_t I = scheduledFactors.top();
+ scheduledFactors.pop();
+ if( visitedFactors[I] )
+ continue;
+ visitedFactors[I] = true;
+
+ // Evaluate if some neighboring variables still need to be fixed; if not, we're done
+ bool allDetermined = true;
+ foreach( const Neighbor &j, nbF(I) )
+ if( !visitedVars[j.node] ) {
+ allDetermined = false;
+ break;
+ }
+ if( allDetermined )
+ continue;
+
+ // Calculate product of incoming messages on factor I
+ Prob prod2;
+ calcBeliefF( I, prod2 );
+
+ // The allowed configuration is restrained according to the variables assigned so far:
+ // pick the argmax amongst the allowed states
+ Real maxProb = std::numeric_limits<Real>::min();
+ State maxState( factor(I).vars() );
+ for( State s( factor(I).vars() ); s.valid(); ++s ){
+ // First, calculate whether this state is consistent with variables that
+ // have been assigned already
+ bool allowedState = true;
+ foreach( const Neighbor &j, nbF(I) )
+ if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
+ allowedState = false;
+ break;
+ }
+ // If it is consistent, check if its probability is larger than what we have seen so far
+ if( allowedState && prod2[s] > maxProb ) {
+ maxState = s;
+ maxProb = prod2[s];
+ }
+ }
+
+ // Decode the argmax
+ foreach( const Neighbor &j, nbF(I) ) {
+ if( visitedVars[j.node] ) {
+ // We have already visited j earlier - hopefully our state is consistent
+ if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
+ std::cerr << "BP::findMaximum - warning: maximum not consistent due to loops." << std::endl;
+ } else {
+ // We found a consistent state for variable j
+ visitedVars[j.node] = true;
+ maximum[j.node] = maxState( var(j.node) );
+ foreach( const Neighbor &J, nbV(j) )
+ if( !visitedFactors[J] )
+ scheduledFactors.push(J);
+ }
+ }
+ }
+ }
+ return maximum;
+}
+
+
} // end of namespace dai