Made state a member variable of Gibbs, and added a cache for factor entries.
authorJoris Mooij <joris@jorismooij.nl>
Sun, 16 Nov 2008 14:02:41 +0000 (15:02 +0100)
committerJoris Mooij <joris@jorismooij.nl>
Sun, 16 Nov 2008 14:02:41 +0000 (15:02 +0100)
include/dai/gibbs.h
src/gibbs.cpp

index fa9e296..f49ab63 100644 (file)
@@ -46,22 +46,28 @@ class Gibbs : public DAIAlgFG {
 
     protected:
         typedef std::vector<size_t> _count_t;
+        typedef std::vector<size_t> _state_t;
+
         size_t _sample_count;
         std::vector<_count_t> _var_counts;
         std::vector<_count_t> _factor_counts;
-
-        typedef std::vector<size_t> _state_t;
-        void update_counts(_state_t &st);
-        void randomize_state(_state_t &st);
-        Prob get_var_dist(_state_t &st, size_t i);
-        void resample_var(_state_t &st, size_t i);
-        size_t get_factor_entry(const _state_t &st, int factor);
+        std::vector<size_t> _factor_entries;
+        _state_t _state;
+
+        void update_counts();
+        void randomize_state();
+        Prob get_var_dist( size_t i );
+        void resample_var( size_t i );
+        size_t get_factor_entry( size_t I );
+        size_t get_factor_entry_interval( size_t I, size_t i );
+        void calc_factor_entries();
+        void update_factor_entries( size_t i );
 
     public:
         // default constructor
-        Gibbs() : DAIAlgFG() {}
+        Gibbs() : DAIAlgFG(), _sample_count(0), _var_counts(), _factor_counts(), _factor_entries(), _state() {}
         // copy constructor
-        Gibbs(const Gibbs & x) : DAIAlgFG(x), _sample_count(x._sample_count), _var_counts(x._var_counts), _factor_counts(x._factor_counts) {}
+        Gibbs(const Gibbs & x) : DAIAlgFG(x), _sample_count(x._sample_count), _var_counts(x._var_counts), _factor_counts(x._factor_counts), _factor_entries(x._factor_entries), _state(x._state) {}
         // construct Gibbs object from FactorGraph
         Gibbs( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg) {
             setProperties( opts );
@@ -74,6 +80,8 @@ class Gibbs : public DAIAlgFG {
                 _sample_count = x._sample_count;
                 _var_counts = x._var_counts;
                 _factor_counts = x._factor_counts;
+                _factor_entries = x._factor_entries;
+                _state = x._state;
             }
             return *this;
         }
index ca430d1..bce5b0f 100644 (file)
@@ -70,107 +70,107 @@ void Gibbs::construct() {
     _var_counts.reserve( nrVars() );
     for( size_t i = 0; i < nrVars(); i++ )
         _var_counts.push_back( _count_t( var(i).states(), 0 ) );
+    
     _factor_counts.clear();
     _factor_counts.reserve( nrFactors() );
     for( size_t I = 0; I < nrFactors(); I++ )
         _factor_counts.push_back( _count_t( factor(I).states(), 0 ) );
+
     _sample_count = 0;
+
+    _factor_entries.clear();
+    _factor_entries.resize( nrFactors(), 0 );
+
+    _state.clear();
+    _state.resize( nrVars(), 0 );
 }
 
 
-void Gibbs::update_counts( _state_t &st ) {
+void Gibbs::calc_factor_entries() {
+    for( size_t I = 0; I < nrFactors(); I++ )
+        _factor_entries[I] = get_factor_entry( I );
+}
+
+void Gibbs::update_factor_entries( size_t i ) {
+    foreach( const Neighbor &I, nbV(i) )
+        _factor_entries[I] = get_factor_entry( I );
+}
+
+
+void Gibbs::update_counts() {
     for( size_t i = 0; i < nrVars(); i++ )
-        _var_counts[i][st[i]]++;
-    for( size_t I = 0; I < nrFactors(); I++ ) {
-        if( 0 ) {
-/*            multind mi( factor(I).vars() );
-            _state_t f_st( factor(I).vars().size() );
-            int k = 0;
-            foreach( size_t j, nbF(I) )
-                f_st[k++] = st[j];
-            _factor_counts[I][mi.li(f_st)]++;*/
-        } else {
-            size_t ent = get_factor_entry(st, I);
-            _factor_counts[I][ent]++;
-        }
-    }
+        _var_counts[i][_state[i]]++;
+    for( size_t I = 0; I < nrFactors(); I++ )
+        _factor_counts[I][_factor_entries[I]]++;
+//        _factor_counts[I][get_factor_entry(I)]++;
     _sample_count++;
 }
 
 
-inline
-size_t Gibbs::get_factor_entry(const _state_t &st, int factor) {
-  size_t f_entry=0;
-  int rank = nbF(factor).size();
-  for(int j=rank-1; j>=0; j--) {
-      int jn = nbF(factor)[j];
-      f_entry *= var(jn).states();
-      f_entry += st[jn];
-  }
-  return f_entry;
+inline size_t Gibbs::get_factor_entry( size_t I ) {
+    size_t f_entry = 0;
+    VarSet::const_reverse_iterator check = factor(I).vars().rbegin();
+    for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
+        size_t j = nbF(I)[_j];     // FIXME
+        assert( var(j) == *check );
+        f_entry *= var(j).states();
+        f_entry += _state[j];
+        check++;
+    }
+    return f_entry;
 }
 
 
-Prob Gibbs::get_var_dist( _state_t &st, size_t i ) {
-    assert( st.size() == vars().size() );
-    assert( i < nrVars() );
-    if( 1 ) {
-        // use markov blanket of n to calculate distribution
-        size_t dim = var(i).states();
-        Neighbors &facts = nbV(i);
-
-        Prob values( dim, 1.0 );
-
-        for( size_t I = 0; I < facts.size(); I++ ) {
-            size_t fa = facts[I];
-            const Factor &f = factor(fa);
-            int save_ind = st[i];
-            for( size_t k = 0; k < dim; k++ ) {
-                st[i] = k;
-                int f_entry = get_factor_entry(st, fa);
-                values[k] *= f[f_entry];
-            }
-            st[i] = save_ind;
-        }
+inline size_t Gibbs::get_factor_entry_interval( size_t I, size_t i ) {
+    size_t skip = 1;
+    VarSet::const_iterator check = factor(I).vars().begin();
+    for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
+        size_t j = nbF(I)[_j];     // FIXME
+        assert( var(j) == *check );
+        if( i == j )
+            break;
+        else
+            skip *= var(j).states();
+        check++;
+    }
+    return skip;
+}
+
 
-        return values.normalized();
-    } else {
-/*        Var vi = var(i);
-        Factor d(vi);
-        assert(vi.states()>0);
-        assert(vi.label()>=0);
-        // loop over factors containing i (nbV(i)):
-        foreach(size_t I, nbV(i)) {
-            // use multind to find linear state for variables != i in factor
-            assert(I<nrFactors());
-            assert(factor(I).vars().size() > 0);
-            VarSet vs (factor(I).vars() / vi);
-            multind mi(vs);
-            _state_t I_st(vs.size());
-            int k=0;
-            foreach(size_t l, nbF(I)) {
-                if(l!=i) I_st[k++] = st[l];
-            }
-            // use slice(ns,ns_state) to get beliefs for variable i
-            // multiply all these beliefs together
-            d *= factor(I).slice(vs, mi.li(I_st));
+Prob Gibbs::get_var_dist( size_t i ) {
+    assert( i < nrVars() );
+    size_t i_states = var(i).states();
+    Prob i_given_MB( i_states, 1.0 );
+
+    // use markov blanket of var(i) to calculate distribution
+    foreach( const Neighbor &I, nbV(i) ) {
+        const Factor &f_I = factor(I);
+        size_t I_skip = get_factor_entry_interval( I, i );
+//        size_t I_entry = get_factor_entry(I) - (_state[i] * I_skip);
+        size_t I_entry = _factor_entries[I] - (_state[i] * I_skip);
+        for( size_t st_i = 0; st_i < i_states; st_i++ ) {
+            i_given_MB[st_i] *= f_I[I_entry];
+            I_entry += I_skip;
         }
-        d.p().normalize();
-        return d.p();*/
     }
+
+    return i_given_MB.normalized();
 }
 
 
-void Gibbs::resample_var( _state_t &st, size_t i ) {
-    // draw randomly from conditional distribution and update 'st'
-    st[i] = get_var_dist( st, i ).draw();
+inline void Gibbs::resample_var( size_t i ) {
+    // draw randomly from conditional distribution and update _state
+    size_t new_state = get_var_dist(i).draw();
+    if( new_state != _state[i] ) {
+        _state[i] = new_state;
+        update_factor_entries( i );
+    }
 }
 
 
-void Gibbs::randomize_state( _state_t &st ) {
-    assert( st.size() == nrVars() );
+void Gibbs::randomize_state() {
     for( size_t i = 0; i < nrVars(); i++ )
-        st[i] = rnd_int( 0, var(i).states() - 1 );
+        _state[i] = rnd_int( 0, var(i).states() - 1 );
 }
 
 
@@ -191,13 +191,13 @@ double Gibbs::run() {
 
     double tic = toc();
     
-    vector<size_t> state( nrVars() );
-    randomize_state( state );
+    randomize_state();
 
+    calc_factor_entries();
     for( size_t iter = 0; iter < props.iters; iter++ ) {
         for( size_t j = 0; j < nrVars(); j++ )
-            resample_var( state, j );
-        update_counts( state );
+            resample_var( j );
+        update_counts();
     }
 
     if( props.verbose >= 3 ) {
@@ -214,17 +214,13 @@ double Gibbs::run() {
 }
 
 
-Factor Gibbs::beliefV( size_t i ) const {
-    Prob p( _var_counts[i].begin(), _var_counts[i].end() );
-    p.normalize();
-    return( Factor( var(i), p ) );
+inline Factor Gibbs::beliefV( size_t i ) const {
+    return Factor( var(i), _var_counts[i].begin() ).normalized();
 }
 
 
-Factor Gibbs::beliefF( size_t I ) const {
-    Prob p( _factor_counts[I].begin(), _factor_counts[I].end() );
-    p.normalize();
-    return( Factor( factor(I).vars(), p ) );
+inline Factor Gibbs::beliefF( size_t I ) const {
+    return Factor( factor(I).vars(), _factor_counts[I].begin() ).normalized();
 }