[Frederik Eaton] Major cleanup of BBP and CBP code and documentation
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4 Giuseppe Passino
5
6 This file is part of libDAI.
7
8 libDAI is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 2 of the License, or
11 (at your option) any later version.
12
13 libDAI is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
17
18 You should have received a copy of the GNU General Public License
19 along with libDAI; if not, write to the Free Software
20 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
21 */
22
23
24 #include <iostream>
25 #include <sstream>
26 #include <map>
27 #include <set>
28 #include <algorithm>
29 #include <stack>
30 #include <dai/bp.h>
31 #include <dai/util.h>
32 #include <dai/properties.h>
33
34
35 namespace dai {
36
37
38 using namespace std;
39
40
41 const char *BP::Name = "BP";
42
43
44 #define DAI_BP_FAST 1
45
46
47 void BP::setProperties( const PropertySet &opts ) {
48 assert( opts.hasKey("tol") );
49 assert( opts.hasKey("maxiter") );
50 assert( opts.hasKey("logdomain") );
51 assert( opts.hasKey("updates") );
52
53 props.tol = opts.getStringAs<double>("tol");
54 props.maxiter = opts.getStringAs<size_t>("maxiter");
55 props.logdomain = opts.getStringAs<bool>("logdomain");
56 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
57
58 if( opts.hasKey("verbose") )
59 props.verbose = opts.getStringAs<size_t>("verbose");
60 else
61 props.verbose = 0;
62 if( opts.hasKey("damping") )
63 props.damping = opts.getStringAs<double>("damping");
64 else
65 props.damping = 0.0;
66 if( opts.hasKey("inference") )
67 props.inference = opts.getStringAs<Properties::InfType>("inference");
68 else
69 props.inference = Properties::InfType::SUMPROD;
70 }
71
72
73 PropertySet BP::getProperties() const {
74 PropertySet opts;
75 opts.Set( "tol", props.tol );
76 opts.Set( "maxiter", props.maxiter );
77 opts.Set( "verbose", props.verbose );
78 opts.Set( "logdomain", props.logdomain );
79 opts.Set( "updates", props.updates );
80 opts.Set( "damping", props.damping );
81 opts.Set( "inference", props.inference );
82 return opts;
83 }
84
85
86 string BP::printProperties() const {
87 stringstream s( stringstream::out );
88 s << "[";
89 s << "tol=" << props.tol << ",";
90 s << "maxiter=" << props.maxiter << ",";
91 s << "verbose=" << props.verbose << ",";
92 s << "logdomain=" << props.logdomain << ",";
93 s << "updates=" << props.updates << ",";
94 s << "damping=" << props.damping << ",";
95 s << "inference=" << props.inference << "]";
96 return s.str();
97 }
98
99
100 void BP::construct() {
101 // create edge properties
102 _edges.clear();
103 _edges.reserve( nrVars() );
104 _edge2lut.clear();
105 if( props.updates == Properties::UpdateType::SEQMAX )
106 _edge2lut.reserve( nrVars() );
107 for( size_t i = 0; i < nrVars(); ++i ) {
108 _edges.push_back( vector<EdgeProp>() );
109 _edges[i].reserve( nbV(i).size() );
110 if( props.updates == Properties::UpdateType::SEQMAX ) {
111 _edge2lut.push_back( vector<LutType::iterator>() );
112 _edge2lut[i].reserve( nbV(i).size() );
113 }
114 foreach( const Neighbor &I, nbV(i) ) {
115 EdgeProp newEP;
116 newEP.message = Prob( var(i).states() );
117 newEP.newMessage = Prob( var(i).states() );
118
119 if( DAI_BP_FAST ) {
120 newEP.index.reserve( factor(I).states() );
121 for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
122 newEP.index.push_back( k );
123 }
124
125 newEP.residual = 0.0;
126 _edges[i].push_back( newEP );
127 if( props.updates == Properties::UpdateType::SEQMAX )
128 _edge2lut[i].push_back( _lut.insert( std::make_pair( newEP.residual, std::make_pair( i, _edges[i].size() - 1 ))) );
129 }
130 }
131 }
132
133
134 void BP::init() {
135 double c = props.logdomain ? 0.0 : 1.0;
136 for( size_t i = 0; i < nrVars(); ++i ) {
137 foreach( const Neighbor &I, nbV(i) ) {
138 message( i, I.iter ).fill( c );
139 newMessage( i, I.iter ).fill( c );
140 if( props.updates == Properties::UpdateType::SEQMAX )
141 updateResidual( i, I.iter, 0.0 );
142 }
143 }
144 }
145
146
147 void BP::findMaxResidual( size_t &i, size_t &_I ) {
148 /*
149 i = 0;
150 _I = 0;
151 double maxres = residual( i, _I );
152 for( size_t j = 0; j < nrVars(); ++j )
153 foreach( const Neighbor &I, nbV(j) )
154 if( residual( j, I.iter ) > maxres ) {
155 i = j;
156 _I = I.iter;
157 maxres = residual( i, _I );
158 }
159 */
160 assert( !_lut.empty() );
161 LutType::const_iterator largestEl = _lut.end();
162 --largestEl;
163 i = largestEl->second.first;
164 _I = largestEl->second.second;
165 }
166
167
168 void BP::calcNewMessage( size_t i, size_t _I ) {
169 // calculate updated message I->i
170 size_t I = nbV(i,_I);
171
172 if( !DAI_BP_FAST ) {
173 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
174 Factor prod( factor( I ) );
175 foreach( const Neighbor &j, nbF(I) )
176 if( j != i ) { // for all j in I \ i
177 foreach( const Neighbor &J, nbV(j) )
178 if( J != I ) { // for all J in nb(j) \ I
179 prod *= Factor( var(j), message(j, J.iter) );
180 }
181 }
182 newMessage(i,_I) = prod.marginal( var(i) ).p();
183 } else {
184 /* OPTIMIZED VERSION */
185 Prob prod( factor(I).p() );
186 if( props.logdomain )
187 prod.takeLog();
188
189 // Calculate product of incoming messages and factor I
190 foreach( const Neighbor &j, nbF(I) ) {
191 if( j != i ) { // for all j in I \ i
192 size_t _I = j.dual;
193 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
194 const ind_t &ind = index(j, _I);
195
196 // prod_j will be the product of messages coming into j
197 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
198 foreach( const Neighbor &J, nbV(j) )
199 if( J != I ) { // for all J in nb(j) \ I
200 if( props.logdomain )
201 prod_j += message( j, J.iter );
202 else
203 prod_j *= message( j, J.iter );
204 }
205
206 // multiply prod with prod_j
207 for( size_t r = 0; r < prod.size(); ++r )
208 if( props.logdomain )
209 prod[r] += prod_j[ind[r]];
210 else
211 prod[r] *= prod_j[ind[r]];
212 }
213 }
214 if( props.logdomain ) {
215 prod -= prod.max();
216 prod.takeExp();
217 }
218
219 // Marginalize onto i
220 Prob marg( var(i).states(), 0.0 );
221 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
222 const ind_t ind = index(i,_I);
223 if( props.inference == Properties::InfType::SUMPROD )
224 for( size_t r = 0; r < prod.size(); ++r )
225 marg[ind[r]] += prod[r];
226 else
227 for( size_t r = 0; r < prod.size(); ++r )
228 if( prod[r] > marg[ind[r]] )
229 marg[ind[r]] = prod[r];
230 marg.normalize();
231
232 // Store result
233 if( props.logdomain )
234 newMessage(i,_I) = marg.log();
235 else
236 newMessage(i,_I) = marg;
237 }
238
239 // Update the residual if necessary
240 if( props.updates == Properties::UpdateType::SEQMAX )
241 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
242 }
243
244
245 // BP::run does not check for NANs for performance reasons
246 // Somehow NaNs do not often occur in BP...
247 double BP::run() {
248 if( props.verbose >= 1 )
249 cerr << "Starting " << identify() << "...";
250 if( props.verbose >= 3)
251 cerr << endl;
252
253 double tic = toc();
254 Diffs diffs(nrVars(), 1.0);
255
256 vector<Edge> update_seq;
257
258 vector<Factor> old_beliefs;
259 old_beliefs.reserve( nrVars() );
260 for( size_t i = 0; i < nrVars(); ++i )
261 old_beliefs.push_back( beliefV(i) );
262
263 size_t nredges = nrEdges();
264
265 if( props.updates == Properties::UpdateType::SEQMAX ) {
266 // do the first pass
267 for( size_t i = 0; i < nrVars(); ++i )
268 foreach( const Neighbor &I, nbV(i) ) {
269 calcNewMessage( i, I.iter );
270 }
271 } else {
272 update_seq.reserve( nredges );
273 /// \todo Investigate whether performance increases by switching the order of following two loops:
274 for( size_t i = 0; i < nrVars(); ++i )
275 foreach( const Neighbor &I, nbV(i) )
276 update_seq.push_back( Edge( i, I.iter ) );
277 }
278
279 // do several passes over the network until maximum number of iterations has
280 // been reached or until the maximum belief difference is smaller than tolerance
281 for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; ++_iters ) {
282 if( props.updates == Properties::UpdateType::SEQMAX ) {
283 // Residuals-BP by Koller et al.
284 for( size_t t = 0; t < nredges; ++t ) {
285 // update the message with the largest residual
286 size_t i, _I;
287 findMaxResidual( i, _I );
288 updateMessage( i, _I );
289
290 // I->i has been updated, which means that residuals for all
291 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
292 foreach( const Neighbor &J, nbV(i) ) {
293 if( J.iter != _I ) {
294 foreach( const Neighbor &j, nbF(J) ) {
295 size_t _J = j.dual;
296 if( j != i )
297 calcNewMessage( j, _J );
298 }
299 }
300 }
301 }
302 } else if( props.updates == Properties::UpdateType::PARALL ) {
303 // Parallel updates
304 for( size_t i = 0; i < nrVars(); ++i )
305 foreach( const Neighbor &I, nbV(i) )
306 calcNewMessage( i, I.iter );
307
308 for( size_t i = 0; i < nrVars(); ++i )
309 foreach( const Neighbor &I, nbV(i) )
310 updateMessage( i, I.iter );
311 } else {
312 // Sequential updates
313 if( props.updates == Properties::UpdateType::SEQRND )
314 random_shuffle( update_seq.begin(), update_seq.end() );
315
316 foreach( const Edge &e, update_seq ) {
317 calcNewMessage( e.first, e.second );
318 updateMessage( e.first, e.second );
319 }
320 }
321
322 // calculate new beliefs and compare with old ones
323 for( size_t i = 0; i < nrVars(); ++i ) {
324 Factor nb( beliefV(i) );
325 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
326 old_beliefs[i] = nb;
327 }
328
329 if( props.verbose >= 3 )
330 cerr << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
331 }
332
333 if( diffs.maxDiff() > _maxdiff )
334 _maxdiff = diffs.maxDiff();
335
336 if( props.verbose >= 1 ) {
337 if( diffs.maxDiff() > props.tol ) {
338 if( props.verbose == 1 )
339 cerr << endl;
340 cerr << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
341 } else {
342 if( props.verbose >= 3 )
343 cerr << Name << "::run: ";
344 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
345 }
346 }
347
348 return diffs.maxDiff();
349 }
350
351
352 void BP::calcBeliefV( size_t i, Prob &p ) const {
353 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
354 foreach( const Neighbor &I, nbV(i) )
355 if( props.logdomain )
356 p += newMessage( i, I.iter );
357 else
358 p *= newMessage( i, I.iter );
359 }
360
361
362 void BP::calcBeliefF( size_t I, Prob &p ) const {
363 p = factor(I).p();
364 if( props.logdomain )
365 p.takeLog();
366
367 foreach( const Neighbor &j, nbF(I) ) {
368 size_t _I = j.dual;
369 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
370 const ind_t & ind = index(j, _I);
371
372 // prod_j will be the product of messages coming into j
373 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
374 foreach( const Neighbor &J, nbV(j) ) {
375 if( J != I ) { // for all J in nb(j) \ I
376 if( props.logdomain )
377 prod_j += newMessage( j, J.iter );
378 else
379 prod_j *= newMessage( j, J.iter );
380 }
381 }
382
383 // multiply p with prod_j
384 for( size_t r = 0; r < p.size(); ++r ) {
385 if( props.logdomain )
386 p[r] += prod_j[ind[r]];
387 else
388 p[r] *= prod_j[ind[r]];
389 }
390 }
391 }
392
393
394 Factor BP::beliefV( size_t i ) const {
395 Prob p;
396 calcBeliefV( i, p );
397
398 if( props.logdomain ) {
399 p -= p.max();
400 p.takeExp();
401 }
402
403 p.normalize();
404 return( Factor( var(i), p ) );
405 }
406
407
408 Factor BP::belief( const Var &n ) const {
409 return( beliefV( findVar( n ) ) );
410 }
411
412
413 vector<Factor> BP::beliefs() const {
414 vector<Factor> result;
415 for( size_t i = 0; i < nrVars(); ++i )
416 result.push_back( beliefV(i) );
417 for( size_t I = 0; I < nrFactors(); ++I )
418 result.push_back( beliefF(I) );
419 return result;
420 }
421
422
423 Factor BP::belief( const VarSet &ns ) const {
424 if( ns.size() == 1 )
425 return belief( *(ns.begin()) );
426 else {
427 size_t I;
428 for( I = 0; I < nrFactors(); I++ )
429 if( factor(I).vars() >> ns )
430 break;
431 assert( I != nrFactors() );
432 return beliefF(I).marginal(ns);
433 }
434 }
435
436
437 Factor BP::beliefF( size_t I ) const {
438 if( !DAI_BP_FAST ) {
439 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
440
441 Factor prod( factor(I) );
442 foreach( const Neighbor &j, nbF(I) ) {
443 foreach( const Neighbor &J, nbV(j) ) {
444 if( J != I ) // for all J in nb(j) \ I
445 prod *= Factor( var(j), newMessage(j, J.iter) );
446 }
447 }
448 return prod.normalized();
449 } else {
450 /* OPTIMIZED VERSION */
451
452 Prob prod;
453 calcBeliefF( I, prod );
454
455 if( props.logdomain ) {
456 prod -= prod.max();
457 prod.takeExp();
458 }
459 prod.normalize();
460
461 Factor result( factor(I).vars(), prod );
462
463 return( result );
464 }
465 }
466
467
468 Real BP::logZ() const {
469 Real sum = 0.0;
470 for(size_t i = 0; i < nrVars(); ++i )
471 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
472 for( size_t I = 0; I < nrFactors(); ++I )
473 sum -= dist( beliefF(I), factor(I), Prob::DISTKL );
474 return sum;
475 }
476
477
478 string BP::identify() const {
479 return string(Name) + printProperties();
480 }
481
482
483 void BP::init( const VarSet &ns ) {
484 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
485 size_t ni = findVar( *n );
486 foreach( const Neighbor &I, nbV( ni ) ) {
487 double val = props.logdomain ? 0.0 : 1.0;
488 message( ni, I.iter ).fill( val );
489 newMessage( ni, I.iter ).fill( val );
490 if( props.updates == Properties::UpdateType::SEQMAX )
491 updateResidual( ni, I.iter, 0.0 );
492 }
493 }
494 }
495
496
497 void BP::updateMessage( size_t i, size_t _I ) {
498 if( recordSentMessages )
499 _sentMessages.push_back(make_pair(i,_I));
500 if( props.damping == 0.0 ) {
501 message(i,_I) = newMessage(i,_I);
502 if( props.updates == Properties::UpdateType::SEQMAX )
503 updateResidual( i, _I, 0.0 );
504 } else {
505 message(i,_I) = (message(i,_I) ^ props.damping) * (newMessage(i,_I) ^ (1.0 - props.damping));
506 if( props.updates == Properties::UpdateType::SEQMAX )
507 updateResidual( i, _I, dist( newMessage(i,_I), message(i,_I), Prob::DISTLINF ) );
508 }
509 }
510
511
512 void BP::updateResidual( size_t i, size_t _I, double r ) {
513 EdgeProp* pEdge = &_edges[i][_I];
514 pEdge->residual = r;
515
516 // rearrange look-up table (delete and reinsert new key)
517 _lut.erase( _edge2lut[i][_I] );
518 _edge2lut[i][_I] = _lut.insert( std::make_pair( r, std::make_pair(i, _I) ) );
519 }
520
521
522 std::vector<size_t> BP::findMaximum() const {
523 std::vector<size_t> maximum( nrVars() );
524 std::vector<bool> visitedVars( nrVars(), false );
525 std::vector<bool> visitedFactors( nrFactors(), false );
526 std::stack<size_t> scheduledFactors;
527 for( size_t i = 0; i < nrVars(); ++i ) {
528 if( visitedVars[i] )
529 continue;
530 visitedVars[i] = true;
531
532 // Maximise with respect to variable i
533 Prob prod;
534 calcBeliefV( i, prod );
535 maximum[i] = std::max_element( prod.begin(), prod.end() ) - prod.begin();
536
537 foreach( const Neighbor &I, nbV(i) )
538 if( !visitedFactors[I] )
539 scheduledFactors.push(I);
540
541 while( !scheduledFactors.empty() ){
542 size_t I = scheduledFactors.top();
543 scheduledFactors.pop();
544 if( visitedFactors[I] )
545 continue;
546 visitedFactors[I] = true;
547
548 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
549 bool allDetermined = true;
550 foreach( const Neighbor &j, nbF(I) )
551 if( !visitedVars[j.node] ) {
552 allDetermined = false;
553 break;
554 }
555 if( allDetermined )
556 continue;
557
558 // Calculate product of incoming messages on factor I
559 Prob prod2;
560 calcBeliefF( I, prod2 );
561
562 // The allowed configuration is restrained according to the variables assigned so far:
563 // pick the argmax amongst the allowed states
564 Real maxProb = std::numeric_limits<Real>::min();
565 State maxState( factor(I).vars() );
566 for( State s( factor(I).vars() ); s.valid(); ++s ){
567 // First, calculate whether this state is consistent with variables that
568 // have been assigned already
569 bool allowedState = true;
570 foreach( const Neighbor &j, nbF(I) )
571 if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
572 allowedState = false;
573 break;
574 }
575 // If it is consistent, check if its probability is larger than what we have seen so far
576 if( allowedState && prod2[s] > maxProb ) {
577 maxState = s;
578 maxProb = prod2[s];
579 }
580 }
581
582 // Decode the argmax
583 foreach( const Neighbor &j, nbF(I) ) {
584 if( visitedVars[j.node] ) {
585 // We have already visited j earlier - hopefully our state is consistent
586 if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
587 std::cerr << "BP::findMaximum - warning: maximum not consistent due to loops." << std::endl;
588 } else {
589 // We found a consistent state for variable j
590 visitedVars[j.node] = true;
591 maximum[j.node] = maxState( var(j.node) );
592 foreach( const Neighbor &J, nbV(j) )
593 if( !visitedFactors[J] )
594 scheduledFactors.push(J);
595 }
596 }
597 }
598 }
599 return maximum;
600 }
601
602
603 } // end of namespace dai