Improved Gibbs and added FactorGraph::logScore( const std::vector<size_t>& statevec )
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Tue, 3 Aug 2010 14:29:43 +0000 (16:29 +0200)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Tue, 3 Aug 2010 14:29:43 +0000 (16:29 +0200)
ChangeLog
examples/example.cpp
include/dai/factorgraph.h
include/dai/gibbs.h
src/factorgraph.cpp
src/gibbs.cpp
tests/aliases.conf
tests/unit/factorgraph_test.cpp

index 11b4cc1..cc5a39e 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,5 +1,11 @@
 git HEAD
 --------
+* Added FactorGraph::logScore( const std::vector<size_t>& statevec )
+* Improved Gibbs
+  - renamed 'iters' property into 'maxiter'
+  - added 'maxtime' property
+  - added 'restart' property
+  - added Gibbs::findMaximum()
 * Improved BP (added 'maxtime' property)
 * Added fromString<>( const std::string& x )
 * Added SmallSet::erase( const T& t )
index 1a51931..a7dae54 100644 (file)
@@ -18,25 +18,6 @@ using namespace dai;
 using namespace std;
 
 
-// Evaluates the log probability of the joint configuration vstate of the variables in factor graph fg
-Real evalState( const FactorGraph& fg, const std::vector<size_t> vstate ) {
-    // First, construct a map from Var objects to their states
-    // This decouples the representation of the joint state in vstate
-    // from the factor graph
-    map<Var, size_t> state;
-    for( size_t i = 0; i < vstate.size(); i++ )
-        state[fg.var(i)] = vstate[i];
-
-    // Evaluate the log probability of the joint configuration in state
-    // by summing the log factor entries of the factors in factor graph fg
-    // that correspond to the joint configuration
-    Real logZ = 0.0;
-    for( size_t I = 0; I < fg.nrFactors(); I++ )
-        logZ += dai::log( fg.factor(I)[calcLinearState( fg.factor(I).vars(), state )] );
-    return logZ;
-}
-
-
 int main( int argc, char *argv[] ) {
     if ( argc != 2 ) {
         cout << "Usage: " << argv[0] << " <filename.fg>" << endl << endl;
@@ -160,17 +141,17 @@ int main( int argc, char *argv[] ) {
             cout << mp.belief(fg.factor(I).vars()) << "=" << mp.beliefF(I) << endl;
 
         // Report exact MAP joint state
-        cout << "Exact MAP state (log probability = " << evalState( fg, jtmapstate ) << "):" << endl;
+        cout << "Exact MAP state (log score = " << fg.logScore( jtmapstate ) << "):" << endl;
         for( size_t i = 0; i < jtmapstate.size(); i++ )
             cout << fg.var(i) << ": " << jtmapstate[i] << endl;
 
         // Report max-product MAP joint state
-        cout << "Approximate (max-product) MAP state (log probability = " << evalState( fg, mpstate ) << "):" << endl;
+        cout << "Approximate (max-product) MAP state (log score = " << fg.logScore( mpstate ) << "):" << endl;
         for( size_t i = 0; i < mpstate.size(); i++ )
             cout << fg.var(i) << ": " << mpstate[i] << endl;
 
         // Report DecMAP joint state
-        cout << "Approximate DecMAP state (log probability = " << evalState( fg, decmapstate ) << "):" << endl;
+        cout << "Approximate DecMAP state (log score = " << fg.logScore( decmapstate ) << "):" << endl;
         for( size_t i = 0; i < decmapstate.size(); i++ )
             cout << fg.var(i) << ": " << decmapstate[i] << endl;
     }
index e30a11f..4c384d1 100644 (file)
@@ -237,6 +237,9 @@ class FactorGraph {
          *  strict subset of another factor domain.
          */
         std::vector<VarSet> maximalFactorDomains() const;
+
+        /// Evaluates the log score (i.e., minus the energy) of the joint configuration \a statevec
+        Real logScore( const std::vector<size_t>& statevec );
     //@}
 
     /// \name Backup/restore mechanism for factors
index 1dafa11..d89e305 100644 (file)
@@ -33,22 +33,34 @@ class Gibbs : public DAIAlgFG {
         typedef std::vector<size_t> _count_t;
         /// Type used to store the joint state of all variables
         typedef std::vector<size_t> _state_t;
-        /// Number of samples counted so far (excluding burn-in)
+        /// Number of samples counted so far (excluding burn-in periods)
         size_t _sample_count;
         /// State counts for each variable
         std::vector<_count_t> _var_counts;
         /// State counts for each factor
         std::vector<_count_t> _factor_counts;
+        /// Number of iterations done (including burn-in periods)
+        size_t _iters;
         /// Current joint state of all variables
         _state_t _state;
+        /// Joint state with maximum probability seen so far
+        _state_t _max_state;
+        /// Highest score so far
+        Real _max_score;
 
     public:
         /// Parameters for Gibbs
         struct Properties {
-            /// Total number of iterations
-            size_t iters;
+            /// Maximum number of iterations
+            size_t maxiter;
 
-            /// Number of "burn-in" iterations
+            /// Maximum time (in seconds)
+            double maxtime;
+
+            /// Number of iterations after which a random restart is made
+            size_t restart;
+
+            /// Number of "burn-in" iterations after each (re)start (for which no statistics are gathered)
             size_t burnin;
 
             /// Verbosity (amount of output sent to stderr)
@@ -60,13 +72,13 @@ class Gibbs : public DAIAlgFG {
 
     public:
         /// Default constructor
-        Gibbs() : DAIAlgFG(), _sample_count(0), _var_counts(), _factor_counts(), _state() {}
+        Gibbs() : DAIAlgFG(), _sample_count(0), _var_counts(), _factor_counts(), _iters(0), _state(), _max_state(), _max_score(-INFINITY) {}
 
         /// Construct from FactorGraph \a fg and PropertySet \a opts
         /** \param fg Factor graph.
          *  \param opts Parameters @see Properties
          */
-        Gibbs( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), _sample_count(0), _var_counts(), _factor_counts(), _state() {
+        Gibbs( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), _sample_count(0), _var_counts(), _factor_counts(), _iters(0), _state(), _max_state(), _max_score(-INFINITY) {
             setProperties( opts );
             construct();
         }
@@ -86,7 +98,8 @@ class Gibbs : public DAIAlgFG {
         virtual void init( const VarSet &/*ns*/ ) { init(); }
         virtual Real run();
         virtual Real maxDiff() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
-        virtual size_t Iterations() const { return props.iters; }
+        virtual size_t Iterations() const { return _iters; }
+        std::vector<std::size_t> findMaximum() const { return _max_state; }
         virtual void setProperties( const PropertySet &opts );
         virtual PropertySet getProperties() const;
         virtual std::string printProperties() const;
index 905bb34..416ecf8 100644 (file)
@@ -337,6 +337,23 @@ vector<VarSet> FactorGraph::maximalFactorDomains() const {
 }
 
 
+Real FactorGraph::logScore( const std::vector<size_t>& statevec ) {
+    // Construct a State object that represents statevec
+    // This decouples the representation of the joint state in statevec from the factor graph
+    map<Var, size_t> statemap;
+    for( size_t i = 0; i < statevec.size(); i++ )
+        statemap[var(i)] = statevec[i];
+    State S(statemap);
+
+    // Evaluate the log probability of the joint configuration in statevec
+    // by summing the log factor entries of the factors that correspond to this joint configuration
+    Real lS = 0.0;
+    for( size_t I = 0; I < nrFactors(); I++ )
+        lS += dai::log( factor(I)[S(factor(I).vars())] );
+    return lS;
+}
+
+
 void FactorGraph::clamp( size_t i, size_t x, bool backup ) {
     DAI_ASSERT( x <= var(i).states() );
     Factor mask( var(i), (Real)0 );
index c9e9e39..3d92d0a 100644 (file)
@@ -29,14 +29,21 @@ const char *Gibbs::Name = "GIBBS";
 
 
 void Gibbs::setProperties( const PropertySet &opts ) {
-    DAI_ASSERT( opts.hasKey("iters") );
-    props.iters = opts.getStringAs<size_t>("iters");
+    DAI_ASSERT( opts.hasKey("maxiter") );
+    props.maxiter = opts.getStringAs<size_t>("maxiter");
 
+    if( opts.hasKey("restart") )
+        props.restart = opts.getStringAs<size_t>("restart");
+    else
+        props.restart = props.maxiter;
     if( opts.hasKey("burnin") )
         props.burnin = opts.getStringAs<size_t>("burnin");
     else
         props.burnin = 0;
-
+    if( opts.hasKey("maxtime") )
+        props.maxtime = opts.getStringAs<Real>("maxtime");
+    else
+        props.maxtime = INFINITY;
     if( opts.hasKey("verbose") )
         props.verbose = opts.getStringAs<size_t>("verbose");
     else
@@ -46,7 +53,9 @@ void Gibbs::setProperties( const PropertySet &opts ) {
 
 PropertySet Gibbs::getProperties() const {
     PropertySet opts;
-    opts.set( "iters", props.iters );
+    opts.set( "maxiter", props.maxiter );
+    opts.set( "maxtime", props.maxtime );
+    opts.set( "restart", props.restart );
     opts.set( "burnin", props.burnin );
     opts.set( "verbose", props.verbose );
     return opts;
@@ -56,7 +65,9 @@ PropertySet Gibbs::getProperties() const {
 string Gibbs::printProperties() const {
     stringstream s( stringstream::out );
     s << "[";
-    s << "iters=" << props.iters << ",";
+    s << "maxiter=" << props.maxiter << ",";
+    s << "maxtime=" << props.maxtime << ",";
+    s << "restart=" << props.restart << ",";
     s << "burnin=" << props.burnin << ",";
     s << "verbose=" << props.verbose << "]";
     return s.str();
@@ -64,6 +75,8 @@ string Gibbs::printProperties() const {
 
 
 void Gibbs::construct() {
+    _sample_count = 0;
+
     _var_counts.clear();
     _var_counts.reserve( nrVars() );
     for( size_t i = 0; i < nrVars(); i++ )
@@ -74,25 +87,33 @@ void Gibbs::construct() {
     for( size_t I = 0; I < nrFactors(); I++ )
         _factor_counts.push_back( _count_t( factor(I).nrStates(), 0 ) );
 
-    _sample_count = 0;
+    _iters = 0;
 
     _state.clear();
     _state.resize( nrVars(), 0 );
+
+    _max_state.clear();
+    _max_state.resize( nrVars(), 0 );
+
+    _max_score = logScore( _max_state );
 }
 
 
 void Gibbs::updateCounts() {
     _sample_count++;
-    if( _sample_count > props.burnin ) {
-        for( size_t i = 0; i < nrVars(); i++ )
-            _var_counts[i][_state[i]]++;
-        for( size_t I = 0; I < nrFactors(); I++ )
-            _factor_counts[I][getFactorEntry(I)]++;
+    for( size_t i = 0; i < nrVars(); i++ )
+        _var_counts[i][_state[i]]++;
+    for( size_t I = 0; I < nrFactors(); I++ )
+        _factor_counts[I][getFactorEntry(I)]++;
+    Real score = logScore( _state );
+    if( score > _max_score ) {
+        _max_state = _state;
+        _max_score = score;
     }
 }
 
 
-inline size_t Gibbs::getFactorEntry( size_t I ) {
+size_t Gibbs::getFactorEntry( size_t I ) {
     size_t f_entry = 0;
     for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
         // note that iterating over nbF(I) yields the same ordering
@@ -105,7 +126,7 @@ inline size_t Gibbs::getFactorEntry( size_t I ) {
 }
 
 
-inline size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
+size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
     size_t skip = 1;
     for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
         // note that iterating over nbF(I) yields the same ordering
@@ -146,7 +167,7 @@ Prob Gibbs::getVarDist( size_t i ) {
 }
 
 
-inline void Gibbs::resampleVar( size_t i ) {
+void Gibbs::resampleVar( size_t i ) {
     _state[i] = getVarDist(i).draw();
 }
 
@@ -158,12 +179,12 @@ void Gibbs::randomizeState() {
 
 
 void Gibbs::init() {
+    _sample_count = 0;
     for( size_t i = 0; i < nrVars(); i++ )
         fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
     for( size_t I = 0; I < nrFactors(); I++ )
         fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
-    _sample_count = 0;
-    randomizeState();
+    _iters = 0;
 }
 
 
@@ -175,31 +196,34 @@ Real Gibbs::run() {
 
     double tic = toc();
 
-    for( size_t iter = 0; iter < props.iters; iter++ ) {
+    for( ; _iters < props.maxiter && (toc() - tic) < props.maxtime; _iters++ ) {
+        if( (_iters % props.restart) == 0 )
+            randomizeState();
         for( size_t i = 0; i < nrVars(); i++ )
             resampleVar( i );
-        updateCounts();
+        if( (_iters % props.restart) > props.burnin )
+            updateCounts();
     }
 
     if( props.verbose >= 3 ) {
         for( size_t i = 0; i < nrVars(); i++ ) {
-            cerr << "belief for variable " << var(i) << ": " << beliefV(i) << endl;
-            cerr << "counts for variable " << var(i) << ": " << Prob( _var_counts[i] ) << endl;
+            cerr << "Belief for variable " << var(i) << ": " << beliefV(i) << endl;
+            cerr << "Counts for variable " << var(i) << ": " << Prob( _var_counts[i] ) << endl;
         }
     }
 
     if( props.verbose >= 3 )
-        cerr << Name << "::run:  ran " << props.iters << " passes (" << toc() - tic << " clocks)." << endl;
+        cerr << Name << "::run:  ran " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
 
-    if( _sample_count == 0 )
+    if( _iters == 0 )
         return INFINITY;
     else
-        return 1.0 / _sample_count;
+        return std::pow( _iters, -0.5 );
 }
 
 
 Factor Gibbs::beliefV( size_t i ) const {
-    if( _sample_count <= props.burnin  )
+    if( _sample_count == 0 )
         return Factor( var(i) );
     else
         return Factor( var(i), _var_counts[i] ).normalized();
@@ -207,7 +231,7 @@ Factor Gibbs::beliefV( size_t i ) const {
 
 
 Factor Gibbs::beliefF( size_t I ) const {
-    if( _sample_count <= props.burnin  )
+    if( _sample_count == 0 )
         return Factor( factor(I).vars() );
     else
         return Factor( factor(I).vars(), _factor_counts[I] ).normalized();
@@ -243,9 +267,9 @@ Factor Gibbs::belief( const VarSet &ns ) const {
 
 std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t iters ) {
     PropertySet gibbsProps;
-    gibbsProps.set("iters", iters);
-    gibbsProps.set("burnin", size_t(0));
-    gibbsProps.set("verbose", size_t(0));
+    gibbsProps.set( "maxiter", iters );
+    gibbsProps.set( "burnin", size_t(0) );
+    gibbsProps.set( "verbose", size_t(0) );
     Gibbs gibbs( fg, gibbsProps );
     gibbs.run();
     return gibbs.state();
index 8be5faa..1e50296 100644 (file)
@@ -171,16 +171,7 @@ LCBP:                           LCBP_FULLCAVin_SEQRND
 
 # --- GIBBS -------------------
 
-GIBBS:                          GIBBS[iters=1000,burnin=100]
-GIBBS_1e1:                      GIBBS[iters=10,burnin=1]
-GIBBS_1e2:                      GIBBS[iters=100,burnin=10]
-GIBBS_1e3:                      GIBBS[iters=1000,burnin=100]
-GIBBS_1e4:                      GIBBS[iters=10000,burnin=1000]
-GIBBS_1e5:                      GIBBS[iters=100000,burnin=10000]
-GIBBS_1e6:                      GIBBS[iters=1000000,burnin=100000]
-GIBBS_1e7:                      GIBBS[iters=10000000,burnin=100000]
-GIBBS_1e8:                      GIBBS[iters=100000000,burnin=100000]
-GIBBS_1e9:                      GIBBS[iters=1000000000,burnin=100000]
+GIBBS:                          GIBBS[iters=10000,burnin=100,restart=1000]
 
 # --- CBP ---------------------
 
index 5ad84a7..3ed8d13 100644 (file)
@@ -139,6 +139,7 @@ BOOST_AUTO_TEST_CASE( QueriesTest ) {
     BOOST_CHECK( G0.MarkovGraph() == GraphAL() );
     BOOST_CHECK( G0.bipGraph() == BipartiteGraph() );
     BOOST_CHECK_EQUAL( G0.maximalFactorDomains().size(), 1 );
+    BOOST_CHECK_CLOSE( G0.logScore( std::vector<size_t>() ), (Real)0.0, tol );
 
     std::vector<Factor> facs;
     facs.push_back( Factor( v01 ) );
@@ -207,6 +208,7 @@ BOOST_AUTO_TEST_CASE( QueriesTest ) {
     BOOST_CHECK_EQUAL( G1.maximalFactorDomains().size(), 2 );
     BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[0], v01 );
     BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[1], v12 );
+    BOOST_CHECK_CLOSE( G1.logScore( std::vector<size_t>(3,0) ), -dai::log((Real)32.0), tol ); 
 
     facs.push_back( Factor( v02 ) );
     H.addEdge( 0, 2 );
@@ -251,6 +253,7 @@ BOOST_AUTO_TEST_CASE( QueriesTest ) {
     BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[0], v01 );
     BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[1], v12 );
     BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[2], v02 );
+    BOOST_CHECK_CLOSE( G2.logScore( std::vector<size_t>(3,0) ), -dai::log((Real)128.0), tol );
 
     Var v3( 3, 3 );
     VarSet v03( v0, v3 );
@@ -312,6 +315,7 @@ BOOST_AUTO_TEST_CASE( QueriesTest ) {
     BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[1], v12 );
     BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[2], v02 );
     BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[3], v3 );
+    BOOST_CHECK_CLOSE( G3.logScore( std::vector<size_t>(4,0) ), -dai::log((Real)384.0), tol );
 
     facs.push_back( Factor( v123 ) );
     H.addEdge( 1, 3 );
@@ -368,6 +372,7 @@ BOOST_AUTO_TEST_CASE( QueriesTest ) {
     BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[0], v01 );
     BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[1], v02 );
     BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[2], v123 );
+    BOOST_CHECK_CLOSE( G4.logScore( std::vector<size_t>(4,0) ), -dai::log((Real)4608.0), tol );
 }