1fd98576cff334c1e899490ec5f4734017fdc1cf
[libdai.git] / src / bp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <sstream>
11 #include <map>
12 #include <set>
13 #include <algorithm>
14 #include <dai/bp.h>
15 #include <dai/util.h>
16 #include <dai/properties.h>
17
18
19 namespace dai {
20
21
22 using namespace std;
23
24
25 #define DAI_BP_FAST 1
26 /// \todo Make DAI_BP_FAST a compile-time choice, as it is a memory/speed tradeoff
27
28
29 void BP::setProperties( const PropertySet &opts ) {
30 DAI_ASSERT( opts.hasKey("tol") );
31 DAI_ASSERT( opts.hasKey("logdomain") );
32 DAI_ASSERT( opts.hasKey("updates") );
33
34 props.tol = opts.getStringAs<Real>("tol");
35 props.logdomain = opts.getStringAs<bool>("logdomain");
36 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
37
38 if( opts.hasKey("maxiter") )
39 props.maxiter = opts.getStringAs<size_t>("maxiter");
40 else
41 props.maxiter = 10000;
42 if( opts.hasKey("maxtime") )
43 props.maxtime = opts.getStringAs<Real>("maxtime");
44 else
45 props.maxtime = INFINITY;
46 if( opts.hasKey("verbose") )
47 props.verbose = opts.getStringAs<size_t>("verbose");
48 else
49 props.verbose = 0;
50 if( opts.hasKey("damping") )
51 props.damping = opts.getStringAs<Real>("damping");
52 else
53 props.damping = 0.0;
54 if( opts.hasKey("inference") )
55 props.inference = opts.getStringAs<Properties::InfType>("inference");
56 else
57 props.inference = Properties::InfType::SUMPROD;
58 }
59
60
61 PropertySet BP::getProperties() const {
62 PropertySet opts;
63 opts.set( "tol", props.tol );
64 opts.set( "maxiter", props.maxiter );
65 opts.set( "maxtime", props.maxtime );
66 opts.set( "verbose", props.verbose );
67 opts.set( "logdomain", props.logdomain );
68 opts.set( "updates", props.updates );
69 opts.set( "damping", props.damping );
70 opts.set( "inference", props.inference );
71 return opts;
72 }
73
74
75 string BP::printProperties() const {
76 stringstream s( stringstream::out );
77 s << "[";
78 s << "tol=" << props.tol << ",";
79 s << "maxiter=" << props.maxiter << ",";
80 s << "maxtime=" << props.maxtime << ",";
81 s << "verbose=" << props.verbose << ",";
82 s << "logdomain=" << props.logdomain << ",";
83 s << "updates=" << props.updates << ",";
84 s << "damping=" << props.damping << ",";
85 s << "inference=" << props.inference << "]";
86 return s.str();
87 }
88
89
90 void BP::construct() {
91 // create edge properties
92 _edges.clear();
93 _edges.reserve( nrVars() );
94 _edge2lut.clear();
95 if( props.updates == Properties::UpdateType::SEQMAX )
96 _edge2lut.reserve( nrVars() );
97 for( size_t i = 0; i < nrVars(); ++i ) {
98 _edges.push_back( vector<EdgeProp>() );
99 _edges[i].reserve( nbV(i).size() );
100 if( props.updates == Properties::UpdateType::SEQMAX ) {
101 _edge2lut.push_back( vector<LutType::iterator>() );
102 _edge2lut[i].reserve( nbV(i).size() );
103 }
104 bforeach( const Neighbor &I, nbV(i) ) {
105 EdgeProp newEP;
106 newEP.message = Prob( var(i).states() );
107 newEP.newMessage = Prob( var(i).states() );
108
109 if( DAI_BP_FAST ) {
110 newEP.index.reserve( factor(I).nrStates() );
111 for( IndexFor k( var(i), factor(I).vars() ); k.valid(); ++k )
112 newEP.index.push_back( k );
113 }
114
115 newEP.residual = 0.0;
116 _edges[i].push_back( newEP );
117 if( props.updates == Properties::UpdateType::SEQMAX )
118 _edge2lut[i].push_back( _lut.insert( make_pair( newEP.residual, make_pair( i, _edges[i].size() - 1 ))) );
119 }
120 }
121
122 // create old beliefs
123 _oldBeliefsV.clear();
124 _oldBeliefsV.reserve( nrVars() );
125 for( size_t i = 0; i < nrVars(); ++i )
126 _oldBeliefsV.push_back( Factor( var(i) ) );
127 _oldBeliefsF.clear();
128 _oldBeliefsF.reserve( nrFactors() );
129 for( size_t I = 0; I < nrFactors(); ++I )
130 _oldBeliefsF.push_back( Factor( factor(I).vars() ) );
131
132 // create update sequence
133 _updateSeq.clear();
134 _updateSeq.reserve( nrEdges() );
135 for( size_t I = 0; I < nrFactors(); I++ )
136 bforeach( const Neighbor &i, nbF(I) )
137 _updateSeq.push_back( Edge( i, i.dual ) );
138 }
139
140
141 void BP::init() {
142 Real c = props.logdomain ? 0.0 : 1.0;
143 for( size_t i = 0; i < nrVars(); ++i ) {
144 bforeach( const Neighbor &I, nbV(i) ) {
145 message( i, I.iter ).fill( c );
146 newMessage( i, I.iter ).fill( c );
147 if( props.updates == Properties::UpdateType::SEQMAX )
148 updateResidual( i, I.iter, 0.0 );
149 }
150 }
151 _iters = 0;
152 }
153
154
155 void BP::findMaxResidual( size_t &i, size_t &_I ) {
156 DAI_ASSERT( !_lut.empty() );
157 LutType::const_iterator largestEl = _lut.end();
158 --largestEl;
159 i = largestEl->second.first;
160 _I = largestEl->second.second;
161 }
162
163
164 Prob BP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
165 Factor Fprod( factor(I) );
166 Prob &prod = Fprod.p();
167 if( props.logdomain )
168 prod.takeLog();
169
170 // Calculate product of incoming messages and factor I
171 bforeach( const Neighbor &j, nbF(I) )
172 if( !(without_i && (j == i)) ) {
173 // prod_j will be the product of messages coming into j
174 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
175 bforeach( const Neighbor &J, nbV(j) )
176 if( J != I ) { // for all J in nb(j) \ I
177 if( props.logdomain )
178 prod_j += message( j, J.iter );
179 else
180 prod_j *= message( j, J.iter );
181 }
182
183 // multiply prod with prod_j
184 if( !DAI_BP_FAST ) {
185 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
186 if( props.logdomain )
187 Fprod += Factor( var(j), prod_j );
188 else
189 Fprod *= Factor( var(j), prod_j );
190 } else {
191 // OPTIMIZED VERSION
192 size_t _I = j.dual;
193 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
194 const ind_t &ind = index(j, _I);
195
196 for( size_t r = 0; r < prod.size(); ++r )
197 if( props.logdomain )
198 prod.set( r, prod[r] + prod_j[ind[r]] );
199 else
200 prod.set( r, prod[r] * prod_j[ind[r]] );
201 }
202 }
203 return prod;
204 }
205
206
207 void BP::calcNewMessage( size_t i, size_t _I ) {
208 // calculate updated message I->i
209 size_t I = nbV(i,_I);
210
211 Prob marg;
212 if( factor(I).vars().size() == 1 ) // optimization
213 marg = factor(I).p();
214 else {
215 Factor Fprod( factor(I) );
216 Prob &prod = Fprod.p();
217 prod = calcIncomingMessageProduct( I, true, i );
218
219 if( props.logdomain ) {
220 prod -= prod.max();
221 prod.takeExp();
222 }
223
224 // Marginalize onto i
225 if( !DAI_BP_FAST ) {
226 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
227 if( props.inference == Properties::InfType::SUMPROD )
228 marg = Fprod.marginal( var(i) ).p();
229 else
230 marg = Fprod.maxMarginal( var(i) ).p();
231 } else {
232 // OPTIMIZED VERSION
233 marg = Prob( var(i).states(), 0.0 );
234 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
235 const ind_t ind = index(i,_I);
236 if( props.inference == Properties::InfType::SUMPROD )
237 for( size_t r = 0; r < prod.size(); ++r )
238 marg.set( ind[r], marg[ind[r]] + prod[r] );
239 else
240 for( size_t r = 0; r < prod.size(); ++r )
241 if( prod[r] > marg[ind[r]] )
242 marg.set( ind[r], prod[r] );
243 marg.normalize();
244 }
245 }
246
247 // Store result
248 if( props.logdomain )
249 newMessage(i,_I) = marg.log();
250 else
251 newMessage(i,_I) = marg;
252
253 // Update the residual if necessary
254 if( props.updates == Properties::UpdateType::SEQMAX )
255 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), DISTLINF ) );
256 }
257
258
259 // BP::run does not check for NANs for performance reasons
260 // Somehow NaNs do not often occur in BP...
261 Real BP::run() {
262 if( props.verbose >= 1 )
263 cerr << "Starting " << identify() << "...";
264 if( props.verbose >= 3)
265 cerr << endl;
266
267 double tic = toc();
268
269 // do several passes over the network until maximum number of iterations has
270 // been reached or until the maximum belief difference is smaller than tolerance
271 Real maxDiff = INFINITY;
272 for( ; _iters < props.maxiter && maxDiff > props.tol && (toc() - tic) < props.maxtime; _iters++ ) {
273 if( props.updates == Properties::UpdateType::SEQMAX ) {
274 if( _iters == 0 ) {
275 // do the first pass
276 for( size_t i = 0; i < nrVars(); ++i )
277 bforeach( const Neighbor &I, nbV(i) )
278 calcNewMessage( i, I.iter );
279 }
280 // Maximum-Residual BP [\ref EMK06]
281 for( size_t t = 0; t < _updateSeq.size(); ++t ) {
282 // update the message with the largest residual
283 size_t i, _I;
284 findMaxResidual( i, _I );
285 updateMessage( i, _I );
286
287 // I->i has been updated, which means that residuals for all
288 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
289 bforeach( const Neighbor &J, nbV(i) ) {
290 if( J.iter != _I ) {
291 bforeach( const Neighbor &j, nbF(J) ) {
292 size_t _J = j.dual;
293 if( j != i )
294 calcNewMessage( j, _J );
295 }
296 }
297 }
298 }
299 } else if( props.updates == Properties::UpdateType::PARALL ) {
300 // Parallel updates
301 for( size_t i = 0; i < nrVars(); ++i )
302 bforeach( const Neighbor &I, nbV(i) )
303 calcNewMessage( i, I.iter );
304
305 for( size_t i = 0; i < nrVars(); ++i )
306 bforeach( const Neighbor &I, nbV(i) )
307 updateMessage( i, I.iter );
308 } else {
309 // Sequential updates
310 if( props.updates == Properties::UpdateType::SEQRND )
311 random_shuffle( _updateSeq.begin(), _updateSeq.end(), rnd );
312
313 bforeach( const Edge &e, _updateSeq ) {
314 calcNewMessage( e.first, e.second );
315 updateMessage( e.first, e.second );
316 }
317 }
318
319 // calculate new beliefs and compare with old ones
320 maxDiff = -INFINITY;
321 for( size_t i = 0; i < nrVars(); ++i ) {
322 Factor b( beliefV(i) );
323 maxDiff = std::max( maxDiff, dist( b, _oldBeliefsV[i], DISTLINF ) );
324 _oldBeliefsV[i] = b;
325 }
326 for( size_t I = 0; I < nrFactors(); ++I ) {
327 Factor b( beliefF(I) );
328 maxDiff = std::max( maxDiff, dist( b, _oldBeliefsF[I], DISTLINF ) );
329 _oldBeliefsF[I] = b;
330 }
331
332 if( props.verbose >= 3 )
333 cerr << name() << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
334 }
335
336 if( maxDiff > _maxdiff )
337 _maxdiff = maxDiff;
338
339 if( props.verbose >= 1 ) {
340 if( maxDiff > props.tol ) {
341 if( props.verbose == 1 )
342 cerr << endl;
343 cerr << name() << "::run: WARNING: not converged after " << _iters << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
344 } else {
345 if( props.verbose >= 3 )
346 cerr << name() << "::run: ";
347 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
348 }
349 }
350
351 return maxDiff;
352 }
353
354
355 void BP::calcBeliefV( size_t i, Prob &p ) const {
356 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
357 bforeach( const Neighbor &I, nbV(i) )
358 if( props.logdomain )
359 p += newMessage( i, I.iter );
360 else
361 p *= newMessage( i, I.iter );
362 }
363
364
365 Factor BP::beliefV( size_t i ) const {
366 Prob p;
367 calcBeliefV( i, p );
368
369 if( props.logdomain ) {
370 p -= p.max();
371 p.takeExp();
372 }
373 p.normalize();
374
375 return( Factor( var(i), p ) );
376 }
377
378
379 Factor BP::beliefF( size_t I ) const {
380 Prob p;
381 calcBeliefF( I, p );
382
383 if( props.logdomain ) {
384 p -= p.max();
385 p.takeExp();
386 }
387 p.normalize();
388
389 return( Factor( factor(I).vars(), p ) );
390 }
391
392
393 vector<Factor> BP::beliefs() const {
394 vector<Factor> result;
395 for( size_t i = 0; i < nrVars(); ++i )
396 result.push_back( beliefV(i) );
397 for( size_t I = 0; I < nrFactors(); ++I )
398 result.push_back( beliefF(I) );
399 return result;
400 }
401
402
403 Factor BP::belief( const VarSet &ns ) const {
404 if( ns.size() == 0 )
405 return Factor();
406 else if( ns.size() == 1 )
407 return beliefV( findVar( *(ns.begin() ) ) );
408 else {
409 size_t I;
410 for( I = 0; I < nrFactors(); I++ )
411 if( factor(I).vars() >> ns )
412 break;
413 if( I == nrFactors() )
414 DAI_THROW(BELIEF_NOT_AVAILABLE);
415 return beliefF(I).marginal(ns);
416 }
417 }
418
419
420 Real BP::logZ() const {
421 Real sum = 0.0;
422 for( size_t i = 0; i < nrVars(); ++i )
423 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
424 for( size_t I = 0; I < nrFactors(); ++I )
425 sum -= dist( beliefF(I), factor(I), DISTKL );
426 return sum;
427 }
428
429
430 void BP::init( const VarSet &ns ) {
431 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
432 size_t ni = findVar( *n );
433 bforeach( const Neighbor &I, nbV( ni ) ) {
434 Real val = props.logdomain ? 0.0 : 1.0;
435 message( ni, I.iter ).fill( val );
436 newMessage( ni, I.iter ).fill( val );
437 if( props.updates == Properties::UpdateType::SEQMAX )
438 updateResidual( ni, I.iter, 0.0 );
439 }
440 }
441 _iters = 0;
442 }
443
444
445 void BP::updateMessage( size_t i, size_t _I ) {
446 if( recordSentMessages )
447 _sentMessages.push_back(make_pair(i,_I));
448 if( props.damping == 0.0 ) {
449 message(i,_I) = newMessage(i,_I);
450 if( props.updates == Properties::UpdateType::SEQMAX )
451 updateResidual( i, _I, 0.0 );
452 } else {
453 if( props.logdomain )
454 message(i,_I) = (message(i,_I) * props.damping) + (newMessage(i,_I) * (1.0 - props.damping));
455 else
456 message(i,_I) = (message(i,_I) ^ props.damping) * (newMessage(i,_I) ^ (1.0 - props.damping));
457 if( props.updates == Properties::UpdateType::SEQMAX )
458 updateResidual( i, _I, dist( newMessage(i,_I), message(i,_I), DISTLINF ) );
459 }
460 }
461
462
463 void BP::updateResidual( size_t i, size_t _I, Real r ) {
464 EdgeProp* pEdge = &_edges[i][_I];
465 pEdge->residual = r;
466
467 // rearrange look-up table (delete and reinsert new key)
468 _lut.erase( _edge2lut[i][_I] );
469 _edge2lut[i][_I] = _lut.insert( make_pair( r, make_pair(i, _I) ) );
470 }
471
472
473 } // end of namespace dai