Miscellaneous changes:
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <algorithm>
27 #include <dai/bp.h>
28 #include <dai/diffs.h>
29 #include <dai/util.h>
30 #include <dai/properties.h>
31
32
33 namespace dai {
34
35
36 using namespace std;
37
38
39 const char *BP::Name = "BP";
40
41
42 bool BP::checkProperties() {
43 if( !HasProperty("updates") )
44 return false;
45 if( !HasProperty("tol") )
46 return false;
47 if (!HasProperty("maxiter") )
48 return false;
49 if (!HasProperty("verbose") )
50 return false;
51 if (!HasProperty("logdomain") )
52 return false;
53
54 ConvertPropertyTo<double>("tol");
55 ConvertPropertyTo<size_t>("maxiter");
56 ConvertPropertyTo<size_t>("verbose");
57 ConvertPropertyTo<UpdateType>("updates");
58 ConvertPropertyTo<bool>("logdomain");
59 logDomain = GetPropertyAs<bool>("logdomain");
60
61 return true;
62 }
63
64
65 void BP::create() {
66 // create edge properties
67 edges.clear();
68 edges.reserve( nrVars() );
69 for( size_t i = 0; i < nrVars(); ++i ) {
70 edges.push_back( vector<EdgeProp>() );
71 edges[i].reserve( nbV(i).size() );
72 foreach( const Neighbor &I, nbV(i) ) {
73 EdgeProp newEP;
74 newEP.message = Prob( var(i).states() );
75 newEP.newMessage = Prob( var(i).states() );
76
77 newEP.index.reserve( factor(I).states() );
78 for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
79 newEP.index.push_back( k );
80
81 newEP.residual = 0.0;
82 edges[i].push_back( newEP );
83 }
84 }
85 }
86
87
88 void BP::init() {
89 assert( checkProperties() );
90 for( size_t i = 0; i < nrVars(); ++i ) {
91 foreach( const Neighbor &I, nbV(i) ) {
92 if( logDomain ) {
93 message( i, I.iter ).fill( 0.0 );
94 newMessage( i, I.iter ).fill( 0.0 );
95 } else {
96 message( i, I.iter ).fill( 1.0 );
97 newMessage( i, I.iter ).fill( 1.0 );
98 }
99 }
100 }
101 }
102
103
104 void BP::findMaxResidual( size_t &i, size_t &_I ) {
105 i = 0;
106 _I = 0;
107 double maxres = residual( i, _I );
108 for( size_t j = 0; j < nrVars(); ++j )
109 foreach( const Neighbor &I, nbV(j) )
110 if( residual( j, I.iter ) > maxres ) {
111 i = j;
112 _I = I.iter;
113 maxres = residual( i, _I );
114 }
115 }
116
117
118 void BP::calcNewMessage( size_t i, size_t _I ) {
119 // calculate updated message I->i
120 size_t I = nbV(i,_I);
121
122 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
123
124 Factor prod( factor( I ) );
125 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ )
126 if( *j != i ) { // for all j in I \ i
127 for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); J++ )
128 if( *J != I ) { // for all J in nb(j) \ I
129 prod *= Factor( *j, message(*j,*J) );
130 Factor marg = prod.marginal(var(i));
131 */
132
133 Prob prod( factor(I).p() );
134 if( logDomain )
135 prod.takeLog();
136
137 // Calculate product of incoming messages and factor I
138 foreach( const Neighbor &j, nbF(I) ) {
139 if( j != i ) { // for all j in I \ i
140 size_t _I = j.dual;
141 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
142 const ind_t & ind = index(j, _I);
143
144 // prod_j will be the product of messages coming into j
145 Prob prod_j( var(j).states(), logDomain ? 0.0 : 1.0 );
146 foreach( const Neighbor &J, nbV(j) )
147 if( J != I ) { // for all J in nb(j) \ I
148 if( logDomain )
149 prod_j += message( j, J.iter );
150 else
151 prod_j *= message( j, J.iter );
152 }
153
154 // multiply prod with prod_j
155 for( size_t r = 0; r < prod.size(); ++r )
156 if( logDomain )
157 prod[r] += prod_j[ind[r]];
158 else
159 prod[r] *= prod_j[ind[r]];
160 }
161 }
162 if( logDomain ) {
163 prod -= prod.maxVal();
164 prod.takeExp();
165 }
166
167 // Marginalize onto i
168 Prob marg( var(i).states(), 0.0 );
169 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
170 const ind_t ind = index(i,_I);
171 for( size_t r = 0; r < prod.size(); ++r )
172 marg[ind[r]] += prod[r];
173 marg.normalize( _normtype );
174
175 // Store result
176 if( logDomain )
177 newMessage(i,_I) = marg.log();
178 else
179 newMessage(i,_I) = marg;
180 }
181
182
183 // BP::run does not check for NANs for performance reasons
184 // Somehow NaNs do not often occur in BP...
185 double BP::run() {
186 if( Verbose() >= 1 )
187 cout << "Starting " << identify() << "...";
188 if( Verbose() >= 3)
189 cout << endl;
190
191 double tic = toc();
192 Diffs diffs(nrVars(), 1.0);
193
194 typedef pair<size_t,size_t> Edge;
195 vector<Edge> update_seq;
196
197 vector<Factor> old_beliefs;
198 old_beliefs.reserve( nrVars() );
199 for( size_t i = 0; i < nrVars(); ++i )
200 old_beliefs.push_back( beliefV(i) );
201
202 size_t iter = 0;
203 size_t nredges = nrEdges();
204
205 if( Updates() == UpdateType::SEQMAX ) {
206 // do the first pass
207 for( size_t i = 0; i < nrVars(); ++i )
208 foreach( const Neighbor &I, nbV(i) ) {
209 calcNewMessage( i, I.iter );
210 // calculate initial residuals
211 residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
212 }
213 } else {
214 update_seq.reserve( nredges );
215 for( size_t i = 0; i < nrVars(); ++i )
216 foreach( const Neighbor &I, nbV(i) )
217 update_seq.push_back( Edge( i, I.iter ) );
218 }
219
220 // do several passes over the network until maximum number of iterations has
221 // been reached or until the maximum belief difference is smaller than tolerance
222 for( iter=0; iter < MaxIter() && diffs.maxDiff() > Tol(); ++iter ) {
223 if( Updates() == UpdateType::SEQMAX ) {
224 // Residuals-BP by Koller et al.
225 for( size_t t = 0; t < nredges; ++t ) {
226 // update the message with the largest residual
227
228 size_t i, _I;
229 findMaxResidual( i, _I );
230 message( i, _I ) = newMessage( i, _I );
231 residual( i, _I ) = 0.0;
232
233 // I->i has been updated, which means that residuals for all
234 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
235 foreach( const Neighbor &J, nbV(i) ) {
236 if( J.iter != _I ) {
237 foreach( const Neighbor &j, nbF(J) ) {
238 size_t _J = j.dual;
239 if( j != i ) {
240 calcNewMessage( j, _J );
241 residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
242 }
243 }
244 }
245 }
246 }
247 } else if( Updates() == UpdateType::PARALL ) {
248 // Parallel updates
249 for( size_t i = 0; i < nrVars(); ++i )
250 foreach( const Neighbor &I, nbV(i) )
251 calcNewMessage( i, I.iter );
252
253 for( size_t i = 0; i < nrVars(); ++i )
254 foreach( const Neighbor &I, nbV(i) )
255 message( i, I.iter ) = newMessage( i, I.iter );
256 } else {
257 // Sequential updates
258 if( Updates() == UpdateType::SEQRND )
259 random_shuffle( update_seq.begin(), update_seq.end() );
260
261 foreach( const Edge &e, update_seq ) {
262 calcNewMessage( e.first, e.second );
263 message( e.first, e.second ) = newMessage( e.first, e.second );
264 }
265 }
266
267 // calculate new beliefs and compare with old ones
268 for( size_t i = 0; i < nrVars(); ++i ) {
269 Factor nb( beliefV(i) );
270 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
271 old_beliefs[i] = nb;
272 }
273
274 if( Verbose() >= 3 )
275 cout << "BP::run: maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
276 }
277
278 updateMaxDiff( diffs.maxDiff() );
279
280 if( Verbose() >= 1 ) {
281 if( diffs.maxDiff() > Tol() ) {
282 if( Verbose() == 1 )
283 cout << endl;
284 cout << "BP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
285 } else {
286 if( Verbose() >= 3 )
287 cout << "BP::run: ";
288 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
289 }
290 }
291
292 return diffs.maxDiff();
293 }
294
295
296 Factor BP::beliefV( size_t i ) const {
297 Prob prod( var(i).states(), logDomain ? 0.0 : 1.0 );
298 foreach( const Neighbor &I, nbV(i) )
299 if( logDomain )
300 prod += newMessage( i, I.iter );
301 else
302 prod *= newMessage( i, I.iter );
303 if( logDomain ) {
304 prod -= prod.maxVal();
305 prod.takeExp();
306 }
307
308 prod.normalize( Prob::NORMPROB );
309 return( Factor( var(i), prod ) );
310 }
311
312
313 Factor BP::belief (const Var &n) const {
314 return( beliefV( findVar( n ) ) );
315 }
316
317
318 vector<Factor> BP::beliefs() const {
319 vector<Factor> result;
320 for( size_t i = 0; i < nrVars(); ++i )
321 result.push_back( beliefV(i) );
322 for( size_t I = 0; I < nrFactors(); ++I )
323 result.push_back( beliefF(I) );
324 return result;
325 }
326
327
328 Factor BP::belief( const VarSet &ns ) const {
329 if( ns.size() == 1 )
330 return belief( *(ns.begin()) );
331 else {
332 size_t I;
333 for( I = 0; I < nrFactors(); I++ )
334 if( factor(I).vars() >> ns )
335 break;
336 assert( I != nrFactors() );
337 return beliefF(I).marginal(ns);
338 }
339 }
340
341
342 Factor BP::beliefF (size_t I) const {
343 Prob prod( factor(I).p() );
344 if( logDomain )
345 prod.takeLog();
346
347 foreach( const Neighbor &j, nbF(I) ) {
348 size_t _I = j.dual;
349 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
350 const ind_t & ind = index(j, _I);
351
352 // prod_j will be the product of messages coming into j
353 Prob prod_j( var(j).states(), logDomain ? 0.0 : 1.0 );
354 foreach( const Neighbor &J, nbV(j) ) {
355 if( J != I ) { // for all J in nb(j) \ I
356 if( logDomain )
357 prod_j += newMessage( j, J.iter );
358 else
359 prod_j *= newMessage( j, J.iter );
360 }
361 }
362
363 // multiply prod with prod_j
364 for( size_t r = 0; r < prod.size(); ++r ) {
365 if( logDomain )
366 prod[r] += prod_j[ind[r]];
367 else
368 prod[r] *= prod_j[ind[r]];
369 }
370 }
371
372 if( logDomain ) {
373 prod -= prod.maxVal();
374 prod.takeExp();
375 }
376
377 Factor result( factor(I).vars(), prod );
378 result.normalize( Prob::NORMPROB );
379
380 return( result );
381
382 /* UNOPTIMIZED VERSION
383
384 Factor prod( factor(I) );
385 for( _nb_cit i = nb2(I).begin(); i != nb2(I).end(); i++ ) {
386 for( _nb_cit J = nb1(*i).begin(); J != nb1(*i).end(); J++ )
387 if( *J != I )
388 prod *= Factor( var(*i), newMessage(*i,*J)) );
389 }
390 return prod.normalize( Prob::NORMPROB );*/
391 }
392
393
394 Complex BP::logZ() const {
395 Complex sum = 0.0;
396 for(size_t i = 0; i < nrVars(); ++i )
397 sum += Complex(1.0 - nbV(i).size()) * beliefV(i).entropy();
398 for( size_t I = 0; I < nrFactors(); ++I )
399 sum -= KL_dist( beliefF(I), factor(I) );
400 return sum;
401 }
402
403
404 string BP::identify() const {
405 stringstream result (stringstream::out);
406 result << Name << GetProperties();
407 return result.str();
408 }
409
410
411 void BP::init( const VarSet &ns ) {
412 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
413 size_t ni = findVar( *n );
414 foreach( const Neighbor &I, nbV( ni ) )
415 message( ni, I.iter ).fill( logDomain ? 0.0 : 1.0 );
416 }
417 }
418
419
420 } // end of namespace dai