-/* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
- Radboud University Nijmegen, The Netherlands
-
+/* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
+ Radboud University Nijmegen, The Netherlands /
+ Max Planck Institute for Biological Cybernetics, Germany
+
This file is part of libDAI.
libDAI is free software; you can redistribute it and/or modify
#include <set>
#include <algorithm>
#include <dai/bp.h>
-#include <dai/diffs.h>
#include <dai/util.h>
#include <dai/properties.h>
const char *BP::Name = "BP";
-bool BP::checkProperties() {
- if( !HasProperty("updates") )
- return false;
- if( !HasProperty("tol") )
- return false;
- if (!HasProperty("maxiter") )
- return false;
- if (!HasProperty("verbose") )
- return false;
+#define DAI_BP_FAST 1
+
+
+void BP::setProperties( const PropertySet &opts ) {
+ assert( opts.hasKey("tol") );
+ assert( opts.hasKey("maxiter") );
+ assert( opts.hasKey("logdomain") );
+ assert( opts.hasKey("updates") );
- ConvertPropertyTo<double>("tol");
- ConvertPropertyTo<size_t>("maxiter");
- ConvertPropertyTo<size_t>("verbose");
- ConvertPropertyTo<UpdateType>("updates");
+ props.tol = opts.getStringAs<double>("tol");
+ props.maxiter = opts.getStringAs<size_t>("maxiter");
+ props.logdomain = opts.getStringAs<bool>("logdomain");
+ props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+
+ if( opts.hasKey("verbose") )
+ props.verbose = opts.getStringAs<size_t>("verbose");
+ else
+ props.verbose = 0;
+ if( opts.hasKey("damping") )
+ props.damping = opts.getStringAs<double>("damping");
+ else
+ props.damping = 0.0;
+ if( opts.hasKey("inference") )
+ props.inference = opts.getStringAs<Properties::InfType>("inference");
+ else
+ props.inference = Properties::InfType::SUMPROD;
+}
+
- return true;
+PropertySet BP::getProperties() const {
+ PropertySet opts;
+ opts.Set( "tol", props.tol );
+ opts.Set( "maxiter", props.maxiter );
+ opts.Set( "verbose", props.verbose );
+ opts.Set( "logdomain", props.logdomain );
+ opts.Set( "updates", props.updates );
+ opts.Set( "damping", props.damping );
+ opts.Set( "inference", props.inference );
+ return opts;
}
-void BP::Regenerate() {
- DAIAlgFG::Regenerate();
-
- // clear messages
- _messages.clear();
- _messages.reserve(nr_edges());
-
- // clear indices
- _indices.clear();
- _indices.reserve(nr_edges());
-
- // create messages and indices
- for( vector<_edge_t>::const_iterator iI=edges().begin(); iI!=edges().end(); ++iI ) {
- _messages.push_back( Prob( var(iI->first).states() ) );
-
- vector<size_t> ind( factor(iI->second).stateSpace(), 0 );
- Index i (var(iI->first), factor(iI->second).vars() );
- for( size_t j = 0; i >= 0; ++i,++j )
- ind[j] = i;
- _indices.push_back( ind );
- }
+string BP::printProperties() const {
+ stringstream s( stringstream::out );
+ s << "[";
+ s << "tol=" << props.tol << ",";
+ s << "maxiter=" << props.maxiter << ",";
+ s << "verbose=" << props.verbose << ",";
+ s << "logdomain=" << props.logdomain << ",";
+ s << "updates=" << props.updates << ",";
+ s << "damping=" << props.damping << ",";
+ s << "inference=" << props.inference << "]";
+ return s.str();
+}
+
+
+void BP::construct() {
+ // create edge properties
+ _edges.clear();
+ _edges.reserve( nrVars() );
+ for( size_t i = 0; i < nrVars(); ++i ) {
+ _edges.push_back( vector<EdgeProp>() );
+ _edges[i].reserve( nbV(i).size() );
+ foreach( const Neighbor &I, nbV(i) ) {
+ EdgeProp newEP;
+ newEP.message = Prob( var(i).states() );
+ newEP.newMessage = Prob( var(i).states() );
+
+ if( DAI_BP_FAST ) {
+ newEP.index.reserve( factor(I).states() );
+ for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
+ newEP.index.push_back( k );
+ }
- // create new_messages
- _newmessages = _messages;
+ newEP.residual = 0.0;
+ _edges[i].push_back( newEP );
+ }
+ }
}
void BP::init() {
- assert( checkProperties() );
- for( vector<Prob>::iterator mij = _messages.begin(); mij != _messages.end(); ++mij )
- mij->fill(1.0 / mij->size());
- _newmessages = _messages;
+ double c = props.logdomain ? 0.0 : 1.0;
+ for( size_t i = 0; i < nrVars(); ++i ) {
+ foreach( const Neighbor &I, nbV(i) ) {
+ message( i, I.iter ).fill( c );
+ newMessage( i, I.iter ).fill( c );
+ }
+ }
}
-void BP::calcNewMessage (size_t iI) {
- // calculate updated message I->i
- size_t i = edge(iI).first;
- size_t I = edge(iI).second;
-
-/* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
-
- Factor prod( factor( I ) );
- for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ )
- if( *j != i ) { // for all j in I \ i
- for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); J++ )
- if( *J != I ) { // for all J in nb(j) \ I
- prod *= Factor( *j, message(*j,*J) );
- Factor marg = prod.marginal(var(i));
-*/
-
- Prob prod( factor(I).p() );
-
- // Calculate product of incoming messages and factor I
- for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); ++j )
- if( *j != i ) { // for all j in I \ i
- // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
- _ind_t* ind = &(index(*j,I));
+void BP::findMaxResidual( size_t &i, size_t &_I ) {
+ i = 0;
+ _I = 0;
+ double maxres = residual( i, _I );
+ for( size_t j = 0; j < nrVars(); ++j )
+ foreach( const Neighbor &I, nbV(j) )
+ if( residual( j, I.iter ) > maxres ) {
+ i = j;
+ _I = I.iter;
+ maxres = residual( i, _I );
+ }
+}
- // prod_j will be the product of messages coming into j
- Prob prod_j( var(*j).states() );
- for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); ++J )
- if( *J != I ) // for all J in nb(j) \ I
- prod_j *= message(*j,*J);
- // multiply prod with prod_j
- for( size_t r = 0; r < prod.size(); ++r )
- prod[r] *= prod_j[(*ind)[r]];
+void BP::calcNewMessage( size_t i, size_t _I ) {
+ // calculate updated message I->i
+ size_t I = nbV(i,_I);
+
+ if( !DAI_BP_FAST ) {
+ /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+ Factor prod( factor( I ) );
+ foreach( const Neighbor &j, nbF(I) )
+ if( j != i ) { // for all j in I \ i
+ foreach( const Neighbor &J, nbV(j) )
+ if( J != I ) { // for all J in nb(j) \ I
+ prod *= Factor( var(j), message(j, J.iter) );
+ }
+ }
+ newMessage(i,_I) = prod.marginal( var(i) ).p();
+ } else {
+ /* OPTIMIZED VERSION */
+ Prob prod( factor(I).p() );
+ if( props.logdomain )
+ prod.takeLog();
+
+ // Calculate product of incoming messages and factor I
+ foreach( const Neighbor &j, nbF(I) ) {
+ if( j != i ) { // for all j in I \ i
+ size_t _I = j.dual;
+ // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+ const ind_t &ind = index(j, _I);
+
+ // prod_j will be the product of messages coming into j
+ Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
+ foreach( const Neighbor &J, nbV(j) )
+ if( J != I ) { // for all J in nb(j) \ I
+ if( props.logdomain )
+ prod_j += message( j, J.iter );
+ else
+ prod_j *= message( j, J.iter );
+ }
+
+ // multiply prod with prod_j
+ for( size_t r = 0; r < prod.size(); ++r )
+ if( props.logdomain )
+ prod[r] += prod_j[ind[r]];
+ else
+ prod[r] *= prod_j[ind[r]];
+ }
+ }
+ if( props.logdomain ) {
+ prod -= prod.maxVal();
+ prod.takeExp();
}
- // Marginalize onto i
- Prob marg( var(i).states(), 0.0 );
- // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
- _ind_t* ind = &(index(i,I));
- for( size_t r = 0; r < prod.size(); ++r )
- marg[(*ind)[r]] += prod[r];
- marg.normalize( _normtype );
-
- // Store result
- _newmessages[iI] = marg;
+ // Marginalize onto i
+ Prob marg( var(i).states(), 0.0 );
+ // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
+ const ind_t ind = index(i,_I);
+ if( props.inference == Properties::InfType::SUMPROD )
+ for( size_t r = 0; r < prod.size(); ++r )
+ marg[ind[r]] += prod[r];
+ else
+ for( size_t r = 0; r < prod.size(); ++r )
+ if( prod[r] > marg[ind[r]] )
+ marg[ind[r]] = prod[r];
+ marg.normalize();
+
+ // Store result
+ if( props.logdomain )
+ newMessage(i,_I) = marg.log();
+ else
+ newMessage(i,_I) = marg;
+ }
}
// BP::run does not check for NANs for performance reasons
// Somehow NaNs do not often occur in BP...
double BP::run() {
- if( Verbose() >= 1 )
+ if( props.verbose >= 1 )
cout << "Starting " << identify() << "...";
- if( Verbose() >= 3)
+ if( props.verbose >= 3)
cout << endl;
- clock_t tic = toc();
+ double tic = toc();
Diffs diffs(nrVars(), 1.0);
- vector<size_t> edge_seq;
- vector<double> residuals;
+ vector<Edge> update_seq;
vector<Factor> old_beliefs;
old_beliefs.reserve( nrVars() );
for( size_t i = 0; i < nrVars(); ++i )
- old_beliefs.push_back(belief1(i));
+ old_beliefs.push_back( beliefV(i) );
- size_t iter = 0;
+ size_t nredges = nrEdges();
- if( Updates() == UpdateType::SEQMAX ) {
+ if( props.updates == Properties::UpdateType::SEQMAX ) {
// do the first pass
- for(size_t iI = 0; iI < nr_edges(); ++iI )
- calcNewMessage(iI);
-
- // calculate initial residuals
- residuals.reserve(nr_edges());
- for( size_t iI = 0; iI < nr_edges(); ++iI )
- residuals.push_back( dist( _newmessages[iI], _messages[iI], Prob::DISTLINF ) );
+ for( size_t i = 0; i < nrVars(); ++i )
+ foreach( const Neighbor &I, nbV(i) ) {
+ calcNewMessage( i, I.iter );
+ // calculate initial residuals
+ residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
+ }
} else {
- edge_seq.reserve( nr_edges() );
- for( size_t i = 0; i < nr_edges(); ++i )
- edge_seq.push_back( i );
+ update_seq.reserve( nredges );
+ for( size_t i = 0; i < nrVars(); ++i )
+ foreach( const Neighbor &I, nbV(i) )
+ update_seq.push_back( Edge( i, I.iter ) );
}
// do several passes over the network until maximum number of iterations has
// been reached or until the maximum belief difference is smaller than tolerance
- for( iter=0; iter < MaxIter() && diffs.max() > Tol(); ++iter ) {
- if( Updates() == UpdateType::SEQMAX ) {
+ for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; ++_iters ) {
+ if( props.updates == Properties::UpdateType::SEQMAX ) {
// Residuals-BP by Koller et al.
- for( size_t t = 0; t < nr_edges(); ++t ) {
+ for( size_t t = 0; t < nredges; ++t ) {
// update the message with the largest residual
- size_t iI = max_element(residuals.begin(), residuals.end()) - residuals.begin();
- _messages[iI] = _newmessages[iI];
- residuals[iI] = 0;
+ size_t i, _I;
+ findMaxResidual( i, _I );
+ updateMessage( i, _I );
// I->i has been updated, which means that residuals for all
// J->j with J in nb[i]\I and j in nb[J]\i have to be updated
- size_t i = edge(iI).first;
- size_t I = edge(iI).second;
- for( _nb_cit J = nb1(i).begin(); J != nb1(i).end(); ++J )
- if( *J != I )
- for( _nb_cit j = nb2(*J).begin(); j != nb2(*J).end(); ++j )
- if( *j != i ) {
- size_t jJ = VV2E(*j,*J);
- calcNewMessage(jJ);
- residuals[jJ] = dist( _newmessages[jJ], _messages[jJ], Prob::DISTLINF );
+ foreach( const Neighbor &J, nbV(i) ) {
+ if( J.iter != _I ) {
+ foreach( const Neighbor &j, nbF(J) ) {
+ size_t _J = j.dual;
+ if( j != i ) {
+ calcNewMessage( j, _J );
+ residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
}
+ }
+ }
+ }
}
- } else if( Updates() == UpdateType::PARALL ) {
+ } else if( props.updates == Properties::UpdateType::PARALL ) {
// Parallel updates
- for( size_t t = 0; t < nr_edges(); ++t )
- calcNewMessage(t);
+ for( size_t i = 0; i < nrVars(); ++i )
+ foreach( const Neighbor &I, nbV(i) )
+ calcNewMessage( i, I.iter );
- for( size_t t = 0; t < nr_edges(); ++t )
- _messages[t] = _newmessages[t];
+ for( size_t i = 0; i < nrVars(); ++i )
+ foreach( const Neighbor &I, nbV(i) )
+ updateMessage( i, I.iter );
} else {
// Sequential updates
- if( Updates() == UpdateType::SEQRND )
- random_shuffle( edge_seq.begin(), edge_seq.end() );
+ if( props.updates == Properties::UpdateType::SEQRND )
+ random_shuffle( update_seq.begin(), update_seq.end() );
- for( size_t t = 0; t < nr_edges(); ++t ) {
- size_t k = edge_seq[t];
- calcNewMessage(k);
- _messages[k] = _newmessages[k];
+ foreach( const Edge &e, update_seq ) {
+ calcNewMessage( e.first, e.second );
+ updateMessage( e.first, e.second );
}
}
// calculate new beliefs and compare with old ones
for( size_t i = 0; i < nrVars(); ++i ) {
- Factor nb( belief1(i) );
+ Factor nb( beliefV(i) );
diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
old_beliefs[i] = nb;
}
- if( Verbose() >= 3 )
- cout << "BP::run: maxdiff " << diffs.max() << " after " << iter+1 << " passes" << endl;
+ if( props.verbose >= 3 )
+ cout << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
}
- updateMaxDiff( diffs.max() );
+ if( diffs.maxDiff() > _maxdiff )
+ _maxdiff = diffs.maxDiff();
- if( Verbose() >= 1 ) {
- if( diffs.max() > Tol() ) {
- if( Verbose() == 1 )
+ if( props.verbose >= 1 ) {
+ if( diffs.maxDiff() > props.tol ) {
+ if( props.verbose == 1 )
cout << endl;
- cout << "BP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
+ cout << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
} else {
- if( Verbose() >= 3 )
- cout << "BP::run: ";
- cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
+ if( props.verbose >= 3 )
+ cout << Name << "::run: ";
+ cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
}
}
- return diffs.max();
+ return diffs.maxDiff();
}
-Factor BP::belief1( size_t i ) const {
- Prob prod( var(i).states() );
- for( _nb_cit I = nb1(i).begin(); I != nb1(i).end(); ++I )
- prod *= newMessage(i,*I);
+Factor BP::beliefV( size_t i ) const {
+ Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 );
+ foreach( const Neighbor &I, nbV(i) )
+ if( props.logdomain )
+ prod += newMessage( i, I.iter );
+ else
+ prod *= newMessage( i, I.iter );
+ if( props.logdomain ) {
+ prod -= prod.maxVal();
+ prod.takeExp();
+ }
- prod.normalize( Prob::NORMPROB );
+ prod.normalize();
return( Factor( var(i), prod ) );
}
Factor BP::belief (const Var &n) const {
- return( belief1( findVar( n ) ) );
+ return( beliefV( findVar( n ) ) );
}
vector<Factor> BP::beliefs() const {
vector<Factor> result;
for( size_t i = 0; i < nrVars(); ++i )
- result.push_back( belief1(i) );
+ result.push_back( beliefV(i) );
for( size_t I = 0; I < nrFactors(); ++I )
- result.push_back( belief2(I) );
+ result.push_back( beliefF(I) );
return result;
}
if( factor(I).vars() >> ns )
break;
assert( I != nrFactors() );
- return belief2(I).marginal(ns);
+ return beliefF(I).marginal(ns);
}
}
-Factor BP::belief2 (size_t I) const {
- Prob prod( factor(I).p() );
+Factor BP::beliefF (size_t I) const {
+ if( !DAI_BP_FAST ) {
+ /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
- for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); ++j ) {
- // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
- const _ind_t *ind = &(index(*j, I));
+ Factor prod( factor(I) );
+ foreach( const Neighbor &j, nbF(I) ) {
+ foreach( const Neighbor &J, nbV(j) ) {
+ if( J != I ) // for all J in nb(j) \ I
+ prod *= Factor( var(j), newMessage(j, J.iter) );
+ }
+ }
+ return prod.normalized();
+ } else {
+ /* OPTIMIZED VERSION */
+ Prob prod( factor(I).p() );
+ if( props.logdomain )
+ prod.takeLog();
- // prod_j will be the product of messages coming into j
- Prob prod_j( var(*j).states() );
- for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); ++J )
- if( *J != I ) // for all J in nb(j) \ I
- prod_j *= newMessage(*j,*J);
+ foreach( const Neighbor &j, nbF(I) ) {
+ size_t _I = j.dual;
+ // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+ const ind_t & ind = index(j, _I);
- // multiply prod with prod_j
- for( size_t r = 0; r < prod.size(); ++r )
- prod[r] *= prod_j[(*ind)[r]];
- }
+ // prod_j will be the product of messages coming into j
+ Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
+ foreach( const Neighbor &J, nbV(j) ) {
+ if( J != I ) { // for all J in nb(j) \ I
+ if( props.logdomain )
+ prod_j += newMessage( j, J.iter );
+ else
+ prod_j *= newMessage( j, J.iter );
+ }
+ }
+
+ // multiply prod with prod_j
+ for( size_t r = 0; r < prod.size(); ++r ) {
+ if( props.logdomain )
+ prod[r] += prod_j[ind[r]];
+ else
+ prod[r] *= prod_j[ind[r]];
+ }
+ }
- Factor result( factor(I).vars(), prod );
- result.normalize( Prob::NORMPROB );
+ if( props.logdomain ) {
+ prod -= prod.maxVal();
+ prod.takeExp();
+ }
- return( result );
+ Factor result( factor(I).vars(), prod );
+ result.normalize();
-/* UNOPTIMIZED VERSION
-
- Factor prod( factor(I) );
- for( _nb_cit i = nb2(I).begin(); i != nb2(I).end(); i++ ) {
- for( _nb_cit J = nb1(*i).begin(); J != nb1(*i).end(); J++ )
- if( *J != I )
- prod *= Factor( var(*i), newMessage(*i,*J)) );
+ return( result );
}
- return prod.normalize( Prob::NORMPROB );*/
}
-Complex BP::logZ() const {
- Complex sum = 0.0;
+Real BP::logZ() const {
+ Real sum = 0.0;
for(size_t i = 0; i < nrVars(); ++i )
- sum += Complex(1.0 - nb1(i).size()) * belief1(i).entropy();
+ sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
for( size_t I = 0; I < nrFactors(); ++I )
- sum -= KL_dist( belief2(I), factor(I) );
+ sum -= dist( beliefF(I), factor(I), Prob::DISTKL );
return sum;
}
string BP::identify() const {
- stringstream result (stringstream::out);
- result << Name << GetProperties();
- return result.str();
+ return string(Name) + printProperties();
}
void BP::init( const VarSet &ns ) {
for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
size_t ni = findVar( *n );
- for( _nb_cit I = nb1(ni).begin(); I != nb1(ni).end(); ++I )
- message(ni,*I).fill( 1.0 );
+ foreach( const Neighbor &I, nbV( ni ) )
+ message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
}
}