Generalized VarSet to "template<typename T> small_set<T>"
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <algorithm>
27 #include <dai/bp.h>
28 #include <dai/diffs.h>
29 #include <dai/util.h>
30 #include <dai/properties.h>
31
32
33 namespace dai {
34
35
36 using namespace std;
37
38
39 const char *BP::Name = "BP";
40
41
42 void BP::setProperties( const PropertySet &opts ) {
43 assert( opts.hasKey("tol") );
44 assert( opts.hasKey("maxiter") );
45 assert( opts.hasKey("logdomain") );
46 assert( opts.hasKey("updates") );
47
48 props.tol = opts.getStringAs<double>("tol");
49 props.maxiter = opts.getStringAs<size_t>("maxiter");
50 props.logdomain = opts.getStringAs<bool>("logdomain");
51 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
52
53 if( opts.hasKey("verbose") )
54 props.verbose = opts.getStringAs<size_t>("verbose");
55 else
56 props.verbose = 0;
57 if( opts.hasKey("damping") )
58 props.damping = opts.getStringAs<double>("damping");
59 else
60 props.damping = 0.0;
61 if( opts.hasKey("inference") )
62 props.inference = opts.getStringAs<Properties::InfType>("inference");
63 else
64 props.inference = Properties::InfType::SUMPROD;
65 }
66
67
68 PropertySet BP::getProperties() const {
69 PropertySet opts;
70 opts.Set( "tol", props.tol );
71 opts.Set( "maxiter", props.maxiter );
72 opts.Set( "verbose", props.verbose );
73 opts.Set( "logdomain", props.logdomain );
74 opts.Set( "updates", props.updates );
75 opts.Set( "damping", props.damping );
76 opts.Set( "inference", props.inference );
77 return opts;
78 }
79
80
81 string BP::printProperties() const {
82 stringstream s( stringstream::out );
83 s << "[";
84 s << "tol=" << props.tol << ",";
85 s << "maxiter=" << props.maxiter << ",";
86 s << "verbose=" << props.verbose << ",";
87 s << "logdomain=" << props.logdomain << ",";
88 s << "updates=" << props.updates << ",";
89 s << "damping=" << props.damping << ",";
90 s << "inference=" << props.inference << "]";
91 return s.str();
92 }
93
94
95 void BP::construct() {
96 // create edge properties
97 _edges.clear();
98 _edges.reserve( nrVars() );
99 for( size_t i = 0; i < nrVars(); ++i ) {
100 _edges.push_back( vector<EdgeProp>() );
101 _edges[i].reserve( nbV(i).size() );
102 foreach( const Neighbor &I, nbV(i) ) {
103 EdgeProp newEP;
104 newEP.message = Prob( var(i).states() );
105 newEP.newMessage = Prob( var(i).states() );
106
107 newEP.index.reserve( factor(I).states() );
108 for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
109 newEP.index.push_back( k );
110
111 newEP.residual = 0.0;
112 _edges[i].push_back( newEP );
113 }
114 }
115 }
116
117
118 void BP::init() {
119 double c = props.logdomain ? 0.0 : 1.0;
120 for( size_t i = 0; i < nrVars(); ++i ) {
121 foreach( const Neighbor &I, nbV(i) ) {
122 message( i, I.iter ).fill( c );
123 newMessage( i, I.iter ).fill( c );
124 }
125 }
126 }
127
128
129 void BP::findMaxResidual( size_t &i, size_t &_I ) {
130 i = 0;
131 _I = 0;
132 double maxres = residual( i, _I );
133 for( size_t j = 0; j < nrVars(); ++j )
134 foreach( const Neighbor &I, nbV(j) )
135 if( residual( j, I.iter ) > maxres ) {
136 i = j;
137 _I = I.iter;
138 maxres = residual( i, _I );
139 }
140 }
141
142
143 void BP::calcNewMessage( size_t i, size_t _I ) {
144 // calculate updated message I->i
145 size_t I = nbV(i,_I);
146
147 if( 0 == 1 ) {
148 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
149 Factor prod( factor( I ) );
150 foreach( const Neighbor &j, nbF(I) )
151 if( j != i ) { // for all j in I \ i
152 foreach( const Neighbor &J, nbV(j) )
153 if( J != I ) { // for all J in nb(j) \ I
154 prod *= Factor( var(j), message(j, J.iter) );
155 }
156 }
157 newMessage(i,_I) = prod.marginal( var(i) ).p();
158 } else {
159 /* OPTIMIZED VERSION */
160 Prob prod( factor(I).p() );
161 if( props.logdomain )
162 prod.takeLog();
163
164 // Calculate product of incoming messages and factor I
165 foreach( const Neighbor &j, nbF(I) ) {
166 if( j != i ) { // for all j in I \ i
167 size_t _I = j.dual;
168 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
169 const ind_t &ind = index(j, _I);
170
171 // prod_j will be the product of messages coming into j
172 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
173 foreach( const Neighbor &J, nbV(j) )
174 if( J != I ) { // for all J in nb(j) \ I
175 if( props.logdomain )
176 prod_j += message( j, J.iter );
177 else
178 prod_j *= message( j, J.iter );
179 }
180
181 // multiply prod with prod_j
182 for( size_t r = 0; r < prod.size(); ++r )
183 if( props.logdomain )
184 prod[r] += prod_j[ind[r]];
185 else
186 prod[r] *= prod_j[ind[r]];
187 }
188 }
189 if( props.logdomain ) {
190 prod -= prod.maxVal();
191 prod.takeExp();
192 }
193
194 // Marginalize onto i
195 Prob marg( var(i).states(), 0.0 );
196 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
197 const ind_t ind = index(i,_I);
198 if( props.inference == Properties::InfType::SUMPROD )
199 for( size_t r = 0; r < prod.size(); ++r )
200 marg[ind[r]] += prod[r];
201 else
202 for( size_t r = 0; r < prod.size(); ++r )
203 if( prod[r] > marg[ind[r]] )
204 marg[ind[r]] = prod[r];
205 marg.normalize();
206
207 // Store result
208 if( props.logdomain )
209 newMessage(i,_I) = marg.log();
210 else
211 newMessage(i,_I) = marg;
212 }
213 }
214
215
216 // BP::run does not check for NANs for performance reasons
217 // Somehow NaNs do not often occur in BP...
218 double BP::run() {
219 if( props.verbose >= 1 )
220 cout << "Starting " << identify() << "...";
221 if( props.verbose >= 3)
222 cout << endl;
223
224 double tic = toc();
225 Diffs diffs(nrVars(), 1.0);
226
227 vector<Edge> update_seq;
228
229 vector<Factor> old_beliefs;
230 old_beliefs.reserve( nrVars() );
231 for( size_t i = 0; i < nrVars(); ++i )
232 old_beliefs.push_back( beliefV(i) );
233
234 size_t nredges = nrEdges();
235
236 if( props.updates == Properties::UpdateType::SEQMAX ) {
237 // do the first pass
238 for( size_t i = 0; i < nrVars(); ++i )
239 foreach( const Neighbor &I, nbV(i) ) {
240 calcNewMessage( i, I.iter );
241 // calculate initial residuals
242 residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
243 }
244 } else {
245 update_seq.reserve( nredges );
246 for( size_t i = 0; i < nrVars(); ++i )
247 foreach( const Neighbor &I, nbV(i) )
248 update_seq.push_back( Edge( i, I.iter ) );
249 }
250
251 // do several passes over the network until maximum number of iterations has
252 // been reached or until the maximum belief difference is smaller than tolerance
253 for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; ++_iters ) {
254 if( props.updates == Properties::UpdateType::SEQMAX ) {
255 // Residuals-BP by Koller et al.
256 for( size_t t = 0; t < nredges; ++t ) {
257 // update the message with the largest residual
258 size_t i, _I;
259 findMaxResidual( i, _I );
260 updateMessage( i, _I );
261
262 // I->i has been updated, which means that residuals for all
263 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
264 foreach( const Neighbor &J, nbV(i) ) {
265 if( J.iter != _I ) {
266 foreach( const Neighbor &j, nbF(J) ) {
267 size_t _J = j.dual;
268 if( j != i ) {
269 calcNewMessage( j, _J );
270 residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
271 }
272 }
273 }
274 }
275 }
276 } else if( props.updates == Properties::UpdateType::PARALL ) {
277 // Parallel updates
278 for( size_t i = 0; i < nrVars(); ++i )
279 foreach( const Neighbor &I, nbV(i) )
280 calcNewMessage( i, I.iter );
281
282 for( size_t i = 0; i < nrVars(); ++i )
283 foreach( const Neighbor &I, nbV(i) )
284 updateMessage( i, I.iter );
285 } else {
286 // Sequential updates
287 if( props.updates == Properties::UpdateType::SEQRND )
288 random_shuffle( update_seq.begin(), update_seq.end() );
289
290 foreach( const Edge &e, update_seq ) {
291 calcNewMessage( e.first, e.second );
292 updateMessage( e.first, e.second );
293 }
294 }
295
296 // calculate new beliefs and compare with old ones
297 for( size_t i = 0; i < nrVars(); ++i ) {
298 Factor nb( beliefV(i) );
299 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
300 old_beliefs[i] = nb;
301 }
302
303 if( props.verbose >= 3 )
304 cout << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
305 }
306
307 if( diffs.maxDiff() > _maxdiff )
308 _maxdiff = diffs.maxDiff();
309
310 if( props.verbose >= 1 ) {
311 if( diffs.maxDiff() > props.tol ) {
312 if( props.verbose == 1 )
313 cout << endl;
314 cout << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
315 } else {
316 if( props.verbose >= 3 )
317 cout << Name << "::run: ";
318 cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
319 }
320 }
321
322 return diffs.maxDiff();
323 }
324
325
326 Factor BP::beliefV( size_t i ) const {
327 Prob prod( var(i).states(), props.logdomain ? 0.0 : 1.0 );
328 foreach( const Neighbor &I, nbV(i) )
329 if( props.logdomain )
330 prod += newMessage( i, I.iter );
331 else
332 prod *= newMessage( i, I.iter );
333 if( props.logdomain ) {
334 prod -= prod.maxVal();
335 prod.takeExp();
336 }
337
338 prod.normalize();
339 return( Factor( var(i), prod ) );
340 }
341
342
343 Factor BP::belief (const Var &n) const {
344 return( beliefV( findVar( n ) ) );
345 }
346
347
348 vector<Factor> BP::beliefs() const {
349 vector<Factor> result;
350 for( size_t i = 0; i < nrVars(); ++i )
351 result.push_back( beliefV(i) );
352 for( size_t I = 0; I < nrFactors(); ++I )
353 result.push_back( beliefF(I) );
354 return result;
355 }
356
357
358 Factor BP::belief( const VarSet &ns ) const {
359 if( ns.size() == 1 )
360 return belief( *(ns.begin()) );
361 else {
362 size_t I;
363 for( I = 0; I < nrFactors(); I++ )
364 if( factor(I).vars() >> ns )
365 break;
366 assert( I != nrFactors() );
367 return beliefF(I).marginal(ns);
368 }
369 }
370
371
372 Factor BP::beliefF (size_t I) const {
373 if( 0 == 1 ) {
374 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
375
376 Factor prod( factor(I) );
377 foreach( const Neighbor &j, nbF(I) ) {
378 foreach( const Neighbor &J, nbV(j) ) {
379 if( J != I ) // for all J in nb(j) \ I
380 prod *= Factor( var(j), newMessage(j, J.iter) );
381 }
382 }
383 return prod.normalized();
384 } else {
385 /* OPTIMIZED VERSION */
386 Prob prod( factor(I).p() );
387 if( props.logdomain )
388 prod.takeLog();
389
390 foreach( const Neighbor &j, nbF(I) ) {
391 size_t _I = j.dual;
392 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
393 const ind_t & ind = index(j, _I);
394
395 // prod_j will be the product of messages coming into j
396 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
397 foreach( const Neighbor &J, nbV(j) ) {
398 if( J != I ) { // for all J in nb(j) \ I
399 if( props.logdomain )
400 prod_j += newMessage( j, J.iter );
401 else
402 prod_j *= newMessage( j, J.iter );
403 }
404 }
405
406 // multiply prod with prod_j
407 for( size_t r = 0; r < prod.size(); ++r ) {
408 if( props.logdomain )
409 prod[r] += prod_j[ind[r]];
410 else
411 prod[r] *= prod_j[ind[r]];
412 }
413 }
414
415 if( props.logdomain ) {
416 prod -= prod.maxVal();
417 prod.takeExp();
418 }
419
420 Factor result( factor(I).vars(), prod );
421 result.normalize();
422
423 return( result );
424 }
425 }
426
427
428 Real BP::logZ() const {
429 Real sum = 0.0;
430 for(size_t i = 0; i < nrVars(); ++i )
431 sum += (1.0 - nbV(i).size()) * beliefV(i).entropy();
432 for( size_t I = 0; I < nrFactors(); ++I )
433 sum -= KL_dist( beliefF(I), factor(I) );
434 return sum;
435 }
436
437
438 string BP::identify() const {
439 return string(Name) + printProperties();
440 }
441
442
443 void BP::init( const VarSet &ns ) {
444 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
445 size_t ni = findVar( *n );
446 foreach( const Neighbor &I, nbV( ni ) )
447 message( ni, I.iter ).fill( props.logdomain ? 0.0 : 1.0 );
448 }
449 }
450
451
452 } // end of namespace dai