1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
28 #include <dai/diffs.h>
30 #include <dai/properties.h>
39 const char *BP::Name
= "BP";
42 bool BP::checkProperties() {
43 if( !HasProperty("updates") )
45 if( !HasProperty("tol") )
47 if (!HasProperty("maxiter") )
49 if (!HasProperty("verbose") )
52 ConvertPropertyTo
<double>("tol");
53 ConvertPropertyTo
<size_t>("maxiter");
54 ConvertPropertyTo
<size_t>("verbose");
55 ConvertPropertyTo
<UpdateType
>("updates");
61 void BP::Regenerate() {
62 // create edge properties
64 edges
.reserve( nrVars() );
65 for( size_t i
= 0; i
< nrVars(); ++i
) {
66 edges
.push_back( vector
<EdgeProp
>() );
67 edges
[i
].reserve( nbV(i
).size() );
68 foreach( const Neighbor
&I
, nbV(i
) ) {
70 newEP
.message
= Prob( var(i
).states() );
71 newEP
.newMessage
= Prob( var(i
).states() );
73 newEP
.index
.reserve( factor(I
).stateSpace() );
74 for( Index
k( var(i
), factor(I
).vars() ); k
>= 0; ++k
)
75 newEP
.index
.push_back( k
);
78 edges
[i
].push_back( newEP
);
85 assert( checkProperties() );
86 for( size_t i
= 0; i
< nrVars(); ++i
) {
87 foreach( const Neighbor
&I
, nbV(i
) ) {
88 message( i
, I
.iter
).fill( 1.0 );
89 newMessage( i
, I
.iter
).fill( 1.0 );
95 void BP::findMaxResidual( size_t &i
, size_t &_I
) {
98 double maxres
= residual( i
, _I
);
99 for( size_t j
= 0; j
< nrVars(); ++j
)
100 foreach( const Neighbor
&I
, nbV(j
) )
101 if( residual( j
, I
.iter
) > maxres
) {
104 maxres
= residual( i
, _I
);
109 void BP::calcNewMessage( size_t i
, size_t _I
) {
110 // calculate updated message I->i
111 size_t I
= nbV(i
,_I
);
113 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
115 Factor prod( factor( I ) );
116 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ )
117 if( *j != i ) { // for all j in I \ i
118 for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); J++ )
119 if( *J != I ) { // for all J in nb(j) \ I
120 prod *= Factor( *j, message(*j,*J) );
121 Factor marg = prod.marginal(var(i));
124 Prob
prod( factor(I
).p() );
126 // Calculate product of incoming messages and factor I
127 foreach( const Neighbor
&j
, nbF(I
) ) {
128 if( j
!= i
) { // for all j in I \ i
130 // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
131 const ind_t
& ind
= index(j
, _I
);
133 // prod_j will be the product of messages coming into j
134 Prob
prod_j( var(j
).states() );
135 foreach( const Neighbor
&J
, nbV(j
) ) {
136 if( J
!= I
) // for all J in nb(j) \ I
137 prod_j
*= message( j
, J
.iter
);
140 // multiply prod with prod_j
141 for( size_t r
= 0; r
< prod
.size(); ++r
)
142 prod
[r
] *= prod_j
[ind
[r
]];
146 // Marginalize onto i
147 Prob
marg( var(i
).states(), 0.0 );
148 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
149 const ind_t ind
= index(i
,_I
);
150 for( size_t r
= 0; r
< prod
.size(); ++r
)
151 marg
[ind
[r
]] += prod
[r
];
152 marg
.normalize( _normtype
);
155 newMessage(i
,_I
) = marg
;
159 // BP::run does not check for NANs for performance reasons
160 // Somehow NaNs do not often occur in BP...
163 cout
<< "Starting " << identify() << "...";
168 Diffs
diffs(nrVars(), 1.0);
170 typedef pair
<size_t,size_t> Edge
;
171 vector
<Edge
> update_seq
;
173 vector
<Factor
> old_beliefs
;
174 old_beliefs
.reserve( nrVars() );
175 for( size_t i
= 0; i
< nrVars(); ++i
)
176 old_beliefs
.push_back( beliefV(i
) );
179 size_t nredges
= nrEdges();
181 if( Updates() == UpdateType::SEQMAX
) {
183 for( size_t i
= 0; i
< nrVars(); ++i
)
184 foreach( const Neighbor
&I
, nbV(i
) ) {
185 calcNewMessage( i
, I
.iter
);
186 // calculate initial residuals
187 residual( i
, I
.iter
) = dist( newMessage( i
, I
.iter
), message( i
, I
.iter
), Prob::DISTLINF
);
190 update_seq
.reserve( nredges
);
191 for( size_t i
= 0; i
< nrVars(); ++i
)
192 foreach( const Neighbor
&I
, nbV(i
) )
193 update_seq
.push_back( Edge( i
, I
.iter
) );
196 // do several passes over the network until maximum number of iterations has
197 // been reached or until the maximum belief difference is smaller than tolerance
198 for( iter
=0; iter
< MaxIter() && diffs
.max() > Tol(); ++iter
) {
199 if( Updates() == UpdateType::SEQMAX
) {
200 // Residuals-BP by Koller et al.
201 for( size_t t
= 0; t
< nredges
; ++t
) {
202 // update the message with the largest residual
205 findMaxResidual( i
, _I
);
206 message( i
, _I
) = newMessage( i
, _I
);
207 residual( i
, _I
) = 0.0;
209 // I->i has been updated, which means that residuals for all
210 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
211 foreach( const Neighbor
&J
, nbV(i
) ) {
213 foreach( const Neighbor
&j
, nbF(J
) ) {
216 calcNewMessage( j
, _J
);
217 residual( j
, _J
) = dist( newMessage( j
, _J
), message( j
, _J
), Prob::DISTLINF
);
223 } else if( Updates() == UpdateType::PARALL
) {
225 for( size_t i
= 0; i
< nrVars(); ++i
)
226 foreach( const Neighbor
&I
, nbV(i
) )
227 calcNewMessage( i
, I
.iter
);
229 for( size_t i
= 0; i
< nrVars(); ++i
)
230 foreach( const Neighbor
&I
, nbV(i
) )
231 message( i
, I
.iter
) = newMessage( i
, I
.iter
);
233 // Sequential updates
234 if( Updates() == UpdateType::SEQRND
)
235 random_shuffle( update_seq
.begin(), update_seq
.end() );
237 foreach( const Edge
&e
, update_seq
) {
238 calcNewMessage( e
.first
, e
.second
);
239 message( e
.first
, e
.second
) = newMessage( e
.first
, e
.second
);
243 // calculate new beliefs and compare with old ones
244 for( size_t i
= 0; i
< nrVars(); ++i
) {
245 Factor
nb( beliefV(i
) );
246 diffs
.push( dist( nb
, old_beliefs
[i
], Prob::DISTLINF
) );
251 cout
<< "BP::run: maxdiff " << diffs
.max() << " after " << iter
+1 << " passes" << endl
;
254 updateMaxDiff( diffs
.max() );
256 if( Verbose() >= 1 ) {
257 if( diffs
.max() > Tol() ) {
260 cout
<< "BP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic
<< " clocks)...final maxdiff:" << diffs
.max() << endl
;
264 cout
<< "converged in " << iter
<< " passes (" << toc() - tic
<< " clocks)." << endl
;
272 Factor
BP::beliefV( size_t i
) const {
273 Prob
prod( var(i
).states() );
274 foreach( const Neighbor
&I
, nbV(i
) )
275 prod
*= newMessage( i
, I
.iter
);
277 prod
.normalize( Prob::NORMPROB
);
278 return( Factor( var(i
), prod
) );
282 Factor
BP::belief (const Var
&n
) const {
283 return( beliefV( findVar( n
) ) );
287 vector
<Factor
> BP::beliefs() const {
288 vector
<Factor
> result
;
289 for( size_t i
= 0; i
< nrVars(); ++i
)
290 result
.push_back( beliefV(i
) );
291 for( size_t I
= 0; I
< nrFactors(); ++I
)
292 result
.push_back( beliefF(I
) );
297 Factor
BP::belief( const VarSet
&ns
) const {
299 return belief( *(ns
.begin()) );
302 for( I
= 0; I
< nrFactors(); I
++ )
303 if( factor(I
).vars() >> ns
)
305 assert( I
!= nrFactors() );
306 return beliefF(I
).marginal(ns
);
311 Factor
BP::beliefF (size_t I
) const {
312 Prob
prod( factor(I
).p() );
314 foreach( const Neighbor
&j
, nbF(I
) ) {
316 // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
317 const ind_t
& ind
= index(j
, _I
);
319 // prod_j will be the product of messages coming into j
320 Prob
prod_j( var(j
).states() );
321 foreach( const Neighbor
&J
, nbV(j
) ) {
322 if( J
!= I
) // for all J in nb(j) \ I
323 prod_j
*= newMessage( j
, J
.iter
);
326 // multiply prod with prod_j
327 for( size_t r
= 0; r
< prod
.size(); ++r
)
328 prod
[r
] *= prod_j
[ind
[r
]];
331 Factor
result( factor(I
).vars(), prod
);
332 result
.normalize( Prob::NORMPROB
);
336 /* UNOPTIMIZED VERSION
338 Factor prod( factor(I) );
339 for( _nb_cit i = nb2(I).begin(); i != nb2(I).end(); i++ ) {
340 for( _nb_cit J = nb1(*i).begin(); J != nb1(*i).end(); J++ )
342 prod *= Factor( var(*i), newMessage(*i,*J)) );
344 return prod.normalize( Prob::NORMPROB );*/
348 Complex
BP::logZ() const {
350 for(size_t i
= 0; i
< nrVars(); ++i
)
351 sum
+= Complex(1.0 - nbV(i
).size()) * beliefV(i
).entropy();
352 for( size_t I
= 0; I
< nrFactors(); ++I
)
353 sum
-= KL_dist( beliefF(I
), factor(I
) );
358 string
BP::identify() const {
359 stringstream
result (stringstream::out
);
360 result
<< Name
<< GetProperties();
365 void BP::init( const VarSet
&ns
) {
366 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); ++n
) {
367 size_t ni
= findVar( *n
);
368 foreach( const Neighbor
&I
, nbV( ni
) )
369 message( ni
, I
.iter
).fill( 1.0 );
374 } // end of namespace dai