Updated copyrights
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <sstream>
25 #include <map>
26 #include <set>
27 #include <algorithm>
28 #include <dai/bp.h>
29 #include <dai/diffs.h>
30 #include <dai/util.h>
31 #include <dai/properties.h>
32
33
34 namespace dai {
35
36
37 using namespace std;
38
39
40 const char *BP::Name = "BP";
41
42
43 void BP::setProperties( const PropertySet &opts ) {
44 assert( opts.hasKey("tol") );
45 assert( opts.hasKey("maxiter") );
46 assert( opts.hasKey("logdomain") );
47 assert( opts.hasKey("updates") );
48
49 props.tol = opts.getStringAs<double>("tol");
50 props.maxiter = opts.getStringAs<size_t>("maxiter");
51 props.logdomain = opts.getStringAs<bool>("logdomain");
52 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
53
54 if( opts.hasKey("verbose") )
55 props.verbose = opts.getStringAs<size_t>("verbose");
56 else
57 props.verbose = 0;
58 if( opts.hasKey("damping") )
59 props.damping = opts.getStringAs<double>("damping");
60 else
61 props.damping = 0.0;
62 if( opts.hasKey("inference") )
63 props.inference = opts.getStringAs<Properties::InfType>("inference");
64 else
65 props.inference = Properties::InfType::SUMPROD;
66 }
67
68
69 PropertySet BP::getProperties() const {
70 PropertySet opts;
71 opts.Set( "tol", props.tol );
72 opts.Set( "maxiter", props.maxiter );
73 opts.Set( "verbose", props.verbose );
74 opts.Set( "logdomain", props.logdomain );
75 opts.Set( "updates", props.updates );
76 opts.Set( "damping", props.damping );
77 opts.Set( "inference", props.inference );
78 return opts;
79 }
80
81
82 string BP::printProperties() const {
83 stringstream s( stringstream::out );
84 s << "[";
85 s << "tol=" << props.tol << ",";
86 s << "maxiter=" << props.maxiter << ",";
87 s << "verbose=" << props.verbose << ",";
88 s << "logdomain=" << props.logdomain << ",";
89 s << "updates=" << props.updates << ",";
90 s << "damping=" << props.damping << ",";
91 s << "inference=" << props.inference << "]";
92 return s.str();
93 }
94
95
96 void BP::construct() {
97 // create edge properties
98 _edges.clear();
99 _edges.reserve( nrVars() );
100 for( size_t i = 0; i < nrVars(); ++i ) {
101 _edges.push_back( vector<EdgeProp>() );
102 _edges[i].reserve( nbV(i).size() );
103 foreach( const Neighbor &I, nbV(i) ) {
104 EdgeProp newEP;
105 newEP.message = Prob( var(i).states() );
106 newEP.newMessage = Prob( var(i).states() );
107
108 newEP.index.reserve( factor(I).states() );
109 for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
110 newEP.index.push_back( k );
111
112 newEP.residual = 0.0;
113 _edges[i].push_back( newEP );
114 }
115 }
116 }
117
118
119 void BP::init() {
120 double c = props.logdomain ? 0.0 : 1.0;
121 for( size_t i = 0; i < nrVars(); ++i ) {
122 foreach( const Neighbor &I, nbV(i) ) {
123 message( i, I.iter ).fill( c );
124 newMessage( i, I.iter ).fill( c );
125 }
126 }
127 }
128
129
130 void BP::findMaxResidual( size_t &i, size_t &_I ) {
131 i = 0;
132 _I = 0;
133 double maxres = residual( i, _I );
134 for( size_t j = 0; j < nrVars(); ++j )
135 foreach( const Neighbor &I, nbV(j) )
136 if( residual( j, I.iter ) > maxres ) {
137 i = j;
138 _I = I.iter;
139 maxres = residual( i, _I );
140 }
141 }
142
143
144 void BP::calcNewMessage( size_t i, size_t _I ) {
145 // calculate updated message I->i
146 size_t I = nbV(i,_I);
147
148 if( 0 == 1 ) {
149 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
150 Factor prod( factor( I ) );
151 foreach( const Neighbor &j, nbF(I) )
152 if( j != i ) { // for all j in I \ i
153 foreach( const Neighbor &J, nbV(j) )
154 if( J != I ) { // for all J in nb(j) \ I
155 prod *= Factor( var(j), message(j, J.iter) );
156 }
157 }
158 newMessage(i,_I) = prod.marginal( var(i) ).p();
159 } else {
160 /* OPTIMIZED VERSION */
161 Prob prod( factor(I).p() );
162 if( props.logdomain )
163 prod.takeLog();
164
165 // Calculate product of incoming messages and factor I
166 foreach( const Neighbor &j, nbF(I) ) {
167 if( j != i ) { // for all j in I \ i
168 size_t _I = j.dual;
169 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
170 const ind_t &ind = index(j, _I);
171
172 // prod_j will be the product of messages coming into j
173 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
174 foreach( const Neighbor &J, nbV(j) )
175 if( J != I ) { // for all J in nb(j) \ I
176 if( props.logdomain )
177 prod_j += message( j, J.iter );
178 else
179 prod_j *= message( j, J.iter );
180 }
181
182 // multiply prod with prod_j
183 for( size_t r = 0; r < prod.size(); ++r )
184 if( props.logdomain )
185 prod[r] += prod_j[ind[r]];
186 else
187 prod[r] *= prod_j[ind[r]];
188 }
189 }
190 if( props.logdomain ) {
191 prod -= prod.maxVal();
192 prod.takeExp();
193 }
194
195 // Marginalize onto i
196 Prob marg( var(i).states(), 0.0 );
197 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
198 const ind_t ind = index(i,_I);
199 if( props.inference == Properties::InfType::SUMPROD )
200 for( size_t r = 0; r < prod.size(); ++r )
201 marg[ind[r]] += prod[r];
202 else
203 for( size_t r = 0; r < prod.size(); ++r )
204 if( prod[r] > marg[ind[r]] )
205 marg[ind[r]] = prod[r];
206 marg.normalize();
207
208 // Store result
209 if( props.logdomain )
210 newMessage(i,_I) = marg.log();
211 else
212 newMessage(i,_I) = marg;
213 }
214 }
215
216
217 // BP::run does not check for NANs for performance reasons
218 // Somehow NaNs do not often occur in BP...
219 double BP::run() {
220 if( props.verbose >= 1 )
221 cout << "Starting " << identify() << "...";
222 if( props.verbose >= 3)
223 cout << endl;
224
225 double tic = toc();
226 Diffs diffs(nrVars(), 1.0);
227
228 vector<Edge> update_seq;
229
230 vector<Factor> old_beliefs;
231 old_beliefs.reserve( nrVars() );
232 for( size_t i = 0; i < nrVars(); ++i )
233 old_beliefs.push_back( beliefV(i) );
234
235 size_t nredges = nrEdges();
236
237 if( props.updates == Properties::UpdateType::SEQMAX ) {
238 // do the first pass
239 for( size_t i = 0; i < nrVars(); ++i )
240 foreach( const Neighbor &I, nbV(i) ) {
241 calcNewMessage( i, I.iter );
242 // calculate initial residuals
243 residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
244 }
245 } else {
246 update_seq.reserve( nredges );
247 for( size_t i = 0; i < nrVars(); ++i )
248 foreach( const Neighbor &I, nbV(i) )
249 update_seq.push_back( Edge( i, I.iter ) );
250 }
251
252 // do several passes over the network until maximum number of iterations has
253 // been reached or until the maximum belief difference is smaller than tolerance
254 for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; ++_iters ) {
255 if( props.updates == Properties::UpdateType::SEQMAX ) {
256 // Residuals-BP by Koller et al.
257 for( size_t t = 0; t < nredges; ++t ) {
258 // update the message with the largest residual
259 size_t i, _I;
260 findMaxResidual( i, _I );
261 updateMessage( i, _I );
262
263 // I->i has been updated, which means that residuals for all
264 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
265 foreach( const Neighbor &J, nbV(i) ) {
266 if( J.iter != _I ) {
267 foreach( const Neighbor &j, nbF(J) ) {
268 size_t _J = j.dual;
269 if( j != i ) {
270 calcNewMessage( j, _J );
271 residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
272 }
273 }
274 }
275 }
276 }
277 } else if( props.updates == Properties::UpdateType::PARALL ) {
278 // Parallel updates
279 for( size_t i = 0; i < nrVars(); ++i )
280 foreach( const Neighbor &I, nbV(i) )
281 calcNewMessage( i, I.iter );
282
283 for( size_t i = 0; i < nrVars(); ++i )
284 foreach( const Neighbor &I, nbV(i) )
285 updateMessage( i, I.iter );
286 } else {
287 // Sequential updates
288 if( props.updates == Properties::UpdateType::SEQRND )
289 random_shuffle( update_seq.begin(), update_seq.end() );
290
291 foreach( const Edge &e, update_seq ) {
292 calcNewMessage( e.first, e.second );
293 updateMessage( e.first, e.second );
294 }
295 }
296
297 // calculate new beliefs and compare with old ones
298 for( size_t i = 0; i < nrVars(); ++i ) {
299 Factor nb( beliefV(i) );
300 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
301 old_beliefs[i] = nb;
302 }
303
304 if( props.verbose >= 3 )
305 cout << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
306 }
307
308 if( diffs.maxDiff() > _maxdiff )
309 _maxdiff = diffs.maxDiff();
310
311 if( props.verbose >= 1 ) {
312 if( diffs.maxDiff() > props.tol ) {
313 if( props.verbose == 1 )
314 cout << endl;
315 cout << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
316 } else {
317 if( props.verbose >= 3 )
318 cout << Name << "::run: ";
319 cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
320 }
321 }
322
323 return diffs.maxDiff();
324 }
325
326
327 Factor BP::beliefV( size_t i ) const {
328 Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 );
329 foreach( const Neighbor &I, nbV(i) )
330 if( props.logdomain )
331 prod += newMessage( i, I.iter );
332 else
333 prod *= newMessage( i, I.iter );
334 if( props.logdomain ) {
335 prod -= prod.maxVal();
336 prod.takeExp();
337 }
338
339 prod.normalize();
340 return( Factor( var(i), prod ) );
341 }
342
343
344 Factor BP::belief (const Var &n) const {
345 return( beliefV( findVar( n ) ) );
346 }
347
348
349 vector<Factor> BP::beliefs() const {
350 vector<Factor> result;
351 for( size_t i = 0; i < nrVars(); ++i )
352 result.push_back( beliefV(i) );
353 for( size_t I = 0; I < nrFactors(); ++I )
354 result.push_back( beliefF(I) );
355 return result;
356 }
357
358
359 Factor BP::belief( const VarSet &ns ) const {
360 if( ns.size() == 1 )
361 return belief( *(ns.begin()) );
362 else {
363 size_t I;
364 for( I = 0; I < nrFactors(); I++ )
365 if( factor(I).vars() >> ns )
366 break;
367 assert( I != nrFactors() );
368 return beliefF(I).marginal(ns);
369 }
370 }
371
372
373 Factor BP::beliefF (size_t I) const {
374 if( 0 == 1 ) {
375 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
376
377 Factor prod( factor(I) );
378 foreach( const Neighbor &j, nbF(I) ) {
379 foreach( const Neighbor &J, nbV(j) ) {
380 if( J != I ) // for all J in nb(j) \ I
381 prod *= Factor( var(j), newMessage(j, J.iter) );
382 }
383 }
384 return prod.normalized();
385 } else {
386 /* OPTIMIZED VERSION */
387 Prob prod( factor(I).p() );
388 if( props.logdomain )
389 prod.takeLog();
390
391 foreach( const Neighbor &j, nbF(I) ) {
392 size_t _I = j.dual;
393 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
394 const ind_t & ind = index(j, _I);
395
396 // prod_j will be the product of messages coming into j
397 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
398 foreach( const Neighbor &J, nbV(j) ) {
399 if( J != I ) { // for all J in nb(j) \ I
400 if( props.logdomain )
401 prod_j += newMessage( j, J.iter );
402 else
403 prod_j *= newMessage( j, J.iter );
404 }
405 }
406
407 // multiply prod with prod_j
408 for( size_t r = 0; r < prod.size(); ++r ) {
409 if( props.logdomain )
410 prod[r] += prod_j[ind[r]];
411 else
412 prod[r] *= prod_j[ind[r]];
413 }
414 }
415
416 if( props.logdomain ) {
417 prod -= prod.maxVal();
418 prod.takeExp();
419 }
420
421 Factor result( factor(I).vars(), prod );
422 result.normalize();
423
424 return( result );
425 }
426 }
427
428
429 Real BP::logZ() const {
430 Real sum = 0.0;
431 for(size_t i = 0; i < nrVars(); ++i )
432 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
433 for( size_t I = 0; I < nrFactors(); ++I )
434 sum -= KL_dist( beliefF(I), factor(I) );
435 return sum;
436 }
437
438
439 string BP::identify() const {
440 return string(Name) + printProperties();
441 }
442
443
444 void BP::init( const VarSet &ns ) {
445 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
446 size_t ni = findVar( *n );
447 foreach( const Neighbor &I, nbV( ni ) )
448 message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
449 }
450 }
451
452
453 } // end of namespace dai