// using the parameters specified by opts and an additional property
// that specifies the type of updates the JTree algorithm should perform
JTree jt( fg, opts("updates",string("HUGIN")) );
- // Initialize junction tree algoritm
+ // Initialize junction tree algorithm
jt.init();
// Run junction tree algorithm
jt.run();
+ // Construct another JTree (junction tree) object that is used to calculate
+ // the joint configuration of variables that has maximum probability (MAP state)
+ JTree jtmap( fg, opts("updates",string("HUGIN"))("inference",string("MAXPROD")) );
+ // Initialize junction tree algorithm
+ jtmap.init();
+ // Run junction tree algorithm
+ jtmap.run();
+ // Calculate joint state of all variables that has maximum probability
+ vector<size_t> jtmapstate = jtmap.findMaximum();
+
// Construct a BP (belief propagation) object from the FactorGraph fg
// using the parameters specified by opts and two additional properties,
// specifying the type of updates the BP algorithm should perform and
// Report log partition sum of fg, approximated by the belief propagation algorithm
cout << "Approximate (loopy belief propagation) log partition sum: " << bp.logZ() << endl;
+ // Report exact MAP variable marginals
+ cout << "Exact MAP variable marginals:" << endl;
+ for( size_t i = 0; i < fg.nrVars(); i++ )
+ cout << jtmap.belief(fg.var(i)) << endl;
+
// Report max-product variable marginals
- cout << "Max-product variable marginals:" << endl;
+ cout << "Approximate (max-product) MAP variable marginals:" << endl;
for( size_t i = 0; i < fg.nrVars(); i++ )
cout << mp.belief(fg.var(i)) << endl;
+ // Report exact MAP factor marginals
+ cout << "Exact MAP factor marginals:" << endl;
+ for( size_t I = 0; I < fg.nrFactors(); I++ )
+ cout << jtmap.belief(fg.factor(I).vars()) << "=" << jtmap.beliefF(I) << endl;
+
// Report max-product factor marginals
- cout << "Max-product factor marginals:" << endl;
+ cout << "Approximate (max-product) MAP factor marginals:" << endl;
for( size_t I = 0; I < fg.nrFactors(); I++ )
cout << mp.belief(fg.factor(I).vars()) << "=" << mp.beliefF(I) << endl;
- // Report max-product joint state
- cout << "Max-product state:" << endl;
+ // Report exact MAP joint state
+ cout << "Exact MAP state:" << endl;
+ for( size_t i = 0; i < jtmapstate.size(); i++ )
+ cout << fg.var(i) << ": " << jtmapstate[i] << endl;
+
+ // Report max-product MAP joint state
+ cout << "Approximate (max-product) MAP state:" << endl;
for( size_t i = 0; i < mpstate.size(); i++ )
cout << fg.var(i) << ": " << mpstate[i] << endl;
}
#include <iostream>
+#include <stack>
#include <dai/jtree.h>
props.verbose = opts.getStringAs<size_t>("verbose");
props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+ if( opts.hasKey("inference") )
+ props.inference = opts.getStringAs<Properties::InfType>("inference");
+ else
+ props.inference = Properties::InfType::SUMPROD;
}
PropertySet opts;
opts.Set( "verbose", props.verbose );
opts.Set( "updates", props.updates );
+ opts.Set( "inference", props.inference );
return opts;
}
stringstream s( stringstream::out );
s << "[";
s << "verbose=" << props.verbose << ",";
- s << "updates=" << props.updates << "]";
+ s << "updates=" << props.updates << ",";
+ s << "inference=" << props.inference << "]";
return s.str();
}
for( size_t i = RTree.size(); (i--) != 0; ) {
// Make outer region RTree[i].n1 consistent with outer region RTree[i].n2
// IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
- Factor new_Qb = Qa[RTree[i].n2].marginal( IR( i ), false );
+ Factor new_Qb;
+ if( props.inference == Properties::InfType::SUMPROD )
+ new_Qb = Qa[RTree[i].n2].marginal( IR( i ), false );
+ else
+ new_Qb = Qa[RTree[i].n2].maxMarginal( IR( i ), false );
+
_logZ += log(new_Qb.normalize());
Qa[RTree[i].n1] *= new_Qb / Qb[i];
Qb[i] = new_Qb;
for( size_t i = 0; i < RTree.size(); i++ ) {
// Make outer region RTree[i].n2 consistent with outer region RTree[i].n1
// IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
- Factor new_Qb = Qa[RTree[i].n1].marginal( IR( i ) );
+ Factor new_Qb;
+ if( props.inference == Properties::InfType::SUMPROD )
+ new_Qb = Qa[RTree[i].n1].marginal( IR( i ) );
+ else
+ new_Qb = Qa[RTree[i].n1].maxMarginal( IR( i ) );
+
Qa[RTree[i].n2] *= new_Qb / Qb[i];
Qb[i] = new_Qb;
}
size_t j = nbIR(e)[0].node; // = RTree[e].n1
size_t _e = nbIR(e)[0].dual;
- Factor piet = OR(i);
+ Factor msg = OR(i);
foreach( const Neighbor &k, nbOR(i) )
if( k != e )
- piet *= message( i, k.iter );
- message( j, _e ) = piet.marginal( IR(e), false );
+ msg *= message( i, k.iter );
+ if( props.inference == Properties::InfType::SUMPROD )
+ message( j, _e ) = msg.marginal( IR(e), false );
+ else
+ message( j, _e ) = msg.maxMarginal( IR(e), false );
_logZ += log( message(j,_e).normalize() );
}
size_t j = nbIR(e)[1].node; // = RTree[e].n2
size_t _e = nbIR(e)[1].dual;
- Factor piet = OR(i);
+ Factor msg = OR(i);
foreach( const Neighbor &k, nbOR(i) )
if( k != e )
- piet *= message( i, k.iter );
- message( j, _e ) = piet.marginal( IR(e) );
+ msg *= message( i, k.iter );
+ if( props.inference == Properties::InfType::SUMPROD )
+ message( j, _e ) = msg.marginal( IR(e) );
+ else
+ message( j, _e ) = msg.maxMarginal( IR(e) );
}
// Calculate beliefs
}
// Only for logZ (and for belief)...
- for( size_t beta = 0; beta < nrIRs(); beta++ )
- Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
+ for( size_t beta = 0; beta < nrIRs(); beta++ ) {
+ if( props.inference == Properties::InfType::SUMPROD )
+ Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
+ else
+ Qb[beta] = Qa[nbIR(beta)[0].node].maxMarginal( IR(beta) );
+ }
}
}
+std::vector<size_t> JTree::findMaximum() const {
+ vector<size_t> maximum( nrVars() );
+ vector<bool> visitedVars( nrVars(), false );
+ vector<bool> visitedFactors( nrFactors(), false );
+ stack<size_t> scheduledFactors;
+ for( size_t i = 0; i < nrVars(); ++i ) {
+ if( visitedVars[i] )
+ continue;
+ visitedVars[i] = true;
+
+ // Maximise with respect to variable i
+ Prob prod = beliefV(i).p();
+ maximum[i] = max_element( prod.begin(), prod.end() ) - prod.begin();
+
+ foreach( const Neighbor &I, nbV(i) )
+ if( !visitedFactors[I] )
+ scheduledFactors.push(I);
+
+ while( !scheduledFactors.empty() ){
+ size_t I = scheduledFactors.top();
+ scheduledFactors.pop();
+ if( visitedFactors[I] )
+ continue;
+ visitedFactors[I] = true;
+
+ // Evaluate if some neighboring variables still need to be fixed; if not, we're done
+ bool allDetermined = true;
+ foreach( const Neighbor &j, nbF(I) )
+ if( !visitedVars[j.node] ) {
+ allDetermined = false;
+ break;
+ }
+ if( allDetermined )
+ continue;
+
+ // Calculate product of incoming messages on factor I
+ Prob prod2 = beliefF(I).p();
+
+ // The allowed configuration is restrained according to the variables assigned so far:
+ // pick the argmax amongst the allowed states
+ Real maxProb = numeric_limits<Real>::min();
+ State maxState( factor(I).vars() );
+ for( State s( factor(I).vars() ); s.valid(); ++s ){
+ // First, calculate whether this state is consistent with variables that
+ // have been assigned already
+ bool allowedState = true;
+ foreach( const Neighbor &j, nbF(I) )
+ if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
+ allowedState = false;
+ break;
+ }
+ // If it is consistent, check if its probability is larger than what we have seen so far
+ if( allowedState && prod2[s] > maxProb ) {
+ maxState = s;
+ maxProb = prod2[s];
+ }
+ }
+
+ // Decode the argmax
+ foreach( const Neighbor &j, nbF(I) ) {
+ if( visitedVars[j.node] ) {
+ // We have already visited j earlier - hopefully our state is consistent
+ if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
+ cerr << "JTree::findMaximum - warning: maximum not consistent due to loops." << endl;
+ } else {
+ // We found a consistent state for variable j
+ visitedVars[j.node] = true;
+ maximum[j.node] = maxState( var(j.node) );
+ foreach( const Neighbor &J, nbV(j) )
+ if( !visitedFactors[J] )
+ scheduledFactors.push(J);
+ }
+ }
+ }
+ }
+ return maximum;
+}
+
+
} // end of namespace dai