d2253bb088f14da72e1b3af9e94270a219f33a62
[libdai.git] / src / jtree.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <dai/jtree.h>
14
15
16 namespace dai {
17
18
19 using namespace std;
20
21
22 void JTree::setProperties( const PropertySet &opts ) {
23 DAI_ASSERT( opts.hasKey("updates") );
24
25 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
26 if( opts.hasKey("verbose") )
27 props.verbose = opts.getStringAs<size_t>("verbose");
28 else
29 props.verbose = 0;
30 if( opts.hasKey("inference") )
31 props.inference = opts.getStringAs<Properties::InfType>("inference");
32 else
33 props.inference = Properties::InfType::SUMPROD;
34 if( opts.hasKey("heuristic") )
35 props.heuristic = opts.getStringAs<Properties::HeuristicType>("heuristic");
36 else
37 props.heuristic = Properties::HeuristicType::MINFILL;
38 if( opts.hasKey("maxmem") )
39 props.maxmem = opts.getStringAs<size_t>("maxmem");
40 else
41 props.maxmem = 0;
42 }
43
44
45 PropertySet JTree::getProperties() const {
46 PropertySet opts;
47 opts.set( "verbose", props.verbose );
48 opts.set( "updates", props.updates );
49 opts.set( "inference", props.inference );
50 opts.set( "heuristic", props.heuristic );
51 opts.set( "maxmem", props.maxmem );
52 return opts;
53 }
54
55
56 string JTree::printProperties() const {
57 stringstream s( stringstream::out );
58 s << "[";
59 s << "verbose=" << props.verbose << ",";
60 s << "updates=" << props.updates << ",";
61 s << "heuristic=" << props.heuristic << ",";
62 s << "inference=" << props.inference << ",";
63 s << "maxmem=" << props.maxmem << "]";
64 return s.str();
65 }
66
67
68 JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {
69 setProperties( opts );
70
71 if( automatic ) {
72 // Create ClusterGraph which contains maximal factors as clusters
73 ClusterGraph _cg( fg, true );
74 if( props.verbose >= 3 )
75 cerr << "Initial clusters: " << _cg << endl;
76
77 // Use heuristic to guess optimal elimination sequence
78 greedyVariableElimination::eliminationCostFunction ec(NULL);
79 switch( (size_t)props.heuristic ) {
80 case Properties::HeuristicType::MINNEIGHBORS:
81 ec = eliminationCost_MinNeighbors;
82 break;
83 case Properties::HeuristicType::MINWEIGHT:
84 ec = eliminationCost_MinWeight;
85 break;
86 case Properties::HeuristicType::MINFILL:
87 ec = eliminationCost_MinFill;
88 break;
89 case Properties::HeuristicType::WEIGHTEDMINFILL:
90 ec = eliminationCost_WeightedMinFill;
91 break;
92 default:
93 DAI_THROW(UNKNOWN_ENUM_VALUE);
94 }
95 size_t fudge = 6; // this yields a rough estimate of the memory needed (for some reason not yet clearly understood)
96 vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( ec ), props.maxmem / (sizeof(Real) * fudge) ).eraseNonMaximal().clusters();
97 if( props.verbose >= 3 )
98 cerr << "VarElim result: " << ElimVec << endl;
99
100 // Estimate memory needed (rough upper bound)
101 long double memneeded = 0;
102 foreach( const VarSet& cl, ElimVec )
103 memneeded += cl.nrStates();
104 memneeded *= sizeof(Real) * fudge;
105 if( props.verbose >= 1 ) {
106 cerr << "Estimate of needed memory: " << memneeded / 1024 << "kB" << endl;
107 cerr << "Maximum memory: ";
108 if( props.maxmem )
109 cerr << props.maxmem / 1024 << "kB" << endl;
110 else
111 cerr << "unlimited" << endl;
112 }
113 if( props.maxmem && memneeded > props.maxmem )
114 DAI_THROW(OUT_OF_MEMORY);
115
116 // Generate the junction tree corresponding to the elimination sequence
117 GenerateJT( fg, ElimVec );
118 }
119 }
120
121
122 void JTree::construct( const FactorGraph &fg, const std::vector<VarSet> &cl, bool verify ) {
123 // Copy the factor graph
124 FactorGraph::operator=( fg );
125
126 // Construct a weighted graph (each edge is weighted with the cardinality
127 // of the intersection of the nodes, where the nodes are the elements of cl).
128 WeightedGraph<int> JuncGraph;
129 // Start by connecting all clusters with cluster zero, and weight zero,
130 // in order to get a connected weighted graph
131 for( size_t i = 1; i < cl.size(); i++ )
132 JuncGraph[UEdge(i,0)] = 0;
133 for( size_t i = 0; i < cl.size(); i++ ) {
134 for( size_t j = i + 1; j < cl.size(); j++ ) {
135 size_t w = (cl[i] & cl[j]).size();
136 if( w )
137 JuncGraph[UEdge(i,j)] = w;
138 }
139 }
140 if( props.verbose >= 3 )
141 cerr << "Weightedgraph: " << JuncGraph << endl;
142
143 // Construct maximal spanning tree using Prim's algorithm
144 RTree = MaxSpanningTree( JuncGraph, true );
145 if( props.verbose >= 3 )
146 cerr << "Spanning tree: " << RTree << endl;
147 DAI_DEBASSERT( RTree.size() == cl.size() - 1 );
148
149 // Construct corresponding region graph
150
151 // Create outer regions
152 _ORs.clear();
153 _ORs.reserve( cl.size() );
154 for( size_t i = 0; i < cl.size(); i++ )
155 _ORs.push_back( FRegion( Factor(cl[i], 1.0), 1.0 ) );
156
157 // For each factor, find an outer region that subsumes that factor.
158 // Then, multiply the outer region with that factor.
159 _fac2OR.clear();
160 _fac2OR.resize( nrFactors(), -1U );
161 for( size_t I = 0; I < nrFactors(); I++ ) {
162 size_t alpha;
163 for( alpha = 0; alpha < nrORs(); alpha++ )
164 if( OR(alpha).vars() >> factor(I).vars() ) {
165 _fac2OR[I] = alpha;
166 break;
167 }
168 if( verify )
169 DAI_ASSERT( alpha != nrORs() );
170 }
171 recomputeORs();
172
173 // Create inner regions and edges
174 _IRs.clear();
175 _IRs.reserve( RTree.size() );
176 vector<Edge> edges;
177 edges.reserve( 2 * RTree.size() );
178 for( size_t i = 0; i < RTree.size(); i++ ) {
179 edges.push_back( Edge( RTree[i].first, nrIRs() ) );
180 edges.push_back( Edge( RTree[i].second, nrIRs() ) );
181 // inner clusters have counting number -1, except if they are empty
182 VarSet intersection = cl[RTree[i].first] & cl[RTree[i].second];
183 _IRs.push_back( Region( intersection, intersection.size() ? -1.0 : 0.0 ) );
184 }
185
186 // create bipartite graph
187 _G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
188
189 // Check counting numbers
190 #ifdef DAI_DEBUG
191 checkCountingNumbers();
192 #endif
193
194 // Create beliefs
195 Qa.clear();
196 Qa.reserve( nrORs() );
197 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
198 Qa.push_back( OR(alpha) );
199
200 Qb.clear();
201 Qb.reserve( nrIRs() );
202 for( size_t beta = 0; beta < nrIRs(); beta++ )
203 Qb.push_back( Factor( IR(beta), 1.0 ) );
204 }
205
206
207 void JTree::GenerateJT( const FactorGraph &fg, const std::vector<VarSet> &cl ) {
208 construct( fg, cl, true );
209
210 // Create messages
211 _mes.clear();
212 _mes.reserve( nrORs() );
213 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
214 _mes.push_back( vector<Factor>() );
215 _mes[alpha].reserve( nbOR(alpha).size() );
216 foreach( const Neighbor &beta, nbOR(alpha) )
217 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
218 }
219
220 if( props.verbose >= 3 )
221 cerr << "Regiongraph generated by JTree::GenerateJT: " << *this << endl;
222 }
223
224
225 Factor JTree::belief( const VarSet &vs ) const {
226 vector<Factor>::const_iterator beta;
227 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
228 if( beta->vars() >> vs )
229 break;
230 if( beta != Qb.end() ) {
231 if( props.inference == Properties::InfType::SUMPROD )
232 return( beta->marginal(vs) );
233 else
234 return( beta->maxMarginal(vs) );
235 } else {
236 vector<Factor>::const_iterator alpha;
237 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
238 if( alpha->vars() >> vs )
239 break;
240 if( alpha == Qa.end() ) {
241 DAI_THROW(BELIEF_NOT_AVAILABLE);
242 return Factor();
243 } else {
244 if( props.inference == Properties::InfType::SUMPROD )
245 return( alpha->marginal(vs) );
246 else
247 return( alpha->maxMarginal(vs) );
248 }
249 }
250 }
251
252
253 vector<Factor> JTree::beliefs() const {
254 vector<Factor> result;
255 for( size_t beta = 0; beta < nrIRs(); beta++ )
256 result.push_back( Qb[beta] );
257 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
258 result.push_back( Qa[alpha] );
259 return result;
260 }
261
262
263 void JTree::runHUGIN() {
264 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
265 Qa[alpha] = OR(alpha);
266
267 for( size_t beta = 0; beta < nrIRs(); beta++ )
268 Qb[beta].fill( 1.0 );
269
270 // CollectEvidence
271 _logZ = 0.0;
272 for( size_t i = RTree.size(); (i--) != 0; ) {
273 // Make outer region RTree[i].first consistent with outer region RTree[i].second
274 // IR(i) = seperator OR(RTree[i].first) && OR(RTree[i].second)
275 Factor new_Qb;
276 if( props.inference == Properties::InfType::SUMPROD )
277 new_Qb = Qa[RTree[i].second].marginal( IR( i ), false );
278 else
279 new_Qb = Qa[RTree[i].second].maxMarginal( IR( i ), false );
280
281 _logZ += log(new_Qb.normalize());
282 Qa[RTree[i].first] *= new_Qb / Qb[i];
283 Qb[i] = new_Qb;
284 }
285 if( RTree.empty() )
286 _logZ += log(Qa[0].normalize() );
287 else
288 _logZ += log(Qa[RTree[0].first].normalize());
289
290 // DistributeEvidence
291 for( size_t i = 0; i < RTree.size(); i++ ) {
292 // Make outer region RTree[i].second consistent with outer region RTree[i].first
293 // IR(i) = seperator OR(RTree[i].first) && OR(RTree[i].second)
294 Factor new_Qb;
295 if( props.inference == Properties::InfType::SUMPROD )
296 new_Qb = Qa[RTree[i].first].marginal( IR( i ) );
297 else
298 new_Qb = Qa[RTree[i].first].maxMarginal( IR( i ) );
299
300 Qa[RTree[i].second] *= new_Qb / Qb[i];
301 Qb[i] = new_Qb;
302 }
303
304 // Normalize
305 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
306 Qa[alpha].normalize();
307 }
308
309
310 void JTree::runShaferShenoy() {
311 // First pass
312 _logZ = 0.0;
313 for( size_t e = nrIRs(); (e--) != 0; ) {
314 // send a message from RTree[e].second to RTree[e].first
315 // or, actually, from the seperator IR(e) to RTree[e].first
316
317 size_t i = nbIR(e)[1].node; // = RTree[e].second
318 size_t j = nbIR(e)[0].node; // = RTree[e].first
319 size_t _e = nbIR(e)[0].dual;
320
321 Factor msg = OR(i);
322 foreach( const Neighbor &k, nbOR(i) )
323 if( k != e )
324 msg *= message( i, k.iter );
325 if( props.inference == Properties::InfType::SUMPROD )
326 message( j, _e ) = msg.marginal( IR(e), false );
327 else
328 message( j, _e ) = msg.maxMarginal( IR(e), false );
329 _logZ += log( message(j,_e).normalize() );
330 }
331
332 // Second pass
333 for( size_t e = 0; e < nrIRs(); e++ ) {
334 size_t i = nbIR(e)[0].node; // = RTree[e].first
335 size_t j = nbIR(e)[1].node; // = RTree[e].second
336 size_t _e = nbIR(e)[1].dual;
337
338 Factor msg = OR(i);
339 foreach( const Neighbor &k, nbOR(i) )
340 if( k != e )
341 msg *= message( i, k.iter );
342 if( props.inference == Properties::InfType::SUMPROD )
343 message( j, _e ) = msg.marginal( IR(e) );
344 else
345 message( j, _e ) = msg.maxMarginal( IR(e) );
346 }
347
348 // Calculate beliefs
349 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
350 Factor piet = OR(alpha);
351 foreach( const Neighbor &k, nbOR(alpha) )
352 piet *= message( alpha, k.iter );
353 if( nrIRs() == 0 ) {
354 _logZ += log( piet.normalize() );
355 Qa[alpha] = piet;
356 } else if( alpha == nbIR(0)[0].node /*RTree[0].first*/ ) {
357 _logZ += log( piet.normalize() );
358 Qa[alpha] = piet;
359 } else
360 Qa[alpha] = piet.normalized();
361 }
362
363 // Only for logZ (and for belief)...
364 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
365 if( props.inference == Properties::InfType::SUMPROD )
366 Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
367 else
368 Qb[beta] = Qa[nbIR(beta)[0].node].maxMarginal( IR(beta) );
369 }
370 }
371
372
373 Real JTree::run() {
374 if( props.updates == Properties::UpdateType::HUGIN )
375 runHUGIN();
376 else if( props.updates == Properties::UpdateType::SHSH )
377 runShaferShenoy();
378 return 0.0;
379 }
380
381
382 Real JTree::logZ() const {
383 /* Real s = 0.0;
384 for( size_t beta = 0; beta < nrIRs(); beta++ )
385 s += IR(beta).c() * Qb[beta].entropy();
386 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
387 s += OR(alpha).c() * Qa[alpha].entropy();
388 s += (OR(alpha).log(true) * Qa[alpha]).sum();
389 }
390 DAI_ASSERT( abs( _logZ - s ) < 1e-8 );
391 return s;*/
392 return _logZ;
393 }
394
395
396 size_t JTree::findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot ) const {
397 // find new root clique (the one with maximal statespace overlap with vs)
398 long double maxval = 0.0;
399 size_t maxalpha = 0;
400 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
401 long double val = VarSet(vs & OR(alpha).vars()).nrStates();
402 if( val > maxval ) {
403 maxval = val;
404 maxalpha = alpha;
405 }
406 }
407
408 // reorder the tree edges such that maxalpha becomes the new root
409 RootedTree newTree( GraphEL( RTree.begin(), RTree.end() ), maxalpha );
410
411 // identify subtree that contains all variables of vs which are not in the new root
412 set<DEdge> subTree;
413 // for each variable in vs
414 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ ) {
415 for( size_t e = 0; e < newTree.size(); e++ ) {
416 if( OR(newTree[e].second).vars().contains( *n ) ) {
417 size_t f = e;
418 subTree.insert( newTree[f] );
419 size_t pos = newTree[f].first;
420 for( ; f > 0; f-- )
421 if( newTree[f-1].second == pos ) {
422 subTree.insert( newTree[f-1] );
423 pos = newTree[f-1].first;
424 }
425 }
426 }
427 }
428 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
429 // find first occurence of PreviousRoot in the tree, which is closest to the new root
430 size_t e = 0;
431 for( ; e != newTree.size(); e++ ) {
432 if( newTree[e].second == PreviousRoot )
433 break;
434 }
435 DAI_ASSERT( e != newTree.size() );
436
437 // track-back path to root and add edges to subTree
438 subTree.insert( newTree[e] );
439 size_t pos = newTree[e].first;
440 for( ; e > 0; e-- )
441 if( newTree[e-1].second == pos ) {
442 subTree.insert( newTree[e-1] );
443 pos = newTree[e-1].first;
444 }
445 }
446
447 // Resulting Tree is a reordered copy of newTree
448 // First add edges in subTree to Tree
449 Tree.clear();
450 vector<DEdge> remTree;
451 for( RootedTree::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
452 if( subTree.count( *e ) )
453 Tree.push_back( *e );
454 else
455 remTree.push_back( *e );
456 size_t subTreeSize = Tree.size();
457 // Then add remaining edges
458 copy( remTree.begin(), remTree.end(), back_inserter( Tree ) );
459
460 return subTreeSize;
461 }
462
463
464 Factor JTree::calcMarginal( const VarSet& vs ) {
465 vector<Factor>::const_iterator beta;
466 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
467 if( beta->vars() >> vs )
468 break;
469 if( beta != Qb.end() ) {
470 if( props.inference == Properties::InfType::SUMPROD )
471 return( beta->marginal(vs) );
472 else
473 return( beta->maxMarginal(vs) );
474 } else {
475 vector<Factor>::const_iterator alpha;
476 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
477 if( alpha->vars() >> vs )
478 break;
479 if( alpha != Qa.end() ) {
480 if( props.inference == Properties::InfType::SUMPROD )
481 return( alpha->marginal(vs) );
482 else
483 return( alpha->maxMarginal(vs) );
484 } else {
485 // Find subtree to do efficient inference
486 RootedTree T;
487 size_t Tsize = findEfficientTree( vs, T );
488
489 // Find remaining variables (which are not in the new root)
490 VarSet vsrem = vs / OR(T.front().first).vars();
491 Factor Pvs (vs, 0.0);
492
493 // Save Qa and Qb on the subtree
494 map<size_t,Factor> Qa_old;
495 map<size_t,Factor> Qb_old;
496 vector<size_t> b(Tsize, 0);
497 for( size_t i = Tsize; (i--) != 0; ) {
498 size_t alpha1 = T[i].first;
499 size_t alpha2 = T[i].second;
500 size_t beta;
501 for( beta = 0; beta < nrIRs(); beta++ )
502 if( UEdge( RTree[beta].first, RTree[beta].second ) == UEdge( alpha1, alpha2 ) )
503 break;
504 DAI_ASSERT( beta != nrIRs() );
505 b[i] = beta;
506
507 if( !Qa_old.count( alpha1 ) )
508 Qa_old[alpha1] = Qa[alpha1];
509 if( !Qa_old.count( alpha2 ) )
510 Qa_old[alpha2] = Qa[alpha2];
511 if( !Qb_old.count( beta ) )
512 Qb_old[beta] = Qb[beta];
513 }
514
515 // For all states of vsrem
516 for( State s(vsrem); s.valid(); s++ ) {
517 // CollectEvidence
518 Real logZ = 0.0;
519 for( size_t i = Tsize; (i--) != 0; ) {
520 // Make outer region T[i].first consistent with outer region T[i].second
521 // IR(i) = seperator OR(T[i].first) && OR(T[i].second)
522
523 for( VarSet::const_iterator n = vsrem.begin(); n != vsrem.end(); n++ )
524 if( Qa[T[i].second].vars() >> *n ) {
525 Factor piet( *n, 0.0 );
526 piet.set( s(*n), 1.0 );
527 Qa[T[i].second] *= piet;
528 }
529
530 Factor new_Qb;
531 if( props.inference == Properties::InfType::SUMPROD )
532 new_Qb = Qa[T[i].second].marginal( IR( b[i] ), false );
533 else
534 new_Qb = Qa[T[i].second].maxMarginal( IR( b[i] ), false );
535 logZ += log(new_Qb.normalize());
536 Qa[T[i].first] *= new_Qb / Qb[b[i]];
537 Qb[b[i]] = new_Qb;
538 }
539 logZ += log(Qa[T[0].first].normalize());
540
541 Factor piet( vsrem, 0.0 );
542 piet.set( s, exp(logZ) );
543 if( props.inference == Properties::InfType::SUMPROD )
544 Pvs += piet * Qa[T[0].first].marginal( vs / vsrem, false ); // OPTIMIZE ME
545 else
546 Pvs += piet * Qa[T[0].first].maxMarginal( vs / vsrem, false ); // OPTIMIZE ME
547
548 // Restore clamped beliefs
549 for( map<size_t,Factor>::const_iterator alpha = Qa_old.begin(); alpha != Qa_old.end(); alpha++ )
550 Qa[alpha->first] = alpha->second;
551 for( map<size_t,Factor>::const_iterator beta = Qb_old.begin(); beta != Qb_old.end(); beta++ )
552 Qb[beta->first] = beta->second;
553 }
554
555 return( Pvs.normalized() );
556 }
557 }
558 }
559
560
561 std::pair<size_t,long double> boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn, size_t maxStates ) {
562 // Create cluster graph from factor graph
563 ClusterGraph _cg( fg, true );
564
565 // Obtain elimination sequence
566 vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( fn ), maxStates ).eraseNonMaximal().clusters();
567
568 // Calculate treewidth
569 size_t treewidth = 0;
570 double nrstates = 0.0;
571 for( size_t i = 0; i < ElimVec.size(); i++ ) {
572 if( ElimVec[i].size() > treewidth )
573 treewidth = ElimVec[i].size();
574 long double s = ElimVec[i].nrStates();
575 if( s > nrstates )
576 nrstates = s;
577 }
578
579 return make_pair(treewidth, nrstates);
580 }
581
582
583 } // end of namespace dai