Merge branch 'vaskeEmFix' of git://disco.cse.ucsc.edu/libDAI into mergeVaskeEmFix
[libdai.git] / src / hak.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <map>
24 #include <dai/hak.h>
25 #include <dai/util.h>
26 #include <dai/exceptions.h>
27
28
29 namespace dai {
30
31
32 using namespace std;
33
34
35 const char *HAK::Name = "HAK";
36
37
38 void HAK::setProperties( const PropertySet &opts ) {
39 assert( opts.hasKey("tol") );
40 assert( opts.hasKey("maxiter") );
41 assert( opts.hasKey("verbose") );
42 assert( opts.hasKey("doubleloop") );
43 assert( opts.hasKey("clusters") );
44
45 props.tol = opts.getStringAs<double>("tol");
46 props.maxiter = opts.getStringAs<size_t>("maxiter");
47 props.verbose = opts.getStringAs<size_t>("verbose");
48 props.doubleloop = opts.getStringAs<bool>("doubleloop");
49 props.clusters = opts.getStringAs<Properties::ClustersType>("clusters");
50
51 if( opts.hasKey("loopdepth") )
52 props.loopdepth = opts.getStringAs<size_t>("loopdepth");
53 else
54 assert( props.clusters != Properties::ClustersType::LOOP );
55 if( opts.hasKey("damping") )
56 props.damping = opts.getStringAs<double>("damping");
57 else
58 props.damping = 0.0;
59 }
60
61
62 PropertySet HAK::getProperties() const {
63 PropertySet opts;
64 opts.Set( "tol", props.tol );
65 opts.Set( "maxiter", props.maxiter );
66 opts.Set( "verbose", props.verbose );
67 opts.Set( "doubleloop", props.doubleloop );
68 opts.Set( "clusters", props.clusters );
69 opts.Set( "loopdepth", props.loopdepth );
70 opts.Set( "damping", props.damping );
71 return opts;
72 }
73
74
75 string HAK::printProperties() const {
76 stringstream s( stringstream::out );
77 s << "[";
78 s << "tol=" << props.tol << ",";
79 s << "maxiter=" << props.maxiter << ",";
80 s << "verbose=" << props.verbose << ",";
81 s << "doubleloop=" << props.doubleloop << ",";
82 s << "clusters=" << props.clusters << ",";
83 s << "loopdepth=" << props.loopdepth << ",";
84 s << "damping=" << props.damping << "]";
85 return s.str();
86 }
87
88
89 void HAK::constructMessages() {
90 // Create outer beliefs
91 _Qa.clear();
92 _Qa.reserve(nrORs());
93 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
94 _Qa.push_back( Factor( OR(alpha).vars() ) );
95
96 // Create inner beliefs
97 _Qb.clear();
98 _Qb.reserve(nrIRs());
99 for( size_t beta = 0; beta < nrIRs(); beta++ )
100 _Qb.push_back( Factor( IR(beta) ) );
101
102 // Create messages
103 _muab.clear();
104 _muab.reserve( nrORs() );
105 _muba.clear();
106 _muba.reserve( nrORs() );
107 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
108 _muab.push_back( vector<Factor>() );
109 _muba.push_back( vector<Factor>() );
110 _muab[alpha].reserve( nbOR(alpha).size() );
111 _muba[alpha].reserve( nbOR(alpha).size() );
112 foreach( const Neighbor &beta, nbOR(alpha) ) {
113 _muab[alpha].push_back( Factor( IR(beta) ) );
114 _muba[alpha].push_back( Factor( IR(beta) ) );
115 }
116 }
117 }
118
119
120 HAK::HAK( const RegionGraph &rg, const PropertySet &opts ) : DAIAlgRG(rg), _Qa(), _Qb(), _muab(), _muba(), _maxdiff(0.0), _iters(0U), props() {
121 setProperties( opts );
122
123 constructMessages();
124 }
125
126
127 void HAK::findLoopClusters( const FactorGraph & fg, std::set<VarSet> &allcl, VarSet newcl, const Var & root, size_t length, VarSet vars ) {
128 for( VarSet::const_iterator in = vars.begin(); in != vars.end(); in++ ) {
129 VarSet ind = fg.delta( fg.findVar( *in ) );
130 if( (newcl.size()) >= 2 && ind.contains( root ) ) {
131 allcl.insert( newcl | *in );
132 }
133 else if( length > 1 )
134 findLoopClusters( fg, allcl, newcl | *in, root, length - 1, ind / newcl );
135 }
136 }
137
138
139 HAK::HAK(const FactorGraph & fg, const PropertySet &opts) : DAIAlgRG(), _Qa(), _Qb(), _muab(), _muba(), _maxdiff(0.0), _iters(0U), props() {
140 setProperties( opts );
141
142 vector<VarSet> cl;
143 if( props.clusters == Properties::ClustersType::MIN ) {
144 cl = fg.Cliques();
145 } else if( props.clusters == Properties::ClustersType::DELTA ) {
146 for( size_t i = 0; i < fg.nrVars(); i++ )
147 cl.push_back(fg.Delta(i));
148 } else if( props.clusters == Properties::ClustersType::LOOP ) {
149 cl = fg.Cliques();
150 set<VarSet> scl;
151 for( size_t i0 = 0; i0 < fg.nrVars(); i0++ ) {
152 VarSet i0d = fg.delta(i0);
153 if( props.loopdepth > 1 )
154 findLoopClusters( fg, scl, fg.var(i0), fg.var(i0), props.loopdepth - 1, fg.delta(i0) );
155 }
156 for( set<VarSet>::const_iterator c = scl.begin(); c != scl.end(); c++ )
157 cl.push_back(*c);
158 if( props.verbose >= 3 ) {
159 cerr << Name << " uses the following clusters: " << endl;
160 for( vector<VarSet>::const_iterator cli = cl.begin(); cli != cl.end(); cli++ )
161 cerr << *cli << endl;
162 }
163 } else
164 DAI_THROW(UNKNOWN_ENUM_VALUE);
165
166 RegionGraph rg(fg,cl);
167 RegionGraph::operator=(rg);
168 constructMessages();
169
170 if( props.verbose >= 3 )
171 cerr << Name << " regiongraph: " << *this << endl;
172 }
173
174
175 string HAK::identify() const {
176 return string(Name) + printProperties();
177 }
178
179
180 void HAK::init( const VarSet &ns ) {
181 for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
182 if( alpha->vars().intersects( ns ) )
183 alpha->fill( 1.0 / alpha->states() );
184
185 for( size_t beta = 0; beta < nrIRs(); beta++ )
186 if( IR(beta).intersects( ns ) ) {
187 _Qb[beta].fill( 1.0 );
188 foreach( const Neighbor &alpha, nbIR(beta) ) {
189 size_t _beta = alpha.dual;
190 muab( alpha, _beta ).fill( 1.0 );
191 muba( alpha, _beta ).fill( 1.0 );
192 }
193 }
194 }
195
196
197 void HAK::init() {
198 for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
199 alpha->fill( 1.0 / alpha->states() );
200
201 for( vector<Factor>::iterator beta = _Qb.begin(); beta != _Qb.end(); beta++ )
202 beta->fill( 1.0 / beta->states() );
203
204 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
205 foreach( const Neighbor &beta, nbOR(alpha) ) {
206 size_t _beta = beta.iter;
207 muab( alpha, _beta ).fill( 1.0 / muab( alpha, _beta ).states() );
208 muba( alpha, _beta ).fill( 1.0 / muab( alpha, _beta ).states() );
209 }
210 }
211
212
213 double HAK::doGBP() {
214 if( props.verbose >= 1 )
215 cerr << "Starting " << identify() << "...";
216 if( props.verbose >= 3)
217 cerr << endl;
218
219 double tic = toc();
220
221 // Check whether counting numbers won't lead to problems
222 for( size_t beta = 0; beta < nrIRs(); beta++ )
223 assert( nbIR(beta).size() + IR(beta).c() != 0.0 );
224
225 // Keep old beliefs to check convergence
226 vector<Factor> old_beliefs;
227 old_beliefs.reserve( nrVars() );
228 for( size_t i = 0; i < nrVars(); i++ )
229 old_beliefs.push_back( belief( var(i) ) );
230
231 // Differences in single node beliefs
232 Diffs diffs(nrVars(), 1.0);
233
234 // do several passes over the network until maximum number of iterations has
235 // been reached or until the maximum belief difference is smaller than tolerance
236 for( _iters = 0; _iters < props.maxiter && diffs.maxDiff() > props.tol; _iters++ ) {
237 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
238 foreach( const Neighbor &alpha, nbIR(beta) ) {
239 size_t _beta = alpha.dual;
240 muab( alpha, _beta ) = _Qa[alpha].marginal(IR(beta)) / muba(alpha,_beta);
241 /* TODO: INVESTIGATE THIS PROBLEM
242 *
243 * In some cases, the muab's can have very large entries because the muba's have very
244 * small entries. This may cause NANs later on (e.g., multiplying large quantities may
245 * result in +inf; normalization then tries to calculate inf / inf which is NAN).
246 * A fix of this problem would consist in normalizing the messages muab.
247 * However, it is not obvious whether this is a real solution, because it has a
248 * negative performance impact and the NAN's seem to be a symptom of a fundamental
249 * numerical unstability.
250 */
251 muab(alpha,_beta).normalize();
252 }
253
254 Factor Qb_new;
255 foreach( const Neighbor &alpha, nbIR(beta) ) {
256 size_t _beta = alpha.dual;
257 Qb_new *= muab(alpha,_beta) ^ (1 / (nbIR(beta).size() + IR(beta).c()));
258 }
259
260 Qb_new.normalize();
261 if( Qb_new.hasNaNs() ) {
262 // TODO: WHAT TO DO IN THIS CASE?
263 cerr << Name << "::doGBP: Qb_new has NaNs!" << endl;
264 return 1.0;
265 }
266 /* TODO: WHAT IS THE PURPOSE OF THE FOLLOWING CODE?
267 *
268 * _Qb[beta] = Qb_new.makeZero(1e-100);
269 */
270
271 if( props.doubleloop || props.damping == 0.0 )
272 _Qb[beta] = Qb_new; // no damping for double loop
273 else
274 _Qb[beta] = (Qb_new^(1.0 - props.damping)) * (_Qb[beta]^props.damping);
275
276 foreach( const Neighbor &alpha, nbIR(beta) ) {
277 size_t _beta = alpha.dual;
278 muba(alpha,_beta) = _Qb[beta] / muab(alpha,_beta);
279
280 /* TODO: INVESTIGATE WHETHER THIS HACK (INVENTED BY KEES) TO PREVENT NANS MAKES SENSE
281 *
282 * muba(beta,*alpha).makePositive(1e-100);
283 *
284 */
285
286 Factor Qa_new = OR(alpha);
287 foreach( const Neighbor &gamma, nbOR(alpha) )
288 Qa_new *= muba(alpha,gamma.iter);
289 Qa_new ^= (1.0 / OR(alpha).c());
290 Qa_new.normalize();
291 if( Qa_new.hasNaNs() ) {
292 cerr << Name << "::doGBP: Qa_new has NaNs!" << endl;
293 return 1.0;
294 }
295 /* TODO: WHAT IS THE PURPOSE OF THE FOLLOWING CODE?
296 *
297 * _Qb[beta] = Qb_new.makeZero(1e-100);
298 */
299
300 if( props.doubleloop || props.damping == 0.0 )
301 _Qa[alpha] = Qa_new; // no damping for double loop
302 else
303 // FIXME: GEOMETRIC DAMPING IS SLOW!
304 _Qa[alpha] = (Qa_new^(1.0 - props.damping)) * (_Qa[alpha]^props.damping);
305 }
306 }
307
308 // Calculate new single variable beliefs and compare with old ones
309 for( size_t i = 0; i < nrVars(); i++ ) {
310 Factor new_belief = belief( var( i ) );
311 diffs.push( dist( new_belief, old_beliefs[i], Prob::DISTLINF ) );
312 old_beliefs[i] = new_belief;
313 }
314
315 if( props.verbose >= 3 )
316 cerr << Name << "::doGBP: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
317 }
318
319 if( diffs.maxDiff() > _maxdiff )
320 _maxdiff = diffs.maxDiff();
321
322 if( props.verbose >= 1 ) {
323 if( diffs.maxDiff() > props.tol ) {
324 if( props.verbose == 1 )
325 cerr << endl;
326 cerr << Name << "::doGBP: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
327 } else {
328 if( props.verbose >= 2 )
329 cerr << Name << "::doGBP: ";
330 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
331 }
332 }
333
334 return diffs.maxDiff();
335 }
336
337
338 double HAK::doDoubleLoop() {
339 if( props.verbose >= 1 )
340 cerr << "Starting " << identify() << "...";
341 if( props.verbose >= 3)
342 cerr << endl;
343
344 double tic = toc();
345
346 // Save original outer regions
347 vector<FRegion> org_ORs = ORs;
348
349 // Save original inner counting numbers and set negative counting numbers to zero
350 vector<double> org_IR_cs( nrIRs(), 0.0 );
351 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
352 org_IR_cs[beta] = IR(beta).c();
353 if( IR(beta).c() < 0.0 )
354 IR(beta).c() = 0.0;
355 }
356
357 // Keep old beliefs to check convergence
358 vector<Factor> old_beliefs;
359 old_beliefs.reserve( nrVars() );
360 for( size_t i = 0; i < nrVars(); i++ )
361 old_beliefs.push_back( belief( var(i) ) );
362
363 // Differences in single node beliefs
364 Diffs diffs(nrVars(), 1.0);
365
366 size_t outer_maxiter = props.maxiter;
367 double outer_tol = props.tol;
368 size_t outer_verbose = props.verbose;
369 double org_maxdiff = _maxdiff;
370
371 // Set parameters for inner loop
372 props.maxiter = 5;
373 props.verbose = outer_verbose ? outer_verbose - 1 : 0;
374
375 size_t outer_iter = 0;
376 size_t total_iter = 0;
377 for( outer_iter = 0; outer_iter < outer_maxiter && diffs.maxDiff() > outer_tol; outer_iter++ ) {
378 // Calculate new outer regions
379 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
380 OR(alpha) = org_ORs[alpha];
381 foreach( const Neighbor &beta, nbOR(alpha) )
382 OR(alpha) *= _Qb[beta] ^ ((IR(beta).c() - org_IR_cs[beta]) / nbIR(beta).size());
383 }
384
385 // Inner loop
386 if( isnan( doGBP() ) )
387 return 1.0;
388
389 // Calculate new single variable beliefs and compare with old ones
390 for( size_t i = 0; i < nrVars(); ++i ) {
391 Factor new_belief = belief( var( i ) );
392 diffs.push( dist( new_belief, old_beliefs[i], Prob::DISTLINF ) );
393 old_beliefs[i] = new_belief;
394 }
395
396 total_iter += Iterations();
397
398 if( props.verbose >= 3 )
399 cerr << Name << "::doDoubleLoop: maxdiff " << diffs.maxDiff() << " after " << total_iter << " passes" << endl;
400 }
401
402 // restore _maxiter, _verbose and _maxdiff
403 props.maxiter = outer_maxiter;
404 props.verbose = outer_verbose;
405 _maxdiff = org_maxdiff;
406
407 _iters = total_iter;
408 if( diffs.maxDiff() > _maxdiff )
409 _maxdiff = diffs.maxDiff();
410
411 // Restore original outer regions
412 ORs = org_ORs;
413
414 // Restore original inner counting numbers
415 for( size_t beta = 0; beta < nrIRs(); ++beta )
416 IR(beta).c() = org_IR_cs[beta];
417
418 if( props.verbose >= 1 ) {
419 if( diffs.maxDiff() > props.tol ) {
420 if( props.verbose == 1 )
421 cerr << endl;
422 cerr << Name << "::doDoubleLoop: WARNING: not converged within " << outer_maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
423 } else {
424 if( props.verbose >= 3 )
425 cerr << Name << "::doDoubleLoop: ";
426 cerr << "converged in " << total_iter << " passes (" << toc() - tic << " seconds)." << endl;
427 }
428 }
429
430 return diffs.maxDiff();
431 }
432
433
434 double HAK::run() {
435 if( props.doubleloop )
436 return doDoubleLoop();
437 else
438 return doGBP();
439 }
440
441
442 Factor HAK::belief( const VarSet &ns ) const {
443 vector<Factor>::const_iterator beta;
444 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
445 if( beta->vars() >> ns )
446 break;
447 if( beta != _Qb.end() )
448 return( beta->marginal(ns) );
449 else {
450 vector<Factor>::const_iterator alpha;
451 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
452 if( alpha->vars() >> ns )
453 break;
454 assert( alpha != _Qa.end() );
455 return( alpha->marginal(ns) );
456 }
457 }
458
459
460 Factor HAK::belief( const Var &n ) const {
461 return belief( (VarSet)n );
462 }
463
464
465 vector<Factor> HAK::beliefs() const {
466 vector<Factor> result;
467 for( size_t beta = 0; beta < nrIRs(); beta++ )
468 result.push_back( Qb(beta) );
469 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
470 result.push_back( Qa(alpha) );
471 return result;
472 }
473
474
475 Real HAK::logZ() const {
476 Real s = 0.0;
477 for( size_t beta = 0; beta < nrIRs(); beta++ )
478 s += IR(beta).c() * Qb(beta).entropy();
479 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
480 s += OR(alpha).c() * Qa(alpha).entropy();
481 s += (OR(alpha).log(true) * Qa(alpha)).sum();
482 }
483 return s;
484 }
485
486
487 } // end of namespace dai