Merge branch 'no-edges2'
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <algorithm>
27 #include <dai/bp.h>
28 #include <dai/diffs.h>
29 #include <dai/util.h>
30 #include <dai/properties.h>
31
32
33 namespace dai {
34
35
36 using namespace std;
37
38
39 const char *BP::Name = "BP";
40
41
42 bool BP::checkProperties() {
43 if( !HasProperty("updates") )
44 return false;
45 if( !HasProperty("tol") )
46 return false;
47 if (!HasProperty("maxiter") )
48 return false;
49 if (!HasProperty("verbose") )
50 return false;
51
52 ConvertPropertyTo<double>("tol");
53 ConvertPropertyTo<size_t>("maxiter");
54 ConvertPropertyTo<size_t>("verbose");
55 ConvertPropertyTo<UpdateType>("updates");
56
57 return true;
58 }
59
60
61 void BP::Regenerate() {
62 // create edge properties
63 edges.clear();
64 edges.reserve( nrVars() );
65 for( size_t i = 0; i < nrVars(); ++i ) {
66 edges.push_back( vector<EdgeProp>() );
67 edges[i].reserve( nbV(i).size() );
68 foreach( const Neighbor &I, nbV(i) ) {
69 EdgeProp newEP;
70 newEP.message = Prob( var(i).states() );
71 newEP.newMessage = Prob( var(i).states() );
72
73 newEP.index.reserve( factor(I).stateSpace() );
74 for( Index k( var(i), factor(I).vars() ); k >= 0; ++k )
75 newEP.index.push_back( k );
76
77 newEP.residual = 0.0;
78 edges[i].push_back( newEP );
79 }
80 }
81 }
82
83
84 void BP::init() {
85 assert( checkProperties() );
86 for( size_t i = 0; i < nrVars(); ++i ) {
87 foreach( const Neighbor &I, nbV(i) ) {
88 message( i, I.iter ).fill( 1.0 );
89 newMessage( i, I.iter ).fill( 1.0 );
90 }
91 }
92 }
93
94
95 void BP::findMaxResidual( size_t &i, size_t &_I ) {
96 i = 0;
97 _I = 0;
98 double maxres = residual( i, _I );
99 for( size_t j = 0; j < nrVars(); ++j )
100 foreach( const Neighbor &I, nbV(j) )
101 if( residual( j, I.iter ) > maxres ) {
102 i = j;
103 _I = I.iter;
104 maxres = residual( i, _I );
105 }
106 }
107
108
109 void BP::calcNewMessage( size_t i, size_t _I ) {
110 // calculate updated message I->i
111 size_t I = nbV(i,_I);
112
113 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
114
115 Factor prod( factor( I ) );
116 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ )
117 if( *j != i ) { // for all j in I \ i
118 for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); J++ )
119 if( *J != I ) { // for all J in nb(j) \ I
120 prod *= Factor( *j, message(*j,*J) );
121 Factor marg = prod.marginal(var(i));
122 */
123
124 Prob prod( factor(I).p() );
125
126 // Calculate product of incoming messages and factor I
127 foreach( const Neighbor &j, nbF(I) ) {
128 if( j != i ) { // for all j in I \ i
129 size_t _I = j.dual;
130 // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
131 const ind_t & ind = index(j, _I);
132
133 // prod_j will be the product of messages coming into j
134 Prob prod_j( var(j).states() );
135 foreach( const Neighbor &J, nbV(j) ) {
136 if( J != I ) // for all J in nb(j) \ I
137 prod_j *= message( j, J.iter );
138 }
139
140 // multiply prod with prod_j
141 for( size_t r = 0; r < prod.size(); ++r )
142 prod[r] *= prod_j[ind[r]];
143 }
144 }
145
146 // Marginalize onto i
147 Prob marg( var(i).states(), 0.0 );
148 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
149 const ind_t ind = index(i,_I);
150 for( size_t r = 0; r < prod.size(); ++r )
151 marg[ind[r]] += prod[r];
152 marg.normalize( _normtype );
153
154 // Store result
155 newMessage(i,_I) = marg;
156 }
157
158
159 // BP::run does not check for NANs for performance reasons
160 // Somehow NaNs do not often occur in BP...
161 double BP::run() {
162 if( Verbose() >= 1 )
163 cout << "Starting " << identify() << "...";
164 if( Verbose() >= 3)
165 cout << endl;
166
167 clock_t tic = toc();
168 Diffs diffs(nrVars(), 1.0);
169
170 typedef pair<size_t,size_t> Edge;
171 vector<Edge> update_seq;
172
173 vector<Factor> old_beliefs;
174 old_beliefs.reserve( nrVars() );
175 for( size_t i = 0; i < nrVars(); ++i )
176 old_beliefs.push_back( beliefV(i) );
177
178 size_t iter = 0;
179 size_t nredges = nrEdges();
180
181 if( Updates() == UpdateType::SEQMAX ) {
182 // do the first pass
183 for( size_t i = 0; i < nrVars(); ++i )
184 foreach( const Neighbor &I, nbV(i) ) {
185 calcNewMessage( i, I.iter );
186 // calculate initial residuals
187 residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
188 }
189 } else {
190 update_seq.reserve( nredges );
191 for( size_t i = 0; i < nrVars(); ++i )
192 foreach( const Neighbor &I, nbV(i) )
193 update_seq.push_back( Edge( i, I.iter ) );
194 }
195
196 // do several passes over the network until maximum number of iterations has
197 // been reached or until the maximum belief difference is smaller than tolerance
198 for( iter=0; iter < MaxIter() && diffs.max() > Tol(); ++iter ) {
199 if( Updates() == UpdateType::SEQMAX ) {
200 // Residuals-BP by Koller et al.
201 for( size_t t = 0; t < nredges; ++t ) {
202 // update the message with the largest residual
203
204 size_t i, _I;
205 findMaxResidual( i, _I );
206 message( i, _I ) = newMessage( i, _I );
207 residual( i, _I ) = 0.0;
208
209 // I->i has been updated, which means that residuals for all
210 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
211 foreach( const Neighbor &J, nbV(i) ) {
212 if( J.iter != _I ) {
213 foreach( const Neighbor &j, nbF(J) ) {
214 size_t _J = j.dual;
215 if( j != i ) {
216 calcNewMessage( j, _J );
217 residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
218 }
219 }
220 }
221 }
222 }
223 } else if( Updates() == UpdateType::PARALL ) {
224 // Parallel updates
225 for( size_t i = 0; i < nrVars(); ++i )
226 foreach( const Neighbor &I, nbV(i) )
227 calcNewMessage( i, I.iter );
228
229 for( size_t i = 0; i < nrVars(); ++i )
230 foreach( const Neighbor &I, nbV(i) )
231 message( i, I.iter ) = newMessage( i, I.iter );
232 } else {
233 // Sequential updates
234 if( Updates() == UpdateType::SEQRND )
235 random_shuffle( update_seq.begin(), update_seq.end() );
236
237 foreach( const Edge &e, update_seq ) {
238 calcNewMessage( e.first, e.second );
239 message( e.first, e.second ) = newMessage( e.first, e.second );
240 }
241 }
242
243 // calculate new beliefs and compare with old ones
244 for( size_t i = 0; i < nrVars(); ++i ) {
245 Factor nb( beliefV(i) );
246 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
247 old_beliefs[i] = nb;
248 }
249
250 if( Verbose() >= 3 )
251 cout << "BP::run: maxdiff " << diffs.max() << " after " << iter+1 << " passes" << endl;
252 }
253
254 updateMaxDiff( diffs.max() );
255
256 if( Verbose() >= 1 ) {
257 if( diffs.max() > Tol() ) {
258 if( Verbose() == 1 )
259 cout << endl;
260 cout << "BP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
261 } else {
262 if( Verbose() >= 3 )
263 cout << "BP::run: ";
264 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
265 }
266 }
267
268 return diffs.max();
269 }
270
271
272 Factor BP::beliefV( size_t i ) const {
273 Prob prod( var(i).states() );
274 foreach( const Neighbor &I, nbV(i) )
275 prod *= newMessage( i, I.iter );
276
277 prod.normalize( Prob::NORMPROB );
278 return( Factor( var(i), prod ) );
279 }
280
281
282 Factor BP::belief (const Var &n) const {
283 return( beliefV( findVar( n ) ) );
284 }
285
286
287 vector<Factor> BP::beliefs() const {
288 vector<Factor> result;
289 for( size_t i = 0; i < nrVars(); ++i )
290 result.push_back( beliefV(i) );
291 for( size_t I = 0; I < nrFactors(); ++I )
292 result.push_back( beliefF(I) );
293 return result;
294 }
295
296
297 Factor BP::belief( const VarSet &ns ) const {
298 if( ns.size() == 1 )
299 return belief( *(ns.begin()) );
300 else {
301 size_t I;
302 for( I = 0; I < nrFactors(); I++ )
303 if( factor(I).vars() >> ns )
304 break;
305 assert( I != nrFactors() );
306 return beliefF(I).marginal(ns);
307 }
308 }
309
310
311 Factor BP::beliefF (size_t I) const {
312 Prob prod( factor(I).p() );
313
314 foreach( const Neighbor &j, nbF(I) ) {
315 size_t _I = j.dual;
316 // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
317 const ind_t & ind = index(j, _I);
318
319 // prod_j will be the product of messages coming into j
320 Prob prod_j( var(j).states() );
321 foreach( const Neighbor &J, nbV(j) ) {
322 if( J != I ) // for all J in nb(j) \ I
323 prod_j *= newMessage( j, J.iter );
324 }
325
326 // multiply prod with prod_j
327 for( size_t r = 0; r < prod.size(); ++r )
328 prod[r] *= prod_j[ind[r]];
329 }
330
331 Factor result( factor(I).vars(), prod );
332 result.normalize( Prob::NORMPROB );
333
334 return( result );
335
336 /* UNOPTIMIZED VERSION
337
338 Factor prod( factor(I) );
339 for( _nb_cit i = nb2(I).begin(); i != nb2(I).end(); i++ ) {
340 for( _nb_cit J = nb1(*i).begin(); J != nb1(*i).end(); J++ )
341 if( *J != I )
342 prod *= Factor( var(*i), newMessage(*i,*J)) );
343 }
344 return prod.normalize( Prob::NORMPROB );*/
345 }
346
347
348 Complex BP::logZ() const {
349 Complex sum = 0.0;
350 for(size_t i = 0; i < nrVars(); ++i )
351 sum += Complex(1.0 - nbV(i).size()) * beliefV(i).entropy();
352 for( size_t I = 0; I < nrFactors(); ++I )
353 sum -= KL_dist( beliefF(I), factor(I) );
354 return sum;
355 }
356
357
358 string BP::identify() const {
359 stringstream result (stringstream::out);
360 result << Name << GetProperties();
361 return result.str();
362 }
363
364
365 void BP::init( const VarSet &ns ) {
366 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
367 size_t ni = findVar( *n );
368 foreach( const Neighbor &I, nbV( ni ) )
369 message( ni, I.iter ).fill( 1.0 );
370 }
371 }
372
373
374 } // end of namespace dai