Rewrote implementation of response propagation in MR
[libdai.git] / src / mr.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2007 Bastian Wemmenhove
8 * Copyright (C) 2007-2010 Joris Mooij [joris dot mooij at libdai dot org]
9 * Copyright (C) 2007 Radboud University Nijmegen, The Netherlands
10 */
11
12
13 #include <cstdio>
14 #include <ctime>
15 #include <cmath>
16 #include <cstdlib>
17 #include <dai/mr.h>
18 #include <dai/bp.h>
19 #include <dai/jtree.h>
20 #include <dai/util.h>
21 #include <dai/bbp.h>
22
23
24 namespace dai {
25
26
27 using namespace std;
28
29
30 const char *MR::Name = "MR";
31
32
33 void MR::setProperties( const PropertySet &opts ) {
34 DAI_ASSERT( opts.hasKey("tol") );
35 DAI_ASSERT( opts.hasKey("verbose") );
36 DAI_ASSERT( opts.hasKey("updates") );
37 DAI_ASSERT( opts.hasKey("inits") );
38
39 props.tol = opts.getStringAs<Real>("tol");
40 props.verbose = opts.getStringAs<size_t>("verbose");
41 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
42 props.inits = opts.getStringAs<Properties::InitType>("inits");
43 }
44
45
46 PropertySet MR::getProperties() const {
47 PropertySet opts;
48 opts.Set( "tol", props.tol );
49 opts.Set( "verbose", props.verbose );
50 opts.Set( "updates", props.updates );
51 opts.Set( "inits", props.inits );
52 return opts;
53 }
54
55
56 string MR::printProperties() const {
57 stringstream s( stringstream::out );
58 s << "[";
59 s << "tol=" << props.tol << ",";
60 s << "verbose=" << props.verbose << ",";
61 s << "updates=" << props.updates << ",";
62 s << "inits=" << props.inits << "]";
63 return s.str();
64 }
65
66
67 void MR::init(size_t Nin, Real *_w, Real *_th) {
68 size_t i,j;
69
70 N = Nin;
71
72 con.resize(N);
73 nb.resize(N);
74 tJ.resize(N);
75 for(i=0; i<N; i++ ) {
76 nb[i].resize(kmax);
77 tJ[i].resize(kmax);
78 con[i]=0;
79 for(j=0; j<N; j++ )
80 if( _w[i*N+j] != 0.0 ) {
81 nb[i][con[i]] = j;
82 tJ[i][con[i]] = tanh(_w[i*N+j]);
83 con[i]++;
84 }
85 }
86
87 theta.resize(N);
88 for(i=0; i<N; i++)
89 theta[i] = _th[i];
90 }
91
92
93 Real MR::init_cor_resp_old() {
94 Real md, maxdev;
95
96 size_t maxIters = 3000;
97 Real damping = 0.1;
98
99 vector<vector<Real> > tJ_org;
100 vector<vector<size_t> > nb_org;
101 vector<size_t> con_org;
102
103 // save original tJ, nb
104 nb_org = nb;
105 tJ_org = tJ;
106 con_org = con;
107
108 maxdev = 0.0;
109 for( size_t cavity = 0; cavity < N; cavity++ ) { // for each spin to be removed
110 con = con_org;
111 size_t concav = con[cavity];
112
113 nb = nb_org;
114 tJ = tJ_org;
115
116 // Adapt the graph variables nb[], tJ[] and con[]
117 for(size_t i=0; i<con[cavity]; i++) {
118 size_t ij = nb[cavity][i];
119 size_t flag=0;
120 size_t j=0;
121 do{
122 if(nb[ij][j]==cavity){
123 while(j<(con[ij]-1)){
124 nb[ij][j]=nb[ij][j+1];
125 tJ[ij][j] = tJ[ij][j+1];
126 j++;
127 }
128 flag=1;
129 }
130 j++;
131 } while(flag==0);
132 }
133 for(size_t i=0; i<con[cavity]; i++)
134 con[nb[cavity][i]]--;
135 con[cavity] = 0;
136
137 // Do everything starting from the new graph********
138
139 makekindex();
140
141 vector<Real> xfield(N*kmax,0.0);
142 vector<Real> rfield(N*kmax,0.0);
143
144 for( size_t i = 0; i < kmax*N; i++ )
145 xfield[i] = 3.0 * (2.0 * rnd_uniform() - 1.);
146
147 for( size_t _i2 = 0; _i2 < concav; _i2++ ) { // Subsequently apply a field to each cavity spin ****************
148 size_t i2 = nb[cavity][_i2];
149
150 for( size_t _i = 0; _i < con[i2]; _i++ )
151 rfield[kmax*i2 + _i] = 1.0;
152
153 size_t iters = 0;
154 do { // From here start the response and belief propagation
155 iters++;
156 md = 0.0;
157 for( size_t k = 0; k < N; k++ ){
158 if( k != cavity ) {
159 for( size_t _l = 0; _l < con[k]; _l++ ){
160 Real xinter = theta[k];
161 Real rinter = 0.;
162 if( k == i2 )
163 rinter += 1.;
164 for( size_t _j = 0; _j < con[k]; _j++ )
165 if( _j != _l ) {
166 size_t ind = kmax*nb[k][_j] + kindex[k][_j]; // index of cavity field "j \ k"
167 Real variab2 = tanh( xfield[ind] );
168 Real variab1 = tJ[k][_j] * variab2;
169 xinter += atanh( variab1 );
170 rinter += tJ[k][_j] * rfield[ind] * (1.0 - variab2*variab2) / (1.0 - variab1*variab1);
171 }
172
173 size_t ind = kmax * k + _l; // index of cavity field "k \ l"
174
175 // update xfield
176 Real devs = xinter - xfield[ind];
177 xfield[ind] += devs * damping;
178 if( fabs(devs) > md )
179 md = fabs( devs );
180
181 // update rfield
182 Real devs2 = rinter - rfield[ind];
183 rfield[ind] += devs2 * damping;
184 if( fabs(devs2) > md )
185 md = fabs(devs2);
186 }
187 }
188 }
189 } while( (md > props.tol) && (iters < maxIters) ); // Precision condition reached -> BP and RP finished
190 if( iters == maxIters )
191 if( props.verbose >= 2 )
192 cerr << "init_cor_resp_old: Convergence not reached (md=" << md << ")..." << endl;
193
194 if( md > maxdev )
195 maxdev = md;
196
197 // compute the observables (i.e. magnetizations and responses)******
198
199 for( size_t _k = 0; _k < concav; _k++ ) {
200 size_t k = nb[cavity][_k];
201 Real rinter = 0.;
202 Real xinter = theta[k];
203 if( _k != _i2 )
204 for( size_t _j = 0; _j < con[k]; _j++ ) {
205 size_t ind = kmax*nb[k][_j] + kindex[k][_j]; // index of cavity field "j \ k"
206 Real variab2 = tanh( xfield[ind] );
207 Real variab1 = tJ[k][_j] * variab2;
208 xinter += atanh( variab1 );
209 rinter += tJ[k][_j] * rfield[ind] * (1.0 - variab2*variab2) / (1.0 - variab1*variab1);
210 }
211
212 if( k != i2 )
213 cors[cavity][_i2][_k] = rinter * (1.0 - tanh(xinter) * tanh(xinter));
214 else
215 cors[cavity][_i2][_k] = 0;
216 }
217 } // close for _i2 = 0...concav
218 }
219
220 // restore nb, tJ, con
221 tJ = tJ_org;
222 nb = nb_org;
223 con = con_org;
224
225 return maxdev;
226 }
227
228
229 Real MR::init_cor_resp() {
230 for( size_t i = 0; i < nrVars(); i++ ) {
231 vector<Factor> pairq;
232 BP bpcav(*this, PropertySet()("updates", string("SEQMAX"))("tol", (Real)1.0e-9)("maxiter", (size_t)10000)("verbose", (size_t)0)("logdomain", false));
233 bpcav.makeCavity( i );
234 bpcav.init();
235 bpcav.run();
236
237 BBP bbp( &bpcav, PropertySet()("verbose",(size_t)0)("tol",(Real)1.0e-9)("maxiter",(size_t)10000)("damping",(Real)0.0)("updates",string("SEQ_MAX")) );
238 for( size_t _j = 0; _j < con[i]; _j++ ) {
239 size_t j = nb[i][_j];
240
241 // Create weights for magnetization of some spin
242 Prob p( 2, 0.0 );
243 p[0] = -1.0;
244 p[1] = 1.0;
245
246 // BBP cost function would be the magnetization of spin j
247 vector<Prob> b1_adj;
248 b1_adj.reserve( nrVars() );
249 for( size_t l = 0; l < nrVars(); l++ )
250 if( l == j )
251 b1_adj.push_back( p );
252 else
253 b1_adj.push_back( Prob( 2, 0.0 ) );
254 bbp.init_V( b1_adj );
255
256 // run BBP to estimate adjoints
257 bbp.run();
258
259 for( size_t _k = 0; _k < con[i]; _k++ ) {
260 size_t k = nb[i][_k];
261 if( k != j )
262 cors[i][_j][_k] = (bbp.adj_psi_V(k)[1] - bbp.adj_psi_V(k)[0]);
263 else
264 cors[i][_j][_k] = 0.0;
265 }
266 }
267 }
268
269 return 0.0;
270 }
271
272
273 Real MR::T(size_t i, sub_nb A) {
274 sub_nb _nbi_min_A(con[i]);
275 _nbi_min_A.set();
276 _nbi_min_A &= ~A;
277
278 Real res = theta[i];
279 for( size_t _j = 0; _j < _nbi_min_A.size(); _j++ )
280 if( _nbi_min_A.test(_j) )
281 res += atanh(tJ[i][_j] * M[i][_j]);
282 return tanh(res);
283 }
284
285
286 Real MR::T(size_t i, size_t _j) {
287 sub_nb j(con[i]);
288 j.set(_j);
289 return T(i,j);
290 }
291
292
293 Real MR::Omega(size_t i, size_t _j, size_t _l) {
294 sub_nb jl(con[i]);
295 jl.set(_j);
296 jl.set(_l);
297 Real Tijl = T(i,jl);
298 return Tijl / (1.0 + tJ[i][_l] * M[i][_l] * Tijl);
299 }
300
301
302 Real MR::Gamma(size_t i, size_t _j, size_t _l1, size_t _l2) {
303 sub_nb jll(con[i]);
304 jll.set(_j);
305 Real Tij = T(i,jll);
306 jll.set(_l1);
307 jll.set(_l2);
308 Real Tijll = T(i,jll);
309
310 return (Tijll - Tij) / (1.0 + tJ[i][_l1] * tJ[i][_l2] * M[i][_l1] * M[i][_l2] + tJ[i][_l1] * M[i][_l1] * Tijll + tJ[i][_l2] * M[i][_l2] * Tijll);
311 }
312
313
314 Real MR::Gamma(size_t i, size_t _l1, size_t _l2) {
315 sub_nb ll(con[i]);
316 Real Ti = T(i,ll);
317 ll.set(_l1);
318 ll.set(_l2);
319 Real Till = T(i,ll);
320
321 return (Till - Ti) / (1.0 + tJ[i][_l1] * tJ[i][_l2] * M[i][_l1] * M[i][_l2] + tJ[i][_l1] * M[i][_l1] * Till + tJ[i][_l2] * M[i][_l2] * Till);
322 }
323
324
325 Real MR::_tJ(size_t i, sub_nb A) {
326 sub_nb::size_type _j = A.find_first();
327 if( _j == sub_nb::npos )
328 return 1.0;
329 else
330 return tJ[i][_j] * _tJ(i, A.reset(_j));
331 }
332
333
334 Real MR::appM(size_t i, sub_nb A) {
335 sub_nb::size_type _j = A.find_first();
336 if( _j == sub_nb::npos )
337 return 1.0;
338 else {
339 sub_nb A_j(A); A_j.reset(_j);
340
341 Real result = M[i][_j] * appM(i, A_j);
342 for( size_t _k = 0; _k < A_j.size(); _k++ )
343 if( A_j.test(_k) ) {
344 sub_nb A_jk(A_j); A_jk.reset(_k);
345 result += cors[i][_j][_k] * appM(i,A_jk);
346 }
347
348 return result;
349 }
350 }
351
352
353 void MR::sum_subs(size_t j, sub_nb A, Real *sum_even, Real *sum_odd) {
354 *sum_even = 0.0;
355 *sum_odd = 0.0;
356
357 sub_nb B(A.size());
358 do {
359 if( B.count() % 2 )
360 *sum_odd += _tJ(j,B) * appM(j,B);
361 else
362 *sum_even += _tJ(j,B) * appM(j,B);
363
364 // calc next subset B
365 size_t bit = 0;
366 for( ; bit < A.size(); bit++ )
367 if( A.test(bit) ) {
368 if( B.test(bit) )
369 B.reset(bit);
370 else {
371 B.set(bit);
372 break;
373 }
374 }
375 } while (!B.none());
376 }
377
378
379 void MR::solvemcav() {
380 Real sum_even, sum_odd;
381 Real maxdev;
382 size_t maxruns = 1000;
383
384 makekindex();
385 for(size_t i=0; i<N; i++)
386 for(size_t _j=0; _j<con[i]; _j++)
387 M[i][_j]=0.1;
388
389 size_t run=0;
390 do {
391 maxdev=0.0;
392 run++;
393 for(size_t i=0; i<N; i++){ // for all i
394 for(size_t _j=0; _j<con[i]; _j++){ // for all j in N_i
395 size_t _i = kindex[i][_j];
396 size_t j = nb[i][_j];
397 DAI_ASSERT( nb[j][_i] == i );
398
399 Real newM = 0.0;
400 if( props.updates == Properties::UpdateType::FULL ) {
401 // find indices in nb[j] that do not correspond with i
402 sub_nb _nbj_min_i(con[j]);
403 _nbj_min_i.set();
404 _nbj_min_i.reset(kindex[i][_j]);
405
406 // find indices in nb[i] that do not correspond with j
407 sub_nb _nbi_min_j(con[i]);
408 _nbi_min_j.set();
409 _nbi_min_j.reset(_j);
410
411 sum_subs(j, _nbj_min_i, &sum_even, &sum_odd);
412 newM = (tanh(theta[j]) * sum_even + sum_odd) / (sum_even + tanh(theta[j]) * sum_odd);
413
414 sum_subs(i, _nbi_min_j, &sum_even, &sum_odd);
415 Real denom = sum_even + tanh(theta[i]) * sum_odd;
416 Real numer = 0.0;
417 for(size_t _k=0; _k<con[i]; _k++) if(_k != _j) {
418 sub_nb _nbi_min_jk(_nbi_min_j);
419 _nbi_min_jk.reset(_k);
420 sum_subs(i, _nbi_min_jk, &sum_even, &sum_odd);
421 numer += tJ[i][_k] * cors[i][_j][_k] * (tanh(theta[i]) * sum_even + sum_odd);
422 }
423 newM -= numer / denom;
424 } else if( props.updates == Properties::UpdateType::LINEAR ) {
425 newM = T(j,_i);
426 for(size_t _l=0; _l<con[i]; _l++) if( _l != _j )
427 newM -= Omega(i,_j,_l) * tJ[i][_l] * cors[i][_j][_l];
428 for(size_t _l1=0; _l1<con[j]; _l1++) if( _l1 != _i )
429 for( size_t _l2=_l1+1; _l2<con[j]; _l2++) if( _l2 != _i)
430 newM += Gamma(j,_i,_l1,_l2) * tJ[j][_l1] * tJ[j][_l2] * cors[j][_l1][_l2];
431 }
432
433 Real dev = newM - M[i][_j];
434 // dev *= 0.02;
435 if( fabs(dev) >= maxdev )
436 maxdev = fabs(dev);
437
438 newM = M[i][_j] + dev;
439 if( fabs(newM) > 1.0 )
440 newM = sign(newM);
441 M[i][_j] = newM;
442 }
443 }
444 } while((maxdev>props.tol)&&(run<maxruns));
445
446 _iters = run;
447 if( maxdev > _maxdiff )
448 _maxdiff = maxdev;
449
450 if(run==maxruns){
451 if( props.verbose >= 1 )
452 cerr << "solve_mcav: Convergence not reached (maxdev=" << maxdev << ")..." << endl;
453 }
454 }
455
456
457 void MR::solveM() {
458 for(size_t i=0; i<N; i++) {
459 if( props.updates == Properties::UpdateType::FULL ) {
460 // find indices in nb[i]
461 sub_nb _nbi(con[i]);
462 _nbi.set();
463
464 // calc numerator1 and denominator1
465 Real sum_even, sum_odd;
466 sum_subs(i, _nbi, &sum_even, &sum_odd);
467
468 Mag[i] = (tanh(theta[i]) * sum_even + sum_odd) / (sum_even + tanh(theta[i]) * sum_odd);
469
470 } else if( props.updates == Properties::UpdateType::LINEAR ) {
471 sub_nb empty(con[i]);
472 Mag[i] = T(i,empty);
473
474 for(size_t _l1=0; _l1<con[i]; _l1++)
475 for( size_t _l2=_l1+1; _l2<con[i]; _l2++)
476 Mag[i] += Gamma(i,_l1,_l2) * tJ[i][_l1] * tJ[i][_l2] * cors[i][_l1][_l2];
477 }
478 if(fabs(Mag[i])>1.)
479 Mag[i] = sign(Mag[i]);
480 }
481 }
482
483
484 Real MR::init_cor() {
485 Real md = 0.0;
486 for( size_t i = 0; i < nrVars(); i++ ) {
487 vector<Factor> pairq;
488 if( props.inits == Properties::InitType::CLAMPING ) {
489 BP bpcav(*this, PropertySet()("updates", string("SEQMAX"))("tol", (Real)1.0e-9)("maxiter", (size_t)10000)("verbose", (size_t)0)("logdomain", false));
490 bpcav.makeCavity( i );
491 pairq = calcPairBeliefs( bpcav, delta(i), false, true );
492 md = std::max( md, bpcav.maxDiff() );
493 } else if( props.inits == Properties::InitType::EXACT ) {
494 JTree jtcav(*this, PropertySet()("updates", string("HUGIN"))("verbose", (size_t)0) );
495 jtcav.makeCavity( i );
496 pairq = calcPairBeliefs( jtcav, delta(i), false, true );
497 }
498 for( size_t jk = 0; jk < pairq.size(); jk++ ) {
499 VarSet::const_iterator kit = pairq[jk].vars().begin();
500 size_t j = findVar( *(kit) );
501 size_t k = findVar( *(++kit) );
502 pairq[jk].normalize();
503 Real cor = (pairq[jk][3] - pairq[jk][2] - pairq[jk][1] + pairq[jk][0]) - (pairq[jk][3] + pairq[jk][2] - pairq[jk][1] - pairq[jk][0]) * (pairq[jk][3] - pairq[jk][2] + pairq[jk][1] - pairq[jk][0]);
504 for( size_t _j = 0; _j < con[i]; _j++ ) if( nb[i][_j] == j )
505 for( size_t _k = 0; _k < con[i]; _k++ ) if( nb[i][_k] == k ) {
506 cors[i][_j][_k] = cor;
507 cors[i][_k][_j] = cor;
508 }
509 }
510 }
511 return md;
512 }
513
514
515 string MR::identify() const {
516 return string(Name) + printProperties();
517 }
518
519
520 Real MR::run() {
521 if( supported ) {
522 if( props.verbose >= 1 )
523 cerr << "Starting " << identify() << "...";
524
525 double tic = toc();
526
527 M.resize(N);
528 for(size_t i=0; i<N; i++)
529 M[i].resize(kmax);
530
531 cors.resize(N);
532 for(size_t i=0; i<N; i++)
533 cors[i].resize(kmax);
534 for(size_t i=0; i<N; i++)
535 for(size_t j=0; j<kmax; j++)
536 cors[i][j].resize(kmax);
537
538 kindex.resize(N);
539 for(size_t i=0; i<N; i++)
540 kindex[i].resize(kmax);
541
542 if( props.inits == Properties::InitType::RESPPROP ) {
543 Real md = init_cor_resp();
544 if( md > _maxdiff )
545 _maxdiff = md;
546 } else if( props.inits == Properties::InitType::RESPPROPOLD ) {
547 Real md = init_cor_resp_old();
548 if( md > _maxdiff )
549 _maxdiff = md;
550 } else if( props.inits == Properties::InitType::EXACT ) {
551 Real md = init_cor();
552 if( md > _maxdiff )
553 _maxdiff = md;
554 }
555 else if( props.inits == Properties::InitType::CLAMPING ) {
556 Real md = init_cor();
557 if( md > _maxdiff )
558 _maxdiff = md;
559 }
560
561 solvemcav();
562
563 Mag.resize(N);
564 solveM();
565
566 if( props.verbose >= 1 )
567 cerr << Name << " needed " << toc() - tic << " seconds." << endl;
568
569 return _maxdiff;
570 } else
571 return 1.0;
572 }
573
574
575 void MR::makekindex() {
576 for(size_t i=0; i<N; i++)
577 for(size_t j=0; j<con[i]; j++) {
578 size_t ij = nb[i][j]; // ij is the j'th neighbour of spin i
579 size_t k=0;
580 while( nb[ij][k] != i )
581 k++;
582 kindex[i][j] = k; // the j'th neighbour of spin i has spin i as its k'th neighbour
583 }
584 }
585
586
587 Factor MR::beliefV( size_t i ) const {
588 if( supported ) {
589 Prob x(2);
590 x[0] = 0.5 - Mag[i] / 2.0;
591 x[1] = 0.5 + Mag[i] / 2.0;
592
593 return Factor( var(i), x );
594 } else
595 return Factor();
596 }
597
598
599 vector<Factor> MR::beliefs() const {
600 vector<Factor> result;
601 for( size_t i = 0; i < nrVars(); i++ )
602 result.push_back( belief( var(i) ) );
603 return result;
604 }
605
606
607
608 MR::MR( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), supported(true), _maxdiff(0.0), _iters(0) {
609 setProperties( opts );
610
611 // check whether all vars in fg are binary
612 // check whether connectivity is <= kmax
613 for( size_t i = 0; i < fg.nrVars(); i++ )
614 if( (fg.var(i).states() > 2) || (fg.delta(i).size() > kmax) ) {
615 supported = false;
616 break;
617 }
618
619 if( !supported )
620 DAI_THROWE(NOT_IMPLEMENTED,"MR only supports binary variables with low connectivity");
621
622 // check whether all interactions are pairwise or single
623 for( size_t I = 0; I < fg.nrFactors(); I++ )
624 if( fg.factor(I).vars().size() > 2 ) {
625 supported = false;
626 break;
627 }
628
629 if( !supported )
630 DAI_THROWE(NOT_IMPLEMENTED,"MR does not support higher order interactions (only single and pairwise are supported)");
631
632 // create w and th
633 size_t Nin = fg.nrVars();
634
635 Real *w = new Real[Nin*Nin];
636 Real *th = new Real[Nin];
637
638 for( size_t i = 0; i < Nin; i++ ) {
639 th[i] = 0.0;
640 for( size_t j = 0; j < Nin; j++ )
641 w[i*Nin+j] = 0.0;
642 }
643
644 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
645 const Factor &psi = fg.factor(I);
646 if( psi.vars().size() == 1 ) {
647 size_t i = fg.findVar( *(psi.vars().begin()) );
648 th[i] += 0.5 * log(psi[1] / psi[0]);
649 } else if( psi.vars().size() == 2 ) {
650 size_t i = fg.findVar( *(psi.vars().begin()) );
651 VarSet::const_iterator jit = psi.vars().begin();
652 size_t j = fg.findVar( *(++jit) );
653
654 w[i*Nin+j] += 0.25 * log(psi[3] * psi[0] / (psi[2] * psi[1]));
655 w[j*Nin+i] += 0.25 * log(psi[3] * psi[0] / (psi[2] * psi[1]));
656
657 th[i] += 0.25 * log(psi[3] / psi[2] * psi[1] / psi[0]);
658 th[j] += 0.25 * log(psi[3] / psi[1] * psi[2] / psi[0]);
659 }
660 }
661
662 init(Nin, w, th);
663
664 delete th;
665 delete w;
666 }
667
668
669 } // end of namespace dai