Implemented various heuristics for choosing a variable elimination sequence in JTree
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Mon, 18 Jan 2010 16:03:16 +0000 (17:03 +0100)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Mon, 18 Jan 2010 16:03:16 +0000 (17:03 +0100)
ChangeLog
include/dai/clustergraph.h
include/dai/doc.h
include/dai/jtree.h
include/dai/varset.h
src/clustergraph.cpp
src/jtree.cpp
utils/fginfo.cpp

index 47164e1..12a59d0 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,5 @@
+* Implemented various heuristics for choosing a variable elimination sequence 
+  in JTree
 * Added BETHE method for GBP/HAK cluster choice
 * Renamed some functions of BipartiteGraph:
     add1() -> addNode1()
index 8c3730f..8be220b 100644 (file)
 namespace dai {
 
 
-    class ClusterGraph;
-
-    /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinFill" criterion.
-    /** The cost is measured as "number of added edges in the adjacency graph",
-     *  where the adjacency graph has the variables as its nodes and connects
-     *  nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
-     */
-    size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i );
-
-    /// Returns the best variable from \a remainingVars to eliminate in the cluster graph \a cl according to the "MinFill" criterion.
-    /** This function invokes eliminationCost_MinFill() for each variable in \a remainingVars, and returns
-     *  the variable which has lowest cost according to eliminationCost_MinFill().
-     *  \note This function can be passed to ClusterGraph::VarElim().
-     */
-    size_t eliminationChoice_MinFill( const ClusterGraph &cl, const std::set<size_t> &remainingVars );
-
-
     /// A ClusterGraph is a hypergraph with variables as nodes, and "clusters" (sets of variables) as hyperedges.
     /** It is implemented as a bipartite graph with variable (Var) nodes and cluster (VarSet) nodes.
      */
