Several changes by Giuseppe Passino
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4 Giuseppe Passino
5
6 This file is part of libDAI.
7
8 libDAI is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 2 of the License, or
11 (at your option) any later version.
12
13 libDAI is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
17
18 You should have received a copy of the GNU General Public License
19 along with libDAI; if not, write to the Free Software
20 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
21 */
22
23
24 #include <iostream>
25 #include <sstream>
26 #include <map>
27 #include <set>
28 #include <algorithm>
29 #include <stack>
30 #include <dai/bp.h>
31 #include <dai/util.h>
32 #include <dai/properties.h>
33
34
35 namespace dai {
36
37
38 using namespace std;
39
40
41 const char *BP::Name = "BP";
42
43
44 void BP::setProperties( const PropertySet &opts ) {
45 assert( opts.hasKey("tol") );
46 assert( opts.hasKey("maxiter") );
47 assert( opts.hasKey("logdomain") );
48 assert( opts.hasKey("updates") );
49
50 props.tol = opts.getStringAs<double>("tol");
51 props.maxiter = opts.getStringAs<size_t>("maxiter");
52 props.logdomain = opts.getStringAs<bool>("logdomain");
53 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
54
55 if( opts.hasKey("verbose") )
56 props.verbose = opts.getStringAs<size_t>("verbose");
57 else
58 props.verbose = 0;
59 if( opts.hasKey("damping") )
60 props.damping = opts.getStringAs<double>("damping");
61 else
62 props.damping = 0.0;
63 if( opts.hasKey("inference") )
64 props.inference = opts.getStringAs<Properties::InfType>("inference");
65 else
66 props.inference = Properties::InfType::SUMPROD;
67 }
68
69
70 PropertySet BP::getProperties() const {
71 PropertySet opts;
72 opts.Set( "tol", props.tol );
73 opts.Set( "maxiter", props.maxiter );
74 opts.Set( "verbose", props.verbose );
75 opts.Set( "logdomain", props.logdomain );
76 opts.Set( "updates", props.updates );
77 opts.Set( "damping", props.damping );
78 opts.Set( "inference", props.inference );
79 return opts;
80 }
81
82
83 string BP::printProperties() const {
84 stringstream s( stringstream::out );
85 s << "[";
86 s << "tol=" << props.tol << ",";
87 s << "maxiter=" << props.maxiter << ",";
88 s << "verbose=" << props.verbose << ",";
89 s << "logdomain=" << props.logdomain << ",";
90 s << "updates=" << props.updates << ",";
91 s << "damping=" << props.damping << ",";
92 s << "inference=" << props.inference << "]";
93 return s.str();
94 }
95
96
97 void BP::construct() {
98 // create edge properties
99 _edges.clear();
100 _edges.reserve( nrVars() );
101 for( size_t i = 0; i < nrVars(); ++i ) {
102 _edges.push_back( vector<EdgeProp>() );
103 _edges[i].reserve( nbV(i).size() );
104 foreach( const Neighbor &I, nbV(i) ) {
105 EdgeProp newEP;
106 newEP.message = Prob( var(i).states() );
107 newEP.newMessage = Prob( var(i).states() );
108
109 newEP.index.reserve( factor(I).states() );
110 for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
111 newEP.index.push_back( k );
112
113 newEP.residual = 0.0;
114 _edges[i].push_back( newEP );
115 }
116 }
117 }
118
119
120 void BP::init() {
121 double c = props.logdomain ? 0.0 : 1.0;
122 for( size_t i = 0; i < nrVars(); ++i ) {
123 foreach( const Neighbor &I, nbV(i) ) {
124 message( i, I.iter ).fill( c );
125 newMessage( i, I.iter ).fill( c );
126 }
127 }
128 }
129
130
131 void BP::findMaxResidual( size_t &i, size_t &_I ) {
132 i = 0;
133 _I = 0;
134 double maxres = residual( i, _I );
135 for( size_t j = 0; j < nrVars(); ++j )
136 foreach( const Neighbor &I, nbV(j) )
137 if( residual( j, I.iter ) > maxres ) {
138 i = j;
139 _I = I.iter;
140 maxres = residual( i, _I );
141 }
142 }
143
144
145 void BP::calcNewMessage( size_t i, size_t _I ) {
146 // calculate updated message I->i
147 size_t I = nbV(i,_I);
148
149 if( 0 == 1 ) {
150 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
151 Factor prod( factor( I ) );
152 foreach( const Neighbor &j, nbF(I) )
153 if( j != i ) { // for all j in I \ i
154 foreach( const Neighbor &J, nbV(j) )
155 if( J != I ) { // for all J in nb(j) \ I
156 prod *= Factor( var(j), message(j, J.iter) );
157 }
158 }
159 newMessage(i,_I) = prod.marginal( var(i) ).p();
160 } else {
161 /* OPTIMIZED VERSION */
162 Prob prod( factor(I).p() );
163 if( props.logdomain )
164 prod.takeLog();
165
166 // Calculate product of incoming messages and factor I
167 foreach( const Neighbor &j, nbF(I) ) {
168 if( j != i ) { // for all j in I \ i
169 size_t _I = j.dual;
170 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
171 const ind_t &ind = index(j, _I);
172
173 // prod_j will be the product of messages coming into j
174 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
175 foreach( const Neighbor &J, nbV(j) )
176 if( J != I ) { // for all J in nb(j) \ I
177 if( props.logdomain )
178 prod_j += message( j, J.iter );
179 else
180 prod_j *= message( j, J.iter );
181 }
182
183 // multiply prod with prod_j
184 for( size_t r = 0; r < prod.size(); ++r )
185 if( props.logdomain )
186 prod[r] += prod_j[ind[r]];
187 else
188 prod[r] *= prod_j[ind[r]];
189 }
190 }
191 if( props.logdomain ) {
192 prod -= prod.maxVal();
193 prod.takeExp();
194 }
195
196 // Marginalize onto i
197 Prob marg( var(i).states(), 0.0 );
198 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
199 const ind_t ind = index(i,_I);
200 if( props.inference == Properties::InfType::SUMPROD )
201 for( size_t r = 0; r < prod.size(); ++r )
202 marg[ind[r]] += prod[r];
203 else
204 for( size_t r = 0; r < prod.size(); ++r )
205 if( prod[r] > marg[ind[r]] )
206 marg[ind[r]] = prod[r];
207 marg.normalize();
208
209 // Store result
210 if( props.logdomain )
211 newMessage(i,_I) = marg.log();
212 else
213 newMessage(i,_I) = marg;
214 }
215 }
216
217
218 // BP::run does not check for NANs for performance reasons
219 // Somehow NaNs do not often occur in BP...
220 double BP::run() {
221 if( props.verbose >= 1 )
222 cout << "Starting " << identify() << "...";
223 if( props.verbose >= 3)
224 cout << endl;
225
226 double tic = toc();
227 Diffs diffs(nrVars(), 1.0);
228
229 vector<Edge> update_seq;
230
231 vector<Factor> old_beliefs;
232 old_beliefs.reserve( nrVars() );
233 for( size_t i = 0; i < nrVars(); ++i )
234 old_beliefs.push_back( beliefV(i) );
235
236 size_t nredges = nrEdges();
237
238 if( props.updates == Properties::UpdateType::SEQMAX ) {
239 // do the first pass
240 for( size_t i = 0; i < nrVars(); ++i )
241 foreach( const Neighbor &I, nbV(i) ) {
242 calcNewMessage( i, I.iter );
243 // calculate initial residuals
244 residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
245 }
246 } else {
247 update_seq.reserve( nredges );
248 for( size_t i = 0; i < nrVars(); ++i )
249 foreach( const Neighbor &I, nbV(i) )
250 update_seq.push_back( Edge( i, I.iter ) );
251 }
252
253 // do several passes over the network until maximum number of iterations has
254 // been reached or until the maximum belief difference is smaller than tolerance
255 for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; ++_iters ) {
256 if( props.updates == Properties::UpdateType::SEQMAX ) {
257 // Residuals-BP by Koller et al.
258 for( size_t t = 0; t < nredges; ++t ) {
259 // update the message with the largest residual
260 size_t i, _I;
261 findMaxResidual( i, _I );
262 updateMessage( i, _I );
263
264 // I->i has been updated, which means that residuals for all
265 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
266 foreach( const Neighbor &J, nbV(i) ) {
267 if( J.iter != _I ) {
268 foreach( const Neighbor &j, nbF(J) ) {
269 size_t _J = j.dual;
270 if( j != i ) {
271 calcNewMessage( j, _J );
272 residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
273 }
274 }
275 }
276 }
277 }
278 } else if( props.updates == Properties::UpdateType::PARALL ) {
279 // Parallel updates
280 for( size_t i = 0; i < nrVars(); ++i )
281 foreach( const Neighbor &I, nbV(i) )
282 calcNewMessage( i, I.iter );
283
284 for( size_t i = 0; i < nrVars(); ++i )
285 foreach( const Neighbor &I, nbV(i) )
286 updateMessage( i, I.iter );
287 } else {
288 // Sequential updates
289 if( props.updates == Properties::UpdateType::SEQRND )
290 random_shuffle( update_seq.begin(), update_seq.end() );
291
292 foreach( const Edge &e, update_seq ) {
293 calcNewMessage( e.first, e.second );
294 updateMessage( e.first, e.second );
295 }
296 }
297
298 // calculate new beliefs and compare with old ones
299 for( size_t i = 0; i < nrVars(); ++i ) {
300 Factor nb( beliefV(i) );
301 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
302 old_beliefs[i] = nb;
303 }
304
305 if( props.verbose >= 3 )
306 cout << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
307 }
308
309 if( diffs.maxDiff() > _maxdiff )
310 _maxdiff = diffs.maxDiff();
311
312 if( props.verbose >= 1 ) {
313 if( diffs.maxDiff() > props.tol ) {
314 if( props.verbose == 1 )
315 cout << endl;
316 cout << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
317 } else {
318 if( props.verbose >= 3 )
319 cout << Name << "::run: ";
320 cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
321 }
322 }
323
324 return diffs.maxDiff();
325 }
326
327
328 void BP::calcBeliefV( size_t i, Prob &p ) const {
329 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
330 foreach( const Neighbor &I, nbV(i) )
331 if( props.logdomain )
332 p += newMessage( i, I.iter );
333 else
334 p *= newMessage( i, I.iter );
335 }
336
337
338 void BP::calcBeliefF( size_t I, Prob &p ) const {
339 p = factor(I).p();
340 if( props.logdomain )
341 p.takeLog();
342
343 foreach( const Neighbor &j, nbF(I) ) {
344 size_t _I = j.dual;
345 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
346 const ind_t & ind = index(j, _I);
347
348 // prod_j will be the product of messages coming into j
349 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
350 foreach( const Neighbor &J, nbV(j) ) {
351 if( J != I ) { // for all J in nb(j) \ I
352 if( props.logdomain )
353 prod_j += newMessage( j, J.iter );
354 else
355 prod_j *= newMessage( j, J.iter );
356 }
357 }
358
359 // multiply p with prod_j
360 for( size_t r = 0; r < p.size(); ++r ) {
361 if( props.logdomain )
362 p[r] += prod_j[ind[r]];
363 else
364 p[r] *= prod_j[ind[r]];
365 }
366 }
367 }
368
369
370 Factor BP::beliefV( size_t i ) const {
371 Prob p;
372 calcBeliefV( i, p );
373
374 if( props.logdomain ) {
375 p -= p.maxVal();
376 p.takeExp();
377 }
378
379 p.normalize();
380 return( Factor( var(i), p ) );
381 }
382
383
384 Factor BP::belief( const Var &n ) const {
385 return( beliefV( findVar( n ) ) );
386 }
387
388
389 vector<Factor> BP::beliefs() const {
390 vector<Factor> result;
391 for( size_t i = 0; i < nrVars(); ++i )
392 result.push_back( beliefV(i) );
393 for( size_t I = 0; I < nrFactors(); ++I )
394 result.push_back( beliefF(I) );
395 return result;
396 }
397
398
399 Factor BP::belief( const VarSet &ns ) const {
400 if( ns.size() == 1 )
401 return belief( *(ns.begin()) );
402 else {
403 size_t I;
404 for( I = 0; I < nrFactors(); I++ )
405 if( factor(I).vars() >> ns )
406 break;
407 assert( I != nrFactors() );
408 return beliefF(I).marginal(ns);
409 }
410 }
411
412
413 Factor BP::beliefF( size_t I ) const {
414 if( 0 == 1 ) {
415 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
416
417 Factor prod( factor(I) );
418 foreach( const Neighbor &j, nbF(I) ) {
419 foreach( const Neighbor &J, nbV(j) ) {
420 if( J != I ) // for all J in nb(j) \ I
421 prod *= Factor( var(j), newMessage(j, J.iter) );
422 }
423 }
424 return prod.normalized();
425 } else {
426 /* OPTIMIZED VERSION */
427
428 Prob prod;
429 calcBeliefF( I, prod );
430
431 if( props.logdomain ) {
432 prod -= prod.maxVal();
433 prod.takeExp();
434 }
435 prod.normalize();
436
437 Factor result( factor(I).vars(), prod );
438
439 return( result );
440 }
441 }
442
443
444 Real BP::logZ() const {
445 Real sum = 0.0;
446 for(size_t i = 0; i < nrVars(); ++i )
447 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
448 for( size_t I = 0; I < nrFactors(); ++I )
449 sum -= dist( beliefF(I), factor(I), Prob::DISTKL );
450 return sum;
451 }
452
453
454 string BP::identify() const {
455 return string(Name) + printProperties();
456 }
457
458
459 void BP::init( const VarSet &ns ) {
460 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
461 size_t ni = findVar( *n );
462 foreach( const Neighbor &I, nbV( ni ) )
463 message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
464 }
465 }
466
467
468 std::vector<size_t> BP::findMaximum() const {
469 std::vector<size_t> maximum( nrVars() );
470 std::vector<bool> visitedVars( nrVars(), false );
471 std::vector<bool> visitedFactors( nrFactors(), false );
472 std::stack<size_t> scheduledFactors;
473 for( size_t i = 0; i < nrVars(); ++i ) {
474 if( visitedVars[i] )
475 continue;
476 visitedVars[i] = true;
477
478 // Maximise with respect to variable i
479 Prob prod;
480 calcBeliefV( i, prod );
481 maximum[i] = std::max_element( prod.begin(), prod.end() ) - prod.begin();
482
483 foreach( const Neighbor &I, nbV(i) )
484 if( !visitedFactors[I] )
485 scheduledFactors.push(I);
486
487 while( !scheduledFactors.empty() ){
488 size_t I = scheduledFactors.top();
489 scheduledFactors.pop();
490 if( visitedFactors[I] )
491 continue;
492 visitedFactors[I] = true;
493
494 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
495 bool allDetermined = true;
496 foreach( const Neighbor &j, nbF(I) )
497 if( !visitedVars[j.node] ) {
498 allDetermined = false;
499 break;
500 }
501 if( allDetermined )
502 continue;
503
504 // Calculate product of incoming messages on factor I
505 Prob prod2;
506 calcBeliefF( I, prod2 );
507
508 // The allowed configuration is restrained according to the variables assigned so far:
509 // pick the argmax amongst the allowed states
510 Real maxProb = std::numeric_limits<Real>::min();
511 State maxState( factor(I).vars() );
512 for( State s( factor(I).vars() ); s.valid(); ++s ){
513 // First, calculate whether this state is consistent with variables that
514 // have been assigned already
515 bool allowedState = true;
516 foreach( const Neighbor &j, nbF(I) )
517 if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
518 allowedState = false;
519 break;
520 }
521 // If it is consistent, check if its probability is larger than what we have seen so far
522 if( allowedState && prod2[s] > maxProb ) {
523 maxState = s;
524 maxProb = prod2[s];
525 }
526 }
527
528 // Decode the argmax
529 foreach( const Neighbor &j, nbF(I) ) {
530 if( visitedVars[j.node] ) {
531 // We have already visited j earlier - hopefully our state is consistent
532 if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
533 std::cerr << "BP::findMaximum - warning: maximum not consistent due to loops." << std::endl;
534 } else {
535 // We found a consistent state for variable j
536 visitedVars[j.node] = true;
537 maximum[j.node] = maxState( var(j.node) );
538 foreach( const Neighbor &J, nbV(j) )
539 if( !visitedFactors[J] )
540 scheduledFactors.push(J);
541 }
542 }
543 }
544 }
545 return maximum;
546 }
547
548
549 } // end of namespace dai