New git HEAD version
[libdai.git] / src / gibbs.cpp
index bce5b0f..85b2379 100644 (file)
@@ -1,21 +1,13 @@
-/*  Copyright (C) 2008  Frederik Eaton [frederik at ofb dot net]
+/*  This file is part of libDAI - http://www.libdai.org/
+ *
+ *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
+ */
 
 
-    This file is part of libDAI.
 
 
-    libDAI is free software; you can redistribute it and/or modify
-    it under the terms of the GNU General Public License as published by
-    the Free Software Foundation; either version 2 of the License, or
-    (at your option) any later version.
-
-    libDAI is distributed in the hope that it will be useful,
-    but WITHOUT ANY WARRANTY; without even the implied warranty of
-    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-    GNU General Public License for more details.
-
-    You should have received a copy of the GNU General Public License
-    along with libDAI; if not, write to the Free Software
-    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
-*/
+#include <dai/dai_config.h>
+#ifdef DAI_WITH_GIBBS
 
 
 #include <iostream>
 
 
 #include <iostream>
@@ -34,13 +26,22 @@ namespace dai {
 using namespace std;
 
 
 using namespace std;
 
 
-const char *Gibbs::Name = "GIBBS";
-
-
 void Gibbs::setProperties( const PropertySet &opts ) {
 void Gibbs::setProperties( const PropertySet &opts ) {
-    assert( opts.hasKey("iters") );
-    props.iters = opts.getStringAs<size_t>("iters");
+    DAI_ASSERT( opts.hasKey("maxiter") );
+    props.maxiter = opts.getStringAs<size_t>("maxiter");
 
 
+    if( opts.hasKey("restart") )
+        props.restart = opts.getStringAs<size_t>("restart");
+    else
+        props.restart = props.maxiter;
+    if( opts.hasKey("burnin") )
+        props.burnin = opts.getStringAs<size_t>("burnin");
+    else
+        props.burnin = 0;
+    if( opts.hasKey("maxtime") )
+        props.maxtime = opts.getStringAs<Real>("maxtime");
+    else
+        props.maxtime = INFINITY;
     if( opts.hasKey("verbose") )
         props.verbose = opts.getStringAs<size_t>("verbose");
     else
     if( opts.hasKey("verbose") )
         props.verbose = opts.getStringAs<size_t>("verbose");
     else
@@ -50,8 +51,11 @@ void Gibbs::setProperties( const PropertySet &opts ) {
 
 PropertySet Gibbs::getProperties() const {
     PropertySet opts;
 
 PropertySet Gibbs::getProperties() const {
     PropertySet opts;
-    opts.Set( "iters", props.iters );
-    opts.Set( "verbose", props.verbose );
+    opts.set( "maxiter", props.maxiter );
+    opts.set( "maxtime", props.maxtime );
+    opts.set( "restart", props.restart );
+    opts.set( "burnin", props.burnin );
+    opts.set( "verbose", props.verbose );
     return opts;
 }
 
     return opts;
 }
 
@@ -59,198 +63,218 @@ PropertySet Gibbs::getProperties() const {
 string Gibbs::printProperties() const {
     stringstream s( stringstream::out );
     s << "[";
 string Gibbs::printProperties() const {
     stringstream s( stringstream::out );
     s << "[";
-    s << "iters=" << props.iters << ",";
+    s << "maxiter=" << props.maxiter << ",";
+    s << "maxtime=" << props.maxtime << ",";
+    s << "restart=" << props.restart << ",";
+    s << "burnin=" << props.burnin << ",";
     s << "verbose=" << props.verbose << "]";
     return s.str();
 }
 
 
 void Gibbs::construct() {
     s << "verbose=" << props.verbose << "]";
     return s.str();
 }
 
 
 void Gibbs::construct() {
+    _sample_count = 0;
+
     _var_counts.clear();
     _var_counts.reserve( nrVars() );
     for( size_t i = 0; i < nrVars(); i++ )
         _var_counts.push_back( _count_t( var(i).states(), 0 ) );
     _var_counts.clear();
     _var_counts.reserve( nrVars() );
     for( size_t i = 0; i < nrVars(); i++ )
         _var_counts.push_back( _count_t( var(i).states(), 0 ) );
-    
+
     _factor_counts.clear();
     _factor_counts.reserve( nrFactors() );
     for( size_t I = 0; I < nrFactors(); I++ )
     _factor_counts.clear();
     _factor_counts.reserve( nrFactors() );
     for( size_t I = 0; I < nrFactors(); I++ )
-        _factor_counts.push_back( _count_t( factor(I).states(), 0 ) );
+        _factor_counts.push_back( _count_t( factor(I).nrStates(), 0 ) );
 
 
-    _sample_count = 0;
-
-    _factor_entries.clear();
-    _factor_entries.resize( nrFactors(), 0 );
+    _iters = 0;
 
     _state.clear();
     _state.resize( nrVars(), 0 );
 
     _state.clear();
     _state.resize( nrVars(), 0 );
-}
 
 
+    _max_state.clear();
+    _max_state.resize( nrVars(), 0 );
 
 
-void Gibbs::calc_factor_entries() {
-    for( size_t I = 0; I < nrFactors(); I++ )
-        _factor_entries[I] = get_factor_entry( I );
+    _max_score = logScore( _max_state );
 }
 
 }
 
-void Gibbs::update_factor_entries( size_t i ) {
-    foreach( const Neighbor &I, nbV(i) )
-        _factor_entries[I] = get_factor_entry( I );
-}
 
 
-
-void Gibbs::update_counts() {
+void Gibbs::updateCounts() {
+    _sample_count++;
     for( size_t i = 0; i < nrVars(); i++ )
         _var_counts[i][_state[i]]++;
     for( size_t I = 0; I < nrFactors(); I++ )
     for( size_t i = 0; i < nrVars(); i++ )
         _var_counts[i][_state[i]]++;
     for( size_t I = 0; I < nrFactors(); I++ )
-        _factor_counts[I][_factor_entries[I]]++;
-//        _factor_counts[I][get_factor_entry(I)]++;
-    _sample_count++;
+        _factor_counts[I][getFactorEntry(I)]++;
+    Real score = logScore( _state );
+    if( score > _max_score ) {
+        _max_state = _state;
+        _max_score = score;
+    }
 }
 
 
 }
 
 
-inline size_t Gibbs::get_factor_entry( size_t I ) {
+size_t Gibbs::getFactorEntry( size_t I ) {
     size_t f_entry = 0;
     size_t f_entry = 0;
-    VarSet::const_reverse_iterator check = factor(I).vars().rbegin();
     for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
     for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
-        size_t j = nbF(I)[_j];     // FIXME
-        assert( var(j) == *check );
+        // note that iterating over nbF(I) yields the same ordering
+        // of variables as iterating over factor(I).vars()
+        size_t j = nbF(I)[_j];
         f_entry *= var(j).states();
         f_entry += _state[j];
         f_entry *= var(j).states();
         f_entry += _state[j];
-        check++;
     }
     return f_entry;
 }
 
 
     }
     return f_entry;
 }
 
 
-inline size_t Gibbs::get_factor_entry_interval( size_t I, size_t i ) {
+size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
     size_t skip = 1;
     size_t skip = 1;
-    VarSet::const_iterator check = factor(I).vars().begin();
     for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
     for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
-        size_t j = nbF(I)[_j];     // FIXME
-        assert( var(j) == *check );
+        // note that iterating over nbF(I) yields the same ordering
+        // of variables as iterating over factor(I).vars()
+        size_t j = nbF(I)[_j];
         if( i == j )
             break;
         else
             skip *= var(j).states();
         if( i == j )
             break;
         else
             skip *= var(j).states();
-        check++;
     }
     return skip;
 }
 
 
     }
     return skip;
 }
 
 
