4581cfa9dcc355a8d114d78095d82aab0a05a3bb
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4 Giuseppe Passino
5
6 This file is part of libDAI.
7
8 libDAI is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 2 of the License, or
11 (at your option) any later version.
12
13 libDAI is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
17
18 You should have received a copy of the GNU General Public License
19 along with libDAI; if not, write to the Free Software
20 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
21 */
22
23
24 #include <iostream>
25 #include <sstream>
26 #include <map>
27 #include <set>
28 #include <algorithm>
29 #include <stack>
30 #include <dai/bp.h>
31 #include <dai/util.h>
32 #include <dai/properties.h>
33
34
35 namespace dai {
36
37
38 using namespace std;
39
40
41 const char *BP::Name = "BP";
42
43
44 #define DAI_BP_FAST 1
45
46
47 void BP::setProperties( const PropertySet &opts ) {
48 assert( opts.hasKey("tol") );
49 assert( opts.hasKey("maxiter") );
50 assert( opts.hasKey("logdomain") );
51 assert( opts.hasKey("updates") );
52
53 props.tol = opts.getStringAs<double>("tol");
54 props.maxiter = opts.getStringAs<size_t>("maxiter");
55 props.logdomain = opts.getStringAs<bool>("logdomain");
56 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
57
58 if( opts.hasKey("verbose") )
59 props.verbose = opts.getStringAs<size_t>("verbose");
60 else
61 props.verbose = 0;
62 if( opts.hasKey("damping") )
63 props.damping = opts.getStringAs<double>("damping");
64 else
65 props.damping = 0.0;
66 if( opts.hasKey("inference") )
67 props.inference = opts.getStringAs<Properties::InfType>("inference");
68 else
69 props.inference = Properties::InfType::SUMPROD;
70 }
71
72
73 PropertySet BP::getProperties() const {
74 PropertySet opts;
75 opts.Set( "tol", props.tol );
76 opts.Set( "maxiter", props.maxiter );
77 opts.Set( "verbose", props.verbose );
78 opts.Set( "logdomain", props.logdomain );
79 opts.Set( "updates", props.updates );
80 opts.Set( "damping", props.damping );
81 opts.Set( "inference", props.inference );
82 return opts;
83 }
84
85
86 string BP::printProperties() const {
87 stringstream s( stringstream::out );
88 s << "[";
89 s << "tol=" << props.tol << ",";
90 s << "maxiter=" << props.maxiter << ",";
91 s << "verbose=" << props.verbose << ",";
92 s << "logdomain=" << props.logdomain << ",";
93 s << "updates=" << props.updates << ",";
94 s << "damping=" << props.damping << ",";
95 s << "inference=" << props.inference << "]";
96 return s.str();
97 }
98
99
100 void BP::construct() {
101 // create edge properties
102 _edges.clear();
103 _edges.reserve( nrVars() );
104 for( size_t i = 0; i < nrVars(); ++i ) {
105 _edges.push_back( vector<EdgeProp>() );
106 _edges[i].reserve( nbV(i).size() );
107 foreach( const Neighbor &I, nbV(i) ) {
108 EdgeProp newEP;
109 newEP.message = Prob( var(i).states() );
110 newEP.newMessage = Prob( var(i).states() );
111
112 if( DAI_BP_FAST ) {
113 newEP.index.reserve( factor(I).states() );
114 for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
115 newEP.index.push_back( k );
116 }
117
118 newEP.residual = 0.0;
119 _edges[i].push_back( newEP );
120 }
121 }
122 }
123
124
125 void BP::init() {
126 double c = props.logdomain ? 0.0 : 1.0;
127 for( size_t i = 0; i < nrVars(); ++i ) {
128 foreach( const Neighbor &I, nbV(i) ) {
129 message( i, I.iter ).fill( c );
130 newMessage( i, I.iter ).fill( c );
131 }
132 }
133 }
134
135
136 void BP::findMaxResidual( size_t &i, size_t &_I ) {
137 i = 0;
138 _I = 0;
139 double maxres = residual( i, _I );
140 for( size_t j = 0; j < nrVars(); ++j )
141 foreach( const Neighbor &I, nbV(j) )
142 if( residual( j, I.iter ) > maxres ) {
143 i = j;
144 _I = I.iter;
145 maxres = residual( i, _I );
146 }
147 }
148
149
150 void BP::calcNewMessage( size_t i, size_t _I ) {
151 // calculate updated message I->i
152 size_t I = nbV(i,_I);
153
154 if( !DAI_BP_FAST ) {
155 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
156 Factor prod( factor( I ) );
157 foreach( const Neighbor &j, nbF(I) )
158 if( j != i ) { // for all j in I \ i
159 foreach( const Neighbor &J, nbV(j) )
160 if( J != I ) { // for all J in nb(j) \ I
161 prod *= Factor( var(j), message(j, J.iter) );
162 }
163 }
164 newMessage(i,_I) = prod.marginal( var(i) ).p();
165 } else {
166 /* OPTIMIZED VERSION */
167 Prob prod( factor(I).p() );
168 if( props.logdomain )
169 prod.takeLog();
170
171 // Calculate product of incoming messages and factor I
172 foreach( const Neighbor &j, nbF(I) ) {
173 if( j != i ) { // for all j in I \ i
174 size_t _I = j.dual;
175 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
176 const ind_t &ind = index(j, _I);
177
178 // prod_j will be the product of messages coming into j
179 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
180 foreach( const Neighbor &J, nbV(j) )
181 if( J != I ) { // for all J in nb(j) \ I
182 if( props.logdomain )
183 prod_j += message( j, J.iter );
184 else
185 prod_j *= message( j, J.iter );
186 }
187
188 // multiply prod with prod_j
189 for( size_t r = 0; r < prod.size(); ++r )
190 if( props.logdomain )
191 prod[r] += prod_j[ind[r]];
192 else
193 prod[r] *= prod_j[ind[r]];
194 }
195 }
196 if( props.logdomain ) {
197 prod -= prod.maxVal();
198 prod.takeExp();
199 }
200
201 // Marginalize onto i
202 Prob marg( var(i).states(), 0.0 );
203 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
204 const ind_t ind = index(i,_I);
205 if( props.inference == Properties::InfType::SUMPROD )
206 for( size_t r = 0; r < prod.size(); ++r )
207 marg[ind[r]] += prod[r];
208 else
209 for( size_t r = 0; r < prod.size(); ++r )
210 if( prod[r] > marg[ind[r]] )
211 marg[ind[r]] = prod[r];
212 marg.normalize();
213
214 // Store result
215 if( props.logdomain )
216 newMessage(i,_I) = marg.log();
217 else
218 newMessage(i,_I) = marg;
219 }
220 }
221
222
223 // BP::run does not check for NANs for performance reasons
224 // Somehow NaNs do not often occur in BP...
225 double BP::run() {
226 if( props.verbose >= 1 )
227 cout << "Starting " << identify() << "...";
228 if( props.verbose >= 3)
229 cout << endl;
230
231 double tic = toc();
232 Diffs diffs(nrVars(), 1.0);
233
234 vector<Edge> update_seq;
235
236 vector<Factor> old_beliefs;
237 old_beliefs.reserve( nrVars() );
238 for( size_t i = 0; i < nrVars(); ++i )
239 old_beliefs.push_back( beliefV(i) );
240
241 size_t nredges = nrEdges();
242
243 if( props.updates == Properties::UpdateType::SEQMAX ) {
244 // do the first pass
245 for( size_t i = 0; i < nrVars(); ++i )
246 foreach( const Neighbor &I, nbV(i) ) {
247 calcNewMessage( i, I.iter );
248 // calculate initial residuals
249 residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
250 }
251 } else {
252 update_seq.reserve( nredges );
253 /// \todo Investigate whether performance increases by switching the order of following two loops:
254 for( size_t i = 0; i < nrVars(); ++i )
255 foreach( const Neighbor &I, nbV(i) )
256 update_seq.push_back( Edge( i, I.iter ) );
257 }
258
259 // do several passes over the network until maximum number of iterations has
260 // been reached or until the maximum belief difference is smaller than tolerance
261 for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; ++_iters ) {
262 if( props.updates == Properties::UpdateType::SEQMAX ) {
263 // Residuals-BP by Koller et al.
264 for( size_t t = 0; t < nredges; ++t ) {
265 // update the message with the largest residual
266 size_t i, _I;
267 findMaxResidual( i, _I );
268 updateMessage( i, _I );
269
270 // I->i has been updated, which means that residuals for all
271 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
272 foreach( const Neighbor &J, nbV(i) ) {
273 if( J.iter != _I ) {
274 foreach( const Neighbor &j, nbF(J) ) {
275 size_t _J = j.dual;
276 if( j != i ) {
277 calcNewMessage( j, _J );
278 residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
279 }
280 }
281 }
282 }
283 }
284 } else if( props.updates == Properties::UpdateType::PARALL ) {
285 // Parallel updates
286 for( size_t i = 0; i < nrVars(); ++i )
287 foreach( const Neighbor &I, nbV(i) )
288 calcNewMessage( i, I.iter );
289
290 for( size_t i = 0; i < nrVars(); ++i )
291 foreach( const Neighbor &I, nbV(i) )
292 updateMessage( i, I.iter );
293 } else {
294 // Sequential updates
295 if( props.updates == Properties::UpdateType::SEQRND )
296 random_shuffle( update_seq.begin(), update_seq.end() );
297
298 foreach( const Edge &e, update_seq ) {
299 calcNewMessage( e.first, e.second );
300 updateMessage( e.first, e.second );
301 }
302 }
303
304 // calculate new beliefs and compare with old ones
305 for( size_t i = 0; i < nrVars(); ++i ) {
306 Factor nb( beliefV(i) );
307 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
308 old_beliefs[i] = nb;
309 }
310
311 if( props.verbose >= 3 )
312 cout << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
313 }
314
315 if( diffs.maxDiff() > _maxdiff )
316 _maxdiff = diffs.maxDiff();
317
318 if( props.verbose >= 1 ) {
319 if( diffs.maxDiff() > props.tol ) {
320 if( props.verbose == 1 )
321 cout << endl;
322 cout << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
323 } else {
324 if( props.verbose >= 3 )
325 cout << Name << "::run: ";
326 cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
327 }
328 }
329
330 return diffs.maxDiff();
331 }
332
333
334 void BP::calcBeliefV( size_t i, Prob &p ) const {
335 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
336 foreach( const Neighbor &I, nbV(i) )
337 if( props.logdomain )
338 p += newMessage( i, I.iter );
339 else
340 p *= newMessage( i, I.iter );
341 }
342
343
344 void BP::calcBeliefF( size_t I, Prob &p ) const {
345 p = factor(I).p();
346 if( props.logdomain )
347 p.takeLog();
348
349 foreach( const Neighbor &j, nbF(I) ) {
350 size_t _I = j.dual;
351 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
352 const ind_t & ind = index(j, _I);
353
354 // prod_j will be the product of messages coming into j
355 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
356 foreach( const Neighbor &J, nbV(j) ) {
357 if( J != I ) { // for all J in nb(j) \ I
358 if( props.logdomain )
359 prod_j += newMessage( j, J.iter );
360 else
361 prod_j *= newMessage( j, J.iter );
362 }
363 }
364
365 // multiply p with prod_j
366 for( size_t r = 0; r < p.size(); ++r ) {
367 if( props.logdomain )
368 p[r] += prod_j[ind[r]];
369 else
370 p[r] *= prod_j[ind[r]];
371 }
372 }
373 }
374
375
376 Factor BP::beliefV( size_t i ) const {
377 Prob p;
378 calcBeliefV( i, p );
379
380 if( props.logdomain ) {
381 p -= p.maxVal();
382 p.takeExp();
383 }
384
385 p.normalize();
386 return( Factor( var(i), p ) );
387 }
388
389
390 Factor BP::belief( const Var &n ) const {
391 return( beliefV( findVar( n ) ) );
392 }
393
394
395 vector<Factor> BP::beliefs() const {
396 vector<Factor> result;
397 for( size_t i = 0; i < nrVars(); ++i )
398 result.push_back( beliefV(i) );
399 for( size_t I = 0; I < nrFactors(); ++I )
400 result.push_back( beliefF(I) );
401 return result;
402 }
403
404
405 Factor BP::belief( const VarSet &ns ) const {
406 if( ns.size() == 1 )
407 return belief( *(ns.begin()) );
408 else {
409 size_t I;
410 for( I = 0; I < nrFactors(); I++ )
411 if( factor(I).vars() >> ns )
412 break;
413 assert( I != nrFactors() );
414 return beliefF(I).marginal(ns);
415 }
416 }
417
418
419 Factor BP::beliefF( size_t I ) const {
420 if( !DAI_BP_FAST ) {
421 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
422
423 Factor prod( factor(I) );
424 foreach( const Neighbor &j, nbF(I) ) {
425 foreach( const Neighbor &J, nbV(j) ) {
426 if( J != I ) // for all J in nb(j) \ I
427 prod *= Factor( var(j), newMessage(j, J.iter) );
428 }
429 }
430 return prod.normalized();
431 } else {
432 /* OPTIMIZED VERSION */
433
434 Prob prod;
435 calcBeliefF( I, prod );
436
437 if( props.logdomain ) {
438 prod -= prod.maxVal();
439 prod.takeExp();
440 }
441 prod.normalize();
442
443 Factor result( factor(I).vars(), prod );
444
445 return( result );
446 }
447 }
448
449
450 Real BP::logZ() const {
451 Real sum = 0.0;
452 for(size_t i = 0; i < nrVars(); ++i )
453 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
454 for( size_t I = 0; I < nrFactors(); ++I )
455 sum -= dist( beliefF(I), factor(I), Prob::DISTKL );
456 return sum;
457 }
458
459
460 string BP::identify() const {
461 return string(Name) + printProperties();
462 }
463
464
465 void BP::init( const VarSet &ns ) {
466 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
467 size_t ni = findVar( *n );
468 foreach( const Neighbor &I, nbV( ni ) )
469 message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
470 }
471 }
472
473
474 std::vector<size_t> BP::findMaximum() const {
475 std::vector<size_t> maximum( nrVars() );
476 std::vector<bool> visitedVars( nrVars(), false );
477 std::vector<bool> visitedFactors( nrFactors(), false );
478 std::stack<size_t> scheduledFactors;
479 for( size_t i = 0; i < nrVars(); ++i ) {
480 if( visitedVars[i] )
481 continue;
482 visitedVars[i] = true;
483
484 // Maximise with respect to variable i
485 Prob prod;
486 calcBeliefV( i, prod );
487 maximum[i] = std::max_element( prod.begin(), prod.end() ) - prod.begin();
488
489 foreach( const Neighbor &I, nbV(i) )
490 if( !visitedFactors[I] )
491 scheduledFactors.push(I);
492
493 while( !scheduledFactors.empty() ){
494 size_t I = scheduledFactors.top();
495 scheduledFactors.pop();
496 if( visitedFactors[I] )
497 continue;
498 visitedFactors[I] = true;
499
500 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
501 bool allDetermined = true;
502 foreach( const Neighbor &j, nbF(I) )
503 if( !visitedVars[j.node] ) {
504 allDetermined = false;
505 break;
506 }
507 if( allDetermined )
508 continue;
509
510 // Calculate product of incoming messages on factor I
511 Prob prod2;
512 calcBeliefF( I, prod2 );
513
514 // The allowed configuration is restrained according to the variables assigned so far:
515 // pick the argmax amongst the allowed states
516 Real maxProb = std::numeric_limits<Real>::min();
517 State maxState( factor(I).vars() );
518 for( State s( factor(I).vars() ); s.valid(); ++s ){
519 // First, calculate whether this state is consistent with variables that
520 // have been assigned already
521 bool allowedState = true;
522 foreach( const Neighbor &j, nbF(I) )
523 if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
524 allowedState = false;
525 break;
526 }
527 // If it is consistent, check if its probability is larger than what we have seen so far
528 if( allowedState && prod2[s] > maxProb ) {
529 maxState = s;
530 maxProb = prod2[s];
531 }
532 }
533
534 // Decode the argmax
535 foreach( const Neighbor &j, nbF(I) ) {
536 if( visitedVars[j.node] ) {
537 // We have already visited j earlier - hopefully our state is consistent
538 if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
539 std::cerr << "BP::findMaximum - warning: maximum not consistent due to loops." << std::endl;
540 } else {
541 // We found a consistent state for variable j
542 visitedVars[j.node] = true;
543 maximum[j.node] = maxState( var(j.node) );
544 foreach( const Neighbor &J, nbV(j) )
545 if( !visitedFactors[J] )
546 scheduledFactors.push(J);
547 }
548 }
549 }
550 }
551 return maximum;
552 }
553
554
555 } // end of namespace dai