Adopted contributions by Christian.
[libdai.git] / hak.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <map>
23 #include "hak.h"
24 #include "util.h"
25 #include "diffs.h"
26
27
28 namespace dai {
29
30
31 const char *HAK::Name = "HAK";
32
33
34 bool HAK::checkProperties() {
35 if( !HasProperty("tol") )
36 return false;
37 if (!HasProperty("maxiter") )
38 return false;
39 if (!HasProperty("verbose") )
40 return false;
41 if( !HasProperty("doubleloop") )
42 return false;
43 if( !HasProperty("clusters") )
44 return false;
45
46 ConvertPropertyTo<double>("tol");
47 ConvertPropertyTo<size_t>("maxiter");
48 ConvertPropertyTo<size_t>("verbose");
49 ConvertPropertyTo<bool>("doubleloop");
50 ConvertPropertyTo<ClustersType>("clusters");
51
52 if( HasProperty("loopdepth") )
53 ConvertPropertyTo<size_t>("loopdepth");
54 else if( Clusters() == ClustersType::LOOP )
55 return false;
56
57 return true;
58 }
59
60
61 void HAK::constructMessages() {
62 // Create outer beliefs
63 _Qa.clear();
64 _Qa.reserve(nr_ORs());
65 for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
66 _Qa.push_back( Factor( OR(alpha).vars() ) );
67
68 // Create inner beliefs
69 _Qb.clear();
70 _Qb.reserve(nr_IRs());
71 for( size_t beta = 0; beta < nr_IRs(); beta++ )
72 _Qb.push_back( Factor( IR(beta) ) );
73
74 // Create messages
75 _muab.clear();
76 _muab.reserve(nr_Redges());
77 _muba.clear();
78 _muba.reserve(nr_Redges());
79 for( vector<R_edge_t>::const_iterator ab = Redges().begin(); ab != Redges().end(); ab++ ) {
80 _muab.push_back( Factor( IR(ab->second) ) );
81 _muba.push_back( Factor( IR(ab->second) ) );
82 }
83 }
84
85
86 HAK::HAK(const RegionGraph & rg, const Properties &opts) : DAIAlgRG(rg, opts) {
87 assert( checkProperties() );
88
89 constructMessages();
90 }
91
92
93 void HAK::findLoopClusters( const FactorGraph & fg, set<VarSet> &allcl, VarSet newcl, const Var & root, size_t length, VarSet vars ) {
94 for( VarSet::const_iterator in = vars.begin(); in != vars.end(); in++ ) {
95 VarSet ind = fg.delta( *in );
96 if( (newcl.size()) >= 2 && (ind >> root) ) {
97 allcl.insert( newcl | *in );
98 }
99 else if( length > 1 )
100 findLoopClusters( fg, allcl, newcl | *in, root, length - 1, ind / newcl );
101 }
102 }
103
104
105 HAK::HAK(const FactorGraph & fg, const Properties &opts) : DAIAlgRG(opts) {
106 assert( checkProperties() );
107
108 vector<VarSet> cl;
109 if( Clusters() == ClustersType::MIN ) {
110 cl = fg.Cliques();
111 } else if( Clusters() == ClustersType::DELTA ) {
112 for( size_t i = 0; i < fg.nrVars(); i++ )
113 cl.push_back(fg.Delta(fg.var(i)));
114 } else if( Clusters() == ClustersType::LOOP ) {
115 cl = fg.Cliques();
116 set<VarSet> scl;
117 for( vector<Var>::const_iterator i0 = fg.vars().begin(); i0 != fg.vars().end(); i0++ ) {
118 VarSet i0d = fg.delta(*i0);
119 if( LoopDepth() > 1 )
120 findLoopClusters( fg, scl, *i0, *i0, LoopDepth() - 1, fg.delta(*i0) );
121 }
122 for( set<VarSet>::const_iterator c = scl.begin(); c != scl.end(); c++ )
123 cl.push_back(*c);
124 if( Verbose() >= 3 ) {
125 cout << "HAK uses the following clusters: " << endl;
126 for( vector<VarSet>::const_iterator cli = cl.begin(); cli != cl.end(); cli++ )
127 cout << *cli << endl;
128 }
129 } else
130 throw "Invalid Clusters type";
131
132 RegionGraph rg(fg,cl);
133 RegionGraph::operator=(rg);
134 constructMessages();
135
136 if( Verbose() >= 3 )
137 cout << "HAK regiongraph: " << *this << endl;
138 }
139
140
141 string HAK::identify() const {
142 stringstream result (stringstream::out);
143 result << Name << GetProperties();
144 return result.str();
145 }
146
147
148 void HAK::init( const VarSet &ns ) {
149 for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
150 if( alpha->vars() && ns )
151 alpha->fill( 1.0 / alpha->stateSpace() );
152
153 for( size_t beta = 0; beta < nr_IRs(); beta++ )
154 if( IR(beta) && ns ) {
155 _Qb[beta].fill( 1.0 / IR(beta).stateSpace() );
156 for( R_nb_cit alpha = nbIR(beta).begin(); alpha != nbIR(beta).end(); alpha++ ) {
157 muab(*alpha,beta).fill( 1.0 / IR(beta).stateSpace() );
158 muba(beta,*alpha).fill( 1.0 / IR(beta).stateSpace() );
159 }
160 }
161 }
162
163
164 void HAK::init() {
165 assert( checkProperties() );
166
167 for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
168 alpha->fill( 1.0 / alpha->stateSpace() );
169
170 for( vector<Factor>::iterator beta = _Qb.begin(); beta != _Qb.end(); beta++ )
171 beta->fill( 1.0 / beta->stateSpace() );
172
173 for( size_t ab = 0; ab < nr_Redges(); ab++ ) {
174 _muab[ab].fill( 1.0 / _muab[ab].stateSpace() );
175 _muba[ab].fill( 1.0 / _muba[ab].stateSpace() );
176 }
177 }
178
179
180 double HAK::doGBP() {
181 if( Verbose() >= 1 )
182 cout << "Starting " << identify() << "...";
183 if( Verbose() >= 3)
184 cout << endl;
185
186 clock_t tic = toc();
187
188 // Check whether counting numbers won't lead to problems
189 for( size_t beta = 0; beta < nr_IRs(); beta++ )
190 assert( nbIR(beta).size() + IR(beta).c() != 0.0 );
191
192 // Keep old beliefs to check convergence
193 vector<Factor> old_beliefs;
194 old_beliefs.reserve( nrVars() );
195 for( size_t i = 0; i < nrVars(); i++ )
196 old_beliefs.push_back( belief( var(i) ) );
197
198 // Differences in single node beliefs
199 Diffs diffs(nrVars(), 1.0);
200
201 size_t iter = 0;
202 // do several passes over the network until maximum number of iterations has
203 // been reached or until the maximum belief difference is smaller than tolerance
204 for( iter = 0; iter < MaxIter() && diffs.max() > Tol(); iter++ ) {
205 for( size_t beta = 0; beta < nr_IRs(); beta++ ) {
206 for( R_nb_cit alpha = nbIR(beta).begin(); alpha != nbIR(beta).end(); alpha++ )
207 muab(*alpha,beta) = _Qa[*alpha].marginal(IR(beta)).divided_by( muba(beta,*alpha) );
208
209 Factor Qb_new;
210 for( R_nb_cit alpha = nbIR(beta).begin(); alpha != nbIR(beta).end(); alpha++ )
211 Qb_new *= muab(*alpha,beta) ^ (1 / (nbIR(beta).size() + IR(beta).c()));
212 Qb_new.normalize( _normtype );
213 if( Qb_new.hasNaNs() ) {
214 cout << "HAK::doGBP: Qb_new has NaNs!" << endl;
215 return NAN;
216 }
217 // _Qb[beta] = Qb_new.makeZero(1e-100); // damping?
218 _Qb[beta] = Qb_new;
219
220 for( R_nb_cit alpha = nbIR(beta).begin(); alpha != nbIR(beta).end(); alpha++ ) {
221 muba(beta,*alpha) = _Qb[beta].divided_by( muab(*alpha,beta) );
222
223 Factor Qa_new = OR(*alpha);
224 for( R_nb_cit gamma = nbOR(*alpha).begin(); gamma != nbOR(*alpha).end(); gamma++ )
225 Qa_new *= muba(*gamma,*alpha);
226 Qa_new ^= (1.0 / OR(*alpha).c());
227 Qa_new.normalize( _normtype );
228 if( Qa_new.hasNaNs() ) {
229 cout << "HAK::doGBP: Qa_new has NaNs!" << endl;
230 return NAN;
231 }
232 // _Qa[*alpha] = Qa_new.makeZero(1e-100); // damping?
233 _Qa[*alpha] = Qa_new;
234 }
235 }
236
237 // Calculate new single variable beliefs and compare with old ones
238 for( size_t i = 0; i < nrVars(); i++ ) {
239 Factor new_belief = belief( var( i ) );
240 diffs.push( dist( new_belief, old_beliefs[i], Prob::DISTLINF ) );
241 old_beliefs[i] = new_belief;
242 }
243
244 if( Verbose() >= 3 )
245 cout << "HAK::doGBP: maxdiff " << diffs.max() << " after " << iter+1 << " passes" << endl;
246 }
247
248 updateMaxDiff( diffs.max() );
249
250 if( Verbose() >= 1 ) {
251 if( diffs.max() > Tol() ) {
252 if( Verbose() == 1 )
253 cout << endl;
254 cout << "HAK::doGBP: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
255 } else {
256 if( Verbose() >= 2 )
257 cout << "HAK::doGBP: ";
258 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
259 }
260 }
261
262 return diffs.max();
263 }
264
265
266 double HAK::doDoubleLoop() {
267 if( Verbose() >= 1 )
268 cout << "Starting " << identify() << "...";
269 if( Verbose() >= 3)
270 cout << endl;
271
272 clock_t tic = toc();
273
274 // Save original outer regions
275 vector<FRegion> org_ORs = ORs();
276
277 // Save original inner counting numbers and set negative counting numbers to zero
278 vector<double> org_IR_cs( nr_IRs(), 0.0 );
279 for( size_t beta = 0; beta < nr_IRs(); beta++ ) {
280 org_IR_cs[beta] = IR(beta).c();
281 if( IR(beta).c() < 0.0 )
282 IR(beta).c() = 0.0;
283 }
284
285 // Keep old beliefs to check convergence
286 vector<Factor> old_beliefs;
287 old_beliefs.reserve( nrVars() );
288 for( size_t i = 0; i < nrVars(); i++ )
289 old_beliefs.push_back( belief( var(i) ) );
290
291 // Differences in single node beliefs
292 Diffs diffs(nrVars(), 1.0);
293
294 size_t outer_maxiter = MaxIter();
295 double outer_tol = Tol();
296 size_t outer_verbose = Verbose();
297 double org_maxdiff = MaxDiff();
298
299 // Set parameters for inner loop
300 MaxIter( 5 );
301 Verbose( outer_verbose ? outer_verbose - 1 : 0 );
302
303 size_t outer_iter = 0;
304 for( outer_iter = 0; outer_iter < outer_maxiter && diffs.max() > outer_tol; outer_iter++ ) {
305 // Calculate new outer regions
306 for( size_t alpha = 0; alpha < nr_ORs(); alpha++ ) {
307 OR(alpha) = org_ORs[alpha];
308 for( R_nb_cit beta = nbOR(alpha).begin(); beta != nbOR(alpha).end(); beta++ )
309 OR(alpha) *= _Qb[*beta] ^ ((IR(*beta).c() - org_IR_cs[*beta]) / nbIR(*beta).size());
310 }
311
312 // Inner loop
313 if( isnan( doGBP() ) )
314 return NAN;
315
316 // Calculate new single variable beliefs and compare with old ones
317 for( size_t i = 0; i < nrVars(); i++ ) {
318 Factor new_belief = belief( var( i ) );
319 diffs.push( dist( new_belief, old_beliefs[i], Prob::DISTLINF ) );
320 old_beliefs[i] = new_belief;
321 }
322
323 if( Verbose() >= 3 )
324 cout << "HAK::doDoubleLoop: maxdiff " << diffs.max() << " after " << outer_iter+1 << " passes" << endl;
325 }
326
327 // restore _maxiter, _verbose and _maxdiff
328 MaxIter( outer_maxiter );
329 Verbose( outer_verbose );
330 MaxDiff( org_maxdiff );
331
332 updateMaxDiff( diffs.max() );
333
334 // Restore original outer regions
335 ORs() = org_ORs;
336
337 // Restore original inner counting numbers
338 for( size_t beta = 0; beta < nr_IRs(); beta++ )
339 IR(beta).c() = org_IR_cs[beta];
340
341 if( Verbose() >= 1 ) {
342 if( diffs.max() > Tol() ) {
343 if( Verbose() == 1 )
344 cout << endl;
345 cout << "HAK::doDoubleLoop: WARNING: not converged within " << outer_maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
346 } else {
347 if( Verbose() >= 3 )
348 cout << "HAK::doDoubleLoop: ";
349 cout << "converged in " << outer_iter << " passes (" << toc() - tic << " clocks)." << endl;
350 }
351 }
352
353 return diffs.max();
354 }
355
356
357 double HAK::run() {
358 if( DoubleLoop() )
359 return doDoubleLoop();
360 else
361 return doGBP();
362 }
363
364
365 Factor HAK::belief( const VarSet &ns ) const {
366 vector<Factor>::const_iterator beta;
367 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
368 if( beta->vars() >> ns )
369 break;
370 if( beta != _Qb.end() )
371 return( beta->marginal(ns) );
372 else {
373 vector<Factor>::const_iterator alpha;
374 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
375 if( alpha->vars() >> ns )
376 break;
377 assert( alpha != _Qa.end() );
378 return( alpha->marginal(ns) );
379 }
380 }
381
382
383 Factor HAK::belief( const Var &n ) const {
384 return belief( (VarSet)n );
385 }
386
387
388 vector<Factor> HAK::beliefs() const {
389 vector<Factor> result;
390 for( size_t beta = 0; beta < nr_IRs(); beta++ )
391 result.push_back( Qb(beta) );
392 for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
393 result.push_back( Qa(alpha) );
394 return result;
395 }
396
397
398 Complex HAK::logZ() const {
399 Complex sum = 0.0;
400 for( size_t beta = 0; beta < nr_IRs(); beta++ )
401 sum += Complex(IR(beta).c()) * Qb(beta).entropy();
402 for( size_t alpha = 0; alpha < nr_ORs(); alpha++ ) {
403 sum += Complex(OR(alpha).c()) * Qa(alpha).entropy();
404 sum += (OR(alpha).log0() * Qa(alpha)).totalSum();
405 }
406 return sum;
407 }
408
409
410 }