aeaeaaea6094d0ec76954d0cbda7c9db9fc58bc9
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <algorithm>
27 #include <dai/bp.h>
28 #include <dai/diffs.h>
29 #include <dai/util.h>
30 #include <dai/properties.h>
31
32
33 namespace dai {
34
35
36 using namespace std;
37
38
39 const char *BP::Name = "BP";
40
41
42 void BP::setProperties( const PropertySet &opts ) {
43 assert( opts.hasKey("tol") );
44 assert( opts.hasKey("maxiter") );
45 assert( opts.hasKey("verbose") );
46 assert( opts.hasKey("logdomain") );
47 assert( opts.hasKey("updates") );
48
49 props.tol = opts.getStringAs<double>("tol");
50 props.maxiter = opts.getStringAs<size_t>("maxiter");
51 props.verbose = opts.getStringAs<size_t>("verbose");
52 props.logdomain = opts.getStringAs<bool>("logdomain");
53 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
54 }
55
56
57 PropertySet BP::getProperties() const {
58 PropertySet opts;
59 opts.Set( "tol", props.tol );
60 opts.Set( "maxiter", props.maxiter );
61 opts.Set( "verbose", props.verbose );
62 opts.Set( "logdomain", props.logdomain );
63 opts.Set( "updates", props.updates );
64 return opts;
65 }
66
67
68 string BP::printProperties() const {
69 stringstream s( stringstream::out );
70 s << "[";
71 s << "tol=" << props.tol << ",";
72 s << "maxiter=" << props.maxiter << ",";
73 s << "verbose=" << props.verbose << ",";
74 s << "logdomain=" << props.logdomain << ",";
75 s << "updates=" << props.updates << "]";
76 return s.str();
77 }
78
79
80 void BP::construct() {
81 // create edge properties
82 edges.clear();
83 edges.reserve( nrVars() );
84 for( size_t i = 0; i < nrVars(); ++i ) {
85 edges.push_back( vector<EdgeProp>() );
86 edges[i].reserve( nbV(i).size() );
87 foreach( const Neighbor &I, nbV(i) ) {
88 EdgeProp newEP;
89 newEP.message = Prob( var(i).states() );
90 newEP.newMessage = Prob( var(i).states() );
91
92 newEP.index.reserve( factor(I).states() );
93 for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
94 newEP.index.push_back( k );
95
96 newEP.residual = 0.0;
97 edges[i].push_back( newEP );
98 }
99 }
100 }
101
102
103 void BP::init() {
104 for( size_t i = 0; i < nrVars(); ++i ) {
105 foreach( const Neighbor &I, nbV(i) ) {
106 if( props.logdomain ) {
107 message( i, I.iter ).fill( 0.0 );
108 newMessage( i, I.iter ).fill( 0.0 );
109 } else {
110 message( i, I.iter ).fill( 1.0 );
111 newMessage( i, I.iter ).fill( 1.0 );
112 }
113 }
114 }
115 }
116
117
118 void BP::findMaxResidual( size_t &i, size_t &_I ) {
119 i = 0;
120 _I = 0;
121 double maxres = residual( i, _I );
122 for( size_t j = 0; j < nrVars(); ++j )
123 foreach( const Neighbor &I, nbV(j) )
124 if( residual( j, I.iter ) > maxres ) {
125 i = j;
126 _I = I.iter;
127 maxres = residual( i, _I );
128 }
129 }
130
131
132 void BP::calcNewMessage( size_t i, size_t _I ) {
133 // calculate updated message I->i
134 size_t I = nbV(i,_I);
135
136 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
137
138 Factor prod( factor( I ) );
139 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ )
140 if( *j != i ) { // for all j in I \ i
141 for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); J++ )
142 if( *J != I ) { // for all J in nb(j) \ I
143 prod *= Factor( *j, message(*j,*J) );
144 Factor marg = prod.marginal(var(i));
145 */
146
147 Prob prod( factor(I).p() );
148 if( props.logdomain )
149 prod.takeLog();
150
151 // Calculate product of incoming messages and factor I
152 foreach( const Neighbor &j, nbF(I) ) {
153 if( j != i ) { // for all j in I \ i
154 size_t _I = j.dual;
155 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
156 const ind_t & ind = index(j, _I);
157
158 // prod_j will be the product of messages coming into j
159 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
160 foreach( const Neighbor &J, nbV(j) )
161 if( J != I ) { // for all J in nb(j) \ I
162 if( props.logdomain )
163 prod_j += message( j, J.iter );
164 else
165 prod_j *= message( j, J.iter );
166 }
167
168 // multiply prod with prod_j
169 for( size_t r = 0; r < prod.size(); ++r )
170 if( props.logdomain )
171 prod[r] += prod_j[ind[r]];
172 else
173 prod[r] *= prod_j[ind[r]];
174 }
175 }
176 if( props.logdomain ) {
177 prod -= prod.maxVal();
178 prod.takeExp();
179 }
180
181 // Marginalize onto i
182 Prob marg( var(i).states(), 0.0 );
183 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
184 const ind_t ind = index(i,_I);
185 for( size_t r = 0; r < prod.size(); ++r )
186 marg[ind[r]] += prod[r];
187 marg.normalize( Prob::NORMPROB );
188
189 // Store result
190 if( props.logdomain )
191 newMessage(i,_I) = marg.log();
192 else
193 newMessage(i,_I) = marg;
194 }
195
196
197 // BP::run does not check for NANs for performance reasons
198 // Somehow NaNs do not often occur in BP...
199 double BP::run() {
200 if( props.verbose >= 1 )
201 cout << "Starting " << identify() << "...";
202 if( props.verbose >= 3)
203 cout << endl;
204
205 double tic = toc();
206 Diffs diffs(nrVars(), 1.0);
207
208 vector<Edge> update_seq;
209
210 vector<Factor> old_beliefs;
211 old_beliefs.reserve( nrVars() );
212 for( size_t i = 0; i < nrVars(); ++i )
213 old_beliefs.push_back( beliefV(i) );
214
215 size_t iter = 0;
216 size_t nredges = nrEdges();
217
218 if( props.updates == Properties::UpdateType::SEQMAX ) {
219 // do the first pass
220 for( size_t i = 0; i < nrVars(); ++i )
221 foreach( const Neighbor &I, nbV(i) ) {
222 calcNewMessage( i, I.iter );
223 // calculate initial residuals
224 residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
225 }
226 } else {
227 update_seq.reserve( nredges );
228 for( size_t i = 0; i < nrVars(); ++i )
229 foreach( const Neighbor &I, nbV(i) )
230 update_seq.push_back( Edge( i, I.iter ) );
231 }
232
233 // do several passes over the network until maximum number of iterations has
234 // been reached or until the maximum belief difference is smaller than tolerance
235 for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; ++iter ) {
236 if( props.updates == Properties::UpdateType::SEQMAX ) {
237 // Residuals-BP by Koller et al.
238 for( size_t t = 0; t < nredges; ++t ) {
239 // update the message with the largest residual
240
241 size_t i, _I;
242 findMaxResidual( i, _I );
243 message( i, _I ) = newMessage( i, _I );
244 residual( i, _I ) = 0.0;
245
246 // I->i has been updated, which means that residuals for all
247 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
248 foreach( const Neighbor &J, nbV(i) ) {
249 if( J.iter != _I ) {
250 foreach( const Neighbor &j, nbF(J) ) {
251 size_t _J = j.dual;
252 if( j != i ) {
253 calcNewMessage( j, _J );
254 residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
255 }
256 }
257 }
258 }
259 }
260 } else if( props.updates == Properties::UpdateType::PARALL ) {
261 // Parallel updates
262 for( size_t i = 0; i < nrVars(); ++i )
263 foreach( const Neighbor &I, nbV(i) )
264 calcNewMessage( i, I.iter );
265
266 for( size_t i = 0; i < nrVars(); ++i )
267 foreach( const Neighbor &I, nbV(i) )
268 message( i, I.iter ) = newMessage( i, I.iter );
269 } else {
270 // Sequential updates
271 if( props.updates == Properties::UpdateType::SEQRND )
272 random_shuffle( update_seq.begin(), update_seq.end() );
273
274 foreach( const Edge &e, update_seq ) {
275 calcNewMessage( e.first, e.second );
276 message( e.first, e.second ) = newMessage( e.first, e.second );
277 }
278 }
279
280 // calculate new beliefs and compare with old ones
281 for( size_t i = 0; i < nrVars(); ++i ) {
282 Factor nb( beliefV(i) );
283 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
284 old_beliefs[i] = nb;
285 }
286
287 if( props.verbose >= 3 )
288 cout << "BP::run: maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
289 }
290
291 if( diffs.maxDiff() > maxdiff )
292 maxdiff = diffs.maxDiff();
293
294 if( props.verbose >= 1 ) {
295 if( diffs.maxDiff() > props.tol ) {
296 if( props.verbose == 1 )
297 cout << endl;
298 cout << "BP::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
299 } else {
300 if( props.verbose >= 3 )
301 cout << "BP::run: ";
302 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
303 }
304 }
305
306 return diffs.maxDiff();
307 }
308
309
310 Factor BP::beliefV( size_t i ) const {
311 Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 );
312 foreach( const Neighbor &I, nbV(i) )
313 if( props.logdomain )
314 prod += newMessage( i, I.iter );
315 else
316 prod *= newMessage( i, I.iter );
317 if( props.logdomain ) {
318 prod -= prod.maxVal();
319 prod.takeExp();
320 }
321
322 prod.normalize( Prob::NORMPROB );
323 return( Factor( var(i), prod ) );
324 }
325
326
327 Factor BP::belief (const Var &n) const {
328 return( beliefV( findVar( n ) ) );
329 }
330
331
332 vector<Factor> BP::beliefs() const {
333 vector<Factor> result;
334 for( size_t i = 0; i < nrVars(); ++i )
335 result.push_back( beliefV(i) );
336 for( size_t I = 0; I < nrFactors(); ++I )
337 result.push_back( beliefF(I) );
338 return result;
339 }
340
341
342 Factor BP::belief( const VarSet &ns ) const {
343 if( ns.size() == 1 )
344 return belief( *(ns.begin()) );
345 else {
346 size_t I;
347 for( I = 0; I < nrFactors(); I++ )
348 if( factor(I).vars() >> ns )
349 break;
350 assert( I != nrFactors() );
351 return beliefF(I).marginal(ns);
352 }
353 }
354
355
356 Factor BP::beliefF (size_t I) const {
357 Prob prod( factor(I).p() );
358 if( props.logdomain )
359 prod.takeLog();
360
361 foreach( const Neighbor &j, nbF(I) ) {
362 size_t _I = j.dual;
363 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
364 const ind_t & ind = index(j, _I);
365
366 // prod_j will be the product of messages coming into j
367 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
368 foreach( const Neighbor &J, nbV(j) ) {
369 if( J != I ) { // for all J in nb(j) \ I
370 if( props.logdomain )
371 prod_j += newMessage( j, J.iter );
372 else
373 prod_j *= newMessage( j, J.iter );
374 }
375 }
376
377 // multiply prod with prod_j
378 for( size_t r = 0; r < prod.size(); ++r ) {
379 if( props.logdomain )
380 prod[r] += prod_j[ind[r]];
381 else
382 prod[r] *= prod_j[ind[r]];
383 }
384 }
385
386 if( props.logdomain ) {
387 prod -= prod.maxVal();
388 prod.takeExp();
389 }
390
391 Factor result( factor(I).vars(), prod );
392 result.normalize( Prob::NORMPROB );
393
394 return( result );
395
396 /* UNOPTIMIZED VERSION
397
398 Factor prod( factor(I) );
399 for( _nb_cit i = nb2(I).begin(); i != nb2(I).end(); i++ ) {
400 for( _nb_cit J = nb1(*i).begin(); J != nb1(*i).end(); J++ )
401 if( *J != I )
402 prod *= Factor( var(*i), newMessage(*i,*J)) );
403 }
404 return prod.normalize( Prob::NORMPROB );*/
405 }
406
407
408 Real BP::logZ() const {
409 Real sum = 0.0;
410 for(size_t i = 0; i < nrVars(); ++i )
411 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
412 for( size_t I = 0; I < nrFactors(); ++I )
413 sum -= KL_dist( beliefF(I), factor(I) );
414 return sum;
415 }
416
417
418 string BP::identify() const {
419 return string(Name) + printProperties();
420 }
421
422
423 void BP::init( const VarSet &ns ) {
424 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
425 size_t ni = findVar( *n );
426 foreach( const Neighbor &I, nbV( ni ) )
427 message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
428 }
429 }
430
431
432 } // end of namespace dai