Merge branch 'master' of git.tuebingen.mpg.de:libdai
[libdai.git] / src / gibbs.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
8 * Copyright (C) 2008-2010 Joris Mooij [joris dot mooij at libdai dot org]
9 */
10
11
12 #include <iostream>
13 #include <sstream>
14 #include <map>
15 #include <set>
16 #include <algorithm>
17 #include <dai/gibbs.h>
18 #include <dai/util.h>
19 #include <dai/properties.h>
20
21
22 namespace dai {
23
24
25 using namespace std;
26
27
28 void Gibbs::setProperties( const PropertySet &opts ) {
29 DAI_ASSERT( opts.hasKey("maxiter") );
30 props.maxiter = opts.getStringAs<size_t>("maxiter");
31
32 if( opts.hasKey("restart") )
33 props.restart = opts.getStringAs<size_t>("restart");
34 else
35 props.restart = props.maxiter;
36 if( opts.hasKey("burnin") )
37 props.burnin = opts.getStringAs<size_t>("burnin");
38 else
39 props.burnin = 0;
40 if( opts.hasKey("maxtime") )
41 props.maxtime = opts.getStringAs<Real>("maxtime");
42 else
43 props.maxtime = INFINITY;
44 if( opts.hasKey("verbose") )
45 props.verbose = opts.getStringAs<size_t>("verbose");
46 else
47 props.verbose = 0;
48 }
49
50
51 PropertySet Gibbs::getProperties() const {
52 PropertySet opts;
53 opts.set( "maxiter", props.maxiter );
54 opts.set( "maxtime", props.maxtime );
55 opts.set( "restart", props.restart );
56 opts.set( "burnin", props.burnin );
57 opts.set( "verbose", props.verbose );
58 return opts;
59 }
60
61
62 string Gibbs::printProperties() const {
63 stringstream s( stringstream::out );
64 s << "[";
65 s << "maxiter=" << props.maxiter << ",";
66 s << "maxtime=" << props.maxtime << ",";
67 s << "restart=" << props.restart << ",";
68 s << "burnin=" << props.burnin << ",";
69 s << "verbose=" << props.verbose << "]";
70 return s.str();
71 }
72
73
74 void Gibbs::construct() {
75 _sample_count = 0;
76
77 _var_counts.clear();
78 _var_counts.reserve( nrVars() );
79 for( size_t i = 0; i < nrVars(); i++ )
80 _var_counts.push_back( _count_t( var(i).states(), 0 ) );
81
82 _factor_counts.clear();
83 _factor_counts.reserve( nrFactors() );
84 for( size_t I = 0; I < nrFactors(); I++ )
85 _factor_counts.push_back( _count_t( factor(I).nrStates(), 0 ) );
86
87 _iters = 0;
88
89 _state.clear();
90 _state.resize( nrVars(), 0 );
91
92 _max_state.clear();
93 _max_state.resize( nrVars(), 0 );
94
95 _max_score = logScore( _max_state );
96 }
97
98
99 void Gibbs::updateCounts() {
100 _sample_count++;
101 for( size_t i = 0; i < nrVars(); i++ )
102 _var_counts[i][_state[i]]++;
103 for( size_t I = 0; I < nrFactors(); I++ )
104 _factor_counts[I][getFactorEntry(I)]++;
105 Real score = logScore( _state );
106 if( score > _max_score ) {
107 _max_state = _state;
108 _max_score = score;
109 }
110 }
111
112
113 size_t Gibbs::getFactorEntry( size_t I ) {
114 size_t f_entry = 0;
115 for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
116 // note that iterating over nbF(I) yields the same ordering
117 // of variables as iterating over factor(I).vars()
118 size_t j = nbF(I)[_j];
119 f_entry *= var(j).states();
120 f_entry += _state[j];
121 }
122 return f_entry;
123 }
124
125
126 size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
127 size_t skip = 1;
128 for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
129 // note that iterating over nbF(I) yields the same ordering
130 // of variables as iterating over factor(I).vars()
131 size_t j = nbF(I)[_j];
132 if( i == j )
133 break;
134 else
135 skip *= var(j).states();
136 }
137 return skip;
138 }
139
140
141 Prob Gibbs::getVarDist( size_t i ) {
142 DAI_ASSERT( i < nrVars() );
143 size_t i_states = var(i).states();
144 Prob i_given_MB( i_states, 1.0 );
145
146 // use Markov blanket of var(i) to calculate distribution
147 foreach( const Neighbor &I, nbV(i) ) {
148 const Factor &f_I = factor(I);
149 size_t I_skip = getFactorEntryDiff( I, i );
150 size_t I_entry = getFactorEntry(I) - (_state[i] * I_skip);
151 for( size_t st_i = 0; st_i < i_states; st_i++ ) {
152 i_given_MB.set( st_i, i_given_MB[st_i] * f_I[I_entry] );
153 I_entry += I_skip;
154 }
155 }
156
157 if( i_given_MB.sum() == 0.0 )
158 // If no state of i is allowed, use uniform distribution
159 // FIXME is that indeed the right thing to do?
160 i_given_MB = Prob( i_states );
161 else
162 i_given_MB.normalize();
163 return i_given_MB;
164 }
165
166
167 void Gibbs::resampleVar( size_t i ) {
168 _state[i] = getVarDist(i).draw();
169 }
170
171
172 void Gibbs::randomizeState() {
173 for( size_t i = 0; i < nrVars(); i++ )
174 _state[i] = rnd( var(i).states() );
175 }
176
177
178 void Gibbs::init() {
179 _sample_count = 0;
180 for( size_t i = 0; i < nrVars(); i++ )
181 fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
182 for( size_t I = 0; I < nrFactors(); I++ )
183 fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
184 _iters = 0;
185 }
186
187
188 Real Gibbs::run() {
189 if( props.verbose >= 1 )
190 cerr << "Starting " << identify() << "...";
191 if( props.verbose >= 3 )
192 cerr << endl;
193
194 double tic = toc();
195
196 for( ; _iters < props.maxiter && (toc() - tic) < props.maxtime; _iters++ ) {
197 if( (_iters % props.restart) == 0 )
198 randomizeState();
199 for( size_t i = 0; i < nrVars(); i++ )
200 resampleVar( i );
201 if( (_iters % props.restart) > props.burnin )
202 updateCounts();
203 }
204
205 if( props.verbose >= 3 ) {
206 for( size_t i = 0; i < nrVars(); i++ ) {
207 cerr << "Belief for variable " << var(i) << ": " << beliefV(i) << endl;
208 cerr << "Counts for variable " << var(i) << ": " << Prob( _var_counts[i] ) << endl;
209 }
210 }
211
212 if( props.verbose >= 3 )
213 cerr << name() << "::run: ran " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
214
215 if( _iters == 0 )
216 return INFINITY;
217 else
218 return std::pow( _iters, -0.5 );
219 }
220
221
222 Factor Gibbs::beliefV( size_t i ) const {
223 if( _sample_count == 0 )
224 return Factor( var(i) );
225 else
226 return Factor( var(i), _var_counts[i] ).normalized();
227 }
228
229
230 Factor Gibbs::beliefF( size_t I ) const {
231 if( _sample_count == 0 )
232 return Factor( factor(I).vars() );
233 else
234 return Factor( factor(I).vars(), _factor_counts[I] ).normalized();
235 }
236
237
238 vector<Factor> Gibbs::beliefs() const {
239 vector<Factor> result;
240 for( size_t i = 0; i < nrVars(); ++i )
241 result.push_back( beliefV(i) );
242 for( size_t I = 0; I < nrFactors(); ++I )
243 result.push_back( beliefF(I) );
244 return result;
245 }
246
247
248 Factor Gibbs::belief( const VarSet &ns ) const {
249 if( ns.size() == 0 )
250 return Factor();
251 else if( ns.size() == 1 )
252 return beliefV( findVar( *(ns.begin()) ) );
253 else {
254 size_t I;
255 for( I = 0; I < nrFactors(); I++ )
256 if( factor(I).vars() >> ns )
257 break;
258 if( I == nrFactors() )
259 DAI_THROW(BELIEF_NOT_AVAILABLE);
260 return beliefF(I).marginal(ns);
261 }
262 }
263
264
265 std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t maxiter ) {
266 PropertySet gibbsProps;
267 gibbsProps.set( "maxiter", maxiter );
268 gibbsProps.set( "burnin", size_t(0) );
269 gibbsProps.set( "verbose", size_t(0) );
270 Gibbs gibbs( fg, gibbsProps );
271 gibbs.run();
272 return gibbs.state();
273 }
274
275
276 } // end of namespace dai