Cleaned up some code in TreeEP and JTree
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Wed, 13 Jan 2010 12:52:53 +0000 (13:52 +0100)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Wed, 13 Jan 2010 12:52:53 +0000 (13:52 +0100)
include/dai/jtree.h
include/dai/mr.h
src/jtree.cpp
src/treeep.cpp

index 3bed8b1..37f6981 100644 (file)
@@ -128,12 +128,20 @@ class JTree : public DAIAlgRG {
         /** First, constructs a weighted graph, where the nodes are the elements of \a cl, and 
          *  each edge is weighted with the cardinality of the intersection of the state spaces of the nodes. 
          *  Then, a maximal spanning tree for this weighted graph is calculated.
-         *  Finally, a corresponding region graph is built:
+         *  Subsequently, a corresponding region graph is built:
          *    - the outer regions correspond with the cliques and have counting number 1;
          *    - the inner regions correspond with the seperators, i.e., the intersections of two 
          *      cliques that are neighbors in the spanning tree, and have counting number -1;
          *    - inner and outer regions are connected by an edge if the inner region is a
          *      seperator for the outer region.
+         *  Finally, Beliefs are constructed.
+         *  If \a verify == \c true, checks whether each factor is subsumed by a clique.
+         */
+        void construct( const std::vector<VarSet> &cl, bool verify=false );
+
+        /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
+        /** Invokes construct() and then constructs messages.
+         *  \see construct()
          */
         void GenerateJT( const std::vector<VarSet> &cl );
 
index 64482d4..ce8a7c6 100644 (file)
@@ -33,6 +33,7 @@ namespace dai {
 
 /// Approximate inference algorithm by Montanari and Rizzo [\ref MoR05]
 /** \author Bastian Wemmenhove wrote the original implementation before it was merged into libDAI
+ *  \todo Clean up code (use a BipartiteGraph-like implementation for the graph structure)
  */
 class MR : public DAIAlgFG {
     private:
index 4874f6f..b287caa 100644 (file)
@@ -88,14 +88,13 @@ JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) :
 }
 
 
-void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
+void JTree::construct( const std::vector<VarSet> &cl, bool verify ) {
     // Construct a weighted graph (each edge is weighted with the cardinality
-    // of the intersection of the nodes, where the nodes are the elements of
-    // Cliques).
+    // of the intersection of the nodes, where the nodes are the elements of cl).
     WeightedGraph<int> JuncGraph;
-    for( size_t i = 0; i < Cliques.size(); i++ )
-        for( size_t j = i+1; j < Cliques.size(); j++ ) {
-            size_t w = (Cliques[i] & Cliques[j]).size();
+    for( size_t i = 0; i < cl.size(); i++ )
+        for( size_t j = i+1; j < cl.size(); j++ ) {
+            size_t w = (cl[i] & cl[j]).size();
             if( w )
                 JuncGraph[UEdge(i,j)] = w;
         }
@@ -106,24 +105,29 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
     // Construct corresponding region graph
 
     // Create outer regions
-    ORs.reserve( Cliques.size() );
-    for( size_t i = 0; i < Cliques.size(); i++ )
-        ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
+    ORs.clear();
+    ORs.reserve( cl.size() );
+    for( size_t i = 0; i < cl.size(); i++ )
+        ORs.push_back( FRegion( Factor(cl[i], 1.0), 1.0 ) );
 
     // For each factor, find an outer region that subsumes that factor.
     // Then, multiply the outer region with that factor.
+    fac2OR.clear();
+    fac2OR.resize( nrFactors(), -1U );
     for( size_t I = 0; I < nrFactors(); I++ ) {
         size_t alpha;
         for( alpha = 0; alpha < nrORs(); alpha++ )
             if( OR(alpha).vars() >> factor(I).vars() ) {
-                fac2OR.push_back( alpha );
+                fac2OR[I] = alpha;
                 break;
             }
-        DAI_ASSERT( alpha != nrORs() );
+        if( verify )
+            DAI_ASSERT( alpha != nrORs() );
     }
     RecomputeORs();
 
     // Create inner regions and edges
+    IRs.clear();
     IRs.reserve( RTree.size() );
     vector<Edge> edges;
     edges.reserve( 2 * RTree.size() );
@@ -131,13 +135,18 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
         edges.push_back( Edge( RTree[i].n1, nrIRs() ) );
         edges.push_back( Edge( RTree[i].n2, nrIRs() ) );
         // inner clusters have counting number -1
-        IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) );
+        IRs.push_back( Region( cl[RTree[i].n1] & cl[RTree[i].n2], -1.0 ) );
     }
 
     // create bipartite graph
     G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
 
