-/* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
- Radboud University Nijmegen, The Netherlands
-
+/* Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
+ Radboud University Nijmegen, The Netherlands /
+ Max Planck Institute for Biological Cybernetics, Germany
+ Giuseppe Passino
+
This file is part of libDAI.
libDAI is free software; you can redistribute it and/or modify
#include <map>
#include <set>
#include <algorithm>
+#include <stack>
#include <dai/bp.h>
-#include <dai/diffs.h>
#include <dai/util.h>
#include <dai/properties.h>
const char *BP::Name = "BP";
+#define DAI_BP_FAST 1
+
+
void BP::setProperties( const PropertySet &opts ) {
assert( opts.hasKey("tol") );
assert( opts.hasKey("maxiter") );
assert( opts.hasKey("logdomain") );
assert( opts.hasKey("updates") );
-
+
props.tol = opts.getStringAs<double>("tol");
props.maxiter = opts.getStringAs<size_t>("maxiter");
props.logdomain = opts.getStringAs<bool>("logdomain");
// create edge properties
_edges.clear();
_edges.reserve( nrVars() );
+ _edge2lut.clear();
+ if( props.updates == Properties::UpdateType::SEQMAX )
+ _edge2lut.reserve( nrVars() );
for( size_t i = 0; i < nrVars(); ++i ) {
_edges.push_back( vector<EdgeProp>() );
- _edges[i].reserve( nbV(i).size() );
+ _edges[i].reserve( nbV(i).size() );
+ if( props.updates == Properties::UpdateType::SEQMAX ) {
+ _edge2lut.push_back( vector<LutType::iterator>() );
+ _edge2lut[i].reserve( nbV(i).size() );
+ }
foreach( const Neighbor &I, nbV(i) ) {
EdgeProp newEP;
newEP.message = Prob( var(i).states() );
newEP.newMessage = Prob( var(i).states() );
- newEP.index.reserve( factor(I).states() );
- for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
- newEP.index.push_back( k );
+ if( DAI_BP_FAST ) {
+ newEP.index.reserve( factor(I).states() );
+ for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
+ newEP.index.push_back( k );
+ }
newEP.residual = 0.0;
_edges[i].push_back( newEP );
+ if( props.updates == Properties::UpdateType::SEQMAX )
+ _edge2lut[i].push_back( _lut.insert( make_pair( newEP.residual, make_pair( i, _edges[i].size() - 1 ))) );
}
}
}
foreach( const Neighbor &I, nbV(i) ) {
message( i, I.iter ).fill( c );
newMessage( i, I.iter ).fill( c );
+ if( props.updates == Properties::UpdateType::SEQMAX )
+ updateResidual( i, I.iter, 0.0 );
}
}
}
void BP::findMaxResidual( size_t &i, size_t &_I ) {
- i = 0;
- _I = 0;
- double maxres = residual( i, _I );
- for( size_t j = 0; j < nrVars(); ++j )
- foreach( const Neighbor &I, nbV(j) )
- if( residual( j, I.iter ) > maxres ) {
- i = j;
- _I = I.iter;
- maxres = residual( i, _I );
- }
+ assert( !_lut.empty() );
+ LutType::const_iterator largestEl = _lut.end();
+ --largestEl;
+ i = largestEl->second.first;
+ _I = largestEl->second.second;
}
// calculate updated message I->i
size_t I = nbV(i,_I);
- if( 0 == 1 ) {
- /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
- Factor prod( factor( I ) );
- foreach( const Neighbor &j, nbF(I) )
- if( j != i ) { // for all j in I \ i
- foreach( const Neighbor &J, nbV(j) )
- if( J != I ) { // for all J in nb(j) \ I
- prod *= Factor( var(j), message(j, J.iter) );
- }
- }
- newMessage(i,_I) = prod.marginal( var(i) ).p();
- } else {
- /* OPTIMIZED VERSION */
- Prob prod( factor(I).p() );
- if( props.logdomain )
- prod.takeLog();
+ Factor Fprod( factor(I) );
+ Prob &prod = Fprod.p();
+ if( props.logdomain )
+ prod.takeLog();
+
+ // Calculate product of incoming messages and factor I
+ foreach( const Neighbor &j, nbF(I) )
+ if( j != i ) { // for all j in I \ i
+ // prod_j will be the product of messages coming into j
+ Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
+ foreach( const Neighbor &J, nbV(j) )
+ if( J != I ) { // for all J in nb(j) \ I
+ if( props.logdomain )
+ prod_j += message( j, J.iter );
+ else
+ prod_j *= message( j, J.iter );
+ }
- // Calculate product of incoming messages and factor I
- foreach( const Neighbor &j, nbF(I) ) {
- if( j != i ) { // for all j in I \ i
+ // multiply prod with prod_j
+ if( !DAI_BP_FAST ) {
+ /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+ if( props.logdomain )
+ Fprod += Factor( var(j), prod_j );
+ else
+ Fprod *= Factor( var(j), prod_j );
+ } else {
+ /* OPTIMIZED VERSION */
size_t _I = j.dual;
// ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
const ind_t &ind = index(j, _I);
-
- // prod_j will be the product of messages coming into j
- Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
- foreach( const Neighbor &J, nbV(j) )
- if( J != I ) { // for all J in nb(j) \ I
- if( props.logdomain )
- prod_j += message( j, J.iter );
- else
- prod_j *= message( j, J.iter );
- }
-
- // multiply prod with prod_j
for( size_t r = 0; r < prod.size(); ++r )
if( props.logdomain )
prod[r] += prod_j[ind[r]];
prod[r] *= prod_j[ind[r]];
}
}
- if( props.logdomain ) {
- prod -= prod.maxVal();
- prod.takeExp();
- }
- // Marginalize onto i
- Prob marg( var(i).states(), 0.0 );
+ if( props.logdomain ) {
+ prod -= prod.max();
+ prod.takeExp();
+ }
+
+ // Marginalize onto i
+ Prob marg;
+ if( !DAI_BP_FAST ) {
+ /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+ if( props.inference == Properties::InfType::SUMPROD )
+ marg = Fprod.marginal( var(i) ).p();
+ else
+ marg = Fprod.maxMarginal( var(i) ).p();
+ } else {
+ /* OPTIMIZED VERSION */
+ marg = Prob( var(i).states(), 0.0 );
// ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
const ind_t ind = index(i,_I);
- if( props.inference == Properties::InfType::SUMPROD )
+ if( props.inference == Properties::InfType::SUMPROD )
for( size_t r = 0; r < prod.size(); ++r )
marg[ind[r]] += prod[r];
else
for( size_t r = 0; r < prod.size(); ++r )
- if( prod[r] > marg[ind[r]] )
+ if( prod[r] > marg[ind[r]] )
marg[ind[r]] = prod[r];
marg.normalize();
-
- // Store result
- if( props.logdomain )
- newMessage(i,_I) = marg.log();
- else
- newMessage(i,_I) = marg;
}
+
+ // Store result
+ if( props.logdomain )
+ newMessage(i,_I) = marg.log();
+ else
+ newMessage(i,_I) = marg;
+
+ // Update the residual if necessary
+ if( props.updates == Properties::UpdateType::SEQMAX )
+ updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
}
// Somehow NaNs do not often occur in BP...
double BP::run() {
if( props.verbose >= 1 )
- cout << "Starting " << identify() << "...";
+ cerr << "Starting " << identify() << "...";
if( props.verbose >= 3)
- cout << endl;
+ cerr << endl;
double tic = toc();
Diffs diffs(nrVars(), 1.0);
-
+
vector<Edge> update_seq;
vector<Factor> old_beliefs;
for( size_t i = 0; i < nrVars(); ++i )
foreach( const Neighbor &I, nbV(i) ) {
calcNewMessage( i, I.iter );
- // calculate initial residuals
- residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
}
} else {
update_seq.reserve( nredges );
+ /// \todo Investigate whether performance increases by switching the order of following two loops:
for( size_t i = 0; i < nrVars(); ++i )
foreach( const Neighbor &I, nbV(i) )
update_seq.push_back( Edge( i, I.iter ) );
if( J.iter != _I ) {
foreach( const Neighbor &j, nbF(J) ) {
size_t _J = j.dual;
- if( j != i ) {
+ if( j != i )
calcNewMessage( j, _J );
- residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
- }
}
}
}
}
} else if( props.updates == Properties::UpdateType::PARALL ) {
- // Parallel updates
+ // Parallel updates
for( size_t i = 0; i < nrVars(); ++i )
foreach( const Neighbor &I, nbV(i) )
calcNewMessage( i, I.iter );
// Sequential updates
if( props.updates == Properties::UpdateType::SEQRND )
random_shuffle( update_seq.begin(), update_seq.end() );
-
+
foreach( const Edge &e, update_seq ) {
calcNewMessage( e.first, e.second );
updateMessage( e.first, e.second );
}
if( props.verbose >= 3 )
- cout << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
+ cerr << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
}
if( diffs.maxDiff() > _maxdiff )
if( props.verbose >= 1 ) {
if( diffs.maxDiff() > props.tol ) {
if( props.verbose == 1 )
- cout << endl;
- cout << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
+ cerr << endl;
+ cerr << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
} else {
if( props.verbose >= 3 )
- cout << Name << "::run: ";
- cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
+ cerr << Name << "::run: ";
+ cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
}
}
}
-Factor BP::beliefV( size_t i ) const {
- Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 );
+void BP::calcBeliefV( size_t i, Prob &p ) const {
+ p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
foreach( const Neighbor &I, nbV(i) )
if( props.logdomain )
- prod += newMessage( i, I.iter );
+ p += newMessage( i, I.iter );
else
- prod *= newMessage( i, I.iter );
+ p *= newMessage( i, I.iter );
+}
+
+
+void BP::calcBeliefF( size_t I, Prob &p ) const {
+ Factor Fprod( factor( I ) );
+ Prob &prod = Fprod.p();
+
+ if( props.logdomain )
+ prod.takeLog();
+
+ foreach( const Neighbor &j, nbF(I) ) {
+ // prod_j will be the product of messages coming into j
+ Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
+ foreach( const Neighbor &J, nbV(j) )
+ if( J != I ) { // for all J in nb(j) \ I
+ if( props.logdomain )
+ prod_j += newMessage( j, J.iter );
+ else
+ prod_j *= newMessage( j, J.iter );
+ }
+
+ // multiply prod with prod_j
+ if( !DAI_BP_FAST ) {
+ /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+ if( props.logdomain )
+ Fprod += Factor( var(j), prod_j );
+ else
+ Fprod *= Factor( var(j), prod_j );
+ } else {
+ /* OPTIMIZED VERSION */
+ size_t _I = j.dual;
+ // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+ const ind_t & ind = index(j, _I);
+
+ for( size_t r = 0; r < prod.size(); ++r ) {
+ if( props.logdomain )
+ prod[r] += prod_j[ind[r]];
+ else
+ prod[r] *= prod_j[ind[r]];
+ }
+ }
+ }
+
+ p = prod;
+}
+
+
+Factor BP::beliefV( size_t i ) const {
+ Prob p;
+ calcBeliefV( i, p );
+
if( props.logdomain ) {
- prod -= prod.maxVal();
- prod.takeExp();
+ p -= p.max();
+ p.takeExp();
}
+ p.normalize();
- prod.normalize();
- return( Factor( var(i), prod ) );
+ return( Factor( var(i), p ) );
}
-Factor BP::belief (const Var &n) const {
+Factor BP::beliefF( size_t I ) const {
+ Prob p;
+ calcBeliefF( I, p );
+
+ if( props.logdomain ) {
+ p -= p.max();
+ p.takeExp();
+ }
+ p.normalize();
+
+ return( Factor( factor(I).vars(), p ) );
+}
+
+
+Factor BP::belief( const Var &n ) const {
return( beliefV( findVar( n ) ) );
}
}
-Factor BP::beliefF (size_t I) const {
- if( 0 == 1 ) {
- /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
-
- Factor prod( factor(I) );
- foreach( const Neighbor &j, nbF(I) ) {
- foreach( const Neighbor &J, nbV(j) ) {
- if( J != I ) // for all J in nb(j) \ I
- prod *= Factor( var(j), newMessage(j, J.iter) );
- }
- }
- return prod.normalized();
- } else {
- /* OPTIMIZED VERSION */
- Prob prod( factor(I).p() );
- if( props.logdomain )
- prod.takeLog();
-
- foreach( const Neighbor &j, nbF(I) ) {
- size_t _I = j.dual;
- // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
- const ind_t & ind = index(j, _I);
-
- // prod_j will be the product of messages coming into j
- Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
- foreach( const Neighbor &J, nbV(j) ) {
- if( J != I ) { // for all J in nb(j) \ I
- if( props.logdomain )
- prod_j += newMessage( j, J.iter );
- else
- prod_j *= newMessage( j, J.iter );
- }
- }
-
- // multiply prod with prod_j
- for( size_t r = 0; r < prod.size(); ++r ) {
- if( props.logdomain )
- prod[r] += prod_j[ind[r]];
- else
- prod[r] *= prod_j[ind[r]];
- }
- }
-
- if( props.logdomain ) {
- prod -= prod.maxVal();
- prod.takeExp();
- }
-
- Factor result( factor(I).vars(), prod );
- result.normalize();
-
- return( result );
- }
-}
-
-
Real BP::logZ() const {
Real sum = 0.0;
for(size_t i = 0; i < nrVars(); ++i )
sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
for( size_t I = 0; I < nrFactors(); ++I )
- sum -= KL_dist( beliefF(I), factor(I) );
+ sum -= dist( beliefF(I), factor(I), Prob::DISTKL );
return sum;
}
-string BP::identify() const {
+string BP::identify() const {
return string(Name) + printProperties();
}
void BP::init( const VarSet &ns ) {
for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
size_t ni = findVar( *n );
- foreach( const Neighbor &I, nbV( ni ) )
- message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
+ foreach( const Neighbor &I, nbV( ni ) ) {
+ double val = props.logdomain ? 0.0 : 1.0;
+ message( ni, I.iter ).fill( val );
+ newMessage( ni, I.iter ).fill( val );
+ if( props.updates == Properties::UpdateType::SEQMAX )
+ updateResidual( ni, I.iter, 0.0 );
+ }
+ }
+}
+
+
+void BP::updateMessage( size_t i, size_t _I ) {
+ if( recordSentMessages )
+ _sentMessages.push_back(make_pair(i,_I));
+ if( props.damping == 0.0 ) {
+ message(i,_I) = newMessage(i,_I);
+ if( props.updates == Properties::UpdateType::SEQMAX )
+ updateResidual( i, _I, 0.0 );
+ } else {
+ message(i,_I) = (message(i,_I) ^ props.damping) * (newMessage(i,_I) ^ (1.0 - props.damping));
+ if( props.updates == Properties::UpdateType::SEQMAX )
+ updateResidual( i, _I, dist( newMessage(i,_I), message(i,_I), Prob::DISTLINF ) );
+ }
+}
+
+
+void BP::updateResidual( size_t i, size_t _I, double r ) {
+ EdgeProp* pEdge = &_edges[i][_I];
+ pEdge->residual = r;
+
+ // rearrange look-up table (delete and reinsert new key)
+ _lut.erase( _edge2lut[i][_I] );
+ _edge2lut[i][_I] = _lut.insert( make_pair( r, make_pair(i, _I) ) );
+}
+
+
+std::vector<size_t> BP::findMaximum() const {
+ vector<size_t> maximum( nrVars() );
+ vector<bool> visitedVars( nrVars(), false );
+ vector<bool> visitedFactors( nrFactors(), false );
+ stack<size_t> scheduledFactors;
+ for( size_t i = 0; i < nrVars(); ++i ) {
+ if( visitedVars[i] )
+ continue;
+ visitedVars[i] = true;
+
+ // Maximise with respect to variable i
+ Prob prod;
+ calcBeliefV( i, prod );
+ maximum[i] = max_element( prod.begin(), prod.end() ) - prod.begin();
+
+ foreach( const Neighbor &I, nbV(i) )
+ if( !visitedFactors[I] )
+ scheduledFactors.push(I);
+
+ while( !scheduledFactors.empty() ){
+ size_t I = scheduledFactors.top();
+ scheduledFactors.pop();
+ if( visitedFactors[I] )
+ continue;
+ visitedFactors[I] = true;
+
+ // Evaluate if some neighboring variables still need to be fixed; if not, we're done
+ bool allDetermined = true;
+ foreach( const Neighbor &j, nbF(I) )
+ if( !visitedVars[j.node] ) {
+ allDetermined = false;
+ break;
+ }
+ if( allDetermined )
+ continue;
+
+ // Calculate product of incoming messages on factor I
+ Prob prod2;
+ calcBeliefF( I, prod2 );
+
+ // The allowed configuration is restrained according to the variables assigned so far:
+ // pick the argmax amongst the allowed states
+ Real maxProb = numeric_limits<Real>::min();
+ State maxState( factor(I).vars() );
+ for( State s( factor(I).vars() ); s.valid(); ++s ){
+ // First, calculate whether this state is consistent with variables that
+ // have been assigned already
+ bool allowedState = true;
+ foreach( const Neighbor &j, nbF(I) )
+ if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
+ allowedState = false;
+ break;
+ }
+ // If it is consistent, check if its probability is larger than what we have seen so far
+ if( allowedState && prod2[s] > maxProb ) {
+ maxState = s;
+ maxProb = prod2[s];
+ }
+ }
+
+ // Decode the argmax
+ foreach( const Neighbor &j, nbF(I) ) {
+ if( visitedVars[j.node] ) {
+ // We have already visited j earlier - hopefully our state is consistent
+ if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
+ cerr << "BP::findMaximum - warning: maximum not consistent due to loops." << endl;
+ } else {
+ // We found a consistent state for variable j
+ visitedVars[j.node] = true;
+ maximum[j.node] = maxState( var(j.node) );
+ foreach( const Neighbor &J, nbV(j) )
+ if( !visitedFactors[J] )
+ scheduledFactors.push(J);
+ }
+ }
+ }
}
+ return maximum;
}