[Frederik Eaton] Added BP_Dual, BBP and CBP algorithms
[libdai.git] / src / cbp.cpp
1 #include <iostream>
2 #include <sstream>
3 #include <map>
4 #include <set>
5 #include <algorithm>
6
7 #include <dai/util.h>
8 #include <dai/properties.h>
9
10 #include <dai/bp.h>
11 #include <dai/cbp.h>
12 #include <dai/bbp.h>
13
14 using namespace std;
15
16 namespace dai {
17
18 const char *CBP::Name = "CBP";
19
20 const char *CBP::PropertyList[] = {"updates","tol","rec_tol","maxiter","verbose","max_levels","min_max_adj","choose","clamp","recursion","bbp_cfn","rand_seed"};
21
22 #define rnd_multi(x) rnd_int(0,(x)-1)
23
24 void CBP::setProperties(const PropertySet &opts) {
25 // DAI_DMSG("in CBP::setProperties");
26 // DAI_PV(opts);
27 foreach(const char* p, PropertyList) {
28 if(!opts.hasKey(p)) {
29 // XXX probably leaks pointer?
30 throw (string("CBP: Missing property ")+p).c_str();
31 }
32 }
33
34 props.tol = opts.getStringAs<double>("tol");
35 props.rec_tol = opts.getStringAs<double>("rec_tol");
36 props.maxiter = opts.getStringAs<size_t>("maxiter");
37 props.verbose = opts.getStringAs<size_t>("verbose");
38 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
39 props.max_levels = opts.getStringAs<size_t>("max_levels");
40 props.min_max_adj = opts.getStringAs<double>("min_max_adj");
41 props.choose = opts.getStringAs<Properties::ChooseMethodType>("choose");
42 props.recursion = opts.getStringAs<Properties::RecurseType>("recursion");
43 props.clamp = opts.getStringAs<Properties::ClampType>("clamp");
44 props.bbp_cfn = opts.getStringAs<bbp_cfn_t>("bbp_cfn");
45 props.rand_seed = opts.getStringAs<size_t>("rand_seed");
46 }
47
48 PropertySet CBP::getProperties() const {
49 PropertySet opts;
50 opts.Set( "tol", props.tol );
51 opts.Set( "rec_tol", props.rec_tol );
52 opts.Set( "maxiter", props.maxiter );
53 opts.Set( "verbose", props.verbose );
54 opts.Set( "updates", props.updates );
55 opts.Set( "max_levels", props.max_levels );
56 opts.Set( "min_max_adj", props.min_max_adj );
57 opts.Set( "choose", props.choose );
58 opts.Set( "recursion", props.recursion );
59 opts.Set( "clamp", props.clamp );
60 opts.Set( "bbp_cfn", props.bbp_cfn );
61 opts.Set( "rand_seed", props.rand_seed );
62 return opts;
63 }
64
65 std::string CBP::printProperties() const {
66 stringstream s( stringstream::out );
67 s << "[";
68 s << "tol=" << props.tol << ",";
69 s << "rec_tol=" << props.rec_tol << ",";
70 s << "maxiter=" << props.maxiter << ",";
71 s << "verbose=" << props.verbose << ",";
72 s << "updates=" << props.updates << ",";
73 s << "max_levels=" << props.max_levels << ",";
74 s << "min_max_adj=" << props.min_max_adj << ",";
75 s << "choose=" << props.choose << ",";
76 s << "recursion=" << props.recursion << ",";
77 s << "clamp=" << props.clamp << ",";
78 s << "bbp_cfn=" << props.bbp_cfn << ",";
79 s << "rand_seed=" << props.rand_seed << ",";
80 s << "]";
81 return s.str();
82 }
83
84 void CBP::construct() {
85 // DAIAlgFG::Regenerate();
86 indexEdges();
87
88 _beliefs1.clear(); _beliefs1.reserve(nrVars());
89 for( size_t i = 0; i < nrVars(); i++ )
90 _beliefs1.push_back( Factor(var(i)).normalized() );
91
92 _beliefs2.clear(); _beliefs2.reserve(nrFactors());
93 for( size_t I = 0; I < nrFactors(); I++ ) {
94 Factor f = factor(I);
95 f.fill(1); f.normalize();
96 _beliefs2.push_back(f);
97 }
98
99 // to compute average level
100 _sum_level = 0;
101 _num_leaves = 0;
102
103 _maxdiff = 0;
104 _iters = 0;
105 }
106
107 static
108 vector<Factor> mixBeliefs(Real p, vector<Factor> b, vector<Factor> c) {
109 vector<Factor> out;
110 assert(b.size()==c.size());
111 out.reserve(b.size());
112 Real pc = 1-p;
113 for(size_t i=0; i<b.size(); i++) {
114 // XXX probably already normalized
115 out.push_back(b[i].normalized()*p+
116 c[i].normalized()*pc);
117 }
118 return out;
119 }
120
121 double CBP::run() {
122 // BP bp(getBP());
123 // InfAlg *bp = newInfAlg( GetPropertyAs<string>("bp_alg"), *this, GetPropertyAs<Properties>("bp_opts") );
124 size_t seed = props.rand_seed;
125 if(seed>0) rnd_seed(seed);
126
127 BP_dual bp_dual(getBP_dual());
128 bp_dual.init();
129 bp_dual.run();
130 _iters += bp_dual.Iterations();
131
132 vector<Factor> beliefs_out;
133 Real lz_out;
134 runRecurse(bp_dual, bp_dual.logZ(),
135 vector<size_t>(0), set<size_t>(),
136 _num_leaves, _sum_level,
137 lz_out, beliefs_out);
138 if(props.verbose>=1) {
139 cerr << "CBP average levels = " << (_sum_level/_num_leaves) << ", leaves = " << _num_leaves << endl;
140 }
141 setBeliefs(beliefs_out, lz_out);
142 return 0;
143 }
144
145 BP CBP::getBP() {
146 PropertySet bpProps;
147 bpProps.Set("updates", string("PARALL"));
148 bpProps.Set("tol", props.tol);
149 bpProps.Set("maxiter", props.maxiter);
150 bpProps.Set("verbose", oneLess(props.verbose));
151 BP bp(*this,bpProps);
152 bp.init();
153 return bp;
154 }
155
156 BP_dual CBP::getBP_dual() {
157 PropertySet bpProps;
158 bpProps.Set("updates", string("PARALL"));
159 bpProps.Set("tol", props.tol);
160 bpProps.Set("maxiter", props.maxiter);
161 bpProps.Set("verbose", oneLess(props.verbose));
162 // cerr << "In getBP_dual" << endl;
163 // DAI_PV(bpProps);
164 BP_dual bp_dual(*this,bpProps);
165 return bp_dual;
166 }
167
168 vector<size_t> complement(vector<size_t>& xis, size_t n_states) {
169 vector<size_t> cmp_xis(0);
170 size_t j=0;
171 for(size_t xi=0; xi<n_states; xi++) {
172 while(j<xis.size() && xis[j]<xi) j++;
173 if(j>=xis.size() || xis[j]>xi) cmp_xis.push_back(xi);
174 }
175 assert( xis.size()+cmp_xis.size() == n_states );
176 return cmp_xis;
177 }
178
179 Real max(Real x,Real y) { return x>y?x:y; }
180
181 Real unSoftMax(Real lz, Real cmp_lz) {
182 double m = max(lz, cmp_lz);
183 lz -= m; cmp_lz -= m;
184 double p = exp(lz)/(exp(lz)+exp(cmp_lz));
185 return p;
186 }
187
188 Real logSumExp(Real lz, Real cmp_lz) {
189 double m = max(lz, cmp_lz);
190 lz -= m; cmp_lz -= m;
191 return m+log(exp(lz)+exp(cmp_lz));
192 }
193
194 Real dist(const vector<Factor>& b1, const vector<Factor>& b2, size_t nv) {
195 Real d=0.0;
196 for(size_t k=0; k<nv; k++) {
197 d += dist( b1[k], b2[k], Prob::DISTLINF );
198 }
199 return d;
200 }
201
202
203 void CBP::runRecurse(BP_dual &bp_dual,
204 double orig_logZ,
205 vector<size_t> clamped_vars_list,
206 set<size_t> clamped_vars,
207 size_t &num_leaves,
208 double &sum_level,
209 Real &lz_out, vector<Factor>& beliefs_out) {
210 // choose a variable/states to clamp:
211 size_t i;
212 vector<size_t> xis;
213 Real maxVar=0.0;
214 bool found;
215 bool clampVar = (Clamping()==Properties::ClampType::CLAMP_VAR);
216
217 // XXX fix to just pass orig_logZ
218
219 if(Recursion()==Properties::RecurseType::REC_LOGZ && recTol()>0 &&
220 exp(bp_dual.logZ()-orig_logZ) < recTol()) {
221 found = false;
222 } else {
223 found = chooseNextClampVar(bp_dual,
224 clamped_vars_list,
225 clamped_vars,
226 i, xis, &maxVar);
227 }
228
229 if(!found) {
230 num_leaves++;
231 sum_level += clamped_vars_list.size();
232 beliefs_out = bp_dual.beliefs();
233 lz_out = bp_dual.logZ();
234 return;
235 }
236
237 if(clampVar) {
238 foreach(size_t xi, xis) { assert(/*0<=xi &&*/ xi<var(i).states()); }
239 } else {
240 foreach(size_t xI, xis) { assert(/*0<=xI &&*/ xI<factor(i).states()); }
241 }
242 // - otherwise, clamp and recurse, saving margin estimates for each
243 // clamp setting. afterwards, combine estimates.
244
245 // compute complement of 'xis'
246 vector<size_t> cmp_xis=complement(xis, clampVar?var(i).states():factor(i).states());
247
248 // XXX could do this more efficiently with a nesting version of
249 // saveProbs/undoProbs
250 Real lz; vector<Factor> b;
251 BP_dual bp_dual_c(bp_dual);
252 if(clampVar) {
253 _clamp((FactorGraph&)bp_dual_c, var(i), xis);
254 bp_dual_c.init(var(i));
255 } else {
256 _clampFactor((FactorGraph&)bp_dual_c, i, xis);
257 bp_dual_c.init(factor(i).vars());
258 }
259 bp_dual_c.run();
260 _iters += bp_dual_c.Iterations();
261
262 lz = bp_dual_c.logZ();
263 b = bp_dual_c.beliefs();
264
265 Real cmp_lz; vector<Factor> cmp_b;
266 BP_dual cmp_bp_dual_c(bp_dual);
267 if(clampVar) {
268 _clamp(cmp_bp_dual_c,var(i),cmp_xis);
269 cmp_bp_dual_c.init(var(i));
270 } else {
271 _clampFactor(cmp_bp_dual_c,i,cmp_xis);
272 cmp_bp_dual_c.init(factor(i).vars());
273 }
274 cmp_bp_dual_c.run();
275 _iters += cmp_bp_dual_c.Iterations();
276
277 cmp_lz = cmp_bp_dual_c.logZ();
278 cmp_b = cmp_bp_dual_c.beliefs();
279
280 double p = unSoftMax(lz, cmp_lz);
281 Real bp__d=0.0;
282
283 if(Recursion()==Properties::RecurseType::REC_BDIFF && recTol() > 0) {
284 vector<Factor> combined_b(mixBeliefs(p,b,cmp_b));
285 Real new_lz = logSumExp(lz,cmp_lz);
286 bp__d = dist(bp_dual.beliefs(),combined_b,nrVars());
287 if(exp(new_lz-orig_logZ)*bp__d < recTol()) {
288 num_leaves++;
289 sum_level += clamped_vars_list.size();
290 beliefs_out = combined_b;
291 lz_out = new_lz;
292 return;
293 }
294 }
295
296 // either we are not doing REC_BDIFF or the distance was large
297 // enough to recurse:
298
299 runRecurse(bp_dual_c, orig_logZ,
300 clamped_vars_list,
301 clamped_vars,
302 num_leaves, sum_level, lz, b);
303 runRecurse(cmp_bp_dual_c, orig_logZ,
304 clamped_vars_list,
305 clamped_vars,
306 num_leaves, sum_level, cmp_lz, cmp_b);
307
308 p = unSoftMax(lz,cmp_lz);
309
310 beliefs_out = mixBeliefs(p, b, cmp_b);
311 lz_out = logSumExp(lz,cmp_lz);
312
313 if(props.verbose>=2) {
314 Real d = dist( bp_dual.beliefs(), beliefs_out, nrVars() );
315 cerr << "Distance (clamping " << i << "): " << d;
316 if(Recursion()==Properties::RecurseType::REC_BDIFF)
317 cerr << "; bp_dual predicted " << bp__d;
318 cerr << "; max adjoint = " << maxVar << "; level = " << clamped_vars_list.size() << endl;
319 }
320 }
321
322 // 'xis' must be sorted
323 bool CBP::chooseNextClampVar(BP_dual& bp,
324 vector<size_t> &clamped_vars_list,
325 set<size_t> &clamped_vars,
326 size_t &i, vector<size_t> &xis, Real *maxVarOut) {
327 Real tiny=1.0e-14;
328 if(props.verbose>=3) {
329 cerr << "clamped_vars_list" << clamped_vars_list << endl;
330 }
331 if(clamped_vars_list.size() >= maxClampLevel()) {
332 return false;
333 }
334 if(ChooseMethod()==Properties::ChooseMethodType::CHOOSE_RANDOM) {
335 if(Clamping()==Properties::ClampType::CLAMP_VAR) {
336 int t=0, t1=100;
337 do {
338 i = rnd_multi(nrVars());
339 t++;
340 } while(abs(bp.belief1(i).p().max()-1) < tiny &&
341 t < t1);
342 if(t==t1) {
343 return false;
344 // die("Too many levels requested in CBP");
345 }
346 // only pick probable values for variable
347 size_t xi;
348 do {
349 xi = rnd_multi(var(i).states());
350 t++;
351 } while(bp.belief1(i).p()[xi] < tiny && t<t1);
352 assert(t<t1);
353 xis.resize(1, xi);
354 // assert(!_clamped_vars.count(i)); // not true for >2-ary variables
355 DAI_IFVERB(2, endl<<"CHOOSE_RANDOM chose variable "<<i<<" state "<<xis[0]<<endl);
356 } else {
357 int t=0, t1=100;
358 do {
359 i = rnd_multi(nrFactors());
360 t++;
361 } while(abs(bp.belief2(i).p().max()-1) < tiny &&
362 t < t1);
363 if(t==t1) {
364 return false;
365 // die("Too many levels requested in CBP");
366 }
367 // only pick probable values for variable
368 size_t xi;
369 do {
370 xi = rnd_multi(factor(i).states());
371 t++;
372 } while(bp.belief2(i).p()[xi] < tiny && t<t1);
373 assert(t<t1);
374 xis.resize(1, xi);
375 // assert(!_clamped_vars.count(i)); // not true for >2-ary variables
376 DAI_IFVERB(2, endl<<"CHOOSE_RANDOM chose factor "<<i<<" state "<<xis[0]<<endl);
377 }
378 } else if(ChooseMethod()==Properties::ChooseMethodType::CHOOSE_BP_ALL) {
379 // try clamping each variable manually
380 assert(Clamping()==Properties::ClampType::CLAMP_VAR);
381 Real max_diff=0.0;
382 int win_k=-1, win_xk=-1;
383 for(size_t k=0; k<nrVars(); k++) {
384 for(size_t xk=0; xk<var(k).states(); xk++) {
385 if(bp.belief1(k)[xk]<tiny) continue;
386 BP_dual bp1(bp);
387 bp1.clamp(var(k), xk);
388 bp1.init(var(k));
389 bp1.run();
390 Real diff=0;
391 for(size_t j=0; j<nrVars(); j++) {
392 diff += dist(bp.belief1(j), bp1.belief1(j), Prob::DISTL1);
393 }
394 if(diff>max_diff) {
395 max_diff=diff; win_k=k; win_xk=xk;
396 }
397 }
398 }
399 assert(win_k>=0); assert(win_xk>=0);
400 i = win_k; xis.resize(1, win_xk);
401 } else if(ChooseMethod()==Properties::ChooseMethodType::CHOOSE_BBP) {
402 Real mvo; if(!maxVarOut) maxVarOut = &mvo;
403 bool clampVar = (Clamping()==Properties::ClampType::CLAMP_VAR);
404 pair<size_t, size_t> cv =
405 bbpFindClampVar(bp,
406 clampVar,
407 clamped_vars_list.size(),
408 BBP_cost_function(),getProperties(),maxVarOut);
409
410 // if slope isn't big enough then don't clamp
411 if(*maxVarOut < minMaxAdj()) return false;
412
413 size_t xi=cv.second;
414 i = cv.first;
415 #define VAR_INFO (clampVar?"variable ":"factor ") \
416 << i << " state " << xi \
417 << " (p=" << (clampVar?bp.belief1(i)[xi]:bp.belief2(i)[xi]) \
418 << ", entropy = " << (clampVar?bp.belief1(i):bp.belief2(i)).entropy() \
419 << ", maxVar = "<< *maxVarOut << ")"
420 Prob b = (clampVar?bp.belief1(i).p():bp.belief2(i).p());
421 if(b[xi] < tiny) {
422 cerr << "Warning, bbpFindClampVar found unlikely "
423 << VAR_INFO << endl;
424 return false;
425 }
426 if(abs(b[xi]-1) < tiny) {
427 cerr << "Warning, bbpFindClampVar found overly likely "
428 << VAR_INFO << endl;
429 return false;
430 }
431
432 xis.resize(1,xi);
433 if(clampVar) {
434 assert(/*0<=xi &&*/ xi<var(i).states());
435 } else {
436 assert(/*0<=xi &&*/ xi<factor(i).states());
437 }
438 DAI_IFVERB(2, "CHOOSE_BBP (num clamped = " << clamped_vars_list.size()
439 << ") chose " << i << " state " << xi << endl);
440 } else {
441 abort();
442 }
443 clamped_vars_list.push_back(i);
444 clamped_vars.insert(i);
445 return true;
446 }
447
448 // void CBP::clamp(const Var & n, size_t i) {
449 // FactorGraph::clamp(n,i);
450 // _clamped_vars.insert(findVar(n));
451 // _clamped_vars_list.push_back(findVar(n));
452 // }
453
454 // void CBP::clamp(const Var & n, const vector<size_t> &is) {
455 // FactorGraph::clamp(n,is);
456 // _clamped_vars.insert(findVar(n));
457 // _clamped_vars_list.push_back(findVar(n));
458 // }
459
460 void CBP::printDebugInfo() {
461 DAI_PV(_beliefs1);
462 DAI_PV(_beliefs2);
463 DAI_PV(_logZ);
464 }
465
466 //----------------------------------------------------------------
467
468 bool doBBPTest=false;
469 bool doBBPGraph=false;
470 size_t bbpGraphLevel=3;
471
472 #define BPP_INIT_GIBBS 1
473
474 /// function which takes a factor graph as input, runs Gibbs and BP_dual,
475 /// creates and runs a BBP object, finds best variable, returns
476 /// (variable,state) pair for clamping
477 // pair<size_t, size_t> bbpFindClampVar(const CBP &fg, bbp_cfn_t cfn, const Properties &props, Real *maxVarOut) {
478 pair<size_t, size_t> bbpFindClampVar(BP_dual &in_bp_dual, bool clampVar,
479 size_t numClamped, bbp_cfn_t cfn, const PropertySet &props, Real *maxVarOut) {
480 #if BPP_INIT_GIBBS
481 vector<size_t> state = getGibbsState(in_bp_dual, 100);
482 in_bp_dual.init(state);
483 in_bp_dual.run();
484 #endif
485
486 Real ourTol = doBBPTest ? 1.0e-11 : 1.0e-3;
487 if(0) {
488 PropertySet bp_Props;
489 bp_Props.Set("updates", string("PARALL"));
490 // bp_Props.Set("tol", props.GetAs<double>("tol"));
491 bp_Props.Set("tol", ourTol);
492 bp_Props.Set("maxiter", props.GetAs<size_t>("maxiter"));
493 bp_Props.Set("verbose", oneLess(props.GetAs<size_t>("verbose")));
494 // bp_Props.ConvertTo<BP_dual::UpdateType>("updates");
495 // DAI_PV(bp_Props.GetAs<BP_dual::UpdateType>("updates"));
496 BP_dual bp_dual(in_bp_dual, bp_Props);
497 #if BPP_INIT_GIBBS
498 bp_dual.init(state);
499 #endif
500 bp_dual.run();
501 }
502
503 if(doBBPGraph && numClamped == bbpGraphLevel) {
504 cerr << "Writing BBP graph data" << endl;
505 makeBBPGraph(in_bp_dual,cfn);
506 doBBPGraph=false; // only do it once
507 cerr << "Done writing BBP graph data" << endl;
508 }
509 if(doBBPTest) {
510 double err = numericBBPTest(in_bp_dual, cfn, /*bbp tol*/ ourTol, /*h*/ 1.0e-5);
511 cerr << "Error from numericBBPTest: " << err << endl;
512 }
513 Real tic1=toc();
514 BBP bbp(in_bp_dual);
515 bbp.maxIter() = props.GetAs<size_t>("maxiter");
516 #if BPP_INIT_GIBBS
517 gibbsInitBBPCostFnAdj(bbp, in_bp_dual, cfn, &state);
518 #else
519 gibbsInitBBPCostFnAdj(bbp, in_bp_dual, cfn, NULL);
520 #endif
521 Real tic2=toc();
522 bbp.run(ourTol);
523 if(props.GetAs<size_t>("verbose") >= 3) {
524 cerr << "BBP took " << toc()-tic1 << " seconds (BBP.run = " << toc()-tic2 << " seconds), "
525 << bbp.doneIters() << " iterations" << endl;
526 }
527
528 // find and return the (variable,state) with the largest adj_psi_1
529 size_t argmax_var=0;
530 size_t argmax_var_state=0;
531 Real max_var=0;
532 if(clampVar) {
533 for(size_t i=0; i<in_bp_dual.nrVars(); i++) {
534 Prob adj_psi_1 = bbp.adj_psi_1(i);
535 if(0) {
536 // helps to account for amount of movement possible in variable
537 // i's beliefs? seems not..
538 adj_psi_1 *= in_bp_dual.belief1(i).entropy();
539 }
540 // try to compensate for effect on same variable (doesn't work)
541 // adj_psi_1[gibbs.state()[i]] -= bp_dual.belief1(i)[gibbs.state()[i]]/10;
542 pair<size_t,Real> argmax_state = adj_psi_1.argmax();
543
544 if(i==0 || argmax_state.second>max_var) {
545 argmax_var = i;
546 max_var = argmax_state.second;
547 argmax_var_state = argmax_state.first;
548 }
549 }
550 assert(/*0 <= argmax_var_state &&*/
551 argmax_var_state < in_bp_dual.var(argmax_var).states());
552 } else {
553 for(size_t I=0; I<in_bp_dual.nrFactors(); I++) {
554 Prob adj_psi_2 = bbp.adj_psi_2(I);
555 if(0) {
556 // helps to account for amount of movement possible in variable
557 // i's beliefs? seems not..
558 adj_psi_2 *= in_bp_dual.belief2(I).entropy();
559 }
560 // try to compensate for effect on same variable (doesn't work)
561 // adj_psi_1[gibbs.state()[i]] -= bp_dual.belief1(i)[gibbs.state()[i]]/10;
562 pair<size_t,Real> argmax_state = adj_psi_2.argmax();
563
564 if(I==0 || argmax_state.second>max_var) {
565 argmax_var = I;
566 max_var = argmax_state.second;
567 argmax_var_state = argmax_state.first;
568 }
569 }
570 assert(/*0 <= argmax_var_state &&*/
571 argmax_var_state < in_bp_dual.factor(argmax_var).states());
572 }
573 if(maxVarOut) *maxVarOut = max_var;
574 return make_pair(argmax_var,argmax_var_state);
575 }
576
577 } // end of namespace dai