Merged bp.h and bp.cpp from SVN head
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <algorithm>
27 #include <dai/bp.h>
28 #include <dai/diffs.h>
29 #include <dai/util.h>
30 #include <dai/properties.h>
31
32
33 namespace dai {
34
35
36 using namespace std;
37
38
39 const char *BP::Name = "BP";
40
41
42 void BP::setProperties( const PropertySet &opts ) {
43 assert( opts.hasKey("tol") );
44 assert( opts.hasKey("maxiter") );
45 assert( opts.hasKey("logdomain") );
46 assert( opts.hasKey("updates") );
47
48 props.tol = opts.getStringAs<double>("tol");
49 props.maxiter = opts.getStringAs<size_t>("maxiter");
50 props.logdomain = opts.getStringAs<bool>("logdomain");
51 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
52
53 if( opts.hasKey("verbose") )
54 props.verbose = opts.getStringAs<size_t>("verbose");
55 else
56 props.verbose = 0;
57 if( opts.hasKey("damping") )
58 props.damping = opts.getStringAs<double>("damping");
59 else
60 props.damping = 0.0;
61 }
62
63
64 PropertySet BP::getProperties() const {
65 PropertySet opts;
66 opts.Set( "tol", props.tol );
67 opts.Set( "maxiter", props.maxiter );
68 opts.Set( "verbose", props.verbose );
69 opts.Set( "logdomain", props.logdomain );
70 opts.Set( "updates", props.updates );
71 opts.Set( "damping", props.damping );
72 return opts;
73 }
74
75
76 string BP::printProperties() const {
77 stringstream s( stringstream::out );
78 s << "[";
79 s << "tol=" << props.tol << ",";
80 s << "maxiter=" << props.maxiter << ",";
81 s << "verbose=" << props.verbose << ",";
82 s << "logdomain=" << props.logdomain << ",";
83 s << "updates=" << props.updates << ",";
84 s << "damping=" << props.damping << "]";
85 return s.str();
86 }
87
88
89 void BP::construct() {
90 // create edge properties
91 _edges.clear();
92 _edges.reserve( nrVars() );
93 for( size_t i = 0; i < nrVars(); ++i ) {
94 _edges.push_back( vector<EdgeProp>() );
95 _edges[i].reserve( nbV(i).size() );
96 foreach( const Neighbor &I, nbV(i) ) {
97 EdgeProp newEP;
98 newEP.message = Prob( var(i).states() );
99 newEP.newMessage = Prob( var(i).states() );
100
101 newEP.index.reserve( factor(I).states() );
102 for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
103 newEP.index.push_back( k );
104
105 newEP.residual = 0.0;
106 _edges[i].push_back( newEP );
107 }
108 }
109 }
110
111
112 void BP::init() {
113 double c = props.logdomain ? 0.0 : 1.0;
114 for( size_t i = 0; i < nrVars(); ++i ) {
115 foreach( const Neighbor &I, nbV(i) ) {
116 message( i, I.iter ).fill( c );
117 newMessage( i, I.iter ).fill( c );
118 }
119 }
120 }
121
122
123 void BP::findMaxResidual( size_t &i, size_t &_I ) {
124 i = 0;
125 _I = 0;
126 double maxres = residual( i, _I );
127 for( size_t j = 0; j < nrVars(); ++j )
128 foreach( const Neighbor &I, nbV(j) )
129 if( residual( j, I.iter ) > maxres ) {
130 i = j;
131 _I = I.iter;
132 maxres = residual( i, _I );
133 }
134 }
135
136
137 void BP::calcNewMessage( size_t i, size_t _I ) {
138 // calculate updated message I->i
139 size_t I = nbV(i,_I);
140
141 if( 0 == 1 ) {
142 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
143 Factor prod( factor( I ) );
144 foreach( const Neighbor &j, nbF(I) )
145 if( j != i ) { // for all j in I \ i
146 foreach( const Neighbor &J, nbV(j) )
147 if( J != I ) { // for all J in nb(j) \ I
148 prod *= Factor( var(j), message(j, J.iter) );
149 }
150 }
151 newMessage(i,_I) = prod.marginal( var(i) ).p();
152 } else {
153 /* OPTIMIZED VERSION */
154 Prob prod( factor(I).p() );
155 if( props.logdomain )
156 prod.takeLog();
157
158 // Calculate product of incoming messages and factor I
159 foreach( const Neighbor &j, nbF(I) ) {
160 if( j != i ) { // for all j in I \ i
161 size_t _I = j.dual;
162 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
163 const ind_t &ind = index(j, _I);
164
165 // prod_j will be the product of messages coming into j
166 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
167 foreach( const Neighbor &J, nbV(j) )
168 if( J != I ) { // for all J in nb(j) \ I
169 if( props.logdomain )
170 prod_j += message( j, J.iter );
171 else
172 prod_j *= message( j, J.iter );
173 }
174
175 // multiply prod with prod_j
176 for( size_t r = 0; r < prod.size(); ++r )
177 if( props.logdomain )
178 prod[r] += prod_j[ind[r]];
179 else
180 prod[r] *= prod_j[ind[r]];
181 }
182 }
183 if( props.logdomain ) {
184 prod -= prod.maxVal();
185 prod.takeExp();
186 }
187
188 // Marginalize onto i
189 Prob marg( var(i).states(), 0.0 );
190 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
191 const ind_t ind = index(i,_I);
192 for( size_t r = 0; r < prod.size(); ++r )
193 marg[ind[r]] += prod[r];
194 marg.normalize();
195
196 // Store result
197 if( props.logdomain )
198 newMessage(i,_I) = marg.log();
199 else
200 newMessage(i,_I) = marg;
201 }
202 }
203
204
205 // BP::run does not check for NANs for performance reasons
206 // Somehow NaNs do not often occur in BP...
207 double BP::run() {
208 if( props.verbose >= 1 )
209 cout << "Starting " << identify() << "...";
210 if( props.verbose >= 3)
211 cout << endl;
212
213 double tic = toc();
214 Diffs diffs(nrVars(), 1.0);
215
216 vector<Edge> update_seq;
217
218 vector<Factor> old_beliefs;
219 old_beliefs.reserve( nrVars() );
220 for( size_t i = 0; i < nrVars(); ++i )
221 old_beliefs.push_back( beliefV(i) );
222
223 size_t nredges = nrEdges();
224
225 if( props.updates == Properties::UpdateType::SEQMAX ) {
226 // do the first pass
227 for( size_t i = 0; i < nrVars(); ++i )
228 foreach( const Neighbor &I, nbV(i) ) {
229 calcNewMessage( i, I.iter );
230 // calculate initial residuals
231 residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
232 }
233 } else {
234 update_seq.reserve( nredges );
235 for( size_t i = 0; i < nrVars(); ++i )
236 foreach( const Neighbor &I, nbV(i) )
237 update_seq.push_back( Edge( i, I.iter ) );
238 }
239
240 // do several passes over the network until maximum number of iterations has
241 // been reached or until the maximum belief difference is smaller than tolerance
242 for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; ++_iters ) {
243 if( props.updates == Properties::UpdateType::SEQMAX ) {
244 // Residuals-BP by Koller et al.
245 for( size_t t = 0; t < nredges; ++t ) {
246 // update the message with the largest residual
247 size_t i, _I;
248 findMaxResidual( i, _I );
249 updateMessage( i, _I );
250 residual( i, _I ) = 0.0;
251
252 // I->i has been updated, which means that residuals for all
253 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
254 foreach( const Neighbor &J, nbV(i) ) {
255 if( J.iter != _I ) {
256 foreach( const Neighbor &j, nbF(J) ) {
257 size_t _J = j.dual;
258 if( j != i ) {
259 calcNewMessage( j, _J );
260 residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
261 }
262 }
263 }
264 }
265 }
266 } else if( props.updates == Properties::UpdateType::PARALL ) {
267 // Parallel updates
268 for( size_t i = 0; i < nrVars(); ++i )
269 foreach( const Neighbor &I, nbV(i) )
270 calcNewMessage( i, I.iter );
271
272 for( size_t i = 0; i < nrVars(); ++i )
273 foreach( const Neighbor &I, nbV(i) )
274 updateMessage( i, I.iter );
275 } else {
276 // Sequential updates
277 if( props.updates == Properties::UpdateType::SEQRND )
278 random_shuffle( update_seq.begin(), update_seq.end() );
279
280 foreach( const Edge &e, update_seq ) {
281 calcNewMessage( e.first, e.second );
282 updateMessage( e.first, e.second );
283 }
284 }
285
286 // calculate new beliefs and compare with old ones
287 for( size_t i = 0; i < nrVars(); ++i ) {
288 Factor nb( beliefV(i) );
289 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
290 old_beliefs[i] = nb;
291 }
292
293 if( props.verbose >= 3 )
294 cout << "BP::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
295 }
296
297 if( diffs.maxDiff() > _maxdiff )
298 _maxdiff = diffs.maxDiff();
299
300 if( props.verbose >= 1 ) {
301 if( diffs.maxDiff() > props.tol ) {
302 if( props.verbose == 1 )
303 cout << endl;
304 cout << "BP::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
305 } else {
306 if( props.verbose >= 3 )
307 cout << "BP::run: ";
308 cout << "converged in " << _iters << " passes (" << toc() - tic << " clocks)." << endl;
309 }
310 }
311
312 return diffs.maxDiff();
313 }
314
315
316 Factor BP::beliefV( size_t i ) const {
317 Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 );
318 foreach( const Neighbor &I, nbV(i) )
319 if( props.logdomain )
320 prod += newMessage( i, I.iter );
321 else
322 prod *= newMessage( i, I.iter );
323 if( props.logdomain ) {
324 prod -= prod.maxVal();
325 prod.takeExp();
326 }
327
328 prod.normalize();
329 return( Factor( var(i), prod ) );
330 }
331
332
333 Factor BP::belief (const Var &n) const {
334 return( beliefV( findVar( n ) ) );
335 }
336
337
338 vector<Factor> BP::beliefs() const {
339 vector<Factor> result;
340 for( size_t i = 0; i < nrVars(); ++i )
341 result.push_back( beliefV(i) );
342 for( size_t I = 0; I < nrFactors(); ++I )
343 result.push_back( beliefF(I) );
344 return result;
345 }
346
347
348 Factor BP::belief( const VarSet &ns ) const {
349 if( ns.size() == 1 )
350 return belief( *(ns.begin()) );
351 else {
352 size_t I;
353 for( I = 0; I < nrFactors(); I++ )
354 if( factor(I).vars() >> ns )
355 break;
356 assert( I != nrFactors() );
357 return beliefF(I).marginal(ns);
358 }
359 }
360
361
362 Factor BP::beliefF (size_t I) const {
363 if( 0 == 1 ) {
364 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
365
366 Factor prod( factor(I) );
367 foreach( const Neighbor &j, nbF(I) ) {
368 foreach( const Neighbor &J, nbV(j) ) {
369 if( J != I ) // for all J in nb(j) \ I
370 prod *= Factor( var(j), newMessage(j, J.iter) );
371 }
372 }
373 return prod.normalized();
374 } else {
375 /* OPTIMIZED VERSION */
376 Prob prod( factor(I).p() );
377 if( props.logdomain )
378 prod.takeLog();
379
380 foreach( const Neighbor &j, nbF(I) ) {
381 size_t _I = j.dual;
382 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
383 const ind_t & ind = index(j, _I);
384
385 // prod_j will be the product of messages coming into j
386 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
387 foreach( const Neighbor &J, nbV(j) ) {
388 if( J != I ) { // for all J in nb(j) \ I
389 if( props.logdomain )
390 prod_j += newMessage( j, J.iter );
391 else
392 prod_j *= newMessage( j, J.iter );
393 }
394 }
395
396 // multiply prod with prod_j
397 for( size_t r = 0; r < prod.size(); ++r ) {
398 if( props.logdomain )
399 prod[r] += prod_j[ind[r]];
400 else
401 prod[r] *= prod_j[ind[r]];
402 }
403 }
404
405 if( props.logdomain ) {
406 prod -= prod.maxVal();
407 prod.takeExp();
408 }
409
410 Factor result( factor(I).vars(), prod );
411 result.normalize();
412
413 return( result );
414 }
415 }
416
417
418 Real BP::logZ() const {
419 Real sum = 0.0;
420 for(size_t i = 0; i < nrVars(); ++i )
421 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
422 for( size_t I = 0; I < nrFactors(); ++I )
423 sum -= KL_dist( beliefF(I), factor(I) );
424 return sum;
425 }
426
427
428 string BP::identify() const {
429 return string(Name) + printProperties();
430 }
431
432
433 void BP::init( const VarSet &ns ) {
434 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
435 size_t ni = findVar( *n );
436 foreach( const Neighbor &I, nbV( ni ) )
437 message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
438 }
439 }
440
441
442 } // end of namespace dai