@@ -184,40 +167,13 @@ namespace dai {
 
         /// \name Variable elimination
         //@{
-            /// Calculates cost of eliminating the \a i 'th variable.
-            /** The cost is measured as "number of added edges in the adjacency graph",
-             *  where the adjacency graph has the variables as its nodes and connects
-             *  nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
-             *  \deprecated Please use dai::eliminationCost_MinFill() instead.
-             */
-            size_t eliminationCost( size_t i ) const {
-                return eliminationCost_MinFill( *this, i );
-            }
-
-            /// Performs Variable Elimination, only keeping track of the interactions that are created along the way.
-            /** \param ElimSeq The sequence in which to eliminate the variables
-             *  \return A set of elimination "cliques"
-             *  \deprecated Not used; if necessary, dai::ClusterGraph::VarElim( EliminationChoice & ) can be used instead.
-             */
-            ClusterGraph VarElim( const std::vector<Var> &ElimSeq ) const;
-
-            /// Performs Variable Elimination using the "MinFill" heuristic
-            /** The "MinFill" heuristic greedily minimizes the cost of eliminating a variable,
-             *  measured with eliminationCost().
-             *  \return A set of elimination "cliques"
-             *  \deprecated Please use dai::ClusterGraph::VarElim( eliminationChoice_MinFill ) instead.
-             */
-            ClusterGraph VarElim_MinFill() const {
-                return VarElim( eliminationChoice_MinFill );
-            }
-
             /// Performs Variable Elimination, only keeping track of the interactions that are created along the way.
             /** \tparam EliminationChoice should support "size_t operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars )"
-             *  \param f function object which returns the next variable index to eliminate; for example, eliminationChoice_MinFill()
-             *  \return A set of elimination "cliques"
+             *  \param f function object which returns the next variable index to eliminate; for example, a dai::greedyVariableElimination object.
+             *  \return A set of elimination "cliques".
              */
             template<class EliminationChoice>
-            ClusterGraph VarElim( EliminationChoice &f ) const {
+            ClusterGraph VarElim( EliminationChoice f ) const {
                 // Make a copy
                 ClusterGraph cl(*this);
                 cl.eraseNonMaximal();
@@ -232,6 +188,7 @@ namespace dai {
                 // Do variable elimination
                 while( !varindices.empty() ) {
                     size_t i = f( cl, varindices );
+                    DAI_ASSERT( i < vars.size() );
                     result.insert( cl.Delta( i ) );
                     cl.insert( cl.delta( i ) );
                     cl.eraseSubsuming( i );
@@ -241,10 +198,112 @@ namespace dai {
 
                 return result;
             }
+
+            /// Calculates cost of eliminating the \a i 'th variable.
+            /** The cost is measured as "number of added edges in the adjacency graph",
+             *  where the adjacency graph has the variables as its nodes and connects
+             *  nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
+             *  \deprecated Please use dai::eliminationCost_MinFill() instead.
+             */
+            size_t eliminationCost( size_t i ) const;
+
+            /// Performs Variable Elimination, only keeping track of the interactions that are created along the way.
+            /** \param ElimSeq The sequence in which to eliminate the variables
+             *  \return A set of elimination "cliques"
+             *  \deprecated Please use dai::ClusterGraph::VarElim( sequentialVariableElimination( ElimSeq ) ) instead.
+             */
+            ClusterGraph VarElim( const std::vector<Var> &ElimSeq ) const;
+
+            /// Performs Variable Elimination using the "MinFill" heuristic
+            /** The "MinFill" heuristic greedily minimizes the cost of eliminating a variable,
+             *  measured with eliminationCost().
+             *  \return A set of elimination "cliques".
+             *  \deprecated Please use dai::ClusterGraph::VarElim( greedyVariableElimination( eliminationCost_MinFill ) ) instead.
+             */
+            ClusterGraph VarElim_MinFill() const;
         //@}
     };
 
 
+    /// Helper object for dai::ClusterGraph::VarElim()
+    /** Chooses the next variable to eliminate by picking them sequentially from a given vector of variables.
+     */
+    class sequentialVariableElimination {
+        private:
+            /// The variable elimination sequence
+            std::vector<Var> seq;
+            /// Counter
+            size_t i;
+
+        public:
+            /// Construct from vector of variables
+            sequentialVariableElimination( const std::vector<Var> s ) : seq(s), i(0) {}
+
+            /// Returns next variable in sequence
+           size_t operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ );
+    };
+
+
+    /// Helper object for dai::ClusterGraph::VarElim()
+    /** Chooses the next variable to eliminate greedily by taking the one that minimizes
+     *  a given heuristic cost function.
+     */
+    class greedyVariableElimination {
+        public:
+            /// Type of cost functions to be used for greedy variable elimination
+            typedef size_t (*eliminationCostFunction)(const ClusterGraph &, size_t);
+
+        private:
+            /// Pointer to the cost function used
+            eliminationCostFunction heuristic;
+
+        public:
+            /// Construct from cost function
+            /** \note Examples of cost functions are eliminationCost_MinFill() and eliminationCost_WeightedMinFill().
+             */
+            greedyVariableElimination( eliminationCostFunction h ) : heuristic(h) {}
+
+            /// Returns the best variable from \a remainingVars to eliminate in the cluster graph \a cl by greedily minimizing the cost function.
+            /** This function calculates the cost for eliminating each variable in \a remaingVars and returns the variable which has lowest cost.
+             */
+            size_t operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars );
+    };
+
+
+    /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinNeighbors" criterion.
+    /** The cost is measured as "number of neigboring nodes in the current adjacency graph",
+     *  where the adjacency graph has the variables as its nodes and connects
+     *  nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
+     */
+    size_t eliminationCost_MinNeighbors( const ClusterGraph &cl, size_t i );
+
+
+    /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinWeight" criterion.
+    /** The cost is measured as "product of weights of neighboring nodes in the current adjacency graph",
+     *  where the adjacency graph has the variables as its nodes and connects
+     *  nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
+     *  The weight of a node is the number of states of the corresponding variable.
+     */
+    size_t eliminationCost_MinWeight( const ClusterGraph &cl, size_t i );
+
+
+    /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinFill" criterion.
+    /** The cost is measured as "number of added edges in the adjacency graph",
+     *  where the adjacency graph has the variables as its nodes and connects
+     *  nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
+     */
+    size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i );
+
+
+    /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "WeightedMinFill" criterion.
+    /** The cost is measured as "total weight of added edges in the adjacency graph",
+     *  where the adjacency graph has the variables as its nodes and connects
+     *  nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
+     *  The weight of an edge is the product of the number of states of the variables corresponding with its nodes.
+     */
+    size_t eliminationCost_WeightedMinFill( const ClusterGraph &cl, size_t i );
+
+
 } // end of namespace dai
 
 
index 963ba34..9ce94ea 100644 (file)
@@ -20,7 +20,7 @@
  *
  *  \todo Document tests and utils
  *
- *  \todo Add FAQ
+ *  \todo Implement routines for UAI probabilistic inference evaluation data
  *
  *  \idea Adapt (part of the) guidelines in http://www.boost.org/development/requirements.html#Design_and_Programming
  *
  *  "Divergence measures and message passing",
  *  <em>MicroSoft Research Technical Report</em> MSR-TR-2005-173,
  *  http://research.microsoft.com/en-us/um/people/minka/papers/message-passing/minka-divergence.pdf
