1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
28 #include <dai/diffs.h>
30 #include <dai/properties.h>
39 const char *BP::Name
= "BP";
42 bool BP::checkProperties() {
43 if( !HasProperty("updates") )
45 if( !HasProperty("tol") )
47 if (!HasProperty("maxiter") )
49 if (!HasProperty("verbose") )
51 if (!HasProperty("logdomain") )
54 ConvertPropertyTo
<double>("tol");
55 ConvertPropertyTo
<size_t>("maxiter");
56 ConvertPropertyTo
<size_t>("verbose");
57 ConvertPropertyTo
<UpdateType
>("updates");
58 ConvertPropertyTo
<bool>("logdomain");
59 logDomain
= GetPropertyAs
<bool>("logdomain");
66 // create edge properties
68 edges
.reserve( nrVars() );
69 for( size_t i
= 0; i
< nrVars(); ++i
) {
70 edges
.push_back( vector
<EdgeProp
>() );
71 edges
[i
].reserve( nbV(i
).size() );
72 foreach( const Neighbor
&I
, nbV(i
) ) {
74 newEP
.message
= Prob( var(i
).states() );
75 newEP
.newMessage
= Prob( var(i
).states() );
77 newEP
.index
.reserve( factor(I
).states() );
78 for( IndexFor
k( var(i
), factor(I
).vars() ); k
>= 0; ++k
)
79 newEP
.index
.push_back( k
);
82 edges
[i
].push_back( newEP
);
89 assert( checkProperties() );
90 for( size_t i
= 0; i
< nrVars(); ++i
) {
91 foreach( const Neighbor
&I
, nbV(i
) ) {
93 message( i
, I
.iter
).fill( 0.0 );
94 newMessage( i
, I
.iter
).fill( 0.0 );
96 message( i
, I
.iter
).fill( 1.0 );
97 newMessage( i
, I
.iter
).fill( 1.0 );
104 void BP::findMaxResidual( size_t &i
, size_t &_I
) {
107 double maxres
= residual( i
, _I
);
108 for( size_t j
= 0; j
< nrVars(); ++j
)
109 foreach( const Neighbor
&I
, nbV(j
) )
110 if( residual( j
, I
.iter
) > maxres
) {
113 maxres
= residual( i
, _I
);
118 void BP::calcNewMessage( size_t i
, size_t _I
) {
119 // calculate updated message I->i
120 size_t I
= nbV(i
,_I
);
122 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
124 Factor prod( factor( I ) );
125 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ )
126 if( *j != i ) { // for all j in I \ i
127 for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); J++ )
128 if( *J != I ) { // for all J in nb(j) \ I
129 prod *= Factor( *j, message(*j,*J) );
130 Factor marg = prod.marginal(var(i));
133 Prob
prod( factor(I
).p() );
137 // Calculate product of incoming messages and factor I
138 foreach( const Neighbor
&j
, nbF(I
) ) {
139 if( j
!= i
) { // for all j in I \ i
141 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
142 const ind_t
& ind
= index(j
, _I
);
144 // prod_j will be the product of messages coming into j
145 Prob
prod_j( var(j
).states(), logDomain
? 0.0 : 1.0 );
146 foreach( const Neighbor
&J
, nbV(j
) )
147 if( J
!= I
) { // for all J in nb(j) \ I
149 prod_j
+= message( j
, J
.iter
);
151 prod_j
*= message( j
, J
.iter
);
154 // multiply prod with prod_j
155 for( size_t r
= 0; r
< prod
.size(); ++r
)
157 prod
[r
] += prod_j
[ind
[r
]];
159 prod
[r
] *= prod_j
[ind
[r
]];
163 prod
-= prod
.maxVal();
167 // Marginalize onto i
168 Prob
marg( var(i
).states(), 0.0 );
169 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
170 const ind_t ind
= index(i
,_I
);
171 for( size_t r
= 0; r
< prod
.size(); ++r
)
172 marg
[ind
[r
]] += prod
[r
];
173 marg
.normalize( _normtype
);
177 newMessage(i
,_I
) = marg
.log();
179 newMessage(i
,_I
) = marg
;
183 // BP::run does not check for NANs for performance reasons
184 // Somehow NaNs do not often occur in BP...
187 cout
<< "Starting " << identify() << "...";
192 Diffs
diffs(nrVars(), 1.0);
194 typedef pair
<size_t,size_t> Edge
;
195 vector
<Edge
> update_seq
;
197 vector
<Factor
> old_beliefs
;
198 old_beliefs
.reserve( nrVars() );
199 for( size_t i
= 0; i
< nrVars(); ++i
)
200 old_beliefs
.push_back( beliefV(i
) );
203 size_t nredges
= nrEdges();
205 if( Updates() == UpdateType::SEQMAX
) {
207 for( size_t i
= 0; i
< nrVars(); ++i
)
208 foreach( const Neighbor
&I
, nbV(i
) ) {
209 calcNewMessage( i
, I
.iter
);
210 // calculate initial residuals
211 residual( i
, I
.iter
) = dist( newMessage( i
, I
.iter
), message( i
, I
.iter
), Prob::DISTLINF
);
214 update_seq
.reserve( nredges
);
215 for( size_t i
= 0; i
< nrVars(); ++i
)
216 foreach( const Neighbor
&I
, nbV(i
) )
217 update_seq
.push_back( Edge( i
, I
.iter
) );
220 // do several passes over the network until maximum number of iterations has
221 // been reached or until the maximum belief difference is smaller than tolerance
222 for( iter
=0; iter
< MaxIter() && diffs
.maxDiff() > Tol(); ++iter
) {
223 if( Updates() == UpdateType::SEQMAX
) {
224 // Residuals-BP by Koller et al.
225 for( size_t t
= 0; t
< nredges
; ++t
) {
226 // update the message with the largest residual
229 findMaxResidual( i
, _I
);
230 message( i
, _I
) = newMessage( i
, _I
);
231 residual( i
, _I
) = 0.0;
233 // I->i has been updated, which means that residuals for all
234 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
235 foreach( const Neighbor
&J
, nbV(i
) ) {
237 foreach( const Neighbor
&j
, nbF(J
) ) {
240 calcNewMessage( j
, _J
);
241 residual( j
, _J
) = dist( newMessage( j
, _J
), message( j
, _J
), Prob::DISTLINF
);
247 } else if( Updates() == UpdateType::PARALL
) {
249 for( size_t i
= 0; i
< nrVars(); ++i
)
250 foreach( const Neighbor
&I
, nbV(i
) )
251 calcNewMessage( i
, I
.iter
);
253 for( size_t i
= 0; i
< nrVars(); ++i
)
254 foreach( const Neighbor
&I
, nbV(i
) )
255 message( i
, I
.iter
) = newMessage( i
, I
.iter
);
257 // Sequential updates
258 if( Updates() == UpdateType::SEQRND
)
259 random_shuffle( update_seq
.begin(), update_seq
.end() );
261 foreach( const Edge
&e
, update_seq
) {
262 calcNewMessage( e
.first
, e
.second
);
263 message( e
.first
, e
.second
) = newMessage( e
.first
, e
.second
);
267 // calculate new beliefs and compare with old ones
268 for( size_t i
= 0; i
< nrVars(); ++i
) {
269 Factor
nb( beliefV(i
) );
270 diffs
.push( dist( nb
, old_beliefs
[i
], Prob::DISTLINF
) );
275 cout
<< "BP::run: maxdiff " << diffs
.maxDiff() << " after " << iter
+1 << " passes" << endl
;
278 updateMaxDiff( diffs
.maxDiff() );
280 if( Verbose() >= 1 ) {
281 if( diffs
.maxDiff() > Tol() ) {
284 cout
<< "BP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic
<< " clocks)...final maxdiff:" << diffs
.maxDiff() << endl
;
288 cout
<< "converged in " << iter
<< " passes (" << toc() - tic
<< " clocks)." << endl
;
292 return diffs
.maxDiff();
296 Factor
BP::beliefV( size_t i
) const {
297 Prob
prod( var(i
).states(), logDomain
? 0.0 : 1.0 );
298 foreach( const Neighbor
&I
, nbV(i
) )
300 prod
+= newMessage( i
, I
.iter
);
302 prod
*= newMessage( i
, I
.iter
);
304 prod
-= prod
.maxVal();
308 prod
.normalize( Prob::NORMPROB
);
309 return( Factor( var(i
), prod
) );
313 Factor
BP::belief (const Var
&n
) const {
314 return( beliefV( findVar( n
) ) );
318 vector
<Factor
> BP::beliefs() const {
319 vector
<Factor
> result
;
320 for( size_t i
= 0; i
< nrVars(); ++i
)
321 result
.push_back( beliefV(i
) );
322 for( size_t I
= 0; I
< nrFactors(); ++I
)
323 result
.push_back( beliefF(I
) );
328 Factor
BP::belief( const VarSet
&ns
) const {
330 return belief( *(ns
.begin()) );
333 for( I
= 0; I
< nrFactors(); I
++ )
334 if( factor(I
).vars() >> ns
)
336 assert( I
!= nrFactors() );
337 return beliefF(I
).marginal(ns
);
342 Factor
BP::beliefF (size_t I
) const {
343 Prob
prod( factor(I
).p() );
347 foreach( const Neighbor
&j
, nbF(I
) ) {
349 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
350 const ind_t
& ind
= index(j
, _I
);
352 // prod_j will be the product of messages coming into j
353 Prob
prod_j( var(j
).states(), logDomain
? 0.0 : 1.0 );
354 foreach( const Neighbor
&J
, nbV(j
) ) {
355 if( J
!= I
) { // for all J in nb(j) \ I
357 prod_j
+= newMessage( j
, J
.iter
);
359 prod_j
*= newMessage( j
, J
.iter
);
363 // multiply prod with prod_j
364 for( size_t r
= 0; r
< prod
.size(); ++r
) {
366 prod
[r
] += prod_j
[ind
[r
]];
368 prod
[r
] *= prod_j
[ind
[r
]];
373 prod
-= prod
.maxVal();
377 Factor
result( factor(I
).vars(), prod
);
378 result
.normalize( Prob::NORMPROB
);
382 /* UNOPTIMIZED VERSION
384 Factor prod( factor(I) );
385 for( _nb_cit i = nb2(I).begin(); i != nb2(I).end(); i++ ) {
386 for( _nb_cit J = nb1(*i).begin(); J != nb1(*i).end(); J++ )
388 prod *= Factor( var(*i), newMessage(*i,*J)) );
390 return prod.normalize( Prob::NORMPROB );*/
394 Complex
BP::logZ() const {
396 for(size_t i
= 0; i
< nrVars(); ++i
)
397 sum
+= Complex(1.0 - nbV(i
).size()) * beliefV(i
).entropy();
398 for( size_t I
= 0; I
< nrFactors(); ++I
)
399 sum
-= KL_dist( beliefF(I
), factor(I
) );
404 string
BP::identify() const {
405 stringstream
result (stringstream::out
);
406 result
<< Name
<< GetProperties();
411 void BP::init( const VarSet
&ns
) {
412 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); ++n
) {
413 size_t ni
= findVar( *n
);
414 foreach( const Neighbor
&I
, nbV( ni
) )
415 message( ni
, I
.iter
).fill( logDomain
? 0.0 : 1.0 );
420 } // end of namespace dai