Merge branch 'master' of git@git.tuebingen.mpg.de:libdai
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <algorithm>
27 #include <dai/bp.h>
28 #include <dai/diffs.h>
29 #include <dai/util.h>
30 #include <dai/properties.h>
31
32
33 namespace dai {
34
35
36 using namespace std;
37
38
39 const char *BP::Name = "BP";
40
41
42 void BP::setProperties( const PropertySet &opts ) {
43 assert( opts.hasKey("tol") );
44 assert( opts.hasKey("maxiter") );
45 assert( opts.hasKey("verbose") );
46 assert( opts.hasKey("logdomain") );
47 assert( opts.hasKey("updates") );
48
49 props.tol = opts.getStringAs<double>("tol");
50 props.maxiter = opts.getStringAs<size_t>("maxiter");
51 props.verbose = opts.getStringAs<size_t>("verbose");
52 props.logdomain = opts.getStringAs<bool>("logdomain");
53 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
54 }
55
56
57 PropertySet BP::getProperties() const {
58 PropertySet opts;
59 opts.Set( "tol", props.tol );
60 opts.Set( "maxiter", props.maxiter );
61 opts.Set( "verbose", props.verbose );
62 opts.Set( "logdomain", props.logdomain );
63 opts.Set( "updates", props.updates );
64 return opts;
65 }
66
67
68 void BP::create() {
69 // create edge properties
70 edges.clear();
71 edges.reserve( nrVars() );
72 for( size_t i = 0; i < nrVars(); ++i ) {
73 edges.push_back( vector<EdgeProp>() );
74 edges[i].reserve( nbV(i).size() );
75 foreach( const Neighbor &I, nbV(i) ) {
76 EdgeProp newEP;
77 newEP.message = Prob( var(i).states() );
78 newEP.newMessage = Prob( var(i).states() );
79
80 newEP.index.reserve( factor(I).states() );
81 for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
82 newEP.index.push_back( k );
83
84 newEP.residual = 0.0;
85 edges[i].push_back( newEP );
86 }
87 }
88 }
89
90
91 void BP::init() {
92 for( size_t i = 0; i < nrVars(); ++i ) {
93 foreach( const Neighbor &I, nbV(i) ) {
94 if( props.logdomain ) {
95 message( i, I.iter ).fill( 0.0 );
96 newMessage( i, I.iter ).fill( 0.0 );
97 } else {
98 message( i, I.iter ).fill( 1.0 );
99 newMessage( i, I.iter ).fill( 1.0 );
100 }
101 }
102 }
103 }
104
105
106 void BP::findMaxResidual( size_t &i, size_t &_I ) {
107 i = 0;
108 _I = 0;
109 double maxres = residual( i, _I );
110 for( size_t j = 0; j < nrVars(); ++j )
111 foreach( const Neighbor &I, nbV(j) )
112 if( residual( j, I.iter ) > maxres ) {
113 i = j;
114 _I = I.iter;
115 maxres = residual( i, _I );
116 }
117 }
118
119
120 void BP::calcNewMessage( size_t i, size_t _I ) {
121 // calculate updated message I->i
122 size_t I = nbV(i,_I);
123
124 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
125
126 Factor prod( factor( I ) );
127 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ )
128 if( *j != i ) { // for all j in I \ i
129 for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); J++ )
130 if( *J != I ) { // for all J in nb(j) \ I
131 prod *= Factor( *j, message(*j,*J) );
132 Factor marg = prod.marginal(var(i));
133 */
134
135 Prob prod( factor(I).p() );
136 if( props.logdomain )
137 prod.takeLog();
138
139 // Calculate product of incoming messages and factor I
140 foreach( const Neighbor &j, nbF(I) ) {
141 if( j != i ) { // for all j in I \ i
142 size_t _I = j.dual;
143 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
144 const ind_t & ind = index(j, _I);
145
146 // prod_j will be the product of messages coming into j
147 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
148 foreach( const Neighbor &J, nbV(j) )
149 if( J != I ) { // for all J in nb(j) \ I
150 if( props.logdomain )
151 prod_j += message( j, J.iter );
152 else
153 prod_j *= message( j, J.iter );
154 }
155
156 // multiply prod with prod_j
157 for( size_t r = 0; r < prod.size(); ++r )
158 if( props.logdomain )
159 prod[r] += prod_j[ind[r]];
160 else
161 prod[r] *= prod_j[ind[r]];
162 }
163 }
164 if( props.logdomain ) {
165 prod -= prod.maxVal();
166 prod.takeExp();
167 }
168
169 // Marginalize onto i
170 Prob marg( var(i).states(), 0.0 );
171 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
172 const ind_t ind = index(i,_I);
173 for( size_t r = 0; r < prod.size(); ++r )
174 marg[ind[r]] += prod[r];
175 marg.normalize( Prob::NORMPROB );
176
177 // Store result
178 if( props.logdomain )
179 newMessage(i,_I) = marg.log();
180 else
181 newMessage(i,_I) = marg;
182 }
183
184
185 // BP::run does not check for NANs for performance reasons
186 // Somehow NaNs do not often occur in BP...
187 double BP::run() {
188 if( props.verbose >= 1 )
189 cout << "Starting " << identify() << "...";
190 if( props.verbose >= 3)
191 cout << endl;
192
193 double tic = toc();
194 Diffs diffs(nrVars(), 1.0);
195
196 vector<Edge> update_seq;
197
198 vector<Factor> old_beliefs;
199 old_beliefs.reserve( nrVars() );
200 for( size_t i = 0; i < nrVars(); ++i )
201 old_beliefs.push_back( beliefV(i) );
202
203 size_t iter = 0;
204 size_t nredges = nrEdges();
205
206 if( props.updates == Properties::UpdateType::SEQMAX ) {
207 // do the first pass
208 for( size_t i = 0; i < nrVars(); ++i )
209 foreach( const Neighbor &I, nbV(i) ) {
210 calcNewMessage( i, I.iter );
211 // calculate initial residuals
212 residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
213 }
214 } else {
215 update_seq.reserve( nredges );
216 for( size_t i = 0; i < nrVars(); ++i )
217 foreach( const Neighbor &I, nbV(i) )
218 update_seq.push_back( Edge( i, I.iter ) );
219 }
220
221 // do several passes over the network until maximum number of iterations has
222 // been reached or until the maximum belief difference is smaller than tolerance
223 for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; ++iter ) {
224 if( props.updates == Properties::UpdateType::SEQMAX ) {
225 // Residuals-BP by Koller et al.
226 for( size_t t = 0; t < nredges; ++t ) {
227 // update the message with the largest residual
228
229 size_t i, _I;
230 findMaxResidual( i, _I );
231 message( i, _I ) = newMessage( i, _I );
232 residual( i, _I ) = 0.0;
233
234 // I->i has been updated, which means that residuals for all
235 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
236 foreach( const Neighbor &J, nbV(i) ) {
237 if( J.iter != _I ) {
238 foreach( const Neighbor &j, nbF(J) ) {
239 size_t _J = j.dual;
240 if( j != i ) {
241 calcNewMessage( j, _J );
242 residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
243 }
244 }
245 }
246 }
247 }
248 } else if( props.updates == Properties::UpdateType::PARALL ) {
249 // Parallel updates
250 for( size_t i = 0; i < nrVars(); ++i )
251 foreach( const Neighbor &I, nbV(i) )
252 calcNewMessage( i, I.iter );
253
254 for( size_t i = 0; i < nrVars(); ++i )
255 foreach( const Neighbor &I, nbV(i) )
256 message( i, I.iter ) = newMessage( i, I.iter );
257 } else {
258 // Sequential updates
259 if( props.updates == Properties::UpdateType::SEQRND )
260 random_shuffle( update_seq.begin(), update_seq.end() );
261
262 foreach( const Edge &e, update_seq ) {
263 calcNewMessage( e.first, e.second );
264 message( e.first, e.second ) = newMessage( e.first, e.second );
265 }
266 }
267
268 // calculate new beliefs and compare with old ones
269 for( size_t i = 0; i < nrVars(); ++i ) {
270 Factor nb( beliefV(i) );
271 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
272 old_beliefs[i] = nb;
273 }
274
275 if( props.verbose >= 3 )
276 cout << "BP::run: maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
277 }
278
279 if( diffs.maxDiff() > maxdiff )
280 maxdiff = diffs.maxDiff();
281
282 if( props.verbose >= 1 ) {
283 if( diffs.maxDiff() > props.tol ) {
284 if( props.verbose == 1 )
285 cout << endl;
286 cout << "BP::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
287 } else {
288 if( props.verbose >= 3 )
289 cout << "BP::run: ";
290 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
291 }
292 }
293
294 return diffs.maxDiff();
295 }
296
297
298 Factor BP::beliefV( size_t i ) const {
299 Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 );
300 foreach( const Neighbor &I, nbV(i) )
301 if( props.logdomain )
302 prod += newMessage( i, I.iter );
303 else
304 prod *= newMessage( i, I.iter );
305 if( props.logdomain ) {
306 prod -= prod.maxVal();
307 prod.takeExp();
308 }
309
310 prod.normalize( Prob::NORMPROB );
311 return( Factor( var(i), prod ) );
312 }
313
314
315 Factor BP::belief (const Var &n) const {
316 return( beliefV( findVar( n ) ) );
317 }
318
319
320 vector<Factor> BP::beliefs() const {
321 vector<Factor> result;
322 for( size_t i = 0; i < nrVars(); ++i )
323 result.push_back( beliefV(i) );
324 for( size_t I = 0; I < nrFactors(); ++I )
325 result.push_back( beliefF(I) );
326 return result;
327 }
328
329
330 Factor BP::belief( const VarSet &ns ) const {
331 if( ns.size() == 1 )
332 return belief( *(ns.begin()) );
333 else {
334 size_t I;
335 for( I = 0; I < nrFactors(); I++ )
336 if( factor(I).vars() >> ns )
337 break;
338 assert( I != nrFactors() );
339 return beliefF(I).marginal(ns);
340 }
341 }
342
343
344 Factor BP::beliefF (size_t I) const {
345 Prob prod( factor(I).p() );
346 if( props.logdomain )
347 prod.takeLog();
348
349 foreach( const Neighbor &j, nbF(I) ) {
350 size_t _I = j.dual;
351 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
352 const ind_t & ind = index(j, _I);
353
354 // prod_j will be the product of messages coming into j
355 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
356 foreach( const Neighbor &J, nbV(j) ) {
357 if( J != I ) { // for all J in nb(j) \ I
358 if( props.logdomain )
359 prod_j += newMessage( j, J.iter );
360 else
361 prod_j *= newMessage( j, J.iter );
362 }
363 }
364
365 // multiply prod with prod_j
366 for( size_t r = 0; r < prod.size(); ++r ) {
367 if( props.logdomain )
368 prod[r] += prod_j[ind[r]];
369 else
370 prod[r] *= prod_j[ind[r]];
371 }
372 }
373
374 if( props.logdomain ) {
375 prod -= prod.maxVal();
376 prod.takeExp();
377 }
378
379 Factor result( factor(I).vars(), prod );
380 result.normalize( Prob::NORMPROB );
381
382 return( result );
383
384 /* UNOPTIMIZED VERSION
385
386 Factor prod( factor(I) );
387 for( _nb_cit i = nb2(I).begin(); i != nb2(I).end(); i++ ) {
388 for( _nb_cit J = nb1(*i).begin(); J != nb1(*i).end(); J++ )
389 if( *J != I )
390 prod *= Factor( var(*i), newMessage(*i,*J)) );
391 }
392 return prod.normalize( Prob::NORMPROB );*/
393 }
394
395
396 Real BP::logZ() const {
397 Real sum = 0.0;
398 for(size_t i = 0; i < nrVars(); ++i )
399 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
400 for( size_t I = 0; I < nrFactors(); ++I )
401 sum -= KL_dist( beliefF(I), factor(I) );
402 return sum;
403 }
404
405
406 string BP::identify() const {
407 stringstream result (stringstream::out);
408 result << Name << getProperties();
409 return result.str();
410 }
411
412
413 void BP::init( const VarSet &ns ) {
414 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
415 size_t ni = findVar( *n );
416 foreach( const Neighbor &I, nbV( ni ) )
417 message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
418 }
419 }
420
421
422 } // end of namespace dai