-Prob Gibbs::get_var_dist( size_t i ) {
-    assert( i < nrVars() );
+Prob Gibbs::getVarDist( size_t i ) {
+    DAI_ASSERT( i < nrVars() );
     size_t i_states = var(i).states();
     Prob i_given_MB( i_states, 1.0 );
 
     size_t i_states = var(i).states();
     Prob i_given_MB( i_states, 1.0 );
 
-    // use markov blanket of var(i) to calculate distribution
-    foreach( const Neighbor &I, nbV(i) ) {
+    // use Markov blanket of var(i) to calculate distribution
+    bforeach( const Neighbor &I, nbV(i) ) {
         const Factor &f_I = factor(I);
         const Factor &f_I = factor(I);
-        size_t I_skip = get_factor_entry_interval( I, i );
-//        size_t I_entry = get_factor_entry(I) - (_state[i] * I_skip);
-        size_t I_entry = _factor_entries[I] - (_state[i] * I_skip);
+        size_t I_skip = getFactorEntryDiff( I, i );
+        size_t I_entry = getFactorEntry(I) - (_state[i] * I_skip);
         for( size_t st_i = 0; st_i < i_states; st_i++ ) {
         for( size_t st_i = 0; st_i < i_states; st_i++ ) {
-            i_given_MB[st_i] *= f_I[I_entry];
+            i_given_MB.set( st_i, i_given_MB[st_i] * f_I[I_entry] );
             I_entry += I_skip;
         }
     }
 
             I_entry += I_skip;
         }
     }
 
-    return i_given_MB.normalized();
+    if( i_given_MB.sum() == 0.0 )
+        // If no state of i is allowed, use uniform distribution
+        // FIXME is that indeed the right thing to do?
+        i_given_MB = Prob( i_states );
+    else
+        i_given_MB.normalize();
+    return i_given_MB;
 }
 
 
 }
 
 
