5944d623f8b9b9d3e7169146eb225c3d0838edb6
[libdai.git] / src / bp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <sstream>
14 #include <map>
15 #include <set>
16 #include <algorithm>
17 #include <dai/bp.h>
18 #include <dai/util.h>
19 #include <dai/properties.h>
20
21
22 namespace dai {
23
24
25 using namespace std;
26
27
28 #define DAI_BP_FAST 1
29
30
31 void BP::setProperties( const PropertySet &opts ) {
32 DAI_ASSERT( opts.hasKey("tol") );
33 DAI_ASSERT( opts.hasKey("logdomain") );
34 DAI_ASSERT( opts.hasKey("updates") );
35
36 props.tol = opts.getStringAs<Real>("tol");
37 props.logdomain = opts.getStringAs<bool>("logdomain");
38 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
39
40 if( opts.hasKey("maxiter") )
41 props.maxiter = opts.getStringAs<size_t>("maxiter");
42 else
43 props.maxiter = 10000;
44 if( opts.hasKey("maxtime") )
45 props.maxtime = opts.getStringAs<Real>("maxtime");
46 else
47 props.maxtime = INFINITY;
48 if( opts.hasKey("verbose") )
49 props.verbose = opts.getStringAs<size_t>("verbose");
50 else
51 props.verbose = 0;
52 if( opts.hasKey("damping") )
53 props.damping = opts.getStringAs<Real>("damping");
54 else
55 props.damping = 0.0;
56 if( opts.hasKey("inference") )
57 props.inference = opts.getStringAs<Properties::InfType>("inference");
58 else
59 props.inference = Properties::InfType::SUMPROD;
60 }
61
62
63 PropertySet BP::getProperties() const {
64 PropertySet opts;
65 opts.set( "tol", props.tol );
66 opts.set( "maxiter", props.maxiter );
67 opts.set( "maxtime", props.maxtime );
68 opts.set( "verbose", props.verbose );
69 opts.set( "logdomain", props.logdomain );
70 opts.set( "updates", props.updates );
71 opts.set( "damping", props.damping );
72 opts.set( "inference", props.inference );
73 return opts;
74 }
75
76
77 string BP::printProperties() const {
78 stringstream s( stringstream::out );
79 s << "[";
80 s << "tol=" << props.tol << ",";
81 s << "maxiter=" << props.maxiter << ",";
82 s << "maxtime=" << props.maxtime << ",";
83 s << "verbose=" << props.verbose << ",";
84 s << "logdomain=" << props.logdomain << ",";
85 s << "updates=" << props.updates << ",";
86 s << "damping=" << props.damping << ",";
87 s << "inference=" << props.inference << "]";
88 return s.str();
89 }
90
91
92 void BP::construct() {
93 // create edge properties
94 _edges.clear();
95 _edges.reserve( nrVars() );
96 _edge2lut.clear();
97 if( props.updates == Properties::UpdateType::SEQMAX )
98 _edge2lut.reserve( nrVars() );
99 for( size_t i = 0; i < nrVars(); ++i ) {
100 _edges.push_back( vector<EdgeProp>() );
101 _edges[i].reserve( nbV(i).size() );
102 if( props.updates == Properties::UpdateType::SEQMAX ) {
103 _edge2lut.push_back( vector<LutType::iterator>() );
104 _edge2lut[i].reserve( nbV(i).size() );
105 }
106 foreach( const Neighbor &I, nbV(i) ) {
107 EdgeProp newEP;
108 newEP.message = Prob( var(i).states() );
109 newEP.newMessage = Prob( var(i).states() );
110
111 if( DAI_BP_FAST ) {
112 newEP.index.reserve( factor(I).nrStates() );
113 for( IndexFor k( var(i), factor(I).vars() ); k.valid(); ++k )
114 newEP.index.push_back( k );
115 }
116
117 newEP.residual = 0.0;
118 _edges[i].push_back( newEP );
119 if( props.updates == Properties::UpdateType::SEQMAX )
120 _edge2lut[i].push_back( _lut.insert( make_pair( newEP.residual, make_pair( i, _edges[i].size() - 1 ))) );
121 }
122 }
123
124 // create old beliefs
125 _oldBeliefsV.clear();
126 _oldBeliefsV.reserve( nrVars() );
127 for( size_t i = 0; i < nrVars(); ++i )
128 _oldBeliefsV.push_back( Factor( var(i) ) );
129 _oldBeliefsF.clear();
130 _oldBeliefsF.reserve( nrFactors() );
131 for( size_t I = 0; I < nrFactors(); ++I )
132 _oldBeliefsF.push_back( Factor( factor(I).vars() ) );
133
134 // create update sequence
135 _updateSeq.clear();
136 _updateSeq.reserve( nrEdges() );
137 for( size_t I = 0; I < nrFactors(); I++ )
138 foreach( const Neighbor &i, nbF(I) )
139 _updateSeq.push_back( Edge( i, i.dual ) );
140 }
141
142
143 void BP::init() {
144 Real c = props.logdomain ? 0.0 : 1.0;
145 for( size_t i = 0; i < nrVars(); ++i ) {
146 foreach( const Neighbor &I, nbV(i) ) {
147 message( i, I.iter ).fill( c );
148 newMessage( i, I.iter ).fill( c );
149 if( props.updates == Properties::UpdateType::SEQMAX )
150 updateResidual( i, I.iter, 0.0 );
151 }
152 }
153 _iters = 0;
154 }
155
156
157 void BP::findMaxResidual( size_t &i, size_t &_I ) {
158 DAI_ASSERT( !_lut.empty() );
159 LutType::const_iterator largestEl = _lut.end();
160 --largestEl;
161 i = largestEl->second.first;
162 _I = largestEl->second.second;
163 }
164
165
166 Prob BP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
167 Factor Fprod( factor(I) );
168 Prob &prod = Fprod.p();
169 if( props.logdomain )
170 prod.takeLog();
171
172 // Calculate product of incoming messages and factor I
173 foreach( const Neighbor &j, nbF(I) )
174 if( !(without_i && (j == i)) ) {
175 // prod_j will be the product of messages coming into j
176 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
177 foreach( const Neighbor &J, nbV(j) )
178 if( J != I ) { // for all J in nb(j) \ I
179 if( props.logdomain )
180 prod_j += message( j, J.iter );
181 else
182 prod_j *= message( j, J.iter );
183 }
184
185 // multiply prod with prod_j
186 if( !DAI_BP_FAST ) {
187 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
188 if( props.logdomain )
189 Fprod += Factor( var(j), prod_j );
190 else
191 Fprod *= Factor( var(j), prod_j );
192 } else {
193 // OPTIMIZED VERSION
194 size_t _I = j.dual;
195 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
196 const ind_t &ind = index(j, _I);
197
198 for( size_t r = 0; r < prod.size(); ++r )
199 if( props.logdomain )
200 prod.set( r, prod[r] + prod_j[ind[r]] );
201 else
202 prod.set( r, prod[r] * prod_j[ind[r]] );
203 }
204 }
205 return prod;
206 }
207
208
209 void BP::calcNewMessage( size_t i, size_t _I ) {
210 // calculate updated message I->i
211 size_t I = nbV(i,_I);
212
213 Prob marg;
214 if( factor(I).vars().size() == 1 ) // optimization
215 marg = factor(I).p();
216 else {
217 Factor Fprod( factor(I) );
218 Prob &prod = Fprod.p();
219 prod = calcIncomingMessageProduct( I, true, i );
220
221 if( props.logdomain ) {
222 prod -= prod.max();
223 prod.takeExp();
224 }
225
226 // Marginalize onto i
227 if( !DAI_BP_FAST ) {
228 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
229 if( props.inference == Properties::InfType::SUMPROD )
230 marg = Fprod.marginal( var(i) ).p();
231 else
232 marg = Fprod.maxMarginal( var(i) ).p();
233 } else {
234 // OPTIMIZED VERSION
235 marg = Prob( var(i).states(), 0.0 );
236 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
237 const ind_t ind = index(i,_I);
238 if( props.inference == Properties::InfType::SUMPROD )
239 for( size_t r = 0; r < prod.size(); ++r )
240 marg.set( ind[r], marg[ind[r]] + prod[r] );
241 else
242 for( size_t r = 0; r < prod.size(); ++r )
243 if( prod[r] > marg[ind[r]] )
244 marg.set( ind[r], prod[r] );
245 marg.normalize();
246 }
247 }
248
249 // Store result
250 if( props.logdomain )
251 newMessage(i,_I) = marg.log();
252 else
253 newMessage(i,_I) = marg;
254
255 // Update the residual if necessary
256 if( props.updates == Properties::UpdateType::SEQMAX )
257 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), DISTLINF ) );
258 }
259
260
261 // BP::run does not check for NANs for performance reasons
262 // Somehow NaNs do not often occur in BP...
263 Real BP::run() {
264 if( props.verbose >= 1 )
265 cerr << "Starting " << identify() << "...";
266 if( props.verbose >= 3)
267 cerr << endl;
268
269 double tic = toc();
270
271 // do several passes over the network until maximum number of iterations has
272 // been reached or until the maximum belief difference is smaller than tolerance
273 Real maxDiff = INFINITY;
274 for( ; _iters < props.maxiter && maxDiff > props.tol && (toc() - tic) < props.maxtime; _iters++ ) {
275 if( props.updates == Properties::UpdateType::SEQMAX ) {
276 if( _iters == 0 ) {
277 // do the first pass
278 for( size_t i = 0; i < nrVars(); ++i )
279 foreach( const Neighbor &I, nbV(i) )
280 calcNewMessage( i, I.iter );
281 }
282 // Maximum-Residual BP [\ref EMK06]
283 for( size_t t = 0; t < _updateSeq.size(); ++t ) {
284 // update the message with the largest residual
285 size_t i, _I;
286 findMaxResidual( i, _I );
287 updateMessage( i, _I );
288
289 // I->i has been updated, which means that residuals for all
290 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
291 foreach( const Neighbor &J, nbV(i) ) {
292 if( J.iter != _I ) {
293 foreach( const Neighbor &j, nbF(J) ) {
294 size_t _J = j.dual;
295 if( j != i )
296 calcNewMessage( j, _J );
297 }
298 }
299 }
300 }
301 } else if( props.updates == Properties::UpdateType::PARALL ) {
302 // Parallel updates
303 for( size_t i = 0; i < nrVars(); ++i )
304 foreach( const Neighbor &I, nbV(i) )
305 calcNewMessage( i, I.iter );
306
307 for( size_t i = 0; i < nrVars(); ++i )
308 foreach( const Neighbor &I, nbV(i) )
309 updateMessage( i, I.iter );
310 } else {
311 // Sequential updates
312 if( props.updates == Properties::UpdateType::SEQRND )
313 random_shuffle( _updateSeq.begin(), _updateSeq.end(), rnd );
314
315 foreach( const Edge &e, _updateSeq ) {
316 calcNewMessage( e.first, e.second );
317 updateMessage( e.first, e.second );
318 }
319 }
320
321 // calculate new beliefs and compare with old ones
322 maxDiff = -INFINITY;
323 for( size_t i = 0; i < nrVars(); ++i ) {
324 Factor b( beliefV(i) );
325 maxDiff = std::max( maxDiff, dist( b, _oldBeliefsV[i], DISTLINF ) );
326 _oldBeliefsV[i] = b;
327 }
328 for( size_t I = 0; I < nrFactors(); ++I ) {
329 Factor b( beliefF(I) );
330 maxDiff = std::max( maxDiff, dist( b, _oldBeliefsF[I], DISTLINF ) );
331 _oldBeliefsF[I] = b;
332 }
333
334 if( props.verbose >= 3 )
335 cerr << name() << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
336 }
337
338 if( maxDiff > _maxdiff )
339 _maxdiff = maxDiff;
340
341 if( props.verbose >= 1 ) {
342 if( maxDiff > props.tol ) {
343 if( props.verbose == 1 )
344 cerr << endl;
345 cerr << name() << "::run: WARNING: not converged after " << _iters << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
346 } else {
347 if( props.verbose >= 3 )
348 cerr << name() << "::run: ";
349 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
350 }
351 }
352
353 return maxDiff;
354 }
355
356
357 void BP::calcBeliefV( size_t i, Prob &p ) const {
358 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
359 foreach( const Neighbor &I, nbV(i) )
360 if( props.logdomain )
361 p += newMessage( i, I.iter );
362 else
363 p *= newMessage( i, I.iter );
364 }
365
366
367 Factor BP::beliefV( size_t i ) const {
368 Prob p;
369 calcBeliefV( i, p );
370
371 if( props.logdomain ) {
372 p -= p.max();
373 p.takeExp();
374 }
375 p.normalize();
376
377 return( Factor( var(i), p ) );
378 }
379
380
381 Factor BP::beliefF( size_t I ) const {
382 Prob p;
383 calcBeliefF( I, p );
384
385 if( props.logdomain ) {
386 p -= p.max();
387 p.takeExp();
388 }
389 p.normalize();
390
391 return( Factor( factor(I).vars(), p ) );
392 }
393
394
395 vector<Factor> BP::beliefs() const {
396 vector<Factor> result;
397 for( size_t i = 0; i < nrVars(); ++i )
398 result.push_back( beliefV(i) );
399 for( size_t I = 0; I < nrFactors(); ++I )
400 result.push_back( beliefF(I) );
401 return result;
402 }
403
404
405 Factor BP::belief( const VarSet &ns ) const {
406 if( ns.size() == 0 )
407 return Factor();
408 else if( ns.size() == 1 )
409 return beliefV( findVar( *(ns.begin() ) ) );
410 else {
411 size_t I;
412 for( I = 0; I < nrFactors(); I++ )
413 if( factor(I).vars() >> ns )
414 break;
415 if( I == nrFactors() )
416 DAI_THROW(BELIEF_NOT_AVAILABLE);
417 return beliefF(I).marginal(ns);
418 }
419 }
420
421
422 Real BP::logZ() const {
423 Real sum = 0.0;
424 for( size_t i = 0; i < nrVars(); ++i )
425 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
426 for( size_t I = 0; I < nrFactors(); ++I )
427 sum -= dist( beliefF(I), factor(I), DISTKL );
428 return sum;
429 }
430
431
432 void BP::init( const VarSet &ns ) {
433 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
434 size_t ni = findVar( *n );
435 foreach( const Neighbor &I, nbV( ni ) ) {
436 Real val = props.logdomain ? 0.0 : 1.0;
437 message( ni, I.iter ).fill( val );
438 newMessage( ni, I.iter ).fill( val );
439 if( props.updates == Properties::UpdateType::SEQMAX )
440 updateResidual( ni, I.iter, 0.0 );
441 }
442 }
443 _iters = 0;
444 }
445
446
447 void BP::updateMessage( size_t i, size_t _I ) {
448 if( recordSentMessages )
449 _sentMessages.push_back(make_pair(i,_I));
450 if( props.damping == 0.0 ) {
451 message(i,_I) = newMessage(i,_I);
452 if( props.updates == Properties::UpdateType::SEQMAX )
453 updateResidual( i, _I, 0.0 );
454 } else {
455 if( props.logdomain )
456 message(i,_I) = (message(i,_I) * props.damping) + (newMessage(i,_I) * (1.0 - props.damping));
457 else
458 message(i,_I) = (message(i,_I) ^ props.damping) * (newMessage(i,_I) ^ (1.0 - props.damping));
459 if( props.updates == Properties::UpdateType::SEQMAX )
460 updateResidual( i, _I, dist( newMessage(i,_I), message(i,_I), DISTLINF ) );
461 }
462 }
463
464
465 void BP::updateResidual( size_t i, size_t _I, Real r ) {
466 EdgeProp* pEdge = &_edges[i][_I];
467 pEdge->residual = r;
468
469 // rearrange look-up table (delete and reinsert new key)
470 _lut.erase( _edge2lut[i][_I] );
471 _edge2lut[i][_I] = _lut.insert( make_pair( r, make_pair(i, _I) ) );
472 }
473
474
475 } // end of namespace dai