-    // Create messages and beliefs
+    // Check counting numbers
+#ifdef DAI_DEBUG
+    checkCountingNumbers();
+#endif
+
+    // Create beliefs
     Qa.clear();
     Qa.reserve( nrORs() );
     for( size_t alpha = 0; alpha < nrORs(); alpha++ )
@@ -147,7 +156,13 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
     Qb.reserve( nrIRs() );
     for( size_t beta = 0; beta < nrIRs(); beta++ )
         Qb.push_back( Factor( IR(beta), 1.0 ) );
+}
+
 
+void JTree::GenerateJT( const std::vector<VarSet> &cl ) {
+    construct( cl, true );
+
+    // Create messages
     _mes.clear();
     _mes.reserve( nrORs() );
     for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
@@ -157,11 +172,6 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
             _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
     }
 
-    // Check counting numbers
-#ifdef DAI_DEBUG
-    checkCountingNumbers();
-#endif
-
     if( props.verbose >= 3 )
         cerr << "Regiongraph generated by JTree::GenerateJT: " << *this << endl;
 }
index 1685219..000ce59 100644 (file)
@@ -60,132 +60,6 @@ string TreeEP::printProperties() const {
 }
 
 
-TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
-    _ns = _I->vars();
-
-    // Make _Qa, _Qb, _a and _b corresponding to the subtree
-    _b.reserve( subRTree.size() );
-    _Qb.reserve( subRTree.size() );
-    _RTree.reserve( subRTree.size() );
-    for( size_t i = 0; i < subRTree.size(); i++ ) {
-        size_t alpha1 = subRTree[i].n1;     // old index 1
-        size_t alpha2 = subRTree[i].n2;     // old index 2
-        size_t beta;                        // old sep index
-        for( beta = 0; beta < jt_RTree.size(); beta++ )
-            if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
-                break;
-        DAI_ASSERT( beta != jt_RTree.size() );
-
-        size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
-        if( newalpha1 == _a.size() ) {
-            _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
-            _a.push_back( alpha1 );         // save old index in index conversion table
-        }
-
-        size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
-        if( newalpha2 == _a.size() ) {
-            _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
-            _a.push_back( alpha2 );         // save old index in index conversion table
-        }
-
-        _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
-        _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
-        _b.push_back( beta );
-    }
-
-    // Find remaining variables (which are not in the new root)
-    _nsrem = _ns / _Qa[0].vars();
-}
-
-
-void TreeEP::TreeEPSubTree::init() {
-    for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
-        _Qa[alpha].fill( 1.0 );
-    for( size_t beta = 0; beta < _Qb.size(); beta++ )
-        _Qb[beta].fill( 1.0 );
-}
-
-
-void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
-    for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
-        _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
-
-    for( size_t beta = 0; beta < _Qb.size(); beta++ )
-        _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
-}
-
-
-void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
-    // Backup _Qa and _Qb
-    vector<Factor> _Qa_old(_Qa);
-    vector<Factor> _Qb_old(_Qb);
-
-    // Clear Qa and Qb
-    for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
-        Qa[_a[alpha]].fill( 0.0 );
-    for( size_t beta = 0; beta < _Qb.size(); beta++ )
-        Qb[_b[beta]].fill( 0.0 );
-
-    // For all states of _nsrem
-    for( State s(_nsrem); s.valid(); s++ ) {
-        // Multiply root with slice of I
-        _Qa[0] *= _I->slice( _nsrem, s );
-
-        // CollectEvidence
-        for( size_t i = _RTree.size(); (i--) != 0; ) {
-            // clamp variables in nsrem
-            for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
-                if( _Qa[_RTree[i].n2].vars() >> *n ) {
-                    Factor delta( *n, 0.0 );
-                    delta[s(*n)] = 1.0;
-                    _Qa[_RTree[i].n2] *= delta;
-                }
-            Factor new_Qb = _Qa[_RTree[i].n2].marginal( _Qb[i].vars(), false );
-            _Qa[_RTree[i].n1] *= new_Qb / _Qb[i];
-            _Qb[i] = new_Qb;
-        }
-
-        // DistributeEvidence
-        for( size_t i = 0; i < _RTree.size(); i++ ) {
-            Factor new_Qb = _Qa[_RTree[i].n1].marginal( _Qb[i].vars(), false );
-            _Qa[_RTree[i].n2] *= new_Qb / _Qb[i];
-            _Qb[i] = new_Qb;
-        }
-
-        // Store Qa's and Qb's
-        for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
-            Qa[_a[alpha]].p() += _Qa[alpha].p();
-        for( size_t beta = 0; beta < _Qb.size(); beta++ )
-            Qb[_b[beta]].p() += _Qb[beta].p();
-
-        // Restore _Qa and _Qb
-        _Qa = _Qa_old;
-        _Qb = _Qb_old;
-    }
-
-    // Normalize Qa and Qb
-    _logZ = 0.0;
-    for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
-        _logZ += log(Qa[_a[alpha]].sum());
-        Qa[_a[alpha]].normalize();
-    }
-    for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
-        _logZ -= log(Qb[_b[beta]].sum());
-        Qb[_b[beta]].normalize();
-    }
-}
-
-
-Real TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
-    Real s = 0.0;
-    for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
-        s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
-    for( size_t beta = 0; beta < _Qb.size(); beta++ )
-        s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
-    return s + _logZ;
-}
-
-
 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
     setProperties( opts );
 
