bce5b0f09f7a94cf2a8a91e162c199abe952d546
[libdai.git] / src / gibbs.cpp
1 /* Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 #include <iostream>
22 #include <sstream>
23 #include <map>
24 #include <set>
25 #include <algorithm>
26 #include <dai/gibbs.h>
27 #include <dai/util.h>
28 #include <dai/properties.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *Gibbs::Name = "GIBBS";
38
39
40 void Gibbs::setProperties( const PropertySet &opts ) {
41 assert( opts.hasKey("iters") );
42 props.iters = opts.getStringAs<size_t>("iters");
43
44 if( opts.hasKey("verbose") )
45 props.verbose = opts.getStringAs<size_t>("verbose");
46 else
47 props.verbose = 0;
48 }
49
50
51 PropertySet Gibbs::getProperties() const {
52 PropertySet opts;
53 opts.Set( "iters", props.iters );
54 opts.Set( "verbose", props.verbose );
55 return opts;
56 }
57
58
59 string Gibbs::printProperties() const {
60 stringstream s( stringstream::out );
61 s << "[";
62 s << "iters=" << props.iters << ",";
63 s << "verbose=" << props.verbose << "]";
64 return s.str();
65 }
66
67
68 void Gibbs::construct() {
69 _var_counts.clear();
70 _var_counts.reserve( nrVars() );
71 for( size_t i = 0; i < nrVars(); i++ )
72 _var_counts.push_back( _count_t( var(i).states(), 0 ) );
73
74 _factor_counts.clear();
75 _factor_counts.reserve( nrFactors() );
76 for( size_t I = 0; I < nrFactors(); I++ )
77 _factor_counts.push_back( _count_t( factor(I).states(), 0 ) );
78
79 _sample_count = 0;
80
81 _factor_entries.clear();
82 _factor_entries.resize( nrFactors(), 0 );
83
84 _state.clear();
85 _state.resize( nrVars(), 0 );
86 }
87
88
89 void Gibbs::calc_factor_entries() {
90 for( size_t I = 0; I < nrFactors(); I++ )
91 _factor_entries[I] = get_factor_entry( I );
92 }
93
94 void Gibbs::update_factor_entries( size_t i ) {
95 foreach( const Neighbor &I, nbV(i) )
96 _factor_entries[I] = get_factor_entry( I );
97 }
98
99
100 void Gibbs::update_counts() {
101 for( size_t i = 0; i < nrVars(); i++ )
102 _var_counts[i][_state[i]]++;
103 for( size_t I = 0; I < nrFactors(); I++ )
104 _factor_counts[I][_factor_entries[I]]++;
105 // _factor_counts[I][get_factor_entry(I)]++;
106 _sample_count++;
107 }
108
109
110 inline size_t Gibbs::get_factor_entry( size_t I ) {
111 size_t f_entry = 0;
112 VarSet::const_reverse_iterator check = factor(I).vars().rbegin();
113 for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
114 size_t j = nbF(I)[_j]; // FIXME
115 assert( var(j) == *check );
116 f_entry *= var(j).states();
117 f_entry += _state[j];
118 check++;
119 }
120 return f_entry;
121 }
122
123
124 inline size_t Gibbs::get_factor_entry_interval( size_t I, size_t i ) {
125 size_t skip = 1;
126 VarSet::const_iterator check = factor(I).vars().begin();
127 for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
128 size_t j = nbF(I)[_j]; // FIXME
129 assert( var(j) == *check );
130 if( i == j )
131 break;
132 else
133 skip *= var(j).states();
134 check++;
135 }
136 return skip;
137 }
138
139
140 Prob Gibbs::get_var_dist( size_t i ) {
141 assert( i < nrVars() );
142 size_t i_states = var(i).states();
143 Prob i_given_MB( i_states, 1.0 );
144
145 // use markov blanket of var(i) to calculate distribution
146 foreach( const Neighbor &I, nbV(i) ) {
147 const Factor &f_I = factor(I);
148 size_t I_skip = get_factor_entry_interval( I, i );
149 // size_t I_entry = get_factor_entry(I) - (_state[i] * I_skip);
150 size_t I_entry = _factor_entries[I] - (_state[i] * I_skip);
151 for( size_t st_i = 0; st_i < i_states; st_i++ ) {
152 i_given_MB[st_i] *= f_I[I_entry];
153 I_entry += I_skip;
154 }
155 }
156
157 return i_given_MB.normalized();
158 }
159
160
161 inline void Gibbs::resample_var( size_t i ) {
162 // draw randomly from conditional distribution and update _state
163 size_t new_state = get_var_dist(i).draw();
164 if( new_state != _state[i] ) {
165 _state[i] = new_state;
166 update_factor_entries( i );
167 }
168 }
169
170
171 void Gibbs::randomize_state() {
172 for( size_t i = 0; i < nrVars(); i++ )
173 _state[i] = rnd_int( 0, var(i).states() - 1 );
174 }
175
176
177 void Gibbs::init() {
178 for( size_t i = 0; i < nrVars(); i++ )
179 fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
180 for( size_t I = 0; I < nrFactors(); I++ )
181 fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
182 _sample_count = 0;
183 }
184
185
186 double Gibbs::run() {
187 if( props.verbose >= 1 )
188 cout << "Starting " << identify() << "...";
189 if( props.verbose >= 3 )
190 cout << endl;
191
192 double tic = toc();
193
194 randomize_state();
195
196 calc_factor_entries();
197 for( size_t iter = 0; iter < props.iters; iter++ ) {
198 for( size_t j = 0; j < nrVars(); j++ )
199 resample_var( j );
200 update_counts();
201 }
202
203 if( props.verbose >= 3 ) {
204 for( size_t i = 0; i < nrVars(); i++ ) {
205 cerr << "belief for variable " << var(i) << ": " << beliefV(i) << endl;
206 cerr << "counts for variable " << var(i) << ": " << Prob( _var_counts[i].begin(), _var_counts[i].end() ) << endl;
207 }
208 }
209
210 if( props.verbose >= 3 )
211 cout << "Gibbs::run: ran " << props.iters << " passes (" << toc() - tic << " clocks)." << endl;
212
213 return 0.0;
214 }
215
216
217 inline Factor Gibbs::beliefV( size_t i ) const {
218 return Factor( var(i), _var_counts[i].begin() ).normalized();
219 }
220
221
222 inline Factor Gibbs::beliefF( size_t I ) const {
223 return Factor( factor(I).vars(), _factor_counts[I].begin() ).normalized();
224 }
225
226
227 Factor Gibbs::belief( const Var &n ) const {
228 return( beliefV( findVar( n ) ) );
229 }
230
231
232 vector<Factor> Gibbs::beliefs() const {
233 vector<Factor> result;
234 for( size_t i = 0; i < nrVars(); i++ )
235 result.push_back( beliefV(i) );
236 for( size_t I = 0; I < nrFactors(); I++ )
237 result.push_back( beliefF(I) );
238 return result;
239 }
240
241
242 Factor Gibbs::belief( const VarSet &ns ) const {
243 if( ns.size() == 1 )
244 return belief( *(ns.begin()) );
245 else {
246 size_t I;
247 for( I = 0; I < nrFactors(); I++ )
248 if( factor(I).vars() >> ns )
249 break;
250 assert( I != nrFactors() );
251 return beliefF(I).marginal(ns);
252 }
253 }
254
255
256 } // end of namespace dai