Some small documentation updates
[libdai.git] / src / bp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <sstream>
14 #include <map>
15 #include <set>
16 #include <algorithm>
17 #include <stack>
18 #include <dai/bp.h>
19 #include <dai/util.h>
20 #include <dai/properties.h>
21
22
23 namespace dai {
24
25
26 using namespace std;
27
28
29 const char *BP::Name = "BP";
30
31
32 #define DAI_BP_FAST 1
33
34
35 void BP::setProperties( const PropertySet &opts ) {
36 DAI_ASSERT( opts.hasKey("tol") );
37 DAI_ASSERT( opts.hasKey("logdomain") );
38 DAI_ASSERT( opts.hasKey("updates") );
39
40 props.tol = opts.getStringAs<Real>("tol");
41 props.logdomain = opts.getStringAs<bool>("logdomain");
42 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
43
44 if( opts.hasKey("maxiter") )
45 props.maxiter = opts.getStringAs<size_t>("maxiter");
46 else
47 props.maxiter = 10000;
48 if( opts.hasKey("maxtime") )
49 props.maxtime = opts.getStringAs<Real>("maxtime");
50 else
51 props.maxtime = INFINITY;
52 if( opts.hasKey("verbose") )
53 props.verbose = opts.getStringAs<size_t>("verbose");
54 else
55 props.verbose = 0;
56 if( opts.hasKey("damping") )
57 props.damping = opts.getStringAs<Real>("damping");
58 else
59 props.damping = 0.0;
60 if( opts.hasKey("inference") )
61 props.inference = opts.getStringAs<Properties::InfType>("inference");
62 else
63 props.inference = Properties::InfType::SUMPROD;
64 }
65
66
67 PropertySet BP::getProperties() const {
68 PropertySet opts;
69 opts.set( "tol", props.tol );
70 opts.set( "maxiter", props.maxiter );
71 opts.set( "maxtime", props.maxtime );
72 opts.set( "verbose", props.verbose );
73 opts.set( "logdomain", props.logdomain );
74 opts.set( "updates", props.updates );
75 opts.set( "damping", props.damping );
76 opts.set( "inference", props.inference );
77 return opts;
78 }
79
80
81 string BP::printProperties() const {
82 stringstream s( stringstream::out );
83 s << "[";
84 s << "tol=" << props.tol << ",";
85 s << "maxiter=" << props.maxiter << ",";
86 s << "maxtime=" << props.maxtime << ",";
87 s << "verbose=" << props.verbose << ",";
88 s << "logdomain=" << props.logdomain << ",";
89 s << "updates=" << props.updates << ",";
90 s << "damping=" << props.damping << ",";
91 s << "inference=" << props.inference << "]";
92 return s.str();
93 }
94
95
96 void BP::construct() {
97 // create edge properties
98 _edges.clear();
99 _edges.reserve( nrVars() );
100 _edge2lut.clear();
101 if( props.updates == Properties::UpdateType::SEQMAX )
102 _edge2lut.reserve( nrVars() );
103 for( size_t i = 0; i < nrVars(); ++i ) {
104 _edges.push_back( vector<EdgeProp>() );
105 _edges[i].reserve( nbV(i).size() );
106 if( props.updates == Properties::UpdateType::SEQMAX ) {
107 _edge2lut.push_back( vector<LutType::iterator>() );
108 _edge2lut[i].reserve( nbV(i).size() );
109 }
110 foreach( const Neighbor &I, nbV(i) ) {
111 EdgeProp newEP;
112 newEP.message = Prob( var(i).states() );
113 newEP.newMessage = Prob( var(i).states() );
114
115 if( DAI_BP_FAST ) {
116 newEP.index.reserve( factor(I).nrStates() );
117 for( IndexFor k( var(i), factor(I).vars() ); k.valid(); ++k )
118 newEP.index.push_back( k );
119 }
120
121 newEP.residual = 0.0;
122 _edges[i].push_back( newEP );
123 if( props.updates == Properties::UpdateType::SEQMAX )
124 _edge2lut[i].push_back( _lut.insert( make_pair( newEP.residual, make_pair( i, _edges[i].size() - 1 ))) );
125 }
126 }
127
128 // create old beliefs
129 _oldBeliefsV.clear();
130 _oldBeliefsV.reserve( nrVars() );
131 for( size_t i = 0; i < nrVars(); ++i )
132 _oldBeliefsV.push_back( Factor( var(i) ) );
133 _oldBeliefsF.clear();
134 _oldBeliefsF.reserve( nrFactors() );
135 for( size_t I = 0; I < nrFactors(); ++I )
136 _oldBeliefsF.push_back( Factor( factor(I).vars() ) );
137
138 // create update sequence
139 _updateSeq.clear();
140 _updateSeq.reserve( nrEdges() );
141 for( size_t I = 0; I < nrFactors(); I++ )
142 foreach( const Neighbor &i, nbF(I) )
143 _updateSeq.push_back( Edge( i, i.dual ) );
144 }
145
146
147 void BP::init() {
148 Real c = props.logdomain ? 0.0 : 1.0;
149 for( size_t i = 0; i < nrVars(); ++i ) {
150 foreach( const Neighbor &I, nbV(i) ) {
151 message( i, I.iter ).fill( c );
152 newMessage( i, I.iter ).fill( c );
153 if( props.updates == Properties::UpdateType::SEQMAX )
154 updateResidual( i, I.iter, 0.0 );
155 }
156 }
157 _iters = 0;
158 }
159
160
161 void BP::findMaxResidual( size_t &i, size_t &_I ) {
162 DAI_ASSERT( !_lut.empty() );
163 LutType::const_iterator largestEl = _lut.end();
164 --largestEl;
165 i = largestEl->second.first;
166 _I = largestEl->second.second;
167 }
168
169
170 Prob BP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
171 Factor Fprod( factor(I) );
172 Prob &prod = Fprod.p();
173 if( props.logdomain )
174 prod.takeLog();
175
176 // Calculate product of incoming messages and factor I
177 foreach( const Neighbor &j, nbF(I) )
178 if( !(without_i && (j == i)) ) {
179 // prod_j will be the product of messages coming into j
180 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
181 foreach( const Neighbor &J, nbV(j) )
182 if( J != I ) { // for all J in nb(j) \ I
183 if( props.logdomain )
184 prod_j += message( j, J.iter );
185 else
186 prod_j *= message( j, J.iter );
187 }
188
189 // multiply prod with prod_j
190 if( !DAI_BP_FAST ) {
191 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
192 if( props.logdomain )
193 Fprod += Factor( var(j), prod_j );
194 else
195 Fprod *= Factor( var(j), prod_j );
196 } else {
197 // OPTIMIZED VERSION
198 size_t _I = j.dual;
199 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
200 const ind_t &ind = index(j, _I);
201
202 for( size_t r = 0; r < prod.size(); ++r )
203 if( props.logdomain )
204 prod.set( r, prod[r] + prod_j[ind[r]] );
205 else
206 prod.set( r, prod[r] * prod_j[ind[r]] );
207 }
208 }
209 return prod;
210 }
211
212
213 void BP::calcNewMessage( size_t i, size_t _I ) {
214 // calculate updated message I->i
215 size_t I = nbV(i,_I);
216
217 Prob marg;
218 if( factor(I).vars().size() == 1 ) // optimization
219 marg = factor(I).p();
220 else {
221 Factor Fprod( factor(I) );
222 Prob &prod = Fprod.p();
223 prod = calcIncomingMessageProduct( I, true, i );
224
225 if( props.logdomain ) {
226 prod -= prod.max();
227 prod.takeExp();
228 }
229
230 // Marginalize onto i
231 if( !DAI_BP_FAST ) {
232 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
233 if( props.inference == Properties::InfType::SUMPROD )
234 marg = Fprod.marginal( var(i) ).p();
235 else
236 marg = Fprod.maxMarginal( var(i) ).p();
237 } else {
238 // OPTIMIZED VERSION
239 marg = Prob( var(i).states(), 0.0 );
240 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
241 const ind_t ind = index(i,_I);
242 if( props.inference == Properties::InfType::SUMPROD )
243 for( size_t r = 0; r < prod.size(); ++r )
244 marg.set( ind[r], marg[ind[r]] + prod[r] );
245 else
246 for( size_t r = 0; r < prod.size(); ++r )
247 if( prod[r] > marg[ind[r]] )
248 marg.set( ind[r], prod[r] );
249 marg.normalize();
250 }
251 }
252
253 // Store result
254 if( props.logdomain )
255 newMessage(i,_I) = marg.log();
256 else
257 newMessage(i,_I) = marg;
258
259 // Update the residual if necessary
260 if( props.updates == Properties::UpdateType::SEQMAX )
261 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), DISTLINF ) );
262 }
263
264
265 // BP::run does not check for NANs for performance reasons
266 // Somehow NaNs do not often occur in BP...
267 Real BP::run() {
268 if( props.verbose >= 1 )
269 cerr << "Starting " << identify() << "...";
270 if( props.verbose >= 3)
271 cerr << endl;
272
273 double tic = toc();
274
275 // do several passes over the network until maximum number of iterations has
276 // been reached or until the maximum belief difference is smaller than tolerance
277 Real maxDiff = INFINITY;
278 for( ; _iters < props.maxiter && maxDiff > props.tol && (toc() - tic) < props.maxtime; _iters++ ) {
279 if( props.updates == Properties::UpdateType::SEQMAX ) {
280 if( _iters == 0 ) {
281 // do the first pass
282 for( size_t i = 0; i < nrVars(); ++i )
283 foreach( const Neighbor &I, nbV(i) )
284 calcNewMessage( i, I.iter );
285 }
286 // Maximum-Residual BP [\ref EMK06]
287 for( size_t t = 0; t < _updateSeq.size(); ++t ) {
288 // update the message with the largest residual
289 size_t i, _I;
290 findMaxResidual( i, _I );
291 updateMessage( i, _I );
292
293 // I->i has been updated, which means that residuals for all
294 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
295 foreach( const Neighbor &J, nbV(i) ) {
296 if( J.iter != _I ) {
297 foreach( const Neighbor &j, nbF(J) ) {
298 size_t _J = j.dual;
299 if( j != i )
300 calcNewMessage( j, _J );
301 }
302 }
303 }
304 }
305 } else if( props.updates == Properties::UpdateType::PARALL ) {
306 // Parallel updates
307 for( size_t i = 0; i < nrVars(); ++i )
308 foreach( const Neighbor &I, nbV(i) )
309 calcNewMessage( i, I.iter );
310
311 for( size_t i = 0; i < nrVars(); ++i )
312 foreach( const Neighbor &I, nbV(i) )
313 updateMessage( i, I.iter );
314 } else {
315 // Sequential updates
316 if( props.updates == Properties::UpdateType::SEQRND )
317 random_shuffle( _updateSeq.begin(), _updateSeq.end() );
318
319 foreach( const Edge &e, _updateSeq ) {
320 calcNewMessage( e.first, e.second );
321 updateMessage( e.first, e.second );
322 }
323 }
324
325 // calculate new beliefs and compare with old ones
326 maxDiff = -INFINITY;
327 for( size_t i = 0; i < nrVars(); ++i ) {
328 Factor b( beliefV(i) );
329 maxDiff = std::max( maxDiff, dist( b, _oldBeliefsV[i], DISTLINF ) );
330 _oldBeliefsV[i] = b;
331 }
332 for( size_t I = 0; I < nrFactors(); ++I ) {
333 Factor b( beliefF(I) );
334 maxDiff = std::max( maxDiff, dist( b, _oldBeliefsF[I], DISTLINF ) );
335 _oldBeliefsF[I] = b;
336 }
337
338 if( props.verbose >= 3 )
339 cerr << Name << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
340 }
341
342 if( maxDiff > _maxdiff )
343 _maxdiff = maxDiff;
344
345 if( props.verbose >= 1 ) {
346 if( maxDiff > props.tol ) {
347 if( props.verbose == 1 )
348 cerr << endl;
349 cerr << Name << "::run: WARNING: not converged after " << _iters << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
350 } else {
351 if( props.verbose >= 3 )
352 cerr << Name << "::run: ";
353 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
354 }
355 }
356
357 return maxDiff;
358 }
359
360
361 void BP::calcBeliefV( size_t i, Prob &p ) const {
362 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
363 foreach( const Neighbor &I, nbV(i) )
364 if( props.logdomain )
365 p += newMessage( i, I.iter );
366 else
367 p *= newMessage( i, I.iter );
368 }
369
370
371 Factor BP::beliefV( size_t i ) const {
372 Prob p;
373 calcBeliefV( i, p );
374
375 if( props.logdomain ) {
376 p -= p.max();
377 p.takeExp();
378 }
379 p.normalize();
380
381 return( Factor( var(i), p ) );
382 }
383
384
385 Factor BP::beliefF( size_t I ) const {
386 Prob p;
387 calcBeliefF( I, p );
388
389 if( props.logdomain ) {
390 p -= p.max();
391 p.takeExp();
392 }
393 p.normalize();
394
395 return( Factor( factor(I).vars(), p ) );
396 }
397
398
399 vector<Factor> BP::beliefs() const {
400 vector<Factor> result;
401 for( size_t i = 0; i < nrVars(); ++i )
402 result.push_back( beliefV(i) );
403 for( size_t I = 0; I < nrFactors(); ++I )
404 result.push_back( beliefF(I) );
405 return result;
406 }
407
408
409 Factor BP::belief( const VarSet &ns ) const {
410 if( ns.size() == 0 )
411 return Factor();
412 else if( ns.size() == 1 )
413 return beliefV( findVar( *(ns.begin() ) ) );
414 else {
415 size_t I;
416 for( I = 0; I < nrFactors(); I++ )
417 if( factor(I).vars() >> ns )
418 break;
419 if( I == nrFactors() )
420 DAI_THROW(BELIEF_NOT_AVAILABLE);
421 return beliefF(I).marginal(ns);
422 }
423 }
424
425
426 Real BP::logZ() const {
427 Real sum = 0.0;
428 for( size_t i = 0; i < nrVars(); ++i )
429 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
430 for( size_t I = 0; I < nrFactors(); ++I )
431 sum -= dist( beliefF(I), factor(I), DISTKL );
432 return sum;
433 }
434
435
436 string BP::identify() const {
437 return string(Name) + printProperties();
438 }
439
440
441 void BP::init( const VarSet &ns ) {
442 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
443 size_t ni = findVar( *n );
444 foreach( const Neighbor &I, nbV( ni ) ) {
445 Real val = props.logdomain ? 0.0 : 1.0;
446 message( ni, I.iter ).fill( val );
447 newMessage( ni, I.iter ).fill( val );
448 if( props.updates == Properties::UpdateType::SEQMAX )
449 updateResidual( ni, I.iter, 0.0 );
450 }
451 }
452 _iters = 0;
453 }
454
455
456 void BP::updateMessage( size_t i, size_t _I ) {
457 if( recordSentMessages )
458 _sentMessages.push_back(make_pair(i,_I));
459 if( props.damping == 0.0 ) {
460 message(i,_I) = newMessage(i,_I);
461 if( props.updates == Properties::UpdateType::SEQMAX )
462 updateResidual( i, _I, 0.0 );
463 } else {
464 if( props.logdomain )
465 message(i,_I) = (message(i,_I) * props.damping) + (newMessage(i,_I) * (1.0 - props.damping));
466 else
467 message(i,_I) = (message(i,_I) ^ props.damping) * (newMessage(i,_I) ^ (1.0 - props.damping));
468 if( props.updates == Properties::UpdateType::SEQMAX )
469 updateResidual( i, _I, dist( newMessage(i,_I), message(i,_I), DISTLINF ) );
470 }
471 }
472
473
474 void BP::updateResidual( size_t i, size_t _I, Real r ) {
475 EdgeProp* pEdge = &_edges[i][_I];
476 pEdge->residual = r;
477
478 // rearrange look-up table (delete and reinsert new key)
479 _lut.erase( _edge2lut[i][_I] );
480 _edge2lut[i][_I] = _lut.insert( make_pair( r, make_pair(i, _I) ) );
481 }
482
483
484 std::vector<size_t> BP::findMaximum() const {
485 vector<size_t> maximum( nrVars() );
486 vector<bool> visitedVars( nrVars(), false );
487 vector<bool> visitedFactors( nrFactors(), false );
488 stack<size_t> scheduledFactors;
489 for( size_t i = 0; i < nrVars(); ++i ) {
490 if( visitedVars[i] )
491 continue;
492 visitedVars[i] = true;
493
494 // Maximise with respect to variable i
495 Prob prod;
496 calcBeliefV( i, prod );
497 maximum[i] = prod.argmax().first;
498
499 foreach( const Neighbor &I, nbV(i) )
500 if( !visitedFactors[I] )
501 scheduledFactors.push(I);
502
503 while( !scheduledFactors.empty() ){
504 size_t I = scheduledFactors.top();
505 scheduledFactors.pop();
506 if( visitedFactors[I] )
507 continue;
508 visitedFactors[I] = true;
509
510 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
511 bool allDetermined = true;
512 foreach( const Neighbor &j, nbF(I) )
513 if( !visitedVars[j.node] ) {
514 allDetermined = false;
515 break;
516 }
517 if( allDetermined )
518 continue;
519
520 // Calculate product of incoming messages on factor I
521 Prob prod2;
522 calcBeliefF( I, prod2 );
523
524 // The allowed configuration is restrained according to the variables assigned so far:
525 // pick the argmax amongst the allowed states
526 Real maxProb = -numeric_limits<Real>::max();
527 State maxState( factor(I).vars() );
528 for( State s( factor(I).vars() ); s.valid(); ++s ){
529 // First, calculate whether this state is consistent with variables that
530 // have been assigned already
531 bool allowedState = true;
532 foreach( const Neighbor &j, nbF(I) )
533 if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
534 allowedState = false;
535 break;
536 }
537 // If it is consistent, check if its probability is larger than what we have seen so far
538 if( allowedState && prod2[s] > maxProb ) {
539 maxState = s;
540 maxProb = prod2[s];
541 }
542 }
543
544 // Decode the argmax
545 foreach( const Neighbor &j, nbF(I) ) {
546 if( visitedVars[j.node] ) {
547 // We have already visited j earlier - hopefully our state is consistent
548 if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
549 cerr << "BP::findMaximum - warning: maximum not consistent due to loops." << endl;
550 } else {
551 // We found a consistent state for variable j
552 visitedVars[j.node] = true;
553 maximum[j.node] = maxState( var(j.node) );
554 foreach( const Neighbor &J, nbV(j) )
555 if( !visitedFactors[J] )
556 scheduledFactors.push(J);
557 }
558 }
559 }
560 }
561 return maximum;
562 }
563
564
565 } // end of namespace dai