@@ -243,111 +117,34 @@ TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opt
 
 
 void TreeEP::construct( const RootedTree &tree ) {
-    vector<VarSet> Cliques;
+    vector<VarSet> cl;
     for( size_t i = 0; i < tree.size(); i++ )
-        Cliques.push_back( VarSet( var(tree[i].n1), var(tree[i].n2) ) );
-
-    // Construct a weighted graph (each edge is weighted with the cardinality
-    // of the intersection of the nodes, where the nodes are the elements of
-    // Cliques).
-    WeightedGraph<int> JuncGraph;
-    for( size_t i = 0; i < Cliques.size(); i++ )
-        for( size_t j = i+1; j < Cliques.size(); j++ ) {
-            size_t w = (Cliques[i] & Cliques[j]).size();
-            if( w )
-                JuncGraph[UEdge(i,j)] = w;
-        }
+        cl.push_back( VarSet( var(tree[i].n1), var(tree[i].n2) ) );
 
-    // Construct maximal spanning tree using Prim's algorithm
-    RTree = MaxSpanningTreePrims( JuncGraph );
-
-    // Construct corresponding region graph
-
-    // Create outer regions
-    ORs.reserve( Cliques.size() );
-    for( size_t i = 0; i < Cliques.size(); i++ )
-        ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
-
-    // For each factor, find an outer region that subsumes that factor.
-    // Then, multiply the outer region with that factor.
     // If no outer region can be found subsuming that factor, label the
     // factor as off-tree.
-    fac2OR.clear();
-    fac2OR.resize( nrFactors(), -1U );
-    for( size_t I = 0; I < nrFactors(); I++ ) {
-        size_t alpha;
-        for( alpha = 0; alpha < nrORs(); alpha++ )
-            if( OR(alpha).vars() >> factor(I).vars() ) {
-                fac2OR[I] = alpha;
-                break;
-            }
-    // DIFF WITH JTree::GenerateJT: assert
-    }
-    RecomputeORs();
-
-    // Create inner regions and edges
-    IRs.reserve( RTree.size() );
-    vector<Edge> edges;
-    edges.reserve( 2 * RTree.size() );
-    for( size_t i = 0; i < RTree.size(); i++ ) {
-        edges.push_back( Edge( RTree[i].n1, IRs.size() ) );
-        edges.push_back( Edge( RTree[i].n2, IRs.size() ) );
-        // inner clusters have counting number -1
-        IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) );
-    }
-
-    // create bipartite graph
-    G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
+    JTree::construct( cl, false );
 
