1 /* Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
6 This file is part of libDAI.
8 libDAI is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 2 of the License, or
11 (at your option) any later version.
13 libDAI is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
18 You should have received a copy of the GNU General Public License
19 along with libDAI; if not, write to the Free Software
20 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
32 #include <dai/properties.h>
41 const char *BP::Name
= "BP";
47 void BP::setProperties( const PropertySet
&opts
) {
48 assert( opts
.hasKey("tol") );
49 assert( opts
.hasKey("maxiter") );
50 assert( opts
.hasKey("logdomain") );
51 assert( opts
.hasKey("updates") );
53 props
.tol
= opts
.getStringAs
<double>("tol");
54 props
.maxiter
= opts
.getStringAs
<size_t>("maxiter");
55 props
.logdomain
= opts
.getStringAs
<bool>("logdomain");
56 props
.updates
= opts
.getStringAs
<Properties::UpdateType
>("updates");
58 if( opts
.hasKey("verbose") )
59 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
62 if( opts
.hasKey("damping") )
63 props
.damping
= opts
.getStringAs
<double>("damping");
66 if( opts
.hasKey("inference") )
67 props
.inference
= opts
.getStringAs
<Properties::InfType
>("inference");
69 props
.inference
= Properties::InfType::SUMPROD
;
73 PropertySet
BP::getProperties() const {
75 opts
.Set( "tol", props
.tol
);
76 opts
.Set( "maxiter", props
.maxiter
);
77 opts
.Set( "verbose", props
.verbose
);
78 opts
.Set( "logdomain", props
.logdomain
);
79 opts
.Set( "updates", props
.updates
);
80 opts
.Set( "damping", props
.damping
);
81 opts
.Set( "inference", props
.inference
);
86 string
BP::printProperties() const {
87 stringstream
s( stringstream::out
);
89 s
<< "tol=" << props
.tol
<< ",";
90 s
<< "maxiter=" << props
.maxiter
<< ",";
91 s
<< "verbose=" << props
.verbose
<< ",";
92 s
<< "logdomain=" << props
.logdomain
<< ",";
93 s
<< "updates=" << props
.updates
<< ",";
94 s
<< "damping=" << props
.damping
<< ",";
95 s
<< "inference=" << props
.inference
<< "]";
100 void BP::construct() {
101 // create edge properties
103 _edges
.reserve( nrVars() );
105 if( props
.updates
== Properties::UpdateType::SEQMAX
)
106 _edge2lut
.reserve( nrVars() );
107 for( size_t i
= 0; i
< nrVars(); ++i
) {
108 _edges
.push_back( vector
<EdgeProp
>() );
109 _edges
[i
].reserve( nbV(i
).size() );
110 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
111 _edge2lut
.push_back( vector
<LutType::iterator
>() );
112 _edge2lut
[i
].reserve( nbV(i
).size() );
114 foreach( const Neighbor
&I
, nbV(i
) ) {
116 newEP
.message
= Prob( var(i
).states() );
117 newEP
.newMessage
= Prob( var(i
).states() );
120 newEP
.index
.reserve( factor(I
).states() );
121 for( IndexFor
k( var(i
), factor(I
).vars() ); k
>= 0; ++k
)
122 newEP
.index
.push_back( k
);
125 newEP
.residual
= 0.0;
126 _edges
[i
].push_back( newEP
);
127 if( props
.updates
== Properties::UpdateType::SEQMAX
)
128 _edge2lut
[i
].push_back( _lut
.insert( std::make_pair( newEP
.residual
, std::make_pair( i
, _edges
[i
].size() - 1 ))) );
135 double c
= props
.logdomain
? 0.0 : 1.0;
136 for( size_t i
= 0; i
< nrVars(); ++i
) {
137 foreach( const Neighbor
&I
, nbV(i
) ) {
138 message( i
, I
.iter
).fill( c
);
139 newMessage( i
, I
.iter
).fill( c
);
140 if( props
.updates
== Properties::UpdateType::SEQMAX
)
141 updateResidual( i
, I
.iter
, 0.0 );
147 void BP::findMaxResidual( size_t &i
, size_t &_I
) {
151 double maxres = residual( i, _I );
152 for( size_t j = 0; j < nrVars(); ++j )
153 foreach( const Neighbor &I, nbV(j) )
154 if( residual( j, I.iter ) > maxres ) {
157 maxres = residual( i, _I );
160 assert( !_lut
.empty() );
161 LutType::const_iterator largestEl
= _lut
.end();
163 i
= largestEl
->second
.first
;
164 _I
= largestEl
->second
.second
;
168 void BP::calcNewMessage( size_t i
, size_t _I
) {
169 // calculate updated message I->i
170 size_t I
= nbV(i
,_I
);
173 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
174 Factor
prod( factor( I
) );
175 foreach( const Neighbor
&j
, nbF(I
) )
176 if( j
!= i
) { // for all j in I \ i
177 foreach( const Neighbor
&J
, nbV(j
) )
178 if( J
!= I
) { // for all J in nb(j) \ I
179 prod
*= Factor( var(j
), message(j
, J
.iter
) );
182 newMessage(i
,_I
) = prod
.marginal( var(i
) ).p();
184 /* OPTIMIZED VERSION */
185 Prob
prod( factor(I
).p() );
186 if( props
.logdomain
)
189 // Calculate product of incoming messages and factor I
190 foreach( const Neighbor
&j
, nbF(I
) ) {
191 if( j
!= i
) { // for all j in I \ i
193 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
194 const ind_t
&ind
= index(j
, _I
);
196 // prod_j will be the product of messages coming into j
197 Prob
prod_j( var(j
).states(), props
.logdomain
? 0.0 : 1.0 );
198 foreach( const Neighbor
&J
, nbV(j
) )
199 if( J
!= I
) { // for all J in nb(j) \ I
200 if( props
.logdomain
)
201 prod_j
+= message( j
, J
.iter
);
203 prod_j
*= message( j
, J
.iter
);
206 // multiply prod with prod_j
207 for( size_t r
= 0; r
< prod
.size(); ++r
)
208 if( props
.logdomain
)
209 prod
[r
] += prod_j
[ind
[r
]];
211 prod
[r
] *= prod_j
[ind
[r
]];
214 if( props
.logdomain
) {
215 prod
-= prod
.maxVal();
219 // Marginalize onto i
220 Prob
marg( var(i
).states(), 0.0 );
221 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
222 const ind_t ind
= index(i
,_I
);
223 if( props
.inference
== Properties::InfType::SUMPROD
)
224 for( size_t r
= 0; r
< prod
.size(); ++r
)
225 marg
[ind
[r
]] += prod
[r
];
227 for( size_t r
= 0; r
< prod
.size(); ++r
)
228 if( prod
[r
] > marg
[ind
[r
]] )
229 marg
[ind
[r
]] = prod
[r
];
233 if( props
.logdomain
)
234 newMessage(i
,_I
) = marg
.log();
236 newMessage(i
,_I
) = marg
;
239 // Update the residual if necessary
240 if( props
.updates
== Properties::UpdateType::SEQMAX
)
241 updateResidual( i
, _I
, dist( newMessage( i
, _I
), message( i
, _I
), Prob::DISTLINF
) );
245 // BP::run does not check for NANs for performance reasons
246 // Somehow NaNs do not often occur in BP...
248 if( props
.verbose
>= 1 )
249 cout
<< "Starting " << identify() << "...";
250 if( props
.verbose
>= 3)
254 Diffs
diffs(nrVars(), 1.0);
256 vector
<Edge
> update_seq
;
258 vector
<Factor
> old_beliefs
;
259 old_beliefs
.reserve( nrVars() );
260 for( size_t i
= 0; i
< nrVars(); ++i
)
261 old_beliefs
.push_back( beliefV(i
) );
263 size_t nredges
= nrEdges();
265 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
267 for( size_t i
= 0; i
< nrVars(); ++i
)
268 foreach( const Neighbor
&I
, nbV(i
) ) {
269 calcNewMessage( i
, I
.iter
);
272 update_seq
.reserve( nredges
);
273 /// \todo Investigate whether performance increases by switching the order of following two loops:
274 for( size_t i
= 0; i
< nrVars(); ++i
)
275 foreach( const Neighbor
&I
, nbV(i
) )
276 update_seq
.push_back( Edge( i
, I
.iter
) );
279 // do several passes over the network until maximum number of iterations has
280 // been reached or until the maximum belief difference is smaller than tolerance
281 for( _iters
=0; _iters
< props
.maxiter
&& diffs
.maxDiff() > props
.tol
; ++_iters
) {
282 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
283 // Residuals-BP by Koller et al.
284 for( size_t t
= 0; t
< nredges
; ++t
) {
285 // update the message with the largest residual
287 findMaxResidual( i
, _I
);
288 updateMessage( i
, _I
);
290 // I->i has been updated, which means that residuals for all
291 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
292 foreach( const Neighbor
&J
, nbV(i
) ) {
294 foreach( const Neighbor
&j
, nbF(J
) ) {
297 calcNewMessage( j
, _J
);
302 } else if( props
.updates
== Properties::UpdateType::PARALL
) {
304 for( size_t i
= 0; i
< nrVars(); ++i
)
305 foreach( const Neighbor
&I
, nbV(i
) )
306 calcNewMessage( i
, I
.iter
);
308 for( size_t i
= 0; i
< nrVars(); ++i
)
309 foreach( const Neighbor
&I
, nbV(i
) )
310 updateMessage( i
, I
.iter
);
312 // Sequential updates
313 if( props
.updates
== Properties::UpdateType::SEQRND
)
314 random_shuffle( update_seq
.begin(), update_seq
.end() );
316 foreach( const Edge
&e
, update_seq
) {
317 calcNewMessage( e
.first
, e
.second
);
318 updateMessage( e
.first
, e
.second
);
322 // calculate new beliefs and compare with old ones
323 for( size_t i
= 0; i
< nrVars(); ++i
) {
324 Factor
nb( beliefV(i
) );
325 diffs
.push( dist( nb
, old_beliefs
[i
], Prob::DISTLINF
) );
329 if( props
.verbose
>= 3 )
330 cout
<< Name
<< "::run: maxdiff " << diffs
.maxDiff() << " after " << _iters
+1 << " passes" << endl
;
333 if( diffs
.maxDiff() > _maxdiff
)
334 _maxdiff
= diffs
.maxDiff();
336 if( props
.verbose
>= 1 ) {
337 if( diffs
.maxDiff() > props
.tol
) {
338 if( props
.verbose
== 1 )
340 cout
<< Name
<< "::run: WARNING: not converged within " << props
.maxiter
<< " passes (" << toc() - tic
<< " seconds)...final maxdiff:" << diffs
.maxDiff() << endl
;
342 if( props
.verbose
>= 3 )
343 cout
<< Name
<< "::run: ";
344 cout
<< "converged in " << _iters
<< " passes (" << toc() - tic
<< " seconds)." << endl
;
348 return diffs
.maxDiff();
352 void BP::calcBeliefV( size_t i
, Prob
&p
) const {
353 p
= Prob( var(i
).states(), props
.logdomain
? 0.0 : 1.0 );
354 foreach( const Neighbor
&I
, nbV(i
) )
355 if( props
.logdomain
)
356 p
+= newMessage( i
, I
.iter
);
358 p
*= newMessage( i
, I
.iter
);
362 void BP::calcBeliefF( size_t I
, Prob
&p
) const {
364 if( props
.logdomain
)
367 foreach( const Neighbor
&j
, nbF(I
) ) {
369 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
370 const ind_t
& ind
= index(j
, _I
);
372 // prod_j will be the product of messages coming into j
373 Prob
prod_j( var(j
).states(), props
.logdomain
? 0.0 : 1.0 );
374 foreach( const Neighbor
&J
, nbV(j
) ) {
375 if( J
!= I
) { // for all J in nb(j) \ I
376 if( props
.logdomain
)
377 prod_j
+= newMessage( j
, J
.iter
);
379 prod_j
*= newMessage( j
, J
.iter
);
383 // multiply p with prod_j
384 for( size_t r
= 0; r
< p
.size(); ++r
) {
385 if( props
.logdomain
)
386 p
[r
] += prod_j
[ind
[r
]];
388 p
[r
] *= prod_j
[ind
[r
]];
394 Factor
BP::beliefV( size_t i
) const {
398 if( props
.logdomain
) {
404 return( Factor( var(i
), p
) );
408 Factor
BP::belief( const Var
&n
) const {
409 return( beliefV( findVar( n
) ) );
413 vector
<Factor
> BP::beliefs() const {
414 vector
<Factor
> result
;
415 for( size_t i
= 0; i
< nrVars(); ++i
)
416 result
.push_back( beliefV(i
) );
417 for( size_t I
= 0; I
< nrFactors(); ++I
)
418 result
.push_back( beliefF(I
) );
423 Factor
BP::belief( const VarSet
&ns
) const {
425 return belief( *(ns
.begin()) );
428 for( I
= 0; I
< nrFactors(); I
++ )
429 if( factor(I
).vars() >> ns
)
431 assert( I
!= nrFactors() );
432 return beliefF(I
).marginal(ns
);
437 Factor
BP::beliefF( size_t I
) const {
439 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
441 Factor
prod( factor(I
) );
442 foreach( const Neighbor
&j
, nbF(I
) ) {
443 foreach( const Neighbor
&J
, nbV(j
) ) {
444 if( J
!= I
) // for all J in nb(j) \ I
445 prod
*= Factor( var(j
), newMessage(j
, J
.iter
) );
448 return prod
.normalized();
450 /* OPTIMIZED VERSION */
453 calcBeliefF( I
, prod
);
455 if( props
.logdomain
) {
456 prod
-= prod
.maxVal();
461 Factor
result( factor(I
).vars(), prod
);
468 Real
BP::logZ() const {
470 for(size_t i
= 0; i
< nrVars(); ++i
)
471 sum
+= (1.0 - nbV(i
).size()) * beliefV(i
).entropy();
472 for( size_t I
= 0; I
< nrFactors(); ++I
)
473 sum
-= dist( beliefF(I
), factor(I
), Prob::DISTKL
);
478 string
BP::identify() const {
479 return string(Name
) + printProperties();
483 void BP::init( const VarSet
&ns
) {
484 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); ++n
) {
485 size_t ni
= findVar( *n
);
486 foreach( const Neighbor
&I
, nbV( ni
) ) {
487 double val
= props
.logdomain
? 0.0 : 1.0;
488 message( ni
, I
.iter
).fill( val
);
489 newMessage( ni
, I
.iter
).fill( val
);
490 if( props
.updates
== Properties::UpdateType::SEQMAX
)
491 updateResidual( ni
, I
.iter
, 0.0 );
497 void BP::updateMessage( size_t i
, size_t _I
) {
498 if( props
.damping
== 0.0 ) {
499 message(i
,_I
) = newMessage(i
,_I
);
500 if( props
.updates
== Properties::UpdateType::SEQMAX
)
501 updateResidual( i
, _I
, 0.0 );
503 message(i
,_I
) = (message(i
,_I
) ^ props
.damping
) * (newMessage(i
,_I
) ^ (1.0 - props
.damping
));
504 if( props
.updates
== Properties::UpdateType::SEQMAX
)
505 updateResidual( i
, _I
, dist( newMessage(i
,_I
), message(i
,_I
), Prob::DISTLINF
) );
510 void BP::updateResidual( size_t i
, size_t _I
, double r
) {
511 EdgeProp
* pEdge
= &_edges
[i
][_I
];
514 // rearrange look-up table (delete and reinsert new key)
515 _lut
.erase( _edge2lut
[i
][_I
] );
516 _edge2lut
[i
][_I
] = _lut
.insert( std::make_pair( r
, std::make_pair(i
, _I
) ) );
520 std::vector
<size_t> BP::findMaximum() const {
521 std::vector
<size_t> maximum( nrVars() );
522 std::vector
<bool> visitedVars( nrVars(), false );
523 std::vector
<bool> visitedFactors( nrFactors(), false );
524 std::stack
<size_t> scheduledFactors
;
525 for( size_t i
= 0; i
< nrVars(); ++i
) {
528 visitedVars
[i
] = true;
530 // Maximise with respect to variable i
532 calcBeliefV( i
, prod
);
533 maximum
[i
] = std::max_element( prod
.begin(), prod
.end() ) - prod
.begin();
535 foreach( const Neighbor
&I
, nbV(i
) )
536 if( !visitedFactors
[I
] )
537 scheduledFactors
.push(I
);
539 while( !scheduledFactors
.empty() ){
540 size_t I
= scheduledFactors
.top();
541 scheduledFactors
.pop();
542 if( visitedFactors
[I
] )
544 visitedFactors
[I
] = true;
546 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
547 bool allDetermined
= true;
548 foreach( const Neighbor
&j
, nbF(I
) )
549 if( !visitedVars
[j
.node
] ) {
550 allDetermined
= false;
556 // Calculate product of incoming messages on factor I
558 calcBeliefF( I
, prod2
);
560 // The allowed configuration is restrained according to the variables assigned so far:
561 // pick the argmax amongst the allowed states
562 Real maxProb
= std::numeric_limits
<Real
>::min();
563 State
maxState( factor(I
).vars() );
564 for( State
s( factor(I
).vars() ); s
.valid(); ++s
){
565 // First, calculate whether this state is consistent with variables that
566 // have been assigned already
567 bool allowedState
= true;
568 foreach( const Neighbor
&j
, nbF(I
) )
569 if( visitedVars
[j
.node
] && maximum
[j
.node
] != s(var(j
.node
)) ) {
570 allowedState
= false;
573 // If it is consistent, check if its probability is larger than what we have seen so far
574 if( allowedState
&& prod2
[s
] > maxProb
) {
581 foreach( const Neighbor
&j
, nbF(I
) ) {
582 if( visitedVars
[j
.node
] ) {
583 // We have already visited j earlier - hopefully our state is consistent
584 if( maximum
[j
.node
] != maxState(var(j
.node
)) && props
.verbose
>= 1 )
585 std::cerr
<< "BP::findMaximum - warning: maximum not consistent due to loops." << std::endl
;
587 // We found a consistent state for variable j
588 visitedVars
[j
.node
] = true;
589 maximum
[j
.node
] = maxState( var(j
.node
) );
590 foreach( const Neighbor
&J
, nbV(j
) )
591 if( !visitedFactors
[J
] )
592 scheduledFactors
.push(J
);
601 } // end of namespace dai