const char *BP::Name = "BP";
-bool BP::checkProperties() {
- if( !HasProperty("updates") )
- return false;
- if( !HasProperty("tol") )
- return false;
- if (!HasProperty("maxiter") )
- return false;
- if (!HasProperty("verbose") )
- return false;
- if (!HasProperty("logdomain") )
- return false;
+void BP::setProperties( const PropertySet &opts ) {
+ assert( opts.hasKey("tol") );
+ assert( opts.hasKey("maxiter") );
+ assert( opts.hasKey("verbose") );
+ assert( opts.hasKey("logdomain") );
+ assert( opts.hasKey("updates") );
- ConvertPropertyTo<double>("tol");
- ConvertPropertyTo<size_t>("maxiter");
- ConvertPropertyTo<size_t>("verbose");
- ConvertPropertyTo<UpdateType>("updates");
- ConvertPropertyTo<bool>("logdomain");
- logDomain = GetPropertyAs<bool>("logdomain");
-
- return true;
+ props.tol = opts.getStringAs<double>("tol");
+ props.maxiter = opts.getStringAs<size_t>("maxiter");
+ props.verbose = opts.getStringAs<size_t>("verbose");
+ props.logdomain = opts.getStringAs<bool>("logdomain");
+ props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+}
+
+
+PropertySet BP::getProperties() const {
+ PropertySet opts;
+ opts.Set( "tol", props.tol );
+ opts.Set( "maxiter", props.maxiter );
+ opts.Set( "verbose", props.verbose );
+ opts.Set( "logdomain", props.logdomain );
+ opts.Set( "updates", props.updates );
+ return opts;
}
void BP::init() {
- assert( checkProperties() );
for( size_t i = 0; i < nrVars(); ++i ) {
foreach( const Neighbor &I, nbV(i) ) {
- if( logDomain ) {
+ if( props.logdomain ) {
message( i, I.iter ).fill( 0.0 );
newMessage( i, I.iter ).fill( 0.0 );
} else {
*/
Prob prod( factor(I).p() );
- if( logDomain )
+ if( props.logdomain )
prod.takeLog();
// Calculate product of incoming messages and factor I
const ind_t & ind = index(j, _I);
// prod_j will be the product of messages coming into j
- Prob prod_j( var(j).states(), logDomain ? 0.0 : 1.0 );
+ Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
foreach( const Neighbor &J, nbV(j) )
if( J != I ) { // for all J in nb(j) \ I
- if( logDomain )
+ if( props.logdomain )
prod_j += message( j, J.iter );
else
prod_j *= message( j, J.iter );
// multiply prod with prod_j
for( size_t r = 0; r < prod.size(); ++r )
- if( logDomain )
+ if( props.logdomain )
prod[r] += prod_j[ind[r]];
else
prod[r] *= prod_j[ind[r]];
}
}
- if( logDomain ) {
+ if( props.logdomain ) {
prod -= prod.maxVal();
prod.takeExp();
}
marg.normalize( _normtype );
// Store result
- if( logDomain )
+ if( props.logdomain )
newMessage(i,_I) = marg.log();
else
newMessage(i,_I) = marg;
// BP::run does not check for NANs for performance reasons
// Somehow NaNs do not often occur in BP...
double BP::run() {
- if( Verbose() >= 1 )
+ if( props.verbose >= 1 )
cout << "Starting " << identify() << "...";
- if( Verbose() >= 3)
+ if( props.verbose >= 3)
cout << endl;
double tic = toc();
size_t iter = 0;
size_t nredges = nrEdges();
- if( Updates() == UpdateType::SEQMAX ) {
+ if( props.updates == Properties::UpdateType::SEQMAX ) {
// do the first pass
for( size_t i = 0; i < nrVars(); ++i )
foreach( const Neighbor &I, nbV(i) ) {
// do several passes over the network until maximum number of iterations has
// been reached or until the maximum belief difference is smaller than tolerance
- for( iter=0; iter < MaxIter() && diffs.maxDiff() > Tol(); ++iter ) {
- if( Updates() == UpdateType::SEQMAX ) {
+ for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; ++iter ) {
+ if( props.updates == Properties::UpdateType::SEQMAX ) {
// Residuals-BP by Koller et al.
for( size_t t = 0; t < nredges; ++t ) {
// update the message with the largest residual
}
}
}
- } else if( Updates() == UpdateType::PARALL ) {
+ } else if( props.updates == Properties::UpdateType::PARALL ) {
// Parallel updates
for( size_t i = 0; i < nrVars(); ++i )
foreach( const Neighbor &I, nbV(i) )
message( i, I.iter ) = newMessage( i, I.iter );
} else {
// Sequential updates
- if( Updates() == UpdateType::SEQRND )
+ if( props.updates == Properties::UpdateType::SEQRND )
random_shuffle( update_seq.begin(), update_seq.end() );
foreach( const Edge &e, update_seq ) {
old_beliefs[i] = nb;
}
- if( Verbose() >= 3 )
+ if( props.verbose >= 3 )
cout << "BP::run: maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
}
- updateMaxDiff( diffs.maxDiff() );
+ if( diffs.maxDiff() > maxdiff )
+ maxdiff = diffs.maxDiff();
- if( Verbose() >= 1 ) {
- if( diffs.maxDiff() > Tol() ) {
- if( Verbose() == 1 )
+ if( props.verbose >= 1 ) {
+ if( diffs.maxDiff() > props.tol ) {
+ if( props.verbose == 1 )
cout << endl;
- cout << "BP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
+ cout << "BP::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
} else {
- if( Verbose() >= 3 )
+ if( props.verbose >= 3 )
cout << "BP::run: ";
cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
}
Factor BP::beliefV( size_t i ) const {
- Prob prod( var(i).states(), logDomain ? 0.0 : 1.0 );
+ Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 );
foreach( const Neighbor &I, nbV(i) )
- if( logDomain )
+ if( props.logdomain )
prod += newMessage( i, I.iter );
else
prod *= newMessage( i, I.iter );
- if( logDomain ) {
+ if( props.logdomain ) {
prod -= prod.maxVal();
prod.takeExp();
}
Factor BP::beliefF (size_t I) const {
Prob prod( factor(I).p() );
- if( logDomain )
+ if( props.logdomain )
prod.takeLog();
foreach( const Neighbor &j, nbF(I) ) {
const ind_t & ind = index(j, _I);
// prod_j will be the product of messages coming into j
- Prob prod_j( var(j).states(), logDomain ? 0.0 : 1.0 );
+ Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
foreach( const Neighbor &J, nbV(j) ) {
if( J != I ) { // for all J in nb(j) \ I
- if( logDomain )
+ if( props.logdomain )
prod_j += newMessage( j, J.iter );
else
prod_j *= newMessage( j, J.iter );
// multiply prod with prod_j
for( size_t r = 0; r < prod.size(); ++r ) {
- if( logDomain )
+ if( props.logdomain )
prod[r] += prod_j[ind[r]];
else
prod[r] *= prod_j[ind[r]];
}
}
- if( logDomain ) {
+ if( props.logdomain ) {
prod -= prod.maxVal();
prod.takeExp();
}
string BP::identify() const {
stringstream result (stringstream::out);
- result << Name << GetProperties();
+ result << Name << getProperties();
return result.str();
}
for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
size_t ni = findVar( *n );
foreach( const Neighbor &I, nbV( ni ) )
- message( ni, I.iter ).fill( logDomain ? 0.0 : 1.0 );
+ message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
}
}