+ *
+ *  \anchor KoF09 \ref KoF09
+ *  D. Koller and N. Friedman (2009):
+ *  <em>Probabilistic Graphical Models - Principles and Techniques</em>,
+ *  The MIT Press, Cambridge, Massachusetts, London, England.
  */
 
 
index fad066b..f22e462 100644 (file)
@@ -46,7 +46,7 @@ class JTree : public DAIAlgRG {
         /// Stores the messages
         std::vector<std::vector<Factor> >  _mes;
 
-        /// Stores the logarithm of the partition sum (not used - why not?)
+        /// Stores the logarithm of the partition sum
         Real _logZ;
 
     public:
@@ -75,6 +75,18 @@ class JTree : public DAIAlgRG {
              */
             DAI_ENUM(InfType,SUMPROD,MAXPROD);
 
+            /// Enumeration of elimination cost functions used for constructing the junction tree
+            /** The cost of eliminating a variable can be (\see [\ref KoF09], page 314)):
+             *  - MINNEIGHBORS the number of neighbors it has in the current adjacency graph;
+             *  - MINWEIGHT the product of the number of states of all neighbors in the current adjacency graph;
+             *  - MINFILL the number of edges that need to be added to the adjacency graph due to the elimination;
+             *  - WEIGHTEDMINFILL the sum of weights of the edges that need to be added to the adjacency graph
+             *    due to the elimination, where a weight of an edge is the produt of weights of its constituent
+             *    vertices.
+             *  The elimination sequence is chosen greedily in order to minimize the cost.
+             */
+            DAI_ENUM(HeuristicType,MINNEIGHBORS,MINWEIGHT,MINFILL,WEIGHTEDMINFILL);
+
             /// Verbosity (amount of output sent to stderr)
             size_t verbose;
 
@@ -83,6 +95,9 @@ class JTree : public DAIAlgRG {
 
             /// Type of inference
             InfType inference;
+
+            /// Heuristic to use for constructing the junction tree
+            HeuristicType heuristic;
         } props;
 
         /// Name of this inference algorithm
@@ -97,7 +112,7 @@ class JTree : public DAIAlgRG {
         /// Construct from FactorGraph \a fg and PropertySet \a opts
         /** \param fg factor graph (which has to be connected);
          ** \param opts Parameters @see Properties
-         *  \param automatic if \c true, construct the junction tree automatically, using the MinFill heuristic.
+         *  \param automatic if \c true, construct the junction tree automatically, using the heuristic in opts['heuristic'].
          *  \throw FACTORGRAPH_NOT_CONNECTED if \a fg is not connected
          */
         JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true );
@@ -183,11 +198,11 @@ class JTree : public DAIAlgRG {
 };
 
 
-/// Calculates upper bound to the treewidth of a FactorGraph, using the MinFill heuristic
+/// Calculates upper bound to the treewidth of a FactorGraph, using the specified heuristic
 /** \relates JTree
  *  \return a pair (number of variables in largest clique, number of states in largest clique)
  */
-std::pair<size_t,size_t> boundTreewidth( const FactorGraph & fg );
+std::pair<size_t,double> boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn );
 
 
 } // end of namespace dai
index 361ef7f..c2a33bc 100644 (file)
@@ -133,8 +133,11 @@ class VarSet : public SmallSet<Var> {
          */
         size_t nrStates() const {
             size_t states = 1;
-            for( VarSet::const_iterator n = begin(); n != end(); n++ )
-                states *= n->states();
+            for( VarSet::const_iterator n = begin(); n != end(); n++ ) {
+                size_t newStates = states * n->states();
+                DAI_ASSERT( newStates >= states );
+                states = newStates;
+            }
             return states;
         }
     //@}
index 60adf3a..c947f13 100644 (file)
@@ -45,27 +45,54 @@ ClusterGraph::ClusterGraph( const std::vector<VarSet> & cls ) : G(), vars(), clu
 }
 
 
+size_t ClusterGraph::eliminationCost( size_t i ) const {
+    return eliminationCost_MinFill( *this, i );
+}
+
 
 ClusterGraph ClusterGraph::VarElim( const std::vector<Var> &ElimSeq ) const {
-    // Make a copy
-    ClusterGraph cl(*this);
-    cl.eraseNonMaximal();
+    return VarElim( sequentialVariableElimination( ElimSeq ) );
+}
 
