3dca7e7fe4d8b8be49f21cfcf0e129f9352b122c
[libdai.git] / src / gibbs.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <sstream>
11 #include <map>
12 #include <set>
13 #include <algorithm>
14 #include <dai/gibbs.h>
15 #include <dai/util.h>
16 #include <dai/properties.h>
17
18
19 namespace dai {
20
21
22 using namespace std;
23
24
25 void Gibbs::setProperties( const PropertySet &opts ) {
26 DAI_ASSERT( opts.hasKey("maxiter") );
27 props.maxiter = opts.getStringAs<size_t>("maxiter");
28
29 if( opts.hasKey("restart") )
30 props.restart = opts.getStringAs<size_t>("restart");
31 else
32 props.restart = props.maxiter;
33 if( opts.hasKey("burnin") )
34 props.burnin = opts.getStringAs<size_t>("burnin");
35 else
36 props.burnin = 0;
37 if( opts.hasKey("maxtime") )
38 props.maxtime = opts.getStringAs<Real>("maxtime");
39 else
40 props.maxtime = INFINITY;
41 if( opts.hasKey("verbose") )
42 props.verbose = opts.getStringAs<size_t>("verbose");
43 else
44 props.verbose = 0;
45 }
46
47
48 PropertySet Gibbs::getProperties() const {
49 PropertySet opts;
50 opts.set( "maxiter", props.maxiter );
51 opts.set( "maxtime", props.maxtime );
52 opts.set( "restart", props.restart );
53 opts.set( "burnin", props.burnin );
54 opts.set( "verbose", props.verbose );
55 return opts;
56 }
57
58
59 string Gibbs::printProperties() const {
60 stringstream s( stringstream::out );
61 s << "[";
62 s << "maxiter=" << props.maxiter << ",";
63 s << "maxtime=" << props.maxtime << ",";
64 s << "restart=" << props.restart << ",";
65 s << "burnin=" << props.burnin << ",";
66 s << "verbose=" << props.verbose << "]";
67 return s.str();
68 }
69
70
71 void Gibbs::construct() {
72 _sample_count = 0;
73
74 _var_counts.clear();
75 _var_counts.reserve( nrVars() );
76 for( size_t i = 0; i < nrVars(); i++ )
77 _var_counts.push_back( _count_t( var(i).states(), 0 ) );
78
79 _factor_counts.clear();
80 _factor_counts.reserve( nrFactors() );
81 for( size_t I = 0; I < nrFactors(); I++ )
82 _factor_counts.push_back( _count_t( factor(I).nrStates(), 0 ) );
83
84 _iters = 0;
85
86 _state.clear();
87 _state.resize( nrVars(), 0 );
88
89 _max_state.clear();
90 _max_state.resize( nrVars(), 0 );
91
92 _max_score = logScore( _max_state );
93 }
94
95
96 void Gibbs::updateCounts() {
97 _sample_count++;
98 for( size_t i = 0; i < nrVars(); i++ )
99 _var_counts[i][_state[i]]++;
100 for( size_t I = 0; I < nrFactors(); I++ )
101 _factor_counts[I][getFactorEntry(I)]++;
102 Real score = logScore( _state );
103 if( score > _max_score ) {
104 _max_state = _state;
105 _max_score = score;
106 }
107 }
108
109
110 size_t Gibbs::getFactorEntry( size_t I ) {
111 size_t f_entry = 0;
112 for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
113 // note that iterating over nbF(I) yields the same ordering
114 // of variables as iterating over factor(I).vars()
115 size_t j = nbF(I)[_j];
116 f_entry *= var(j).states();
117 f_entry += _state[j];
118 }
119 return f_entry;
120 }
121
122
123 size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
124 size_t skip = 1;
125 for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
126 // note that iterating over nbF(I) yields the same ordering
127 // of variables as iterating over factor(I).vars()
128 size_t j = nbF(I)[_j];
129 if( i == j )
130 break;
131 else
132 skip *= var(j).states();
133 }
134 return skip;
135 }
136
137
138 Prob Gibbs::getVarDist( size_t i ) {
139 DAI_ASSERT( i < nrVars() );
140 size_t i_states = var(i).states();
141 Prob i_given_MB( i_states, 1.0 );
142
143 // use Markov blanket of var(i) to calculate distribution
144 bforeach( const Neighbor &I, nbV(i) ) {
145 const Factor &f_I = factor(I);
146 size_t I_skip = getFactorEntryDiff( I, i );
147 size_t I_entry = getFactorEntry(I) - (_state[i] * I_skip);
148 for( size_t st_i = 0; st_i < i_states; st_i++ ) {
149 i_given_MB.set( st_i, i_given_MB[st_i] * f_I[I_entry] );
150 I_entry += I_skip;
151 }
152 }
153
154 if( i_given_MB.sum() == 0.0 )
155 // If no state of i is allowed, use uniform distribution
156 // FIXME is that indeed the right thing to do?
157 i_given_MB = Prob( i_states );
158 else
159 i_given_MB.normalize();
160 return i_given_MB;
161 }
162
163
164 void Gibbs::resampleVar( size_t i ) {
165 _state[i] = getVarDist(i).draw();
166 }
167
168
169 void Gibbs::randomizeState() {
170 for( size_t i = 0; i < nrVars(); i++ )
171 _state[i] = rnd( var(i).states() );
172 }
173
174
175 void Gibbs::init() {
176 _sample_count = 0;
177 for( size_t i = 0; i < nrVars(); i++ )
178 fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
179 for( size_t I = 0; I < nrFactors(); I++ )
180 fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
181 _iters = 0;
182 }
183
184
185 Real Gibbs::run() {
186 if( props.verbose >= 1 )
187 cerr << "Starting " << identify() << "...";
188 if( props.verbose >= 3 )
189 cerr << endl;
190
191 double tic = toc();
192
193 for( ; _iters < props.maxiter && (toc() - tic) < props.maxtime; _iters++ ) {
194 if( (_iters % props.restart) == 0 )
195 randomizeState();
196 for( size_t i = 0; i < nrVars(); i++ )
197 resampleVar( i );
198 if( (_iters % props.restart) > props.burnin )
199 updateCounts();
200 }
201
202 if( props.verbose >= 3 ) {
203 for( size_t i = 0; i < nrVars(); i++ ) {
204 cerr << "Belief for variable " << var(i) << ": " << beliefV(i) << endl;
205 cerr << "Counts for variable " << var(i) << ": " << Prob( _var_counts[i] ) << endl;
206 }
207 }
208
209 if( props.verbose >= 3 )
210 cerr << name() << "::run: ran " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
211
212 if( _iters == 0 )
213 return INFINITY;
214 else
215 return std::pow( _iters, -0.5 );
216 }
217
218
219 Factor Gibbs::beliefV( size_t i ) const {
220 if( _sample_count == 0 )
221 return Factor( var(i) );
222 else
223 return Factor( var(i), _var_counts[i] ).normalized();
224 }
225
226
227 Factor Gibbs::beliefF( size_t I ) const {
228 if( _sample_count == 0 )
229 return Factor( factor(I).vars() );
230 else
231 return Factor( factor(I).vars(), _factor_counts[I] ).normalized();
232 }
233
234
235 vector<Factor> Gibbs::beliefs() const {
236 vector<Factor> result;
237 for( size_t i = 0; i < nrVars(); ++i )
238 result.push_back( beliefV(i) );
239 for( size_t I = 0; I < nrFactors(); ++I )
240 result.push_back( beliefF(I) );
241 return result;
242 }
243
244
245 Factor Gibbs::belief( const VarSet &ns ) const {
246 if( ns.size() == 0 )
247 return Factor();
248 else if( ns.size() == 1 )
249 return beliefV( findVar( *(ns.begin()) ) );
250 else {
251 size_t I;
252 for( I = 0; I < nrFactors(); I++ )
253 if( factor(I).vars() >> ns )
254 break;
255 if( I == nrFactors() )
256 DAI_THROW(BELIEF_NOT_AVAILABLE);
257 return beliefF(I).marginal(ns);
258 }
259 }
260
261
262 std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t maxiter ) {
263 PropertySet gibbsProps;
264 gibbsProps.set( "maxiter", maxiter );
265 gibbsProps.set( "burnin", size_t(0) );
266 gibbsProps.set( "verbose", size_t(0) );
267 Gibbs gibbs( fg, gibbsProps );
268 gibbs.run();
269 return gibbs.state();
270 }
271
272
273 } // end of namespace dai