Improved ClusterGraph and JTree (added 'maxmem' property)
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Wed, 4 Aug 2010 08:37:03 +0000 (10:37 +0200)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Wed, 4 Aug 2010 08:37:03 +0000 (10:37 +0200)
ChangeLog
include/dai/clustergraph.h
include/dai/exceptions.h
include/dai/jtree.h
src/clustergraph.cpp
src/exceptions.cpp
src/jtree.cpp
tests/unit/clustergraph_test.cpp
utils/fginfo.cpp

index 05786ee..b1f3ffe 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,5 +1,13 @@
 git HEAD
 --------
+* Improved ClusterGraph
+  - added ClusterGraph(const FactorGraph& fg, bool onlyMaximal) constructor
+  - findVar( const Var& n ) no longer throws an exception if the variable is
+    not found
+  - added findCluster( const VarSet& cl )
+  - added elimVar( size_t i )
+  - added 'maxStates' argument to VarElim( EliminationChoice f, size_t maxStates=0 )
+* Improved JTree (added 'maxmem' property)
 * Improved HAK (added 'maxtime' property)
 * Improved TreeEP (added 'maxtime' property)
 * Added FactorGraph::logScore( const std::vector<size_t>& statevec )
index a8c1596..25fc185 100644 (file)
@@ -23,6 +23,7 @@
 #include <vector>
 #include <dai/varset.h>
 #include <dai/bipgraph.h>
