Cleaned up TProb<T> and some misc other stuff, improved State class, and
[libdai.git] / src / gibbs.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
8 */
9
10
11 #include <iostream>
12 #include <sstream>
13 #include <map>
14 #include <set>
15 #include <algorithm>
16 #include <dai/gibbs.h>
17 #include <dai/util.h>
18 #include <dai/properties.h>
19
20
21 namespace dai {
22
23
24 using namespace std;
25
26
27 const char *Gibbs::Name = "GIBBS";
28
29
30 void Gibbs::setProperties( const PropertySet &opts ) {
31 DAI_ASSERT( opts.hasKey("iters") );
32 props.iters = opts.getStringAs<size_t>("iters");
33
34 if( opts.hasKey("verbose") )
35 props.verbose = opts.getStringAs<size_t>("verbose");
36 else
37 props.verbose = 0;
38 }
39
40
41 PropertySet Gibbs::getProperties() const {
42 PropertySet opts;
43 opts.Set( "iters", props.iters );
44 opts.Set( "verbose", props.verbose );
45 return opts;
46 }
47
48
49 string Gibbs::printProperties() const {
50 stringstream s( stringstream::out );
51 s << "[";
52 s << "iters=" << props.iters << ",";
53 s << "verbose=" << props.verbose << "]";
54 return s.str();
55 }
56
57
58 void Gibbs::construct() {
59 _var_counts.clear();
60 _var_counts.reserve( nrVars() );
61 for( size_t i = 0; i < nrVars(); i++ )
62 _var_counts.push_back( _count_t( var(i).states(), 0 ) );
63
64 _factor_counts.clear();
65 _factor_counts.reserve( nrFactors() );
66 for( size_t I = 0; I < nrFactors(); I++ )
67 _factor_counts.push_back( _count_t( factor(I).states(), 0 ) );
68
69 _sample_count = 0;
70
71 _state.clear();
72 _state.resize( nrVars(), 0 );
73 }
74
75
76 void Gibbs::updateCounts() {
77 for( size_t i = 0; i < nrVars(); i++ )
78 _var_counts[i][_state[i]]++;
79 for( size_t I = 0; I < nrFactors(); I++ )
80 _factor_counts[I][getFactorEntry(I)]++;
81 _sample_count++;
82 }
83
84
85 inline size_t Gibbs::getFactorEntry( size_t I ) {
86 size_t f_entry = 0;
87 for( int _j = nbF(I).size() - 1; _j >= 0; _j-- ) {
88 // note that iterating over nbF(I) yields the same ordering
89 // of variables as iterating over factor(I).vars()
90 size_t j = nbF(I)[_j];
91 f_entry *= var(j).states();
92 f_entry += _state[j];
93 }
94 return f_entry;
95 }
96
97
98 inline size_t Gibbs::getFactorEntryDiff( size_t I, size_t i ) {
99 size_t skip = 1;
100 for( size_t _j = 0; _j < nbF(I).size(); _j++ ) {
101 // note that iterating over nbF(I) yields the same ordering
102 // of variables as iterating over factor(I).vars()
103 size_t j = nbF(I)[_j];
104 if( i == j )
105 break;
106 else
107 skip *= var(j).states();
108 }
109 return skip;
110 }
111
112
113 Prob Gibbs::getVarDist( size_t i ) {
114 DAI_ASSERT( i < nrVars() );
115 size_t i_states = var(i).states();
116 Prob i_given_MB( i_states, 1.0 );
117
118 // use Markov blanket of var(i) to calculate distribution
119 foreach( const Neighbor &I, nbV(i) ) {
120 const Factor &f_I = factor(I);
121 size_t I_skip = getFactorEntryDiff( I, i );
122 size_t I_entry = getFactorEntry(I) - (_state[i] * I_skip);
123 for( size_t st_i = 0; st_i < i_states; st_i++ ) {
124 i_given_MB[st_i] *= f_I[I_entry];
125 I_entry += I_skip;
126 }
127 }
128
129 if( i_given_MB.sum() == 0.0 )
130 // If no state of i is allowed, use uniform distribution
131 // FIXME is that indeed the right thing to do?
132 i_given_MB = Prob( i_states );
133 else
134 i_given_MB.normalize();
135 return i_given_MB;
136 }
137
138
139 inline void Gibbs::resampleVar( size_t i ) {
140 // draw randomly from conditional distribution and update _state
141 _state[i] = getVarDist(i).draw();
142 }
143
144
145 void Gibbs::randomizeState() {
146 for( size_t i = 0; i < nrVars(); i++ )
147 _state[i] = rnd_int( 0, var(i).states() - 1 );
148 }
149
150
151 void Gibbs::init() {
152 for( size_t i = 0; i < nrVars(); i++ )
153 fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
154 for( size_t I = 0; I < nrFactors(); I++ )
155 fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
156 _sample_count = 0;
157 }
158
159
160 Real Gibbs::run() {
161 if( props.verbose >= 1 )
162 cerr << "Starting " << identify() << "...";
163 if( props.verbose >= 3 )
164 cerr << endl;
165
166 double tic = toc();
167
168 randomizeState();
169
170 for( size_t iter = 0; iter < props.iters; iter++ ) {
171 for( size_t j = 0; j < nrVars(); j++ )
172 resampleVar( j );
173 updateCounts();
174 }
175
176 if( props.verbose >= 3 ) {
177 for( size_t i = 0; i < nrVars(); i++ ) {
178 cerr << "belief for variable " << var(i) << ": " << beliefV(i) << endl;
179 cerr << "counts for variable " << var(i) << ": " << Prob( _var_counts[i] ) << endl;
180 }
181 }
182
183 if( props.verbose >= 3 )
184 cerr << Name << "::run: ran " << props.iters << " passes (" << toc() - tic << " clocks)." << endl;
185
186 return 0.0;
187 }
188
189
190 Factor Gibbs::beliefV( size_t i ) const {
191 return Factor( var(i), _var_counts[i] ).normalized();
192 }
193
194
195 Factor Gibbs::beliefF( size_t I ) const {
196 return Factor( factor(I).vars(), _factor_counts[I] ).normalized();
197 }
198
199
200 Factor Gibbs::belief( const Var &n ) const {
201 return( beliefV( findVar( n ) ) );
202 }
203
204
205 vector<Factor> Gibbs::beliefs() const {
206 vector<Factor> result;
207 for( size_t i = 0; i < nrVars(); ++i )
208 result.push_back( beliefV(i) );
209 for( size_t I = 0; I < nrFactors(); ++I )
210 result.push_back( beliefF(I) );
211 return result;
212 }
213
214
215 Factor Gibbs::belief( const VarSet &ns ) const {
216 if( ns.size() == 1 )
217 return belief( *(ns.begin()) );
218 else {
219 size_t I;
220 for( I = 0; I < nrFactors(); I++ )
221 if( factor(I).vars() >> ns )
222 break;
223 DAI_ASSERT( I != nrFactors() );
224 return beliefF(I).marginal(ns);
225 }
226 }
227
228
229 } // end of namespace dai