Partial adoption of contributions by Giuseppe:
[libdai.git] / src / hak.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <map>
23 #include <dai/hak.h>
24 #include <dai/util.h>
25 #include <dai/diffs.h>
26
27
28 namespace dai {
29
30
31 using namespace std;
32
33
34 const char *HAK::Name = "HAK";
35
36
37 bool HAK::checkProperties() {
38 if( !HasProperty("tol") )
39 return false;
40 if (!HasProperty("maxiter") )
41 return false;
42 if (!HasProperty("verbose") )
43 return false;
44 if( !HasProperty("doubleloop") )
45 return false;
46 if( !HasProperty("clusters") )
47 return false;
48
49 ConvertPropertyTo<double>("tol");
50 ConvertPropertyTo<size_t>("maxiter");
51 ConvertPropertyTo<size_t>("verbose");
52 ConvertPropertyTo<bool>("doubleloop");
53 ConvertPropertyTo<ClustersType>("clusters");
54
55 if( HasProperty("loopdepth") )
56 ConvertPropertyTo<size_t>("loopdepth");
57 else if( Clusters() == ClustersType::LOOP )
58 return false;
59
60 return true;
61 }
62
63
64 void HAK::constructMessages() {
65 // Create outer beliefs
66 _Qa.clear();
67 _Qa.reserve(nr_ORs());
68 for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
69 _Qa.push_back( Factor( OR(alpha).vars() ) );
70
71 // Create inner beliefs
72 _Qb.clear();
73 _Qb.reserve(nr_IRs());
74 for( size_t beta = 0; beta < nr_IRs(); beta++ )
75 _Qb.push_back( Factor( IR(beta) ) );
76
77 // Create messages
78 _muab.clear();
79 _muab.reserve(nr_Redges());
80 _muba.clear();
81 _muba.reserve(nr_Redges());
82 for( vector<R_edge_t>::const_iterator ab = Redges().begin(); ab != Redges().end(); ab++ ) {
83 _muab.push_back( Factor( IR(ab->second) ) );
84 _muba.push_back( Factor( IR(ab->second) ) );
85 }
86 }
87
88
89 HAK::HAK(const RegionGraph & rg, const Properties &opts) : DAIAlgRG(rg, opts) {
90 assert( checkProperties() );
91
92 constructMessages();
93 }
94
95
96 void HAK::findLoopClusters( const FactorGraph & fg, set<VarSet> &allcl, VarSet newcl, const Var & root, size_t length, VarSet vars ) {
97 for( VarSet::const_iterator in = vars.begin(); in != vars.end(); in++ ) {
98 VarSet ind = fg.delta( *in );
99 if( (newcl.size()) >= 2 && (ind >> root) ) {
100 allcl.insert( newcl | *in );
101 }
102 else if( length > 1 )
103 findLoopClusters( fg, allcl, newcl | *in, root, length - 1, ind / newcl );
104 }
105 }
106
107
108 HAK::HAK(const FactorGraph & fg, const Properties &opts) : DAIAlgRG(opts) {
109 assert( checkProperties() );
110
111 vector<VarSet> cl;
112 if( Clusters() == ClustersType::MIN ) {
113 cl = fg.Cliques();
114 } else if( Clusters() == ClustersType::DELTA ) {
115 for( size_t i = 0; i < fg.nrVars(); i++ )
116 cl.push_back(fg.Delta(fg.var(i)));
117 } else if( Clusters() == ClustersType::LOOP ) {
118 cl = fg.Cliques();
119 set<VarSet> scl;
120 for( vector<Var>::const_iterator i0 = fg.vars().begin(); i0 != fg.vars().end(); i0++ ) {
121 VarSet i0d = fg.delta(*i0);
122 if( LoopDepth() > 1 )
123 findLoopClusters( fg, scl, *i0, *i0, LoopDepth() - 1, fg.delta(*i0) );
124 }
125 for( set<VarSet>::const_iterator c = scl.begin(); c != scl.end(); c++ )
126 cl.push_back(*c);
127 if( Verbose() >= 3 ) {
128 cout << "HAK uses the following clusters: " << endl;
129 for( vector<VarSet>::const_iterator cli = cl.begin(); cli != cl.end(); cli++ )
130 cout << *cli << endl;
131 }
132 } else
133 throw "Invalid Clusters type";
134
135 RegionGraph rg(fg,cl);
136 RegionGraph::operator=(rg);
137 constructMessages();
138
139 if( Verbose() >= 3 )
140 cout << "HAK regiongraph: " << *this << endl;
141 }
142
143
144 string HAK::identify() const {
145 stringstream result (stringstream::out);
146 result << Name << GetProperties();
147 return result.str();
148 }
149
150
151 void HAK::init( const VarSet &ns ) {
152 for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
153 if( alpha->vars() && ns )
154 alpha->fill( 1.0 / alpha->stateSpace() );
155
156 for( size_t beta = 0; beta < nr_IRs(); beta++ )
157 if( IR(beta) && ns ) {
158 _Qb[beta].fill( 1.0 / IR(beta).stateSpace() );
159 for( R_nb_cit alpha = nbIR(beta).begin(); alpha != nbIR(beta).end(); alpha++ ) {
160 muab(*alpha,beta).fill( 1.0 / IR(beta).stateSpace() );
161 muba(beta,*alpha).fill( 1.0 / IR(beta).stateSpace() );
162 }
163 }
164 }
165
166
167 void HAK::init() {
168 assert( checkProperties() );
169
170 for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
171 alpha->fill( 1.0 / alpha->stateSpace() );
172
173 for( vector<Factor>::iterator beta = _Qb.begin(); beta != _Qb.end(); beta++ )
174 beta->fill( 1.0 / beta->stateSpace() );
175
176 for( size_t ab = 0; ab < nr_Redges(); ab++ ) {
177 _muab[ab].fill( 1.0 / _muab[ab].stateSpace() );
178 _muba[ab].fill( 1.0 / _muba[ab].stateSpace() );
179 }
180 }
181
182
183 double HAK::doGBP() {
184 if( Verbose() >= 1 )
185 cout << "Starting " << identify() << "...";
186 if( Verbose() >= 3)
187 cout << endl;
188
189 clock_t tic = toc();
190
191 // Check whether counting numbers won't lead to problems
192 for( size_t beta = 0; beta < nr_IRs(); beta++ )
193 assert( nbIR(beta).size() + IR(beta).c() != 0.0 );
194
195 // Keep old beliefs to check convergence
196 vector<Factor> old_beliefs;
197 old_beliefs.reserve( nrVars() );
198 for( size_t i = 0; i < nrVars(); i++ )
199 old_beliefs.push_back( belief( var(i) ) );
200
201 // Differences in single node beliefs
202 Diffs diffs(nrVars(), 1.0);
203
204 size_t iter = 0;
205 // do several passes over the network until maximum number of iterations has
206 // been reached or until the maximum belief difference is smaller than tolerance
207 for( iter = 0; iter < MaxIter() && diffs.max() > Tol(); iter++ ) {
208 for( size_t beta = 0; beta < nr_IRs(); beta++ ) {
209 for( R_nb_cit alpha = nbIR(beta).begin(); alpha != nbIR(beta).end(); alpha++ )
210 muab(*alpha,beta) = _Qa[*alpha].marginal(IR(beta)).divided_by( muba(beta,*alpha) );
211
212 Factor Qb_new;
213 for( R_nb_cit alpha = nbIR(beta).begin(); alpha != nbIR(beta).end(); alpha++ )
214 Qb_new *= muab(*alpha,beta) ^ (1 / (nbIR(beta).size() + IR(beta).c()));
215 Qb_new.normalize( _normtype );
216 if( Qb_new.hasNaNs() ) {
217 cout << "HAK::doGBP: Qb_new has NaNs!" << endl;
218 return NAN;
219 }
220 // _Qb[beta] = Qb_new.makeZero(1e-100); // damping?
221 _Qb[beta] = Qb_new;
222
223 for( R_nb_cit alpha = nbIR(beta).begin(); alpha != nbIR(beta).end(); alpha++ ) {
224 muba(beta,*alpha) = _Qb[beta].divided_by( muab(*alpha,beta) );
225
226 Factor Qa_new = OR(*alpha);
227 for( R_nb_cit gamma = nbOR(*alpha).begin(); gamma != nbOR(*alpha).end(); gamma++ )
228 Qa_new *= muba(*gamma,*alpha);
229 Qa_new ^= (1.0 / OR(*alpha).c());
230 Qa_new.normalize( _normtype );
231 if( Qa_new.hasNaNs() ) {
232 cout << "HAK::doGBP: Qa_new has NaNs!" << endl;
233 return NAN;
234 }
235 // _Qa[*alpha] = Qa_new.makeZero(1e-100); // damping?
236 _Qa[*alpha] = Qa_new;
237 }
238 }
239
240 // Calculate new single variable beliefs and compare with old ones
241 for( size_t i = 0; i < nrVars(); i++ ) {
242 Factor new_belief = belief( var( i ) );
243 diffs.push( dist( new_belief, old_beliefs[i], Prob::DISTLINF ) );
244 old_beliefs[i] = new_belief;
245 }
246
247 if( Verbose() >= 3 )
248 cout << "HAK::doGBP: maxdiff " << diffs.max() << " after " << iter+1 << " passes" << endl;
249 }
250
251 updateMaxDiff( diffs.max() );
252
253 if( Verbose() >= 1 ) {
254 if( diffs.max() > Tol() ) {
255 if( Verbose() == 1 )
256 cout << endl;
257 cout << "HAK::doGBP: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
258 } else {
259 if( Verbose() >= 2 )
260 cout << "HAK::doGBP: ";
261 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
262 }
263 }
264
265 return diffs.max();
266 }
267
268
269 double HAK::doDoubleLoop() {
270 if( Verbose() >= 1 )
271 cout << "Starting " << identify() << "...";
272 if( Verbose() >= 3)
273 cout << endl;
274
275 clock_t tic = toc();
276
277 // Save original outer regions
278 vector<FRegion> org_ORs = ORs();
279
280 // Save original inner counting numbers and set negative counting numbers to zero
281 vector<double> org_IR_cs( nr_IRs(), 0.0 );
282 for( size_t beta = 0; beta < nr_IRs(); beta++ ) {
283 org_IR_cs[beta] = IR(beta).c();
284 if( IR(beta).c() < 0.0 )
285 IR(beta).c() = 0.0;
286 }
287
288 // Keep old beliefs to check convergence
289 vector<Factor> old_beliefs;
290 old_beliefs.reserve( nrVars() );
291 for( size_t i = 0; i < nrVars(); i++ )
292 old_beliefs.push_back( belief( var(i) ) );
293
294 // Differences in single node beliefs
295 Diffs diffs(nrVars(), 1.0);
296
297 size_t outer_maxiter = MaxIter();
298 double outer_tol = Tol();
299 size_t outer_verbose = Verbose();
300 double org_maxdiff = MaxDiff();
301
302 // Set parameters for inner loop
303 MaxIter( 5 );
304 Verbose( outer_verbose ? outer_verbose - 1 : 0 );
305
306 size_t outer_iter = 0;
307 for( outer_iter = 0; outer_iter < outer_maxiter && diffs.max() > outer_tol; outer_iter++ ) {
308 // Calculate new outer regions
309 for( size_t alpha = 0; alpha < nr_ORs(); alpha++ ) {
310 OR(alpha) = org_ORs[alpha];
311 for( R_nb_cit beta = nbOR(alpha).begin(); beta != nbOR(alpha).end(); beta++ )
312 OR(alpha) *= _Qb[*beta] ^ ((IR(*beta).c() - org_IR_cs[*beta]) / nbIR(*beta).size());
313 }
314
315 // Inner loop
316 if( isnan( doGBP() ) )
317 return NAN;
318
319 // Calculate new single variable beliefs and compare with old ones
320 for( size_t i = 0; i < nrVars(); ++i ) {
321 Factor new_belief = belief( var( i ) );
322 diffs.push( dist( new_belief, old_beliefs[i], Prob::DISTLINF ) );
323 old_beliefs[i] = new_belief;
324 }
325
326 if( Verbose() >= 3 )
327 cout << "HAK::doDoubleLoop: maxdiff " << diffs.max() << " after " << outer_iter+1 << " passes" << endl;
328 }
329
330 // restore _maxiter, _verbose and _maxdiff
331 MaxIter( outer_maxiter );
332 Verbose( outer_verbose );
333 MaxDiff( org_maxdiff );
334
335 updateMaxDiff( diffs.max() );
336
337 // Restore original outer regions
338 ORs() = org_ORs;
339
340 // Restore original inner counting numbers
341 for( size_t beta = 0; beta < nr_IRs(); ++beta )
342 IR(beta).c() = org_IR_cs[beta];
343
344 if( Verbose() >= 1 ) {
345 if( diffs.max() > Tol() ) {
346 if( Verbose() == 1 )
347 cout << endl;
348 cout << "HAK::doDoubleLoop: WARNING: not converged within " << outer_maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
349 } else {
350 if( Verbose() >= 3 )
351 cout << "HAK::doDoubleLoop: ";
352 cout << "converged in " << outer_iter << " passes (" << toc() - tic << " clocks)." << endl;
353 }
354 }
355
356 return diffs.max();
357 }
358
359
360 double HAK::run() {
361 if( DoubleLoop() )
362 return doDoubleLoop();
363 else
364 return doGBP();
365 }
366
367
368 Factor HAK::belief( const VarSet &ns ) const {
369 vector<Factor>::const_iterator beta;
370 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
371 if( beta->vars() >> ns )
372 break;
373 if( beta != _Qb.end() )
374 return( beta->marginal(ns) );
375 else {
376 vector<Factor>::const_iterator alpha;
377 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
378 if( alpha->vars() >> ns )
379 break;
380 assert( alpha != _Qa.end() );
381 return( alpha->marginal(ns) );
382 }
383 }
384
385
386 Factor HAK::belief( const Var &n ) const {
387 return belief( (VarSet)n );
388 }
389
390
391 vector<Factor> HAK::beliefs() const {
392 vector<Factor> result;
393 for( size_t beta = 0; beta < nr_IRs(); beta++ )
394 result.push_back( Qb(beta) );
395 for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
396 result.push_back( Qa(alpha) );
397 return result;
398 }
399
400
401 Complex HAK::logZ() const {
402 Complex sum = 0.0;
403 for( size_t beta = 0; beta < nr_IRs(); beta++ )
404 sum += Complex(IR(beta).c()) * Qb(beta).entropy();
405 for( size_t alpha = 0; alpha < nr_ORs(); alpha++ ) {
406 sum += Complex(OR(alpha).c()) * Qa(alpha).entropy();
407 sum += (OR(alpha).log0() * Qa(alpha)).totalSum();
408 }
409 return sum;
410 }
411
412
413 } // end of namespace dai