[Frederik Eaton] Added BP_Dual, BBP and CBP algorithms
[libdai.git] / src / bp_dual.cpp
1
2 #include <iostream>
3 #include <sstream>
4 #include <map>
5 #include <set>
6 #include <algorithm>
7
8 #include <dai/bbp.h>
9 //#include <dai/diffs.h>
10 #include <dai/util.h>
11 //#include "stlutil.h"
12 #include <dai/properties.h>
13
14 namespace dai {
15
16 using namespace std;
17
18 const char *BP_dual::Name = "BP_dual";
19
20 const char *BP_dual::PropertyList[] = {"tol","maxiter","updates","verbose"};
21
22 void BP_dual::setProperties( const PropertySet &opts ) {
23 // DAI_DMSG("in BP_dual::setProperties");
24 // DAI_PV(opts);
25
26 bool die=false;
27 foreach(const char *p, PropertyList) {
28 if( !opts.hasKey(p) ) {
29 cerr << "BP_dual: missing property " << p << endl;
30 die=true;
31 }
32 }
33 if(die) throw "BP_dual: Couldn't set properties";
34
35 props.tol = opts.getStringAs<double>("tol");
36 props.maxiter = opts.getStringAs<size_t>("maxiter");
37 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
38 props.verbose = opts.getStringAs<size_t>("verbose");
39
40 // DAI_PV(printProperties());
41 }
42
43 PropertySet BP_dual::getProperties() const {
44 PropertySet opts;
45 opts.Set( "tol", props.tol );
46 opts.Set( "maxiter", props.maxiter );
47 opts.Set( "updates", props.updates );
48 opts.Set( "verbose", props.verbose );
49 return opts;
50 }
51
52 std::string BP_dual::printProperties() const {
53 stringstream s( stringstream::out );
54 s << "[";
55 s << "tol=" << props.tol << ",";
56 s << "maxiter=" << props.maxiter << ",";
57 s << "updates=" << props.updates << ",";
58 s << "verbose=" << props.verbose;
59 s << "]";
60 return s.str();
61 }
62
63 // void BP_dual::checkProperties() {
64 // const char *props[] = {"updates","tol","maxiter","verbose"};
65 // for(size_t i=0; i<sizeof(props)/sizeof(*props); i++) {
66 // if(!HasProperty(props[i]))
67 // die("BP_dual: Missing property \"%s\"", props[i]);
68 // }
69
70 // ConvertPropertyTo<double>("tol");
71 // ConvertPropertyTo<size_t>("maxiter");
72 // ConvertPropertyTo<size_t>("verbose");
73 // ConvertPropertyTo<UpdateType>("updates");
74 // }
75
76 void BP_dual::RegenerateIndices() {
77 _indices.clear();
78 _indices.reserve(nr_edges());
79
80 foreach(Edge iI, edges()) {
81 vector<size_t> ind( factor(iI.second).states(), 0 );
82 IndexFor i (var(iI.first), factor(iI.second).vars() );
83 for( size_t j = 0; i >= 0; ++i,++j )
84 ind[j] = i;
85 _indices.push_back( ind );
86 }
87 }
88
89 void BP_dual::RegenerateMessages() {
90 _msgs.Zn.resize(nr_edges(),1.0);
91 _msgs.Zm.resize(nr_edges(),1.0);
92
93 // clear messages
94 _msgs.m.clear();
95 _msgs.m.reserve(nr_edges());
96 _msgs.n.clear();
97 _msgs.n.reserve(nr_edges());
98
99 // create messages and indices
100 foreach(Edge iI, edges()) {
101 // initialize to uniform distributions
102 _msgs.m.push_back( Prob( var(iI.first).states() ) );
103 _msgs.n.push_back( Prob( var(iI.first).states() ) );
104 }
105
106 // create new_messages
107 _new_msgs = _msgs;
108 }
109
110 void BP_dual::RegenerateBeliefs() {
111 _beliefs.b1.clear();
112 _beliefs.b1.reserve(nrVars());
113 _beliefs.Zb1.resize(nrVars(), 1.0);
114 _beliefs.b2.clear();
115 _beliefs.b2.reserve(nrFactors());
116 _beliefs.Zb2.resize(nrFactors(), 1.0);
117
118 for(size_t i=0; i<nrVars(); i++) {
119 _beliefs.b1.push_back( Prob( var(i).states() ).setUniform() );
120 }
121 for(size_t I=0; I<nrFactors(); I++) {
122 _beliefs.b2.push_back( Prob( factor(I).states() ).setUniform() );
123 }
124 }
125
126 // called by constructor, called before 'init'
127 void BP_dual::Regenerate() {
128
129 indexEdges(); // so we can use compatibility interface
130
131 // DAIAlgFG::Regenerate(); // located in BipartiteGraph
132
133 RegenerateIndices();
134 RegenerateMessages();
135 RegenerateBeliefs();
136
137 _maxdiff = 0;
138 _iters = 0;
139 }
140
141 void BP_dual::CalcBelief1(size_t i) {
142 Prob prod( var(i).states(), 1.0 );
143 foreach(size_t I, nbV(i)) {
144 prod *= newMsgM(I,i);
145 }
146 _beliefs.Zb1[i] = prod.normalize();
147 _beliefs.b1[i] = prod;
148 }
149
150 void BP_dual::CalcBelief2(size_t I) {
151 Prob prod( factor(I).p() );
152 foreach(size_t j, nbF(I)) {
153 const _ind_t *ind = &(index(j, I));
154 for(size_t r=0; r<prod.size(); r++) {
155 Prob n(newMsgN(j,I));
156 prod[r] *= n[(*ind)[r]];
157 }
158 }
159 _beliefs.Zb2[I] = prod.normalize();
160 _beliefs.b2[I] = prod;
161 }
162
163 // called after run()
164 void BP_dual::CalcBeliefs() {
165 for(size_t i=0; i<nrVars(); i++) {
166 // calculate b_i
167 CalcBelief1(i);
168 }
169 for(size_t I=0; I<nrFactors(); I++) {
170 // calculate b_I
171 CalcBelief2(I);
172 }
173 }
174
175 void BP_dual::calcNewM(size_t iI) {
176 // calculate updated message I->i
177 size_t i = edge(iI).first;
178 size_t I = edge(iI).second;
179
180 Prob prod( factor(I).p() );
181
182 foreach(size_t j, nbF(I)) {
183 if( j != i ) { // for all j in I \ i
184 _ind_t* ind = &(index(j,I));
185 Prob n(msgN(j,I));
186 for( size_t r = 0; r < prod.size(); r++ )
187 prod[r] *= n[(*ind)[r]];
188 }
189 }
190
191 // Marginalize onto i
192 Prob marg( var(i).states(), 0.0 );
193 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
194 _ind_t* ind = &(index(i,I));
195 for( size_t r = 0; r < prod.size(); r++ )
196 marg[(*ind)[r]] += prod[r];
197
198 _new_msgs.Zm[iI] = marg.normalize();
199 _new_msgs.m[iI] = marg;
200 }
201
202 void BP_dual::calcNewN(size_t iI) {
203 // XXX optimize
204 // calculate updated message i->I
205 size_t i = edge(iI).first;
206 size_t I = edge(iI).second;
207
208 Prob prod(var(i).states(), 1.0);
209 foreach(size_t J, nbV(i)) {
210 if(J != I) { // for all J in i \ I
211 prod *= msgM(J,i);
212 }
213 }
214 _new_msgs.Zn[iI] = prod.normalize();
215 _new_msgs.n[iI] = prod;
216 }
217
218 void BP_dual::upMsgM(size_t iI) {
219 _msgs.m[iI] = _new_msgs.m[iI];
220 _msgs.Zm[iI] = _new_msgs.Zm[iI];
221 }
222
223 void BP_dual::upMsgN(size_t iI) {
224 _msgs.n[iI] = _new_msgs.n[iI];
225 _msgs.Zn[iI] = _new_msgs.Zn[iI];
226 }
227
228 double BP_dual::run() {
229 DAI_IFVERB(1, "Starting " << identify() << "..." << endl);
230
231 double tic = toc();
232 // for some reason we need 2* here, where orig BP doesn't
233 Diffs diffs(2*nrVars(), 1.0);
234
235 vector<size_t> edge_seq;
236 vector<double> residuals;
237
238 vector<Factor> old_beliefs;
239 old_beliefs.reserve( nrVars() );
240 for( size_t i = 0; i < nrVars(); i++ ) {
241 CalcBelief1(i);
242 old_beliefs.push_back(belief1(i));
243 }
244
245 size_t iter = 0;
246
247 if( Updates() == UpdateType::SEQMAX ) {
248 // do the first pass
249 for(size_t iI = 0; iI < nr_edges(); iI++ ) {
250 calcNewM(iI);
251 calcNewN(iI);
252 }
253
254 // calculate initial residuals
255 residuals.reserve(nr_edges());
256 for( size_t iI = 0; iI < nr_edges(); iI++ )
257 residuals.push_back( dist( _new_msgs.m[iI], _msgs.m[iI], Prob::DISTLINF ) );
258 } else {
259 edge_seq.reserve( nr_edges() );
260 for( size_t i = 0; i < nr_edges(); i++ )
261 edge_seq.push_back( i );
262 }
263
264 // do several passes over the network until maximum number of iterations has
265 // been reached or until the maximum belief difference is smaller than tolerance
266 for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; iter++ ) {
267 if( Updates() == UpdateType::SEQMAX ) {
268 // Residuals-BP by Koller et al.
269 for( size_t t = 0; t < nr_edges(); t++ ) {
270 // update the message with the largest residual
271 size_t iI = max_element(residuals.begin(), residuals.end()) - residuals.begin();
272 upMsgM(iI);
273 residuals[iI] = 0;
274
275 // I->i has been updated, which means that residuals for all
276 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
277 size_t i = edge(iI).first;
278 size_t I = edge(iI).second;
279 foreach(size_t J, nbV(i)) {
280 if(J != I) {
281 size_t iJ = VV2E(i,J);
282 calcNewN(iJ);
283 upMsgN(iJ);
284 foreach(size_t j, nbF(J)) {
285 if(j != i) {
286 size_t jJ = VV2E(j,J);
287 calcNewM(jJ);
288 residuals[jJ] = dist( _new_msgs.m[jJ], _msgs.m[jJ], Prob::DISTLINF );
289 }
290 }
291 }
292 }
293 }
294 } else if( Updates() == UpdateType::PARALL ) {
295 // Parallel updates
296 for( size_t t = 0; t < nr_edges(); t++ ) {
297 calcNewM(t);
298 calcNewN(t);
299 }
300 if(0) {
301 for(size_t t=0; t<nr_edges(); t++) {
302 upMsgM(t); upMsgN(t);
303 }
304 } else {
305 _msgs = _new_msgs;
306 }
307 } else {
308 // Sequential updates
309 if( Updates() == UpdateType::SEQRND )
310 random_shuffle( edge_seq.begin(), edge_seq.end() );
311
312 foreach(size_t k, edge_seq) {
313 calcNewM(k);
314 calcNewN(k);
315 upMsgM(k);
316 upMsgN(k);
317 }
318 }
319
320 // calculate new beliefs and compare with old ones
321 for( size_t i = 0; i < nrVars(); i++ ) {
322 CalcBelief1(i);
323 Factor nb( belief1(i) );
324 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
325 old_beliefs[i] = nb;
326 }
327
328 DAI_IFVERB(3,"BP_dual::run: maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl);
329
330 _iters++;
331 }
332
333 updateMaxDiff( diffs.maxDiff() );
334
335 if( props.verbose >= 1 ) {
336 if( diffs.maxDiff() > props.tol ) {
337 DAI_IFVERB(1, endl << "BP_dual::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl);
338 } else {
339 DAI_IFVERB(3, "BP_dual::run: converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl);
340 }
341 }
342
343 CalcBeliefs();
344
345 return diffs.maxDiff();
346 }
347
348 string BP_dual::identify() const {
349 return string(Name) + printProperties();
350 }
351
352 vector<Factor> BP_dual::beliefs() const {
353 vector<Factor> result;
354 for( size_t i = 0; i < nrVars(); i++ )
355 result.push_back( belief1(i) );
356 for( size_t I = 0; I < nrFactors(); I++ )
357 result.push_back( belief2(I) );
358 return result;
359 }
360
361 Factor BP_dual::belief( const VarSet &ns ) const {
362 if( ns.size() == 1 )
363 return belief( *(ns.begin()) );
364 else {
365 size_t I;
366 for( I = 0; I < nrFactors(); I++ )
367 if( factor(I).vars() >> ns )
368 break;
369 assert( I != nrFactors() );
370 return belief2(I).marginal(ns);
371 }
372 }
373
374 Real BP_dual::logZ() const {
375 Real sum = 0.0;
376 for(size_t i = 0; i < nrVars(); i++ )
377 sum += Real(1.0 - nbV(i).size()) * belief1(i).entropy();
378 for( size_t I = 0; I < nrFactors(); I++ )
379 sum -= dist( belief2(I), factor(I), Prob::DISTKL );
380 return sum;
381 }
382
383 // reset only messages to/from certain variables
384 void BP_dual::init(const VarSet &ns) {
385 _iters=0;
386 foreach(Var n, ns) {
387 size_t ni = findVar(n);
388 size_t st = n.states();
389 foreach(Neighbor I, nbV(ni)) {
390 msgM(I.node,ni).fill(1.0/st);
391 zM(I.node,ni) = 1.0;
392 msgN(ni,I.node).fill(1.0/st);
393 zN(ni,I.node) = 1.0;
394 }
395 }
396 }
397
398 void BP_dual::init() {
399 _iters=0;
400 for(size_t iI = 0; iI < nr_edges(); iI++ ) {
401 _msgs.m[iI].setUniform();
402 _msgs.Zm[iI] = 1;
403 _msgs.n[iI].setUniform();
404 _msgs.Zn[iI] = 1;
405 }
406 _new_msgs = _msgs;
407 }
408
409
410 void BP_dual::init(const vector<size_t>& state) {
411 _iters=0;
412 for(size_t iI = 0; iI < nr_edges(); iI++ ) {
413 size_t i = edge(iI).first;
414 _msgs.m[iI].fill(0.1);
415 _msgs.m[iI][state[i]]=1;
416 _msgs.Zm[iI] = _msgs.m[iI].normalize();
417 _msgs.n[iI].fill(0.1);
418 _msgs.n[iI][state[i]]=1;
419 _msgs.Zn[iI] = _msgs.n[iI].normalize();
420 }
421 _new_msgs = _msgs;
422 }
423
424 void _clamp(FactorGraph &g, const Var & n, const vector<size_t> &is ) {
425 Factor mask_n(n,0.0);
426
427 foreach(size_t i, is) { assert( i <= n.states() ); mask_n[i] = 1.0; }
428
429 for( size_t I = 0; I < g.nrFactors(); I++ )
430 if( g.factor(I).vars() && n )
431 g.factor(I) *= mask_n;
432 }
433
434 // clamp a factor to have one of a set of values
435 void _clampFactor(FactorGraph &g, size_t I, const vector<size_t> &is) {
436 size_t st = g.factor(I).states();
437 Prob mask_n(st,0.0);
438
439 foreach(size_t i, is) { assert( i <= st ); mask_n[i] = 1.0; }
440
441 g.factor(I).p() *= mask_n;
442 }
443
444 } // end of namespace dai