Merge branch 'master' into gibbs
[libdai.git] / src / gibbs.cpp
1 /* Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 #include <iostream>
22 #include <sstream>
23 #include <map>
24 #include <set>
25 #include <algorithm>
26 #include <dai/gibbs.h>
27 #include <dai/util.h>
28 #include <dai/properties.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *Gibbs::Name = "GIBBS";
38
39
40 void Gibbs::setProperties( const PropertySet &opts ) {
41 assert( opts.hasKey("iters") );
42 props.iters = opts.getStringAs<size_t>("iters");
43
44 if( opts.hasKey("verbose") )
45 props.verbose = opts.getStringAs<size_t>("verbose");
46 else
47 props.verbose = 0;
48 }
49
50
51 PropertySet Gibbs::getProperties() const {
52 PropertySet opts;
53 opts.Set( "iters", props.iters );
54 opts.Set( "verbose", props.verbose );
55 return opts;
56 }
57
58
59 string Gibbs::printProperties() const {
60 stringstream s( stringstream::out );
61 s << "[";
62 s << "iters=" << props.iters << ",";
63 s << "verbose=" << props.verbose << "]";
64 return s.str();
65 }
66
67
68 void Gibbs::construct() {
69 _var_counts.clear();
70 _var_counts.reserve( nrVars() );
71 for( size_t i = 0; i < nrVars(); i++ )
72 _var_counts.push_back( _count_t( var(i).states(), 0 ) );
73 _factor_counts.clear();
74 _factor_counts.reserve( nrFactors() );
75 for( size_t I = 0; I < nrFactors(); I++ )
76 _factor_counts.push_back( _count_t( factor(I).states(), 0 ) );
77 _sample_count = 0;
78 }
79
80
81 void Gibbs::update_counts( _state_t &st ) {
82 for( size_t i = 0; i < nrVars(); i++ )
83 _var_counts[i][st[i]]++;
84 for( size_t I = 0; I < nrFactors(); I++ ) {
85 if( 0 ) {
86 /* multind mi( factor(I).vars() );
87 _state_t f_st( factor(I).vars().size() );
88 int k = 0;
89 foreach( size_t j, nbF(I) )
90 f_st[k++] = st[j];
91 _factor_counts[I][mi.li(f_st)]++;*/
92 } else {
93 size_t ent = get_factor_entry(st, I);
94 _factor_counts[I][ent]++;
95 }
96 }
97 _sample_count++;
98 }
99
100
101 inline
102 size_t Gibbs::get_factor_entry(const _state_t &st, int factor) {
103 size_t f_entry=0;
104 int rank = nbF(factor).size();
105 for(int j=rank-1; j>=0; j--) {
106 int jn = nbF(factor)[j];
107 f_entry *= var(jn).states();
108 f_entry += st[jn];
109 }
110 return f_entry;
111 }
112
113
114 Prob Gibbs::get_var_dist( _state_t &st, size_t i ) {
115 assert( st.size() == vars().size() );
116 assert( i < nrVars() );
117 if( 1 ) {
118 // use markov blanket of n to calculate distribution
119 size_t dim = var(i).states();
120 Neighbors &facts = nbV(i);
121
122 Prob values( dim, 1.0 );
123
124 for( size_t I = 0; I < facts.size(); I++ ) {
125 size_t fa = facts[I];
126 const Factor &f = factor(fa);
127 int save_ind = st[i];
128 for( size_t k = 0; k < dim; k++ ) {
129 st[i] = k;
130 int f_entry = get_factor_entry(st, fa);
131 values[k] *= f[f_entry];
132 }
133 st[i] = save_ind;
134 }
135
136 return values.normalized();
137 } else {
138 /* Var vi = var(i);
139 Factor d(vi);
140 assert(vi.states()>0);
141 assert(vi.label()>=0);
142 // loop over factors containing i (nbV(i)):
143 foreach(size_t I, nbV(i)) {
144 // use multind to find linear state for variables != i in factor
145 assert(I<nrFactors());
146 assert(factor(I).vars().size() > 0);
147 VarSet vs (factor(I).vars() / vi);
148 multind mi(vs);
149 _state_t I_st(vs.size());
150 int k=0;
151 foreach(size_t l, nbF(I)) {
152 if(l!=i) I_st[k++] = st[l];
153 }
154 // use slice(ns,ns_state) to get beliefs for variable i
155 // multiply all these beliefs together
156 d *= factor(I).slice(vs, mi.li(I_st));
157 }
158 d.p().normalize();
159 return d.p();*/
160 }
161 }
162
163
164 void Gibbs::resample_var( _state_t &st, size_t i ) {
165 // draw randomly from conditional distribution and update 'st'
166 st[i] = get_var_dist( st, i ).draw();
167 }
168
169
170 void Gibbs::randomize_state( _state_t &st ) {
171 assert( st.size() == nrVars() );
172 for( size_t i = 0; i < nrVars(); i++ )
173 st[i] = rnd_int( 0, var(i).states() - 1 );
174 }
175
176
177 void Gibbs::init() {
178 for( size_t i = 0; i < nrVars(); i++ )
179 fill( _var_counts[i].begin(), _var_counts[i].end(), 0 );
180 for( size_t I = 0; I < nrFactors(); I++ )
181 fill( _factor_counts[I].begin(), _factor_counts[I].end(), 0 );
182 _sample_count = 0;
183 }
184
185
186 double Gibbs::run() {
187 if( props.verbose >= 1 )
188 cout << "Starting " << identify() << "...";
189 if( props.verbose >= 3 )
190 cout << endl;
191
192 double tic = toc();
193
194 vector<size_t> state( nrVars() );
195 randomize_state( state );
196
197 for( size_t iter = 0; iter < props.iters; iter++ ) {
198 for( size_t j = 0; j < nrVars(); j++ )
199 resample_var( state, j );
200 update_counts( state );
201 }
202
203 if( props.verbose >= 3 ) {
204 for( size_t i = 0; i < nrVars(); i++ ) {
205 cerr << "belief for variable " << var(i) << ": " << beliefV(i) << endl;
206 cerr << "counts for variable " << var(i) << ": " << Prob( _var_counts[i].begin(), _var_counts[i].end() ) << endl;
207 }
208 }
209
210 if( props.verbose >= 3 )
211 cout << "Gibbs::run: ran " << props.iters << " passes (" << toc() - tic << " clocks)." << endl;
212
213 return 0.0;
214 }
215
216
217 Factor Gibbs::beliefV( size_t i ) const {
218 Prob p( _var_counts[i].begin(), _var_counts[i].end() );
219 p.normalize();
220 return( Factor( var(i), p ) );
221 }
222
223
224 Factor Gibbs::beliefF( size_t I ) const {
225 Prob p( _factor_counts[I].begin(), _factor_counts[I].end() );
226 p.normalize();
227 return( Factor( factor(I).vars(), p ) );
228 }
229
230
231 Factor Gibbs::belief( const Var &n ) const {
232 return( beliefV( findVar( n ) ) );
233 }
234
235
236 vector<Factor> Gibbs::beliefs() const {
237 vector<Factor> result;
238 for( size_t i = 0; i < nrVars(); i++ )
239 result.push_back( beliefV(i) );
240 for( size_t I = 0; I < nrFactors(); I++ )
241 result.push_back( beliefF(I) );
242 return result;
243 }
244
245
246 Factor Gibbs::belief( const VarSet &ns ) const {
247 if( ns.size() == 1 )
248 return belief( *(ns.begin()) );
249 else {
250 size_t I;
251 for( I = 0; I < nrFactors(); I++ )
252 if( factor(I).vars() >> ns )
253 break;
254 assert( I != nrFactors() );
255 return beliefF(I).marginal(ns);
256 }
257 }
258
259
260 } // end of namespace dai