e3043cc0169911236c8c3241fadf1264db66907f
1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
20 #include <dai/properties.h>
29 const char *BP::Name
= "BP";
35 void BP::setProperties( const PropertySet
&opts
) {
36 DAI_ASSERT( opts
.hasKey("tol") );
37 DAI_ASSERT( opts
.hasKey("maxiter") );
38 DAI_ASSERT( opts
.hasKey("logdomain") );
39 DAI_ASSERT( opts
.hasKey("updates") );
41 props
.tol
= opts
.getStringAs
<double>("tol");
42 props
.maxiter
= opts
.getStringAs
<size_t>("maxiter");
43 props
.logdomain
= opts
.getStringAs
<bool>("logdomain");
44 props
.updates
= opts
.getStringAs
<Properties::UpdateType
>("updates");
46 if( opts
.hasKey("verbose") )
47 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
50 if( opts
.hasKey("damping") )
51 props
.damping
= opts
.getStringAs
<double>("damping");
54 if( opts
.hasKey("inference") )
55 props
.inference
= opts
.getStringAs
<Properties::InfType
>("inference");
57 props
.inference
= Properties::InfType::SUMPROD
;
61 PropertySet
BP::getProperties() const {
63 opts
.Set( "tol", props
.tol
);
64 opts
.Set( "maxiter", props
.maxiter
);
65 opts
.Set( "verbose", props
.verbose
);
66 opts
.Set( "logdomain", props
.logdomain
);
67 opts
.Set( "updates", props
.updates
);
68 opts
.Set( "damping", props
.damping
);
69 opts
.Set( "inference", props
.inference
);
74 string
BP::printProperties() const {
75 stringstream
s( stringstream::out
);
77 s
<< "tol=" << props
.tol
<< ",";
78 s
<< "maxiter=" << props
.maxiter
<< ",";
79 s
<< "verbose=" << props
.verbose
<< ",";
80 s
<< "logdomain=" << props
.logdomain
<< ",";
81 s
<< "updates=" << props
.updates
<< ",";
82 s
<< "damping=" << props
.damping
<< ",";
83 s
<< "inference=" << props
.inference
<< "]";
88 void BP::construct() {
89 // create edge properties
91 _edges
.reserve( nrVars() );
93 if( props
.updates
== Properties::UpdateType::SEQMAX
)
94 _edge2lut
.reserve( nrVars() );
95 for( size_t i
= 0; i
< nrVars(); ++i
) {
96 _edges
.push_back( vector
<EdgeProp
>() );
97 _edges
[i
].reserve( nbV(i
).size() );
98 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
99 _edge2lut
.push_back( vector
<LutType::iterator
>() );
100 _edge2lut
[i
].reserve( nbV(i
).size() );
102 foreach( const Neighbor
&I
, nbV(i
) ) {
104 newEP
.message
= Prob( var(i
).states() );
105 newEP
.newMessage
= Prob( var(i
).states() );
108 newEP
.index
.reserve( factor(I
).states() );
109 for( IndexFor
k( var(i
), factor(I
).vars() ); k
.valid(); ++k
)
110 newEP
.index
.push_back( k
);
113 newEP
.residual
= 0.0;
114 _edges
[i
].push_back( newEP
);
115 if( props
.updates
== Properties::UpdateType::SEQMAX
)
116 _edge2lut
[i
].push_back( _lut
.insert( make_pair( newEP
.residual
, make_pair( i
, _edges
[i
].size() - 1 ))) );
123 double c
= props
.logdomain
? 0.0 : 1.0;
124 for( size_t i
= 0; i
< nrVars(); ++i
) {
125 foreach( const Neighbor
&I
, nbV(i
) ) {
126 message( i
, I
.iter
).fill( c
);
127 newMessage( i
, I
.iter
).fill( c
);
128 if( props
.updates
== Properties::UpdateType::SEQMAX
)
129 updateResidual( i
, I
.iter
, 0.0 );
135 void BP::findMaxResidual( size_t &i
, size_t &_I
) {
136 DAI_ASSERT( !_lut
.empty() );
137 LutType::const_iterator largestEl
= _lut
.end();
139 i
= largestEl
->second
.first
;
140 _I
= largestEl
->second
.second
;
144 void BP::calcNewMessage( size_t i
, size_t _I
) {
145 // calculate updated message I->i
146 size_t I
= nbV(i
,_I
);
148 Factor
Fprod( factor(I
) );
149 Prob
&prod
= Fprod
.p();
150 if( props
.logdomain
)
153 // Calculate product of incoming messages and factor I
154 foreach( const Neighbor
&j
, nbF(I
) )
155 if( j
!= i
) { // for all j in I \ i
156 // prod_j will be the product of messages coming into j
157 Prob
prod_j( var(j
).states(), props
.logdomain
? 0.0 : 1.0 );
158 foreach( const Neighbor
&J
, nbV(j
) )
159 if( J
!= I
) { // for all J in nb(j) \ I
160 if( props
.logdomain
)
161 prod_j
+= message( j
, J
.iter
);
163 prod_j
*= message( j
, J
.iter
);
166 // multiply prod with prod_j
168 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
169 if( props
.logdomain
)
170 Fprod
+= Factor( var(j
), prod_j
);
172 Fprod
*= Factor( var(j
), prod_j
);
174 /* OPTIMIZED VERSION */
176 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
177 const ind_t
&ind
= index(j
, _I
);
178 for( size_t r
= 0; r
< prod
.size(); ++r
)
179 if( props
.logdomain
)
180 prod
[r
] += prod_j
[ind
[r
]];
182 prod
[r
] *= prod_j
[ind
[r
]];
186 if( props
.logdomain
) {
191 // Marginalize onto i
194 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
195 if( props
.inference
== Properties::InfType::SUMPROD
)
196 marg
= Fprod
.marginal( var(i
) ).p();
198 marg
= Fprod
.maxMarginal( var(i
) ).p();
200 /* OPTIMIZED VERSION */
201 marg
= Prob( var(i
).states(), 0.0 );
202 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
203 const ind_t ind
= index(i
,_I
);
204 if( props
.inference
== Properties::InfType::SUMPROD
)
205 for( size_t r
= 0; r
< prod
.size(); ++r
)
206 marg
[ind
[r
]] += prod
[r
];
208 for( size_t r
= 0; r
< prod
.size(); ++r
)
209 if( prod
[r
] > marg
[ind
[r
]] )
210 marg
[ind
[r
]] = prod
[r
];
215 if( props
.logdomain
)
216 newMessage(i
,_I
) = marg
.log();
218 newMessage(i
,_I
) = marg
;
220 // Update the residual if necessary
221 if( props
.updates
== Properties::UpdateType::SEQMAX
)
222 updateResidual( i
, _I
, dist( newMessage( i
, _I
), message( i
, _I
), Prob::DISTLINF
) );
226 // BP::run does not check for NANs for performance reasons
227 // Somehow NaNs do not often occur in BP...
229 if( props
.verbose
>= 1 )
230 cerr
<< "Starting " << identify() << "...";
231 if( props
.verbose
>= 3)
235 Diffs
diffs(nrVars(), 1.0);
237 vector
<Edge
> update_seq
;
239 vector
<Factor
> old_beliefs
;
240 old_beliefs
.reserve( nrVars() );
241 for( size_t i
= 0; i
< nrVars(); ++i
)
242 old_beliefs
.push_back( beliefV(i
) );
244 size_t nredges
= nrEdges();
246 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
248 for( size_t i
= 0; i
< nrVars(); ++i
)
249 foreach( const Neighbor
&I
, nbV(i
) ) {
250 calcNewMessage( i
, I
.iter
);
253 update_seq
.reserve( nredges
);
254 /// \todo Investigate whether performance increases by switching the order of following two loops:
255 for( size_t i
= 0; i
< nrVars(); ++i
)
256 foreach( const Neighbor
&I
, nbV(i
) )
257 update_seq
.push_back( Edge( i
, I
.iter
) );
260 // do several passes over the network until maximum number of iterations has
261 // been reached or until the maximum belief difference is smaller than tolerance
262 for( _iters
=0; _iters
< props
.maxiter
&& diffs
.maxDiff() > props
.tol
; ++_iters
) {
263 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
264 // Residuals-BP by Koller et al.
265 for( size_t t
= 0; t
< nredges
; ++t
) {
266 // update the message with the largest residual
268 findMaxResidual( i
, _I
);
269 updateMessage( i
, _I
);
271 // I->i has been updated, which means that residuals for all
272 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
273 foreach( const Neighbor
&J
, nbV(i
) ) {
275 foreach( const Neighbor
&j
, nbF(J
) ) {
278 calcNewMessage( j
, _J
);
283 } else if( props
.updates
== Properties::UpdateType::PARALL
) {
285 for( size_t i
= 0; i
< nrVars(); ++i
)
286 foreach( const Neighbor
&I
, nbV(i
) )
287 calcNewMessage( i
, I
.iter
);
289 for( size_t i
= 0; i
< nrVars(); ++i
)
290 foreach( const Neighbor
&I
, nbV(i
) )
291 updateMessage( i
, I
.iter
);
293 // Sequential updates
294 if( props
.updates
== Properties::UpdateType::SEQRND
)
295 random_shuffle( update_seq
.begin(), update_seq
.end() );
297 foreach( const Edge
&e
, update_seq
) {
298 calcNewMessage( e
.first
, e
.second
);
299 updateMessage( e
.first
, e
.second
);
303 // calculate new beliefs and compare with old ones
304 for( size_t i
= 0; i
< nrVars(); ++i
) {
305 Factor
nb( beliefV(i
) );
306 diffs
.push( dist( nb
, old_beliefs
[i
], Prob::DISTLINF
) );
310 if( props
.verbose
>= 3 )
311 cerr
<< Name
<< "::run: maxdiff " << diffs
.maxDiff() << " after " << _iters
+1 << " passes" << endl
;
314 if( diffs
.maxDiff() > _maxdiff
)
315 _maxdiff
= diffs
.maxDiff();
317 if( props
.verbose
>= 1 ) {
318 if( diffs
.maxDiff() > props
.tol
) {
319 if( props
.verbose
== 1 )
321 cerr
<< Name
<< "::run: WARNING: not converged within " << props
.maxiter
<< " passes (" << toc() - tic
<< " seconds)...final maxdiff:" << diffs
.maxDiff() << endl
;
323 if( props
.verbose
>= 3 )
324 cerr
<< Name
<< "::run: ";
325 cerr
<< "converged in " << _iters
<< " passes (" << toc() - tic
<< " seconds)." << endl
;
329 return diffs
.maxDiff();
333 void BP::calcBeliefV( size_t i
, Prob
&p
) const {
334 p
= Prob( var(i
).states(), props
.logdomain
? 0.0 : 1.0 );
335 foreach( const Neighbor
&I
, nbV(i
) )
336 if( props
.logdomain
)
337 p
+= newMessage( i
, I
.iter
);
339 p
*= newMessage( i
, I
.iter
);
343 void BP::calcBeliefF( size_t I
, Prob
&p
) const {
344 Factor
Fprod( factor( I
) );
345 Prob
&prod
= Fprod
.p();
347 if( props
.logdomain
)
350 foreach( const Neighbor
&j
, nbF(I
) ) {
351 // prod_j will be the product of messages coming into j
352 Prob
prod_j( var(j
).states(), props
.logdomain
? 0.0 : 1.0 );
353 foreach( const Neighbor
&J
, nbV(j
) )
354 if( J
!= I
) { // for all J in nb(j) \ I
355 if( props
.logdomain
)
356 prod_j
+= newMessage( j
, J
.iter
);
358 prod_j
*= newMessage( j
, J
.iter
);
361 // multiply prod with prod_j
363 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
364 if( props
.logdomain
)
365 Fprod
+= Factor( var(j
), prod_j
);
367 Fprod
*= Factor( var(j
), prod_j
);
369 /* OPTIMIZED VERSION */
371 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
372 const ind_t
& ind
= index(j
, _I
);
374 for( size_t r
= 0; r
< prod
.size(); ++r
) {
375 if( props
.logdomain
)
376 prod
[r
] += prod_j
[ind
[r
]];
378 prod
[r
] *= prod_j
[ind
[r
]];
387 Factor
BP::beliefV( size_t i
) const {
391 if( props
.logdomain
) {
397 return( Factor( var(i
), p
) );
401 Factor
BP::beliefF( size_t I
) const {
405 if( props
.logdomain
) {
411 return( Factor( factor(I
).vars(), p
) );
415 Factor
BP::belief( const Var
&n
) const {
416 return( beliefV( findVar( n
) ) );
420 vector
<Factor
> BP::beliefs() const {
421 vector
<Factor
> result
;
422 for( size_t i
= 0; i
< nrVars(); ++i
)
423 result
.push_back( beliefV(i
) );
424 for( size_t I
= 0; I
< nrFactors(); ++I
)
425 result
.push_back( beliefF(I
) );
430 Factor
BP::belief( const VarSet
&ns
) const {
432 return belief( *(ns
.begin()) );
435 for( I
= 0; I
< nrFactors(); I
++ )
436 if( factor(I
).vars() >> ns
)
438 DAI_ASSERT( I
!= nrFactors() );
439 return beliefF(I
).marginal(ns
);
444 Real
BP::logZ() const {
446 for(size_t i
= 0; i
< nrVars(); ++i
)
447 sum
+= (1.0 - nbV(i
).size()) * beliefV(i
).entropy();
448 for( size_t I
= 0; I
< nrFactors(); ++I
)
449 sum
-= dist( beliefF(I
), factor(I
), Prob::DISTKL
);
454 string
BP::identify() const {
455 return string(Name
) + printProperties();
459 void BP::init( const VarSet
&ns
) {
460 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); ++n
) {
461 size_t ni
= findVar( *n
);
462 foreach( const Neighbor
&I
, nbV( ni
) ) {
463 double val
= props
.logdomain
? 0.0 : 1.0;
464 message( ni
, I
.iter
).fill( val
);
465 newMessage( ni
, I
.iter
).fill( val
);
466 if( props
.updates
== Properties::UpdateType::SEQMAX
)
467 updateResidual( ni
, I
.iter
, 0.0 );
473 void BP::updateMessage( size_t i
, size_t _I
) {
474 if( recordSentMessages
)
475 _sentMessages
.push_back(make_pair(i
,_I
));
476 if( props
.damping
== 0.0 ) {
477 message(i
,_I
) = newMessage(i
,_I
);
478 if( props
.updates
== Properties::UpdateType::SEQMAX
)
479 updateResidual( i
, _I
, 0.0 );
481 message(i
,_I
) = (message(i
,_I
) ^ props
.damping
) * (newMessage(i
,_I
) ^ (1.0 - props
.damping
));
482 if( props
.updates
== Properties::UpdateType::SEQMAX
)
483 updateResidual( i
, _I
, dist( newMessage(i
,_I
), message(i
,_I
), Prob::DISTLINF
) );
488 void BP::updateResidual( size_t i
, size_t _I
, double r
) {
489 EdgeProp
* pEdge
= &_edges
[i
][_I
];
492 // rearrange look-up table (delete and reinsert new key)
493 _lut
.erase( _edge2lut
[i
][_I
] );
494 _edge2lut
[i
][_I
] = _lut
.insert( make_pair( r
, make_pair(i
, _I
) ) );
498 std::vector
<size_t> BP::findMaximum() const {
499 vector
<size_t> maximum( nrVars() );
500 vector
<bool> visitedVars( nrVars(), false );
501 vector
<bool> visitedFactors( nrFactors(), false );
502 stack
<size_t> scheduledFactors
;
503 for( size_t i
= 0; i
< nrVars(); ++i
) {
506 visitedVars
[i
] = true;
508 // Maximise with respect to variable i
510 calcBeliefV( i
, prod
);
511 maximum
[i
] = max_element( prod
.begin(), prod
.end() ) - prod
.begin();
513 foreach( const Neighbor
&I
, nbV(i
) )
514 if( !visitedFactors
[I
] )
515 scheduledFactors
.push(I
);
517 while( !scheduledFactors
.empty() ){
518 size_t I
= scheduledFactors
.top();
519 scheduledFactors
.pop();
520 if( visitedFactors
[I
] )
522 visitedFactors
[I
] = true;
524 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
525 bool allDetermined
= true;
526 foreach( const Neighbor
&j
, nbF(I
) )
527 if( !visitedVars
[j
.node
] ) {
528 allDetermined
= false;
534 // Calculate product of incoming messages on factor I
536 calcBeliefF( I
, prod2
);
538 // The allowed configuration is restrained according to the variables assigned so far:
539 // pick the argmax amongst the allowed states
540 Real maxProb
= numeric_limits
<Real
>::min();
541 State
maxState( factor(I
).vars() );
542 for( State
s( factor(I
).vars() ); s
.valid(); ++s
){
543 // First, calculate whether this state is consistent with variables that
544 // have been assigned already
545 bool allowedState
= true;
546 foreach( const Neighbor
&j
, nbF(I
) )
547 if( visitedVars
[j
.node
] && maximum
[j
.node
] != s(var(j
.node
)) ) {
548 allowedState
= false;
551 // If it is consistent, check if its probability is larger than what we have seen so far
552 if( allowedState
&& prod2
[s
] > maxProb
) {
559 foreach( const Neighbor
&j
, nbF(I
) ) {
560 if( visitedVars
[j
.node
] ) {
561 // We have already visited j earlier - hopefully our state is consistent
562 if( maximum
[j
.node
] != maxState(var(j
.node
)) && props
.verbose
>= 1 )
563 cerr
<< "BP::findMaximum - warning: maximum not consistent due to loops." << endl
;
565 // We found a consistent state for variable j
566 visitedVars
[j
.node
] = true;
567 maximum
[j
.node
] = maxState( var(j
.node
) );
568 foreach( const Neighbor
&J
, nbV(j
) )
569 if( !visitedFactors
[J
] )
570 scheduledFactors
.push(J
);
579 } // end of namespace dai