Fixed regression FBP and bugs in TRWBP
[libdai.git] / src / bp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <sstream>
14 #include <map>
15 #include <set>
16 #include <algorithm>
17 #include <stack>
18 #include <dai/bp.h>
19 #include <dai/util.h>
20 #include <dai/properties.h>
21
22
23 namespace dai {
24
25
26 using namespace std;
27
28
29 const char *BP::Name = "BP";
30
31
32 #define DAI_BP_FAST 1
33
34
35 void BP::setProperties( const PropertySet &opts ) {
36 DAI_ASSERT( opts.hasKey("tol") );
37 DAI_ASSERT( opts.hasKey("maxiter") );
38 DAI_ASSERT( opts.hasKey("logdomain") );
39 DAI_ASSERT( opts.hasKey("updates") );
40
41 props.tol = opts.getStringAs<Real>("tol");
42 props.maxiter = opts.getStringAs<size_t>("maxiter");
43 props.logdomain = opts.getStringAs<bool>("logdomain");
44 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
45
46 if( opts.hasKey("verbose") )
47 props.verbose = opts.getStringAs<size_t>("verbose");
48 else
49 props.verbose = 0;
50 if( opts.hasKey("damping") )
51 props.damping = opts.getStringAs<Real>("damping");
52 else
53 props.damping = 0.0;
54 if( opts.hasKey("inference") )
55 props.inference = opts.getStringAs<Properties::InfType>("inference");
56 else
57 props.inference = Properties::InfType::SUMPROD;
58 }
59
60
61 PropertySet BP::getProperties() const {
62 PropertySet opts;
63 opts.Set( "tol", props.tol );
64 opts.Set( "maxiter", props.maxiter );
65 opts.Set( "verbose", props.verbose );
66 opts.Set( "logdomain", props.logdomain );
67 opts.Set( "updates", props.updates );
68 opts.Set( "damping", props.damping );
69 opts.Set( "inference", props.inference );
70 return opts;
71 }
72
73
74 string BP::printProperties() const {
75 stringstream s( stringstream::out );
76 s << "[";
77 s << "tol=" << props.tol << ",";
78 s << "maxiter=" << props.maxiter << ",";
79 s << "verbose=" << props.verbose << ",";
80 s << "logdomain=" << props.logdomain << ",";
81 s << "updates=" << props.updates << ",";
82 s << "damping=" << props.damping << ",";
83 s << "inference=" << props.inference << "]";
84 return s.str();
85 }
86
87
88 void BP::construct() {
89 // create edge properties
90 _edges.clear();
91 _edges.reserve( nrVars() );
92 _edge2lut.clear();
93 if( props.updates == Properties::UpdateType::SEQMAX )
94 _edge2lut.reserve( nrVars() );
95 for( size_t i = 0; i < nrVars(); ++i ) {
96 _edges.push_back( vector<EdgeProp>() );
97 _edges[i].reserve( nbV(i).size() );
98 if( props.updates == Properties::UpdateType::SEQMAX ) {
99 _edge2lut.push_back( vector<LutType::iterator>() );
100 _edge2lut[i].reserve( nbV(i).size() );
101 }
102 foreach( const Neighbor &I, nbV(i) ) {
103 EdgeProp newEP;
104 newEP.message = Prob( var(i).states() );
105 newEP.newMessage = Prob( var(i).states() );
106
107 if( DAI_BP_FAST ) {
108 newEP.index.reserve( factor(I).states() );
109 for( IndexFor k( var(i), factor(I).vars() ); k.valid(); ++k )
110 newEP.index.push_back( k );
111 }
112
113 newEP.residual = 0.0;
114 _edges[i].push_back( newEP );
115 if( props.updates == Properties::UpdateType::SEQMAX )
116 _edge2lut[i].push_back( _lut.insert( make_pair( newEP.residual, make_pair( i, _edges[i].size() - 1 ))) );
117 }
118 }
119 }
120
121
122 void BP::init() {
123 Real c = props.logdomain ? 0.0 : 1.0;
124 for( size_t i = 0; i < nrVars(); ++i ) {
125 foreach( const Neighbor &I, nbV(i) ) {
126 message( i, I.iter ).fill( c );
127 newMessage( i, I.iter ).fill( c );
128 if( props.updates == Properties::UpdateType::SEQMAX )
129 updateResidual( i, I.iter, 0.0 );
130 }
131 }
132 }
133
134
135 void BP::findMaxResidual( size_t &i, size_t &_I ) {
136 DAI_ASSERT( !_lut.empty() );
137 LutType::const_iterator largestEl = _lut.end();
138 --largestEl;
139 i = largestEl->second.first;
140 _I = largestEl->second.second;
141 }
142
143
144 void BP::calcNewMessage( size_t i, size_t _I ) {
145 // calculate updated message I->i
146 size_t I = nbV(i,_I);
147
148 Prob marg;
149 if( factor(I).vars().size() == 1 ) // optimization
150 marg = factor(I).p();
151 else {
152 Factor Fprod( factor(I) );
153 Prob &prod = Fprod.p();
154 if( props.logdomain )
155 prod.takeLog();
156
157 // Calculate product of incoming messages and factor I
158 foreach( const Neighbor &j, nbF(I) )
159 if( j != i ) { // for all j in I \ i
160 // prod_j will be the product of messages coming into j
161 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
162 foreach( const Neighbor &J, nbV(j) )
163 if( J != I ) { // for all J in nb(j) \ I
164 if( props.logdomain )
165 prod_j += message( j, J.iter );
166 else
167 prod_j *= message( j, J.iter );
168 }
169
170 // multiply prod with prod_j
171 if( !DAI_BP_FAST ) {
172 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
173 if( props.logdomain )
174 Fprod += Factor( var(j), prod_j );
175 else
176 Fprod *= Factor( var(j), prod_j );
177 } else {
178 /* OPTIMIZED VERSION */
179 size_t _I = j.dual;
180 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
181 const ind_t &ind = index(j, _I);
182 for( size_t r = 0; r < prod.size(); ++r )
183 if( props.logdomain )
184 prod[r] += prod_j[ind[r]];
185 else
186 prod[r] *= prod_j[ind[r]];
187 }
188 }
189
190 if( props.logdomain ) {
191 prod -= prod.max();
192 prod.takeExp();
193 }
194
195 // Marginalize onto i
196 if( !DAI_BP_FAST ) {
197 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
198 if( props.inference == Properties::InfType::SUMPROD )
199 marg = Fprod.marginal( var(i) ).p();
200 else
201 marg = Fprod.maxMarginal( var(i) ).p();
202 } else {
203 /* OPTIMIZED VERSION */
204 marg = Prob( var(i).states(), 0.0 );
205 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
206 const ind_t ind = index(i,_I);
207 if( props.inference == Properties::InfType::SUMPROD )
208 for( size_t r = 0; r < prod.size(); ++r )
209 marg[ind[r]] += prod[r];
210 else
211 for( size_t r = 0; r < prod.size(); ++r )
212 if( prod[r] > marg[ind[r]] )
213 marg[ind[r]] = prod[r];
214 marg.normalize();
215 }
216 }
217
218 // Store result
219 if( props.logdomain )
220 newMessage(i,_I) = marg.log();
221 else
222 newMessage(i,_I) = marg;
223
224 // Update the residual if necessary
225 if( props.updates == Properties::UpdateType::SEQMAX )
226 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
227 }
228
229
230 // BP::run does not check for NANs for performance reasons
231 // Somehow NaNs do not often occur in BP...
232 Real BP::run() {
233 if( props.verbose >= 1 )
234 cerr << "Starting " << identify() << "...";
235 if( props.verbose >= 3)
236 cerr << endl;
237
238 double tic = toc();
239 Real maxDiff = INFINITY;
240
241 vector<Factor> oldBeliefsV, oldBeliefsF;
242 oldBeliefsV.reserve( nrVars() );
243 for( size_t i = 0; i < nrVars(); ++i )
244 oldBeliefsV.push_back( beliefV(i) );
245 oldBeliefsF.reserve( nrFactors() );
246 for( size_t I = 0; I < nrFactors(); ++I )
247 oldBeliefsF.push_back( beliefF(I) );
248
249 size_t nredges = nrEdges();
250 vector<Edge> update_seq;
251 if( props.updates == Properties::UpdateType::SEQMAX ) {
252 // do the first pass
253 for( size_t i = 0; i < nrVars(); ++i )
254 foreach( const Neighbor &I, nbV(i) )
255 calcNewMessage( i, I.iter );
256 } else {
257 update_seq.reserve( nredges );
258 for( size_t I = 0; I < nrFactors(); I++ )
259 foreach( const Neighbor &i, nbF(I) )
260 update_seq.push_back( Edge( i, i.dual ) );
261 }
262
263 // do several passes over the network until maximum number of iterations has
264 // been reached or until the maximum belief difference is smaller than tolerance
265 for( _iters=0; _iters < props.maxiter && maxDiff > props.tol; ++_iters ) {
266 if( props.updates == Properties::UpdateType::SEQMAX ) {
267 // Residuals-BP by Koller et al.
268 for( size_t t = 0; t < nredges; ++t ) {
269 // update the message with the largest residual
270 size_t i, _I;
271 findMaxResidual( i, _I );
272 updateMessage( i, _I );
273
274 // I->i has been updated, which means that residuals for all
275 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
276 foreach( const Neighbor &J, nbV(i) ) {
277 if( J.iter != _I ) {
278 foreach( const Neighbor &j, nbF(J) ) {
279 size_t _J = j.dual;
280 if( j != i )
281 calcNewMessage( j, _J );
282 }
283 }
284 }
285 }
286 } else if( props.updates == Properties::UpdateType::PARALL ) {
287 // Parallel updates
288 for( size_t i = 0; i < nrVars(); ++i )
289 foreach( const Neighbor &I, nbV(i) )
290 calcNewMessage( i, I.iter );
291
292 for( size_t i = 0; i < nrVars(); ++i )
293 foreach( const Neighbor &I, nbV(i) )
294 updateMessage( i, I.iter );
295 } else {
296 // Sequential updates
297 if( props.updates == Properties::UpdateType::SEQRND )
298 random_shuffle( update_seq.begin(), update_seq.end() );
299
300 foreach( const Edge &e, update_seq ) {
301 calcNewMessage( e.first, e.second );
302 updateMessage( e.first, e.second );
303 }
304 }
305
306 // calculate new beliefs and compare with old ones
307 maxDiff = -INFINITY;
308 for( size_t i = 0; i < nrVars(); ++i ) {
309 Factor b( beliefV(i) );
310 maxDiff = std::max( maxDiff, dist( b, oldBeliefsV[i], Prob::DISTLINF ) );
311 oldBeliefsV[i] = b;
312 }
313 for( size_t I = 0; I < nrFactors(); ++I ) {
314 Factor b( beliefF(I) );
315 maxDiff = std::max( maxDiff, dist( b, oldBeliefsF[I], Prob::DISTLINF ) );
316 oldBeliefsF[I] = b;
317 }
318
319 if( props.verbose >= 3 )
320 cerr << Name << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
321 }
322
323 if( maxDiff > _maxdiff )
324 _maxdiff = maxDiff;
325
326 if( props.verbose >= 1 ) {
327 if( maxDiff > props.tol ) {
328 if( props.verbose == 1 )
329 cerr << endl;
330 cerr << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
331 } else {
332 if( props.verbose >= 3 )
333 cerr << Name << "::run: ";
334 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
335 }
336 }
337
338 return maxDiff;
339 }
340
341
342 void BP::calcBeliefV( size_t i, Prob &p ) const {
343 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
344 foreach( const Neighbor &I, nbV(i) )
345 if( props.logdomain )
346 p += newMessage( i, I.iter );
347 else
348 p *= newMessage( i, I.iter );
349 }
350
351
352 void BP::calcBeliefF( size_t I, Prob &p ) const {
353 Factor Fprod( factor( I ) );
354 Prob &prod = Fprod.p();
355
356 if( props.logdomain )
357 prod.takeLog();
358
359 foreach( const Neighbor &j, nbF(I) ) {
360 // prod_j will be the product of messages coming into j
361 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
362 foreach( const Neighbor &J, nbV(j) )
363 if( J != I ) { // for all J in nb(j) \ I
364 if( props.logdomain )
365 prod_j += newMessage( j, J.iter );
366 else
367 prod_j *= newMessage( j, J.iter );
368 }
369
370 // multiply prod with prod_j
371 if( !DAI_BP_FAST ) {
372 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
373 if( props.logdomain )
374 Fprod += Factor( var(j), prod_j );
375 else
376 Fprod *= Factor( var(j), prod_j );
377 } else {
378 /* OPTIMIZED VERSION */
379 size_t _I = j.dual;
380 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
381 const ind_t & ind = index(j, _I);
382
383 for( size_t r = 0; r < prod.size(); ++r ) {
384 if( props.logdomain )
385 prod[r] += prod_j[ind[r]];
386 else
387 prod[r] *= prod_j[ind[r]];
388 }
389 }
390 }
391
392 p = prod;
393 }
394
395
396 Factor BP::beliefV( size_t i ) const {
397 Prob p;
398 calcBeliefV( i, p );
399
400 if( props.logdomain ) {
401 p -= p.max();
402 p.takeExp();
403 }
404 p.normalize();
405
406 return( Factor( var(i), p ) );
407 }
408
409
410 Factor BP::beliefF( size_t I ) const {
411 Prob p;
412 calcBeliefF( I, p );
413
414 if( props.logdomain ) {
415 p -= p.max();
416 p.takeExp();
417 }
418 p.normalize();
419
420 return( Factor( factor(I).vars(), p ) );
421 }
422
423
424 vector<Factor> BP::beliefs() const {
425 vector<Factor> result;
426 for( size_t i = 0; i < nrVars(); ++i )
427 result.push_back( beliefV(i) );
428 for( size_t I = 0; I < nrFactors(); ++I )
429 result.push_back( beliefF(I) );
430 return result;
431 }
432
433
434 Factor BP::belief( const VarSet &ns ) const {
435 if( ns.size() == 0 )
436 return Factor();
437 else if( ns.size() == 1 )
438 return beliefV( findVar( *(ns.begin() ) ) );
439 else {
440 size_t I;
441 for( I = 0; I < nrFactors(); I++ )
442 if( factor(I).vars() >> ns )
443 break;
444 if( I == nrFactors() )
445 DAI_THROW(BELIEF_NOT_AVAILABLE);
446 return beliefF(I).marginal(ns);
447 }
448 }
449
450
451 Real BP::logZ() const {
452 Real sum = 0.0;
453 for( size_t i = 0; i < nrVars(); ++i )
454 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
455 for( size_t I = 0; I < nrFactors(); ++I )
456 sum -= dist( beliefF(I), factor(I), Prob::DISTKL );
457 return sum;
458 }
459
460
461 string BP::identify() const {
462 return string(Name) + printProperties();
463 }
464
465
466 void BP::init( const VarSet &ns ) {
467 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
468 size_t ni = findVar( *n );
469 foreach( const Neighbor &I, nbV( ni ) ) {
470 Real val = props.logdomain ? 0.0 : 1.0;
471 message( ni, I.iter ).fill( val );
472 newMessage( ni, I.iter ).fill( val );
473 if( props.updates == Properties::UpdateType::SEQMAX )
474 updateResidual( ni, I.iter, 0.0 );
475 }
476 }
477 }
478
479
480 void BP::updateMessage( size_t i, size_t _I ) {
481 if( recordSentMessages )
482 _sentMessages.push_back(make_pair(i,_I));
483 if( props.damping == 0.0 ) {
484 message(i,_I) = newMessage(i,_I);
485 if( props.updates == Properties::UpdateType::SEQMAX )
486 updateResidual( i, _I, 0.0 );
487 } else {
488 if( props.logdomain )
489 message(i,_I) = (message(i,_I) * props.damping) + (newMessage(i,_I) * (1.0 - props.damping));
490 else
491 message(i,_I) = (message(i,_I) ^ props.damping) * (newMessage(i,_I) ^ (1.0 - props.damping));
492 if( props.updates == Properties::UpdateType::SEQMAX )
493 updateResidual( i, _I, dist( newMessage(i,_I), message(i,_I), Prob::DISTLINF ) );
494 }
495 }
496
497
498 void BP::updateResidual( size_t i, size_t _I, Real r ) {
499 EdgeProp* pEdge = &_edges[i][_I];
500 pEdge->residual = r;
501
502 // rearrange look-up table (delete and reinsert new key)
503 _lut.erase( _edge2lut[i][_I] );
504 _edge2lut[i][_I] = _lut.insert( make_pair( r, make_pair(i, _I) ) );
505 }
506
507
508 std::vector<size_t> BP::findMaximum() const {
509 vector<size_t> maximum( nrVars() );
510 vector<bool> visitedVars( nrVars(), false );
511 vector<bool> visitedFactors( nrFactors(), false );
512 stack<size_t> scheduledFactors;
513 for( size_t i = 0; i < nrVars(); ++i ) {
514 if( visitedVars[i] )
515 continue;
516 visitedVars[i] = true;
517
518 // Maximise with respect to variable i
519 Prob prod;
520 calcBeliefV( i, prod );
521 maximum[i] = prod.argmax().first;
522
523 foreach( const Neighbor &I, nbV(i) )
524 if( !visitedFactors[I] )
525 scheduledFactors.push(I);
526
527 while( !scheduledFactors.empty() ){
528 size_t I = scheduledFactors.top();
529 scheduledFactors.pop();
530 if( visitedFactors[I] )
531 continue;
532 visitedFactors[I] = true;
533
534 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
535 bool allDetermined = true;
536 foreach( const Neighbor &j, nbF(I) )
537 if( !visitedVars[j.node] ) {
538 allDetermined = false;
539 break;
540 }
541 if( allDetermined )
542 continue;
543
544 // Calculate product of incoming messages on factor I
545 Prob prod2;
546 calcBeliefF( I, prod2 );
547
548 // The allowed configuration is restrained according to the variables assigned so far:
549 // pick the argmax amongst the allowed states
550 Real maxProb = numeric_limits<Real>::min();
551 State maxState( factor(I).vars() );
552 for( State s( factor(I).vars() ); s.valid(); ++s ){
553 // First, calculate whether this state is consistent with variables that
554 // have been assigned already
555 bool allowedState = true;
556 foreach( const Neighbor &j, nbF(I) )
557 if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
558 allowedState = false;
559 break;
560 }
561 // If it is consistent, check if its probability is larger than what we have seen so far
562 if( allowedState && prod2[s] > maxProb ) {
563 maxState = s;
564 maxProb = prod2[s];
565 }
566 }
567
568 // Decode the argmax
569 foreach( const Neighbor &j, nbF(I) ) {
570 if( visitedVars[j.node] ) {
571 // We have already visited j earlier - hopefully our state is consistent
572 if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
573 cerr << "BP::findMaximum - warning: maximum not consistent due to loops." << endl;
574 } else {
575 // We found a consistent state for variable j
576 visitedVars[j.node] = true;
577 maximum[j.node] = maxState( var(j.node) );
578 foreach( const Neighbor &J, nbV(j) )
579 if( !visitedFactors[J] )
580 scheduledFactors.push(J);
581 }
582 }
583 }
584 }
585 return maximum;
586 }
587
588
589 } // end of namespace dai