3b7477bff3a1df12207e2c111c741aa9a43d4b07
[libdai.git] / src / bp.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <algorithm>
27 #include <dai/bp.h>
28 #include <dai/diffs.h>
29 #include <dai/util.h>
30 #include <dai/properties.h>
31
32
33 namespace dai {
34
35
36 using namespace std;
37
38
39 const char *BP::Name = "BP";
40
41
42 bool BP::checkProperties() {
43 if( !HasProperty("updates") )
44 return false;
45 if( !HasProperty("tol") )
46 return false;
47 if (!HasProperty("maxiter") )
48 return false;
49 if (!HasProperty("verbose") )
50 return false;
51
52 ConvertPropertyTo<double>("tol");
53 ConvertPropertyTo<size_t>("maxiter");
54 ConvertPropertyTo<size_t>("verbose");
55 ConvertPropertyTo<UpdateType>("updates");
56
57 return true;
58 }
59
60
61 void BP::Regenerate() {
62 DAIAlgFG::Regenerate();
63
64 // clear messages
65 _messages.clear();
66 _messages.reserve(nr_edges());
67
68 // clear indices
69 _indices.clear();
70 _indices.reserve(nr_edges());
71
72 // create messages and indices
73 for( vector<_edge_t>::const_iterator iI=edges().begin(); iI!=edges().end(); ++iI ) {
74 _messages.push_back( Prob( var(iI->first).states() ) );
75
76 vector<size_t> ind( factor(iI->second).stateSpace(), 0 );
77 Index i (var(iI->first), factor(iI->second).vars() );
78 for( size_t j = 0; i >= 0; ++i,++j )
79 ind[j] = i;
80 _indices.push_back( ind );
81 }
82
83 // create new_messages
84 _newmessages = _messages;
85 }
86
87
88 void BP::init() {
89 assert( checkProperties() );
90 for( vector<Prob>::iterator mij = _messages.begin(); mij != _messages.end(); ++mij )
91 mij->fill(1.0 / mij->size());
92 _newmessages = _messages;
93 }
94
95
96 void BP::calcNewMessage (size_t iI) {
97 // calculate updated message I->i
98 size_t i = edge(iI).first;
99 size_t I = edge(iI).second;
100
101 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
102
103 Factor prod( factor( I ) );
104 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ )
105 if( *j != i ) { // for all j in I \ i
106 for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); J++ )
107 if( *J != I ) { // for all J in nb(j) \ I
108 prod *= Factor( *j, message(*j,*J) );
109 Factor marg = prod.marginal(var(i));
110 */
111
112 Prob prod( factor(I).p() );
113
114 // Calculate product of incoming messages and factor I
115 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); ++j )
116 if( *j != i ) { // for all j in I \ i
117 // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
118 _ind_t* ind = &(index(*j,I));
119
120 // prod_j will be the product of messages coming into j
121 Prob prod_j( var(*j).states() );
122 for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); ++J )
123 if( *J != I ) // for all J in nb(j) \ I
124 prod_j *= message(*j,*J);
125
126 // multiply prod with prod_j
127 for( size_t r = 0; r < prod.size(); ++r )
128 prod[r] *= prod_j[(*ind)[r]];
129 }
130
131 // Marginalize onto i
132 Prob marg( var(i).states(), 0.0 );
133 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
134 _ind_t* ind = &(index(i,I));
135 for( size_t r = 0; r < prod.size(); ++r )
136 marg[(*ind)[r]] += prod[r];
137 marg.normalize( _normtype );
138
139 // Store result
140 _newmessages[iI] = marg;
141 }
142
143
144 // BP::run does not check for NANs for performance reasons
145 // Somehow NaNs do not often occur in BP...
146 double BP::run() {
147 if( Verbose() >= 1 )
148 cout << "Starting " << identify() << "...";
149 if( Verbose() >= 3)
150 cout << endl;
151
152 clock_t tic = toc();
153 Diffs diffs(nrVars(), 1.0);
154
155 vector<size_t> edge_seq;
156 vector<double> residuals;
157
158 vector<Factor> old_beliefs;
159 old_beliefs.reserve( nrVars() );
160 for( size_t i = 0; i < nrVars(); ++i )
161 old_beliefs.push_back(belief1(i));
162
163 size_t iter = 0;
164
165 if( Updates() == UpdateType::SEQMAX ) {
166 // do the first pass
167 for(size_t iI = 0; iI < nr_edges(); ++iI )
168 calcNewMessage(iI);
169
170 // calculate initial residuals
171 residuals.reserve(nr_edges());
172 for( size_t iI = 0; iI < nr_edges(); ++iI )
173 residuals.push_back( dist( _newmessages[iI], _messages[iI], Prob::DISTLINF ) );
174 } else {
175 edge_seq.reserve( nr_edges() );
176 for( size_t i = 0; i < nr_edges(); ++i )
177 edge_seq.push_back( i );
178 }
179
180 // do several passes over the network until maximum number of iterations has
181 // been reached or until the maximum belief difference is smaller than tolerance
182 for( iter=0; iter < MaxIter() && diffs.max() > Tol(); ++iter ) {
183 if( Updates() == UpdateType::SEQMAX ) {
184 // Residuals-BP by Koller et al.
185 for( size_t t = 0; t < nr_edges(); ++t ) {
186 // update the message with the largest residual
187 size_t iI = max_element(residuals.begin(), residuals.end()) - residuals.begin();
188 _messages[iI] = _newmessages[iI];
189 residuals[iI] = 0;
190
191 // I->i has been updated, which means that residuals for all
192 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
193 size_t i = edge(iI).first;
194 size_t I = edge(iI).second;
195 for( _nb_cit J = nb1(i).begin(); J != nb1(i).end(); ++J )
196 if( *J != I )
197 for( _nb_cit j = nb2(*J).begin(); j != nb2(*J).end(); ++j )
198 if( *j != i ) {
199 size_t jJ = VV2E(*j,*J);
200 calcNewMessage(jJ);
201 residuals[jJ] = dist( _newmessages[jJ], _messages[jJ], Prob::DISTLINF );
202 }
203 }
204 } else if( Updates() == UpdateType::PARALL ) {
205 // Parallel updates
206 for( size_t t = 0; t < nr_edges(); ++t )
207 calcNewMessage(t);
208
209 for( size_t t = 0; t < nr_edges(); ++t )
210 _messages[t] = _newmessages[t];
211 } else {
212 // Sequential updates
213 if( Updates() == UpdateType::SEQRND )
214 random_shuffle( edge_seq.begin(), edge_seq.end() );
215
216 for( size_t t = 0; t < nr_edges(); ++t ) {
217 size_t k = edge_seq[t];
218 calcNewMessage(k);
219 _messages[k] = _newmessages[k];
220 }
221 }
222
223 // calculate new beliefs and compare with old ones
224 for( size_t i = 0; i < nrVars(); ++i ) {
225 Factor nb( belief1(i) );
226 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
227 old_beliefs[i] = nb;
228 }
229
230 if( Verbose() >= 3 )
231 cout << "BP::run: maxdiff " << diffs.max() << " after " << iter+1 << " passes" << endl;
232 }
233
234 updateMaxDiff( diffs.max() );
235
236 if( Verbose() >= 1 ) {
237 if( diffs.max() > Tol() ) {
238 if( Verbose() == 1 )
239 cout << endl;
240 cout << "BP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
241 } else {
242 if( Verbose() >= 3 )
243 cout << "BP::run: ";
244 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
245 }
246 }
247
248 return diffs.max();
249 }
250
251
252 Factor BP::belief1( size_t i ) const {
253 Prob prod( var(i).states() );
254 for( _nb_cit I = nb1(i).begin(); I != nb1(i).end(); ++I )
255 prod *= newMessage(i,*I);
256
257 prod.normalize( Prob::NORMPROB );
258 return( Factor( var(i), prod ) );
259 }
260
261
262 Factor BP::belief (const Var &n) const {
263 return( belief1( findVar( n ) ) );
264 }
265
266
267 vector<Factor> BP::beliefs() const {
268 vector<Factor> result;
269 for( size_t i = 0; i < nrVars(); ++i )
270 result.push_back( belief1(i) );
271 for( size_t I = 0; I < nrFactors(); ++I )
272 result.push_back( belief2(I) );
273 return result;
274 }
275
276
277 Factor BP::belief( const VarSet &ns ) const {
278 if( ns.size() == 1 )
279 return belief( *(ns.begin()) );
280 else {
281 size_t I;
282 for( I = 0; I < nrFactors(); I++ )
283 if( factor(I).vars() >> ns )
284 break;
285 assert( I != nrFactors() );
286 return belief2(I).marginal(ns);
287 }
288 }
289
290
291 Factor BP::belief2 (size_t I) const {
292 Prob prod( factor(I).p() );
293
294 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); ++j ) {
295 // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
296 const _ind_t *ind = &(index(*j, I));
297
298 // prod_j will be the product of messages coming into j
299 Prob prod_j( var(*j).states() );
300 for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); ++J )
301 if( *J != I ) // for all J in nb(j) \ I
302 prod_j *= newMessage(*j,*J);
303
304 // multiply prod with prod_j
305 for( size_t r = 0; r < prod.size(); ++r )
306 prod[r] *= prod_j[(*ind)[r]];
307 }
308
309 Factor result( factor(I).vars(), prod );
310 result.normalize( Prob::NORMPROB );
311
312 return( result );
313
314 /* UNOPTIMIZED VERSION
315
316 Factor prod( factor(I) );
317 for( _nb_cit i = nb2(I).begin(); i != nb2(I).end(); i++ ) {
318 for( _nb_cit J = nb1(*i).begin(); J != nb1(*i).end(); J++ )
319 if( *J != I )
320 prod *= Factor( var(*i), newMessage(*i,*J)) );
321 }
322 return prod.normalize( Prob::NORMPROB );*/
323 }
324
325
326 Complex BP::logZ() const {
327 Complex sum = 0.0;
328 for(size_t i = 0; i < nrVars(); ++i )
329 sum += Complex(1.0 - nb1(i).size()) * belief1(i).entropy();
330 for( size_t I = 0; I < nrFactors(); ++I )
331 sum -= KL_dist( belief2(I), factor(I) );
332 return sum;
333 }
334
335
336 string BP::identify() const {
337 stringstream result (stringstream::out);
338 result << Name << GetProperties();
339 return result.str();
340 }
341
342
343 void BP::init( const VarSet &ns ) {
344 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
345 size_t ni = findVar( *n );
346 for( _nb_cit I = nb1(ni).begin(); I != nb1(ni).end(); ++I )
347 message(ni,*I).fill( 1.0 );
348 }
349 }
350
351
352 } // end of namespace dai