Added BETHE method for GBP/HAK cluster choice
[libdai.git] / src / regiongraph.cpp
index 5b1ef7e..6df320e 100644 (file)
@@ -23,14 +23,22 @@ namespace dai {
 using namespace std;
 
 
-RegionGraph::RegionGraph( const FactorGraph &fg, const std::vector<Region> &ors, const std::vector<Region> &irs, const std::vector<std::pair<size_t,size_t> > &edges) : FactorGraph(fg), G(), ORs(), IRs(irs), fac2OR() {
-    // Copy outer regions (giving them counting number 1.0)
+void RegionGraph::construct( const FactorGraph &fg, const std::vector<VarSet> &ors, const std::vector<Region> &irs, const std::vector<std::pair<size_t,size_t> > &edges ) {
+    // Copy factor graph structure
+    FactorGraph::operator=( fg );
+
+    // Copy inner regions
+    IRs = irs;
+
+    // Construct outer regions (giving them counting number 1.0)
+    ORs.clear();
     ORs.reserve( ors.size() );
-    for( vector<Region>::const_iterator alpha = ors.begin(); alpha != ors.end(); alpha++ )
-        ORs.push_back( FRegion(Factor(*alpha, 1.0), 1.0) );
+    foreach( const VarSet &alpha, ors )
+        ORs.push_back( FRegion(Factor(alpha, 1.0), 1.0) );
 
     // For each factor, find an outer region that subsumes that factor.
     // Then, multiply the outer region with that factor.
+    fac2OR.clear();
     fac2OR.reserve( nrFactors() );
     for( size_t I = 0; I < nrFactors(); I++ ) {
         size_t alpha;
@@ -45,39 +53,14 @@ RegionGraph::RegionGraph( const FactorGraph &fg, const std::vector<Region> &ors,
 
     // Create bipartite graph
     G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
-
-    // Check counting numbers
-#ifdef DAI_DEBUG
-    checkCountingNumbers();
-#endif
 }
 
 
-// CVM style
-RegionGraph::RegionGraph( const FactorGraph &fg, const std::vector<VarSet> &cl ) : FactorGraph(fg), G(), ORs(), IRs(), fac2OR() {
+void RegionGraph::constructCVM( const FactorGraph &fg, const std::vector<VarSet> &cl ) {
     // Retain only maximal clusters
     ClusterGraph cg( cl );
     cg.eraseNonMaximal();
 
-    // Create outer regions, giving them counting number 1.0
-    ORs.reserve( cg.size() );
-    foreach( const VarSet &ns, cg.clusters )
-        ORs.push_back( FRegion(Factor(ns, 1.0), 1.0) );
-
-    // For each factor, find an outer regions that subsumes that factor.
-    // Then, multiply the outer region with that factor.
-    fac2OR.reserve( nrFactors() );
-    for( size_t I = 0; I < nrFactors(); I++ ) {
-        size_t alpha;
-        for( alpha = 0; alpha < nrORs(); alpha++ )
-            if( OR(alpha).vars() >> factor(I).vars() ) {
-                fac2OR.push_back( alpha );
-                break;
-            }
-        DAI_ASSERT( alpha != nrORs() );
-    }
-    RecomputeORs();
-
     // Create inner regions - first pass
     set<VarSet> betas;
     for( size_t alpha = 0; alpha < cg.clusters.size(); alpha++ )
@@ -100,25 +83,35 @@ RegionGraph::RegionGraph( const FactorGraph &fg, const std::vector<VarSet> &cl )
         betas.insert(new_betas.begin(), new_betas.end());
     } while( new_betas.size() );
 
-    // Create inner regions - store them in the bipartite graph
-    IRs.reserve( betas.size() );
+    // Create inner regions - final phase
+    vector<Region> irs;
+    irs.reserve( betas.size() );
     for( set<VarSet>::const_iterator beta = betas.begin(); beta != betas.end(); beta++ )
-        IRs.push_back( Region(*beta,0.0) );
+        irs.push_back( Region(*beta,0.0) );
 
     // Create edges
     vector<pair<size_t,size_t> > edges;
-    for( size_t beta = 0; beta < nrIRs(); beta++ ) {
-        for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
-            if( OR(alpha).vars() >> IR(beta) )
+    for( size_t beta = 0; beta < irs.size(); beta++ )
+        for( size_t alpha = 0; alpha < cg.clusters.size(); alpha++ )
+            if( cg.clusters[alpha] >> irs[beta] )
                 edges.push_back( pair<size_t,size_t>(alpha,beta) );
-        }
-    }
 
-    // Create bipartite graph
-    G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
+    // Construct region graph
+    construct( fg, cg.clusters, irs, edges );
 
     // Calculate counting numbers
     calcCountingNumbers();
+}
+
+
+RegionGraph::RegionGraph( const FactorGraph &fg, const std::vector<Region> &ors, const std::vector<Region> &irs, const std::vector<std::pair<size_t,size_t> > &edges ) {
+    vector<VarSet> ors_alt;
+    ors_alt.reserve( ors.size() );
+    for( size_t alpha = 0; alpha < ors.size(); alpha++ ) {
+        ors_alt.push_back( ors[alpha] );
+        DAI_ASSERT( ors[alpha].c() == 1.0 );
+    }
+    construct( fg, ors_alt, irs, edges );
 
     // Check counting numbers
 #ifdef DAI_DEBUG