Improved error messages of Evidence::addEvidenceTabFile
[libdai.git] / src / gibbs.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
8 */
9
10
11 #include <iostream>
12 #include <sstream>
13 #include <map>
14 #include <set>
15 #include <algorithm>
16 #include <dai/gibbs.h>
17 #include <dai/util.h>
18 #include <dai/properties.h>
19
20
21 namespace dai {
22
23
24 using namespace std;
25
26
27 const char *Gibbs::Name = "GIBBS";
28
29
30 void Gibbs::setProperties( const PropertySet &opts ) {
31 DAI_ASSERT( opts.hasKey("iters") );
32 props.iters = opts.getStringAs<size_t>("iters");
33
34 if( opts.hasKey("burnin") )
35 props.burnin = opts.getStringAs<size_t>("burnin");
36 else
37 props.burnin = 0;
38
39 if( opts.hasKey("verbose") )
40 props.verbose = opts.getStringAs<size_t>("verbose");
41 else
42 props.verbose = 0;
43 }
44
45
46 PropertySet Gibbs::getProperties() const {
47 PropertySet opts;
48 opts.Set( "iters", props.iters );
49 opts.Set( "burnin", props.burnin );
50 opts.Set( "verbose", props.verbose );
51 return opts;
52 }
53
54
55 string Gibbs::printProperties() const {
56 stringstream s( stringstream::out );
57 s << "[";
58 s << "iters=" << props.iters << ",";
59 s << "burnin=" << props.burnin << ",";
60 s << "verbose=" << props.verbose << "]";
61 return s.str();
62 }
63
64
65 void Gibbs::construct() {
66 _var_counts.clear();
67 _var_counts.reserve( nrVars() );
68 for( size_t i = 0; i < nrVars(); i++ )
69 _var_counts.push_back( _count_t( var(i).states(), 0 ) );
70
71 _factor_counts.clear();
72 _factor_counts.reserve( nrFactors() );
73 for( size_t I = 0; I < nrFactors(); I++ )
74 _factor_counts.push_back( _count_t( factor(I).states(), 0 ) );
75
76 _sample_count = 0;
77
78 _state.clear();
79 _state.resize( nrVars(), 0 );
80 }
81
82
83 void Gibbs::updateCounts() {
84 _sample_count++;
85 if( _sample_count > props.burnin ) {
86 for( size_t i = 0; i < nrVars(); i++ )
87 _var_counts[i][_state[i]]++;
88 for( size_t I = 0; I < nrFactors(); I++ )
89 _factor_counts[I][getFactorEntry(I)]++;
90 }
91 }
92
93
94 inline size_t Gibbs::getFactorEntry( size_t I ) {
95 size_t f_entry = 0;
96 for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
97 // note that iterating over nbF(I) yields the same ordering
98 // of variables as iterating over factor(I).vars()
99 size_t j = nbF(I)[_j];
100 f_entry *= var(j).states();
101 f_entry += _state[j];
102 }
103 return f_entry;
104 }
105
106
107 inline size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
108 size_t skip = 1;
109 for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
110 // note that iterating over nbF(I) yields the same ordering
111 // of variables as iterating over factor(I).vars()
112 size_t j = nbF(I)[_j];
113 if( i == j )
114 break;
115 else
116 skip *= var(j).states();
117 }
118 return skip;
119 }
120
121
122 Prob Gibbs::getVarDist( size_t i ) {
123 DAI_ASSERT( i < nrVars() );
124 size_t i_states = var(i).states();
125 Prob i_given_MB( i_states, 1.0 );
126
127 // use Markov blanket of var(i) to calculate distribution
128 foreach( const Neighbor &I, nbV(i) ) {
129 const Factor &f_I = factor(I);
130 size_t I_skip = getFactorEntryDiff( I, i );
131 size_t I_entry = getFactorEntry(I) - (_state[i] * I_skip);
132 for( size_t st_i = 0; st_i < i_states; st_i++ ) {
133 i_given_MB[st_i] *= f_I[I_entry];
134 I_entry += I_skip;
135 }
136 }
137
138 if( i_given_MB.sum() == 0.0 )
139 // If no state of i is allowed, use uniform distribution
140 // FIXME is that indeed the right thing to do?
141 i_given_MB = Prob( i_states );
142 else
143 i_given_MB.normalize();
144 return i_given_MB;
145 }
146
147
148 inline void Gibbs::resampleVar( size_t i ) {
149 _state[i] = getVarDist(i).draw();
150 }
151
152
153 void Gibbs::randomizeState() {
154 for( size_t i = 0; i < nrVars(); i++ )
155 _state[i] = rnd( var(i).states() );
156 }
157
158
159 void Gibbs::init() {
160 for( size_t i = 0; i < nrVars(); i++ )
161 fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
162 for( size_t I = 0; I < nrFactors(); I++ )
163 fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
164 _sample_count = 0;
165 }
166
167
168 Real Gibbs::run() {
169 if( props.verbose >= 1 )
170 cerr << "Starting " << identify() << "...";
171 if( props.verbose >= 3 )
172 cerr << endl;
173
174 double tic = toc();
175
176 randomizeState();
177
178 for( size_t iter = 0; iter < props.iters; iter++ ) {
179 for( size_t j = 0; j < nrVars(); j++ )
180 resampleVar( j );
181 updateCounts();
182 }
183
184 if( props.verbose >= 3 ) {
185 for( size_t i = 0; i < nrVars(); i++ ) {
186 cerr << "belief for variable " << var(i) << ": " << beliefV(i) << endl;
187 cerr << "counts for variable " << var(i) << ": " << Prob( _var_counts[i] ) << endl;
188 }
189 }
190
191 if( props.verbose >= 3 )
192 cerr << Name << "::run: ran " << props.iters << " passes (" << toc() - tic << " clocks)." << endl;
193
194 return 0.0;
195 }
196
197
198 Factor Gibbs::beliefV( size_t i ) const {
199 return Factor( var(i), _var_counts[i] ).normalized();
200 }
201
202
203 Factor Gibbs::beliefF( size_t I ) const {
204 return Factor( factor(I).vars(), _factor_counts[I] ).normalized();
205 }
206
207
208 Factor Gibbs::belief( const Var &n ) const {
209 return( beliefV( findVar( n ) ) );
210 }
211
212
213 vector<Factor> Gibbs::beliefs() const {
214 vector<Factor> result;
215 for( size_t i = 0; i < nrVars(); ++i )
216 result.push_back( beliefV(i) );
217 for( size_t I = 0; I < nrFactors(); ++I )
218 result.push_back( beliefF(I) );
219 return result;
220 }
221
222
223 Factor Gibbs::belief( const VarSet &ns ) const {
224 if( ns.size() == 1 )
225 return belief( *(ns.begin()) );
226 else {
227 size_t I;
228 for( I = 0; I < nrFactors(); I++ )
229 if( factor(I).vars() >> ns )
230 break;
231 DAI_ASSERT( I != nrFactors() );
232 return beliefF(I).marginal(ns);
233 }
234 }
235
236
237 std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t iters, size_t burnin ) {
238 PropertySet gibbsProps;
239 gibbsProps.Set("iters", iters);
240 gibbsProps.Set("burnin", burnin);
241 gibbsProps.Set("verbose", size_t(0));
242 Gibbs gibbs( fg, gibbsProps );
243 gibbs.run();
244 return gibbs.state();
245 }
246
247
248 } // end of namespace dai