-    ClusterGraph result;
 
-    // Do variable elimination
-    for( vector<Var>::const_iterator n = ElimSeq.begin(); n != ElimSeq.end(); n++ ) {
-        size_t i = cl.findVar( *n );
-        DAI_ASSERT( i != cl.vars.size() );
+ClusterGraph ClusterGraph::VarElim_MinFill() const {
+    return VarElim( greedyVariableElimination( &eliminationCost_MinFill ) );
+}
+
 
-        result.insert( cl.Delta(i) );
+size_t sequentialVariableElimination::operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ ) {
+    return cl.findVar( seq.at(i++) );
+}
 
-        cl.insert( cl.delta(i) );
-        cl.eraseSubsuming( i );
-        cl.eraseNonMaximal();
+
+size_t greedyVariableElimination::operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars ) {
+    set<size_t>::const_iterator lowest = remainingVars.end();
+    size_t lowest_cost = -1UL;
+    for( set<size_t>::const_iterator i = remainingVars.begin(); i != remainingVars.end(); i++ ) {
+        size_t cost = heuristic( cl, *i );
+        if( lowest == remainingVars.end() || lowest_cost > cost ) {
+            lowest = i;
+            lowest_cost = cost;
+        }
     }
+    return *lowest;
+}
+
+
+size_t eliminationCost_MinNeighbors( const ClusterGraph &cl, size_t i ) {
+    std::vector<size_t> id_n = cl.G.delta1( i );
+    return id_n.size();
+}
+
+
+size_t eliminationCost_MinWeight( const ClusterGraph &cl, size_t i ) {
+    std::vector<size_t> id_n = cl.G.delta1( i );
+    
+    size_t cost = 1;
+    for( size_t _i = 0; _i < id_n.size(); _i++ )
+        cost *= cl.vars[id_n[_i]].states();
 
-    return result;
+    return cost;
 }
 
 
@@ -73,7 +100,6 @@ size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i ) {
     std::vector<size_t> id_n = cl.G.delta1( i );
 
     size_t cost = 0;
-
     // for each unordered pair {i1,i2} adjacent to n
     for( size_t _i1 = 0; _i1 < id_n.size(); _i1++ )
         for( size_t _i2 = _i1 + 1; _i2 < id_n.size(); _i2++ ) {
@@ -86,18 +112,20 @@ size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i ) {
 }
 
 
-size_t eliminationChoice_MinFill( const ClusterGraph &cl, const std::set<size_t> &remainingVars ) {
-    set<size_t>::const_iterator lowest = remainingVars.end();
-    size_t lowest_cost = -1UL;
-    for( set<size_t>::const_iterator i = remainingVars.begin(); i != remainingVars.end(); i++ ) {
-        size_t cost = eliminationCost_MinFill( cl, *i );
-        if( lowest == remainingVars.end() || lowest_cost > cost ) {
-            lowest = i;
-            lowest_cost = cost;
+size_t eliminationCost_WeightedMinFill( const ClusterGraph &cl, size_t i ) {
+    std::vector<size_t> id_n = cl.G.delta1( i );
+
+    size_t cost = 0;
+    // for each unordered pair {i1,i2} adjacent to n
+    for( size_t _i1 = 0; _i1 < id_n.size(); _i1++ )
+        for( size_t _i2 = _i1 + 1; _i2 < id_n.size(); _i2++ ) {
+            // if i1 and i2 are not adjacent, eliminating n would make them adjacent
+            if( !cl.adj(id_n[_i1], id_n[_i2]) )
+                cost += cl.vars[id_n[_i1]].states() * cl.vars[id_n[_i2]].states();
         }
-    }
-    return *lowest;
-}    
+
+    return cost;
+}
 
 
 } // end of namespace dai
index c0e65c0..31a448e 100644 (file)
@@ -33,6 +33,10 @@ void JTree::setProperties( const PropertySet &opts ) {
         props.inference = opts.getStringAs<Properties::InfType>("inference");
     else
         props.inference = Properties::InfType::SUMPROD;
+    if( opts.hasKey("heuristic") )
+        props.heuristic = opts.getStringAs<Properties::HeuristicType>("heuristic");
+    else
+        props.heuristic = Properties::HeuristicType::MINFILL;
 }
 
 
@@ -41,6 +45,7 @@ PropertySet JTree::getProperties() const {
     opts.Set( "verbose", props.verbose );
     opts.Set( "updates", props.updates );
     opts.Set( "inference", props.inference );
+    opts.Set( "heuristic", props.heuristic );
     return opts;
 }
 
@@ -50,6 +55,7 @@ string JTree::printProperties() const {
     s << "[";
     s << "verbose=" << props.verbose << ",";
     s << "updates=" << props.updates << ",";
+    s << "heuristic=" << props.heuristic << ",";
     s << "inference=" << props.inference << "]";
     return s.str();
 }
@@ -77,8 +83,25 @@ JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) :
         if( props.verbose >= 3 )
             cerr << "Maximal clusters: " << _cg << endl;
 
