Merge branch 'joris'
[libdai.git] / src / gibbs.cpp
1 /* Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net],
2 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <algorithm>
27 #include <dai/gibbs.h>
28 #include <dai/util.h>
29 #include <dai/properties.h>
30
31
32 namespace dai {
33
34
35 using namespace std;
36
37
38 const char *Gibbs::Name = "GIBBS";
39
40
41 void Gibbs::setProperties( const PropertySet &opts ) {
42 assert( opts.hasKey("iters") );
43 props.iters = opts.getStringAs<size_t>("iters");
44
45 if( opts.hasKey("verbose") )
46 props.verbose = opts.getStringAs<size_t>("verbose");
47 else
48 props.verbose = 0;
49 }
50
51
52 PropertySet Gibbs::getProperties() const {
53 PropertySet opts;
54 opts.Set( "iters", props.iters );
55 opts.Set( "verbose", props.verbose );
56 return opts;
57 }
58
59
60 string Gibbs::printProperties() const {
61 stringstream s( stringstream::out );
62 s << "[";
63 s << "iters=" << props.iters << ",";
64 s << "verbose=" << props.verbose << "]";
65 return s.str();
66 }
67
68
69 void Gibbs::construct() {
70 _var_counts.clear();
71 _var_counts.reserve( nrVars() );
72 for( size_t i = 0; i < nrVars(); i++ )
73 _var_counts.push_back( _count_t( var(i).states(), 0 ) );
74
75 _factor_counts.clear();
76 _factor_counts.reserve( nrFactors() );
77 for( size_t I = 0; I < nrFactors(); I++ )
78 _factor_counts.push_back( _count_t( factor(I).states(), 0 ) );
79
80 _sample_count = 0;
81
82 _state.clear();
83 _state.resize( nrVars(), 0 );
84 }
85
86
87 void Gibbs::updateCounts() {
88 for( size_t i = 0; i < nrVars(); i++ )
89 _var_counts[i][_state[i]]++;
90 for( size_t I = 0; I < nrFactors(); I++ )
91 _factor_counts[I][getFactorEntry(I)]++;
92 _sample_count++;
93 }
94
95
96 inline size_t Gibbs::getFactorEntry( size_t I ) {
97 size_t f_entry = 0;
98 for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
99 // note that iterating over nbF(I) yields the same ordering
100 // of variables as iterating over factor(I).vars()
101 size_t j = nbF(I)[_j];
102 f_entry *= var(j).states();
103 f_entry += _state[j];
104 }
105 return f_entry;
106 }
107
108
109 inline size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
110 size_t skip = 1;
111 for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
112 // note that iterating over nbF(I) yields the same ordering
113 // of variables as iterating over factor(I).vars()
114 size_t j = nbF(I)[_j];
115 if( i == j )
116 break;
117 else
118 skip *= var(j).states();
119 }
120 return skip;
121 }
122
123
124 Prob Gibbs::getVarDist( size_t i ) {
125 assert( i < nrVars() );
126 size_t i_states = var(i).states();
127 Prob i_given_MB( i_states, 1.0 );
128
129 // use markov blanket of var(i) to calculate distribution
130 foreach( const Neighbor &I, nbV(i) ) {
131 const Factor &f_I = factor(I);
132 size_t I_skip = getFactorEntryDiff( I, i );
133 size_t I_entry = getFactorEntry(I) - (_state[i] * I_skip);
134 for( size_t st_i = 0; st_i < i_states; st_i++ ) {
135 i_given_MB[st_i] *= f_I[I_entry];
136 I_entry += I_skip;
137 }
138 }
139
140 return i_given_MB.normalized();
141 }
142
143
144 inline void Gibbs::resampleVar( size_t i ) {
145 // draw randomly from conditional distribution and update _state
146 _state[i] = getVarDist(i).draw();
147 }
148
149
150 void Gibbs::randomizeState() {
151 for( size_t i = 0; i < nrVars(); i++ )
152 _state[i] = rnd_int( 0, var(i).states() - 1 );
153 }
154
155
156 void Gibbs::init() {
157 for( size_t i = 0; i < nrVars(); i++ )
158 fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
159 for( size_t I = 0; I < nrFactors(); I++ )
160 fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
161 _sample_count = 0;
162 }
163
164
165 double Gibbs::run() {
166 if( props.verbose >= 1 )
167 cout << "Starting " << identify() << "...";
168 if( props.verbose >= 3 )
169 cout << endl;
170
171 double tic = toc();
172
173 randomizeState();
174
175 for( size_t iter = 0; iter < props.iters; iter++ ) {
176 for( size_t j = 0; j < nrVars(); j++ )
177 resampleVar( j );
178 updateCounts();
179 }
180
181 if( props.verbose >= 3 ) {
182 for( size_t i = 0; i < nrVars(); i++ ) {
183 cerr << "belief for variable " << var(i) << ": " << beliefV(i) << endl;
184 cerr << "counts for variable " << var(i) << ": " << Prob( _var_counts[i].begin(), _var_counts[i].end() ) << endl;
185 }
186 }
187
188 if( props.verbose >= 3 )
189 cout << Name << "::run: ran " << props.iters << " passes (" << toc() - tic << " clocks)." << endl;
190
191 return 0.0;
192 }
193
194
195 inline Factor Gibbs::beliefV( size_t i ) const {
196 return Factor( var(i), _var_counts[i].begin() ).normalized();
197 }
198
199
200 inline Factor Gibbs::beliefF( size_t I ) const {
201 return Factor( factor(I).vars(), _factor_counts[I].begin() ).normalized();
202 }
203
204
205 Factor Gibbs::belief( const Var &n ) const {
206 return( beliefV( findVar( n ) ) );
207 }
208
209
210 vector<Factor> Gibbs::beliefs() const {
211 vector<Factor> result;
212 for( size_t i = 0; i < nrVars(); ++i )
213 result.push_back( beliefV(i) );
214 for( size_t I = 0; I < nrFactors(); ++I )
215 result.push_back( beliefF(I) );
216 return result;
217 }
218
219
220 Factor Gibbs::belief( const VarSet &ns ) const {
221 if( ns.size() == 1 )
222 return belief( *(ns.begin()) );
223 else {
224 size_t I;
225 for( I = 0; I < nrFactors(); I++ )
226 if( factor(I).vars() >> ns )
227 break;
228 assert( I != nrFactors() );
229 return beliefF(I).marginal(ns);
230 }
231 }
232
233
234 } // end of namespace dai