Finished release 0.2.4
[libdai.git] / src / gibbs.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
8 * Copyright (C) 2008-2010 Joris Mooij [joris dot mooij at libdai dot org]
9 */
10
11
12 #include <iostream>
13 #include <sstream>
14 #include <map>
15 #include <set>
16 #include <algorithm>
17 #include <dai/gibbs.h>
18 #include <dai/util.h>
19 #include <dai/properties.h>
20
21
22 namespace dai {
23
24
25 using namespace std;
26
27
28 const char *Gibbs::Name = "GIBBS";
29
30
31 void Gibbs::setProperties( const PropertySet &opts ) {
32 DAI_ASSERT( opts.hasKey("iters") );
33 props.iters = opts.getStringAs<size_t>("iters");
34
35 if( opts.hasKey("burnin") )
36 props.burnin = opts.getStringAs<size_t>("burnin");
37 else
38 props.burnin = 0;
39
40 if( opts.hasKey("verbose") )
41 props.verbose = opts.getStringAs<size_t>("verbose");
42 else
43 props.verbose = 0;
44 }
45
46
47 PropertySet Gibbs::getProperties() const {
48 PropertySet opts;
49 opts.Set( "iters", props.iters );
50 opts.Set( "burnin", props.burnin );
51 opts.Set( "verbose", props.verbose );
52 return opts;
53 }
54
55
56 string Gibbs::printProperties() const {
57 stringstream s( stringstream::out );
58 s << "[";
59 s << "iters=" << props.iters << ",";
60 s << "burnin=" << props.burnin << ",";
61 s << "verbose=" << props.verbose << "]";
62 return s.str();
63 }
64
65
66 void Gibbs::construct() {
67 _var_counts.clear();
68 _var_counts.reserve( nrVars() );
69 for( size_t i = 0; i < nrVars(); i++ )
70 _var_counts.push_back( _count_t( var(i).states(), 0 ) );
71
72 _factor_counts.clear();
73 _factor_counts.reserve( nrFactors() );
74 for( size_t I = 0; I < nrFactors(); I++ )
75 _factor_counts.push_back( _count_t( factor(I).states(), 0 ) );
76
77 _sample_count = 0;
78
79 _state.clear();
80 _state.resize( nrVars(), 0 );
81 }
82
83
84 void Gibbs::updateCounts() {
85 _sample_count++;
86 if( _sample_count > props.burnin ) {
87 for( size_t i = 0; i < nrVars(); i++ )
88 _var_counts[i][_state[i]]++;
89 for( size_t I = 0; I < nrFactors(); I++ )
90 _factor_counts[I][getFactorEntry(I)]++;
91 }
92 }
93
94
95 inline size_t Gibbs::getFactorEntry( size_t I ) {
96 size_t f_entry = 0;
97 for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
98 // note that iterating over nbF(I) yields the same ordering
99 // of variables as iterating over factor(I).vars()
100 size_t j = nbF(I)[_j];
101 f_entry *= var(j).states();
102 f_entry += _state[j];
103 }
104 return f_entry;
105 }
106
107
108 inline size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
109 size_t skip = 1;
110 for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
111 // note that iterating over nbF(I) yields the same ordering
112 // of variables as iterating over factor(I).vars()
113 size_t j = nbF(I)[_j];
114 if( i == j )
115 break;
116 else
117 skip *= var(j).states();
118 }
119 return skip;
120 }
121
122
123 Prob Gibbs::getVarDist( size_t i ) {
124 DAI_ASSERT( i < nrVars() );
125 size_t i_states = var(i).states();
126 Prob i_given_MB( i_states, 1.0 );
127
128 // use Markov blanket of var(i) to calculate distribution
129 foreach( const Neighbor &I, nbV(i) ) {
130 const Factor &f_I = factor(I);
131 size_t I_skip = getFactorEntryDiff( I, i );
132 size_t I_entry = getFactorEntry(I) - (_state[i] * I_skip);
133 for( size_t st_i = 0; st_i < i_states; st_i++ ) {
134 i_given_MB[st_i] *= f_I[I_entry];
135 I_entry += I_skip;
136 }
137 }
138
139 if( i_given_MB.sum() == 0.0 )
140 // If no state of i is allowed, use uniform distribution
141 // FIXME is that indeed the right thing to do?
142 i_given_MB = Prob( i_states );
143 else
144 i_given_MB.normalize();
145 return i_given_MB;
146 }
147
148
149 inline void Gibbs::resampleVar( size_t i ) {
150 _state[i] = getVarDist(i).draw();
151 }
152
153
154 void Gibbs::randomizeState() {
155 for( size_t i = 0; i < nrVars(); i++ )
156 _state[i] = rnd( var(i).states() );
157 }
158
159
160 void Gibbs::init() {
161 for( size_t i = 0; i < nrVars(); i++ )
162 fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
163 for( size_t I = 0; I < nrFactors(); I++ )
164 fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
165 _sample_count = 0;
166 }
167
168
169 Real Gibbs::run() {
170 if( props.verbose >= 1 )
171 cerr << "Starting " << identify() << "...";
172 if( props.verbose >= 3 )
173 cerr << endl;
174
175 double tic = toc();
176
177 randomizeState();
178
179 for( size_t iter = 0; iter < props.iters; iter++ ) {
180 for( size_t j = 0; j < nrVars(); j++ )
181 resampleVar( j );
182 updateCounts();
183 }
184
185 if( props.verbose >= 3 ) {
186 for( size_t i = 0; i < nrVars(); i++ ) {
187 cerr << "belief for variable " << var(i) << ": " << beliefV(i) << endl;
188 cerr << "counts for variable " << var(i) << ": " << Prob( _var_counts[i] ) << endl;
189 }
190 }
191
192 if( props.verbose >= 3 )
193 cerr << Name << "::run: ran " << props.iters << " passes (" << toc() - tic << " clocks)." << endl;
194
195 return 0.0;
196 }
197
198
199 Factor Gibbs::beliefV( size_t i ) const {
200 return Factor( var(i), _var_counts[i] ).normalized();
201 }
202
203
204 Factor Gibbs::beliefF( size_t I ) const {
205 return Factor( factor(I).vars(), _factor_counts[I] ).normalized();
206 }
207
208
209 vector<Factor> Gibbs::beliefs() const {
210 vector<Factor> result;
211 for( size_t i = 0; i < nrVars(); ++i )
212 result.push_back( beliefV(i) );
213 for( size_t I = 0; I < nrFactors(); ++I )
214 result.push_back( beliefF(I) );
215 return result;
216 }
217
218
219 Factor Gibbs::belief( const VarSet &ns ) const {
220 if( ns.size() == 0 )
221 return Factor();
222 else if( ns.size() == 1 )
223 return beliefV( findVar( *(ns.begin()) ) );
224 else {
225 size_t I;
226 for( I = 0; I < nrFactors(); I++ )
227 if( factor(I).vars() >> ns )
228 break;
229 if( I == nrFactors() )
230 DAI_THROW(BELIEF_NOT_AVAILABLE);
231 return beliefF(I).marginal(ns);
232 }
233 }
234
235
236 std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t iters ) {
237 PropertySet gibbsProps;
238 gibbsProps.Set("iters", iters);
239 gibbsProps.Set("burnin", size_t(0));
240 gibbsProps.Set("verbose", size_t(0));
241 Gibbs gibbs( fg, gibbsProps );
242 gibbs.run();
243 return gibbs.state();
244 }
245
246
247 } // end of namespace dai