1 /* This file is part of libDAI - http://www.libdai.org/
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
9 #include <dai/dai_config.h>
20 #include <dai/properties.h>
30 /// \todo Make DAI_BP_FAST a compile-time choice, as it is a memory/speed tradeoff
33 void BP::setProperties( const PropertySet
&opts
) {
34 DAI_ASSERT( opts
.hasKey("tol") );
35 DAI_ASSERT( opts
.hasKey("logdomain") );
36 DAI_ASSERT( opts
.hasKey("updates") );
38 props
.tol
= opts
.getStringAs
<Real
>("tol");
39 props
.logdomain
= opts
.getStringAs
<bool>("logdomain");
40 props
.updates
= opts
.getStringAs
<Properties::UpdateType
>("updates");
42 if( opts
.hasKey("maxiter") )
43 props
.maxiter
= opts
.getStringAs
<size_t>("maxiter");
45 props
.maxiter
= 10000;
46 if( opts
.hasKey("maxtime") )
47 props
.maxtime
= opts
.getStringAs
<Real
>("maxtime");
49 props
.maxtime
= INFINITY
;
50 if( opts
.hasKey("verbose") )
51 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
54 if( opts
.hasKey("damping") )
55 props
.damping
= opts
.getStringAs
<Real
>("damping");
58 if( opts
.hasKey("inference") )
59 props
.inference
= opts
.getStringAs
<Properties::InfType
>("inference");
61 props
.inference
= Properties::InfType::SUMPROD
;
65 PropertySet
BP::getProperties() const {
67 opts
.set( "tol", props
.tol
);
68 opts
.set( "maxiter", props
.maxiter
);
69 opts
.set( "maxtime", props
.maxtime
);
70 opts
.set( "verbose", props
.verbose
);
71 opts
.set( "logdomain", props
.logdomain
);
72 opts
.set( "updates", props
.updates
);
73 opts
.set( "damping", props
.damping
);
74 opts
.set( "inference", props
.inference
);
79 string
BP::printProperties() const {
80 stringstream
s( stringstream::out
);
82 s
<< "tol=" << props
.tol
<< ",";
83 s
<< "maxiter=" << props
.maxiter
<< ",";
84 s
<< "maxtime=" << props
.maxtime
<< ",";
85 s
<< "verbose=" << props
.verbose
<< ",";
86 s
<< "logdomain=" << props
.logdomain
<< ",";
87 s
<< "updates=" << props
.updates
<< ",";
88 s
<< "damping=" << props
.damping
<< ",";
89 s
<< "inference=" << props
.inference
<< "]";
94 void BP::construct() {
95 // create edge properties
97 _edges
.reserve( nrVars() );
99 if( props
.updates
== Properties::UpdateType::SEQMAX
)
100 _edge2lut
.reserve( nrVars() );
101 for( size_t i
= 0; i
< nrVars(); ++i
) {
102 _edges
.push_back( vector
<EdgeProp
>() );
103 _edges
[i
].reserve( nbV(i
).size() );
104 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
105 _edge2lut
.push_back( vector
<LutType::iterator
>() );
106 _edge2lut
[i
].reserve( nbV(i
).size() );
108 bforeach( const Neighbor
&I
, nbV(i
) ) {
110 newEP
.message
= Prob( var(i
).states() );
111 newEP
.newMessage
= Prob( var(i
).states() );
114 newEP
.index
.reserve( factor(I
).nrStates() );
115 for( IndexFor
k( var(i
), factor(I
).vars() ); k
.valid(); ++k
)
116 newEP
.index
.push_back( k
);
119 newEP
.residual
= 0.0;
120 _edges
[i
].push_back( newEP
);
121 if( props
.updates
== Properties::UpdateType::SEQMAX
)
122 _edge2lut
[i
].push_back( _lut
.insert( make_pair( newEP
.residual
, make_pair( i
, _edges
[i
].size() - 1 ))) );
126 // create old beliefs
127 _oldBeliefsV
.clear();
128 _oldBeliefsV
.reserve( nrVars() );
129 for( size_t i
= 0; i
< nrVars(); ++i
)
130 _oldBeliefsV
.push_back( Factor( var(i
) ) );
131 _oldBeliefsF
.clear();
132 _oldBeliefsF
.reserve( nrFactors() );
133 for( size_t I
= 0; I
< nrFactors(); ++I
)
134 _oldBeliefsF
.push_back( Factor( factor(I
).vars() ) );
136 // create update sequence
138 _updateSeq
.reserve( nrEdges() );
139 for( size_t I
= 0; I
< nrFactors(); I
++ )
140 bforeach( const Neighbor
&i
, nbF(I
) )
141 _updateSeq
.push_back( Edge( i
, i
.dual
) );
146 Real c
= props
.logdomain
? 0.0 : 1.0;
147 for( size_t i
= 0; i
< nrVars(); ++i
) {
148 bforeach( const Neighbor
&I
, nbV(i
) ) {
149 message( i
, I
.iter
).fill( c
);
150 newMessage( i
, I
.iter
).fill( c
);
151 if( props
.updates
== Properties::UpdateType::SEQMAX
)
152 updateResidual( i
, I
.iter
, 0.0 );
159 void BP::findMaxResidual( size_t &i
, size_t &_I
) {
160 DAI_ASSERT( !_lut
.empty() );
161 LutType::const_iterator largestEl
= _lut
.end();
163 i
= largestEl
->second
.first
;
164 _I
= largestEl
->second
.second
;
168 Prob
BP::calcIncomingMessageProduct( size_t I
, bool without_i
, size_t i
) const {
169 Factor
Fprod( factor(I
) );
170 Prob
&prod
= Fprod
.p();
171 if( props
.logdomain
)
174 // Calculate product of incoming messages and factor I
175 bforeach( const Neighbor
&j
, nbF(I
) )
176 if( !(without_i
&& (j
== i
)) ) {
177 // prod_j will be the product of messages coming into j
178 Prob
prod_j( var(j
).states(), props
.logdomain
? 0.0 : 1.0 );
179 bforeach( const Neighbor
&J
, nbV(j
) )
180 if( J
!= I
) { // for all J in nb(j) \ I
181 if( props
.logdomain
)
182 prod_j
+= message( j
, J
.iter
);
184 prod_j
*= message( j
, J
.iter
);
187 // multiply prod with prod_j
189 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
190 if( props
.logdomain
)
191 Fprod
+= Factor( var(j
), prod_j
);
193 Fprod
*= Factor( var(j
), prod_j
);
197 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
198 const ind_t
&ind
= index(j
, _I
);
200 for( size_t r
= 0; r
< prod
.size(); ++r
)
201 if( props
.logdomain
)
202 prod
.set( r
, prod
[r
] + prod_j
[ind
[r
]] );
204 prod
.set( r
, prod
[r
] * prod_j
[ind
[r
]] );
211 void BP::calcNewMessage( size_t i
, size_t _I
) {
212 // calculate updated message I->i
213 size_t I
= nbV(i
,_I
);
216 if( factor(I
).vars().size() == 1 ) // optimization
217 marg
= factor(I
).p();
219 Factor
Fprod( factor(I
) );
220 Prob
&prod
= Fprod
.p();
221 prod
= calcIncomingMessageProduct( I
, true, i
);
223 if( props
.logdomain
) {
228 // Marginalize onto i
230 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
231 if( props
.inference
== Properties::InfType::SUMPROD
)
232 marg
= Fprod
.marginal( var(i
) ).p();
234 marg
= Fprod
.maxMarginal( var(i
) ).p();
237 marg
= Prob( var(i
).states(), 0.0 );
238 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
239 const ind_t ind
= index(i
,_I
);
240 if( props
.inference
== Properties::InfType::SUMPROD
)
241 for( size_t r
= 0; r
< prod
.size(); ++r
)
242 marg
.set( ind
[r
], marg
[ind
[r
]] + prod
[r
] );
244 for( size_t r
= 0; r
< prod
.size(); ++r
)
245 if( prod
[r
] > marg
[ind
[r
]] )
246 marg
.set( ind
[r
], prod
[r
] );
252 if( props
.logdomain
)
253 newMessage(i
,_I
) = marg
.log();
255 newMessage(i
,_I
) = marg
;
257 // Update the residual if necessary
258 if( props
.updates
== Properties::UpdateType::SEQMAX
)
259 updateResidual( i
, _I
, dist( newMessage( i
, _I
), message( i
, _I
), DISTLINF
) );
263 // BP::run does not check for NANs for performance reasons
264 // Somehow NaNs do not often occur in BP...
266 if( props
.verbose
>= 1 )
267 cerr
<< "Starting " << identify() << "...";
268 if( props
.verbose
>= 3)
273 // do several passes over the network until maximum number of iterations has
274 // been reached or until the maximum belief difference is smaller than tolerance
275 Real maxDiff
= INFINITY
;
276 for( ; _iters
< props
.maxiter
&& maxDiff
> props
.tol
&& (toc() - tic
) < props
.maxtime
; _iters
++ ) {
277 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
280 for( size_t i
= 0; i
< nrVars(); ++i
)
281 bforeach( const Neighbor
&I
, nbV(i
) )
282 calcNewMessage( i
, I
.iter
);
284 // Maximum-Residual BP [\ref EMK06]
285 for( size_t t
= 0; t
< _updateSeq
.size(); ++t
) {
286 // update the message with the largest residual
288 findMaxResidual( i
, _I
);
289 updateMessage( i
, _I
);
291 // I->i has been updated, which means that residuals for all
292 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
293 bforeach( const Neighbor
&J
, nbV(i
) ) {
295 bforeach( const Neighbor
&j
, nbF(J
) ) {
298 calcNewMessage( j
, _J
);
303 } else if( props
.updates
== Properties::UpdateType::PARALL
) {
305 for( size_t i
= 0; i
< nrVars(); ++i
)
306 bforeach( const Neighbor
&I
, nbV(i
) )
307 calcNewMessage( i
, I
.iter
);
309 for( size_t i
= 0; i
< nrVars(); ++i
)
310 bforeach( const Neighbor
&I
, nbV(i
) )
311 updateMessage( i
, I
.iter
);
313 // Sequential updates
314 if( props
.updates
== Properties::UpdateType::SEQRND
)
315 random_shuffle( _updateSeq
.begin(), _updateSeq
.end(), rnd
);
317 bforeach( const Edge
&e
, _updateSeq
) {
318 calcNewMessage( e
.first
, e
.second
);
319 updateMessage( e
.first
, e
.second
);
323 // calculate new beliefs and compare with old ones
325 for( size_t i
= 0; i
< nrVars(); ++i
) {
326 Factor
b( beliefV(i
) );
327 maxDiff
= std::max( maxDiff
, dist( b
, _oldBeliefsV
[i
], DISTLINF
) );
330 for( size_t I
= 0; I
< nrFactors(); ++I
) {
331 Factor
b( beliefF(I
) );
332 maxDiff
= std::max( maxDiff
, dist( b
, _oldBeliefsF
[I
], DISTLINF
) );
336 if( props
.verbose
>= 3 )
337 cerr
<< name() << "::run: maxdiff " << maxDiff
<< " after " << _iters
+1 << " passes" << endl
;
340 if( maxDiff
> _maxdiff
)
343 if( props
.verbose
>= 1 ) {
344 if( maxDiff
> props
.tol
) {
345 if( props
.verbose
== 1 )
347 cerr
<< name() << "::run: WARNING: not converged after " << _iters
<< " passes (" << toc() - tic
<< " seconds)...final maxdiff:" << maxDiff
<< endl
;
349 if( props
.verbose
>= 3 )
350 cerr
<< name() << "::run: ";
351 cerr
<< "converged in " << _iters
<< " passes (" << toc() - tic
<< " seconds)." << endl
;
359 void BP::calcBeliefV( size_t i
, Prob
&p
) const {
360 p
= Prob( var(i
).states(), props
.logdomain
? 0.0 : 1.0 );
361 bforeach( const Neighbor
&I
, nbV(i
) )
362 if( props
.logdomain
)
363 p
+= newMessage( i
, I
.iter
);
365 p
*= newMessage( i
, I
.iter
);
369 Factor
BP::beliefV( size_t i
) const {
373 if( props
.logdomain
) {
379 return( Factor( var(i
), p
) );
383 Factor
BP::beliefF( size_t I
) const {
387 if( props
.logdomain
) {
393 return( Factor( factor(I
).vars(), p
) );
397 vector
<Factor
> BP::beliefs() const {
398 vector
<Factor
> result
;
399 for( size_t i
= 0; i
< nrVars(); ++i
)
400 result
.push_back( beliefV(i
) );
401 for( size_t I
= 0; I
< nrFactors(); ++I
)
402 result
.push_back( beliefF(I
) );
407 Factor
BP::belief( const VarSet
&ns
) const {
410 else if( ns
.size() == 1 )
411 return beliefV( findVar( *(ns
.begin() ) ) );
414 for( I
= 0; I
< nrFactors(); I
++ )
415 if( factor(I
).vars() >> ns
)
417 if( I
== nrFactors() )
418 DAI_THROW(BELIEF_NOT_AVAILABLE
);
419 return beliefF(I
).marginal(ns
);
424 Real
BP::logZ() const {
426 for( size_t i
= 0; i
< nrVars(); ++i
)
427 sum
+= (1.0 - nbV(i
).size()) * beliefV(i
).entropy();
428 for( size_t I
= 0; I
< nrFactors(); ++I
)
429 sum
-= dist( beliefF(I
), factor(I
), DISTKL
);
434 void BP::init( const VarSet
&ns
) {
435 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); ++n
) {
436 size_t ni
= findVar( *n
);
437 bforeach( const Neighbor
&I
, nbV( ni
) ) {
438 Real val
= props
.logdomain
? 0.0 : 1.0;
439 message( ni
, I
.iter
).fill( val
);
440 newMessage( ni
, I
.iter
).fill( val
);
441 if( props
.updates
== Properties::UpdateType::SEQMAX
)
442 updateResidual( ni
, I
.iter
, 0.0 );
449 void BP::updateMessage( size_t i
, size_t _I
) {
450 if( recordSentMessages
)
451 _sentMessages
.push_back(make_pair(i
,_I
));
452 if( props
.damping
== 0.0 ) {
453 message(i
,_I
) = newMessage(i
,_I
);
454 if( props
.updates
== Properties::UpdateType::SEQMAX
)
455 updateResidual( i
, _I
, 0.0 );
457 if( props
.logdomain
)
458 message(i
,_I
) = (message(i
,_I
) * props
.damping
) + (newMessage(i
,_I
) * (1.0 - props
.damping
));
460 message(i
,_I
) = (message(i
,_I
) ^ props
.damping
) * (newMessage(i
,_I
) ^ (1.0 - props
.damping
));
461 if( props
.updates
== Properties::UpdateType::SEQMAX
)
462 updateResidual( i
, _I
, dist( newMessage(i
,_I
), message(i
,_I
), DISTLINF
) );
467 void BP::updateResidual( size_t i
, size_t _I
, Real r
) {
468 EdgeProp
* pEdge
= &_edges
[i
][_I
];
471 // rearrange look-up table (delete and reinsert new key)
472 _lut
.erase( _edge2lut
[i
][_I
] );
473 _edge2lut
[i
][_I
] = _lut
.insert( make_pair( r
, make_pair(i
, _I
) ) );
477 } // end of namespace dai