+#include <dai/factorgraph.h>
 
 
 namespace dai {
@@ -31,6 +32,7 @@ namespace dai {
     /// A ClusterGraph is a hypergraph with variables as nodes, and "clusters" (sets of variables) as hyperedges.
     /** It is implemented as a bipartite graph with variable (Var) nodes and cluster (VarSet) nodes.
      *  One may think of a ClusterGraph as a FactorGraph without the actual factor values.
+     *  \todo Remove the _vars and _clusters variables and use only the graph and a contextual factor graph.
      */
     class ClusterGraph {
         public:
@@ -58,6 +60,12 @@ namespace dai {
 
             /// Construct from vector of VarSet 's
             ClusterGraph( const std::vector<VarSet>& cls );
+
+            /// Construct from a factor graph
+            /** Creates cluster graph which has factors in \a fg as clusters if \a onlyMaximal == \c false,
+             *  and only the maximal factors in \a fg if \a onlyMaximal == \c true.
+             */
+            ClusterGraph( const FactorGraph& fg, bool onlyMaximal );
         //@}
 
         /// \name Queries
@@ -90,19 +98,34 @@ namespace dai {
             }
 
             /// Returns the index of variable \a n
-            /** \throw OBJECT_NOT_FOUND if the variable does not occur in the cluster graph
-             */
             size_t findVar( const Var& n ) const {
-                size_t r = find( _vars.begin(), _vars.end(), n ) - _vars.begin();
-                if( r == _vars.size() )
-                    DAI_THROW(OBJECT_NOT_FOUND);
-                return r;
+                return find( _vars.begin(), _vars.end(), n ) - _vars.begin();
             }
 
+            /// Returns the index of a cluster \a cl
+            size_t findCluster( const VarSet& cl ) const {
+                return find( _clusters.begin(), _clusters.end(), cl ) - _clusters.begin();
+            }
+
+/*            /// Returns the index of a cluster \a _cl
+            size_t findCluster( const SmallSet<size_t>& _cl ) const {
+                if( _cl.size() == 0 ) {
+                    for( size_t I = 0; I < nrClusters(); I++ )
+                        if( cluster(I).size() == 0 )
+                            return I;
+                } else {
+                    size_t i = _cl.front();
+                    foreach( const Neighbor& I, _G.nb1(i) )
+                        if( _G.nb2Set(I) == _cl )
+                            return I;
+                }
+                return nrClusters();
+            }*/
+
             /// Returns union of clusters that contain the \a i 'th variable
             VarSet Delta( size_t i ) const {
                 VarSet result;
-                foreach( const Neighbor &I, _G.nb1(i) )
+                foreach( const NeighborI, _G.nb1(i) )
                     result |= _clusters[I];
                 return result;
             }
@@ -117,7 +140,7 @@ namespace dai {
                 if( i1 == i2 )
                     return false;
                 bool result = false;
-                foreach( const Neighbor &I, _G.nb1(i1) )
+                foreach( const NeighborI, _G.nb1(i1) )
                     if( find( _G.nb2(I).begin(), _G.nb2(I).end(), i2 ) != _G.nb2(I).end() ) {
                         result = true;
                         break;
@@ -131,8 +154,8 @@ namespace dai {
                 const VarSet & clI = _clusters[I];
                 bool maximal = true;
                 // The following may not be optimal, since it may repeatedly test the same cluster *J
-                foreach( const Neighbor &i, _G.nb2(I) ) {
-                    foreach( const Neighbor &J, _G.nb1(i) )
+                foreach( const Neighbori, _G.nb2(I) ) {
+                    foreach( const NeighborJ, _G.nb1(i) )
                         if( (J != I) && (clI << _clusters[J]) ) {
                             maximal = false;
                             break;
@@ -146,14 +169,18 @@ namespace dai {
 
         /// \name Operations
         //@{
-            /// Inserts a cluster (if it does not already exist)
-            void insert( const VarSet& cl ) {
-                if( find( _clusters.begin(), _clusters.end(), cl ) == _clusters.end() ) {
+            /// Inserts a cluster (if it does not already exist) and creates new variables, if necessary
+            /** \note This function could be better optimized if the index of one variable in \a cl would be known.
+             *        If one could assume _vars to be ordered, a binary search could be used instead of a linear one.
+             */
+            size_t insert( const VarSet& cl ) {
+                size_t index = findCluster( cl );  // OPTIMIZE ME
+                if( index == _clusters.size() ) {
                     _clusters.push_back( cl );
                     // add variables (if necessary) and calculate neighborhood of new cluster
                     std::vector<size_t> nbs;
                     for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
-                        size_t iter = find( _vars.begin(), _vars.end(), *n ) - _vars.begin();
+                        size_t iter = findVar( *n );  // OPTIMIZE ME
                         nbs.push_back( iter );
                         if( iter == _vars.size() ) {
                             _G.addNode1();
@@ -162,8 +189,22 @@ namespace dai {
                     }
                     _G.addNode2( nbs.begin(), nbs.end(), nbs.size() );
                 }
+                return index;
             }
 
+/*            /// Inserts a cluster (if it does not already exist), assuming no new variables have to be created
+            size_t insert( const SmallSet<size_t>& _cl ) {
+                size_t index = findCluster( _cl );
+                if( index == _clusters.size() ) {
+                    VarSet cl;
+                    foreach( size_t i, _cl )
+                        cl |= var(i);
+                    _clusters.push_back( cl );
+                    _G.addNode2( _cl.begin(), _cl.end(), _cl.size() );
+                }
+                return index;
+            }*/
+
             /// Erases all clusters that are not maximal
             ClusterGraph& eraseNonMaximal() {
                 for( size_t I = 0; I < _G.nrNodes2(); ) {
@@ -178,12 +219,63 @@ namespace dai {
 
             /// Erases all clusters that contain the \a i 'th variable
             ClusterGraph& eraseSubsuming( size_t i ) {
+                DAI_ASSERT( i < nrVars() );
                 while( _G.nb1(i).size() ) {
                     _clusters.erase( _clusters.begin() + _G.nb1(i)[0] );
                     _G.eraseNode2( _G.nb1(i)[0] );
                 }
                 return *this;
             }
+
+            /// Eliminates variable with index \a i, without deleting the variable itself
+            /** \note This function can be better optimized
+             */
+            VarSet elimVar( size_t i ) {
+                DAI_ASSERT( i < nrVars() );
+                VarSet Di = Delta( i );
+
+//                if( 1 ) { // unoptimized, transparent code
+                    VarSet di = delta( i );
+                    insert( di );
+                    eraseSubsuming( i );
+                    eraseNonMaximal();
+/*                } else { // partially optimized code
+                    SmallSet<size_t> nbI = _G.delta1( i, false );
+                    size_t I = insert( nbI );
+
+                    while( _G.nb1(i).size() ) {
+                        size_t J = _G.nb1(i,0);
+                        _clusters.erase( _clusters.begin() + J );
+                        _G.eraseNode2( J );
+                        if( I > J )
+                            I--;
+                    }
+
+                    bool di_maximal = true;
+                    foreach( size_t j, nbI ) {
+                        for( size_t _J = 0; _J < _G.nb1(j).size(); ) {
+                            size_t J = _G.nb1(j,_J);
+                            SmallSet<size_t> indJ = _G.nb2Set( J );
+                            if( indJ << nbI && indJ.size() != nbI.size() ) {
+                                _clusters.erase( _clusters.begin() + J );
+                                _G.eraseNode2( J );
+                                if( I > J )
+                                    I--;
+                            } else {
+                                if( di_maximal && indJ >> nbI && indJ.size() != nbI.size() )
+                                    di_maximal = false;
+                                _J++;
+                            }
+                        }
+                    }
+                    if( !di_maximal ) {
+                        _clusters.erase( _clusters.begin() + I );
+                        _G.eraseNode2( I );
+                    }
+                }*/
+
+                return Di;
+            }
         //@}
 
         /// \name Input/Ouput
@@ -200,10 +292,12 @@ namespace dai {
             /// Performs Variable Elimination, keeping track of the interactions that are created along the way.
             /** \tparam EliminationChoice should support "size_t operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars )"
              *  \param f function object which returns the next variable index to eliminate; for example, a dai::greedyVariableElimination object.
+             *  \param maxStates maximum total number of states of all clusters in the output cluster graph (0 means no limit).
+             *  \throws OUT_OF_MEMORY if total number of states becomes larger than maxStates
              *  \return A set of elimination "cliques".
              */
             template<class EliminationChoice>
-            ClusterGraph VarElim( EliminationChoice f ) const {
+            ClusterGraph VarElim( EliminationChoice f, size_t maxStates=0 ) const {
                 // Make a copy
                 ClusterGraph cl(*this);
                 cl.eraseNonMaximal();
@@ -216,13 +310,16 @@ namespace dai {
                     varindices.insert( i );
 
                 // Do variable elimination
+                size_t totalStates = 0;
                 while( !varindices.empty() ) {
                     size_t i = f( cl, varindices );
-                    DAI_ASSERT( i < _vars.size() );
-                    result.insert( cl.Delta( i ) );
-                    cl.insert( cl.delta( i ) );
-                    cl.eraseSubsuming( i );
-                    cl.eraseNonMaximal();
+                    VarSet Di = cl.elimVar( i );
+                    result.insert( Di );
+                    if( maxStates ) {
+                        totalStates += Di.nrStates();
+                        if( totalStates > maxStates )
+                            DAI_THROW(OUT_OF_MEMORY);
+                    }
                     varindices.erase( i );
                 }
 
index 7275569..874b323 100644 (file)
@@ -98,6 +98,7 @@ class Exception : public std::runtime_error {
                    FACTORGRAPH_NOT_CONNECTED,
                    INTERNAL_ERROR,
                    RUNTIME_ERROR,
+                   OUT_OF_MEMORY,
                    NUM_ERRORS};  // NUM_ERRORS should be the last entry
 
         /// Constructor
index 1693fa2..c27be06 100644 (file)
@@ -98,6 +98,9 @@ class JTree : public DAIAlgRG {
 
             /// Heuristic to use for constructing the junction tree
             HeuristicType heuristic;
+
+            /// Maximum memory to use in bytes (0 means unlimited)
+            size_t maxmem;
         } props;
 
         /// Name of this inference algorithm
@@ -200,9 +203,11 @@ class JTree : public DAIAlgRG {
 
 /// Calculates upper bound to the treewidth of a FactorGraph, using the specified heuristic
 /** \relates JTree
+ *  \param maxStates maximum total number of states in outer regions of junction tree (0 means no limit)
+ *  \throws OUT_OF_MEMORY if the total number of states becomes larger than maxStates
  *  \return a pair (number of variables in largest clique, number of states in largest clique)
  */
-std::pair<size_t,double> boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn );
+std::pair<size_t,double> boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn, size_t maxStates=0 );
 
 
 } // end of namespace dai
index 46afcc1..1b2722a 100644 (file)
@@ -45,6 +45,22 @@ ClusterGraph::ClusterGraph( const std::vector<VarSet> & cls ) : _G(), _vars(), _
 }
 
 
+ClusterGraph::ClusterGraph( const FactorGraph& fg, bool onlyMaximal ) : _G( fg.bipGraph() ), _vars(), _clusters() {
+    // copy variables
+    _vars.reserve( fg.nrVars() );
+    for( size_t i = 0; i < fg.nrVars(); i++ )
+        _vars.push_back( fg.var(i) );
+
+    // copy clusters
+    _clusters.reserve( fg.nrFactors() );
+    for( size_t I = 0; I < fg.nrFactors(); I++ )
+        _clusters.push_back( fg.factor(I).vars() );
+
+    if( onlyMaximal )
+        eraseNonMaximal();
+}
+
+
 size_t sequentialVariableElimination::operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ ) {
     return cl.findVar( seq.at(i++) );
 }
index 0fabc6a..887cc28 100644 (file)
@@ -38,7 +38,8 @@ namespace dai {
         "Multiple undo levels unsupported",
         "FactorGraph is not connected",
         "Internal error",
-        "Runtime error"
+        "Runtime error",
+        "Out of memory"
     };
 
 
index 44bdb61..3347703 100644 (file)
@@ -39,6 +39,10 @@ void JTree::setProperties( const PropertySet &opts ) {
         props.heuristic = opts.getStringAs<Properties::HeuristicType>("heuristic");
     else
         props.heuristic = Properties::HeuristicType::MINFILL;
+    if( opts.hasKey("maxmem") )
+        props.maxmem = opts.getStringAs<size_t>("maxmem");
+    else
+        props.maxmem = 0;
 }
 
 
@@ -48,6 +52,7 @@ PropertySet JTree::getProperties() const {
     opts.set( "updates", props.updates );
     opts.set( "inference", props.inference );
     opts.set( "heuristic", props.heuristic );
+    opts.set( "maxmem", props.maxmem );
     return opts;
 }
 
@@ -58,7 +63,8 @@ string JTree::printProperties() const {
     s << "verbose=" << props.verbose << ",";
     s << "updates=" << props.updates << ",";
     s << "heuristic=" << props.heuristic << ",";
-    s << "inference=" << props.inference << "]";
+    s << "inference=" << props.inference << ",";
+    s << "maxmem=" << props.maxmem << "]";
     return s.str();
 }
 
@@ -67,21 +73,11 @@ JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) :
     setProperties( opts );
 
     if( automatic ) {
-        // Create ClusterGraph which contains factors as clusters
-        vector<VarSet> cl;
-        cl.reserve( fg.nrFactors() );
-        for( size_t I = 0; I < fg.nrFactors(); I++ )
-            cl.push_back( fg.factor(I).vars() );
-        ClusterGraph _cg( cl );
-
+        // Create ClusterGraph which contains maximal factors as clusters
+        ClusterGraph _cg( fg, true );
         if( props.verbose >= 3 )
             cerr << "Initial clusters: " << _cg << endl;
 
-        // Retain only maximal clusters
-        _cg.eraseNonMaximal();
-        if( props.verbose >= 3 )
-            cerr << "Maximal clusters: " << _cg << endl;
-
         // Use heuristic to guess optimal elimination sequence
         greedyVariableElimination::eliminationCostFunction ec(NULL);
         switch( (size_t)props.heuristic ) {
@@ -100,10 +96,23 @@ JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) :
             default:
                 DAI_THROW(UNKNOWN_ENUM_VALUE);
         }
-        vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( ec ) ).eraseNonMaximal().clusters();
+        size_t fudge = 6; // this yields a rough estimate of the memory needed (for some reason not yet clearly understood)
+        vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( ec ), props.maxmem / (sizeof(Real) * fudge) ).eraseNonMaximal().clusters();
         if( props.verbose >= 3 )
             cerr << "VarElim result: " << ElimVec << endl;
 
+        // Estimate memory needed (rough upper bound)
+        long double memneeded = 0;
+        foreach( const VarSet& cl, ElimVec )
+            memneeded += cl.nrStates();
+        memneeded *= sizeof(Real) * fudge;
+        if( props.verbose >= 1 ) {
+            cerr << "Estimate of needed memory: " << memneeded / 1024 << "kB" << endl;
+            cerr << "Maximum memory: " << props.maxmem / 1024 << "kB" << endl;
+        }
+        if( props.maxmem && memneeded > props.maxmem )
+            DAI_THROW(OUT_OF_MEMORY);
+
         // Generate the junction tree corresponding to the elimination sequence
         GenerateJT( fg, ElimVec );
     }
@@ -553,18 +562,12 @@ Factor JTree::calcMarginal( const VarSet& vs ) {
 }
 
 
-std::pair<size_t,double> boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn ) {
-    ClusterGraph _cg;
-
-    // Copy factors
-    for( size_t I = 0; I < fg.nrFactors(); I++ )
-        _cg.insert( fg.factor(I).vars() );
-
-    // Retain only maximal clusters
-    _cg.eraseNonMaximal();
+std::pair<size_t,double> boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn, size_t maxStates ) {
+    // Create cluster graph from factor graph
+    ClusterGraph _cg( fg, true );
 
     // Obtain elimination sequence
-    vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( fn ) ).eraseNonMaximal().clusters();
+    vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( fn ), maxStates ).eraseNonMaximal().clusters();
 
     // Calculate treewidth
     size_t treewidth = 0;
index d245b83..2d86253 100644 (file)
@@ -27,6 +27,11 @@ const double tol = 1e-8;
 
 
 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
+    Var v0( 0, 2 );
+    Var v1( 1, 3 );
+    Var v2( 2, 2 );
+    Var v3( 3, 4 );
+
     ClusterGraph G;
     BOOST_CHECK_EQUAL( G.clusters(), std::vector<VarSet>() );
     BOOST_CHECK( G.bipGraph() == BipartiteGraph() );
@@ -36,12 +41,10 @@ BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
     BOOST_CHECK_THROW( G.var( 0 ), Exception );
     BOOST_CHECK_THROW( G.cluster( 0 ), Exception );
 #endif
-    BOOST_CHECK_THROW( G.findVar( Var( 0, 2 ) ), Exception );
+    BOOST_CHECK_EQUAL( G.findVar( v0 ), 0 );
+    BOOST_CHECK_EQUAL( G.findCluster( v0 ), 0 );
+    BOOST_CHECK_EQUAL( G.findCluster( VarSet(v0,v1) ), 0 );
 
-    Var v0( 0, 2 );
-    Var v1( 1, 3 );
-    Var v2( 2, 2 );
-    Var v3( 3, 4 );
     std::vector<Var> vs;
     vs.push_back( v0 );
     vs.push_back( v1 );
@@ -67,6 +70,13 @@ BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
     BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
     BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
     BOOST_CHECK_EQUAL( G2.findVar( v3 ), 3 );
+    BOOST_CHECK_EQUAL( G2.findVar( Var(4, 2) ), 4 );
+    BOOST_CHECK_EQUAL( G2.findCluster( v01 ), 0 );
+    BOOST_CHECK_EQUAL( G2.findCluster( v12 ), 1 );
+    BOOST_CHECK_EQUAL( G2.findCluster( v23 ), 2 );
+    BOOST_CHECK_EQUAL( G2.findCluster( v13 ), 3 );
+    BOOST_CHECK_EQUAL( G2.findCluster( v02 ), 4 );
+    BOOST_CHECK_EQUAL( G2.findCluster( v03 ), 4 );
 
     ClusterGraph Gb( G );
     BOOST_CHECK( G.bipGraph() == Gb.bipGraph() );
@@ -87,6 +97,48 @@ BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
     BOOST_CHECK( G2.bipGraph() == G2c.bipGraph() );
     BOOST_CHECK( G2.vars() == G2c.vars() );
     BOOST_CHECK( G2.clusters() == G2c.clusters() );
+
+    std::vector<Factor> facs;
+    facs.push_back( Factor( v01 ) );
+    facs.push_back( Factor( v12 ) );
+    facs.push_back( Factor( v1 ) );
+    facs.push_back( Factor( v2 ) );
+    facs.push_back( Factor( v23 ) );
+    facs.push_back( Factor( v13 ) );
+    FactorGraph F3( facs );
+    ClusterGraph G3( F3, false );
+    BOOST_CHECK_EQUAL( G3.bipGraph(), F3.bipGraph() );
+    BOOST_CHECK_EQUAL( G3.nrVars(), F3.nrVars() );
+    for( size_t i = 0; i < 4; i++ )
+        BOOST_CHECK_EQUAL( G3.var(i), F3.var(i) );
+    BOOST_CHECK_EQUAL( G3.vars(), F3.vars() );
+    BOOST_CHECK_EQUAL( G3.nrClusters(), F3.nrFactors() );
+    for( size_t I = 0; I < 6; I++ )
+        BOOST_CHECK_EQUAL( G3.cluster(I), F3.factor(I).vars() );
+
+    ClusterGraph G4( FactorGraph( facs ), true );
+    BOOST_CHECK_EQUAL( G4.nrVars(), 4 );
+    BOOST_CHECK_EQUAL( G4.var(0), v0 );
+    BOOST_CHECK_EQUAL( G4.var(1), v1 );
+    BOOST_CHECK_EQUAL( G4.var(2), v2 );
+    BOOST_CHECK_EQUAL( G4.var(3), v3 );
+    BOOST_CHECK_EQUAL( G4.nrClusters(), 4 );
+    BOOST_CHECK_EQUAL( G4.cluster(0), v01 );
+    BOOST_CHECK_EQUAL( G4.cluster(1), v12 );
+    BOOST_CHECK_EQUAL( G4.cluster(2), v23 );
+    BOOST_CHECK_EQUAL( G4.cluster(3), v13 );
+    typedef BipartiteGraph::Edge Edge;
+    std::vector<Edge> edges;
+    edges.push_back( Edge(0, 0) );
+    edges.push_back( Edge(1, 0) );
+    edges.push_back( Edge(1, 1) );
+    edges.push_back( Edge(1, 3) );
+    edges.push_back( Edge(2, 1) );
+    edges.push_back( Edge(2, 2) );
+    edges.push_back( Edge(3, 2) );
+    edges.push_back( Edge(3, 3) );
+    BipartiteGraph G4G( 4, 4, edges.begin(), edges.end() );
+    BOOST_CHECK_EQUAL( G4.bipGraph(), G4G );
 }
 
 
@@ -140,6 +192,17 @@ BOOST_AUTO_TEST_CASE( QueriesTest ) {
     BOOST_CHECK_EQUAL( G.findVar( v2 ), 2 );
     BOOST_CHECK_EQUAL( G.findVar( v3 ), 3 );
     BOOST_CHECK_EQUAL( G.findVar( v4 ), 4 );
+    BOOST_CHECK_EQUAL( G.findCluster( v01 ), 0 );
+    BOOST_CHECK_EQUAL( G.findCluster( v12 ), 1 );
+    BOOST_CHECK_EQUAL( G.findCluster( v123 ), 2 );
+    BOOST_CHECK_EQUAL( G.findCluster( v34 ), 3 );
+    BOOST_CHECK_EQUAL( G.findCluster( v04 ), 4 );
+    BOOST_CHECK_EQUAL( G.findCluster( v02 ), 5 );
+    BOOST_CHECK_EQUAL( G.findCluster( v03 ), 5 );
+    BOOST_CHECK_EQUAL( G.findCluster( v13 ), 5 );
+    BOOST_CHECK_EQUAL( G.findCluster( v14 ), 5 );
+    BOOST_CHECK_EQUAL( G.findCluster( v23 ), 5 );
+    BOOST_CHECK_EQUAL( G.findCluster( v24 ), 5 );
     BipartiteGraph H( 5, 5 );
     H.addEdge( 0, 0 );
     H.addEdge( 1, 0 );
@@ -328,6 +391,63 @@ BOOST_AUTO_TEST_CASE( VarElimTest ) {
     H.addEdge( 4, 4 );
     BOOST_CHECK( G.bipGraph() == H );
 
+    G = Gorg;
+    BOOST_CHECK_EQUAL( G.elimVar( 0 ), v14 | v0 );
+    BOOST_CHECK_EQUAL( G.vars(), Gorg.vars() );
+    BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
+    BOOST_CHECK_EQUAL( G.cluster(0), v123 );
+    BOOST_CHECK_EQUAL( G.cluster(1), v34 );
+    BOOST_CHECK_EQUAL( G.cluster(2), v14 );
+    BipartiteGraph H_0( 5, 3 );
+    H_0.addEdge( 1, 0 ); H_0.addEdge( 2, 0 ); H_0.addEdge( 3, 0 ); H_0.addEdge( 3, 1 ); H_0.addEdge( 4, 1 ); H_0.addEdge( 1, 2 ); H_0.addEdge( 4, 2 );
+    BOOST_CHECK( G.bipGraph() == H_0 );
+
+    G = Gorg;
+    BOOST_CHECK_EQUAL( G.elimVar( 1 ), v01 | v23 );
+    BOOST_CHECK_EQUAL( G.vars(), Gorg.vars() );
+    BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
+    BOOST_CHECK_EQUAL( G.cluster(0), v34 );
+    BOOST_CHECK_EQUAL( G.cluster(1), v04 );
+    BOOST_CHECK_EQUAL( G.cluster(2), v02 | v3 );
+    BipartiteGraph H_1( 5, 3 );
+    H_1.addEdge( 3, 0 ); H_1.addEdge( 4, 0 ); H_1.addEdge( 0, 1 ); H_1.addEdge( 4, 1 ); H_1.addEdge( 0, 2 ); H_1.addEdge( 2, 2 ); H_1.addEdge( 3, 2 );
+    BOOST_CHECK( G.bipGraph() == H_1 );
+
+    G = Gorg;
+    BOOST_CHECK_EQUAL( G.elimVar( 2 ), v123 );
+    BOOST_CHECK_EQUAL( G.vars(), Gorg.vars() );
+    BOOST_CHECK_EQUAL( G.nrClusters(), 4 );
+    BOOST_CHECK_EQUAL( G.cluster(0), v01 );
+    BOOST_CHECK_EQUAL( G.cluster(1), v34 );
+    BOOST_CHECK_EQUAL( G.cluster(2), v04 );
+    BOOST_CHECK_EQUAL( G.cluster(3), v13 );
+    BipartiteGraph H_2( 5, 4 );
+    H_2.addEdge( 0, 0 ); H_2.addEdge( 1, 0 ); H_2.addEdge( 3, 1 ); H_2.addEdge( 4, 1 ); H_2.addEdge( 0, 2 ); H_2.addEdge( 4, 2 ); H_2.addEdge( 1, 3 ); H_2.addEdge( 3, 3 );
+    BOOST_CHECK( G.bipGraph() == H_2 );
+
+    G = Gorg;
+    BOOST_CHECK_EQUAL( G.elimVar( 3 ), v12 | v34 );
+    BOOST_CHECK_EQUAL( G.vars(), Gorg.vars() );
+    BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
+    BOOST_CHECK_EQUAL( G.cluster(0), v01 );
+    BOOST_CHECK_EQUAL( G.cluster(1), v04 );
+    BOOST_CHECK_EQUAL( G.cluster(2), v12 | v4 );
+    BipartiteGraph H_3( 5, 3 );
+    H_3.addEdge( 0, 0 ); H_3.addEdge( 1, 0 ); H_3.addEdge( 0, 1 ); H_3.addEdge( 4, 1 ); H_3.addEdge( 1, 2 ); H_3.addEdge( 2, 2 ); H_3.addEdge( 4, 2 );
+    BOOST_CHECK( G.bipGraph() == H_3 );
+
+    G = Gorg;
+    BOOST_CHECK_EQUAL( G.elimVar( 4 ), v03 | v4 );
+    BOOST_CHECK_EQUAL( G.vars(), Gorg.vars() );
+    BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
+    BOOST_CHECK_EQUAL( G.cluster(0), v01 );
+    BOOST_CHECK_EQUAL( G.cluster(1), v123 );
+    BOOST_CHECK_EQUAL( G.cluster(2), v03 );
+    BipartiteGraph H_4( 5, 3 );
+    H_4.addEdge( 0, 0 ); H_4.addEdge( 1, 0 ); H_4.addEdge( 1, 1 ); H_4.addEdge( 2, 1 ); H_4.addEdge( 3, 1 ); H_4.addEdge( 0, 2 ); H_4.addEdge( 3, 2 );
+    BOOST_CHECK( G.bipGraph() == H_4 );
+
+    G = Gorg;
     BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 0 ), 1 );
     BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 1 ), 2 );
     BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 2 ), 0 );
index 80da9d5..311bcb8 100644 (file)
@@ -75,9 +75,10 @@ int main( int argc, char *argv[] ) {
     if( argc != 3 ) {
         // Display help message if number of command line arguments is incorrect
         cout << "This program is part of libDAI - http://www.libdai.org/" << endl << endl;
-        cout << "Usage: ./fginfo <in.fg> <tw>" << endl << endl;
+        cout << "Usage: ./fginfo <in.fg> <maxstates>" << endl << endl;
         cout << "Reports some detailed information about the factor graph <in.fg>." << endl;
-        cout << "Also calculates treewidth (which may take some time) unless <tw> == 0." << endl;
+        cout << "Also calculates treewidth, with maximum total number of states" << endl;
+        cout << "given by <maxstates>, where 0 means unlimited." << endl << endl;
         return 1;
     } else {
         // Read factorgraph
@@ -95,18 +96,53 @@ int main( int argc, char *argv[] ) {
         cout << "Has negatives:         " << hasNegatives(fg.factors()) << endl;
         cout << "Binary variables?      " << fg.isBinary() << endl;
         cout << "Pairwise interactions? " << fg.isPairwise() << endl;
-        // Calculate treewidth using various heuristics, if requested
-        if( calc_tw ) {
-            std::pair<size_t,double> tw;
-            tw = boundTreewidth(fg, &eliminationCost_MinNeighbors);
-            cout << "Treewidth (MinNeighbors):     " << tw.first << " (" << tw.second << " states)" << endl;
-            tw = boundTreewidth(fg, &eliminationCost_MinWeight);
-            cout << "Treewidth (MinWeight):        " << tw.first << " (" << tw.second << " states)" << endl;
-            tw = boundTreewidth(fg, &eliminationCost_MinFill);
-            cout << "Treewidth (MinFill):          " << tw.first << " (" << tw.second << " states)" << endl;
-            tw = boundTreewidth(fg, &eliminationCost_WeightedMinFill);
-            cout << "Treewidth (WeightedMinFill):  " << tw.first << " (" << tw.second << " states)" << endl;
+        
+        // Calculate treewidth using various heuristics
+        std::pair<size_t,double> tw;
+        cout << "Treewidth (MinNeighbors):     ";
+        try {
+            tw = boundTreewidth(fg, &eliminationCost_MinNeighbors, maxstates );
+            cout << tw.first << " (" << tw.second << " states)" << endl;
+        } catch( Exception &e ) {
+            if( e.code() == Exception::OUT_OF_MEMORY )
+                cout << "> " << maxstates << endl;
+            else
+                cout << "an exception occurred" << endl;
+        }
+        
+        cout << "Treewidth (MinWeight):        ";
+        try {
+            tw = boundTreewidth(fg, &eliminationCost_MinWeight, maxstates );
+            cout << tw.first << " (" << tw.second << " states)" << endl;
+        } catch( Exception &e ) {
+            if( e.code() == Exception::OUT_OF_MEMORY )
+                cout << "> " << maxstates << endl;
+            else
+                cout << "an exception occurred" << endl;
+        }
+        
+        cout << "Treewidth (MinFill):          ";
+        try {
+            tw = boundTreewidth(fg, &eliminationCost_MinFill, maxstates );
+            cout << tw.first << " (" << tw.second << " states)" << endl;
+        } catch( Exception &e ) {
+            if( e.code() == Exception::OUT_OF_MEMORY )
+                cout << "> " << maxstates << endl;
+            else
+                cout << "an exception occurred" << endl;
+        }
+
+        cout << "Treewidth (WeightedMinFill):  ";
+        try {
+            tw = boundTreewidth(fg, &eliminationCost_WeightedMinFill, maxstates );
+            cout << tw.first << " (" << tw.second << " states)" << endl;
+        } catch( Exception &e ) {
+            if( e.code() == Exception::OUT_OF_MEMORY )
+                cout << "> " << maxstates << endl;
+            else
+                cout << "an exception occurred" << endl;
         }
+        
         // Calculate total state space
         long double stsp = 1.0;
         for( size_t i = 0; i < fg.nrVars(); i++ )