Merge branch 'pletscher'
[libdai.git] / src / mr.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2007 Bastian Wemmenhove
8 * Copyright (C) 2007-2009 Joris Mooij [joris dot mooij at libdai dot org]
9 * Copyright (C) 2007 Radboud University Nijmegen, The Netherlands
10 */
11
12
13 #include <cstdio>
14 #include <ctime>
15 #include <cmath>
16 #include <cstdlib>
17 #include <dai/mr.h>
18 #include <dai/bp.h>
19 #include <dai/jtree.h>
20 #include <dai/util.h>
21
22
23 namespace dai {
24
25
26 using namespace std;
27
28
29 const char *MR::Name = "MR";
30
31
32 void MR::setProperties( const PropertySet &opts ) {
33 DAI_ASSERT( opts.hasKey("tol") );
34 DAI_ASSERT( opts.hasKey("verbose") );
35 DAI_ASSERT( opts.hasKey("updates") );
36 DAI_ASSERT( opts.hasKey("inits") );
37
38 props.tol = opts.getStringAs<double>("tol");
39 props.verbose = opts.getStringAs<size_t>("verbose");
40 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
41 props.inits = opts.getStringAs<Properties::InitType>("inits");
42 }
43
44
45 PropertySet MR::getProperties() const {
46 PropertySet opts;
47 opts.Set( "tol", props.tol );
48 opts.Set( "verbose", props.verbose );
49 opts.Set( "updates", props.updates );
50 opts.Set( "inits", props.inits );
51 return opts;
52 }
53
54
55 string MR::printProperties() const {
56 stringstream s( stringstream::out );
57 s << "[";
58 s << "tol=" << props.tol << ",";
59 s << "verbose=" << props.verbose << ",";
60 s << "updates=" << props.updates << ",";
61 s << "inits=" << props.inits << "]";
62 return s.str();
63 }
64
65
66 // init N, con, nb, tJ, theta
67 void MR::init(size_t Nin, double *_w, double *_th) {
68 size_t i,j;
69
70 N = Nin;
71
72 con.resize(N);
73 nb.resize(N);
74 tJ.resize(N);
75 for(i=0; i<N; i++ ) {
76 nb[i].resize(kmax);
77 tJ[i].resize(kmax);
78 con[i]=0;
79 for(j=0; j<N; j++ )
80 if( _w[i*N+j] != 0.0 ) {
81 nb[i][con[i]] = j;
82 tJ[i][con[i]] = tanh(_w[i*N+j]);
83 con[i]++;
84 }
85 }
86
87 theta.resize(N);
88 for(i=0; i<N; i++)
89 theta[i] = _th[i];
90 }
91
92
93 // calculate cors
94 double MR::init_cor_resp() {
95 size_t j,k,l, runx,i2;
96 double variab1, variab2;
97 double md, maxdev;
98 double thbJsite[kmax];
99 double xinter;
100 double rinter;
101 double res[kmax];
102 size_t s2;
103 size_t flag;
104 size_t concav;
105 size_t runs = 3000;
106 double eps = 0.2;
107 size_t cavity;
108
109 vector<vector<double> > tJ_org;
110 vector<vector<size_t> > nb_org;
111 vector<size_t> con_org;
112 vector<double> theta_org;
113
114 vector<double> xfield(N*kmax,0.0);
115 vector<double> rfield(N*kmax,0.0);
116 vector<double> Hfield(N,0.0);
117 vector<double> devs(N*kmax,0.0);
118 vector<double> devs2(N*kmax,0.0);
119 vector<double> dev(N,0.0);
120 vector<double> avmag(N,0.0);
121
122 // save original tJ, nb
123 nb_org = nb;
124 tJ_org = tJ;
125 con_org = con;
126 theta_org = theta;
127
128 maxdev = 0.0;
129 for(cavity=0; cavity<N; cavity++){ // for each spin to be removed
130 con = con_org;
131 concav=con[cavity];
132
133 nb = nb_org;
134 tJ = tJ_org;
135
136 // Adapt the graph variables nb[], tJ[] and con[]
137 for(size_t i=0; i<con[cavity]; i++) {
138 size_t ij = nb[cavity][i];
139 flag=0;
140 j=0;
141 do{
142 if(nb[ij][j]==cavity){
143 while(j<(con[ij]-1)){
144 nb[ij][j]=nb[ij][j+1];
145 tJ[ij][j] = tJ[ij][j+1];
146 j++;
147 }
148 flag=1;
149 }
150 j++;
151 } while(flag==0);
152 }
153 for(size_t i=0; i<con[cavity]; i++)
154 con[nb[cavity][i]]--;
155 con[cavity] = 0;
156
157
158 // Do everything starting from the new graph********
159
160 makekindex();
161 theta = theta_org;
162
163 for(size_t i=0; i<kmax*N; i++)
164 xfield[i] = 3.0*(2*rnd_uniform()-1.);
165
166 for(i2=0; i2<concav; i2++){ // Subsequently apply a field to each cavity spin ****************
167
168 s2 = nb[cavity][i2]; // identify the index of the cavity spin
169 for(size_t i=0; i<con[s2]; i++)
170 rfield[kmax*s2+i] = 1.;
171
172 runx=0;
173 do { // From here start the response and belief propagation
174 runx++;
175 md=0.0;
176 for(k=0; k<N; k++){
177 if(k!=cavity) {
178 for(size_t i=0; i<con[k]; i++)
179 thbJsite[i] = tJ[k][i];
180 for(l=0; l<con[k]; l++){
181 xinter = 1.;
182 rinter = 0.;
183 if(k==s2) rinter += 1.;
184 for(j=0; j<con[k]; j++)
185 if(j!=l){
186 variab2 = tanh(xfield[kmax*nb[k][j]+kindex[k][j]]);
187 variab1 = thbJsite[j]*variab2;
188 xinter *= (1.+variab1)/(1.-variab1);
189
190 rinter += thbJsite[j]*rfield[kmax*nb[k][j]+kindex[k][j]]*(1-variab2*variab2)/(1-variab1*variab1);
191 }
192
193 variab1 = 0.5*log(xinter);
194 xinter = variab1 + theta[k];
195 devs[kmax*k+l] = xinter-xfield[kmax*k+l];
196 xfield[kmax*k+l] = xfield[kmax*k+l]+devs[kmax*k+l]*eps;
197 if( fabs(devs[kmax*k+l]) > md )
198 md = fabs(devs[kmax*k+l]);
199
200 devs2[kmax*k+l] = rinter-rfield[kmax*k+l];
201 rfield[kmax*k+l] = rfield[kmax*k+l]+devs2[kmax*k+l]*eps;
202 if( fabs(devs2[kmax*k+l]) > md )
203 md = fabs(devs2[kmax*k+l]);
204 }
205 }
206 }
207 } while((md > props.tol)&&(runx<runs)); // Precision condition reached -> BP and RP finished
208 if(runx==runs)
209 if( props.verbose >= 2 )
210 cerr << "init_cor_resp: Convergence not reached (md=" << md << ")..." << endl;
211 if(md > maxdev)
212 maxdev = md;
213
214 // compute the observables (i.e. magnetizations and responses)******
215
216 for(size_t i=0; i<concav; i++){
217 rinter = 0.;
218 xinter = 1.;
219 if(i!=i2)
220 for(j=0; j<con[nb[cavity][i]]; j++){
221 variab2 = tanh(xfield[kmax*nb[nb[cavity][i]][j]+kindex[nb[cavity][i]][j]]);
222 variab1 = tJ[nb[cavity][i]][j]*variab2;
223 rinter += tJ[nb[cavity][i]][j]*rfield[kmax*nb[nb[cavity][i]][j]+kindex[nb[cavity][i]][j]]*(1-variab2*variab2)/(1-variab1*variab1);
224 xinter *= (1.+variab1)/(1.-variab1);
225 }
226 xinter = tanh(0.5*log(xinter)+theta[nb[cavity][i]]);
227 res[i] = rinter*(1-xinter*xinter);
228 }
229
230 // *******************
231
232 for(size_t i=0; i<concav; i++)
233 if(nb[cavity][i]!=s2)
234 // if(i!=i2)
235 cors[cavity][i2][i] = res[i];
236 else
237 cors[cavity][i2][i] = 0;
238 } // close for i2 = 0...concav
239 }
240
241 // restore nb, tJ, con
242 tJ = tJ_org;
243 nb = nb_org;
244 con = con_org;
245 theta = theta_org;
246
247 return maxdev;
248 }
249
250
251 double MR::T(size_t i, sub_nb A) {
252 // i is a variable index
253 // A is a subset of nb[i]
254 //
255 // calculate T{(i)}_A as defined in Rizzo&Montanari e-print (2.17)
256
257 sub_nb _nbi_min_A(con[i]);
258 _nbi_min_A.set();
259 _nbi_min_A &= ~A;
260
261 double res = theta[i];
262 for( size_t _j = 0; _j < _nbi_min_A.size(); _j++ )
263 if( _nbi_min_A.test(_j) )
264 res += atanh(tJ[i][_j] * M[i][_j]);
265 return tanh(res);
266 }
267
268
269 double MR::T(size_t i, size_t _j) {
270 sub_nb j(con[i]);
271 j.set(_j);
272 return T(i,j);
273 }
274
275
276 double MR::Omega(size_t i, size_t _j, size_t _l) {
277 sub_nb jl(con[i]);
278 jl.set(_j);
279 jl.set(_l);
280 double Tijl = T(i,jl);
281 return Tijl / (1.0 + tJ[i][_l] * M[i][_l] * Tijl);
282 }
283
284
285 double MR::Gamma(size_t i, size_t _j, size_t _l1, size_t _l2) {
286 sub_nb jll(con[i]);
287 jll.set(_j);
288 double Tij = T(i,jll);
289 jll.set(_l1);
290 jll.set(_l2);
291 double Tijll = T(i,jll);
292
293 return (Tijll - Tij) / (1.0 + tJ[i][_l1] * tJ[i][_l2] * M[i][_l1] * M[i][_l2] + tJ[i][_l1] * M[i][_l1] * Tijll + tJ[i][_l2] * M[i][_l2] * Tijll);
294 }
295
296
297 double MR::Gamma(size_t i, size_t _l1, size_t _l2) {
298 sub_nb ll(con[i]);
299 double Ti = T(i,ll);
300 ll.set(_l1);
301 ll.set(_l2);
302 double Till = T(i,ll);
303
304 return (Till - Ti) / (1.0 + tJ[i][_l1] * tJ[i][_l2] * M[i][_l1] * M[i][_l2] + tJ[i][_l1] * M[i][_l1] * Till + tJ[i][_l2] * M[i][_l2] * Till);
305 }
306
307
308 double MR::_tJ(size_t i, sub_nb A) {
309 // i is a variable index
310 // A is a subset of nb[i]
311 //
312 // calculate the product of all tJ[i][_j] for _j in A
313
314 sub_nb::size_type _j = A.find_first();
315 if( _j == sub_nb::npos )
316 return 1.0;
317 else
318 return tJ[i][_j] * _tJ(i, A.reset(_j));
319 }
320
321
322 double MR::appM(size_t i, sub_nb A) {
323 // i is a variable index
324 // A is a subset of nb[i]
325 //
326 // calculate the moment of variables in A from M and cors, neglecting higher order cumulants,
327 // defined as the sum over all partitions of A into subsets of cardinality two at most of the
328 // product of the cumulants (either first order, i.e. M, or second order, i.e. cors) of the
329 // entries of the partitions
330
331 sub_nb::size_type _j = A.find_first();
332 if( _j == sub_nb::npos )
333 return 1.0;
334 else {
335 sub_nb A_j(A); A_j.reset(_j);
336
337 double result = M[i][_j] * appM(i, A_j);
338 for( size_t _k = 0; _k < A_j.size(); _k++ )
339 if( A_j.test(_k) ) {
340 sub_nb A_jk(A_j); A_jk.reset(_k);
341 result += cors[i][_j][_k] * appM(i,A_jk);
342 }
343
344 return result;
345 }
346 }
347
348
349 void MR::sum_subs(size_t j, sub_nb A, double *sum_even, double *sum_odd) {
350 // j is a variable index
351 // A is a subset of nb[j]
352
353 // calculate sum over all even/odd subsets B of A of _tJ(j,B) appM(j,B)
354
355 *sum_even = 0.0;
356 *sum_odd = 0.0;
357
358 sub_nb B(A.size());
359 do {
360 if( B.count() % 2 )
361 *sum_odd += _tJ(j,B) * appM(j,B);
362 else
363 *sum_even += _tJ(j,B) * appM(j,B);
364
365 // calc next subset B
366 size_t bit = 0;
367 for( ; bit < A.size(); bit++ )
368 if( A.test(bit) ) {
369 if( B.test(bit) )
370 B.reset(bit);
371 else {
372 B.set(bit);
373 break;
374 }
375 }
376 } while (!B.none());
377 }
378
379
380 void MR::solvemcav() {
381 double sum_even, sum_odd;
382 double maxdev;
383 size_t maxruns = 1000;
384
385 makekindex();
386 for(size_t i=0; i<N; i++)
387 for(size_t _j=0; _j<con[i]; _j++)
388 M[i][_j]=0.1;
389
390 size_t run=0;
391 do {
392 maxdev=0.0;
393 run++;
394 for(size_t i=0; i<N; i++){ // for all i
395 for(size_t _j=0; _j<con[i]; _j++){ // for all j in N_i
396 size_t _i = kindex[i][_j];
397 size_t j = nb[i][_j];
398 DAI_ASSERT( nb[j][_i] == i );
399
400 double newM = 0.0;
401 if( props.updates == Properties::UpdateType::FULL ) {
402 // find indices in nb[j] that do not correspond with i
403 sub_nb _nbj_min_i(con[j]);
404 _nbj_min_i.set();
405 _nbj_min_i.reset(kindex[i][_j]);
406
407 // find indices in nb[i] that do not correspond with j
408 sub_nb _nbi_min_j(con[i]);
409 _nbi_min_j.set();
410 _nbi_min_j.reset(_j);
411
412 sum_subs(j, _nbj_min_i, &sum_even, &sum_odd);
413 newM = (tanh(theta[j]) * sum_even + sum_odd) / (sum_even + tanh(theta[j]) * sum_odd);
414
415 sum_subs(i, _nbi_min_j, &sum_even, &sum_odd);
416 double denom = sum_even + tanh(theta[i]) * sum_odd;
417 double numer = 0.0;
418 for(size_t _k=0; _k<con[i]; _k++) if(_k != _j) {
419 sub_nb _nbi_min_jk(_nbi_min_j);
420 _nbi_min_jk.reset(_k);
421 sum_subs(i, _nbi_min_jk, &sum_even, &sum_odd);
422 numer += tJ[i][_k] * cors[i][_j][_k] * (tanh(theta[i]) * sum_even + sum_odd);
423 }
424 newM -= numer / denom;
425 } else if( props.updates == Properties::UpdateType::LINEAR ) {
426 newM = T(j,_i);
427 for(size_t _l=0; _l<con[i]; _l++) if( _l != _j )
428 newM -= Omega(i,_j,_l) * tJ[i][_l] * cors[i][_j][_l];
429 for(size_t _l1=0; _l1<con[j]; _l1++) if( _l1 != _i )
430 for( size_t _l2=_l1+1; _l2<con[j]; _l2++) if( _l2 != _i)
431 newM += Gamma(j,_i,_l1,_l2) * tJ[j][_l1] * tJ[j][_l2] * cors[j][_l1][_l2];
432 }
433
434 double dev = newM - M[i][_j];
435 // dev *= 0.02;
436 if( fabs(dev) >= maxdev )
437 maxdev = fabs(dev);
438
439 newM = M[i][_j] + dev;
440 if( fabs(newM) > 1.0 )
441 newM = sign(newM);
442 M[i][_j] = newM;
443 }
444 }
445 } while((maxdev>props.tol)&&(run<maxruns));
446
447 _iters = run;
448 if( maxdev > _maxdiff )
449 _maxdiff = maxdev;
450
451 if(run==maxruns){
452 if( props.verbose >= 1 )
453 cerr << "solve_mcav: Convergence not reached (maxdev=" << maxdev << ")..." << endl;
454 }
455 }
456
457
458 void MR::solveM() {
459 for(size_t i=0; i<N; i++) {
460 if( props.updates == Properties::UpdateType::FULL ) {
461 // find indices in nb[i]
462 sub_nb _nbi(con[i]);
463 _nbi.set();
464
465 // calc numerator1 and denominator1
466 double sum_even, sum_odd;
467 sum_subs(i, _nbi, &sum_even, &sum_odd);
468
469 Mag[i] = (tanh(theta[i]) * sum_even + sum_odd) / (sum_even + tanh(theta[i]) * sum_odd);
470
471 } else if( props.updates == Properties::UpdateType::LINEAR ) {
472 sub_nb empty(con[i]);
473 Mag[i] = T(i,empty);
474
475 for(size_t _l1=0; _l1<con[i]; _l1++)
476 for( size_t _l2=_l1+1; _l2<con[i]; _l2++)
477 Mag[i] += Gamma(i,_l1,_l2) * tJ[i][_l1] * tJ[i][_l2] * cors[i][_l1][_l2];
478 }
479 if(fabs(Mag[i])>1.)
480 Mag[i] = sign(Mag[i]);
481 }
482 }
483
484
485 void MR::init_cor() {
486 for( size_t i = 0; i < nrVars(); i++ ) {
487 vector<Factor> pairq;
488 if( props.inits == Properties::InitType::CLAMPING ) {
489 BP bpcav(*this, PropertySet()("updates", string("SEQMAX"))("tol", 1.0e-9)("maxiter", (size_t)10000)("verbose", (size_t)0)("logdomain", false));
490 bpcav.makeCavity( i );
491 pairq = calcPairBeliefs( bpcav, delta(i), false );
492 } else if( props.inits == Properties::InitType::EXACT ) {
493 JTree jtcav(*this, PropertySet()("updates", string("HUGIN"))("verbose", (size_t)0) );
494 jtcav.makeCavity( i );
495 pairq = calcPairBeliefs( jtcav, delta(i), false );
496 }
497 for( size_t jk = 0; jk < pairq.size(); jk++ ) {
498 VarSet::const_iterator kit = pairq[jk].vars().begin();
499 size_t j = findVar( *(kit) );
500 size_t k = findVar( *(++kit) );
501 pairq[jk].normalize();
502 double cor = (pairq[jk][3] - pairq[jk][2] - pairq[jk][1] + pairq[jk][0]) - (pairq[jk][3] + pairq[jk][2] - pairq[jk][1] - pairq[jk][0]) * (pairq[jk][3] - pairq[jk][2] + pairq[jk][1] - pairq[jk][0]);
503 for( size_t _j = 0; _j < con[i]; _j++ ) if( nb[i][_j] == j )
504 for( size_t _k = 0; _k < con[i]; _k++ ) if( nb[i][_k] == k ) {
505 cors[i][_j][_k] = cor;
506 cors[i][_k][_j] = cor;
507 }
508 }
509 }
510 }
511
512
513 string MR::identify() const {
514 return string(Name) + printProperties();
515 }
516
517
518 double MR::run() {
519 if( supported ) {
520 if( props.verbose >= 1 )
521 cerr << "Starting " << identify() << "...";
522
523 double tic = toc();
524 // Diffs diffs(nrVars(), 1.0);
525
526 M.resize(N);
527 for(size_t i=0; i<N; i++)
528 M[i].resize(kmax);
529
530 cors.resize(N);
531 for(size_t i=0; i<N; i++)
532 cors[i].resize(kmax);
533 for(size_t i=0; i<N; i++)
534 for(size_t j=0; j<kmax; j++)
535 cors[i][j].resize(kmax);
536
537 kindex.resize(N);
538 for(size_t i=0; i<N; i++)
539 kindex[i].resize(kmax);
540
541 if( props.inits == Properties::InitType::RESPPROP ) {
542 double md = init_cor_resp();
543 if( md > _maxdiff )
544 _maxdiff = md;
545 } else if( props.inits == Properties::InitType::EXACT )
546 init_cor(); // FIXME no MaxDiff() calculation
547 else if( props.inits == Properties::InitType::CLAMPING )
548 init_cor(); // FIXME no MaxDiff() calculation
549
550 solvemcav();
551
552 Mag.resize(N);
553 solveM();
554
555 if( props.verbose >= 1 )
556 cerr << Name << " needed " << toc() - tic << " seconds." << endl;
557
558 return 0.0;
559 } else
560 return 1.0;
561 }
562
563
564 void MR::makekindex() {
565 for(size_t i=0; i<N; i++)
566 for(size_t j=0; j<con[i]; j++) {
567 size_t ij = nb[i][j]; // ij is the j'th neighbour of spin i
568 size_t k=0;
569 while( nb[ij][k] != i )
570 k++;
571 kindex[i][j] = k; // the j'th neighbour of spin i has spin i as its k'th neighbour
572 }
573 }
574
575
576 Factor MR::belief( const Var &n ) const {
577 if( supported ) {
578 size_t i = findVar( n );
579
580 Real x[2];
581 x[0] = 0.5 - Mag[i] / 2.0;
582 x[1] = 0.5 + Mag[i] / 2.0;
583
584 return Factor( n, x );
585 } else
586 return Factor();
587 }
588
589
590 vector<Factor> MR::beliefs() const {
591 vector<Factor> result;
592 for( size_t i = 0; i < nrVars(); i++ )
593 result.push_back( belief( var(i) ) );
594 return result;
595 }
596
597
598
599 MR::MR( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), supported(true), _maxdiff(0.0), _iters(0) {
600 setProperties( opts );
601
602 // check whether all vars in fg are binary
603 // check whether connectivity is <= kmax
604 for( size_t i = 0; i < fg.nrVars(); i++ )
605 if( (fg.var(i).states() > 2) || (fg.delta(i).size() > kmax) ) {
606 supported = false;
607 break;
608 }
609
610 if( !supported )
611 return;
612
613 // check whether all interactions are pairwise or single
614 for( size_t I = 0; I < fg.nrFactors(); I++ )
615 if( fg.factor(I).vars().size() > 2 ) {
616 supported = false;
617 break;
618 }
619
620 if( !supported )
621 return;
622
623 // create w and th
624 size_t Nin = fg.nrVars();
625
626 double *w = new double[Nin*Nin];
627 double *th = new double[Nin];
628
629 for( size_t i = 0; i < Nin; i++ ) {
630 th[i] = 0.0;
631 for( size_t j = 0; j < Nin; j++ )
632 w[i*Nin+j] = 0.0;
633 }
634
635 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
636 const Factor &psi = fg.factor(I);
637 if( psi.vars().size() == 1 ) {
638 size_t i = fg.findVar( *(psi.vars().begin()) );
639 th[i] += 0.5 * log(psi[1] / psi[0]);
640 } else if( psi.vars().size() == 2 ) {
641 size_t i = fg.findVar( *(psi.vars().begin()) );
642 VarSet::const_iterator jit = psi.vars().begin();
643 size_t j = fg.findVar( *(++jit) );
644
645 w[i*Nin+j] += 0.25 * log(psi[3] * psi[0] / (psi[2] * psi[1]));
646 w[j*Nin+i] += 0.25 * log(psi[3] * psi[0] / (psi[2] * psi[1]));
647
648 th[i] += 0.25 * log(psi[3] / psi[2] * psi[1] / psi[0]);
649 th[j] += 0.25 * log(psi[3] / psi[1] * psi[2] / psi[0]);
650 }
651 }
652
653 init(Nin, w, th);
654
655 delete th;
656 delete w;
657 }
658
659
660 } // end of namespace dai