-    // Check counting numbers
-    checkCountingNumbers();
-
-    // Create messages and beliefs
-    Qa.clear();
-    Qa.reserve( nrORs() );
-    for( size_t alpha = 0; alpha < nrORs(); alpha++ )
-        Qa.push_back( OR(alpha) );
-
-    Qb.clear();
-    Qb.reserve( nrIRs() );
-    for( size_t beta = 0; beta < nrIRs(); beta++ )
-        Qb.push_back( Factor( IR(beta), 1.0 ) );
-
-    // DIFF with JTree::GenerateJT:  no messages
-
-    // DIFF with JTree::GenerateJT:
     // Create factor approximations
     _Q.clear();
     size_t PreviousRoot = (size_t)-1;
-    for( size_t I = 0; I < nrFactors(); I++ )
-        if( offtree(I) ) {
-            // find efficient subtree
-            RootedTree subTree;
-            /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
-            PreviousRoot = subTree[0].n1;
-            //subTree.resize( subTreeSize );  // FIXME
-//          cerr << "subtree " << I << " has size " << subTreeSize << endl;
-
-            TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
-            _Q[I] = QI;
-        }
-    // Previous root of first off-tree factor should be the root of the last off-tree factor
-    for( size_t I = 0; I < nrFactors(); I++ )
-        if( offtree(I) ) {
-            RootedTree subTree;
-            /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
-            PreviousRoot = subTree[0].n1;
-            //subTree.resize( subTreeSize ); // FIXME
-//          cerr << "subtree " << I << " has size " << subTreeSize << endl;
-
-            TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
-            _Q[I] = QI;
-            break;
-        }
+    // Second repetition: previous root of first off-tree factor should be the root of the last off-tree factor
+    for( size_t repeats = 0; repeats < 2; repeats++ )
+        for( size_t I = 0; I < nrFactors(); I++ )
+            if( offtree(I) ) {
+                // find efficient subtree
+                RootedTree subTree;
+                /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
+                PreviousRoot = subTree[0].n1;
+                //subTree.resize( subTreeSize );  // FIXME
+                //cerr << "subtree " << I << " has size " << subTreeSize << endl;
+                _Q[I] = TreeEPSubTree( subTree, RTree, Qa, Qb, &factor(I) );
+                if( repeats == 1 )
+                    break;
+            }
 
-    if( props.verbose >= 3 ) {
+    if( props.verbose >= 3 )
         cerr << "Resulting regiongraph: " << *this << endl;
-    }
 }
 
 
