Some improvements to the new BP_dual/BBP/CBP code to make it compile
[libdai.git] / src / bp_dual.cpp
1
2 #include <iostream>
3 #include <sstream>
4 #include <map>
5 #include <set>
6 #include <algorithm>
7
8 #include <dai/bbp.h>
9 //#include <dai/diffs.h>
10 #include <dai/util.h>
11 //#include "stlutil.h"
12 #include <dai/properties.h>
13
14 namespace dai {
15
16 using namespace std;
17
18 const char *BP_dual::Name = "BP_dual";
19
20 const char *BP_dual::PropertyList[] = {"tol","maxiter","updates","verbose"};
21
22 void BP_dual::setProperties( const PropertySet &opts ) {
23 // DAI_DMSG("in BP_dual::setProperties");
24 // DAI_PV(opts);
25
26 bool die=false;
27 foreach(const char *p, PropertyList) {
28 if( !opts.hasKey(p) ) {
29 cerr << "BP_dual: missing property " << p << endl;
30 die=true;
31 }
32 }
33 if(die)
34 DAI_THROW(NOT_ALL_PROPERTIES_SPECIFIED);
35
36 props.tol = opts.getStringAs<double>("tol");
37 props.maxiter = opts.getStringAs<size_t>("maxiter");
38 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
39 props.verbose = opts.getStringAs<size_t>("verbose");
40
41 // DAI_PV(printProperties());
42 }
43
44 PropertySet BP_dual::getProperties() const {
45 PropertySet opts;
46 opts.Set( "tol", props.tol );
47 opts.Set( "maxiter", props.maxiter );
48 opts.Set( "updates", props.updates );
49 opts.Set( "verbose", props.verbose );
50 return opts;
51 }
52
53 std::string BP_dual::printProperties() const {
54 stringstream s( stringstream::out );
55 s << "[";
56 s << "tol=" << props.tol << ",";
57 s << "maxiter=" << props.maxiter << ",";
58 s << "updates=" << props.updates << ",";
59 s << "verbose=" << props.verbose;
60 s << "]";
61 return s.str();
62 }
63
64 // void BP_dual::checkProperties() {
65 // const char *props[] = {"updates","tol","maxiter","verbose"};
66 // for(size_t i=0; i<sizeof(props)/sizeof(*props); i++) {
67 // if(!HasProperty(props[i]))
68 // die("BP_dual: Missing property \"%s\"", props[i]);
69 // }
70
71 // ConvertPropertyTo<double>("tol");
72 // ConvertPropertyTo<size_t>("maxiter");
73 // ConvertPropertyTo<size_t>("verbose");
74 // ConvertPropertyTo<UpdateType>("updates");
75 // }
76
77 void BP_dual::RegenerateIndices() {
78 _indices.clear();
79 _indices.reserve(nr_edges());
80
81 foreach(Edge iI, edges()) {
82 vector<size_t> ind( factor(iI.second).states(), 0 );
83 IndexFor i (var(iI.first), factor(iI.second).vars() );
84 for( size_t j = 0; i >= 0; ++i,++j )
85 ind[j] = i;
86 _indices.push_back( ind );
87 }
88 }
89
90 void BP_dual::RegenerateMessages() {
91 _msgs.Zn.resize(nr_edges(),1.0);
92 _msgs.Zm.resize(nr_edges(),1.0);
93
94 // clear messages
95 _msgs.m.clear();
96 _msgs.m.reserve(nr_edges());
97 _msgs.n.clear();
98 _msgs.n.reserve(nr_edges());
99
100 // create messages and indices
101 foreach(Edge iI, edges()) {
102 // initialize to uniform distributions
103 _msgs.m.push_back( Prob( var(iI.first).states() ) );
104 _msgs.n.push_back( Prob( var(iI.first).states() ) );
105 }
106
107 // create new_messages
108 _new_msgs = _msgs;
109 }
110
111 void BP_dual::RegenerateBeliefs() {
112 _beliefs.b1.clear();
113 _beliefs.b1.reserve(nrVars());
114 _beliefs.Zb1.resize(nrVars(), 1.0);
115 _beliefs.b2.clear();
116 _beliefs.b2.reserve(nrFactors());
117 _beliefs.Zb2.resize(nrFactors(), 1.0);
118
119 for(size_t i=0; i<nrVars(); i++) {
120 _beliefs.b1.push_back( Prob( var(i).states() ).setUniform() );
121 }
122 for(size_t I=0; I<nrFactors(); I++) {
123 _beliefs.b2.push_back( Prob( factor(I).states() ).setUniform() );
124 }
125 }
126
127 // called by constructor, called before 'init'
128 void BP_dual::Regenerate() {
129
130 indexEdges(); // so we can use compatibility interface
131
132 // DAIAlgFG::Regenerate(); // located in BipartiteGraph
133
134 RegenerateIndices();
135 RegenerateMessages();
136 RegenerateBeliefs();
137
138 _maxdiff = 0;
139 _iters = 0;
140 }
141
142 void BP_dual::CalcBelief1(size_t i) {
143 Prob prod( var(i).states(), 1.0 );
144 foreach(size_t I, nbV(i)) {
145 prod *= newMsgM(I,i);
146 }
147 _beliefs.Zb1[i] = prod.normalize();
148 _beliefs.b1[i] = prod;
149 }
150
151 void BP_dual::CalcBelief2(size_t I) {
152 Prob prod( factor(I).p() );
153 foreach(size_t j, nbF(I)) {
154 const _ind_t *ind = &(index(j, I));
155 for(size_t r=0; r<prod.size(); r++) {
156 Prob n(newMsgN(j,I));
157 prod[r] *= n[(*ind)[r]];
158 }
159 }
160 _beliefs.Zb2[I] = prod.normalize();
161 _beliefs.b2[I] = prod;
162 }
163
164 // called after run()
165 void BP_dual::CalcBeliefs() {
166 for(size_t i=0; i<nrVars(); i++) {
167 // calculate b_i
168 CalcBelief1(i);
169 }
170 for(size_t I=0; I<nrFactors(); I++) {
171 // calculate b_I
172 CalcBelief2(I);
173 }
174 }
175
176 void BP_dual::calcNewM(size_t iI) {
177 // calculate updated message I->i
178 size_t i = edge(iI).first;
179 size_t I = edge(iI).second;
180
181 Prob prod( factor(I).p() );
182
183 foreach(size_t j, nbF(I)) {
184 if( j != i ) { // for all j in I \ i
185 _ind_t* ind = &(index(j,I));
186 Prob n(msgN(j,I));
187 for( size_t r = 0; r < prod.size(); r++ )
188 prod[r] *= n[(*ind)[r]];
189 }
190 }
191
192 // Marginalize onto i
193 Prob marg( var(i).states(), 0.0 );
194 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
195 _ind_t* ind = &(index(i,I));
196 for( size_t r = 0; r < prod.size(); r++ )
197 marg[(*ind)[r]] += prod[r];
198
199 _new_msgs.Zm[iI] = marg.normalize();
200 _new_msgs.m[iI] = marg;
201 }
202
203 void BP_dual::calcNewN(size_t iI) {
204 // XXX optimize
205 // calculate updated message i->I
206 size_t i = edge(iI).first;
207 size_t I = edge(iI).second;
208
209 Prob prod(var(i).states(), 1.0);
210 foreach(size_t J, nbV(i)) {
211 if(J != I) { // for all J in i \ I
212 prod *= msgM(J,i);
213 }
214 }
215 _new_msgs.Zn[iI] = prod.normalize();
216 _new_msgs.n[iI] = prod;
217 }
218
219 void BP_dual::upMsgM(size_t iI) {
220 _msgs.m[iI] = _new_msgs.m[iI];
221 _msgs.Zm[iI] = _new_msgs.Zm[iI];
222 }
223
224 void BP_dual::upMsgN(size_t iI) {
225 _msgs.n[iI] = _new_msgs.n[iI];
226 _msgs.Zn[iI] = _new_msgs.Zn[iI];
227 }
228
229 double BP_dual::run() {
230 DAI_IFVERB(1, "Starting " << identify() << "..." << endl);
231
232 double tic = toc();
233 // for some reason we need 2* here, where orig BP doesn't
234 Diffs diffs(2*nrVars(), 1.0);
235
236 vector<size_t> edge_seq;
237 vector<double> residuals;
238
239 vector<Factor> old_beliefs;
240 old_beliefs.reserve( nrVars() );
241 for( size_t i = 0; i < nrVars(); i++ ) {
242 CalcBelief1(i);
243 old_beliefs.push_back(belief1(i));
244 }
245
246 size_t iter = 0;
247
248 if( Updates() == UpdateType::SEQMAX ) {
249 // do the first pass
250 for(size_t iI = 0; iI < nr_edges(); iI++ ) {
251 calcNewM(iI);
252 calcNewN(iI);
253 }
254
255 // calculate initial residuals
256 residuals.reserve(nr_edges());
257 for( size_t iI = 0; iI < nr_edges(); iI++ )
258 residuals.push_back( dist( _new_msgs.m[iI], _msgs.m[iI], Prob::DISTLINF ) );
259 } else {
260 edge_seq.reserve( nr_edges() );
261 for( size_t i = 0; i < nr_edges(); i++ )
262 edge_seq.push_back( i );
263 }
264
265 // do several passes over the network until maximum number of iterations has
266 // been reached or until the maximum belief difference is smaller than tolerance
267 for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; iter++ ) {
268 if( Updates() == UpdateType::SEQMAX ) {
269 // Residuals-BP by Koller et al.
270 for( size_t t = 0; t < nr_edges(); t++ ) {
271 // update the message with the largest residual
272 size_t iI = max_element(residuals.begin(), residuals.end()) - residuals.begin();
273 upMsgM(iI);
274 residuals[iI] = 0;
275
276 // I->i has been updated, which means that residuals for all
277 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
278 size_t i = edge(iI).first;
279 size_t I = edge(iI).second;
280 foreach(size_t J, nbV(i)) {
281 if(J != I) {
282 size_t iJ = VV2E(i,J);
283 calcNewN(iJ);
284 upMsgN(iJ);
285 foreach(size_t j, nbF(J)) {
286 if(j != i) {
287 size_t jJ = VV2E(j,J);
288 calcNewM(jJ);
289 residuals[jJ] = dist( _new_msgs.m[jJ], _msgs.m[jJ], Prob::DISTLINF );
290 }
291 }
292 }
293 }
294 }
295 } else if( Updates() == UpdateType::PARALL ) {
296 // Parallel updates
297 for( size_t t = 0; t < nr_edges(); t++ ) {
298 calcNewM(t);
299 calcNewN(t);
300 }
301 if(0) {
302 for(size_t t=0; t<nr_edges(); t++) {
303 upMsgM(t); upMsgN(t);
304 }
305 } else {
306 _msgs = _new_msgs;
307 }
308 } else {
309 // Sequential updates
310 if( Updates() == UpdateType::SEQRND )
311 random_shuffle( edge_seq.begin(), edge_seq.end() );
312
313 foreach(size_t k, edge_seq) {
314 calcNewM(k);
315 calcNewN(k);
316 upMsgM(k);
317 upMsgN(k);
318 }
319 }
320
321 // calculate new beliefs and compare with old ones
322 for( size_t i = 0; i < nrVars(); i++ ) {
323 CalcBelief1(i);
324 Factor nb( belief1(i) );
325 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
326 old_beliefs[i] = nb;
327 }
328
329 DAI_IFVERB(3,"BP_dual::run: maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl);
330
331 _iters++;
332 }
333
334 updateMaxDiff( diffs.maxDiff() );
335
336 if( props.verbose >= 1 ) {
337 if( diffs.maxDiff() > props.tol ) {
338 DAI_IFVERB(1, endl << "BP_dual::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl);
339 } else {
340 DAI_IFVERB(3, "BP_dual::run: converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl);
341 }
342 }
343
344 CalcBeliefs();
345
346 return diffs.maxDiff();
347 }
348
349 string BP_dual::identify() const {
350 return string(Name) + printProperties();
351 }
352
353 vector<Factor> BP_dual::beliefs() const {
354 vector<Factor> result;
355 for( size_t i = 0; i < nrVars(); i++ )
356 result.push_back( belief1(i) );
357 for( size_t I = 0; I < nrFactors(); I++ )
358 result.push_back( belief2(I) );
359 return result;
360 }
361
362 Factor BP_dual::belief( const VarSet &ns ) const {
363 if( ns.size() == 1 )
364 return belief( *(ns.begin()) );
365 else {
366 size_t I;
367 for( I = 0; I < nrFactors(); I++ )
368 if( factor(I).vars() >> ns )
369 break;
370 assert( I != nrFactors() );
371 return belief2(I).marginal(ns);
372 }
373 }
374
375 Real BP_dual::logZ() const {
376 Real sum = 0.0;
377 for(size_t i = 0; i < nrVars(); i++ )
378 sum += Real(1.0 - nbV(i).size()) * belief1(i).entropy();
379 for( size_t I = 0; I < nrFactors(); I++ )
380 sum -= dist( belief2(I), factor(I), Prob::DISTKL );
381 return sum;
382 }
383
384 // reset only messages to/from certain variables
385 void BP_dual::init(const VarSet &ns) {
386 _iters=0;
387 foreach(Var n, ns) {
388 size_t ni = findVar(n);
389 size_t st = n.states();
390 foreach(Neighbor I, nbV(ni)) {
391 msgM(I.node,ni).fill(1.0/st);
392 zM(I.node,ni) = 1.0;
393 msgN(ni,I.node).fill(1.0/st);
394 zN(ni,I.node) = 1.0;
395 }
396 }
397 }
398
399 void BP_dual::init() {
400 _iters=0;
401 for(size_t iI = 0; iI < nr_edges(); iI++ ) {
402 _msgs.m[iI].setUniform();
403 _msgs.Zm[iI] = 1;
404 _msgs.n[iI].setUniform();
405 _msgs.Zn[iI] = 1;
406 }
407 _new_msgs = _msgs;
408 }
409
410
411 void BP_dual::init(const vector<size_t>& state) {
412 _iters=0;
413 for(size_t iI = 0; iI < nr_edges(); iI++ ) {
414 size_t i = edge(iI).first;
415 _msgs.m[iI].fill(0.1);
416 _msgs.m[iI][state[i]]=1;
417 _msgs.Zm[iI] = _msgs.m[iI].normalize();
418 _msgs.n[iI].fill(0.1);
419 _msgs.n[iI][state[i]]=1;
420 _msgs.Zn[iI] = _msgs.n[iI].normalize();
421 }
422 _new_msgs = _msgs;
423 }
424
425 void _clamp(FactorGraph &g, const Var & n, const vector<size_t> &is ) {
426 Factor mask_n(n,0.0);
427
428 foreach(size_t i, is) { assert( i <= n.states() ); mask_n[i] = 1.0; }
429
430 for( size_t I = 0; I < g.nrFactors(); I++ )
431 if( g.factor(I).vars().contains( n ) )
432 g.factor(I) *= mask_n;
433 }
434
435 // clamp a factor to have one of a set of values
436 void _clampFactor(FactorGraph &g, size_t I, const vector<size_t> &is) {
437 size_t st = g.factor(I).states();
438 Prob mask_n(st,0.0);
439
440 foreach(size_t i, is) { assert( i <= st ); mask_n[i] = 1.0; }
441
442 g.factor(I).p() *= mask_n;
443 }
444
445 } // end of namespace dai