-        // Use MinFill heuristic to guess optimal elimination sequence
-        vector<VarSet> ElimVec = _cg.VarElim( eliminationChoice_MinFill ).eraseNonMaximal().toVector();
+        // Use heuristic to guess optimal elimination sequence
+        greedyVariableElimination::eliminationCostFunction ec(NULL);
+        switch( (size_t)props.heuristic ) {
+            case Properties::HeuristicType::MINNEIGHBORS:
+                ec = eliminationCost_MinNeighbors;
+                break;
+            case Properties::HeuristicType::MINWEIGHT:
+                ec = eliminationCost_MinWeight;
+                break;
+            case Properties::HeuristicType::MINFILL:
+                ec = eliminationCost_MinFill;
+                break;
+            case Properties::HeuristicType::WEIGHTEDMINFILL:
+                ec = eliminationCost_WeightedMinFill;
+                break;
+            default:
+                DAI_THROW(UNKNOWN_ENUM_VALUE);
+        }
+        vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( ec ) ).eraseNonMaximal().toVector();
         if( props.verbose >= 3 )
             cerr << "VarElim result: " << ElimVec << endl;
 
@@ -497,7 +520,7 @@ Factor JTree::calcMarginal( const VarSet& vs ) {
 }
 
 
-std::pair<size_t,size_t> boundTreewidth( const FactorGraph & fg ) {
+std::pair<size_t,double> boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn ) {
     ClusterGraph _cg;
 
     // Copy factors
@@ -508,11 +531,11 @@ std::pair<size_t,size_t> boundTreewidth( const FactorGraph & fg ) {
     _cg.eraseNonMaximal();
 
     // Obtain elimination sequence
-    vector<VarSet> ElimVec = _cg.VarElim( eliminationChoice_MinFill ).eraseNonMaximal().toVector();
+    vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( fn ) ).eraseNonMaximal().toVector();
 
     // Calculate treewidth
     size_t treewidth = 0;
-    size_t nrstates = 0;
+    double nrstates = 0.0;
     for( size_t i = 0; i < ElimVec.size(); i++ ) {
         if( ElimVec[i].size() > treewidth )
             treewidth = ElimVec[i].size();
@@ -521,13 +544,7 @@ std::pair<size_t,size_t> boundTreewidth( const FactorGraph & fg ) {
             nrstates = s;
     }
 
-    return pair<size_t,size_t>(treewidth, nrstates);
-}
-
-
-std::pair<size_t,size_t> treewidth( const FactorGraph & fg )
-{
-    return boundTreewidth( fg );
+    return make_pair(treewidth, nrstates);
 }
 
 
index 9b919af..e42cb9e 100644 (file)
@@ -93,9 +93,15 @@ int main( int argc, char *argv[] ) {
         cout << "Binary variables?      " << fg.isBinary() << endl;
         cout << "Pairwise interactions? " << fg.isPairwise() << endl;
         if( calc_tw ) {
-            std::pair<size_t,size_t> tw = boundTreewidth(fg);
-            cout << "Treewidth:           " << tw.first << endl;
-            cout << "Largest cluster for JTree has " << tw.second << " states " << endl;
+            std::pair<size_t,size_t> tw;
+            tw = boundTreewidth(fg, &eliminationCost_MinNeighbors);
+            cout << "Treewidth (MinNeighbors):     " << tw.first << " (" << tw.second << " states)" << endl;
+            tw = boundTreewidth(fg, &eliminationCost_MinWeight);
+            cout << "Treewidth (MinWeight):        " << tw.first << " (" << tw.second << " states)" << endl;
+            tw = boundTreewidth(fg, &eliminationCost_MinFill);
+            cout << "Treewidth (MinFill):          " << tw.first << " (" << tw.second << " states)" << endl;
+            tw = boundTreewidth(fg, &eliminationCost_WeightedMinFill);
+            cout << "Treewidth (WeightedMinFill):  " << tw.first << " (" << tw.second << " states)" << endl;
         }
         long double stsp = 1.0;
         for( size_t i = 0; i < fg.nrVars(); i++ )