1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
28 #include <dai/diffs.h>
30 #include <dai/properties.h>
39 const char *BP::Name
= "BP";
42 void BP::setProperties( const PropertySet
&opts
) {
43 assert( opts
.hasKey("tol") );
44 assert( opts
.hasKey("maxiter") );
45 assert( opts
.hasKey("verbose") );
46 assert( opts
.hasKey("logdomain") );
47 assert( opts
.hasKey("updates") );
49 props
.tol
= opts
.getStringAs
<double>("tol");
50 props
.maxiter
= opts
.getStringAs
<size_t>("maxiter");
51 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
52 props
.logdomain
= opts
.getStringAs
<bool>("logdomain");
53 props
.updates
= opts
.getStringAs
<Properties::UpdateType
>("updates");
57 PropertySet
BP::getProperties() const {
59 opts
.Set( "tol", props
.tol
);
60 opts
.Set( "maxiter", props
.maxiter
);
61 opts
.Set( "verbose", props
.verbose
);
62 opts
.Set( "logdomain", props
.logdomain
);
63 opts
.Set( "updates", props
.updates
);
68 string
BP::printProperties() const {
69 stringstream
s( stringstream::out
);
71 s
<< "tol=" << props
.tol
<< ",";
72 s
<< "maxiter=" << props
.maxiter
<< ",";
73 s
<< "verbose=" << props
.verbose
<< ",";
74 s
<< "logdomain=" << props
.logdomain
<< ",";
75 s
<< "updates=" << props
.updates
<< "]";
80 void BP::construct() {
81 // create edge properties
83 edges
.reserve( nrVars() );
84 for( size_t i
= 0; i
< nrVars(); ++i
) {
85 edges
.push_back( vector
<EdgeProp
>() );
86 edges
[i
].reserve( nbV(i
).size() );
87 foreach( const Neighbor
&I
, nbV(i
) ) {
89 newEP
.message
= Prob( var(i
).states() );
90 newEP
.newMessage
= Prob( var(i
).states() );
92 newEP
.index
.reserve( factor(I
).states() );
93 for( IndexFor
k( var(i
), factor(I
).vars() ); k
>= 0; ++k
)
94 newEP
.index
.push_back( k
);
97 edges
[i
].push_back( newEP
);
104 for( size_t i
= 0; i
< nrVars(); ++i
) {
105 foreach( const Neighbor
&I
, nbV(i
) ) {
106 if( props
.logdomain
) {
107 message( i
, I
.iter
).fill( 0.0 );
108 newMessage( i
, I
.iter
).fill( 0.0 );
110 message( i
, I
.iter
).fill( 1.0 );
111 newMessage( i
, I
.iter
).fill( 1.0 );
118 void BP::findMaxResidual( size_t &i
, size_t &_I
) {
121 double maxres
= residual( i
, _I
);
122 for( size_t j
= 0; j
< nrVars(); ++j
)
123 foreach( const Neighbor
&I
, nbV(j
) )
124 if( residual( j
, I
.iter
) > maxres
) {
127 maxres
= residual( i
, _I
);
132 void BP::calcNewMessage( size_t i
, size_t _I
) {
133 // calculate updated message I->i
134 size_t I
= nbV(i
,_I
);
136 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
138 Factor prod( factor( I ) );
139 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ )
140 if( *j != i ) { // for all j in I \ i
141 for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); J++ )
142 if( *J != I ) { // for all J in nb(j) \ I
143 prod *= Factor( *j, message(*j,*J) );
144 Factor marg = prod.marginal(var(i));
147 Prob
prod( factor(I
).p() );
148 if( props
.logdomain
)
151 // Calculate product of incoming messages and factor I
152 foreach( const Neighbor
&j
, nbF(I
) ) {
153 if( j
!= i
) { // for all j in I \ i
155 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
156 const ind_t
& ind
= index(j
, _I
);
158 // prod_j will be the product of messages coming into j
159 Prob
prod_j( var(j
).states(), props
.logdomain
? 0.0 : 1.0 );
160 foreach( const Neighbor
&J
, nbV(j
) )
161 if( J
!= I
) { // for all J in nb(j) \ I
162 if( props
.logdomain
)
163 prod_j
+= message( j
, J
.iter
);
165 prod_j
*= message( j
, J
.iter
);
168 // multiply prod with prod_j
169 for( size_t r
= 0; r
< prod
.size(); ++r
)
170 if( props
.logdomain
)
171 prod
[r
] += prod_j
[ind
[r
]];
173 prod
[r
] *= prod_j
[ind
[r
]];
176 if( props
.logdomain
) {
177 prod
-= prod
.maxVal();
181 // Marginalize onto i
182 Prob
marg( var(i
).states(), 0.0 );
183 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
184 const ind_t ind
= index(i
,_I
);
185 for( size_t r
= 0; r
< prod
.size(); ++r
)
186 marg
[ind
[r
]] += prod
[r
];
187 marg
.normalize( Prob::NORMPROB
);
190 if( props
.logdomain
)
191 newMessage(i
,_I
) = marg
.log();
193 newMessage(i
,_I
) = marg
;
197 // BP::run does not check for NANs for performance reasons
198 // Somehow NaNs do not often occur in BP...
200 if( props
.verbose
>= 1 )
201 cout
<< "Starting " << identify() << "...";
202 if( props
.verbose
>= 3)
206 Diffs
diffs(nrVars(), 1.0);
208 vector
<Edge
> update_seq
;
210 vector
<Factor
> old_beliefs
;
211 old_beliefs
.reserve( nrVars() );
212 for( size_t i
= 0; i
< nrVars(); ++i
)
213 old_beliefs
.push_back( beliefV(i
) );
216 size_t nredges
= nrEdges();
218 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
220 for( size_t i
= 0; i
< nrVars(); ++i
)
221 foreach( const Neighbor
&I
, nbV(i
) ) {
222 calcNewMessage( i
, I
.iter
);
223 // calculate initial residuals
224 residual( i
, I
.iter
) = dist( newMessage( i
, I
.iter
), message( i
, I
.iter
), Prob::DISTLINF
);
227 update_seq
.reserve( nredges
);
228 for( size_t i
= 0; i
< nrVars(); ++i
)
229 foreach( const Neighbor
&I
, nbV(i
) )
230 update_seq
.push_back( Edge( i
, I
.iter
) );
233 // do several passes over the network until maximum number of iterations has
234 // been reached or until the maximum belief difference is smaller than tolerance
235 for( iter
=0; iter
< props
.maxiter
&& diffs
.maxDiff() > props
.tol
; ++iter
) {
236 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
237 // Residuals-BP by Koller et al.
238 for( size_t t
= 0; t
< nredges
; ++t
) {
239 // update the message with the largest residual
242 findMaxResidual( i
, _I
);
243 message( i
, _I
) = newMessage( i
, _I
);
244 residual( i
, _I
) = 0.0;
246 // I->i has been updated, which means that residuals for all
247 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
248 foreach( const Neighbor
&J
, nbV(i
) ) {
250 foreach( const Neighbor
&j
, nbF(J
) ) {
253 calcNewMessage( j
, _J
);
254 residual( j
, _J
) = dist( newMessage( j
, _J
), message( j
, _J
), Prob::DISTLINF
);
260 } else if( props
.updates
== Properties::UpdateType::PARALL
) {
262 for( size_t i
= 0; i
< nrVars(); ++i
)
263 foreach( const Neighbor
&I
, nbV(i
) )
264 calcNewMessage( i
, I
.iter
);
266 for( size_t i
= 0; i
< nrVars(); ++i
)
267 foreach( const Neighbor
&I
, nbV(i
) )
268 message( i
, I
.iter
) = newMessage( i
, I
.iter
);
270 // Sequential updates
271 if( props
.updates
== Properties::UpdateType::SEQRND
)
272 random_shuffle( update_seq
.begin(), update_seq
.end() );
274 foreach( const Edge
&e
, update_seq
) {
275 calcNewMessage( e
.first
, e
.second
);
276 message( e
.first
, e
.second
) = newMessage( e
.first
, e
.second
);
280 // calculate new beliefs and compare with old ones
281 for( size_t i
= 0; i
< nrVars(); ++i
) {
282 Factor
nb( beliefV(i
) );
283 diffs
.push( dist( nb
, old_beliefs
[i
], Prob::DISTLINF
) );
287 if( props
.verbose
>= 3 )
288 cout
<< "BP::run: maxdiff " << diffs
.maxDiff() << " after " << iter
+1 << " passes" << endl
;
291 if( diffs
.maxDiff() > maxdiff
)
292 maxdiff
= diffs
.maxDiff();
294 if( props
.verbose
>= 1 ) {
295 if( diffs
.maxDiff() > props
.tol
) {
296 if( props
.verbose
== 1 )
298 cout
<< "BP::run: WARNING: not converged within " << props
.maxiter
<< " passes (" << toc() - tic
<< " clocks)...final maxdiff:" << diffs
.maxDiff() << endl
;
300 if( props
.verbose
>= 3 )
302 cout
<< "converged in " << iter
<< " passes (" << toc() - tic
<< " clocks)." << endl
;
306 return diffs
.maxDiff();
310 Factor
BP::beliefV( size_t i
) const {
311 Prob
prod( var(i
).states(), props
.logdomain
? 0.0 : 1.0 );
312 foreach( const Neighbor
&I
, nbV(i
) )
313 if( props
.logdomain
)
314 prod
+= newMessage( i
, I
.iter
);
316 prod
*= newMessage( i
, I
.iter
);
317 if( props
.logdomain
) {
318 prod
-= prod
.maxVal();
322 prod
.normalize( Prob::NORMPROB
);
323 return( Factor( var(i
), prod
) );
327 Factor
BP::belief (const Var
&n
) const {
328 return( beliefV( findVar( n
) ) );
332 vector
<Factor
> BP::beliefs() const {
333 vector
<Factor
> result
;
334 for( size_t i
= 0; i
< nrVars(); ++i
)
335 result
.push_back( beliefV(i
) );
336 for( size_t I
= 0; I
< nrFactors(); ++I
)
337 result
.push_back( beliefF(I
) );
342 Factor
BP::belief( const VarSet
&ns
) const {
344 return belief( *(ns
.begin()) );
347 for( I
= 0; I
< nrFactors(); I
++ )
348 if( factor(I
).vars() >> ns
)
350 assert( I
!= nrFactors() );
351 return beliefF(I
).marginal(ns
);
356 Factor
BP::beliefF (size_t I
) const {
357 Prob
prod( factor(I
).p() );
358 if( props
.logdomain
)
361 foreach( const Neighbor
&j
, nbF(I
) ) {
363 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
364 const ind_t
& ind
= index(j
, _I
);
366 // prod_j will be the product of messages coming into j
367 Prob
prod_j( var(j
).states(), props
.logdomain
? 0.0 : 1.0 );
368 foreach( const Neighbor
&J
, nbV(j
) ) {
369 if( J
!= I
) { // for all J in nb(j) \ I
370 if( props
.logdomain
)
371 prod_j
+= newMessage( j
, J
.iter
);
373 prod_j
*= newMessage( j
, J
.iter
);
377 // multiply prod with prod_j
378 for( size_t r
= 0; r
< prod
.size(); ++r
) {
379 if( props
.logdomain
)
380 prod
[r
] += prod_j
[ind
[r
]];
382 prod
[r
] *= prod_j
[ind
[r
]];
386 if( props
.logdomain
) {
387 prod
-= prod
.maxVal();
391 Factor
result( factor(I
).vars(), prod
);
392 result
.normalize( Prob::NORMPROB
);
396 /* UNOPTIMIZED VERSION
398 Factor prod( factor(I) );
399 for( _nb_cit i = nb2(I).begin(); i != nb2(I).end(); i++ ) {
400 for( _nb_cit J = nb1(*i).begin(); J != nb1(*i).end(); J++ )
402 prod *= Factor( var(*i), newMessage(*i,*J)) );
404 return prod.normalize( Prob::NORMPROB );*/
408 Real
BP::logZ() const {
410 for(size_t i
= 0; i
< nrVars(); ++i
)
411 sum
+= (1.0 - nbV(i
).size()) * beliefV(i
).entropy();
412 for( size_t I
= 0; I
< nrFactors(); ++I
)
413 sum
-= KL_dist( beliefF(I
), factor(I
) );
418 string
BP::identify() const {
419 return string(Name
) + printProperties();
423 void BP::init( const VarSet
&ns
) {
424 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); ++n
) {
425 size_t ni
= findVar( *n
);
426 foreach( const Neighbor
&I
, nbV( ni
) )
427 message( ni
, I
.iter
).fill( props
.logdomain
? 0.0 : 1.0 );
432 } // end of namespace dai