@@ -374,10 +171,7 @@ Real TreeEP::run() {
 
     double tic = toc();
 
-    vector<Factor> oldBeliefsV;
-    oldBeliefsV.reserve( nrVars() );
-    for( size_t i = 0; i < nrVars(); i++ )
-        oldBeliefsV.push_back( beliefV(i) );
+    vector<Factor> oldBeliefs = beliefs();
 
     // do several passes over the network until maximum number of iterations has
     // been reached or until the maximum belief difference is smaller than tolerance
@@ -391,12 +185,11 @@ Real TreeEP::run() {
             }
 
         // calculate new beliefs and compare with old ones
+        vector<Factor> newBeliefs = beliefs();
         maxDiff = -INFINITY;
-        for( size_t i = 0; i < nrVars(); i++ ) {
-            Factor nb( beliefV(i) );
-            maxDiff = std::max( maxDiff, dist( nb, oldBeliefsV[i], Prob::DISTLINF ) );
-            oldBeliefsV[i] = nb;
-        }
+        for( size_t t = 0; t < oldBeliefs.size(); t++ )
+            maxDiff = std::max( maxDiff, dist( newBeliefs[t], oldBeliefs[t], Prob::DISTLINF ) );
+        swap( newBeliefs, oldBeliefs );
 
         if( props.verbose >= 3 )
             cerr << Name << "::run:  maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
@@ -443,4 +236,130 @@ Real TreeEP::logZ() const {
 }
 
 
+TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
+    _ns = _I->vars();
+
+    // Make _Qa, _Qb, _a and _b corresponding to the subtree
+    _b.reserve( subRTree.size() );
+    _Qb.reserve( subRTree.size() );
+    _RTree.reserve( subRTree.size() );
+    for( size_t i = 0; i < subRTree.size(); i++ ) {
+        size_t alpha1 = subRTree[i].n1;     // old index 1
+        size_t alpha2 = subRTree[i].n2;     // old index 2
+        size_t beta;                        // old sep index
+        for( beta = 0; beta < jt_RTree.size(); beta++ )
+            if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
+                break;
+        DAI_ASSERT( beta != jt_RTree.size() );
+
+        size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
+        if( newalpha1 == _a.size() ) {
+            _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
+            _a.push_back( alpha1 );         // save old index in index conversion table
+        }
+
+        size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
+        if( newalpha2 == _a.size() ) {
+            _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
+            _a.push_back( alpha2 );         // save old index in index conversion table
+        }
+
+        _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
+        _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
+        _b.push_back( beta );
+    }
+
+    // Find remaining variables (which are not in the new root)
+    _nsrem = _ns / _Qa[0].vars();
+}
+
+
+void TreeEP::TreeEPSubTree::init() {
+    for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
+        _Qa[alpha].fill( 1.0 );
+    for( size_t beta = 0; beta < _Qb.size(); beta++ )
+        _Qb[beta].fill( 1.0 );
+}
+
+
+void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
+    for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
+        _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
+
+    for( size_t beta = 0; beta < _Qb.size(); beta++ )
+        _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
+}
+
+
+void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
+    // Backup _Qa and _Qb
+    vector<Factor> _Qa_old(_Qa);
+    vector<Factor> _Qb_old(_Qb);
+
+    // Clear Qa and Qb
+    for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
+        Qa[_a[alpha]].fill( 0.0 );
+    for( size_t beta = 0; beta < _Qb.size(); beta++ )
+        Qb[_b[beta]].fill( 0.0 );
+
+    // For all states of _nsrem
+    for( State s(_nsrem); s.valid(); s++ ) {
+        // Multiply root with slice of I
+        _Qa[0] *= _I->slice( _nsrem, s );
+
+        // CollectEvidence
+        for( size_t i = _RTree.size(); (i--) != 0; ) {
+            // clamp variables in nsrem
+            for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
+                if( _Qa[_RTree[i].n2].vars() >> *n ) {
+                    Factor delta( *n, 0.0 );
+                    delta[s(*n)] = 1.0;
+                    _Qa[_RTree[i].n2] *= delta;
+                }
+            Factor new_Qb = _Qa[_RTree[i].n2].marginal( _Qb[i].vars(), false );
+            _Qa[_RTree[i].n1] *= new_Qb / _Qb[i];
+            _Qb[i] = new_Qb;
+        }
+
+        // DistributeEvidence
+        for( size_t i = 0; i < _RTree.size(); i++ ) {
+            Factor new_Qb = _Qa[_RTree[i].n1].marginal( _Qb[i].vars(), false );
+            _Qa[_RTree[i].n2] *= new_Qb / _Qb[i];
+            _Qb[i] = new_Qb;
+        }
+
+        // Store Qa's and Qb's
+        for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
+            Qa[_a[alpha]].p() += _Qa[alpha].p();
+        for( size_t beta = 0; beta < _Qb.size(); beta++ )
+            Qb[_b[beta]].p() += _Qb[beta].p();
+
+        // Restore _Qa and _Qb
+        _Qa = _Qa_old;
+        _Qb = _Qb_old;
+    }
+
+    // Normalize Qa and Qb
+    _logZ = 0.0;
+    for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
+        _logZ += log(Qa[_a[alpha]].sum());
+        Qa[_a[alpha]].normalize();
+    }
+    for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
+        _logZ -= log(Qb[_b[beta]].sum());
+        Qb[_b[beta]].normalize();
+    }
+}
+
+
+Real TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
+    Real s = 0.0;
+    for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
+        s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
+    for( size_t beta = 0; beta < _Qb.size(); beta++ )
+        s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
+    return s + _logZ;
+}
+
+
 } // end of namespace dai