Added Exceptions framework (and more)
[libdai.git] / src / hak.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <map>
23 #include <dai/hak.h>
24 #include <dai/util.h>
25 #include <dai/diffs.h>
26 #include <dai/exceptions.h>
27
28
29 namespace dai {
30
31
32 using namespace std;
33
34
35 const char *HAK::Name = "HAK";
36
37
38 bool HAK::checkProperties() {
39 if( !HasProperty("tol") )
40 return false;
41 if (!HasProperty("maxiter") )
42 return false;
43 if (!HasProperty("verbose") )
44 return false;
45 if( !HasProperty("doubleloop") )
46 return false;
47 if( !HasProperty("clusters") )
48 return false;
49
50 ConvertPropertyTo<double>("tol");
51 ConvertPropertyTo<size_t>("maxiter");
52 ConvertPropertyTo<size_t>("verbose");
53 ConvertPropertyTo<bool>("doubleloop");
54 ConvertPropertyTo<ClustersType>("clusters");
55
56 if( HasProperty("loopdepth") )
57 ConvertPropertyTo<size_t>("loopdepth");
58 else if( Clusters() == ClustersType::LOOP )
59 return false;
60
61 return true;
62 }
63
64
65 void HAK::constructMessages() {
66 // Create outer beliefs
67 _Qa.clear();
68 _Qa.reserve(nrORs());
69 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
70 _Qa.push_back( Factor( OR(alpha).vars() ) );
71
72 // Create inner beliefs
73 _Qb.clear();
74 _Qb.reserve(nrIRs());
75 for( size_t beta = 0; beta < nrIRs(); beta++ )
76 _Qb.push_back( Factor( IR(beta) ) );
77
78 // Create messages
79 _muab.clear();
80 _muab.reserve( nrORs() );
81 _muba.clear();
82 _muba.reserve( nrORs() );
83 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
84 _muab.push_back( vector<Factor>() );
85 _muba.push_back( vector<Factor>() );
86 _muab[alpha].reserve( nbOR(alpha).size() );
87 _muba[alpha].reserve( nbOR(alpha).size() );
88 foreach( const Neighbor &beta, nbOR(alpha) ) {
89 _muab[alpha].push_back( Factor( IR(beta) ) );
90 _muba[alpha].push_back( Factor( IR(beta) ) );
91 }
92 }
93 }
94
95
96 HAK::HAK(const RegionGraph & rg, const Properties &opts) : DAIAlgRG(rg, opts) {
97 assert( checkProperties() );
98
99 constructMessages();
100 }
101
102
103 void HAK::findLoopClusters( const FactorGraph & fg, std::set<VarSet> &allcl, VarSet newcl, const Var & root, size_t length, VarSet vars ) {
104 for( VarSet::const_iterator in = vars.begin(); in != vars.end(); in++ ) {
105 VarSet ind = fg.delta( fg.findVar( *in ) );
106 if( (newcl.size()) >= 2 && (ind >> root) ) {
107 allcl.insert( newcl | *in );
108 }
109 else if( length > 1 )
110 findLoopClusters( fg, allcl, newcl | *in, root, length - 1, ind / newcl );
111 }
112 }
113
114
115 HAK::HAK(const FactorGraph & fg, const Properties &opts) : DAIAlgRG(opts) {
116 assert( checkProperties() );
117
118 vector<VarSet> cl;
119 if( Clusters() == ClustersType::MIN ) {
120 cl = fg.Cliques();
121 } else if( Clusters() == ClustersType::DELTA ) {
122 for( size_t i = 0; i < fg.nrVars(); i++ )
123 cl.push_back(fg.Delta(i));
124 } else if( Clusters() == ClustersType::LOOP ) {
125 cl = fg.Cliques();
126 set<VarSet> scl;
127 for( size_t i0 = 0; i0 < fg.nrVars(); i0++ ) {
128 VarSet i0d = fg.delta(i0);
129 if( LoopDepth() > 1 )
130 findLoopClusters( fg, scl, fg.var(i0), fg.var(i0), LoopDepth() - 1, fg.delta(i0) );
131 }
132 for( set<VarSet>::const_iterator c = scl.begin(); c != scl.end(); c++ )
133 cl.push_back(*c);
134 if( Verbose() >= 3 ) {
135 cout << "HAK uses the following clusters: " << endl;
136 for( vector<VarSet>::const_iterator cli = cl.begin(); cli != cl.end(); cli++ )
137 cout << *cli << endl;
138 }
139 } else
140 DAI_THROW(INTERNAL_ERROR);
141
142 RegionGraph rg(fg,cl);
143 RegionGraph::operator=(rg);
144 constructMessages();
145
146 if( Verbose() >= 3 )
147 cout << "HAK regiongraph: " << *this << endl;
148 }
149
150
151 string HAK::identify() const {
152 stringstream result (stringstream::out);
153 result << Name << GetProperties();
154 return result.str();
155 }
156
157
158 void HAK::init( const VarSet &ns ) {
159 for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
160 if( alpha->vars().intersects( ns ) )
161 alpha->fill( 1.0 / alpha->states() );
162
163 for( size_t beta = 0; beta < nrIRs(); beta++ )
164 if( IR(beta).intersects( ns ) ) {
165 _Qb[beta].fill( 1.0 );
166 foreach( const Neighbor &alpha, nbIR(beta) ) {
167 size_t _beta = alpha.dual;
168 muab( alpha, _beta ).fill( 1.0 / IR(beta).states() );
169 muba( alpha, _beta ).fill( 1.0 / IR(beta).states() );
170 }
171 }
172 }
173
174
175 void HAK::init() {
176 assert( checkProperties() );
177
178 for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
179 alpha->fill( 1.0 / alpha->states() );
180
181 for( vector<Factor>::iterator beta = _Qb.begin(); beta != _Qb.end(); beta++ )
182 beta->fill( 1.0 / beta->states() );
183
184 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
185 foreach( const Neighbor &beta, nbOR(alpha) ) {
186 size_t _beta = beta.iter;
187 muab( alpha, _beta ).fill( 1.0 / muab( alpha, _beta ).states() );
188 muba( alpha, _beta ).fill( 1.0 / muab( alpha, _beta ).states() );
189 }
190 }
191
192
193 double HAK::doGBP() {
194 if( Verbose() >= 1 )
195 cout << "Starting " << identify() << "...";
196 if( Verbose() >= 3)
197 cout << endl;
198
199 double tic = toc();
200
201 // Check whether counting numbers won't lead to problems
202 for( size_t beta = 0; beta < nrIRs(); beta++ )
203 assert( nbIR(beta).size() + IR(beta).c() != 0.0 );
204
205 // Keep old beliefs to check convergence
206 vector<Factor> old_beliefs;
207 old_beliefs.reserve( nrVars() );
208 for( size_t i = 0; i < nrVars(); i++ )
209 old_beliefs.push_back( belief( var(i) ) );
210
211 // Differences in single node beliefs
212 Diffs diffs(nrVars(), 1.0);
213
214 size_t iter = 0;
215 // do several passes over the network until maximum number of iterations has
216 // been reached or until the maximum belief difference is smaller than tolerance
217 for( iter = 0; iter < MaxIter() && diffs.maxDiff() > Tol(); iter++ ) {
218 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
219 foreach( const Neighbor &alpha, nbIR(beta) ) {
220 size_t _beta = alpha.dual;
221 muab( alpha, _beta ) = _Qa[alpha].marginal(IR(beta)).divided_by( muba(alpha,_beta) );
222 }
223
224 Factor Qb_new;
225 foreach( const Neighbor &alpha, nbIR(beta) ) {
226 size_t _beta = alpha.dual;
227 Qb_new *= muab(alpha,_beta) ^ (1 / (nbIR(beta).size() + IR(beta).c()));
228 }
229
230 Qb_new.normalize( _normtype );
231 if( Qb_new.hasNaNs() ) {
232 cout << "HAK::doGBP: Qb_new has NaNs!" << endl;
233 return 1.0;
234 }
235 // _Qb[beta] = Qb_new.makeZero(1e-100); // damping?
236 _Qb[beta] = Qb_new;
237
238 foreach( const Neighbor &alpha, nbIR(beta) ) {
239 size_t _beta = alpha.dual;
240
241 muba(alpha,_beta) = _Qb[beta].divided_by( muab(alpha,_beta) );
242
243 Factor Qa_new = OR(alpha);
244 foreach( const Neighbor &gamma, nbOR(alpha) )
245 Qa_new *= muba(alpha,gamma.iter);
246 Qa_new ^= (1.0 / OR(alpha).c());
247 Qa_new.normalize( _normtype );
248 if( Qa_new.hasNaNs() ) {
249 cout << "HAK::doGBP: Qa_new has NaNs!" << endl;
250 return 1.0;
251 }
252 // _Qa[alpha] = Qa_new.makeZero(1e-100); // damping?
253 _Qa[alpha] = Qa_new;
254 }
255 }
256
257 // Calculate new single variable beliefs and compare with old ones
258 for( size_t i = 0; i < nrVars(); i++ ) {
259 Factor new_belief = belief( var( i ) );
260 diffs.push( dist( new_belief, old_beliefs[i], Prob::DISTLINF ) );
261 old_beliefs[i] = new_belief;
262 }
263
264 if( Verbose() >= 3 )
265 cout << "HAK::doGBP: maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
266 }
267
268 updateMaxDiff( diffs.maxDiff() );
269
270 if( Verbose() >= 1 ) {
271 if( diffs.maxDiff() > Tol() ) {
272 if( Verbose() == 1 )
273 cout << endl;
274 cout << "HAK::doGBP: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
275 } else {
276 if( Verbose() >= 2 )
277 cout << "HAK::doGBP: ";
278 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
279 }
280 }
281
282 return diffs.maxDiff();
283 }
284
285
286 double HAK::doDoubleLoop() {
287 if( Verbose() >= 1 )
288 cout << "Starting " << identify() << "...";
289 if( Verbose() >= 3)
290 cout << endl;
291
292 double tic = toc();
293
294 // Save original outer regions
295 vector<FRegion> org_ORs = ORs;
296
297 // Save original inner counting numbers and set negative counting numbers to zero
298 vector<double> org_IR_cs( nrIRs(), 0.0 );
299 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
300 org_IR_cs[beta] = IR(beta).c();
301 if( IR(beta).c() < 0.0 )
302 IR(beta).c() = 0.0;
303 }
304
305 // Keep old beliefs to check convergence
306 vector<Factor> old_beliefs;
307 old_beliefs.reserve( nrVars() );
308 for( size_t i = 0; i < nrVars(); i++ )
309 old_beliefs.push_back( belief( var(i) ) );
310
311 // Differences in single node beliefs
312 Diffs diffs(nrVars(), 1.0);
313
314 size_t outer_maxiter = MaxIter();
315 double outer_tol = Tol();
316 size_t outer_verbose = Verbose();
317 double org_maxdiff = MaxDiff();
318
319 // Set parameters for inner loop
320 MaxIter( 5 );
321 Verbose( outer_verbose ? outer_verbose - 1 : 0 );
322
323 size_t outer_iter = 0;
324 for( outer_iter = 0; outer_iter < outer_maxiter && diffs.maxDiff() > outer_tol; outer_iter++ ) {
325 // Calculate new outer regions
326 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
327 OR(alpha) = org_ORs[alpha];
328 foreach( const Neighbor &beta, nbOR(alpha) )
329 OR(alpha) *= _Qb[beta] ^ ((IR(beta).c() - org_IR_cs[beta]) / nbIR(beta).size());
330 }
331
332 // Inner loop
333 if( isnan( doGBP() ) )
334 return 1.0;
335
336 // Calculate new single variable beliefs and compare with old ones
337 for( size_t i = 0; i < nrVars(); ++i ) {
338 Factor new_belief = belief( var( i ) );
339 diffs.push( dist( new_belief, old_beliefs[i], Prob::DISTLINF ) );
340 old_beliefs[i] = new_belief;
341 }
342
343 if( Verbose() >= 3 )
344 cout << "HAK::doDoubleLoop: maxdiff " << diffs.maxDiff() << " after " << outer_iter+1 << " passes" << endl;
345 }
346
347 // restore _maxiter, _verbose and _maxdiff
348 MaxIter( outer_maxiter );
349 Verbose( outer_verbose );
350 MaxDiff( org_maxdiff );
351
352 updateMaxDiff( diffs.maxDiff() );
353
354 // Restore original outer regions
355 ORs = org_ORs;
356
357 // Restore original inner counting numbers
358 for( size_t beta = 0; beta < nrIRs(); ++beta )
359 IR(beta).c() = org_IR_cs[beta];
360
361 if( Verbose() >= 1 ) {
362 if( diffs.maxDiff() > Tol() ) {
363 if( Verbose() == 1 )
364 cout << endl;
365 cout << "HAK::doDoubleLoop: WARNING: not converged within " << outer_maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
366 } else {
367 if( Verbose() >= 3 )
368 cout << "HAK::doDoubleLoop: ";
369 cout << "converged in " << outer_iter << " passes (" << toc() - tic << " clocks)." << endl;
370 }
371 }
372
373 return diffs.maxDiff();
374 }
375
376
377 double HAK::run() {
378 if( DoubleLoop() )
379 return doDoubleLoop();
380 else
381 return doGBP();
382 }
383
384
385 Factor HAK::belief( const VarSet &ns ) const {
386 vector<Factor>::const_iterator beta;
387 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
388 if( beta->vars() >> ns )
389 break;
390 if( beta != _Qb.end() )
391 return( beta->marginal(ns) );
392 else {
393 vector<Factor>::const_iterator alpha;
394 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
395 if( alpha->vars() >> ns )
396 break;
397 assert( alpha != _Qa.end() );
398 return( alpha->marginal(ns) );
399 }
400 }
401
402
403 Factor HAK::belief( const Var &n ) const {
404 return belief( (VarSet)n );
405 }
406
407
408 vector<Factor> HAK::beliefs() const {
409 vector<Factor> result;
410 for( size_t beta = 0; beta < nrIRs(); beta++ )
411 result.push_back( Qb(beta) );
412 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
413 result.push_back( Qa(alpha) );
414 return result;
415 }
416
417
418 Complex HAK::logZ() const {
419 Complex sum = 0.0;
420 for( size_t beta = 0; beta < nrIRs(); beta++ )
421 sum += Complex(IR(beta).c()) * Qb(beta).entropy();
422 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
423 sum += Complex(OR(alpha).c()) * Qa(alpha).entropy();
424 sum += (OR(alpha).log0() * Qa(alpha)).totalSum();
425 }
426 return sum;
427 }
428
429
430 } // end of namespace dai