4b165f50b6a1e71083981e2d2fe10ea970b2e354
[libdai.git] / src / bp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <sstream>
14 #include <map>
15 #include <set>
16 #include <algorithm>
17 #include <dai/bp.h>
18 #include <dai/util.h>
19 #include <dai/properties.h>
20
21
22 namespace dai {
23
24
25 using namespace std;
26
27
28 const char *BP::Name = "BP";
29
30
31 #define DAI_BP_FAST 1
32
33
34 void BP::setProperties( const PropertySet &opts ) {
35 DAI_ASSERT( opts.hasKey("tol") );
36 DAI_ASSERT( opts.hasKey("logdomain") );
37 DAI_ASSERT( opts.hasKey("updates") );
38
39 props.tol = opts.getStringAs<Real>("tol");
40 props.logdomain = opts.getStringAs<bool>("logdomain");
41 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
42
43 if( opts.hasKey("maxiter") )
44 props.maxiter = opts.getStringAs<size_t>("maxiter");
45 else
46 props.maxiter = 10000;
47 if( opts.hasKey("maxtime") )
48 props.maxtime = opts.getStringAs<Real>("maxtime");
49 else
50 props.maxtime = INFINITY;
51 if( opts.hasKey("verbose") )
52 props.verbose = opts.getStringAs<size_t>("verbose");
53 else
54 props.verbose = 0;
55 if( opts.hasKey("damping") )
56 props.damping = opts.getStringAs<Real>("damping");
57 else
58 props.damping = 0.0;
59 if( opts.hasKey("inference") )
60 props.inference = opts.getStringAs<Properties::InfType>("inference");
61 else
62 props.inference = Properties::InfType::SUMPROD;
63 }
64
65
66 PropertySet BP::getProperties() const {
67 PropertySet opts;
68 opts.set( "tol", props.tol );
69 opts.set( "maxiter", props.maxiter );
70 opts.set( "maxtime", props.maxtime );
71 opts.set( "verbose", props.verbose );
72 opts.set( "logdomain", props.logdomain );
73 opts.set( "updates", props.updates );
74 opts.set( "damping", props.damping );
75 opts.set( "inference", props.inference );
76 return opts;
77 }
78
79
80 string BP::printProperties() const {
81 stringstream s( stringstream::out );
82 s << "[";
83 s << "tol=" << props.tol << ",";
84 s << "maxiter=" << props.maxiter << ",";
85 s << "maxtime=" << props.maxtime << ",";
86 s << "verbose=" << props.verbose << ",";
87 s << "logdomain=" << props.logdomain << ",";
88 s << "updates=" << props.updates << ",";
89 s << "damping=" << props.damping << ",";
90 s << "inference=" << props.inference << "]";
91 return s.str();
92 }
93
94
95 void BP::construct() {
96 // create edge properties
97 _edges.clear();
98 _edges.reserve( nrVars() );
99 _edge2lut.clear();
100 if( props.updates == Properties::UpdateType::SEQMAX )
101 _edge2lut.reserve( nrVars() );
102 for( size_t i = 0; i < nrVars(); ++i ) {
103 _edges.push_back( vector<EdgeProp>() );
104 _edges[i].reserve( nbV(i).size() );
105 if( props.updates == Properties::UpdateType::SEQMAX ) {
106 _edge2lut.push_back( vector<LutType::iterator>() );
107 _edge2lut[i].reserve( nbV(i).size() );
108 }
109 foreach( const Neighbor &I, nbV(i) ) {
110 EdgeProp newEP;
111 newEP.message = Prob( var(i).states() );
112 newEP.newMessage = Prob( var(i).states() );
113
114 if( DAI_BP_FAST ) {
115 newEP.index.reserve( factor(I).nrStates() );
116 for( IndexFor k( var(i), factor(I).vars() ); k.valid(); ++k )
117 newEP.index.push_back( k );
118 }
119
120 newEP.residual = 0.0;
121 _edges[i].push_back( newEP );
122 if( props.updates == Properties::UpdateType::SEQMAX )
123 _edge2lut[i].push_back( _lut.insert( make_pair( newEP.residual, make_pair( i, _edges[i].size() - 1 ))) );
124 }
125 }
126
127 // create old beliefs
128 _oldBeliefsV.clear();
129 _oldBeliefsV.reserve( nrVars() );
130 for( size_t i = 0; i < nrVars(); ++i )
131 _oldBeliefsV.push_back( Factor( var(i) ) );
132 _oldBeliefsF.clear();
133 _oldBeliefsF.reserve( nrFactors() );
134 for( size_t I = 0; I < nrFactors(); ++I )
135 _oldBeliefsF.push_back( Factor( factor(I).vars() ) );
136
137 // create update sequence
138 _updateSeq.clear();
139 _updateSeq.reserve( nrEdges() );
140 for( size_t I = 0; I < nrFactors(); I++ )
141 foreach( const Neighbor &i, nbF(I) )
142 _updateSeq.push_back( Edge( i, i.dual ) );
143 }
144
145
146 void BP::init() {
147 Real c = props.logdomain ? 0.0 : 1.0;
148 for( size_t i = 0; i < nrVars(); ++i ) {
149 foreach( const Neighbor &I, nbV(i) ) {
150 message( i, I.iter ).fill( c );
151 newMessage( i, I.iter ).fill( c );
152 if( props.updates == Properties::UpdateType::SEQMAX )
153 updateResidual( i, I.iter, 0.0 );
154 }
155 }
156 _iters = 0;
157 }
158
159
160 void BP::findMaxResidual( size_t &i, size_t &_I ) {
161 DAI_ASSERT( !_lut.empty() );
162 LutType::const_iterator largestEl = _lut.end();
163 --largestEl;
164 i = largestEl->second.first;
165 _I = largestEl->second.second;
166 }
167
168
169 Prob BP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
170 Factor Fprod( factor(I) );
171 Prob &prod = Fprod.p();
172 if( props.logdomain )
173 prod.takeLog();
174
175 // Calculate product of incoming messages and factor I
176 foreach( const Neighbor &j, nbF(I) )
177 if( !(without_i && (j == i)) ) {
178 // prod_j will be the product of messages coming into j
179 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
180 foreach( const Neighbor &J, nbV(j) )
181 if( J != I ) { // for all J in nb(j) \ I
182 if( props.logdomain )
183 prod_j += message( j, J.iter );
184 else
185 prod_j *= message( j, J.iter );
186 }
187
188 // multiply prod with prod_j
189 if( !DAI_BP_FAST ) {
190 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
191 if( props.logdomain )
192 Fprod += Factor( var(j), prod_j );
193 else
194 Fprod *= Factor( var(j), prod_j );
195 } else {
196 // OPTIMIZED VERSION
197 size_t _I = j.dual;
198 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
199 const ind_t &ind = index(j, _I);
200
201 for( size_t r = 0; r < prod.size(); ++r )
202 if( props.logdomain )
203 prod.set( r, prod[r] + prod_j[ind[r]] );
204 else
205 prod.set( r, prod[r] * prod_j[ind[r]] );
206 }
207 }
208 return prod;
209 }
210
211
212 void BP::calcNewMessage( size_t i, size_t _I ) {
213 // calculate updated message I->i
214 size_t I = nbV(i,_I);
215
216 Prob marg;
217 if( factor(I).vars().size() == 1 ) // optimization
218 marg = factor(I).p();
219 else {
220 Factor Fprod( factor(I) );
221 Prob &prod = Fprod.p();
222 prod = calcIncomingMessageProduct( I, true, i );
223
224 if( props.logdomain ) {
225 prod -= prod.max();
226 prod.takeExp();
227 }
228
229 // Marginalize onto i
230 if( !DAI_BP_FAST ) {
231 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
232 if( props.inference == Properties::InfType::SUMPROD )
233 marg = Fprod.marginal( var(i) ).p();
234 else
235 marg = Fprod.maxMarginal( var(i) ).p();
236 } else {
237 // OPTIMIZED VERSION
238 marg = Prob( var(i).states(), 0.0 );
239 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
240 const ind_t ind = index(i,_I);
241 if( props.inference == Properties::InfType::SUMPROD )
242 for( size_t r = 0; r < prod.size(); ++r )
243 marg.set( ind[r], marg[ind[r]] + prod[r] );
244 else
245 for( size_t r = 0; r < prod.size(); ++r )
246 if( prod[r] > marg[ind[r]] )
247 marg.set( ind[r], prod[r] );
248 marg.normalize();
249 }
250 }
251
252 // Store result
253 if( props.logdomain )
254 newMessage(i,_I) = marg.log();
255 else
256 newMessage(i,_I) = marg;
257
258 // Update the residual if necessary
259 if( props.updates == Properties::UpdateType::SEQMAX )
260 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), DISTLINF ) );
261 }
262
263
264 // BP::run does not check for NANs for performance reasons
265 // Somehow NaNs do not often occur in BP...
266 Real BP::run() {
267 if( props.verbose >= 1 )
268 cerr << "Starting " << identify() << "...";
269 if( props.verbose >= 3)
270 cerr << endl;
271
272 double tic = toc();
273
274 // do several passes over the network until maximum number of iterations has
275 // been reached or until the maximum belief difference is smaller than tolerance
276 Real maxDiff = INFINITY;
277 for( ; _iters < props.maxiter && maxDiff > props.tol && (toc() - tic) < props.maxtime; _iters++ ) {
278 if( props.updates == Properties::UpdateType::SEQMAX ) {
279 if( _iters == 0 ) {
280 // do the first pass
281 for( size_t i = 0; i < nrVars(); ++i )
282 foreach( const Neighbor &I, nbV(i) )
283 calcNewMessage( i, I.iter );
284 }
285 // Maximum-Residual BP [\ref EMK06]
286 for( size_t t = 0; t < _updateSeq.size(); ++t ) {
287 // update the message with the largest residual
288 size_t i, _I;
289 findMaxResidual( i, _I );
290 updateMessage( i, _I );
291
292 // I->i has been updated, which means that residuals for all
293 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
294 foreach( const Neighbor &J, nbV(i) ) {
295 if( J.iter != _I ) {
296 foreach( const Neighbor &j, nbF(J) ) {
297 size_t _J = j.dual;
298 if( j != i )
299 calcNewMessage( j, _J );
300 }
301 }
302 }
303 }
304 } else if( props.updates == Properties::UpdateType::PARALL ) {
305 // Parallel updates
306 for( size_t i = 0; i < nrVars(); ++i )
307 foreach( const Neighbor &I, nbV(i) )
308 calcNewMessage( i, I.iter );
309
310 for( size_t i = 0; i < nrVars(); ++i )
311 foreach( const Neighbor &I, nbV(i) )
312 updateMessage( i, I.iter );
313 } else {
314 // Sequential updates
315 if( props.updates == Properties::UpdateType::SEQRND )
316 random_shuffle( _updateSeq.begin(), _updateSeq.end() );
317
318 foreach( const Edge &e, _updateSeq ) {
319 calcNewMessage( e.first, e.second );
320 updateMessage( e.first, e.second );
321 }
322 }
323
324 // calculate new beliefs and compare with old ones
325 maxDiff = -INFINITY;
326 for( size_t i = 0; i < nrVars(); ++i ) {
327 Factor b( beliefV(i) );
328 maxDiff = std::max( maxDiff, dist( b, _oldBeliefsV[i], DISTLINF ) );
329 _oldBeliefsV[i] = b;
330 }
331 for( size_t I = 0; I < nrFactors(); ++I ) {
332 Factor b( beliefF(I) );
333 maxDiff = std::max( maxDiff, dist( b, _oldBeliefsF[I], DISTLINF ) );
334 _oldBeliefsF[I] = b;
335 }
336
337 if( props.verbose >= 3 )
338 cerr << Name << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
339 }
340
341 if( maxDiff > _maxdiff )
342 _maxdiff = maxDiff;
343
344 if( props.verbose >= 1 ) {
345 if( maxDiff > props.tol ) {
346 if( props.verbose == 1 )
347 cerr << endl;
348 cerr << Name << "::run: WARNING: not converged after " << _iters << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
349 } else {
350 if( props.verbose >= 3 )
351 cerr << Name << "::run: ";
352 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
353 }
354 }
355
356 return maxDiff;
357 }
358
359
360 void BP::calcBeliefV( size_t i, Prob &p ) const {
361 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
362 foreach( const Neighbor &I, nbV(i) )
363 if( props.logdomain )
364 p += newMessage( i, I.iter );
365 else
366 p *= newMessage( i, I.iter );
367 }
368
369
370 Factor BP::beliefV( size_t i ) const {
371 Prob p;
372 calcBeliefV( i, p );
373
374 if( props.logdomain ) {
375 p -= p.max();
376 p.takeExp();
377 }
378 p.normalize();
379
380 return( Factor( var(i), p ) );
381 }
382
383
384 Factor BP::beliefF( size_t I ) const {
385 Prob p;
386 calcBeliefF( I, p );
387
388 if( props.logdomain ) {
389 p -= p.max();
390 p.takeExp();
391 }
392 p.normalize();
393
394 return( Factor( factor(I).vars(), p ) );
395 }
396
397
398 vector<Factor> BP::beliefs() const {
399 vector<Factor> result;
400 for( size_t i = 0; i < nrVars(); ++i )
401 result.push_back( beliefV(i) );
402 for( size_t I = 0; I < nrFactors(); ++I )
403 result.push_back( beliefF(I) );
404 return result;
405 }
406
407
408 Factor BP::belief( const VarSet &ns ) const {
409 if( ns.size() == 0 )
410 return Factor();
411 else if( ns.size() == 1 )
412 return beliefV( findVar( *(ns.begin() ) ) );
413 else {
414 size_t I;
415 for( I = 0; I < nrFactors(); I++ )
416 if( factor(I).vars() >> ns )
417 break;
418 if( I == nrFactors() )
419 DAI_THROW(BELIEF_NOT_AVAILABLE);
420 return beliefF(I).marginal(ns);
421 }
422 }
423
424
425 Real BP::logZ() const {
426 Real sum = 0.0;
427 for( size_t i = 0; i < nrVars(); ++i )
428 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
429 for( size_t I = 0; I < nrFactors(); ++I )
430 sum -= dist( beliefF(I), factor(I), DISTKL );
431 return sum;
432 }
433
434
435 string BP::identify() const {
436 return string(Name) + printProperties();
437 }
438
439
440 void BP::init( const VarSet &ns ) {
441 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
442 size_t ni = findVar( *n );
443 foreach( const Neighbor &I, nbV( ni ) ) {
444 Real val = props.logdomain ? 0.0 : 1.0;
445 message( ni, I.iter ).fill( val );
446 newMessage( ni, I.iter ).fill( val );
447 if( props.updates == Properties::UpdateType::SEQMAX )
448 updateResidual( ni, I.iter, 0.0 );
449 }
450 }
451 _iters = 0;
452 }
453
454
455 void BP::updateMessage( size_t i, size_t _I ) {
456 if( recordSentMessages )
457 _sentMessages.push_back(make_pair(i,_I));
458 if( props.damping == 0.0 ) {
459 message(i,_I) = newMessage(i,_I);
460 if( props.updates == Properties::UpdateType::SEQMAX )
461 updateResidual( i, _I, 0.0 );
462 } else {
463 if( props.logdomain )
464 message(i,_I) = (message(i,_I) * props.damping) + (newMessage(i,_I) * (1.0 - props.damping));
465 else
466 message(i,_I) = (message(i,_I) ^ props.damping) * (newMessage(i,_I) ^ (1.0 - props.damping));
467 if( props.updates == Properties::UpdateType::SEQMAX )
468 updateResidual( i, _I, dist( newMessage(i,_I), message(i,_I), DISTLINF ) );
469 }
470 }
471
472
473 void BP::updateResidual( size_t i, size_t _I, Real r ) {
474 EdgeProp* pEdge = &_edges[i][_I];
475 pEdge->residual = r;
476
477 // rearrange look-up table (delete and reinsert new key)
478 _lut.erase( _edge2lut[i][_I] );
479 _edge2lut[i][_I] = _lut.insert( make_pair( r, make_pair(i, _I) ) );
480 }
481
482
483 } // end of namespace dai