Improved HAK (added 'maxtime' property)
[libdai.git] / src / mr.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2007 Bastian Wemmenhove
8 * Copyright (C) 2007-2010 Joris Mooij [joris dot mooij at libdai dot org]
9 * Copyright (C) 2007 Radboud University Nijmegen, The Netherlands
10 */
11
12
13 #include <cstdio>
14 #include <ctime>
15 #include <cmath>
16 #include <cstdlib>
17 #include <dai/mr.h>
18 #include <dai/bp.h>
19 #include <dai/jtree.h>
20 #include <dai/util.h>
21 #include <dai/bbp.h>
22
23
24 namespace dai {
25
26
27 using namespace std;
28
29
30 const char *MR::Name = "MR";
31
32
33 void MR::setProperties( const PropertySet &opts ) {
34 DAI_ASSERT( opts.hasKey("tol") );
35 DAI_ASSERT( opts.hasKey("updates") );
36 DAI_ASSERT( opts.hasKey("inits") );
37
38 props.tol = opts.getStringAs<Real>("tol");
39 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
40 props.inits = opts.getStringAs<Properties::InitType>("inits");
41 if( opts.hasKey("verbose") )
42 props.verbose = opts.getStringAs<size_t>("verbose");
43 else
44 props.verbose = 0;
45 }
46
47
48 PropertySet MR::getProperties() const {
49 PropertySet opts;
50 opts.set( "tol", props.tol );
51 opts.set( "verbose", props.verbose );
52 opts.set( "updates", props.updates );
53 opts.set( "inits", props.inits );
54 return opts;
55 }
56
57
58 string MR::printProperties() const {
59 stringstream s( stringstream::out );
60 s << "[";
61 s << "tol=" << props.tol << ",";
62 s << "verbose=" << props.verbose << ",";
63 s << "updates=" << props.updates << ",";
64 s << "inits=" << props.inits << "]";
65 return s.str();
66 }
67
68
69 Real MR::T(size_t i, sub_nb A) {
70 sub_nb _nbi_min_A(G.nb(i).size());
71 _nbi_min_A.set();
72 _nbi_min_A &= ~A;
73
74 Real res = theta[i];
75 for( size_t _j = 0; _j < _nbi_min_A.size(); _j++ )
76 if( _nbi_min_A.test(_j) )
77 res += atanh(tJ[i][_j] * M[i][_j]);
78 return tanh(res);
79 }
80
81
82 Real MR::T(size_t i, size_t _j) {
83 sub_nb j(G.nb(i).size());
84 j.set(_j);
85 return T(i,j);
86 }
87
88
89 Real MR::Omega(size_t i, size_t _j, size_t _l) {
90 sub_nb jl(G.nb(i).size());
91 jl.set(_j);
92 jl.set(_l);
93 Real Tijl = T(i,jl);
94 return Tijl / (1.0 + tJ[i][_l] * M[i][_l] * Tijl);
95 }
96
97
98 Real MR::Gamma(size_t i, size_t _j, size_t _l1, size_t _l2) {
99 sub_nb jll(G.nb(i).size());
100 jll.set(_j);
101 Real Tij = T(i,jll);
102 jll.set(_l1);
103 jll.set(_l2);
104 Real Tijll = T(i,jll);
105
106 return (Tijll - Tij) / (1.0 + tJ[i][_l1] * tJ[i][_l2] * M[i][_l1] * M[i][_l2] + tJ[i][_l1] * M[i][_l1] * Tijll + tJ[i][_l2] * M[i][_l2] * Tijll);
107 }
108
109
110 Real MR::Gamma(size_t i, size_t _l1, size_t _l2) {
111 sub_nb ll(G.nb(i).size());
112 Real Ti = T(i,ll);
113 ll.set(_l1);
114 ll.set(_l2);
115 Real Till = T(i,ll);
116
117 return (Till - Ti) / (1.0 + tJ[i][_l1] * tJ[i][_l2] * M[i][_l1] * M[i][_l2] + tJ[i][_l1] * M[i][_l1] * Till + tJ[i][_l2] * M[i][_l2] * Till);
118 }
119
120
121 Real MR::_tJ(size_t i, sub_nb A) {
122 sub_nb::size_type _j = A.find_first();
123 if( _j == sub_nb::npos )
124 return 1.0;
125 else
126 return tJ[i][_j] * _tJ(i, A.reset(_j));
127 }
128
129
130 Real MR::appM(size_t i, sub_nb A) {
131 sub_nb::size_type _j = A.find_first();
132 if( _j == sub_nb::npos )
133 return 1.0;
134 else {
135 sub_nb A_j(A); A_j.reset(_j);
136
137 Real result = M[i][_j] * appM(i, A_j);
138 for( size_t _k = 0; _k < A_j.size(); _k++ )
139 if( A_j.test(_k) ) {
140 sub_nb A_jk(A_j); A_jk.reset(_k);
141 result += cors[i][_j][_k] * appM(i,A_jk);
142 }
143
144 return result;
145 }
146 }
147
148
149 void MR::sum_subs(size_t j, sub_nb A, Real *sum_even, Real *sum_odd) {
150 *sum_even = 0.0;
151 *sum_odd = 0.0;
152
153 sub_nb B(A.size());
154 do {
155 if( B.count() % 2 )
156 *sum_odd += _tJ(j,B) * appM(j,B);
157 else
158 *sum_even += _tJ(j,B) * appM(j,B);
159
160 // calc next subset B
161 size_t bit = 0;
162 for( ; bit < A.size(); bit++ )
163 if( A.test(bit) ) {
164 if( B.test(bit) )
165 B.reset(bit);
166 else {
167 B.set(bit);
168 break;
169 }
170 }
171 } while (!B.none());
172 }
173
174
175 void MR::propagateCavityFields() {
176 Real sum_even, sum_odd;
177 Real maxdev;
178 size_t maxruns = 1000;
179
180 for( size_t i = 0; i < G.nrNodes(); i++ )
181 foreach( const Neighbor &j, G.nb(i) )
182 M[i][j.iter] = 0.1;
183
184 size_t run=0;
185 do {
186 maxdev=0.0;
187 run++;
188 for( size_t i = 0; i < G.nrNodes(); i++ ) {
189 foreach( const Neighbor &j, G.nb(i) ) {
190 size_t _j = j.iter;
191 size_t _i = G.findNb(j,i);
192 DAI_ASSERT( G.nb(j,_i) == i );
193
194 Real newM = 0.0;
195 if( props.updates == Properties::UpdateType::FULL ) {
196 // find indices in nb(j) that do not correspond with i
197 sub_nb _nbj_min_i(G.nb(j).size());
198 _nbj_min_i.set();
199 _nbj_min_i.reset(_i);
200
201 // find indices in nb(i) that do not correspond with j
202 sub_nb _nbi_min_j(G.nb(i).size());
203 _nbi_min_j.set();
204 _nbi_min_j.reset(_j);
205
206 sum_subs(j, _nbj_min_i, &sum_even, &sum_odd);
207 newM = (tanh(theta[j]) * sum_even + sum_odd) / (sum_even + tanh(theta[j]) * sum_odd);
208
209 sum_subs(i, _nbi_min_j, &sum_even, &sum_odd);
210 Real denom = sum_even + tanh(theta[i]) * sum_odd;
211 Real numer = 0.0;
212 for(size_t _k=0; _k < G.nb(i).size(); _k++) if(_k != _j) {
213 sub_nb _nbi_min_jk(_nbi_min_j);
214 _nbi_min_jk.reset(_k);
215 sum_subs(i, _nbi_min_jk, &sum_even, &sum_odd);
216 numer += tJ[i][_k] * cors[i][_j][_k] * (tanh(theta[i]) * sum_even + sum_odd);
217 }
218 newM -= numer / denom;
219 } else if( props.updates == Properties::UpdateType::LINEAR ) {
220 newM = T(j,_i);
221 for(size_t _l=0; _l<G.nb(i).size(); _l++) if( _l != _j )
222 newM -= Omega(i,_j,_l) * tJ[i][_l] * cors[i][_j][_l];
223 for(size_t _l1=0; _l1<G.nb(j).size(); _l1++) if( _l1 != _i )
224 for( size_t _l2=_l1+1; _l2<G.nb(j).size(); _l2++) if( _l2 != _i)
225 newM += Gamma(j,_i,_l1,_l2) * tJ[j][_l1] * tJ[j][_l2] * cors[j][_l1][_l2];
226 }
227
228 Real dev = newM - M[i][_j];
229 // dev *= 0.02;
230 if( abs(dev) >= maxdev )
231 maxdev = abs(dev);
232
233 newM = M[i][_j] + dev;
234 if( abs(newM) > 1.0 )
235 newM = (newM > 0.0) ? 1.0 : -1.0;
236 M[i][_j] = newM;
237 }
238 }
239 } while((maxdev>props.tol)&&(run<maxruns));
240
241 _iters = run;
242 if( maxdev > _maxdiff )
243 _maxdiff = maxdev;
244
245 if(run==maxruns){
246 if( props.verbose >= 1 )
247 cerr << "MR::propagateCavityFields: Convergence not reached (maxdev=" << maxdev << ")..." << endl;
248 }
249 }
250
251
252 void MR::calcMagnetizations() {
253 for( size_t i = 0; i < G.nrNodes(); i++ ) {
254 if( props.updates == Properties::UpdateType::FULL ) {
255 // find indices in nb(i)
256 sub_nb _nbi( G.nb(i).size() );
257 _nbi.set();
258
259 // calc numerator1 and denominator1
260 Real sum_even, sum_odd;
261 sum_subs(i, _nbi, &sum_even, &sum_odd);
262
263 Mag[i] = (tanh(theta[i]) * sum_even + sum_odd) / (sum_even + tanh(theta[i]) * sum_odd);
264
265 } else if( props.updates == Properties::UpdateType::LINEAR ) {
266 sub_nb empty( G.nb(i).size() );
267 Mag[i] = T(i,empty);
268
269 for( size_t _l1 = 0; _l1 < G.nb(i).size(); _l1++ )
270 for( size_t _l2 = _l1 + 1; _l2 < G.nb(i).size(); _l2++ )
271 Mag[i] += Gamma(i,_l1,_l2) * tJ[i][_l1] * tJ[i][_l2] * cors[i][_l1][_l2];
272 }
273 if( abs( Mag[i] ) > 1.0 )
274 Mag[i] = (Mag[i] > 0.0) ? 1.0 : -1.0;
275 }
276 }
277
278
279 Real MR::calcCavityCorrelations() {
280 Real md = 0.0;
281 for( size_t i = 0; i < nrVars(); i++ ) {
282 vector<Factor> pairq;
283 if( props.inits == Properties::InitType::EXACT ) {
284 JTree jtcav(*this, PropertySet()("updates", string("HUGIN"))("verbose", (size_t)0) );
285 jtcav.makeCavity( i );
286 pairq = calcPairBeliefs( jtcav, delta(i), false, true );
287 } else if( props.inits == Properties::InitType::CLAMPING ) {
288 BP bpcav(*this, PropertySet()("updates", string("SEQMAX"))("tol", (Real)1.0e-9)("maxiter", (size_t)10000)("verbose", (size_t)0)("logdomain", false));
289 bpcav.makeCavity( i );
290
291 pairq = calcPairBeliefs( bpcav, delta(i), false, true );
292 md = std::max( md, bpcav.maxDiff() );
293 } else if( props.inits == Properties::InitType::RESPPROP ) {
294 BP bpcav(*this, PropertySet()("updates", string("SEQMAX"))("tol", (Real)1.0e-9)("maxiter", (size_t)10000)("verbose", (size_t)0)("logdomain", false));
295 bpcav.makeCavity( i );
296 bpcav.makeCavity( i );
297 bpcav.init();
298 bpcav.run();
299
300 BBP bbp( &bpcav, PropertySet()("verbose",(size_t)0)("tol",(Real)1.0e-9)("maxiter",(size_t)10000)("damping",(Real)0.0)("updates",string("SEQ_MAX")) );
301 foreach( const Neighbor &j, G.nb(i) ) {
302 // Create weights for magnetization of some spin
303 Prob p( 2, 0.0 );
304 p.set( 0, -1.0 );
305 p.set( 1, 1.0 );
306
307 // BBP cost function would be the magnetization of spin j
308 vector<Prob> b1_adj;
309 b1_adj.reserve( nrVars() );
310 for( size_t l = 0; l < nrVars(); l++ )
311 if( l == j )
312 b1_adj.push_back( p );
313 else
314 b1_adj.push_back( Prob( 2, 0.0 ) );
315 bbp.init_V( b1_adj );
316
317 // run BBP to estimate adjoints
318 bbp.run();
319
320 foreach( const Neighbor &k, G.nb(i) ) {
321 if( k != j )
322 cors[i][j.iter][k.iter] = (bbp.adj_psi_V(k)[1] - bbp.adj_psi_V(k)[0]);
323 else
324 cors[i][j.iter][k.iter] = 0.0;
325 }
326 }
327 }
328
329 if( props.inits != Properties::InitType::RESPPROP ) {
330 for( size_t jk = 0; jk < pairq.size(); jk++ ) {
331 VarSet::const_iterator kit = pairq[jk].vars().begin();
332 size_t j = findVar( *(kit) );
333 size_t k = findVar( *(++kit) );
334 pairq[jk].normalize();
335 Real cor = (pairq[jk][3] - pairq[jk][2] - pairq[jk][1] + pairq[jk][0]) - (pairq[jk][3] + pairq[jk][2] - pairq[jk][1] - pairq[jk][0]) * (pairq[jk][3] - pairq[jk][2] + pairq[jk][1] - pairq[jk][0]);
336
337 size_t _j = G.findNb(i,j);
338 size_t _k = G.findNb(i,k);
339 cors[i][_j][_k] = cor;
340 cors[i][_k][_j] = cor;
341 }
342 }
343
344 }
345 return md;
346 }
347
348
349 string MR::identify() const {
350 return string(Name) + printProperties();
351 }
352
353
354 Real MR::run() {
355 if( supported ) {
356 if( props.verbose >= 1 )
357 cerr << "Starting " << identify() << "...";
358
359 double tic = toc();
360
361 // approximate correlations of cavity spins
362 Real md = calcCavityCorrelations();
363 if( md > _maxdiff )
364 _maxdiff = md;
365
366 // solve messages
367 propagateCavityFields();
368
369 // calculate magnetizations
370 calcMagnetizations();
371
372 if( props.verbose >= 1 )
373 cerr << Name << " needed " << toc() - tic << " seconds." << endl;
374
375 return _maxdiff;
376 } else
377 return 1.0;
378 }
379
380
381 Factor MR::beliefV( size_t i ) const {
382 if( supported ) {
383 Real x[2];
384 x[0] = 0.5 - Mag[i] / 2.0;
385 x[1] = 0.5 + Mag[i] / 2.0;
386
387 return Factor( var(i), x );
388 } else
389 return Factor();
390 }
391
392
393 Factor MR::belief (const VarSet &ns) const {
394 if( ns.size() == 0 )
395 return Factor();
396 else if( ns.size() == 1 )
397 return beliefV( findVar( *(ns.begin()) ) );
398 else {
399 DAI_THROW(BELIEF_NOT_AVAILABLE);
400 return Factor();
401 }
402 }
403
404
405 vector<Factor> MR::beliefs() const {
406 vector<Factor> result;
407 for( size_t i = 0; i < nrVars(); i++ )
408 result.push_back( beliefV( i ) );
409 return result;
410 }
411
412
413 MR::MR( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), supported(true), _maxdiff(0.0), _iters(0) {
414 setProperties( opts );
415
416 size_t N = fg.nrVars();
417
418 // check whether all vars in fg are binary
419 for( size_t i = 0; i < N; i++ )
420 if( (fg.var(i).states() > 2) ) {
421 supported = false;
422 break;
423 }
424 if( !supported )
425 DAI_THROWE(NOT_IMPLEMENTED,"MR only supports binary variables");
426
427 // check whether all interactions are pairwise or single
428 // and construct Markov graph
429 G = GraphAL(N);
430 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
431 const Factor &psi = fg.factor(I);
432 if( psi.vars().size() > 2 ) {
433 supported = false;
434 break;
435 } else if( psi.vars().size() == 2 ) {
436 VarSet::const_iterator jit = psi.vars().begin();
437 size_t i = fg.findVar( *(jit) );
438 size_t j = fg.findVar( *(++jit) );
439 G.addEdge( i, j, false );
440 }
441 }
442 if( !supported )
443 DAI_THROWE(NOT_IMPLEMENTED,"MR does not support higher order interactions (only single and pairwise are supported)");
444
445 // construct theta
446 theta.clear();
447 theta.resize( N, 0.0 );
448
449 // construct tJ
450 tJ.resize( N );
451 for( size_t i = 0; i < N; i++ )
452 tJ[i].resize( G.nb(i).size(), 0.0 );
453
454 // initialize theta and tJ
455 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
456 const Factor &psi = fg.factor(I);
457 if( psi.vars().size() == 1 ) {
458 size_t i = fg.findVar( *(psi.vars().begin()) );
459 theta[i] += 0.5 * log(psi[1] / psi[0]);
460 } else if( psi.vars().size() == 2 ) {
461 VarSet::const_iterator jit = psi.vars().begin();
462 size_t i = fg.findVar( *(jit) );
463 size_t j = fg.findVar( *(++jit) );
464
465 Real w_ij = 0.25 * log(psi[3] * psi[0] / (psi[2] * psi[1]));
466 tJ[i][G.findNb(i,j)] += w_ij;
467 tJ[j][G.findNb(j,i)] += w_ij;
468
469 theta[i] += 0.25 * log(psi[3] / psi[2] * psi[1] / psi[0]);
470 theta[j] += 0.25 * log(psi[3] / psi[1] * psi[2] / psi[0]);
471 }
472 }
473 for( size_t i = 0; i < N; i++ )
474 foreach( const Neighbor &j, G.nb(i) )
475 tJ[i][j.iter] = tanh( tJ[i][j.iter] );
476
477 // construct M
478 M.resize( N );
479 for( size_t i = 0; i < N; i++ )
480 M[i].resize( G.nb(i).size() );
481
482 // construct cors
483 cors.resize( N );
484 for( size_t i = 0; i < N; i++ )
485 cors[i].resize( G.nb(i).size() );
486 for( size_t i = 0; i < N; i++ )
487 for( size_t _j = 0; _j < cors[i].size(); _j++ )
488 cors[i][_j].resize( G.nb(i).size() );
489
490 // construct Mag
491 Mag.resize( N );
492 }
493
494
495 } // end of namespace dai