7270279a5e6753ffbe18e518ec7b6518b537c5c7
[libdai.git] / src / mr.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <cstdio>
10 #include <ctime>
11 #include <cmath>
12 #include <cstdlib>
13 #include <dai/mr.h>
14 #include <dai/bp.h>
15 #include <dai/jtree.h>
16 #include <dai/util.h>
17 #include <dai/bbp.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 void MR::setProperties( const PropertySet &opts ) {
27 DAI_ASSERT( opts.hasKey("tol") );
28 DAI_ASSERT( opts.hasKey("updates") );
29 DAI_ASSERT( opts.hasKey("inits") );
30
31 props.tol = opts.getStringAs<Real>("tol");
32 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
33 props.inits = opts.getStringAs<Properties::InitType>("inits");
34 if( opts.hasKey("verbose") )
35 props.verbose = opts.getStringAs<size_t>("verbose");
36 else
37 props.verbose = 0;
38 }
39
40
41 PropertySet MR::getProperties() const {
42 PropertySet opts;
43 opts.set( "tol", props.tol );
44 opts.set( "verbose", props.verbose );
45 opts.set( "updates", props.updates );
46 opts.set( "inits", props.inits );
47 return opts;
48 }
49
50
51 string MR::printProperties() const {
52 stringstream s( stringstream::out );
53 s << "[";
54 s << "tol=" << props.tol << ",";
55 s << "verbose=" << props.verbose << ",";
56 s << "updates=" << props.updates << ",";
57 s << "inits=" << props.inits << "]";
58 return s.str();
59 }
60
61
62 Real MR::T(size_t i, sub_nb A) {
63 sub_nb _nbi_min_A(G.nb(i).size());
64 _nbi_min_A.set();
65 _nbi_min_A &= ~A;
66
67 Real res = theta[i];
68 for( size_t _j = 0; _j < _nbi_min_A.size(); _j++ )
69 if( _nbi_min_A.test(_j) )
70 res += atanh(tJ[i][_j] * M[i][_j]);
71 return tanh(res);
72 }
73
74
75 Real MR::T(size_t i, size_t _j) {
76 sub_nb j(G.nb(i).size());
77 j.set(_j);
78 return T(i,j);
79 }
80
81
82 Real MR::Omega(size_t i, size_t _j, size_t _l) {
83 sub_nb jl(G.nb(i).size());
84 jl.set(_j);
85 jl.set(_l);
86 Real Tijl = T(i,jl);
87 return Tijl / (1.0 + tJ[i][_l] * M[i][_l] * Tijl);
88 }
89
90
91 Real MR::Gamma(size_t i, size_t _j, size_t _l1, size_t _l2) {
92 sub_nb jll(G.nb(i).size());
93 jll.set(_j);
94 Real Tij = T(i,jll);
95 jll.set(_l1);
96 jll.set(_l2);
97 Real Tijll = T(i,jll);
98
99 return (Tijll - Tij) / (1.0 + tJ[i][_l1] * tJ[i][_l2] * M[i][_l1] * M[i][_l2] + tJ[i][_l1] * M[i][_l1] * Tijll + tJ[i][_l2] * M[i][_l2] * Tijll);
100 }
101
102
103 Real MR::Gamma(size_t i, size_t _l1, size_t _l2) {
104 sub_nb ll(G.nb(i).size());
105 Real Ti = T(i,ll);
106 ll.set(_l1);
107 ll.set(_l2);
108 Real Till = T(i,ll);
109
110 return (Till - Ti) / (1.0 + tJ[i][_l1] * tJ[i][_l2] * M[i][_l1] * M[i][_l2] + tJ[i][_l1] * M[i][_l1] * Till + tJ[i][_l2] * M[i][_l2] * Till);
111 }
112
113
114 Real MR::_tJ(size_t i, sub_nb A) {
115 sub_nb::size_type _j = A.find_first();
116 if( _j == sub_nb::npos )
117 return 1.0;
118 else
119 return tJ[i][_j] * _tJ(i, A.reset(_j));
120 }
121
122
123 Real MR::appM(size_t i, sub_nb A) {
124 sub_nb::size_type _j = A.find_first();
125 if( _j == sub_nb::npos )
126 return 1.0;
127 else {
128 sub_nb A_j(A); A_j.reset(_j);
129
130 Real result = M[i][_j] * appM(i, A_j);
131 for( size_t _k = 0; _k < A_j.size(); _k++ )
132 if( A_j.test(_k) ) {
133 sub_nb A_jk(A_j); A_jk.reset(_k);
134 result += cors[i][_j][_k] * appM(i,A_jk);
135 }
136
137 return result;
138 }
139 }
140
141
142 void MR::sum_subs(size_t j, sub_nb A, Real *sum_even, Real *sum_odd) {
143 *sum_even = 0.0;
144 *sum_odd = 0.0;
145
146 sub_nb B(A.size());
147 do {
148 if( B.count() % 2 )
149 *sum_odd += _tJ(j,B) * appM(j,B);
150 else
151 *sum_even += _tJ(j,B) * appM(j,B);
152
153 // calc next subset B
154 size_t bit = 0;
155 for( ; bit < A.size(); bit++ )
156 if( A.test(bit) ) {
157 if( B.test(bit) )
158 B.reset(bit);
159 else {
160 B.set(bit);
161 break;
162 }
163 }
164 } while (!B.none());
165 }
166
167
168 void MR::propagateCavityFields() {
169 Real sum_even, sum_odd;
170 Real maxdev;
171 size_t maxruns = 1000;
172
173 for( size_t i = 0; i < G.nrNodes(); i++ )
174 bforeach( const Neighbor &j, G.nb(i) )
175 M[i][j.iter] = 0.1;
176
177 size_t run=0;
178 do {
179 maxdev=0.0;
180 run++;
181 for( size_t i = 0; i < G.nrNodes(); i++ ) {
182 bforeach( const Neighbor &j, G.nb(i) ) {
183 size_t _j = j.iter;
184 size_t _i = G.findNb(j,i);
185 DAI_ASSERT( G.nb(j,_i) == i );
186
187 Real newM = 0.0;
188 if( props.updates == Properties::UpdateType::FULL ) {
189 // find indices in nb(j) that do not correspond with i
190 sub_nb _nbj_min_i(G.nb(j).size());
191 _nbj_min_i.set();
192 _nbj_min_i.reset(_i);
193
194 // find indices in nb(i) that do not correspond with j
195 sub_nb _nbi_min_j(G.nb(i).size());
196 _nbi_min_j.set();
197 _nbi_min_j.reset(_j);
198
199 sum_subs(j, _nbj_min_i, &sum_even, &sum_odd);
200 newM = (tanh(theta[j]) * sum_even + sum_odd) / (sum_even + tanh(theta[j]) * sum_odd);
201
202 sum_subs(i, _nbi_min_j, &sum_even, &sum_odd);
203 Real denom = sum_even + tanh(theta[i]) * sum_odd;
204 Real numer = 0.0;
205 for(size_t _k=0; _k < G.nb(i).size(); _k++) if(_k != _j) {
206 sub_nb _nbi_min_jk(_nbi_min_j);
207 _nbi_min_jk.reset(_k);
208 sum_subs(i, _nbi_min_jk, &sum_even, &sum_odd);
209 numer += tJ[i][_k] * cors[i][_j][_k] * (tanh(theta[i]) * sum_even + sum_odd);
210 }
211 newM -= numer / denom;
212 } else if( props.updates == Properties::UpdateType::LINEAR ) {
213 newM = T(j,_i);
214 for(size_t _l=0; _l<G.nb(i).size(); _l++) if( _l != _j )
215 newM -= Omega(i,_j,_l) * tJ[i][_l] * cors[i][_j][_l];
216 for(size_t _l1=0; _l1<G.nb(j).size(); _l1++) if( _l1 != _i )
217 for( size_t _l2=_l1+1; _l2<G.nb(j).size(); _l2++) if( _l2 != _i)
218 newM += Gamma(j,_i,_l1,_l2) * tJ[j][_l1] * tJ[j][_l2] * cors[j][_l1][_l2];
219 }
220
221 Real dev = newM - M[i][_j];
222 // dev *= 0.02;
223 if( abs(dev) >= maxdev )
224 maxdev = abs(dev);
225
226 newM = M[i][_j] + dev;
227 if( abs(newM) > 1.0 )
228 newM = (newM > 0.0) ? 1.0 : -1.0;
229 M[i][_j] = newM;
230 }
231 }
232 } while((maxdev>props.tol)&&(run<maxruns));
233
234 _iters = run;
235 if( maxdev > _maxdiff )
236 _maxdiff = maxdev;
237
238 if(run==maxruns){
239 if( props.verbose >= 1 )
240 cerr << "MR::propagateCavityFields: Convergence not reached (maxdev=" << maxdev << ")..." << endl;
241 }
242 }
243
244
245 void MR::calcMagnetizations() {
246 for( size_t i = 0; i < G.nrNodes(); i++ ) {
247 if( props.updates == Properties::UpdateType::FULL ) {
248 // find indices in nb(i)
249 sub_nb _nbi( G.nb(i).size() );
250 _nbi.set();
251
252 // calc numerator1 and denominator1
253 Real sum_even, sum_odd;
254 sum_subs(i, _nbi, &sum_even, &sum_odd);
255
256 Mag[i] = (tanh(theta[i]) * sum_even + sum_odd) / (sum_even + tanh(theta[i]) * sum_odd);
257
258 } else if( props.updates == Properties::UpdateType::LINEAR ) {
259 sub_nb empty( G.nb(i).size() );
260 Mag[i] = T(i,empty);
261
262 for( size_t _l1 = 0; _l1 < G.nb(i).size(); _l1++ )
263 for( size_t _l2 = _l1 + 1; _l2 < G.nb(i).size(); _l2++ )
264 Mag[i] += Gamma(i,_l1,_l2) * tJ[i][_l1] * tJ[i][_l2] * cors[i][_l1][_l2];
265 }
266 if( abs( Mag[i] ) > 1.0 )
267 Mag[i] = (Mag[i] > 0.0) ? 1.0 : -1.0;
268 }
269 }
270
271
272 Real MR::calcCavityCorrelations() {
273 Real md = 0.0;
274 for( size_t i = 0; i < nrVars(); i++ ) {
275 vector<Factor> pairq;
276 if( props.inits == Properties::InitType::EXACT ) {
277 JTree jtcav(*this, PropertySet()("updates", string("HUGIN"))("verbose", (size_t)0) );
278 jtcav.makeCavity( i );
279 pairq = calcPairBeliefs( jtcav, delta(i), false, true );
280 } else if( props.inits == Properties::InitType::CLAMPING ) {
281 BP bpcav(*this, PropertySet()("updates", string("SEQMAX"))("tol", (Real)1.0e-9)("maxiter", (size_t)10000)("verbose", (size_t)0)("logdomain", false));
282 bpcav.makeCavity( i );
283
284 pairq = calcPairBeliefs( bpcav, delta(i), false, true );
285 md = std::max( md, bpcav.maxDiff() );
286 } else if( props.inits == Properties::InitType::RESPPROP ) {
287 BP bpcav(*this, PropertySet()("updates", string("SEQMAX"))("tol", (Real)1.0e-9)("maxiter", (size_t)10000)("verbose", (size_t)0)("logdomain", false));
288 bpcav.makeCavity( i );
289 bpcav.makeCavity( i );
290 bpcav.init();
291 bpcav.run();
292
293 BBP bbp( &bpcav, PropertySet()("verbose",(size_t)0)("tol",(Real)1.0e-9)("maxiter",(size_t)10000)("damping",(Real)0.0)("updates",string("SEQ_MAX")) );
294 bforeach( const Neighbor &j, G.nb(i) ) {
295 // Create weights for magnetization of some spin
296 Prob p( 2, 0.0 );
297 p.set( 0, -1.0 );
298 p.set( 1, 1.0 );
299
300 // BBP cost function would be the magnetization of spin j
301 vector<Prob> b1_adj;
302 b1_adj.reserve( nrVars() );
303 for( size_t l = 0; l < nrVars(); l++ )
304 if( l == j )
305 b1_adj.push_back( p );
306 else
307 b1_adj.push_back( Prob( 2, 0.0 ) );
308 bbp.init_V( b1_adj );
309
310 // run BBP to estimate adjoints
311 bbp.run();
312
313 bforeach( const Neighbor &k, G.nb(i) ) {
314 if( k != j )
315 cors[i][j.iter][k.iter] = (bbp.adj_psi_V(k)[1] - bbp.adj_psi_V(k)[0]);
316 else
317 cors[i][j.iter][k.iter] = 0.0;
318 }
319 }
320 }
321
322 if( props.inits != Properties::InitType::RESPPROP ) {
323 for( size_t jk = 0; jk < pairq.size(); jk++ ) {
324 VarSet::const_iterator kit = pairq[jk].vars().begin();
325 size_t j = findVar( *(kit) );
326 size_t k = findVar( *(++kit) );
327 pairq[jk].normalize();
328 Real cor = (pairq[jk][3] - pairq[jk][2] - pairq[jk][1] + pairq[jk][0]) - (pairq[jk][3] + pairq[jk][2] - pairq[jk][1] - pairq[jk][0]) * (pairq[jk][3] - pairq[jk][2] + pairq[jk][1] - pairq[jk][0]);
329
330 size_t _j = G.findNb(i,j);
331 size_t _k = G.findNb(i,k);
332 cors[i][_j][_k] = cor;
333 cors[i][_k][_j] = cor;
334 }
335 }
336
337 }
338 return md;
339 }
340
341
342 Real MR::run() {
343 if( supported ) {
344 if( props.verbose >= 1 )
345 cerr << "Starting " << identify() << "...";
346
347 double tic = toc();
348
349 // approximate correlations of cavity spins
350 Real md = calcCavityCorrelations();
351 if( md > _maxdiff )
352 _maxdiff = md;
353
354 // solve messages
355 propagateCavityFields();
356
357 // calculate magnetizations
358 calcMagnetizations();
359
360 if( props.verbose >= 1 )
361 cerr << name() << " needed " << toc() - tic << " seconds." << endl;
362
363 return _maxdiff;
364 } else
365 return 1.0;
366 }
367
368
369 Factor MR::beliefV( size_t i ) const {
370 if( supported ) {
371 Real x[2];
372 x[0] = 0.5 - Mag[i] / 2.0;
373 x[1] = 0.5 + Mag[i] / 2.0;
374
375 return Factor( var(i), x );
376 } else
377 return Factor();
378 }
379
380
381 Factor MR::belief (const VarSet &ns) const {
382 if( ns.size() == 0 )
383 return Factor();
384 else if( ns.size() == 1 )
385 return beliefV( findVar( *(ns.begin()) ) );
386 else {
387 DAI_THROW(BELIEF_NOT_AVAILABLE);
388 return Factor();
389 }
390 }
391
392
393 vector<Factor> MR::beliefs() const {
394 vector<Factor> result;
395 for( size_t i = 0; i < nrVars(); i++ )
396 result.push_back( beliefV( i ) );
397 return result;
398 }
399
400
401 MR::MR( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), supported(true), _maxdiff(0.0), _iters(0) {
402 setProperties( opts );
403
404 size_t N = fg.nrVars();
405
406 // check whether all vars in fg are binary
407 for( size_t i = 0; i < N; i++ )
408 if( (fg.var(i).states() > 2) ) {
409 supported = false;
410 break;
411 }
412 if( !supported )
413 DAI_THROWE(NOT_IMPLEMENTED,"MR only supports binary variables");
414
415 // check whether all interactions are pairwise or single
416 // and construct Markov graph
417 G = GraphAL(N);
418 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
419 const Factor &psi = fg.factor(I);
420 if( psi.vars().size() > 2 ) {
421 supported = false;
422 break;
423 } else if( psi.vars().size() == 2 ) {
424 VarSet::const_iterator jit = psi.vars().begin();
425 size_t i = fg.findVar( *(jit) );
426 size_t j = fg.findVar( *(++jit) );
427 G.addEdge( i, j, false );
428 }
429 }
430 if( !supported )
431 DAI_THROWE(NOT_IMPLEMENTED,"MR does not support higher order interactions (only single and pairwise are supported)");
432
433 // construct theta
434 theta.clear();
435 theta.resize( N, 0.0 );
436
437 // construct tJ
438 tJ.resize( N );
439 for( size_t i = 0; i < N; i++ )
440 tJ[i].resize( G.nb(i).size(), 0.0 );
441
442 // initialize theta and tJ
443 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
444 const Factor &psi = fg.factor(I);
445 if( psi.vars().size() == 1 ) {
446 size_t i = fg.findVar( *(psi.vars().begin()) );
447 theta[i] += 0.5 * log(psi[1] / psi[0]);
448 } else if( psi.vars().size() == 2 ) {
449 VarSet::const_iterator jit = psi.vars().begin();
450 size_t i = fg.findVar( *(jit) );
451 size_t j = fg.findVar( *(++jit) );
452
453 Real w_ij = 0.25 * log(psi[3] * psi[0] / (psi[2] * psi[1]));
454 tJ[i][G.findNb(i,j)] += w_ij;
455 tJ[j][G.findNb(j,i)] += w_ij;
456
457 theta[i] += 0.25 * log(psi[3] / psi[2] * psi[1] / psi[0]);
458 theta[j] += 0.25 * log(psi[3] / psi[1] * psi[2] / psi[0]);
459 }
460 }
461 for( size_t i = 0; i < N; i++ )
462 bforeach( const Neighbor &j, G.nb(i) )
463 tJ[i][j.iter] = tanh( tJ[i][j.iter] );
464
465 // construct M
466 M.resize( N );
467 for( size_t i = 0; i < N; i++ )
468 M[i].resize( G.nb(i).size() );
469
470 // construct cors
471 cors.resize( N );
472 for( size_t i = 0; i < N; i++ )
473 cors[i].resize( G.nb(i).size() );
474 for( size_t i = 0; i < N; i++ )
475 for( size_t _j = 0; _j < cors[i].size(); _j++ )
476 cors[i][_j].resize( G.nb(i).size() );
477
478 // construct Mag
479 Mag.resize( N );
480 }
481
482
483 } // end of namespace dai