-inline void Gibbs::resample_var( size_t i ) {
-    // draw randomly from conditional distribution and update _state
-    size_t new_state = get_var_dist(i).draw();
-    if( new_state != _state[i] ) {
-        _state[i] = new_state;
-        update_factor_entries( i );
-    }
+void Gibbs::resampleVar( size_t i ) {
+    _state[i] = getVarDist(i).draw();
 }
 
 
 }
 
 
-void Gibbs::randomize_state() {
+void Gibbs::randomizeState() {
     for( size_t i = 0; i < nrVars(); i++ )
     for( size_t i = 0; i < nrVars(); i++ )
-        _state[i] = rnd_int( 0, var(i).states() - 1 );
+        _state[i] = rnd( var(i).states() );
 }
 
 
 void Gibbs::init() {
 }
 
 
 void Gibbs::init() {
+    _sample_count = 0;
     for( size_t i = 0; i < nrVars(); i++ )
         fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
     for( size_t I = 0; I < nrFactors(); I++ )
         fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
     for( size_t i = 0; i < nrVars(); i++ )
         fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
     for( size_t I = 0; I < nrFactors(); I++ )
         fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
-    _sample_count = 0;
+    _iters = 0;
 }
 
 
 }
 
 
-double Gibbs::run() {
+Real Gibbs::run() {
     if( props.verbose >= 1 )
     if( props.verbose >= 1 )
-        cout << "Starting " << identify() << "...";
+        cerr << "Starting " << identify() << "...";
     if( props.verbose >= 3 )
     if( props.verbose >= 3 )
-        cout << endl;
+        cerr << endl;
 
     double tic = toc();
 
     double tic = toc();
-    
-    randomize_state();
-
-    calc_factor_entries();
-    for( size_t iter = 0; iter < props.iters; iter++ ) {
-        for( size_t j = 0; j < nrVars(); j++ )
-            resample_var( j );
-        update_counts();
+
+    for( ; _iters < props.maxiter && (toc() - tic) < props.maxtime; _iters++ ) {
+        if( (_iters % props.restart) == 0 )
+            randomizeState();
+        for( size_t i = 0; i < nrVars(); i++ )
+            resampleVar( i );
+        if( (_iters % props.restart) > props.burnin )
+            updateCounts();
     }
 
     if( props.verbose >= 3 ) {
         for( size_t i = 0; i < nrVars(); i++ ) {
     }
 
     if( props.verbose >= 3 ) {
         for( size_t i = 0; i < nrVars(); i++ ) {
-            cerr << "belief for variable " << var(i) << ": " << beliefV(i) << endl;
-            cerr << "counts for variable " << var(i) << ": " << Prob( _var_counts[i].begin(), _var_counts[i].end() ) << endl;
+            cerr << "Belief for variable " << var(i) << ": " << beliefV(i) << endl;
+            cerr << "Counts for variable " << var(i) << ": " << Prob( _var_counts[i] ) << endl;
         }
     }
         }
     }
-    
-    if( props.verbose >= 3 )
-        cout << "Gibbs::run:  ran " << props.iters << " passes (" << toc() - tic << " clocks)." << endl;
-
-    return 0.0;
-}
 
 
+    if( props.verbose >= 3 )
+        cerr << name() << "::run:  ran " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
 
 
-inline Factor Gibbs::beliefV( size_t i ) const {
-    return Factor( var(i), _var_counts[i].begin() ).normalized();
+    if( _iters == 0 )
+        return INFINITY;
+    else
+        return std::pow( _iters, -0.5 );
 }
 
 
 }
 
 
