1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
20 #include <dai/properties.h>
29 const char *BP::Name
= "BP";
35 void BP::setProperties( const PropertySet
&opts
) {
36 DAI_ASSERT( opts
.hasKey("tol") );
37 DAI_ASSERT( opts
.hasKey("logdomain") );
38 DAI_ASSERT( opts
.hasKey("updates") );
40 props
.tol
= opts
.getStringAs
<Real
>("tol");
41 props
.logdomain
= opts
.getStringAs
<bool>("logdomain");
42 props
.updates
= opts
.getStringAs
<Properties::UpdateType
>("updates");
44 if( opts
.hasKey("maxiter") )
45 props
.maxiter
= opts
.getStringAs
<size_t>("maxiter");
47 props
.maxiter
= 10000;
48 if( opts
.hasKey("maxtime") )
49 props
.maxtime
= opts
.getStringAs
<Real
>("maxtime");
51 props
.maxtime
= INFINITY
;
52 if( opts
.hasKey("verbose") )
53 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
56 if( opts
.hasKey("damping") )
57 props
.damping
= opts
.getStringAs
<Real
>("damping");
60 if( opts
.hasKey("inference") )
61 props
.inference
= opts
.getStringAs
<Properties::InfType
>("inference");
63 props
.inference
= Properties::InfType::SUMPROD
;
67 PropertySet
BP::getProperties() const {
69 opts
.set( "tol", props
.tol
);
70 opts
.set( "maxiter", props
.maxiter
);
71 opts
.set( "maxtime", props
.maxtime
);
72 opts
.set( "verbose", props
.verbose
);
73 opts
.set( "logdomain", props
.logdomain
);
74 opts
.set( "updates", props
.updates
);
75 opts
.set( "damping", props
.damping
);
76 opts
.set( "inference", props
.inference
);
81 string
BP::printProperties() const {
82 stringstream
s( stringstream::out
);
84 s
<< "tol=" << props
.tol
<< ",";
85 s
<< "maxiter=" << props
.maxiter
<< ",";
86 s
<< "maxtime=" << props
.maxtime
<< ",";
87 s
<< "verbose=" << props
.verbose
<< ",";
88 s
<< "logdomain=" << props
.logdomain
<< ",";
89 s
<< "updates=" << props
.updates
<< ",";
90 s
<< "damping=" << props
.damping
<< ",";
91 s
<< "inference=" << props
.inference
<< "]";
96 void BP::construct() {
97 // create edge properties
99 _edges
.reserve( nrVars() );
101 if( props
.updates
== Properties::UpdateType::SEQMAX
)
102 _edge2lut
.reserve( nrVars() );
103 for( size_t i
= 0; i
< nrVars(); ++i
) {
104 _edges
.push_back( vector
<EdgeProp
>() );
105 _edges
[i
].reserve( nbV(i
).size() );
106 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
107 _edge2lut
.push_back( vector
<LutType::iterator
>() );
108 _edge2lut
[i
].reserve( nbV(i
).size() );
110 foreach( const Neighbor
&I
, nbV(i
) ) {
112 newEP
.message
= Prob( var(i
).states() );
113 newEP
.newMessage
= Prob( var(i
).states() );
116 newEP
.index
.reserve( factor(I
).nrStates() );
117 for( IndexFor
k( var(i
), factor(I
).vars() ); k
.valid(); ++k
)
118 newEP
.index
.push_back( k
);
121 newEP
.residual
= 0.0;
122 _edges
[i
].push_back( newEP
);
123 if( props
.updates
== Properties::UpdateType::SEQMAX
)
124 _edge2lut
[i
].push_back( _lut
.insert( make_pair( newEP
.residual
, make_pair( i
, _edges
[i
].size() - 1 ))) );
128 // create old beliefs
130 oldBeliefsV
.reserve( nrVars() );
131 for( size_t i
= 0; i
< nrVars(); ++i
)
132 oldBeliefsV
.push_back( Factor( var(i
) ) );
134 oldBeliefsF
.reserve( nrFactors() );
135 for( size_t I
= 0; I
< nrFactors(); ++I
)
136 oldBeliefsF
.push_back( Factor( factor(I
).vars() ) );
138 // create update sequence
140 updateSeq
.reserve( nrEdges() );
141 for( size_t I
= 0; I
< nrFactors(); I
++ )
142 foreach( const Neighbor
&i
, nbF(I
) )
143 updateSeq
.push_back( Edge( i
, i
.dual
) );
148 Real c
= props
.logdomain
? 0.0 : 1.0;
149 for( size_t i
= 0; i
< nrVars(); ++i
) {
150 foreach( const Neighbor
&I
, nbV(i
) ) {
151 message( i
, I
.iter
).fill( c
);
152 newMessage( i
, I
.iter
).fill( c
);
153 if( props
.updates
== Properties::UpdateType::SEQMAX
)
154 updateResidual( i
, I
.iter
, 0.0 );
161 void BP::findMaxResidual( size_t &i
, size_t &_I
) {
162 DAI_ASSERT( !_lut
.empty() );
163 LutType::const_iterator largestEl
= _lut
.end();
165 i
= largestEl
->second
.first
;
166 _I
= largestEl
->second
.second
;
170 Prob
BP::calcIncomingMessageProduct( size_t I
, bool without_i
, size_t i
) const {
171 Factor
Fprod( factor(I
) );
172 Prob
&prod
= Fprod
.p();
173 if( props
.logdomain
)
176 // Calculate product of incoming messages and factor I
177 foreach( const Neighbor
&j
, nbF(I
) )
178 if( !(without_i
&& (j
== i
)) ) {
179 // prod_j will be the product of messages coming into j
180 Prob
prod_j( var(j
).states(), props
.logdomain
? 0.0 : 1.0 );
181 foreach( const Neighbor
&J
, nbV(j
) )
182 if( J
!= I
) { // for all J in nb(j) \ I
183 if( props
.logdomain
)
184 prod_j
+= message( j
, J
.iter
);
186 prod_j
*= message( j
, J
.iter
);
189 // multiply prod with prod_j
191 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
192 if( props
.logdomain
)
193 Fprod
+= Factor( var(j
), prod_j
);
195 Fprod
*= Factor( var(j
), prod_j
);
199 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
200 const ind_t
&ind
= index(j
, _I
);
202 for( size_t r
= 0; r
< prod
.size(); ++r
)
203 if( props
.logdomain
)
204 prod
.set( r
, prod
[r
] + prod_j
[ind
[r
]] );
206 prod
.set( r
, prod
[r
] * prod_j
[ind
[r
]] );
213 void BP::calcNewMessage( size_t i
, size_t _I
) {
214 // calculate updated message I->i
215 size_t I
= nbV(i
,_I
);
218 if( factor(I
).vars().size() == 1 ) // optimization
219 marg
= factor(I
).p();
221 Factor
Fprod( factor(I
) );
222 Prob
&prod
= Fprod
.p();
223 prod
= calcIncomingMessageProduct( I
, true, i
);
225 if( props
.logdomain
) {
230 // Marginalize onto i
232 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
233 if( props
.inference
== Properties::InfType::SUMPROD
)
234 marg
= Fprod
.marginal( var(i
) ).p();
236 marg
= Fprod
.maxMarginal( var(i
) ).p();
239 marg
= Prob( var(i
).states(), 0.0 );
240 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
241 const ind_t ind
= index(i
,_I
);
242 if( props
.inference
== Properties::InfType::SUMPROD
)
243 for( size_t r
= 0; r
< prod
.size(); ++r
)
244 marg
.set( ind
[r
], marg
[ind
[r
]] + prod
[r
] );
246 for( size_t r
= 0; r
< prod
.size(); ++r
)
247 if( prod
[r
] > marg
[ind
[r
]] )
248 marg
.set( ind
[r
], prod
[r
] );
254 if( props
.logdomain
)
255 newMessage(i
,_I
) = marg
.log();
257 newMessage(i
,_I
) = marg
;
259 // Update the residual if necessary
260 if( props
.updates
== Properties::UpdateType::SEQMAX
)
261 updateResidual( i
, _I
, dist( newMessage( i
, _I
), message( i
, _I
), DISTLINF
) );
265 // BP::run does not check for NANs for performance reasons
266 // Somehow NaNs do not often occur in BP...
268 if( props
.verbose
>= 1 )
269 cerr
<< "Starting " << identify() << "...";
270 if( props
.verbose
>= 3)
275 // do several passes over the network until maximum number of iterations has
276 // been reached or until the maximum belief difference is smaller than tolerance
277 Real maxDiff
= INFINITY
;
278 for( ; _iters
< props
.maxiter
&& maxDiff
> props
.tol
&& (toc() - tic
) < props
.maxtime
; _iters
++ ) {
279 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
282 for( size_t i
= 0; i
< nrVars(); ++i
)
283 foreach( const Neighbor
&I
, nbV(i
) )
284 calcNewMessage( i
, I
.iter
);
286 // Maximum-Residual BP [\ref EMK06]
287 for( size_t t
= 0; t
< updateSeq
.size(); ++t
) {
288 // update the message with the largest residual
290 findMaxResidual( i
, _I
);
291 updateMessage( i
, _I
);
293 // I->i has been updated, which means that residuals for all
294 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
295 foreach( const Neighbor
&J
, nbV(i
) ) {
297 foreach( const Neighbor
&j
, nbF(J
) ) {
300 calcNewMessage( j
, _J
);
305 } else if( props
.updates
== Properties::UpdateType::PARALL
) {
307 for( size_t i
= 0; i
< nrVars(); ++i
)
308 foreach( const Neighbor
&I
, nbV(i
) )
309 calcNewMessage( i
, I
.iter
);
311 for( size_t i
= 0; i
< nrVars(); ++i
)
312 foreach( const Neighbor
&I
, nbV(i
) )
313 updateMessage( i
, I
.iter
);
315 // Sequential updates
316 if( props
.updates
== Properties::UpdateType::SEQRND
)
317 random_shuffle( updateSeq
.begin(), updateSeq
.end() );
319 foreach( const Edge
&e
, updateSeq
) {
320 calcNewMessage( e
.first
, e
.second
);
321 updateMessage( e
.first
, e
.second
);
325 // calculate new beliefs and compare with old ones
327 for( size_t i
= 0; i
< nrVars(); ++i
) {
328 Factor
b( beliefV(i
) );
329 maxDiff
= std::max( maxDiff
, dist( b
, oldBeliefsV
[i
], DISTLINF
) );
332 for( size_t I
= 0; I
< nrFactors(); ++I
) {
333 Factor
b( beliefF(I
) );
334 maxDiff
= std::max( maxDiff
, dist( b
, oldBeliefsF
[I
], DISTLINF
) );
338 if( props
.verbose
>= 3 )
339 cerr
<< Name
<< "::run: maxdiff " << maxDiff
<< " after " << _iters
+1 << " passes" << endl
;
342 if( maxDiff
> _maxdiff
)
345 if( props
.verbose
>= 1 ) {
346 if( maxDiff
> props
.tol
) {
347 if( props
.verbose
== 1 )
349 cerr
<< Name
<< "::run: WARNING: not converged after " << _iters
<< " passes (" << toc() - tic
<< " seconds)...final maxdiff:" << maxDiff
<< endl
;
351 if( props
.verbose
>= 3 )
352 cerr
<< Name
<< "::run: ";
353 cerr
<< "converged in " << _iters
<< " passes (" << toc() - tic
<< " seconds)." << endl
;
361 void BP::calcBeliefV( size_t i
, Prob
&p
) const {
362 p
= Prob( var(i
).states(), props
.logdomain
? 0.0 : 1.0 );
363 foreach( const Neighbor
&I
, nbV(i
) )
364 if( props
.logdomain
)
365 p
+= newMessage( i
, I
.iter
);
367 p
*= newMessage( i
, I
.iter
);
371 Factor
BP::beliefV( size_t i
) const {
375 if( props
.logdomain
) {
381 return( Factor( var(i
), p
) );
385 Factor
BP::beliefF( size_t I
) const {
389 if( props
.logdomain
) {
395 return( Factor( factor(I
).vars(), p
) );
399 vector
<Factor
> BP::beliefs() const {
400 vector
<Factor
> result
;
401 for( size_t i
= 0; i
< nrVars(); ++i
)
402 result
.push_back( beliefV(i
) );
403 for( size_t I
= 0; I
< nrFactors(); ++I
)
404 result
.push_back( beliefF(I
) );
409 Factor
BP::belief( const VarSet
&ns
) const {
412 else if( ns
.size() == 1 )
413 return beliefV( findVar( *(ns
.begin() ) ) );
416 for( I
= 0; I
< nrFactors(); I
++ )
417 if( factor(I
).vars() >> ns
)
419 if( I
== nrFactors() )
420 DAI_THROW(BELIEF_NOT_AVAILABLE
);
421 return beliefF(I
).marginal(ns
);
426 Real
BP::logZ() const {
428 for( size_t i
= 0; i
< nrVars(); ++i
)
429 sum
+= (1.0 - nbV(i
).size()) * beliefV(i
).entropy();
430 for( size_t I
= 0; I
< nrFactors(); ++I
)
431 sum
-= dist( beliefF(I
), factor(I
), DISTKL
);
436 string
BP::identify() const {
437 return string(Name
) + printProperties();
441 void BP::init( const VarSet
&ns
) {
442 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); ++n
) {
443 size_t ni
= findVar( *n
);
444 foreach( const Neighbor
&I
, nbV( ni
) ) {
445 Real val
= props
.logdomain
? 0.0 : 1.0;
446 message( ni
, I
.iter
).fill( val
);
447 newMessage( ni
, I
.iter
).fill( val
);
448 if( props
.updates
== Properties::UpdateType::SEQMAX
)
449 updateResidual( ni
, I
.iter
, 0.0 );
456 void BP::updateMessage( size_t i
, size_t _I
) {
457 if( recordSentMessages
)
458 _sentMessages
.push_back(make_pair(i
,_I
));
459 if( props
.damping
== 0.0 ) {
460 message(i
,_I
) = newMessage(i
,_I
);
461 if( props
.updates
== Properties::UpdateType::SEQMAX
)
462 updateResidual( i
, _I
, 0.0 );
464 if( props
.logdomain
)
465 message(i
,_I
) = (message(i
,_I
) * props
.damping
) + (newMessage(i
,_I
) * (1.0 - props
.damping
));
467 message(i
,_I
) = (message(i
,_I
) ^ props
.damping
) * (newMessage(i
,_I
) ^ (1.0 - props
.damping
));
468 if( props
.updates
== Properties::UpdateType::SEQMAX
)
469 updateResidual( i
, _I
, dist( newMessage(i
,_I
), message(i
,_I
), DISTLINF
) );
474 void BP::updateResidual( size_t i
, size_t _I
, Real r
) {
475 EdgeProp
* pEdge
= &_edges
[i
][_I
];
478 // rearrange look-up table (delete and reinsert new key)
479 _lut
.erase( _edge2lut
[i
][_I
] );
480 _edge2lut
[i
][_I
] = _lut
.insert( make_pair( r
, make_pair(i
, _I
) ) );
484 std::vector
<size_t> BP::findMaximum() const {
485 vector
<size_t> maximum( nrVars() );
486 vector
<bool> visitedVars( nrVars(), false );
487 vector
<bool> visitedFactors( nrFactors(), false );
488 stack
<size_t> scheduledFactors
;
489 for( size_t i
= 0; i
< nrVars(); ++i
) {
492 visitedVars
[i
] = true;
494 // Maximise with respect to variable i
496 calcBeliefV( i
, prod
);
497 maximum
[i
] = prod
.argmax().first
;
499 foreach( const Neighbor
&I
, nbV(i
) )
500 if( !visitedFactors
[I
] )
501 scheduledFactors
.push(I
);
503 while( !scheduledFactors
.empty() ){
504 size_t I
= scheduledFactors
.top();
505 scheduledFactors
.pop();
506 if( visitedFactors
[I
] )
508 visitedFactors
[I
] = true;
510 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
511 bool allDetermined
= true;
512 foreach( const Neighbor
&j
, nbF(I
) )
513 if( !visitedVars
[j
.node
] ) {
514 allDetermined
= false;
520 // Calculate product of incoming messages on factor I
522 calcBeliefF( I
, prod2
);
524 // The allowed configuration is restrained according to the variables assigned so far:
525 // pick the argmax amongst the allowed states
526 Real maxProb
= -numeric_limits
<Real
>::max();
527 State
maxState( factor(I
).vars() );
528 for( State
s( factor(I
).vars() ); s
.valid(); ++s
){
529 // First, calculate whether this state is consistent with variables that
530 // have been assigned already
531 bool allowedState
= true;
532 foreach( const Neighbor
&j
, nbF(I
) )
533 if( visitedVars
[j
.node
] && maximum
[j
.node
] != s(var(j
.node
)) ) {
534 allowedState
= false;
537 // If it is consistent, check if its probability is larger than what we have seen so far
538 if( allowedState
&& prod2
[s
] > maxProb
) {
545 foreach( const Neighbor
&j
, nbF(I
) ) {
546 if( visitedVars
[j
.node
] ) {
547 // We have already visited j earlier - hopefully our state is consistent
548 if( maximum
[j
.node
] != maxState(var(j
.node
)) && props
.verbose
>= 1 )
549 cerr
<< "BP::findMaximum - warning: maximum not consistent due to loops." << endl
;
551 // We found a consistent state for variable j
552 visitedVars
[j
.node
] = true;
553 maximum
[j
.node
] = maxState( var(j
.node
) );
554 foreach( const Neighbor
&J
, nbV(j
) )
555 if( !visitedFactors
[J
] )
556 scheduledFactors
.push(J
);
565 } // end of namespace dai