Cleaned up some code in TreeEP and JTree
[libdai.git] / src / jtree.cpp
index 4874f6f..b287caa 100644 (file)
@@ -88,14 +88,13 @@ JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) :
 }
 
 
-void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
+void JTree::construct( const std::vector<VarSet> &cl, bool verify ) {
     // Construct a weighted graph (each edge is weighted with the cardinality
-    // of the intersection of the nodes, where the nodes are the elements of
-    // Cliques).
+    // of the intersection of the nodes, where the nodes are the elements of cl).
     WeightedGraph<int> JuncGraph;
-    for( size_t i = 0; i < Cliques.size(); i++ )
-        for( size_t j = i+1; j < Cliques.size(); j++ ) {
-            size_t w = (Cliques[i] & Cliques[j]).size();
+    for( size_t i = 0; i < cl.size(); i++ )
+        for( size_t j = i+1; j < cl.size(); j++ ) {
+            size_t w = (cl[i] & cl[j]).size();
             if( w )
                 JuncGraph[UEdge(i,j)] = w;
         }
@@ -106,24 +105,29 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
     // Construct corresponding region graph
 
     // Create outer regions
-    ORs.reserve( Cliques.size() );
-    for( size_t i = 0; i < Cliques.size(); i++ )
-        ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
+    ORs.clear();
+    ORs.reserve( cl.size() );
+    for( size_t i = 0; i < cl.size(); i++ )
+        ORs.push_back( FRegion( Factor(cl[i], 1.0), 1.0 ) );
 
     // For each factor, find an outer region that subsumes that factor.
     // Then, multiply the outer region with that factor.
+    fac2OR.clear();
+    fac2OR.resize( nrFactors(), -1U );
     for( size_t I = 0; I < nrFactors(); I++ ) {
         size_t alpha;
         for( alpha = 0; alpha < nrORs(); alpha++ )
             if( OR(alpha).vars() >> factor(I).vars() ) {
-                fac2OR.push_back( alpha );
+                fac2OR[I] = alpha;
                 break;
             }
-        DAI_ASSERT( alpha != nrORs() );
+        if( verify )
+            DAI_ASSERT( alpha != nrORs() );
     }
     RecomputeORs();
 
     // Create inner regions and edges
+    IRs.clear();
     IRs.reserve( RTree.size() );
     vector<Edge> edges;
     edges.reserve( 2 * RTree.size() );
@@ -131,13 +135,18 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
         edges.push_back( Edge( RTree[i].n1, nrIRs() ) );
         edges.push_back( Edge( RTree[i].n2, nrIRs() ) );
         // inner clusters have counting number -1
-        IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) );
+        IRs.push_back( Region( cl[RTree[i].n1] & cl[RTree[i].n2], -1.0 ) );
     }
 
     // create bipartite graph
     G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
 
-    // Create messages and beliefs
+    // Check counting numbers
+#ifdef DAI_DEBUG
+    checkCountingNumbers();
+#endif
+
+    // Create beliefs
     Qa.clear();
     Qa.reserve( nrORs() );
     for( size_t alpha = 0; alpha < nrORs(); alpha++ )
@@ -147,7 +156,13 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
     Qb.reserve( nrIRs() );
     for( size_t beta = 0; beta < nrIRs(); beta++ )
         Qb.push_back( Factor( IR(beta), 1.0 ) );
+}
+
 
+void JTree::GenerateJT( const std::vector<VarSet> &cl ) {
+    construct( cl, true );
+
+    // Create messages
     _mes.clear();
     _mes.reserve( nrORs() );
     for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
@@ -157,11 +172,6 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
             _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
     }
 
-    // Check counting numbers
-#ifdef DAI_DEBUG
-    checkCountingNumbers();
-#endif
-
     if( props.verbose >= 3 )
         cerr << "Regiongraph generated by JTree::GenerateJT: " << *this << endl;
 }