7c8d81a85b88da6fcdc91a710b4266a66fd33bfe
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
5 This file is part of libDAI.
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
30 #include <dai/properties.h>
39 const char *BP::Name
= "BP";
42 void BP::setProperties( const PropertySet
&opts
) {
43 assert( opts
.hasKey("tol") );
44 assert( opts
.hasKey("maxiter") );
45 assert( opts
.hasKey("logdomain") );
46 assert( opts
.hasKey("updates") );
48 props
.tol
= opts
.getStringAs
<double>("tol");
49 props
.maxiter
= opts
.getStringAs
<size_t>("maxiter");
50 props
.logdomain
= opts
.getStringAs
<bool>("logdomain");
51 props
.updates
= opts
.getStringAs
<Properties::UpdateType
>("updates");
53 if( opts
.hasKey("verbose") )
54 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
57 if( opts
.hasKey("damping") )
58 props
.damping
= opts
.getStringAs
<double>("damping");
61 if( opts
.hasKey("inference") )
62 props
.inference
= opts
.getStringAs
<Properties::InfType
>("inference");
64 props
.inference
= Properties::InfType::SUMPROD
;
68 PropertySet
BP::getProperties() const {
70 opts
.Set( "tol", props
.tol
);
71 opts
.Set( "maxiter", props
.maxiter
);
72 opts
.Set( "verbose", props
.verbose
);
73 opts
.Set( "logdomain", props
.logdomain
);
74 opts
.Set( "updates", props
.updates
);
75 opts
.Set( "damping", props
.damping
);
76 opts
.Set( "inference", props
.inference
);
81 string
BP::printProperties() const {
82 stringstream
s( stringstream::out
);
84 s
<< "tol=" << props
.tol
<< ",";
85 s
<< "maxiter=" << props
.maxiter
<< ",";
86 s
<< "verbose=" << props
.verbose
<< ",";
87 s
<< "logdomain=" << props
.logdomain
<< ",";
88 s
<< "updates=" << props
.updates
<< ",";
89 s
<< "damping=" << props
.damping
<< ",";
90 s
<< "inference=" << props
.inference
<< "]";
95 void BP::construct() {
96 // create edge properties
98 _edges
.reserve( nrVars() );
99 for( size_t i
= 0; i
< nrVars(); ++i
) {
100 _edges
.push_back( vector
<EdgeProp
>() );
101 _edges
[i
].reserve( nbV(i
).size() );
102 foreach( const Neighbor
&I
, nbV(i
) ) {
104 newEP
.message
= Prob( var(i
).states() );
105 newEP
.newMessage
= Prob( var(i
).states() );
107 newEP
.index
.reserve( factor(I
).states() );
108 for( IndexFor
k( var(i
), factor(I
).vars() ); k
>= 0; ++k
)
109 newEP
.index
.push_back( k
);
111 newEP
.residual
= 0.0;
112 _edges
[i
].push_back( newEP
);
119 double c
= props
.logdomain
? 0.0 : 1.0;
120 for( size_t i
= 0; i
< nrVars(); ++i
) {
121 foreach( const Neighbor
&I
, nbV(i
) ) {
122 message( i
, I
.iter
).fill( c
);
123 newMessage( i
, I
.iter
).fill( c
);
129 void BP::findMaxResidual( size_t &i
, size_t &_I
) {
132 double maxres
= residual( i
, _I
);
133 for( size_t j
= 0; j
< nrVars(); ++j
)
134 foreach( const Neighbor
&I
, nbV(j
) )
135 if( residual( j
, I
.iter
) > maxres
) {
138 maxres
= residual( i
, _I
);
143 void BP::calcNewMessage( size_t i
, size_t _I
) {
144 // calculate updated message I->i
145 size_t I
= nbV(i
,_I
);
148 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
149 Factor
prod( factor( I
) );
150 foreach( const Neighbor
&j
, nbF(I
) )
151 if( j
!= i
) { // for all j in I \ i
152 foreach( const Neighbor
&J
, nbV(j
) )
153 if( J
!= I
) { // for all J in nb(j) \ I
154 prod
*= Factor( var(j
), message(j
, J
.iter
) );
157 newMessage(i
,_I
) = prod
.marginal( var(i
) ).p();
159 /* OPTIMIZED VERSION */
160 Prob
prod( factor(I
).p() );
161 if( props
.logdomain
)
164 // Calculate product of incoming messages and factor I
165 foreach( const Neighbor
&j
, nbF(I
) ) {
166 if( j
!= i
) { // for all j in I \ i
168 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
169 const ind_t
&ind
= index(j
, _I
);
171 // prod_j will be the product of messages coming into j
172 Prob
prod_j( var(j
).states(), props
.logdomain
? 0.0 : 1.0 );
173 foreach( const Neighbor
&J
, nbV(j
) )
174 if( J
!= I
) { // for all J in nb(j) \ I
175 if( props
.logdomain
)
176 prod_j
+= message( j
, J
.iter
);
178 prod_j
*= message( j
, J
.iter
);
181 // multiply prod with prod_j
182 for( size_t r
= 0; r
< prod
.size(); ++r
)
183 if( props
.logdomain
)
184 prod
[r
] += prod_j
[ind
[r
]];
186 prod
[r
] *= prod_j
[ind
[r
]];
189 if( props
.logdomain
) {
190 prod
-= prod
.maxVal();
194 // Marginalize onto i
195 Prob
marg( var(i
).states(), 0.0 );
196 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
197 const ind_t ind
= index(i
,_I
);
198 if( props
.inference
== Properties::InfType::SUMPROD
)
199 for( size_t r
= 0; r
< prod
.size(); ++r
)
200 marg
[ind
[r
]] += prod
[r
];
202 for( size_t r
= 0; r
< prod
.size(); ++r
)
203 if( prod
[r
] > marg
[ind
[r
]] )
204 marg
[ind
[r
]] = prod
[r
];
208 if( props
.logdomain
)
209 newMessage(i
,_I
) = marg
.log();
211 newMessage(i
,_I
) = marg
;
216 // BP::run does not check for NANs for performance reasons
217 // Somehow NaNs do not often occur in BP...
219 if( props
.verbose
>= 1 )
220 cout
<< "Starting " << identify() << "...";
221 if( props
.verbose
>= 3)
225 Diffs
diffs(nrVars(), 1.0);
227 vector
<Edge
> update_seq
;
229 vector
<Factor
> old_beliefs
;
230 old_beliefs
.reserve( nrVars() );
231 for( size_t i
= 0; i
< nrVars(); ++i
)
232 old_beliefs
.push_back( beliefV(i
) );
234 size_t nredges
= nrEdges();
236 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
238 for( size_t i
= 0; i
< nrVars(); ++i
)
239 foreach( const Neighbor
&I
, nbV(i
) ) {
240 calcNewMessage( i
, I
.iter
);
241 // calculate initial residuals
242 residual( i
, I
.iter
) = dist( newMessage( i
, I
.iter
), message( i
, I
.iter
), Prob::DISTLINF
);
245 update_seq
.reserve( nredges
);
246 for( size_t i
= 0; i
< nrVars(); ++i
)
247 foreach( const Neighbor
&I
, nbV(i
) )
248 update_seq
.push_back( Edge( i
, I
.iter
) );
251 // do several passes over the network until maximum number of iterations has
252 // been reached or until the maximum belief difference is smaller than tolerance
253 for( _iters
=0; _iters
< props
.maxiter
&& diffs
.maxDiff() > props
.tol
; ++_iters
) {
254 if( props
.updates
== Properties::UpdateType::SEQMAX
) {
255 // Residuals-BP by Koller et al.
256 for( size_t t
= 0; t
< nredges
; ++t
) {
257 // update the message with the largest residual
259 findMaxResidual( i
, _I
);
260 updateMessage( i
, _I
);
262 // I->i has been updated, which means that residuals for all
263 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
264 foreach( const Neighbor
&J
, nbV(i
) ) {
266 foreach( const Neighbor
&j
, nbF(J
) ) {
269 calcNewMessage( j
, _J
);
270 residual( j
, _J
) = dist( newMessage( j
, _J
), message( j
, _J
), Prob::DISTLINF
);
276 } else if( props
.updates
== Properties::UpdateType::PARALL
) {
278 for( size_t i
= 0; i
< nrVars(); ++i
)
279 foreach( const Neighbor
&I
, nbV(i
) )
280 calcNewMessage( i
, I
.iter
);
282 for( size_t i
= 0; i
< nrVars(); ++i
)
283 foreach( const Neighbor
&I
, nbV(i
) )
284 updateMessage( i
, I
.iter
);
286 // Sequential updates
287 if( props
.updates
== Properties::UpdateType::SEQRND
)
288 random_shuffle( update_seq
.begin(), update_seq
.end() );
290 foreach( const Edge
&e
, update_seq
) {
291 calcNewMessage( e
.first
, e
.second
);
292 updateMessage( e
.first
, e
.second
);
296 // calculate new beliefs and compare with old ones
297 for( size_t i
= 0; i
< nrVars(); ++i
) {
298 Factor
nb( beliefV(i
) );
299 diffs
.push( dist( nb
, old_beliefs
[i
], Prob::DISTLINF
) );
303 if( props
.verbose
>= 3 )
304 cout
<< Name
<< "::run: maxdiff " << diffs
.maxDiff() << " after " << _iters
+1 << " passes" << endl
;
307 if( diffs
.maxDiff() > _maxdiff
)
308 _maxdiff
= diffs
.maxDiff();
310 if( props
.verbose
>= 1 ) {
311 if( diffs
.maxDiff() > props
.tol
) {
312 if( props
.verbose
== 1 )
314 cout
<< Name
<< "::run: WARNING: not converged within " << props
.maxiter
<< " passes (" << toc() - tic
<< " seconds)...final maxdiff:" << diffs
.maxDiff() << endl
;
316 if( props
.verbose
>= 3 )
317 cout
<< Name
<< "::run: ";
318 cout
<< "converged in " << _iters
<< " passes (" << toc() - tic
<< " seconds)." << endl
;
322 return diffs
.maxDiff();
326 Factor
BP::beliefV( size_t i
) const {
327 Prob
prod( var(i
).states(), props
.logdomain
? 0.0 : 1.0 );
328 foreach( const Neighbor
&I
, nbV(i
) )
329 if( props
.logdomain
)
330 prod
+= newMessage( i
, I
.iter
);
332 prod
*= newMessage( i
, I
.iter
);
333 if( props
.logdomain
) {
334 prod
-= prod
.maxVal();
339 return( Factor( var(i
), prod
) );
343 Factor
BP::belief (const Var
&n
) const {
344 return( beliefV( findVar( n
) ) );
348 vector
<Factor
> BP::beliefs() const {
349 vector
<Factor
> result
;
350 for( size_t i
= 0; i
< nrVars(); ++i
)
351 result
.push_back( beliefV(i
) );
352 for( size_t I
= 0; I
< nrFactors(); ++I
)
353 result
.push_back( beliefF(I
) );
358 Factor
BP::belief( const VarSet
&ns
) const {
360 return belief( *(ns
.begin()) );
363 for( I
= 0; I
< nrFactors(); I
++ )
364 if( factor(I
).vars() >> ns
)
366 assert( I
!= nrFactors() );
367 return beliefF(I
).marginal(ns
);
372 Factor
BP::beliefF (size_t I
) const {
374 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
376 Factor
prod( factor(I
) );
377 foreach( const Neighbor
&j
, nbF(I
) ) {
378 foreach( const Neighbor
&J
, nbV(j
) ) {
379 if( J
!= I
) // for all J in nb(j) \ I
380 prod
*= Factor( var(j
), newMessage(j
, J
.iter
) );
383 return prod
.normalized();
385 /* OPTIMIZED VERSION */
386 Prob
prod( factor(I
).p() );
387 if( props
.logdomain
)
390 foreach( const Neighbor
&j
, nbF(I
) ) {
392 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
393 const ind_t
& ind
= index(j
, _I
);
395 // prod_j will be the product of messages coming into j
396 Prob
prod_j( var(j
).states(), props
.logdomain
? 0.0 : 1.0 );
397 foreach( const Neighbor
&J
, nbV(j
) ) {
398 if( J
!= I
) { // for all J in nb(j) \ I
399 if( props
.logdomain
)
400 prod_j
+= newMessage( j
, J
.iter
);
402 prod_j
*= newMessage( j
, J
.iter
);
406 // multiply prod with prod_j
407 for( size_t r
= 0; r
< prod
.size(); ++r
) {
408 if( props
.logdomain
)
409 prod
[r
] += prod_j
[ind
[r
]];
411 prod
[r
] *= prod_j
[ind
[r
]];
415 if( props
.logdomain
) {
416 prod
-= prod
.maxVal();
420 Factor
result( factor(I
).vars(), prod
);
428 Real
BP::logZ() const {
430 for(size_t i
= 0; i
< nrVars(); ++i
)
431 sum
+= (1.0 - nbV(i
).size()) * beliefV(i
).entropy();
432 for( size_t I
= 0; I
< nrFactors(); ++I
)
433 sum
-= dist( beliefF(I
), factor(I
), Prob::DISTKL
);
438 string
BP::identify() const {
439 return string(Name
) + printProperties();
443 void BP::init( const VarSet
&ns
) {
444 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); ++n
) {
445 size_t ni
= findVar( *n
);
446 foreach( const Neighbor
&I
, nbV( ni
) )
447 message( ni
, I
.iter
).fill( props
.logdomain
? 0.0 : 1.0 );
452 } // end of namespace dai