Multiple changes: changes in build system, one workaround and one bug fix
[libdai.git] / src / bp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/dai_config.h>
10 #ifdef DAI_WITH_BP
11
12
13 #include <iostream>
14 #include <sstream>
15 #include <map>
16 #include <set>
17 #include <algorithm>
18 #include <dai/bp.h>
19 #include <dai/util.h>
20 #include <dai/properties.h>
21
22
23 namespace dai {
24
25
26 using namespace std;
27
28
29 #define DAI_BP_FAST 1
30 /// \todo Make DAI_BP_FAST a compile-time choice, as it is a memory/speed tradeoff
31
32
33 void BP::setProperties( const PropertySet &opts ) {
34 DAI_ASSERT( opts.hasKey("tol") );
35 DAI_ASSERT( opts.hasKey("logdomain") );
36 DAI_ASSERT( opts.hasKey("updates") );
37
38 props.tol = opts.getStringAs<Real>("tol");
39 props.logdomain = opts.getStringAs<bool>("logdomain");
40 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
41
42 if( opts.hasKey("maxiter") )
43 props.maxiter = opts.getStringAs<size_t>("maxiter");
44 else
45 props.maxiter = 10000;
46 if( opts.hasKey("maxtime") )
47 props.maxtime = opts.getStringAs<Real>("maxtime");
48 else
49 props.maxtime = INFINITY;
50 if( opts.hasKey("verbose") )
51 props.verbose = opts.getStringAs<size_t>("verbose");
52 else
53 props.verbose = 0;
54 if( opts.hasKey("damping") )
55 props.damping = opts.getStringAs<Real>("damping");
56 else
57 props.damping = 0.0;
58 if( opts.hasKey("inference") )
59 props.inference = opts.getStringAs<Properties::InfType>("inference");
60 else
61 props.inference = Properties::InfType::SUMPROD;
62 }
63
64
65 PropertySet BP::getProperties() const {
66 PropertySet opts;
67 opts.set( "tol", props.tol );
68 opts.set( "maxiter", props.maxiter );
69 opts.set( "maxtime", props.maxtime );
70 opts.set( "verbose", props.verbose );
71 opts.set( "logdomain", props.logdomain );
72 opts.set( "updates", props.updates );
73 opts.set( "damping", props.damping );
74 opts.set( "inference", props.inference );
75 return opts;
76 }
77
78
79 string BP::printProperties() const {
80 stringstream s( stringstream::out );
81 s << "[";
82 s << "tol=" << props.tol << ",";
83 s << "maxiter=" << props.maxiter << ",";
84 s << "maxtime=" << props.maxtime << ",";
85 s << "verbose=" << props.verbose << ",";
86 s << "logdomain=" << props.logdomain << ",";
87 s << "updates=" << props.updates << ",";
88 s << "damping=" << props.damping << ",";
89 s << "inference=" << props.inference << "]";
90 return s.str();
91 }
92
93
94 void BP::construct() {
95 // create edge properties
96 _edges.clear();
97 _edges.reserve( nrVars() );
98 _edge2lut.clear();
99 if( props.updates == Properties::UpdateType::SEQMAX )
100 _edge2lut.reserve( nrVars() );
101 for( size_t i = 0; i < nrVars(); ++i ) {
102 _edges.push_back( vector<EdgeProp>() );
103 _edges[i].reserve( nbV(i).size() );
104 if( props.updates == Properties::UpdateType::SEQMAX ) {
105 _edge2lut.push_back( vector<LutType::iterator>() );
106 _edge2lut[i].reserve( nbV(i).size() );
107 }
108 bforeach( const Neighbor &I, nbV(i) ) {
109 EdgeProp newEP;
110 newEP.message = Prob( var(i).states() );
111 newEP.newMessage = Prob( var(i).states() );
112
113 if( DAI_BP_FAST ) {
114 newEP.index.reserve( factor(I).nrStates() );
115 for( IndexFor k( var(i), factor(I).vars() ); k.valid(); ++k )
116 newEP.index.push_back( k );
117 }
118
119 newEP.residual = 0.0;
120 _edges[i].push_back( newEP );
121 if( props.updates == Properties::UpdateType::SEQMAX )
122 _edge2lut[i].push_back( _lut.insert( make_pair( newEP.residual, make_pair( i, _edges[i].size() - 1 ))) );
123 }
124 }
125
126 // create old beliefs
127 _oldBeliefsV.clear();
128 _oldBeliefsV.reserve( nrVars() );
129 for( size_t i = 0; i < nrVars(); ++i )
130 _oldBeliefsV.push_back( Factor( var(i) ) );
131 _oldBeliefsF.clear();
132 _oldBeliefsF.reserve( nrFactors() );
133 for( size_t I = 0; I < nrFactors(); ++I )
134 _oldBeliefsF.push_back( Factor( factor(I).vars() ) );
135
136 // create update sequence
137 _updateSeq.clear();
138 _updateSeq.reserve( nrEdges() );
139 for( size_t I = 0; I < nrFactors(); I++ )
140 bforeach( const Neighbor &i, nbF(I) )
141 _updateSeq.push_back( Edge( i, i.dual ) );
142 }
143
144
145 void BP::init() {
146 Real c = props.logdomain ? 0.0 : 1.0;
147 for( size_t i = 0; i < nrVars(); ++i ) {
148 bforeach( const Neighbor &I, nbV(i) ) {
149 message( i, I.iter ).fill( c );
150 newMessage( i, I.iter ).fill( c );
151 if( props.updates == Properties::UpdateType::SEQMAX )
152 updateResidual( i, I.iter, 0.0 );
153 }
154 }
155 _iters = 0;
156 }
157
158
159 void BP::findMaxResidual( size_t &i, size_t &_I ) {
160 DAI_ASSERT( !_lut.empty() );
161 LutType::const_iterator largestEl = _lut.end();
162 --largestEl;
163 i = largestEl->second.first;
164 _I = largestEl->second.second;
165 }
166
167
168 Prob BP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
169 Factor Fprod( factor(I) );
170 Prob &prod = Fprod.p();
171 if( props.logdomain )
172 prod.takeLog();
173
174 // Calculate product of incoming messages and factor I
175 bforeach( const Neighbor &j, nbF(I) )
176 if( !(without_i && (j == i)) ) {
177 // prod_j will be the product of messages coming into j
178 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
179 bforeach( const Neighbor &J, nbV(j) )
180 if( J != I ) { // for all J in nb(j) \ I
181 if( props.logdomain )
182 prod_j += message( j, J.iter );
183 else
184 prod_j *= message( j, J.iter );
185 }
186
187 // multiply prod with prod_j
188 if( !DAI_BP_FAST ) {
189 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
190 if( props.logdomain )
191 Fprod += Factor( var(j), prod_j );
192 else
193 Fprod *= Factor( var(j), prod_j );
194 } else {
195 // OPTIMIZED VERSION
196 size_t _I = j.dual;
197 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
198 const ind_t &ind = index(j, _I);
199
200 for( size_t r = 0; r < prod.size(); ++r )
201 if( props.logdomain )
202 prod.set( r, prod[r] + prod_j[ind[r]] );
203 else
204 prod.set( r, prod[r] * prod_j[ind[r]] );
205 }
206 }
207 return prod;
208 }
209
210
211 void BP::calcNewMessage( size_t i, size_t _I ) {
212 // calculate updated message I->i
213 size_t I = nbV(i,_I);
214
215 Prob marg;
216 if( factor(I).vars().size() == 1 ) // optimization
217 marg = factor(I).p();
218 else {
219 Factor Fprod( factor(I) );
220 Prob &prod = Fprod.p();
221 prod = calcIncomingMessageProduct( I, true, i );
222
223 if( props.logdomain ) {
224 prod -= prod.max();
225 prod.takeExp();
226 }
227
228 // Marginalize onto i
229 if( !DAI_BP_FAST ) {
230 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
231 if( props.inference == Properties::InfType::SUMPROD )
232 marg = Fprod.marginal( var(i) ).p();
233 else
234 marg = Fprod.maxMarginal( var(i) ).p();
235 } else {
236 // OPTIMIZED VERSION
237 marg = Prob( var(i).states(), 0.0 );
238 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
239 const ind_t ind = index(i,_I);
240 if( props.inference == Properties::InfType::SUMPROD )
241 for( size_t r = 0; r < prod.size(); ++r )
242 marg.set( ind[r], marg[ind[r]] + prod[r] );
243 else
244 for( size_t r = 0; r < prod.size(); ++r )
245 if( prod[r] > marg[ind[r]] )
246 marg.set( ind[r], prod[r] );
247 marg.normalize();
248 }
249 }
250
251 // Store result
252 if( props.logdomain )
253 newMessage(i,_I) = marg.log();
254 else
255 newMessage(i,_I) = marg;
256
257 // Update the residual if necessary
258 if( props.updates == Properties::UpdateType::SEQMAX )
259 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), DISTLINF ) );
260 }
261
262
263 // BP::run does not check for NANs for performance reasons
264 // Somehow NaNs do not often occur in BP...
265 Real BP::run() {
266 if( props.verbose >= 1 )
267 cerr << "Starting " << identify() << "...";
268 if( props.verbose >= 3)
269 cerr << endl;
270
271 double tic = toc();
272
273 // do several passes over the network until maximum number of iterations has
274 // been reached or until the maximum belief difference is smaller than tolerance
275 Real maxDiff = INFINITY;
276 for( ; _iters < props.maxiter && maxDiff > props.tol && (toc() - tic) < props.maxtime; _iters++ ) {
277 if( props.updates == Properties::UpdateType::SEQMAX ) {
278 if( _iters == 0 ) {
279 // do the first pass
280 for( size_t i = 0; i < nrVars(); ++i )
281 bforeach( const Neighbor &I, nbV(i) )
282 calcNewMessage( i, I.iter );
283 }
284 // Maximum-Residual BP [\ref EMK06]
285 for( size_t t = 0; t < _updateSeq.size(); ++t ) {
286 // update the message with the largest residual
287 size_t i, _I;
288 findMaxResidual( i, _I );
289 updateMessage( i, _I );
290
291 // I->i has been updated, which means that residuals for all
292 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
293 bforeach( const Neighbor &J, nbV(i) ) {
294 if( J.iter != _I ) {
295 bforeach( const Neighbor &j, nbF(J) ) {
296 size_t _J = j.dual;
297 if( j != i )
298 calcNewMessage( j, _J );
299 }
300 }
301 }
302 }
303 } else if( props.updates == Properties::UpdateType::PARALL ) {
304 // Parallel updates
305 for( size_t i = 0; i < nrVars(); ++i )
306 bforeach( const Neighbor &I, nbV(i) )
307 calcNewMessage( i, I.iter );
308
309 for( size_t i = 0; i < nrVars(); ++i )
310 bforeach( const Neighbor &I, nbV(i) )
311 updateMessage( i, I.iter );
312 } else {
313 // Sequential updates
314 if( props.updates == Properties::UpdateType::SEQRND )
315 random_shuffle( _updateSeq.begin(), _updateSeq.end(), rnd );
316
317 bforeach( const Edge &e, _updateSeq ) {
318 calcNewMessage( e.first, e.second );
319 updateMessage( e.first, e.second );
320 }
321 }
322
323 // calculate new beliefs and compare with old ones
324 maxDiff = -INFINITY;
325 for( size_t i = 0; i < nrVars(); ++i ) {
326 Factor b( beliefV(i) );
327 maxDiff = std::max( maxDiff, dist( b, _oldBeliefsV[i], DISTLINF ) );
328 _oldBeliefsV[i] = b;
329 }
330 for( size_t I = 0; I < nrFactors(); ++I ) {
331 Factor b( beliefF(I) );
332 maxDiff = std::max( maxDiff, dist( b, _oldBeliefsF[I], DISTLINF ) );
333 _oldBeliefsF[I] = b;
334 }
335
336 if( props.verbose >= 3 )
337 cerr << name() << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
338 }
339
340 if( maxDiff > _maxdiff )
341 _maxdiff = maxDiff;
342
343 if( props.verbose >= 1 ) {
344 if( maxDiff > props.tol ) {
345 if( props.verbose == 1 )
346 cerr << endl;
347 cerr << name() << "::run: WARNING: not converged after " << _iters << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
348 } else {
349 if( props.verbose >= 3 )
350 cerr << name() << "::run: ";
351 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
352 }
353 }
354
355 return maxDiff;
356 }
357
358
359 void BP::calcBeliefV( size_t i, Prob &p ) const {
360 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
361 bforeach( const Neighbor &I, nbV(i) )
362 if( props.logdomain )
363 p += newMessage( i, I.iter );
364 else
365 p *= newMessage( i, I.iter );
366 }
367
368
369 Factor BP::beliefV( size_t i ) const {
370 Prob p;
371 calcBeliefV( i, p );
372
373 if( props.logdomain ) {
374 p -= p.max();
375 p.takeExp();
376 }
377 p.normalize();
378
379 return( Factor( var(i), p ) );
380 }
381
382
383 Factor BP::beliefF( size_t I ) const {
384 Prob p;
385 calcBeliefF( I, p );
386
387 if( props.logdomain ) {
388 p -= p.max();
389 p.takeExp();
390 }
391 p.normalize();
392
393 return( Factor( factor(I).vars(), p ) );
394 }
395
396
397 vector<Factor> BP::beliefs() const {
398 vector<Factor> result;
399 for( size_t i = 0; i < nrVars(); ++i )
400 result.push_back( beliefV(i) );
401 for( size_t I = 0; I < nrFactors(); ++I )
402 result.push_back( beliefF(I) );
403 return result;
404 }
405
406
407 Factor BP::belief( const VarSet &ns ) const {
408 if( ns.size() == 0 )
409 return Factor();
410 else if( ns.size() == 1 )
411 return beliefV( findVar( *(ns.begin() ) ) );
412 else {
413 size_t I;
414 for( I = 0; I < nrFactors(); I++ )
415 if( factor(I).vars() >> ns )
416 break;
417 if( I == nrFactors() )
418 DAI_THROW(BELIEF_NOT_AVAILABLE);
419 return beliefF(I).marginal(ns);
420 }
421 }
422
423
424 Real BP::logZ() const {
425 Real sum = 0.0;
426 for( size_t i = 0; i < nrVars(); ++i )
427 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
428 for( size_t I = 0; I < nrFactors(); ++I )
429 sum -= dist( beliefF(I), factor(I), DISTKL );
430 return sum;
431 }
432
433
434 void BP::init( const VarSet &ns ) {
435 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
436 size_t ni = findVar( *n );
437 bforeach( const Neighbor &I, nbV( ni ) ) {
438 Real val = props.logdomain ? 0.0 : 1.0;
439 message( ni, I.iter ).fill( val );
440 newMessage( ni, I.iter ).fill( val );
441 if( props.updates == Properties::UpdateType::SEQMAX )
442 updateResidual( ni, I.iter, 0.0 );
443 }
444 }
445 _iters = 0;
446 }
447
448
449 void BP::updateMessage( size_t i, size_t _I ) {
450 if( recordSentMessages )
451 _sentMessages.push_back(make_pair(i,_I));
452 if( props.damping == 0.0 ) {
453 message(i,_I) = newMessage(i,_I);
454 if( props.updates == Properties::UpdateType::SEQMAX )
455 updateResidual( i, _I, 0.0 );
456 } else {
457 if( props.logdomain )
458 message(i,_I) = (message(i,_I) * props.damping) + (newMessage(i,_I) * (1.0 - props.damping));
459 else
460 message(i,_I) = (message(i,_I) ^ props.damping) * (newMessage(i,_I) ^ (1.0 - props.damping));
461 if( props.updates == Properties::UpdateType::SEQMAX )
462 updateResidual( i, _I, dist( newMessage(i,_I), message(i,_I), DISTLINF ) );
463 }
464 }
465
466
467 void BP::updateResidual( size_t i, size_t _I, Real r ) {
468 EdgeProp* pEdge = &_edges[i][_I];
469 pEdge->residual = r;
470
471 // rearrange look-up table (delete and reinsert new key)
472 _lut.erase( _edge2lut[i][_I] );
473 _edge2lut[i][_I] = _lut.insert( make_pair( r, make_pair(i, _I) ) );
474 }
475
476
477 } // end of namespace dai
478
479
480 #endif