Removed cache because it's not always a performance improvement
authorJoris Mooij <joris@jorismooij.nl>
Sun, 16 Nov 2008 14:37:55 +0000 (15:37 +0100)
committerJoris Mooij <joris@jorismooij.nl>
Sun, 16 Nov 2008 14:37:55 +0000 (15:37 +0100)
include/dai/gibbs.h
src/gibbs.cpp

index f49ab63..6b63bdf 100644 (file)
@@ -51,23 +51,20 @@ class Gibbs : public DAIAlgFG {
         size_t _sample_count;
         std::vector<_count_t> _var_counts;
         std::vector<_count_t> _factor_counts;
-        std::vector<size_t> _factor_entries;
         _state_t _state;
 
-        void update_counts();
-        void randomize_state();
-        Prob get_var_dist( size_t i );
-        void resample_var( size_t i );
-        size_t get_factor_entry( size_t I );
-        size_t get_factor_entry_interval( size_t I, size_t i );
-        void calc_factor_entries();
-        void update_factor_entries( size_t i );
+        void updateCounts();
+        void randomizeState();
+        Prob getVarDist( size_t i );
+        void resampleVar( size_t i );
+        size_t getFactorEntry( size_t I );
+        size_t getFactorEntryDiff( size_t I, size_t i );
 
     public:
         // default constructor
-        Gibbs() : DAIAlgFG(), _sample_count(0), _var_counts(), _factor_counts(), _factor_entries(), _state() {}
+        Gibbs() : DAIAlgFG(), _sample_count(0), _var_counts(), _factor_counts(), _state() {}
         // copy constructor
-        Gibbs(const Gibbs & x) : DAIAlgFG(x), _sample_count(x._sample_count), _var_counts(x._var_counts), _factor_counts(x._factor_counts), _factor_entries(x._factor_entries), _state(x._state) {}
+        Gibbs(const Gibbs & x) : DAIAlgFG(x), _sample_count(x._sample_count), _var_counts(x._var_counts), _factor_counts(x._factor_counts), _state(x._state) {}
         // construct Gibbs object from FactorGraph
         Gibbs( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg) {
             setProperties( opts );
@@ -80,7 +77,6 @@ class Gibbs : public DAIAlgFG {
                 _sample_count = x._sample_count;
                 _var_counts = x._var_counts;
                 _factor_counts = x._factor_counts;
-                _factor_entries = x._factor_entries;
                 _state = x._state;
             }
             return *this;
index bce5b0f..96a8a17 100644 (file)
@@ -78,66 +78,49 @@ void Gibbs::construct() {
 
     _sample_count = 0;
 
-    _factor_entries.clear();
-    _factor_entries.resize( nrFactors(), 0 );
-
     _state.clear();
     _state.resize( nrVars(), 0 );
 }
 
 
-void Gibbs::calc_factor_entries() {
-    for( size_t I = 0; I < nrFactors(); I++ )
-        _factor_entries[I] = get_factor_entry( I );
-}
-
-void Gibbs::update_factor_entries( size_t i ) {
-    foreach( const Neighbor &I, nbV(i) )
-        _factor_entries[I] = get_factor_entry( I );
-}
-
-
-void Gibbs::update_counts() {
+void Gibbs::updateCounts() {
     for( size_t i = 0; i < nrVars(); i++ )
         _var_counts[i][_state[i]]++;
     for( size_t I = 0; I < nrFactors(); I++ )
-        _factor_counts[I][_factor_entries[I]]++;
-//        _factor_counts[I][get_factor_entry(I)]++;
+        _factor_counts[I][getFactorEntry(I)]++;
     _sample_count++;
 }
 
 
-inline size_t Gibbs::get_factor_entry( size_t I ) {
+inline size_t Gibbs::getFactorEntry( size_t I ) {
     size_t f_entry = 0;
-    VarSet::const_reverse_iterator check = factor(I).vars().rbegin();
     for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
-        size_t j = nbF(I)[_j];     // FIXME
-        assert( var(j) == *check );
+        // note that iterating over nbF(I) yields the same ordering
+        // of variables as iterating over factor(I).vars()
+        size_t j = nbF(I)[_j];
         f_entry *= var(j).states();
         f_entry += _state[j];
-        check++;
     }
     return f_entry;
 }
 
 
-inline size_t Gibbs::get_factor_entry_interval( size_t I, size_t i ) {
+inline size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
     size_t skip = 1;
-    VarSet::const_iterator check = factor(I).vars().begin();
     for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
-        size_t j = nbF(I)[_j];     // FIXME
-        assert( var(j) == *check );
+        // note that iterating over nbF(I) yields the same ordering
+        // of variables as iterating over factor(I).vars()
+        size_t j = nbF(I)[_j];
         if( i == j )
             break;
         else
             skip *= var(j).states();
-        check++;
     }
     return skip;
 }
 
 
-Prob Gibbs::get_var_dist( size_t i ) {
+Prob Gibbs::getVarDist( size_t i ) {
     assert( i < nrVars() );
     size_t i_states = var(i).states();
     Prob i_given_MB( i_states, 1.0 );
@@ -145,9 +128,8 @@ Prob Gibbs::get_var_dist( size_t i ) {
     // use markov blanket of var(i) to calculate distribution
     foreach( const Neighbor &I, nbV(i) ) {
         const Factor &f_I = factor(I);
-        size_t I_skip = get_factor_entry_interval( I, i );
-//        size_t I_entry = get_factor_entry(I) - (_state[i] * I_skip);
-        size_t I_entry = _factor_entries[I] - (_state[i] * I_skip);
+        size_t I_skip = getFactorEntryDiff( I, i );
+        size_t I_entry = getFactorEntry(I) - (_state[i] * I_skip);
         for( size_t st_i = 0; st_i < i_states; st_i++ ) {
             i_given_MB[st_i] *= f_I[I_entry];
             I_entry += I_skip;
@@ -158,17 +140,13 @@ Prob Gibbs::get_var_dist( size_t i ) {
 }
 
 
-inline void Gibbs::resample_var( size_t i ) {
+inline void Gibbs::resampleVar( size_t i ) {
     // draw randomly from conditional distribution and update _state
-    size_t new_state = get_var_dist(i).draw();
-    if( new_state != _state[i] ) {
-        _state[i] = new_state;
-        update_factor_entries( i );
-    }
+    _state[i] = getVarDist(i).draw();
 }
 
 
-void Gibbs::randomize_state() {
+void Gibbs::randomizeState() {
     for( size_t i = 0; i < nrVars(); i++ )
         _state[i] = rnd_int( 0, var(i).states() - 1 );
 }
@@ -191,13 +169,12 @@ double Gibbs::run() {
 
     double tic = toc();
     
-    randomize_state();
+    randomizeState();
 
-    calc_factor_entries();
     for( size_t iter = 0; iter < props.iters; iter++ ) {
         for( size_t j = 0; j < nrVars(); j++ )
-            resample_var( j );
-        update_counts();
+            resampleVar( j );
+        updateCounts();
     }
 
     if( props.verbose >= 3 ) {