Removed erroneous 'inline' directives in gibbs.cpp
[libdai.git] / src / gibbs.cpp
1 /* Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net],
2 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <algorithm>
27 #include <dai/gibbs.h>
28 #include <dai/util.h>
29 #include <dai/properties.h>
30
31
32 namespace dai {
33
34
35 using namespace std;
36
37
38 const char *Gibbs::Name = "GIBBS";
39
40
41 void Gibbs::setProperties( const PropertySet &opts ) {
42 assert( opts.hasKey("iters") );
43 props.iters = opts.getStringAs<size_t>("iters");
44
45 if( opts.hasKey("verbose") )
46 props.verbose = opts.getStringAs<size_t>("verbose");
47 else
48 props.verbose = 0;
49 }
50
51
52 PropertySet Gibbs::getProperties() const {
53 PropertySet opts;
54 opts.Set( "iters", props.iters );
55 opts.Set( "verbose", props.verbose );
56 return opts;
57 }
58
59
60 string Gibbs::printProperties() const {
61 stringstream s( stringstream::out );
62 s << "[";
63 s << "iters=" << props.iters << ",";
64 s << "verbose=" << props.verbose << "]";
65 return s.str();
66 }
67
68
69 void Gibbs::construct() {
70 _var_counts.clear();
71 _var_counts.reserve( nrVars() );
72 for( size_t i = 0; i < nrVars(); i++ )
73 _var_counts.push_back( _count_t( var(i).states(), 0 ) );
74
75 _factor_counts.clear();
76 _factor_counts.reserve( nrFactors() );
77 for( size_t I = 0; I < nrFactors(); I++ )
78 _factor_counts.push_back( _count_t( factor(I).states(), 0 ) );
79
80 _sample_count = 0;
81
82 _state.clear();
83 _state.resize( nrVars(), 0 );
84 }
85
86
87 void Gibbs::updateCounts() {
88 for( size_t i = 0; i < nrVars(); i++ )
89 _var_counts[i][_state[i]]++;
90 for( size_t I = 0; I < nrFactors(); I++ )
91 _factor_counts[I][getFactorEntry(I)]++;
92 _sample_count++;
93 }
94
95
96 inline size_t Gibbs::getFactorEntry( size_t I ) {
97 size_t f_entry = 0;
98 for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
99 // note that iterating over nbF(I) yields the same ordering
100 // of variables as iterating over factor(I).vars()
101 size_t j = nbF(I)[_j];
102 f_entry *= var(j).states();
103 f_entry += _state[j];
104 }
105 return f_entry;
106 }
107
108
109 inline size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
110 size_t skip = 1;
111 for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
112 // note that iterating over nbF(I) yields the same ordering
113 // of variables as iterating over factor(I).vars()
114 size_t j = nbF(I)[_j];
115 if( i == j )
116 break;
117 else
118 skip *= var(j).states();
119 }
120 return skip;
121 }
122
123
124 Prob Gibbs::getVarDist( size_t i ) {
125 assert( i < nrVars() );
126 size_t i_states = var(i).states();
127 Prob i_given_MB( i_states, 1.0 );
128
129 // use Markov blanket of var(i) to calculate distribution
130 foreach( const Neighbor &I, nbV(i) ) {
131 const Factor &f_I = factor(I);
132 size_t I_skip = getFactorEntryDiff( I, i );
133 size_t I_entry = getFactorEntry(I) - (_state[i] * I_skip);
134 for( size_t st_i = 0; st_i < i_states; st_i++ ) {
135 i_given_MB[st_i] *= f_I[I_entry];
136 I_entry += I_skip;
137 }
138 }
139
140 if( i_given_MB.sum() == 0.0 )
141 // If no state of i is allowed, use uniform distribution
142 // FIXME is that indeed the right thing to do?
143 i_given_MB = Prob( i_states );
144 else
145 i_given_MB.normalize();
146 return i_given_MB;
147 }
148
149
150 inline void Gibbs::resampleVar( size_t i ) {
151 // draw randomly from conditional distribution and update _state
152 _state[i] = getVarDist(i).draw();
153 }
154
155
156 void Gibbs::randomizeState() {
157 for( size_t i = 0; i < nrVars(); i++ )
158 _state[i] = rnd_int( 0, var(i).states() - 1 );
159 }
160
161
162 void Gibbs::init() {
163 for( size_t i = 0; i < nrVars(); i++ )
164 fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
165 for( size_t I = 0; I < nrFactors(); I++ )
166 fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
167 _sample_count = 0;
168 }
169
170
171 double Gibbs::run() {
172 if( props.verbose >= 1 )
173 cerr << "Starting " << identify() << "...";
174 if( props.verbose >= 3 )
175 cerr << endl;
176
177 double tic = toc();
178
179 randomizeState();
180
181 for( size_t iter = 0; iter < props.iters; iter++ ) {
182 for( size_t j = 0; j < nrVars(); j++ )
183 resampleVar( j );
184 updateCounts();
185 }
186
187 if( props.verbose >= 3 ) {
188 for( size_t i = 0; i < nrVars(); i++ ) {
189 cerr << "belief for variable " << var(i) << ": " << beliefV(i) << endl;
190 cerr << "counts for variable " << var(i) << ": " << Prob( _var_counts[i].begin(), _var_counts[i].end() ) << endl;
191 }
192 }
193
194 if( props.verbose >= 3 )
195 cerr << Name << "::run: ran " << props.iters << " passes (" << toc() - tic << " clocks)." << endl;
196
197 return 0.0;
198 }
199
200
201 Factor Gibbs::beliefV( size_t i ) const {
202 return Factor( var(i), _var_counts[i].begin() ).normalized();
203 }
204
205
206 Factor Gibbs::beliefF( size_t I ) const {
207 return Factor( factor(I).vars(), _factor_counts[I].begin() ).normalized();
208 }
209
210
211 Factor Gibbs::belief( const Var &n ) const {
212 return( beliefV( findVar( n ) ) );
213 }
214
215
216 vector<Factor> Gibbs::beliefs() const {
217 vector<Factor> result;
218 for( size_t i = 0; i < nrVars(); ++i )
219 result.push_back( beliefV(i) );
220 for( size_t I = 0; I < nrFactors(); ++I )
221 result.push_back( beliefF(I) );
222 return result;
223 }
224
225
226 Factor Gibbs::belief( const VarSet &ns ) const {
227 if( ns.size() == 1 )
228 return belief( *(ns.begin()) );
229 else {
230 size_t I;
231 for( I = 0; I < nrFactors(); I++ )
232 if( factor(I).vars() >> ns )
233 break;
234 assert( I != nrFactors() );
235 return beliefF(I).marginal(ns);
236 }
237 }
238
239
240 } // end of namespace dai