Implemented various heuristics for choosing a variable elimination sequence in JTree
[libdai.git] / src / clustergraph.cpp
index 60adf3a..c947f13 100644 (file)
@@ -45,27 +45,54 @@ ClusterGraph::ClusterGraph( const std::vector<VarSet> & cls ) : G(), vars(), clu
 }
 
 
+size_t ClusterGraph::eliminationCost( size_t i ) const {
+    return eliminationCost_MinFill( *this, i );
+}
+
 
 ClusterGraph ClusterGraph::VarElim( const std::vector<Var> &ElimSeq ) const {
-    // Make a copy
-    ClusterGraph cl(*this);
-    cl.eraseNonMaximal();
+    return VarElim( sequentialVariableElimination( ElimSeq ) );
+}
 
-    ClusterGraph result;
 
-    // Do variable elimination
-    for( vector<Var>::const_iterator n = ElimSeq.begin(); n != ElimSeq.end(); n++ ) {
-        size_t i = cl.findVar( *n );
-        DAI_ASSERT( i != cl.vars.size() );
+ClusterGraph ClusterGraph::VarElim_MinFill() const {
+    return VarElim( greedyVariableElimination( &eliminationCost_MinFill ) );
+}
+
 
-        result.insert( cl.Delta(i) );
+size_t sequentialVariableElimination::operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ ) {
+    return cl.findVar( seq.at(i++) );
+}
 
-        cl.insert( cl.delta(i) );
-        cl.eraseSubsuming( i );
-        cl.eraseNonMaximal();
+
+size_t greedyVariableElimination::operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars ) {
+    set<size_t>::const_iterator lowest = remainingVars.end();
+    size_t lowest_cost = -1UL;
+    for( set<size_t>::const_iterator i = remainingVars.begin(); i != remainingVars.end(); i++ ) {
+        size_t cost = heuristic( cl, *i );
+        if( lowest == remainingVars.end() || lowest_cost > cost ) {
+            lowest = i;
+            lowest_cost = cost;
+        }
     }
+    return *lowest;
+}
+
+
+size_t eliminationCost_MinNeighbors( const ClusterGraph &cl, size_t i ) {
+    std::vector<size_t> id_n = cl.G.delta1( i );
+    return id_n.size();
+}
+
+
+size_t eliminationCost_MinWeight( const ClusterGraph &cl, size_t i ) {
+    std::vector<size_t> id_n = cl.G.delta1( i );
+    
+    size_t cost = 1;
+    for( size_t _i = 0; _i < id_n.size(); _i++ )
+        cost *= cl.vars[id_n[_i]].states();
 
-    return result;
+    return cost;
 }
 
 
@@ -73,7 +100,6 @@ size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i ) {
     std::vector<size_t> id_n = cl.G.delta1( i );
 
     size_t cost = 0;
-
     // for each unordered pair {i1,i2} adjacent to n
     for( size_t _i1 = 0; _i1 < id_n.size(); _i1++ )
         for( size_t _i2 = _i1 + 1; _i2 < id_n.size(); _i2++ ) {
@@ -86,18 +112,20 @@ size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i ) {
 }
 
 
-size_t eliminationChoice_MinFill( const ClusterGraph &cl, const std::set<size_t> &remainingVars ) {
-    set<size_t>::const_iterator lowest = remainingVars.end();
-    size_t lowest_cost = -1UL;
-    for( set<size_t>::const_iterator i = remainingVars.begin(); i != remainingVars.end(); i++ ) {
-        size_t cost = eliminationCost_MinFill( cl, *i );
-        if( lowest == remainingVars.end() || lowest_cost > cost ) {
-            lowest = i;
-            lowest_cost = cost;
+size_t eliminationCost_WeightedMinFill( const ClusterGraph &cl, size_t i ) {
+    std::vector<size_t> id_n = cl.G.delta1( i );
+
+    size_t cost = 0;
+    // for each unordered pair {i1,i2} adjacent to n
+    for( size_t _i1 = 0; _i1 < id_n.size(); _i1++ )
+        for( size_t _i2 = _i1 + 1; _i2 < id_n.size(); _i2++ ) {
+            // if i1 and i2 are not adjacent, eliminating n would make them adjacent
+            if( !cl.adj(id_n[_i1], id_n[_i2]) )
+                cost += cl.vars[id_n[_i1]].states() * cl.vars[id_n[_i2]].states();
         }
-    }
-    return *lowest;
-}    
+
+    return cost;
+}
 
 
 } // end of namespace dai