Cleanups of matlab, BP; small improvement of utils/fginfo
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <sstream>
25 #include <map>
26 #include <set>
27 #include <algorithm>
28 #include <dai/bp.h>
29 #include <dai/util.h>
30 #include <dai/properties.h>
31
32
33 namespace dai {
34
35
36 using namespace std;
37
38
39 const char *BP::Name = "BP";
40
41
42 #define DAI_BP_FAST 1
43
44
45 void BP::setProperties( const PropertySet &opts ) {
46 assert( opts.hasKey("tol") );
47 assert( opts.hasKey("maxiter") );
48 assert( opts.hasKey("logdomain") );
49 assert( opts.hasKey("updates") );
50
51 props.tol = opts.getStringAs<double>("tol");
52 props.maxiter = opts.getStringAs<size_t>("maxiter");
53 props.logdomain = opts.getStringAs<bool>("logdomain");
54 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
55
56 if( opts.hasKey("verbose") )
57 props.verbose = opts.getStringAs<size_t>("verbose");
58 else
59 props.verbose = 0;
60 if( opts.hasKey("damping") )
61 props.damping = opts.getStringAs<double>("damping");
62 else
63 props.damping = 0.0;
64 if( opts.hasKey("inference") )
65 props.inference = opts.getStringAs<Properties::InfType>("inference");
66 else
67 props.inference = Properties::InfType::SUMPROD;
68 }
69
70
71 PropertySet BP::getProperties() const {
72 PropertySet opts;
73 opts.Set( "tol", props.tol );
74 opts.Set( "maxiter", props.maxiter );
75 opts.Set( "verbose", props.verbose );
76 opts.Set( "logdomain", props.logdomain );
77 opts.Set( "updates", props.updates );
78 opts.Set( "damping", props.damping );
79 opts.Set( "inference", props.inference );
80 return opts;
81 }
82
83
84 string BP::printProperties() const {
85 stringstream s( stringstream::out );
86 s << "[";
87 s << "tol=" << props.tol << ",";
88 s << "maxiter=" << props.maxiter << ",";
89 s << "verbose=" << props.verbose << ",";
90 s << "logdomain=" << props.logdomain << ",";
91 s << "updates=" << props.updates << ",";
92 s << "damping=" << props.damping << ",";
93 s << "inference=" << props.inference << "]";
94 return s.str();
95 }
96
97
98 void BP::construct() {
99 // create edge properties
100 _edges.clear();
101 _edges.reserve( nrVars() );
102 for( size_t i = 0; i < nrVars(); ++i ) {
103 _edges.push_back( vector<EdgeProp>() );
104 _edges[i].reserve( nbV(i).size() );
105 foreach( const Neighbor &I, nbV(i) ) {
106 EdgeProp newEP;
107 newEP.message = Prob( var(i).states() );
108 newEP.newMessage = Prob( var(i).states() );
109
110 if( DAI_BP_FAST ) {
111 newEP.index.reserve( factor(I).states() );
112 for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
113 newEP.index.push_back( k );
114 }
115
116 newEP.residual = 0.0;
117 _edges[i].push_back( newEP );
118 }
119 }
120 }
121
122
123 void BP::init() {
124 double c = props.logdomain ? 0.0 : 1.0;
125 for( size_t i = 0; i < nrVars(); ++i ) {
126 foreach( const Neighbor &I, nbV(i) ) {
127 message( i, I.iter ).fill( c );
128 newMessage( i, I.iter ).fill( c );
129 }
130 }
131 }
132
133
134 void BP::findMaxResidual( size_t &i, size_t &_I ) {
135 i = 0;
136 _I = 0;
137 double maxres = residual( i, _I );
138 for( size_t j = 0; j < nrVars(); ++j )
139 foreach( const Neighbor &I, nbV(j) )
140 if( residual( j, I.iter ) > maxres ) {
141 i = j;
142 _I = I.iter;
143 maxres = residual( i, _I );
144 }
145 }
146
147
148 void BP::calcNewMessage( size_t i, size_t _I ) {
149 // calculate updated message I->i
150 size_t I = nbV(i,_I);
151
152 if( !DAI_BP_FAST ) {
153 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
154 Factor prod( factor( I ) );
155 foreach( const Neighbor &j, nbF(I) )
156 if( j != i ) { // for all j in I \ i
157 foreach( const Neighbor &J, nbV(j) )
158 if( J != I ) { // for all J in nb(j) \ I
159 prod *= Factor( var(j), message(j, J.iter) );
160 }
161 }
162 newMessage(i,_I) = prod.marginal( var(i) ).p();
163 } else {
164 /* OPTIMIZED VERSION */
165 Prob prod( factor(I).p() );
166 if( props.logdomain )
167 prod.takeLog();
168
169 // Calculate product of incoming messages and factor I
170 foreach( const Neighbor &j, nbF(I) ) {
171 if( j != i ) { // for all j in I \ i
172 size_t _I = j.dual;
173 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
174 const ind_t &ind = index(j, _I);
175
176 // prod_j will be the product of messages coming into j
177 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
178 foreach( const Neighbor &J, nbV(j) )
179 if( J != I ) { // for all J in nb(j) \ I
180 if( props.logdomain )
181 prod_j += message( j, J.iter );
182 else
183 prod_j *= message( j, J.iter );
184 }
185
186 // multiply prod with prod_j
187 for( size_t r = 0; r < prod.size(); ++r )
188 if( props.logdomain )
189 prod[r] += prod_j[ind[r]];
190 else
191 prod[r] *= prod_j[ind[r]];
192 }
193 }
194 if( props.logdomain ) {
195 prod -= prod.maxVal();
196 prod.takeExp();
197 }
198
199 // Marginalize onto i
200 Prob marg( var(i).states(), 0.0 );
201 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
202 const ind_t ind = index(i,_I);
203 if( props.inference == Properties::InfType::SUMPROD )
204 for( size_t r = 0; r < prod.size(); ++r )
205 marg[ind[r]] += prod[r];
206 else
207 for( size_t r = 0; r < prod.size(); ++r )
208 if( prod[r] > marg[ind[r]] )
209 marg[ind[r]] = prod[r];
210 marg.normalize();
211
212 // Store result
213 if( props.logdomain )
214 newMessage(i,_I) = marg.log();
215 else
216 newMessage(i,_I) = marg;
217 }
218 }
219
220
221 // BP::run does not check for NANs for performance reasons
222 // Somehow NaNs do not often occur in BP...
223 double BP::run() {
224 if( props.verbose >= 1 )
225 cout << "Starting " << identify() << "...";
226 if( props.verbose >= 3)
227 cout << endl;
228
229 double tic = toc();
230 Diffs diffs(nrVars(), 1.0);
231
232 vector<Edge> update_seq;
233
234 vector<Factor> old_beliefs;
235 old_beliefs.reserve( nrVars() );
236 for( size_t i = 0; i < nrVars(); ++i )
237 old_beliefs.push_back( beliefV(i) );
238
239 size_t nredges = nrEdges();
240
241 if( props.updates == Properties::UpdateType::SEQMAX ) {
242 // do the first pass
243 for( size_t i = 0; i < nrVars(); ++i )
244 foreach( const Neighbor &I, nbV(i) ) {
245 calcNewMessage( i, I.iter );
246 // calculate initial residuals
247 residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
248 }
249 } else {
250 update_seq.reserve( nredges );
251 for( size_t i = 0; i < nrVars(); ++i )
252 foreach( const Neighbor &I, nbV(i) )
253 update_seq.push_back( Edge( i, I.iter ) );
254 }
255
256 // do several passes over the network until maximum number of iterations has
257 // been reached or until the maximum belief difference is smaller than tolerance
258 for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; ++_iters ) {
259 if( props.updates == Properties::UpdateType::SEQMAX ) {
260 // Residuals-BP by Koller et al.
261 for( size_t t = 0; t < nredges; ++t ) {
262 // update the message with the largest residual
263 size_t i, _I;
264 findMaxResidual( i, _I );
265 updateMessage( i, _I );
266
267 // I->i has been updated, which means that residuals for all
268 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
269 foreach( const Neighbor &J, nbV(i) ) {
270 if( J.iter != _I ) {
271 foreach( const Neighbor &j, nbF(J) ) {
272 size_t _J = j.dual;
273 if( j != i ) {
274 calcNewMessage( j, _J );
275 residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
276 }
277 }
278 }
279 }
280 }
281 } else if( props.updates == Properties::UpdateType::PARALL ) {
282 // Parallel updates
283 for( size_t i = 0; i < nrVars(); ++i )
284 foreach( const Neighbor &I, nbV(i) )
285 calcNewMessage( i, I.iter );
286
287 for( size_t i = 0; i < nrVars(); ++i )
288 foreach( const Neighbor &I, nbV(i) )
289 updateMessage( i, I.iter );
290 } else {
291 // Sequential updates
292 if( props.updates == Properties::UpdateType::SEQRND )
293 random_shuffle( update_seq.begin(), update_seq.end() );
294
295 foreach( const Edge &e, update_seq ) {
296 calcNewMessage( e.first, e.second );
297 updateMessage( e.first, e.second );
298 }
299 }
300
301 // calculate new beliefs and compare with old ones
302 for( size_t i = 0; i < nrVars(); ++i ) {
303 Factor nb( beliefV(i) );
304 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
305 old_beliefs[i] = nb;
306 }
307
308 if( props.verbose >= 3 )
309 cout << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
310 }
311
312 if( diffs.maxDiff() > _maxdiff )
313 _maxdiff = diffs.maxDiff();
314
315 if( props.verbose >= 1 ) {
316 if( diffs.maxDiff() > props.tol ) {
317 if( props.verbose == 1 )
318 cout << endl;
319 cout << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
320 } else {
321 if( props.verbose >= 3 )
322 cout << Name << "::run: ";
323 cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
324 }
325 }
326
327 return diffs.maxDiff();
328 }
329
330
331 Factor BP::beliefV( size_t i ) const {
332 Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 );
333 foreach( const Neighbor &I, nbV(i) )
334 if( props.logdomain )
335 prod += newMessage( i, I.iter );
336 else
337 prod *= newMessage( i, I.iter );
338 if( props.logdomain ) {
339 prod -= prod.maxVal();
340 prod.takeExp();
341 }
342
343 prod.normalize();
344 return( Factor( var(i), prod ) );
345 }
346
347
348 Factor BP::belief (const Var &n) const {
349 return( beliefV( findVar( n ) ) );
350 }
351
352
353 vector<Factor> BP::beliefs() const {
354 vector<Factor> result;
355 for( size_t i = 0; i < nrVars(); ++i )
356 result.push_back( beliefV(i) );
357 for( size_t I = 0; I < nrFactors(); ++I )
358 result.push_back( beliefF(I) );
359 return result;
360 }
361
362
363 Factor BP::belief( const VarSet &ns ) const {
364 if( ns.size() == 1 )
365 return belief( *(ns.begin()) );
366 else {
367 size_t I;
368 for( I = 0; I < nrFactors(); I++ )
369 if( factor(I).vars() >> ns )
370 break;
371 assert( I != nrFactors() );
372 return beliefF(I).marginal(ns);
373 }
374 }
375
376
377 Factor BP::beliefF (size_t I) const {
378 if( !DAI_BP_FAST ) {
379 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
380
381 Factor prod( factor(I) );
382 foreach( const Neighbor &j, nbF(I) ) {
383 foreach( const Neighbor &J, nbV(j) ) {
384 if( J != I ) // for all J in nb(j) \ I
385 prod *= Factor( var(j), newMessage(j, J.iter) );
386 }
387 }
388 return prod.normalized();
389 } else {
390 /* OPTIMIZED VERSION */
391 Prob prod( factor(I).p() );
392 if( props.logdomain )
393 prod.takeLog();
394
395 foreach( const Neighbor &j, nbF(I) ) {
396 size_t _I = j.dual;
397 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
398 const ind_t & ind = index(j, _I);
399
400 // prod_j will be the product of messages coming into j
401 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
402 foreach( const Neighbor &J, nbV(j) ) {
403 if( J != I ) { // for all J in nb(j) \ I
404 if( props.logdomain )
405 prod_j += newMessage( j, J.iter );
406 else
407 prod_j *= newMessage( j, J.iter );
408 }
409 }
410
411 // multiply prod with prod_j
412 for( size_t r = 0; r < prod.size(); ++r ) {
413 if( props.logdomain )
414 prod[r] += prod_j[ind[r]];
415 else
416 prod[r] *= prod_j[ind[r]];
417 }
418 }
419
420 if( props.logdomain ) {
421 prod -= prod.maxVal();
422 prod.takeExp();
423 }
424
425 Factor result( factor(I).vars(), prod );
426 result.normalize();
427
428 return( result );
429 }
430 }
431
432
433 Real BP::logZ() const {
434 Real sum = 0.0;
435 for(size_t i = 0; i < nrVars(); ++i )
436 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
437 for( size_t I = 0; I < nrFactors(); ++I )
438 sum -= dist( beliefF(I), factor(I), Prob::DISTKL );
439 return sum;
440 }
441
442
443 string BP::identify() const {
444 return string(Name) + printProperties();
445 }
446
447
448 void BP::init( const VarSet &ns ) {
449 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
450 size_t ni = findVar( *n );
451 foreach( const Neighbor &I, nbV( ni ) )
452 message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
453 }
454 }
455
456
457 } // end of namespace dai