Removed cache because it's not always a performance improvement
[libdai.git] / src / gibbs.cpp
1 /* Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 #include <iostream>
22 #include <sstream>
23 #include <map>
24 #include <set>
25 #include <algorithm>
26 #include <dai/gibbs.h>
27 #include <dai/util.h>
28 #include <dai/properties.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *Gibbs::Name = "GIBBS";
38
39
40 void Gibbs::setProperties( const PropertySet &opts ) {
41 assert( opts.hasKey("iters") );
42 props.iters = opts.getStringAs<size_t>("iters");
43
44 if( opts.hasKey("verbose") )
45 props.verbose = opts.getStringAs<size_t>("verbose");
46 else
47 props.verbose = 0;
48 }
49
50
51 PropertySet Gibbs::getProperties() const {
52 PropertySet opts;
53 opts.Set( "iters", props.iters );
54 opts.Set( "verbose", props.verbose );
55 return opts;
56 }
57
58
59 string Gibbs::printProperties() const {
60 stringstream s( stringstream::out );
61 s << "[";
62 s << "iters=" << props.iters << ",";
63 s << "verbose=" << props.verbose << "]";
64 return s.str();
65 }
66
67
68 void Gibbs::construct() {
69 _var_counts.clear();
70 _var_counts.reserve( nrVars() );
71 for( size_t i = 0; i < nrVars(); i++ )
72 _var_counts.push_back( _count_t( var(i).states(), 0 ) );
73
74 _factor_counts.clear();
75 _factor_counts.reserve( nrFactors() );
76 for( size_t I = 0; I < nrFactors(); I++ )
77 _factor_counts.push_back( _count_t( factor(I).states(), 0 ) );
78
79 _sample_count = 0;
80
81 _state.clear();
82 _state.resize( nrVars(), 0 );
83 }
84
85
86 void Gibbs::updateCounts() {
87 for( size_t i = 0; i < nrVars(); i++ )
88 _var_counts[i][_state[i]]++;
89 for( size_t I = 0; I < nrFactors(); I++ )
90 _factor_counts[I][getFactorEntry(I)]++;
91 _sample_count++;
92 }
93
94
95 inline size_t Gibbs::getFactorEntry( size_t I ) {
96 size_t f_entry = 0;
97 for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
98 // note that iterating over nbF(I) yields the same ordering
99 // of variables as iterating over factor(I).vars()
100 size_t j = nbF(I)[_j];
101 f_entry *= var(j).states();
102 f_entry += _state[j];
103 }
104 return f_entry;
105 }
106
107
108 inline size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
109 size_t skip = 1;
110 for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
111 // note that iterating over nbF(I) yields the same ordering
112 // of variables as iterating over factor(I).vars()
113 size_t j = nbF(I)[_j];
114 if( i == j )
115 break;
116 else
117 skip *= var(j).states();
118 }
119 return skip;
120 }
121
122
123 Prob Gibbs::getVarDist( size_t i ) {
124 assert( i < nrVars() );
125 size_t i_states = var(i).states();
126 Prob i_given_MB( i_states, 1.0 );
127
128 // use markov blanket of var(i) to calculate distribution
129 foreach( const Neighbor &I, nbV(i) ) {
130 const Factor &f_I = factor(I);
131 size_t I_skip = getFactorEntryDiff( I, i );
132 size_t I_entry = getFactorEntry(I) - (_state[i] * I_skip);
133 for( size_t st_i = 0; st_i < i_states; st_i++ ) {
134 i_given_MB[st_i] *= f_I[I_entry];
135 I_entry += I_skip;
136 }
137 }
138
139 return i_given_MB.normalized();
140 }
141
142
143 inline void Gibbs::resampleVar( size_t i ) {
144 // draw randomly from conditional distribution and update _state
145 _state[i] = getVarDist(i).draw();
146 }
147
148
149 void Gibbs::randomizeState() {
150 for( size_t i = 0; i < nrVars(); i++ )
151 _state[i] = rnd_int( 0, var(i).states() - 1 );
152 }
153
154
155 void Gibbs::init() {
156 for( size_t i = 0; i < nrVars(); i++ )
157 fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
158 for( size_t I = 0; I < nrFactors(); I++ )
159 fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
160 _sample_count = 0;
161 }
162
163
164 double Gibbs::run() {
165 if( props.verbose >= 1 )
166 cout << "Starting " << identify() << "...";
167 if( props.verbose >= 3 )
168 cout << endl;
169
170 double tic = toc();
171
172 randomizeState();
173
174 for( size_t iter = 0; iter < props.iters; iter++ ) {
175 for( size_t j = 0; j < nrVars(); j++ )
176 resampleVar( j );
177 updateCounts();
178 }
179
180 if( props.verbose >= 3 ) {
181 for( size_t i = 0; i < nrVars(); i++ ) {
182 cerr << "belief for variable " << var(i) << ": " << beliefV(i) << endl;
183 cerr << "counts for variable " << var(i) << ": " << Prob( _var_counts[i].begin(), _var_counts[i].end() ) << endl;
184 }
185 }
186
187 if( props.verbose >= 3 )
188 cout << "Gibbs::run: ran " << props.iters << " passes (" << toc() - tic << " clocks)." << endl;
189
190 return 0.0;
191 }
192
193
194 inline Factor Gibbs::beliefV( size_t i ) const {
195 return Factor( var(i), _var_counts[i].begin() ).normalized();
196 }
197
198
199 inline Factor Gibbs::beliefF( size_t I ) const {
200 return Factor( factor(I).vars(), _factor_counts[I].begin() ).normalized();
201 }
202
203
204 Factor Gibbs::belief( const Var &n ) const {
205 return( beliefV( findVar( n ) ) );
206 }
207
208
209 vector<Factor> Gibbs::beliefs() const {
210 vector<Factor> result;
211 for( size_t i = 0; i < nrVars(); i++ )
212 result.push_back( beliefV(i) );
213 for( size_t I = 0; I < nrFactors(); I++ )
214 result.push_back( beliefF(I) );
215 return result;
216 }
217
218
219 Factor Gibbs::belief( const VarSet &ns ) const {
220 if( ns.size() == 1 )
221 return belief( *(ns.begin()) );
222 else {
223 size_t I;
224 for( I = 0; I < nrFactors(); I++ )
225 if( factor(I).vars() >> ns )
226 break;
227 assert( I != nrFactors() );
228 return beliefF(I).marginal(ns);
229 }
230 }
231
232
233 } // end of namespace dai