Improved ClusterGraph implementation and MaxSpanningTreePrims implementation.
[libdai.git] / src / clustergraph.cpp
index 80e107e..4b0695e 100644 (file)
@@ -32,87 +32,84 @@ namespace dai {
 using namespace std;
 
 
-ClusterGraph ClusterGraph::VarElim( const std::vector<Var> & ElimSeq ) const {
-    const long verbose = 0;
+ClusterGraph::ClusterGraph( const std::vector<VarSet> & cls ) : G(), vars(), clusters() {
+    // construct vars, clusters and edge list
+    vector<Edge> edges;
+    foreach( const VarSet &cl, cls ) {
+        if( find( clusters.begin(), clusters.end(), cl ) == clusters.end() ) {
+            // add cluster
+            size_t n2 = clusters.size();
+            clusters.push_back( cl );
+            for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
+                size_t n1 = find( vars.begin(), vars.end(), *n ) - vars.begin();
+                if( n1 == vars.size() )
+                    // add variable
+                    vars.push_back( *n );
+                edges.push_back( Edge( n1, n2 ) );
+            }
+        } // disregard duplicate clusters
+    }
+
+    // Create bipartite graph
+    G.create( vars.size(), clusters.size(), edges.begin(), edges.end() );
+}
+
 
+ClusterGraph ClusterGraph::VarElim_MinFill() const {
     // Make a copy
-    ClusterGraph _Cl(*this);
+    ClusterGraph cl(*this);
+    cl.eraseNonMaximal();
 
     ClusterGraph result;
-    _Cl.eraseNonMaximal();
+
+    // Construct set of variable indices
+    set<size_t> varindices;
+    for( size_t i = 0; i < vars.size(); ++i )
+        varindices.insert( i );
     
     // Do variable elimination
-    for( vector<Var>::const_iterator n = ElimSeq.begin(); n != ElimSeq.end(); n++ ) {
-        assert( _Cl.vars() && *n );
-
-        if( verbose >= 1 )
-            cout << "Cost of eliminating " << *n << ": " << _Cl.eliminationCost( *n ) << " new edges" << endl;
-        
-        result.insert( _Cl.Delta(*n) );
-
-        if( verbose >= 1 )
-            cout << "_Cl = " << _Cl << endl;
-
-        if( verbose >= 1 )
-            cout << "After inserting " << _Cl.delta(*n) << ", _Cl = ";
-        _Cl.insert( _Cl.delta(*n) );
-        if( verbose >= 1 )
-            cout << _Cl << endl;
-
-        if( verbose >= 1 )
-            cout << "After erasing clusters that contain " << *n <<  ", _Cl = ";
-        _Cl.eraseSubsuming( *n );
-        if( verbose >= 1 )
-            cout << _Cl << endl;
-
-        if( verbose >= 1 )
-            cout << "After erasing nonmaximal clusters, _Cl = ";
-        _Cl.eraseNonMaximal();
-        if( verbose >= 1 )
-            cout << _Cl << endl;
+    while( !varindices.empty() ) {
+        set<size_t>::const_iterator lowest = varindices.end();
+        size_t lowest_cost = -1UL;
+        for( set<size_t>::const_iterator i = varindices.begin(); i != varindices.end(); i++ ) {
+            size_t cost = cl.eliminationCost( *i );
+            if( lowest == varindices.end() || lowest_cost > cost ) {
+                lowest = i;
+                lowest_cost = cost;
+            }
+        }
+        size_t i = *lowest;
+
+        result.insert( cl.Delta( i ) );
+
+        cl.insert( cl.delta( i ) );
+        cl.eraseSubsuming( i );
+        cl.eraseNonMaximal();
+        varindices.erase( i );
     }
 
     return result;
 }
 
 
-ClusterGraph ClusterGraph::VarElim_MinFill() const {
-    const long verbose = 0;
 
+ClusterGraph ClusterGraph::VarElim( const std::vector<Var> & ElimSeq ) const {
     // Make a copy
-    ClusterGraph _Cl(*this);
-    VarSet _vars( vars() );
+    ClusterGraph cl(*this);
+    cl.eraseNonMaximal();
 
     ClusterGraph result;
-    _Cl.eraseNonMaximal();
     
     // Do variable elimination
-    while( !_vars.empty() ) {
-        if( verbose >= 1 )
-            cout << "Var  Eliminiation cost" << endl;
-        VarSet::const_iterator lowest = _vars.end();
-        size_t lowest_cost = -1UL;
-        for( VarSet::const_iterator n = _vars.begin(); n != _vars.end(); n++ ) {
-            size_t cost = _Cl.eliminationCost( *n );
-            if( verbose >= 1 )
-                cout << *n << "  " << cost << endl;
-            if( lowest == _vars.end() || lowest_cost > cost ) {
-                lowest = n;
-                lowest_cost = cost;
-            }
-        }
-        Var n = *lowest;
-
-        if( verbose >= 1 )
-            cout << "Lowest: " << n << " (" << lowest_cost << ")" << endl;
-
-        result.insert( _Cl.Delta(n) );
+    for( vector<Var>::const_iterator n = ElimSeq.begin(); n != ElimSeq.end(); n++ ) {
+        size_t i = cl.findVar( *n );
+        assert( i != cl.vars.size() );
 
-        _Cl.insert( _Cl.delta(n) );
-        _Cl.eraseSubsuming( n );
-        _Cl.eraseNonMaximal();
-        _vars /= n;
+        result.insert( cl.Delta(i) );
 
+        cl.insert( cl.delta(i) );
+        cl.eraseSubsuming( i );
+        cl.eraseNonMaximal();
     }
 
     return result;