From 4fe8baa9d0ad8168dad49cb0ec6d82454ad99130 Mon Sep 17 00:00:00 2001 From: Joris Mooij Date: Wed, 13 Jan 2010 13:52:53 +0100 Subject: [PATCH] Cleaned up some code in TreeEP and JTree --- include/dai/jtree.h | 10 +- include/dai/mr.h | 1 + src/jtree.cpp | 46 +++--- src/treeep.cpp | 379 +++++++++++++++++--------------------------- 4 files changed, 187 insertions(+), 249 deletions(-) diff --git a/include/dai/jtree.h b/include/dai/jtree.h index 3bed8b1..37f6981 100644 --- a/include/dai/jtree.h +++ b/include/dai/jtree.h @@ -128,12 +128,20 @@ class JTree : public DAIAlgRG { /** First, constructs a weighted graph, where the nodes are the elements of \a cl, and * each edge is weighted with the cardinality of the intersection of the state spaces of the nodes. * Then, a maximal spanning tree for this weighted graph is calculated. - * Finally, a corresponding region graph is built: + * Subsequently, a corresponding region graph is built: * - the outer regions correspond with the cliques and have counting number 1; * - the inner regions correspond with the seperators, i.e., the intersections of two * cliques that are neighbors in the spanning tree, and have counting number -1; * - inner and outer regions are connected by an edge if the inner region is a * seperator for the outer region. + * Finally, Beliefs are constructed. + * If \a verify == \c true, checks whether each factor is subsumed by a clique. + */ + void construct( const std::vector &cl, bool verify=false ); + + /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence). + /** Invokes construct() and then constructs messages. + * \see construct() */ void GenerateJT( const std::vector &cl ); diff --git a/include/dai/mr.h b/include/dai/mr.h index 64482d4..ce8a7c6 100644 --- a/include/dai/mr.h +++ b/include/dai/mr.h @@ -33,6 +33,7 @@ namespace dai { /// Approximate inference algorithm by Montanari and Rizzo [\ref MoR05] /** \author Bastian Wemmenhove wrote the original implementation before it was merged into libDAI + * \todo Clean up code (use a BipartiteGraph-like implementation for the graph structure) */ class MR : public DAIAlgFG { private: diff --git a/src/jtree.cpp b/src/jtree.cpp index 4874f6f..b287caa 100644 --- a/src/jtree.cpp +++ b/src/jtree.cpp @@ -88,14 +88,13 @@ JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : } -void JTree::GenerateJT( const std::vector &Cliques ) { +void JTree::construct( const std::vector &cl, bool verify ) { // Construct a weighted graph (each edge is weighted with the cardinality - // of the intersection of the nodes, where the nodes are the elements of - // Cliques). + // of the intersection of the nodes, where the nodes are the elements of cl). WeightedGraph JuncGraph; - for( size_t i = 0; i < Cliques.size(); i++ ) - for( size_t j = i+1; j < Cliques.size(); j++ ) { - size_t w = (Cliques[i] & Cliques[j]).size(); + for( size_t i = 0; i < cl.size(); i++ ) + for( size_t j = i+1; j < cl.size(); j++ ) { + size_t w = (cl[i] & cl[j]).size(); if( w ) JuncGraph[UEdge(i,j)] = w; } @@ -106,24 +105,29 @@ void JTree::GenerateJT( const std::vector &Cliques ) { // Construct corresponding region graph // Create outer regions - ORs.reserve( Cliques.size() ); - for( size_t i = 0; i < Cliques.size(); i++ ) - ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) ); + ORs.clear(); + ORs.reserve( cl.size() ); + for( size_t i = 0; i < cl.size(); i++ ) + ORs.push_back( FRegion( Factor(cl[i], 1.0), 1.0 ) ); // For each factor, find an outer region that subsumes that factor. // Then, multiply the outer region with that factor. + fac2OR.clear(); + fac2OR.resize( nrFactors(), -1U ); for( size_t I = 0; I < nrFactors(); I++ ) { size_t alpha; for( alpha = 0; alpha < nrORs(); alpha++ ) if( OR(alpha).vars() >> factor(I).vars() ) { - fac2OR.push_back( alpha ); + fac2OR[I] = alpha; break; } - DAI_ASSERT( alpha != nrORs() ); + if( verify ) + DAI_ASSERT( alpha != nrORs() ); } RecomputeORs(); // Create inner regions and edges + IRs.clear(); IRs.reserve( RTree.size() ); vector edges; edges.reserve( 2 * RTree.size() ); @@ -131,13 +135,18 @@ void JTree::GenerateJT( const std::vector &Cliques ) { edges.push_back( Edge( RTree[i].n1, nrIRs() ) ); edges.push_back( Edge( RTree[i].n2, nrIRs() ) ); // inner clusters have counting number -1 - IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) ); + IRs.push_back( Region( cl[RTree[i].n1] & cl[RTree[i].n2], -1.0 ) ); } // create bipartite graph G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() ); - // Create messages and beliefs + // Check counting numbers +#ifdef DAI_DEBUG + checkCountingNumbers(); +#endif + + // Create beliefs Qa.clear(); Qa.reserve( nrORs() ); for( size_t alpha = 0; alpha < nrORs(); alpha++ ) @@ -147,7 +156,13 @@ void JTree::GenerateJT( const std::vector &Cliques ) { Qb.reserve( nrIRs() ); for( size_t beta = 0; beta < nrIRs(); beta++ ) Qb.push_back( Factor( IR(beta), 1.0 ) ); +} + +void JTree::GenerateJT( const std::vector &cl ) { + construct( cl, true ); + + // Create messages _mes.clear(); _mes.reserve( nrORs() ); for( size_t alpha = 0; alpha < nrORs(); alpha++ ) { @@ -157,11 +172,6 @@ void JTree::GenerateJT( const std::vector &Cliques ) { _mes[alpha].push_back( Factor( IR(beta), 1.0 ) ); } - // Check counting numbers -#ifdef DAI_DEBUG - checkCountingNumbers(); -#endif - if( props.verbose >= 3 ) cerr << "Regiongraph generated by JTree::GenerateJT: " << *this << endl; } diff --git a/src/treeep.cpp b/src/treeep.cpp index 1685219..000ce59 100644 --- a/src/treeep.cpp +++ b/src/treeep.cpp @@ -60,132 +60,6 @@ string TreeEP::printProperties() const { } -TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector &jt_Qa, const std::vector &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) { - _ns = _I->vars(); - - // Make _Qa, _Qb, _a and _b corresponding to the subtree - _b.reserve( subRTree.size() ); - _Qb.reserve( subRTree.size() ); - _RTree.reserve( subRTree.size() ); - for( size_t i = 0; i < subRTree.size(); i++ ) { - size_t alpha1 = subRTree[i].n1; // old index 1 - size_t alpha2 = subRTree[i].n2; // old index 2 - size_t beta; // old sep index - for( beta = 0; beta < jt_RTree.size(); beta++ ) - if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) ) - break; - DAI_ASSERT( beta != jt_RTree.size() ); - - size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin(); - if( newalpha1 == _a.size() ) { - _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) ); - _a.push_back( alpha1 ); // save old index in index conversion table - } - - size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin(); - if( newalpha2 == _a.size() ) { - _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) ); - _a.push_back( alpha2 ); // save old index in index conversion table - } - - _RTree.push_back( DEdge( newalpha1, newalpha2 ) ); - _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) ); - _b.push_back( beta ); - } - - // Find remaining variables (which are not in the new root) - _nsrem = _ns / _Qa[0].vars(); -} - - -void TreeEP::TreeEPSubTree::init() { - for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) - _Qa[alpha].fill( 1.0 ); - for( size_t beta = 0; beta < _Qb.size(); beta++ ) - _Qb[beta].fill( 1.0 ); -} - - -void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector &Qa, const std::vector &Qb ) { - for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) - _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha]; - - for( size_t beta = 0; beta < _Qb.size(); beta++ ) - _Qb[beta] = Qb[_b[beta]] / _Qb[beta]; -} - - -void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector &Qa, std::vector &Qb ) { - // Backup _Qa and _Qb - vector _Qa_old(_Qa); - vector _Qb_old(_Qb); - - // Clear Qa and Qb - for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) - Qa[_a[alpha]].fill( 0.0 ); - for( size_t beta = 0; beta < _Qb.size(); beta++ ) - Qb[_b[beta]].fill( 0.0 ); - - // For all states of _nsrem - for( State s(_nsrem); s.valid(); s++ ) { - // Multiply root with slice of I - _Qa[0] *= _I->slice( _nsrem, s ); - - // CollectEvidence - for( size_t i = _RTree.size(); (i--) != 0; ) { - // clamp variables in nsrem - for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ ) - if( _Qa[_RTree[i].n2].vars() >> *n ) { - Factor delta( *n, 0.0 ); - delta[s(*n)] = 1.0; - _Qa[_RTree[i].n2] *= delta; - } - Factor new_Qb = _Qa[_RTree[i].n2].marginal( _Qb[i].vars(), false ); - _Qa[_RTree[i].n1] *= new_Qb / _Qb[i]; - _Qb[i] = new_Qb; - } - - // DistributeEvidence - for( size_t i = 0; i < _RTree.size(); i++ ) { - Factor new_Qb = _Qa[_RTree[i].n1].marginal( _Qb[i].vars(), false ); - _Qa[_RTree[i].n2] *= new_Qb / _Qb[i]; - _Qb[i] = new_Qb; - } - - // Store Qa's and Qb's - for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) - Qa[_a[alpha]].p() += _Qa[alpha].p(); - for( size_t beta = 0; beta < _Qb.size(); beta++ ) - Qb[_b[beta]].p() += _Qb[beta].p(); - - // Restore _Qa and _Qb - _Qa = _Qa_old; - _Qb = _Qb_old; - } - - // Normalize Qa and Qb - _logZ = 0.0; - for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) { - _logZ += log(Qa[_a[alpha]].sum()); - Qa[_a[alpha]].normalize(); - } - for( size_t beta = 0; beta < _Qb.size(); beta++ ) { - _logZ -= log(Qb[_b[beta]].sum()); - Qb[_b[beta]].normalize(); - } -} - - -Real TreeEP::TreeEPSubTree::logZ( const std::vector &Qa, const std::vector &Qb ) const { - Real s = 0.0; - for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) - s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum(); - for( size_t beta = 0; beta < _Qb.size(); beta++ ) - s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum(); - return s + _logZ; -} - - TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() { setProperties( opts ); @@ -243,111 +117,34 @@ TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opt void TreeEP::construct( const RootedTree &tree ) { - vector Cliques; + vector cl; for( size_t i = 0; i < tree.size(); i++ ) - Cliques.push_back( VarSet( var(tree[i].n1), var(tree[i].n2) ) ); - - // Construct a weighted graph (each edge is weighted with the cardinality - // of the intersection of the nodes, where the nodes are the elements of - // Cliques). - WeightedGraph JuncGraph; - for( size_t i = 0; i < Cliques.size(); i++ ) - for( size_t j = i+1; j < Cliques.size(); j++ ) { - size_t w = (Cliques[i] & Cliques[j]).size(); - if( w ) - JuncGraph[UEdge(i,j)] = w; - } + cl.push_back( VarSet( var(tree[i].n1), var(tree[i].n2) ) ); - // Construct maximal spanning tree using Prim's algorithm - RTree = MaxSpanningTreePrims( JuncGraph ); - - // Construct corresponding region graph - - // Create outer regions - ORs.reserve( Cliques.size() ); - for( size_t i = 0; i < Cliques.size(); i++ ) - ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) ); - - // For each factor, find an outer region that subsumes that factor. - // Then, multiply the outer region with that factor. // If no outer region can be found subsuming that factor, label the // factor as off-tree. - fac2OR.clear(); - fac2OR.resize( nrFactors(), -1U ); - for( size_t I = 0; I < nrFactors(); I++ ) { - size_t alpha; - for( alpha = 0; alpha < nrORs(); alpha++ ) - if( OR(alpha).vars() >> factor(I).vars() ) { - fac2OR[I] = alpha; - break; - } - // DIFF WITH JTree::GenerateJT: assert - } - RecomputeORs(); - - // Create inner regions and edges - IRs.reserve( RTree.size() ); - vector edges; - edges.reserve( 2 * RTree.size() ); - for( size_t i = 0; i < RTree.size(); i++ ) { - edges.push_back( Edge( RTree[i].n1, IRs.size() ) ); - edges.push_back( Edge( RTree[i].n2, IRs.size() ) ); - // inner clusters have counting number -1 - IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) ); - } - - // create bipartite graph - G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() ); + JTree::construct( cl, false ); - // Check counting numbers - checkCountingNumbers(); - - // Create messages and beliefs - Qa.clear(); - Qa.reserve( nrORs() ); - for( size_t alpha = 0; alpha < nrORs(); alpha++ ) - Qa.push_back( OR(alpha) ); - - Qb.clear(); - Qb.reserve( nrIRs() ); - for( size_t beta = 0; beta < nrIRs(); beta++ ) - Qb.push_back( Factor( IR(beta), 1.0 ) ); - - // DIFF with JTree::GenerateJT: no messages - - // DIFF with JTree::GenerateJT: // Create factor approximations _Q.clear(); size_t PreviousRoot = (size_t)-1; - for( size_t I = 0; I < nrFactors(); I++ ) - if( offtree(I) ) { - // find efficient subtree - RootedTree subTree; - /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot ); - PreviousRoot = subTree[0].n1; - //subTree.resize( subTreeSize ); // FIXME -// cerr << "subtree " << I << " has size " << subTreeSize << endl; - - TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) ); - _Q[I] = QI; - } - // Previous root of first off-tree factor should be the root of the last off-tree factor - for( size_t I = 0; I < nrFactors(); I++ ) - if( offtree(I) ) { - RootedTree subTree; - /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot ); - PreviousRoot = subTree[0].n1; - //subTree.resize( subTreeSize ); // FIXME -// cerr << "subtree " << I << " has size " << subTreeSize << endl; - - TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) ); - _Q[I] = QI; - break; - } + // Second repetition: previous root of first off-tree factor should be the root of the last off-tree factor + for( size_t repeats = 0; repeats < 2; repeats++ ) + for( size_t I = 0; I < nrFactors(); I++ ) + if( offtree(I) ) { + // find efficient subtree + RootedTree subTree; + /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot ); + PreviousRoot = subTree[0].n1; + //subTree.resize( subTreeSize ); // FIXME + //cerr << "subtree " << I << " has size " << subTreeSize << endl; + _Q[I] = TreeEPSubTree( subTree, RTree, Qa, Qb, &factor(I) ); + if( repeats == 1 ) + break; + } - if( props.verbose >= 3 ) { + if( props.verbose >= 3 ) cerr << "Resulting regiongraph: " << *this << endl; - } } @@ -374,10 +171,7 @@ Real TreeEP::run() { double tic = toc(); - vector oldBeliefsV; - oldBeliefsV.reserve( nrVars() ); - for( size_t i = 0; i < nrVars(); i++ ) - oldBeliefsV.push_back( beliefV(i) ); + vector oldBeliefs = beliefs(); // do several passes over the network until maximum number of iterations has // been reached or until the maximum belief difference is smaller than tolerance @@ -391,12 +185,11 @@ Real TreeEP::run() { } // calculate new beliefs and compare with old ones + vector newBeliefs = beliefs(); maxDiff = -INFINITY; - for( size_t i = 0; i < nrVars(); i++ ) { - Factor nb( beliefV(i) ); - maxDiff = std::max( maxDiff, dist( nb, oldBeliefsV[i], Prob::DISTLINF ) ); - oldBeliefsV[i] = nb; - } + for( size_t t = 0; t < oldBeliefs.size(); t++ ) + maxDiff = std::max( maxDiff, dist( newBeliefs[t], oldBeliefs[t], Prob::DISTLINF ) ); + swap( newBeliefs, oldBeliefs ); if( props.verbose >= 3 ) cerr << Name << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl; @@ -443,4 +236,130 @@ Real TreeEP::logZ() const { } +TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector &jt_Qa, const std::vector &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) { + _ns = _I->vars(); + + // Make _Qa, _Qb, _a and _b corresponding to the subtree + _b.reserve( subRTree.size() ); + _Qb.reserve( subRTree.size() ); + _RTree.reserve( subRTree.size() ); + for( size_t i = 0; i < subRTree.size(); i++ ) { + size_t alpha1 = subRTree[i].n1; // old index 1 + size_t alpha2 = subRTree[i].n2; // old index 2 + size_t beta; // old sep index + for( beta = 0; beta < jt_RTree.size(); beta++ ) + if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) ) + break; + DAI_ASSERT( beta != jt_RTree.size() ); + + size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin(); + if( newalpha1 == _a.size() ) { + _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) ); + _a.push_back( alpha1 ); // save old index in index conversion table + } + + size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin(); + if( newalpha2 == _a.size() ) { + _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) ); + _a.push_back( alpha2 ); // save old index in index conversion table + } + + _RTree.push_back( DEdge( newalpha1, newalpha2 ) ); + _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) ); + _b.push_back( beta ); + } + + // Find remaining variables (which are not in the new root) + _nsrem = _ns / _Qa[0].vars(); +} + + +void TreeEP::TreeEPSubTree::init() { + for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) + _Qa[alpha].fill( 1.0 ); + for( size_t beta = 0; beta < _Qb.size(); beta++ ) + _Qb[beta].fill( 1.0 ); +} + + +void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector &Qa, const std::vector &Qb ) { + for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) + _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha]; + + for( size_t beta = 0; beta < _Qb.size(); beta++ ) + _Qb[beta] = Qb[_b[beta]] / _Qb[beta]; +} + + +void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector &Qa, std::vector &Qb ) { + // Backup _Qa and _Qb + vector _Qa_old(_Qa); + vector _Qb_old(_Qb); + + // Clear Qa and Qb + for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) + Qa[_a[alpha]].fill( 0.0 ); + for( size_t beta = 0; beta < _Qb.size(); beta++ ) + Qb[_b[beta]].fill( 0.0 ); + + // For all states of _nsrem + for( State s(_nsrem); s.valid(); s++ ) { + // Multiply root with slice of I + _Qa[0] *= _I->slice( _nsrem, s ); + + // CollectEvidence + for( size_t i = _RTree.size(); (i--) != 0; ) { + // clamp variables in nsrem + for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ ) + if( _Qa[_RTree[i].n2].vars() >> *n ) { + Factor delta( *n, 0.0 ); + delta[s(*n)] = 1.0; + _Qa[_RTree[i].n2] *= delta; + } + Factor new_Qb = _Qa[_RTree[i].n2].marginal( _Qb[i].vars(), false ); + _Qa[_RTree[i].n1] *= new_Qb / _Qb[i]; + _Qb[i] = new_Qb; + } + + // DistributeEvidence + for( size_t i = 0; i < _RTree.size(); i++ ) { + Factor new_Qb = _Qa[_RTree[i].n1].marginal( _Qb[i].vars(), false ); + _Qa[_RTree[i].n2] *= new_Qb / _Qb[i]; + _Qb[i] = new_Qb; + } + + // Store Qa's and Qb's + for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) + Qa[_a[alpha]].p() += _Qa[alpha].p(); + for( size_t beta = 0; beta < _Qb.size(); beta++ ) + Qb[_b[beta]].p() += _Qb[beta].p(); + + // Restore _Qa and _Qb + _Qa = _Qa_old; + _Qb = _Qb_old; + } + + // Normalize Qa and Qb + _logZ = 0.0; + for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) { + _logZ += log(Qa[_a[alpha]].sum()); + Qa[_a[alpha]].normalize(); + } + for( size_t beta = 0; beta < _Qb.size(); beta++ ) { + _logZ -= log(Qb[_b[beta]].sum()); + Qb[_b[beta]].normalize(); + } +} + + +Real TreeEP::TreeEPSubTree::logZ( const std::vector &Qa, const std::vector &Qb ) const { + Real s = 0.0; + for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) + s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum(); + for( size_t beta = 0; beta < _Qb.size(); beta++ ) + s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum(); + return s + _logZ; +} + + } // end of namespace dai -- 2.20.1