Improved WeightedGraph code and added unit tests
[libdai.git] / src / mr.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2007 Bastian Wemmenhove
8 * Copyright (C) 2007-2010 Joris Mooij [joris dot mooij at libdai dot org]
9 * Copyright (C) 2007 Radboud University Nijmegen, The Netherlands
10 */
11
12
13 #include <cstdio>
14 #include <ctime>
15 #include <cmath>
16 #include <cstdlib>
17 #include <dai/mr.h>
18 #include <dai/bp.h>
19 #include <dai/jtree.h>
20 #include <dai/util.h>
21 #include <dai/bbp.h>
22
23
24 namespace dai {
25
26
27 using namespace std;
28
29
30 const char *MR::Name = "MR";
31
32
33 void MR::setProperties( const PropertySet &opts ) {
34 DAI_ASSERT( opts.hasKey("tol") );
35 DAI_ASSERT( opts.hasKey("verbose") );
36 DAI_ASSERT( opts.hasKey("updates") );
37 DAI_ASSERT( opts.hasKey("inits") );
38
39 props.tol = opts.getStringAs<Real>("tol");
40 props.verbose = opts.getStringAs<size_t>("verbose");
41 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
42 props.inits = opts.getStringAs<Properties::InitType>("inits");
43 }
44
45
46 PropertySet MR::getProperties() const {
47 PropertySet opts;
48 opts.Set( "tol", props.tol );
49 opts.Set( "verbose", props.verbose );
50 opts.Set( "updates", props.updates );
51 opts.Set( "inits", props.inits );
52 return opts;
53 }
54
55
56 string MR::printProperties() const {
57 stringstream s( stringstream::out );
58 s << "[";
59 s << "tol=" << props.tol << ",";
60 s << "verbose=" << props.verbose << ",";
61 s << "updates=" << props.updates << ",";
62 s << "inits=" << props.inits << "]";
63 return s.str();
64 }
65
66
67 Real MR::T(size_t i, sub_nb A) {
68 sub_nb _nbi_min_A(G.nb(i).size());
69 _nbi_min_A.set();
70 _nbi_min_A &= ~A;
71
72 Real res = theta[i];
73 for( size_t _j = 0; _j < _nbi_min_A.size(); _j++ )
74 if( _nbi_min_A.test(_j) )
75 res += atanh(tJ[i][_j] * M[i][_j]);
76 return tanh(res);
77 }
78
79
80 Real MR::T(size_t i, size_t _j) {
81 sub_nb j(G.nb(i).size());
82 j.set(_j);
83 return T(i,j);
84 }
85
86
87 Real MR::Omega(size_t i, size_t _j, size_t _l) {
88 sub_nb jl(G.nb(i).size());
89 jl.set(_j);
90 jl.set(_l);
91 Real Tijl = T(i,jl);
92 return Tijl / (1.0 + tJ[i][_l] * M[i][_l] * Tijl);
93 }
94
95
96 Real MR::Gamma(size_t i, size_t _j, size_t _l1, size_t _l2) {
97 sub_nb jll(G.nb(i).size());
98 jll.set(_j);
99 Real Tij = T(i,jll);
100 jll.set(_l1);
101 jll.set(_l2);
102 Real Tijll = T(i,jll);
103
104 return (Tijll - Tij) / (1.0 + tJ[i][_l1] * tJ[i][_l2] * M[i][_l1] * M[i][_l2] + tJ[i][_l1] * M[i][_l1] * Tijll + tJ[i][_l2] * M[i][_l2] * Tijll);
105 }
106
107
108 Real MR::Gamma(size_t i, size_t _l1, size_t _l2) {
109 sub_nb ll(G.nb(i).size());
110 Real Ti = T(i,ll);
111 ll.set(_l1);
112 ll.set(_l2);
113 Real Till = T(i,ll);
114
115 return (Till - Ti) / (1.0 + tJ[i][_l1] * tJ[i][_l2] * M[i][_l1] * M[i][_l2] + tJ[i][_l1] * M[i][_l1] * Till + tJ[i][_l2] * M[i][_l2] * Till);
116 }
117
118
119 Real MR::_tJ(size_t i, sub_nb A) {
120 sub_nb::size_type _j = A.find_first();
121 if( _j == sub_nb::npos )
122 return 1.0;
123 else
124 return tJ[i][_j] * _tJ(i, A.reset(_j));
125 }
126
127
128 Real MR::appM(size_t i, sub_nb A) {
129 sub_nb::size_type _j = A.find_first();
130 if( _j == sub_nb::npos )
131 return 1.0;
132 else {
133 sub_nb A_j(A); A_j.reset(_j);
134
135 Real result = M[i][_j] * appM(i, A_j);
136 for( size_t _k = 0; _k < A_j.size(); _k++ )
137 if( A_j.test(_k) ) {
138 sub_nb A_jk(A_j); A_jk.reset(_k);
139 result += cors[i][_j][_k] * appM(i,A_jk);
140 }
141
142 return result;
143 }
144 }
145
146
147 void MR::sum_subs(size_t j, sub_nb A, Real *sum_even, Real *sum_odd) {
148 *sum_even = 0.0;
149 *sum_odd = 0.0;
150
151 sub_nb B(A.size());
152 do {
153 if( B.count() % 2 )
154 *sum_odd += _tJ(j,B) * appM(j,B);
155 else
156 *sum_even += _tJ(j,B) * appM(j,B);
157
158 // calc next subset B
159 size_t bit = 0;
160 for( ; bit < A.size(); bit++ )
161 if( A.test(bit) ) {
162 if( B.test(bit) )
163 B.reset(bit);
164 else {
165 B.set(bit);
166 break;
167 }
168 }
169 } while (!B.none());
170 }
171
172
173 void MR::propagateCavityFields() {
174 Real sum_even, sum_odd;
175 Real maxdev;
176 size_t maxruns = 1000;
177
178 for( size_t i = 0; i < G.nrNodes(); i++ )
179 foreach( const Neighbor &j, G.nb(i) )
180 M[i][j.iter] = 0.1;
181
182 size_t run=0;
183 do {
184 maxdev=0.0;
185 run++;
186 for( size_t i = 0; i < G.nrNodes(); i++ ) {
187 foreach( const Neighbor &j, G.nb(i) ) {
188 size_t _j = j.iter;
189 size_t _i = G.findNb(j,i);
190 DAI_ASSERT( G.nb(j,_i) == i );
191
192 Real newM = 0.0;
193 if( props.updates == Properties::UpdateType::FULL ) {
194 // find indices in nb(j) that do not correspond with i
195 sub_nb _nbj_min_i(G.nb(j).size());
196 _nbj_min_i.set();
197 _nbj_min_i.reset(_i);
198
199 // find indices in nb(i) that do not correspond with j
200 sub_nb _nbi_min_j(G.nb(i).size());
201 _nbi_min_j.set();
202 _nbi_min_j.reset(_j);
203
204 sum_subs(j, _nbj_min_i, &sum_even, &sum_odd);
205 newM = (tanh(theta[j]) * sum_even + sum_odd) / (sum_even + tanh(theta[j]) * sum_odd);
206
207 sum_subs(i, _nbi_min_j, &sum_even, &sum_odd);
208 Real denom = sum_even + tanh(theta[i]) * sum_odd;
209 Real numer = 0.0;
210 for(size_t _k=0; _k < G.nb(i).size(); _k++) if(_k != _j) {
211 sub_nb _nbi_min_jk(_nbi_min_j);
212 _nbi_min_jk.reset(_k);
213 sum_subs(i, _nbi_min_jk, &sum_even, &sum_odd);
214 numer += tJ[i][_k] * cors[i][_j][_k] * (tanh(theta[i]) * sum_even + sum_odd);
215 }
216 newM -= numer / denom;
217 } else if( props.updates == Properties::UpdateType::LINEAR ) {
218 newM = T(j,_i);
219 for(size_t _l=0; _l<G.nb(i).size(); _l++) if( _l != _j )
220 newM -= Omega(i,_j,_l) * tJ[i][_l] * cors[i][_j][_l];
221 for(size_t _l1=0; _l1<G.nb(j).size(); _l1++) if( _l1 != _i )
222 for( size_t _l2=_l1+1; _l2<G.nb(j).size(); _l2++) if( _l2 != _i)
223 newM += Gamma(j,_i,_l1,_l2) * tJ[j][_l1] * tJ[j][_l2] * cors[j][_l1][_l2];
224 }
225
226 Real dev = newM - M[i][_j];
227 // dev *= 0.02;
228 if( abs(dev) >= maxdev )
229 maxdev = abs(dev);
230
231 newM = M[i][_j] + dev;
232 if( abs(newM) > 1.0 )
233 newM = (newM > 0.0) ? 1.0 : -1.0;
234 M[i][_j] = newM;
235 }
236 }
237 } while((maxdev>props.tol)&&(run<maxruns));
238
239 _iters = run;
240 if( maxdev > _maxdiff )
241 _maxdiff = maxdev;
242
243 if(run==maxruns){
244 if( props.verbose >= 1 )
245 cerr << "MR::propagateCavityFields: Convergence not reached (maxdev=" << maxdev << ")..." << endl;
246 }
247 }
248
249
250 void MR::calcMagnetizations() {
251 for( size_t i = 0; i < G.nrNodes(); i++ ) {
252 if( props.updates == Properties::UpdateType::FULL ) {
253 // find indices in nb(i)
254 sub_nb _nbi( G.nb(i).size() );
255 _nbi.set();
256
257 // calc numerator1 and denominator1
258 Real sum_even, sum_odd;
259 sum_subs(i, _nbi, &sum_even, &sum_odd);
260
261 Mag[i] = (tanh(theta[i]) * sum_even + sum_odd) / (sum_even + tanh(theta[i]) * sum_odd);
262
263 } else if( props.updates == Properties::UpdateType::LINEAR ) {
264 sub_nb empty( G.nb(i).size() );
265 Mag[i] = T(i,empty);
266
267 for( size_t _l1 = 0; _l1 < G.nb(i).size(); _l1++ )
268 for( size_t _l2 = _l1 + 1; _l2 < G.nb(i).size(); _l2++ )
269 Mag[i] += Gamma(i,_l1,_l2) * tJ[i][_l1] * tJ[i][_l2] * cors[i][_l1][_l2];
270 }
271 if( abs( Mag[i] ) > 1.0 )
272 Mag[i] = (Mag[i] > 0.0) ? 1.0 : -1.0;
273 }
274 }
275
276
277 Real MR::calcCavityCorrelations() {
278 Real md = 0.0;
279 for( size_t i = 0; i < nrVars(); i++ ) {
280 vector<Factor> pairq;
281 if( props.inits == Properties::InitType::EXACT ) {
282 JTree jtcav(*this, PropertySet()("updates", string("HUGIN"))("verbose", (size_t)0) );
283 jtcav.makeCavity( i );
284 pairq = calcPairBeliefs( jtcav, delta(i), false, true );
285 } else if( props.inits == Properties::InitType::CLAMPING ) {
286 BP bpcav(*this, PropertySet()("updates", string("SEQMAX"))("tol", (Real)1.0e-9)("maxiter", (size_t)10000)("verbose", (size_t)0)("logdomain", false));
287 bpcav.makeCavity( i );
288
289 pairq = calcPairBeliefs( bpcav, delta(i), false, true );
290 md = std::max( md, bpcav.maxDiff() );
291 } else if( props.inits == Properties::InitType::RESPPROP ) {
292 BP bpcav(*this, PropertySet()("updates", string("SEQMAX"))("tol", (Real)1.0e-9)("maxiter", (size_t)10000)("verbose", (size_t)0)("logdomain", false));
293 bpcav.makeCavity( i );
294 bpcav.makeCavity( i );
295 bpcav.init();
296 bpcav.run();
297
298 BBP bbp( &bpcav, PropertySet()("verbose",(size_t)0)("tol",(Real)1.0e-9)("maxiter",(size_t)10000)("damping",(Real)0.0)("updates",string("SEQ_MAX")) );
299 foreach( const Neighbor &j, G.nb(i) ) {
300 // Create weights for magnetization of some spin
301 Prob p( 2, 0.0 );
302 p[0] = -1.0;
303 p[1] = 1.0;
304
305 // BBP cost function would be the magnetization of spin j
306 vector<Prob> b1_adj;
307 b1_adj.reserve( nrVars() );
308 for( size_t l = 0; l < nrVars(); l++ )
309 if( l == j )
310 b1_adj.push_back( p );
311 else
312 b1_adj.push_back( Prob( 2, 0.0 ) );
313 bbp.init_V( b1_adj );
314
315 // run BBP to estimate adjoints
316 bbp.run();
317
318 foreach( const Neighbor &k, G.nb(i) ) {
319 if( k != j )
320 cors[i][j.iter][k.iter] = (bbp.adj_psi_V(k)[1] - bbp.adj_psi_V(k)[0]);
321 else
322 cors[i][j.iter][k.iter] = 0.0;
323 }
324 }
325 }
326
327 if( props.inits != Properties::InitType::RESPPROP ) {
328 for( size_t jk = 0; jk < pairq.size(); jk++ ) {
329 VarSet::const_iterator kit = pairq[jk].vars().begin();
330 size_t j = findVar( *(kit) );
331 size_t k = findVar( *(++kit) );
332 pairq[jk].normalize();
333 Real cor = (pairq[jk][3] - pairq[jk][2] - pairq[jk][1] + pairq[jk][0]) - (pairq[jk][3] + pairq[jk][2] - pairq[jk][1] - pairq[jk][0]) * (pairq[jk][3] - pairq[jk][2] + pairq[jk][1] - pairq[jk][0]);
334
335 size_t _j = G.findNb(i,j);
336 size_t _k = G.findNb(i,k);
337 cors[i][_j][_k] = cor;
338 cors[i][_k][_j] = cor;
339 }
340 }
341
342 }
343 return md;
344 }
345
346
347 string MR::identify() const {
348 return string(Name) + printProperties();
349 }
350
351
352 Real MR::run() {
353 if( supported ) {
354 if( props.verbose >= 1 )
355 cerr << "Starting " << identify() << "...";
356
357 double tic = toc();
358
359 // approximate correlations of cavity spins
360 Real md = calcCavityCorrelations();
361 if( md > _maxdiff )
362 _maxdiff = md;
363
364 // solve messages
365 propagateCavityFields();
366
367 // calculate magnetizations
368 calcMagnetizations();
369
370 if( props.verbose >= 1 )
371 cerr << Name << " needed " << toc() - tic << " seconds." << endl;
372
373 return _maxdiff;
374 } else
375 return 1.0;
376 }
377
378
379 Factor MR::beliefV( size_t i ) const {
380 if( supported ) {
381 Prob x(2);
382 x[0] = 0.5 - Mag[i] / 2.0;
383 x[1] = 0.5 + Mag[i] / 2.0;
384
385 return Factor( var(i), x );
386 } else
387 return Factor();
388 }
389
390
391 vector<Factor> MR::beliefs() const {
392 vector<Factor> result;
393 for( size_t i = 0; i < nrVars(); i++ )
394 result.push_back( beliefV( i ) );
395 return result;
396 }
397
398
399 MR::MR( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), supported(true), _maxdiff(0.0), _iters(0) {
400 setProperties( opts );
401
402 size_t N = fg.nrVars();
403
404 // check whether all vars in fg are binary
405 for( size_t i = 0; i < N; i++ )
406 if( (fg.var(i).states() > 2) ) {
407 supported = false;
408 break;
409 }
410 if( !supported )
411 DAI_THROWE(NOT_IMPLEMENTED,"MR only supports binary variables");
412
413 // check whether all interactions are pairwise or single
414 // and construct Markov graph
415 G = GraphAL(N);
416 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
417 const Factor &psi = fg.factor(I);
418 if( psi.vars().size() > 2 ) {
419 supported = false;
420 break;
421 } else if( psi.vars().size() == 2 ) {
422 VarSet::const_iterator jit = psi.vars().begin();
423 size_t i = fg.findVar( *(jit) );
424 size_t j = fg.findVar( *(++jit) );
425 G.addEdge( i, j, false );
426 }
427 }
428 if( !supported )
429 DAI_THROWE(NOT_IMPLEMENTED,"MR does not support higher order interactions (only single and pairwise are supported)");
430
431 // construct theta
432 theta.clear();
433 theta.resize( N, 0.0 );
434
435 // construct tJ
436 tJ.resize( N );
437 for( size_t i = 0; i < N; i++ )
438 tJ[i].resize( G.nb(i).size(), 0.0 );
439
440 // initialize theta and tJ
441 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
442 const Factor &psi = fg.factor(I);
443 if( psi.vars().size() == 1 ) {
444 size_t i = fg.findVar( *(psi.vars().begin()) );
445 theta[i] += 0.5 * log(psi[1] / psi[0]);
446 } else if( psi.vars().size() == 2 ) {
447 VarSet::const_iterator jit = psi.vars().begin();
448 size_t i = fg.findVar( *(jit) );
449 size_t j = fg.findVar( *(++jit) );
450
451 Real w_ij = 0.25 * log(psi[3] * psi[0] / (psi[2] * psi[1]));
452 tJ[i][G.findNb(i,j)] += w_ij;
453 tJ[j][G.findNb(j,i)] += w_ij;
454
455 theta[i] += 0.25 * log(psi[3] / psi[2] * psi[1] / psi[0]);
456 theta[j] += 0.25 * log(psi[3] / psi[1] * psi[2] / psi[0]);
457 }
458 }
459 for( size_t i = 0; i < N; i++ )
460 foreach( const Neighbor &j, G.nb(i) )
461 tJ[i][j.iter] = tanh( tJ[i][j.iter] );
462
463 // construct M
464 M.resize( N );
465 for( size_t i = 0; i < N; i++ )
466 M[i].resize( G.nb(i).size() );
467
468 // construct cors
469 cors.resize( N );
470 for( size_t i = 0; i < N; i++ )
471 cors[i].resize( G.nb(i).size() );
472 for( size_t i = 0; i < N; i++ )
473 for( size_t _j = 0; _j < cors[i].size(); _j++ )
474 cors[i][_j].resize( G.nb(i).size() );
475
476 // construct Mag
477 Mag.resize( N );
478 }
479
480
481 } // end of namespace dai