-inline Factor Gibbs::beliefF( size_t I ) const {
-    return Factor( factor(I).vars(), _factor_counts[I].begin() ).normalized();
+Factor Gibbs::beliefV( size_t i ) const {
+    if( _sample_count == 0 )
+        return Factor( var(i) );
+    else
+        return Factor( var(i), _var_counts[i] ).normalized();
 }
 
 
 }
 
 
-Factor Gibbs::belief( const Var &n ) const {
-    return( beliefV( findVar( n ) ) );
+Factor Gibbs::beliefF( size_t I ) const {
+    if( _sample_count == 0 )
+        return Factor( factor(I).vars() );
+    else
+        return Factor( factor(I).vars(), _factor_counts[I] ).normalized();
 }
 
 
 vector<Factor> Gibbs::beliefs() const {
     vector<Factor> result;
 }
 
 
 vector<Factor> Gibbs::beliefs() const {
     vector<Factor> result;
-    for( size_t i = 0; i < nrVars(); i++ )
+    for( size_t i = 0; i < nrVars(); ++i )
         result.push_back( beliefV(i) );
         result.push_back( beliefV(i) );
-    for( size_t I = 0; I < nrFactors(); I++ )
+    for( size_t I = 0; I < nrFactors(); ++I )
         result.push_back( beliefF(I) );
     return result;
 }
 
 
 Factor Gibbs::belief( const VarSet &ns ) const {
         result.push_back( beliefF(I) );
     return result;
 }
 
 
 Factor Gibbs::belief( const VarSet &ns ) const {
-    if( ns.size() == 1 )
-        return belief( *(ns.begin()) );
+    if( ns.size() == 0 )
+        return Factor();
+    else if( ns.size() == 1 )
+        return beliefV( findVar( *(ns.begin()) ) );
     else {
         size_t I;
         for( I = 0; I < nrFactors(); I++ )
             if( factor(I).vars() >> ns )
                 break;
     else {
         size_t I;
         for( I = 0; I < nrFactors(); I++ )
             if( factor(I).vars() >> ns )
                 break;
-        assert( I != nrFactors() );
+        if( I == nrFactors() )
+            DAI_THROW(BELIEF_NOT_AVAILABLE);
         return beliefF(I).marginal(ns);
     }
 }
 
 
         return beliefF(I).marginal(ns);
     }
 }
 
 
+std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t maxiter ) {
+    PropertySet gibbsProps;
+    gibbsProps.set( "maxiter", maxiter );
+    gibbsProps.set( "burnin", size_t(0) );
+    gibbsProps.set( "verbose", size_t(0) );
+    Gibbs gibbs( fg, gibbsProps );
+    gibbs.run();
+    return gibbs.state();
+}
+
+
 } // end of namespace dai
 } // end of namespace dai
+
+
+#endif