a07ee0331a2264b3fba4fd9fd86925e07d0fa2e3
[libdai.git] / src / gibbs.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
8 * Copyright (C) 2008-2010 Joris Mooij [joris dot mooij at libdai dot org]
9 */
10
11
12 #include <iostream>
13 #include <sstream>
14 #include <map>
15 #include <set>
16 #include <algorithm>
17 #include <dai/gibbs.h>
18 #include <dai/util.h>
19 #include <dai/properties.h>
20
21
22 namespace dai {
23
24
25 using namespace std;
26
27
28 const char *Gibbs::Name = "GIBBS";
29
30
31 void Gibbs::setProperties( const PropertySet &opts ) {
32 /// \deprecated The 'iters' property has been renamed into 'maxiters'
33 DAI_ASSERT( opts.hasKey("maxiter") || opts.hasKey("iters") );
34 if( opts.hasKey("maxiter") )
35 props.maxiter = opts.getStringAs<size_t>("maxiter");
36 else
37 props.maxiter = opts.getStringAs<size_t>("iters");
38
39 if( opts.hasKey("restart") )
40 props.restart = opts.getStringAs<size_t>("restart");
41 else
42 props.restart = props.maxiter;
43 if( opts.hasKey("burnin") )
44 props.burnin = opts.getStringAs<size_t>("burnin");
45 else
46 props.burnin = 0;
47 if( opts.hasKey("maxtime") )
48 props.maxtime = opts.getStringAs<Real>("maxtime");
49 else
50 props.maxtime = INFINITY;
51 if( opts.hasKey("verbose") )
52 props.verbose = opts.getStringAs<size_t>("verbose");
53 else
54 props.verbose = 0;
55 }
56
57
58 PropertySet Gibbs::getProperties() const {
59 PropertySet opts;
60 opts.set( "maxiter", props.maxiter );
61 opts.set( "maxtime", props.maxtime );
62 opts.set( "restart", props.restart );
63 opts.set( "burnin", props.burnin );
64 opts.set( "verbose", props.verbose );
65 return opts;
66 }
67
68
69 string Gibbs::printProperties() const {
70 stringstream s( stringstream::out );
71 s << "[";
72 s << "maxiter=" << props.maxiter << ",";
73 s << "maxtime=" << props.maxtime << ",";
74 s << "restart=" << props.restart << ",";
75 s << "burnin=" << props.burnin << ",";
76 s << "verbose=" << props.verbose << "]";
77 return s.str();
78 }
79
80
81 void Gibbs::construct() {
82 _sample_count = 0;
83
84 _var_counts.clear();
85 _var_counts.reserve( nrVars() );
86 for( size_t i = 0; i < nrVars(); i++ )
87 _var_counts.push_back( _count_t( var(i).states(), 0 ) );
88
89 _factor_counts.clear();
90 _factor_counts.reserve( nrFactors() );
91 for( size_t I = 0; I < nrFactors(); I++ )
92 _factor_counts.push_back( _count_t( factor(I).nrStates(), 0 ) );
93
94 _iters = 0;
95
96 _state.clear();
97 _state.resize( nrVars(), 0 );
98
99 _max_state.clear();
100 _max_state.resize( nrVars(), 0 );
101
102 _max_score = logScore( _max_state );
103 }
104
105
106 void Gibbs::updateCounts() {
107 _sample_count++;
108 for( size_t i = 0; i < nrVars(); i++ )
109 _var_counts[i][_state[i]]++;
110 for( size_t I = 0; I < nrFactors(); I++ )
111 _factor_counts[I][getFactorEntry(I)]++;
112 Real score = logScore( _state );
113 if( score > _max_score ) {
114 _max_state = _state;
115 _max_score = score;
116 }
117 }
118
119
120 size_t Gibbs::getFactorEntry( size_t I ) {
121 size_t f_entry = 0;
122 for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
123 // note that iterating over nbF(I) yields the same ordering
124 // of variables as iterating over factor(I).vars()
125 size_t j = nbF(I)[_j];
126 f_entry *= var(j).states();
127 f_entry += _state[j];
128 }
129 return f_entry;
130 }
131
132
133 size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
134 size_t skip = 1;
135 for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
136 // note that iterating over nbF(I) yields the same ordering
137 // of variables as iterating over factor(I).vars()
138 size_t j = nbF(I)[_j];
139 if( i == j )
140 break;
141 else
142 skip *= var(j).states();
143 }
144 return skip;
145 }
146
147
148 Prob Gibbs::getVarDist( size_t i ) {
149 DAI_ASSERT( i < nrVars() );
150 size_t i_states = var(i).states();
151 Prob i_given_MB( i_states, 1.0 );
152
153 // use Markov blanket of var(i) to calculate distribution
154 foreach( const Neighbor &I, nbV(i) ) {
155 const Factor &f_I = factor(I);
156 size_t I_skip = getFactorEntryDiff( I, i );
157 size_t I_entry = getFactorEntry(I) - (_state[i] * I_skip);
158 for( size_t st_i = 0; st_i < i_states; st_i++ ) {
159 i_given_MB.set( st_i, i_given_MB[st_i] * f_I[I_entry] );
160 I_entry += I_skip;
161 }
162 }
163
164 if( i_given_MB.sum() == 0.0 )
165 // If no state of i is allowed, use uniform distribution
166 // FIXME is that indeed the right thing to do?
167 i_given_MB = Prob( i_states );
168 else
169 i_given_MB.normalize();
170 return i_given_MB;
171 }
172
173
174 void Gibbs::resampleVar( size_t i ) {
175 _state[i] = getVarDist(i).draw();
176 }
177
178
179 void Gibbs::randomizeState() {
180 for( size_t i = 0; i < nrVars(); i++ )
181 _state[i] = rnd( var(i).states() );
182 }
183
184
185 void Gibbs::init() {
186 _sample_count = 0;
187 for( size_t i = 0; i < nrVars(); i++ )
188 fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
189 for( size_t I = 0; I < nrFactors(); I++ )
190 fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
191 _iters = 0;
192 }
193
194
195 Real Gibbs::run() {
196 if( props.verbose >= 1 )
197 cerr << "Starting " << identify() << "...";
198 if( props.verbose >= 3 )
199 cerr << endl;
200
201 double tic = toc();
202
203 for( ; _iters < props.maxiter && (toc() - tic) < props.maxtime; _iters++ ) {
204 if( (_iters % props.restart) == 0 )
205 randomizeState();
206 for( size_t i = 0; i < nrVars(); i++ )
207 resampleVar( i );
208 if( (_iters % props.restart) > props.burnin )
209 updateCounts();
210 }
211
212 if( props.verbose >= 3 ) {
213 for( size_t i = 0; i < nrVars(); i++ ) {
214 cerr << "Belief for variable " << var(i) << ": " << beliefV(i) << endl;
215 cerr << "Counts for variable " << var(i) << ": " << Prob( _var_counts[i] ) << endl;
216 }
217 }
218
219 if( props.verbose >= 3 )
220 cerr << Name << "::run: ran " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
221
222 if( _iters == 0 )
223 return INFINITY;
224 else
225 return std::pow( _iters, -0.5 );
226 }
227
228
229 Factor Gibbs::beliefV( size_t i ) const {
230 if( _sample_count == 0 )
231 return Factor( var(i) );
232 else
233 return Factor( var(i), _var_counts[i] ).normalized();
234 }
235
236
237 Factor Gibbs::beliefF( size_t I ) const {
238 if( _sample_count == 0 )
239 return Factor( factor(I).vars() );
240 else
241 return Factor( factor(I).vars(), _factor_counts[I] ).normalized();
242 }
243
244
245 vector<Factor> Gibbs::beliefs() const {
246 vector<Factor> result;
247 for( size_t i = 0; i < nrVars(); ++i )
248 result.push_back( beliefV(i) );
249 for( size_t I = 0; I < nrFactors(); ++I )
250 result.push_back( beliefF(I) );
251 return result;
252 }
253
254
255 Factor Gibbs::belief( const VarSet &ns ) const {
256 if( ns.size() == 0 )
257 return Factor();
258 else if( ns.size() == 1 )
259 return beliefV( findVar( *(ns.begin()) ) );
260 else {
261 size_t I;
262 for( I = 0; I < nrFactors(); I++ )
263 if( factor(I).vars() >> ns )
264 break;
265 if( I == nrFactors() )
266 DAI_THROW(BELIEF_NOT_AVAILABLE);
267 return beliefF(I).marginal(ns);
268 }
269 }
270
271
272 std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t iters ) {
273 PropertySet gibbsProps;
274 gibbsProps.set( "maxiter", iters );
275 gibbsProps.set( "burnin", size_t(0) );
276 gibbsProps.set( "verbose", size_t(0) );
277 Gibbs gibbs( fg, gibbsProps );
278 gibbs.run();
279 return gibbs.state();
280 }
281
282
283 } // end of namespace dai