Cleaned up variable elimination code in ClusterGraph
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Wed, 13 Jan 2010 17:36:52 +0000 (18:36 +0100)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Wed, 13 Jan 2010 17:36:52 +0000 (18:36 +0100)
include/dai/bipgraph.h
include/dai/clustergraph.h
src/clustergraph.cpp
src/jtree.cpp

index 815b43e..62f1f57 100644 (file)
@@ -42,7 +42,6 @@ namespace dai {
  *  Thus, each node has an associated variable of type BipartiteGraph::Neighbors, which is a vector of
  *  Neighbor structures, describing its neighboring nodes of the other type.
  *  \idea Cache second-order neighborhoods in BipartiteGraph.
- *  \todo Check whether BGL isConnected improves performance.
  */
 class BipartiteGraph {
     public:
index bb77ba4..8c3730f 100644 (file)
 namespace dai {
 
 
+    class ClusterGraph;
+
+    /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinFill" criterion.
+    /** The cost is measured as "number of added edges in the adjacency graph",
+     *  where the adjacency graph has the variables as its nodes and connects
+     *  nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
+     */
+    size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i );
+
+    /// Returns the best variable from \a remainingVars to eliminate in the cluster graph \a cl according to the "MinFill" criterion.
+    /** This function invokes eliminationCost_MinFill() for each variable in \a remainingVars, and returns
+     *  the variable which has lowest cost according to eliminationCost_MinFill().
+     *  \note This function can be passed to ClusterGraph::VarElim().
+     */
+    size_t eliminationChoice_MinFill( const ClusterGraph &cl, const std::set<size_t> &remainingVars );
+
+
     /// A ClusterGraph is a hypergraph with variables as nodes, and "clusters" (sets of variables) as hyperedges.
     /** It is implemented as a bipartite graph with variable (Var) nodes and cluster (VarSet) nodes.
      */
@@ -85,7 +102,7 @@ namespace dai {
             }
 
             /// Returns \c true if variables with indices \a i1 and \a i2 are adjacent, i.e., both contained in the same cluster
-            bool adj( size_t i1, size_t i2 ) {
+            bool adj( size_t i1, size_t i2 ) const {
                 bool result = false;
                 foreach( const Neighbor &I, G.nb1(i1) )
                     if( find( G.nb2(I).begin(), G.nb2(I).end(), i2 ) != G.nb2(I).end() ) {
@@ -171,37 +188,59 @@ namespace dai {
             /** The cost is measured as "number of added edges in the adjacency graph",
              *  where the adjacency graph has the variables as its nodes and connects
              *  nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
+             *  \deprecated Please use dai::eliminationCost_MinFill() instead.
              */
-            size_t eliminationCost( size_t i ) {
-                std::vector<size_t> id_n = G.delta1( i );
-
-                size_t cost = 0;
-
-                // for each unordered pair {i1,i2} adjacent to n
-                for( size_t _i1 = 0; _i1 < id_n.size(); _i1++ )
-                    for( size_t _i2 = _i1 + 1; _i2 < id_n.size(); _i2++ ) {
-                        // if i1 and i2 are not adjacent, eliminating n would make them adjacent
-                        if( !adj(id_n[_i1], id_n[_i2]) )
-                            cost++;
-                    }
-
-                return cost;
+            size_t eliminationCost( size_t i ) const {
+                return eliminationCost_MinFill( *this, i );
             }
 
             /// Performs Variable Elimination, only keeping track of the interactions that are created along the way.
             /** \param ElimSeq The sequence in which to eliminate the variables
              *  \return A set of elimination "cliques"
-             *  \todo Variable elimination should be implemented generically using a function
-             *  object that specifies which variable to delete next.
+             *  \deprecated Not used; if necessary, dai::ClusterGraph::VarElim( EliminationChoice & ) can be used instead.
              */
             ClusterGraph VarElim( const std::vector<Var> &ElimSeq ) const;
 
-            /// Performs Variable Eliminiation using the "MinFill" heuristic
+            /// Performs Variable Elimination using the "MinFill" heuristic
             /** The "MinFill" heuristic greedily minimizes the cost of eliminating a variable,
              *  measured with eliminationCost().
              *  \return A set of elimination "cliques"
+             *  \deprecated Please use dai::ClusterGraph::VarElim( eliminationChoice_MinFill ) instead.
              */
-            ClusterGraph VarElim_MinFill() const;
+            ClusterGraph VarElim_MinFill() const {
+                return VarElim( eliminationChoice_MinFill );
+            }
+
+            /// Performs Variable Elimination, only keeping track of the interactions that are created along the way.
+            /** \tparam EliminationChoice should support "size_t operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars )"
+             *  \param f function object which returns the next variable index to eliminate; for example, eliminationChoice_MinFill()
+             *  \return A set of elimination "cliques"
+             */
+            template<class EliminationChoice>
+            ClusterGraph VarElim( EliminationChoice &f ) const {
+                // Make a copy
+                ClusterGraph cl(*this);
+                cl.eraseNonMaximal();
+
+                ClusterGraph result;
+
+                // Construct set of variable indices
+                std::set<size_t> varindices;
+                for( size_t i = 0; i < vars.size(); ++i )
+                    varindices.insert( i );
+
+                // Do variable elimination
+                while( !varindices.empty() ) {
+                    size_t i = f( cl, varindices );
+                    result.insert( cl.Delta( i ) );
+                    cl.insert( cl.delta( i ) );
+                    cl.eraseSubsuming( i );
+                    cl.eraseNonMaximal();
+                    varindices.erase( i );
+                }
+
+                return result;
+            }
         //@}
     };
 
index e257f46..60adf3a 100644 (file)
@@ -45,65 +45,59 @@ ClusterGraph::ClusterGraph( const std::vector<VarSet> & cls ) : G(), vars(), clu
 }
 
 
-ClusterGraph ClusterGraph::VarElim_MinFill() const {
+
+ClusterGraph ClusterGraph::VarElim( const std::vector<Var> &ElimSeq ) const {
     // Make a copy
     ClusterGraph cl(*this);
     cl.eraseNonMaximal();
 
     ClusterGraph result;
 
-    // Construct set of variable indices
-    set<size_t> varindices;
-    for( size_t i = 0; i < vars.size(); ++i )
-        varindices.insert( i );
-
     // Do variable elimination
-    while( !varindices.empty() ) {
-        set<size_t>::const_iterator lowest = varindices.end();
-        size_t lowest_cost = -1UL;
-        for( set<size_t>::const_iterator i = varindices.begin(); i != varindices.end(); i++ ) {
-            size_t cost = cl.eliminationCost( *i );
-            if( lowest == varindices.end() || lowest_cost > cost ) {
-                lowest = i;
-                lowest_cost = cost;
-            }
-        }
-        size_t i = *lowest;
+    for( vector<Var>::const_iterator n = ElimSeq.begin(); n != ElimSeq.end(); n++ ) {
+        size_t i = cl.findVar( *n );
+        DAI_ASSERT( i != cl.vars.size() );
 
-        result.insert( cl.Delta( i ) );
+        result.insert( cl.Delta(i) );
 
-        cl.insert( cl.delta( i ) );
+        cl.insert( cl.delta(i) );
         cl.eraseSubsuming( i );
         cl.eraseNonMaximal();
-        varindices.erase( i );
     }
 
     return result;
 }
 
 
+size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i ) {
+    std::vector<size_t> id_n = cl.G.delta1( i );
 
-ClusterGraph ClusterGraph::VarElim( const std::vector<Var> & ElimSeq ) const {
-    // Make a copy
-    ClusterGraph cl(*this);
-    cl.eraseNonMaximal();
+    size_t cost = 0;
 
-    ClusterGraph result;
+    // for each unordered pair {i1,i2} adjacent to n
+    for( size_t _i1 = 0; _i1 < id_n.size(); _i1++ )
+        for( size_t _i2 = _i1 + 1; _i2 < id_n.size(); _i2++ ) {
+            // if i1 and i2 are not adjacent, eliminating n would make them adjacent
+            if( !cl.adj(id_n[_i1], id_n[_i2]) )
+                cost++;
+        }
 
-    // Do variable elimination
-    for( vector<Var>::const_iterator n = ElimSeq.begin(); n != ElimSeq.end(); n++ ) {
-        size_t i = cl.findVar( *n );
-        DAI_ASSERT( i != cl.vars.size() );
+    return cost;
+}
 
-        result.insert( cl.Delta(i) );
 
-        cl.insert( cl.delta(i) );
-        cl.eraseSubsuming( i );
-        cl.eraseNonMaximal();
+size_t eliminationChoice_MinFill( const ClusterGraph &cl, const std::set<size_t> &remainingVars ) {
+    set<size_t>::const_iterator lowest = remainingVars.end();
+    size_t lowest_cost = -1UL;
+    for( set<size_t>::const_iterator i = remainingVars.begin(); i != remainingVars.end(); i++ ) {
+        size_t cost = eliminationCost_MinFill( cl, *i );
+        if( lowest == remainingVars.end() || lowest_cost > cost ) {
+            lowest = i;
+            lowest_cost = cost;
+        }
     }
-
-    return result;
-}
+    return *lowest;
+}    
 
 
 } // end of namespace dai
index b287caa..bcdbefc 100644 (file)
@@ -78,9 +78,9 @@ JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) :
             cerr << "Maximal clusters: " << _cg << endl;
 
         // Use MinFill heuristic to guess optimal elimination sequence
-        vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
+        vector<VarSet> ElimVec = _cg.VarElim( eliminationChoice_MinFill ).eraseNonMaximal().toVector();
         if( props.verbose >= 3 )
-            cerr << "VarElim_MinFill result: " << ElimVec << endl;
+            cerr << "VarElim result: " << ElimVec << endl;
 
         // Generate the junction tree corresponding to the elimination sequence
         GenerateJT( ElimVec );
@@ -511,7 +511,7 @@ std::pair<size_t,size_t> boundTreewidth( const FactorGraph & fg ) {
     _cg.eraseNonMaximal();
 
     // Obtain elimination sequence
-    vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
+    vector<VarSet> ElimVec = _cg.VarElim( eliminationChoice_MinFill ).eraseNonMaximal().toVector();
 
     // Calculate treewidth
     size_t treewidth = 0;