1c05b5e2cfd0cbcd725990bfa2382a289fdc9078
[libdai.git] / src / jtree.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <stack>
14 #include <dai/jtree.h>
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 const char *JTree::Name = "JTREE";
24
25
26 void JTree::setProperties( const PropertySet &opts ) {
27 DAI_ASSERT( opts.hasKey("verbose") );
28 DAI_ASSERT( opts.hasKey("updates") );
29
30 props.verbose = opts.getStringAs<size_t>("verbose");
31 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
32 if( opts.hasKey("inference") )
33 props.inference = opts.getStringAs<Properties::InfType>("inference");
34 else
35 props.inference = Properties::InfType::SUMPROD;
36 }
37
38
39 PropertySet JTree::getProperties() const {
40 PropertySet opts;
41 opts.Set( "verbose", props.verbose );
42 opts.Set( "updates", props.updates );
43 opts.Set( "inference", props.inference );
44 return opts;
45 }
46
47
48 string JTree::printProperties() const {
49 stringstream s( stringstream::out );
50 s << "[";
51 s << "verbose=" << props.verbose << ",";
52 s << "updates=" << props.updates << ",";
53 s << "inference=" << props.inference << "]";
54 return s.str();
55 }
56
57
58 JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {
59 setProperties( opts );
60
61 if( !isConnected() )
62 DAI_THROW(FACTORGRAPH_NOT_CONNECTED);
63
64 if( automatic ) {
65 // Create ClusterGraph which contains factors as clusters
66 vector<VarSet> cl;
67 cl.reserve( fg.nrFactors() );
68 for( size_t I = 0; I < nrFactors(); I++ )
69 cl.push_back( factor(I).vars() );
70 ClusterGraph _cg( cl );
71
72 if( props.verbose >= 3 )
73 cerr << "Initial clusters: " << _cg << endl;
74
75 // Retain only maximal clusters
76 _cg.eraseNonMaximal();
77 if( props.verbose >= 3 )
78 cerr << "Maximal clusters: " << _cg << endl;
79
80 // Use MinFill heuristic to guess optimal elimination sequence
81 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
82 if( props.verbose >= 3 )
83 cerr << "VarElim_MinFill result: " << ElimVec << endl;
84
85 // Generate the junction tree corresponding to the elimination sequence
86 GenerateJT( ElimVec );
87 }
88 }
89
90
91 void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
92 // Construct a weighted graph (each edge is weighted with the cardinality
93 // of the intersection of the nodes, where the nodes are the elements of
94 // Cliques).
95 WeightedGraph<int> JuncGraph;
96 for( size_t i = 0; i < Cliques.size(); i++ )
97 for( size_t j = i+1; j < Cliques.size(); j++ ) {
98 size_t w = (Cliques[i] & Cliques[j]).size();
99 if( w )
100 JuncGraph[UEdge(i,j)] = w;
101 }
102
103 // Construct maximal spanning tree using Prim's algorithm
104 RTree = MaxSpanningTreePrims( JuncGraph );
105
106 // Construct corresponding region graph
107
108 // Create outer regions
109 ORs.reserve( Cliques.size() );
110 for( size_t i = 0; i < Cliques.size(); i++ )
111 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
112
113 // For each factor, find an outer region that subsumes that factor.
114 // Then, multiply the outer region with that factor.
115 for( size_t I = 0; I < nrFactors(); I++ ) {
116 size_t alpha;
117 for( alpha = 0; alpha < nrORs(); alpha++ )
118 if( OR(alpha).vars() >> factor(I).vars() ) {
119 fac2OR.push_back( alpha );
120 break;
121 }
122 DAI_ASSERT( alpha != nrORs() );
123 }
124 RecomputeORs();
125
126 // Create inner regions and edges
127 IRs.reserve( RTree.size() );
128 vector<Edge> edges;
129 edges.reserve( 2 * RTree.size() );
130 for( size_t i = 0; i < RTree.size(); i++ ) {
131 edges.push_back( Edge( RTree[i].n1, nrIRs() ) );
132 edges.push_back( Edge( RTree[i].n2, nrIRs() ) );
133 // inner clusters have counting number -1
134 IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) );
135 }
136
137 // create bipartite graph
138 G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
139
140 // Create messages and beliefs
141 Qa.clear();
142 Qa.reserve( nrORs() );
143 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
144 Qa.push_back( OR(alpha) );
145
146 Qb.clear();
147 Qb.reserve( nrIRs() );
148 for( size_t beta = 0; beta < nrIRs(); beta++ )
149 Qb.push_back( Factor( IR(beta), 1.0 ) );
150
151 _mes.clear();
152 _mes.reserve( nrORs() );
153 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
154 _mes.push_back( vector<Factor>() );
155 _mes[alpha].reserve( nbOR(alpha).size() );
156 foreach( const Neighbor &beta, nbOR(alpha) )
157 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
158 }
159
160 // Check counting numbers
161 #ifdef DAI_DEBUG
162 checkCountingNumbers();
163 #endif
164
165 if( props.verbose >= 3 )
166 cerr << "Regiongraph generated by JTree::GenerateJT: " << *this << endl;
167 }
168
169
170 string JTree::identify() const {
171 return string(Name) + printProperties();
172 }
173
174
175 Factor JTree::belief( const VarSet &vs ) const {
176 vector<Factor>::const_iterator beta;
177 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
178 if( beta->vars() >> vs )
179 break;
180 if( beta != Qb.end() )
181 return( beta->marginal(vs) );
182 else {
183 vector<Factor>::const_iterator alpha;
184 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
185 if( alpha->vars() >> vs )
186 break;
187 if( alpha == Qa.end() ) {
188 DAI_THROW(BELIEF_NOT_AVAILABLE);
189 return Factor();
190 } else
191 return( alpha->marginal(vs) );
192 }
193 }
194
195
196 vector<Factor> JTree::beliefs() const {
197 vector<Factor> result;
198 for( size_t beta = 0; beta < nrIRs(); beta++ )
199 result.push_back( Qb[beta] );
200 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
201 result.push_back( Qa[alpha] );
202 return result;
203 }
204
205
206 Factor JTree::belief( const Var &v ) const {
207 return belief( (VarSet)v );
208 }
209
210
211 void JTree::runHUGIN() {
212 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
213 Qa[alpha] = OR(alpha);
214
215 for( size_t beta = 0; beta < nrIRs(); beta++ )
216 Qb[beta].fill( 1.0 );
217
218 // CollectEvidence
219 _logZ = 0.0;
220 for( size_t i = RTree.size(); (i--) != 0; ) {
221 // Make outer region RTree[i].n1 consistent with outer region RTree[i].n2
222 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
223 Factor new_Qb;
224 if( props.inference == Properties::InfType::SUMPROD )
225 new_Qb = Qa[RTree[i].n2].marginal( IR( i ), false );
226 else
227 new_Qb = Qa[RTree[i].n2].maxMarginal( IR( i ), false );
228
229 _logZ += log(new_Qb.normalize());
230 Qa[RTree[i].n1] *= new_Qb / Qb[i];
231 Qb[i] = new_Qb;
232 }
233 if( RTree.empty() )
234 _logZ += log(Qa[0].normalize() );
235 else
236 _logZ += log(Qa[RTree[0].n1].normalize());
237
238 // DistributeEvidence
239 for( size_t i = 0; i < RTree.size(); i++ ) {
240 // Make outer region RTree[i].n2 consistent with outer region RTree[i].n1
241 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
242 Factor new_Qb;
243 if( props.inference == Properties::InfType::SUMPROD )
244 new_Qb = Qa[RTree[i].n1].marginal( IR( i ) );
245 else
246 new_Qb = Qa[RTree[i].n1].maxMarginal( IR( i ) );
247
248 Qa[RTree[i].n2] *= new_Qb / Qb[i];
249 Qb[i] = new_Qb;
250 }
251
252 // Normalize
253 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
254 Qa[alpha].normalize();
255 }
256
257
258 void JTree::runShaferShenoy() {
259 // First pass
260 _logZ = 0.0;
261 for( size_t e = nrIRs(); (e--) != 0; ) {
262 // send a message from RTree[e].n2 to RTree[e].n1
263 // or, actually, from the seperator IR(e) to RTree[e].n1
264
265 size_t i = nbIR(e)[1].node; // = RTree[e].n2
266 size_t j = nbIR(e)[0].node; // = RTree[e].n1
267 size_t _e = nbIR(e)[0].dual;
268
269 Factor msg = OR(i);
270 foreach( const Neighbor &k, nbOR(i) )
271 if( k != e )
272 msg *= message( i, k.iter );
273 if( props.inference == Properties::InfType::SUMPROD )
274 message( j, _e ) = msg.marginal( IR(e), false );
275 else
276 message( j, _e ) = msg.maxMarginal( IR(e), false );
277 _logZ += log( message(j,_e).normalize() );
278 }
279
280 // Second pass
281 for( size_t e = 0; e < nrIRs(); e++ ) {
282 size_t i = nbIR(e)[0].node; // = RTree[e].n1
283 size_t j = nbIR(e)[1].node; // = RTree[e].n2
284 size_t _e = nbIR(e)[1].dual;
285
286 Factor msg = OR(i);
287 foreach( const Neighbor &k, nbOR(i) )
288 if( k != e )
289 msg *= message( i, k.iter );
290 if( props.inference == Properties::InfType::SUMPROD )
291 message( j, _e ) = msg.marginal( IR(e) );
292 else
293 message( j, _e ) = msg.maxMarginal( IR(e) );
294 }
295
296 // Calculate beliefs
297 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
298 Factor piet = OR(alpha);
299 foreach( const Neighbor &k, nbOR(alpha) )
300 piet *= message( alpha, k.iter );
301 if( nrIRs() == 0 ) {
302 _logZ += log( piet.normalize() );
303 Qa[alpha] = piet;
304 } else if( alpha == nbIR(0)[0].node /*RTree[0].n1*/ ) {
305 _logZ += log( piet.normalize() );
306 Qa[alpha] = piet;
307 } else
308 Qa[alpha] = piet.normalized();
309 }
310
311 // Only for logZ (and for belief)...
312 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
313 if( props.inference == Properties::InfType::SUMPROD )
314 Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
315 else
316 Qb[beta] = Qa[nbIR(beta)[0].node].maxMarginal( IR(beta) );
317 }
318 }
319
320
321 Real JTree::run() {
322 if( props.updates == Properties::UpdateType::HUGIN )
323 runHUGIN();
324 else if( props.updates == Properties::UpdateType::SHSH )
325 runShaferShenoy();
326 return 0.0;
327 }
328
329
330 Real JTree::logZ() const {
331 Real s = 0.0;
332 for( size_t beta = 0; beta < nrIRs(); beta++ )
333 s += IR(beta).c() * Qb[beta].entropy();
334 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
335 s += OR(alpha).c() * Qa[alpha].entropy();
336 s += (OR(alpha).log(true) * Qa[alpha]).sum();
337 }
338 return s;
339 }
340
341
342 size_t JTree::findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot ) const {
343 // find new root clique (the one with maximal statespace overlap with vs)
344 size_t maxval = 0, maxalpha = 0;
345 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
346 size_t val = VarSet(vs & OR(alpha).vars()).nrStates();
347 if( val > maxval ) {
348 maxval = val;
349 maxalpha = alpha;
350 }
351 }
352
353 // reorder the tree edges such that maxalpha becomes the new root
354 RootedTree newTree( Graph( RTree.begin(), RTree.end() ), maxalpha );
355
356 // identify subtree that contains all variables of vs which are not in the new root
357 VarSet vsrem = vs / OR(maxalpha).vars();
358 set<DEdge> subTree;
359 // for each variable in vs that is not in the root clique
360 for( VarSet::const_iterator n = vsrem.begin(); n != vsrem.end(); n++ ) {
361 // find first occurence of *n in the tree, which is closest to the root
362 size_t e = 0;
363 for( ; e != newTree.size(); e++ ) {
364 if( OR(newTree[e].n2).vars().contains( *n ) )
365 break;
366 }
367 DAI_ASSERT( e != newTree.size() );
368
369 // track-back path to root and add edges to subTree
370 subTree.insert( newTree[e] );
371 size_t pos = newTree[e].n1;
372 for( ; e > 0; e-- )
373 if( newTree[e-1].n2 == pos ) {
374 subTree.insert( newTree[e-1] );
375 pos = newTree[e-1].n1;
376 }
377 }
378 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
379 // find first occurence of PreviousRoot in the tree, which is closest to the new root
380 size_t e = 0;
381 for( ; e != newTree.size(); e++ ) {
382 if( newTree[e].n2 == PreviousRoot )
383 break;
384 }
385 DAI_ASSERT( e != newTree.size() );
386
387 // track-back path to root and add edges to subTree
388 subTree.insert( newTree[e] );
389 size_t pos = newTree[e].n1;
390 for( ; e > 0; e-- )
391 if( newTree[e-1].n2 == pos ) {
392 subTree.insert( newTree[e-1] );
393 pos = newTree[e-1].n1;
394 }
395 }
396
397 // Resulting Tree is a reordered copy of newTree
398 // First add edges in subTree to Tree
399 Tree.clear();
400 vector<DEdge> remTree;
401 for( RootedTree::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
402 if( subTree.count( *e ) )
403 Tree.push_back( *e );
404 else
405 remTree.push_back( *e );
406 size_t subTreeSize = Tree.size();
407 // Then add remaining edges
408 copy( remTree.begin(), remTree.end(), back_inserter( Tree ) );
409
410 return subTreeSize;
411 }
412
413
414 Factor JTree::calcMarginal( const VarSet& vs ) {
415 vector<Factor>::const_iterator beta;
416 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
417 if( beta->vars() >> vs )
418 break;
419 if( beta != Qb.end() )
420 return( beta->marginal(vs) );
421 else {
422 vector<Factor>::const_iterator alpha;
423 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
424 if( alpha->vars() >> vs )
425 break;
426 if( alpha != Qa.end() )
427 return( alpha->marginal(vs) );
428 else {
429 // Find subtree to do efficient inference
430 RootedTree T;
431 size_t Tsize = findEfficientTree( vs, T );
432
433 // Find remaining variables (which are not in the new root)
434 VarSet vsrem = vs / OR(T.front().n1).vars();
435 Factor Pvs (vs, 0.0);
436
437 // Save Qa and Qb on the subtree
438 map<size_t,Factor> Qa_old;
439 map<size_t,Factor> Qb_old;
440 vector<size_t> b(Tsize, 0);
441 for( size_t i = Tsize; (i--) != 0; ) {
442 size_t alpha1 = T[i].n1;
443 size_t alpha2 = T[i].n2;
444 size_t beta;
445 for( beta = 0; beta < nrIRs(); beta++ )
446 if( UEdge( RTree[beta].n1, RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
447 break;
448 DAI_ASSERT( beta != nrIRs() );
449 b[i] = beta;
450
451 if( !Qa_old.count( alpha1 ) )
452 Qa_old[alpha1] = Qa[alpha1];
453 if( !Qa_old.count( alpha2 ) )
454 Qa_old[alpha2] = Qa[alpha2];
455 if( !Qb_old.count( beta ) )
456 Qb_old[beta] = Qb[beta];
457 }
458
459 // For all states of vsrem
460 for( State s(vsrem); s.valid(); s++ ) {
461 // CollectEvidence
462 Real logZ = 0.0;
463 for( size_t i = Tsize; (i--) != 0; ) {
464 // Make outer region T[i].n1 consistent with outer region T[i].n2
465 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
466
467 for( VarSet::const_iterator n = vsrem.begin(); n != vsrem.end(); n++ )
468 if( Qa[T[i].n2].vars() >> *n ) {
469 Factor piet( *n, 0.0 );
470 piet[s(*n)] = 1.0;
471 Qa[T[i].n2] *= piet;
472 }
473
474 Factor new_Qb = Qa[T[i].n2].marginal( IR( b[i] ), false );
475 logZ += log(new_Qb.normalize());
476 Qa[T[i].n1] *= new_Qb / Qb[b[i]];
477 Qb[b[i]] = new_Qb;
478 }
479 logZ += log(Qa[T[0].n1].normalize());
480
481 Factor piet( vsrem, 0.0 );
482 piet[s] = exp(logZ);
483 Pvs += piet * Qa[T[0].n1].marginal( vs / vsrem, false ); // OPTIMIZE ME
484
485 // Restore clamped beliefs
486 for( map<size_t,Factor>::const_iterator alpha = Qa_old.begin(); alpha != Qa_old.end(); alpha++ )
487 Qa[alpha->first] = alpha->second;
488 for( map<size_t,Factor>::const_iterator beta = Qb_old.begin(); beta != Qb_old.end(); beta++ )
489 Qb[beta->first] = beta->second;
490 }
491
492 return( Pvs.normalized() );
493 }
494 }
495 }
496
497
498 std::pair<size_t,size_t> boundTreewidth( const FactorGraph & fg ) {
499 ClusterGraph _cg;
500
501 // Copy factors
502 for( size_t I = 0; I < fg.nrFactors(); I++ )
503 _cg.insert( fg.factor(I).vars() );
504
505 // Retain only maximal clusters
506 _cg.eraseNonMaximal();
507
508 // Obtain elimination sequence
509 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
510
511 // Calculate treewidth
512 size_t treewidth = 0;
513 size_t nrstates = 0;
514 for( size_t i = 0; i < ElimVec.size(); i++ ) {
515 if( ElimVec[i].size() > treewidth )
516 treewidth = ElimVec[i].size();
517 size_t s = ElimVec[i].nrStates();
518 if( s > nrstates )
519 nrstates = s;
520 }
521
522 return pair<size_t,size_t>(treewidth, nrstates);
523 }
524
525
526 std::pair<size_t,size_t> treewidth( const FactorGraph & fg )
527 {
528 return boundTreewidth( fg );
529 }
530
531
532 std::vector<size_t> JTree::findMaximum() const {
533 vector<size_t> maximum( nrVars() );
534 vector<bool> visitedVars( nrVars(), false );
535 vector<bool> visitedFactors( nrFactors(), false );
536 stack<size_t> scheduledFactors;
537 for( size_t i = 0; i < nrVars(); ++i ) {
538 if( visitedVars[i] )
539 continue;
540 visitedVars[i] = true;
541
542 // Maximise with respect to variable i
543 Prob prod = beliefV(i).p();
544 maximum[i] = prod.argmax().first;
545
546 foreach( const Neighbor &I, nbV(i) )
547 if( !visitedFactors[I] )
548 scheduledFactors.push(I);
549
550 while( !scheduledFactors.empty() ){
551 size_t I = scheduledFactors.top();
552 scheduledFactors.pop();
553 if( visitedFactors[I] )
554 continue;
555 visitedFactors[I] = true;
556
557 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
558 bool allDetermined = true;
559 foreach( const Neighbor &j, nbF(I) )
560 if( !visitedVars[j.node] ) {
561 allDetermined = false;
562 break;
563 }
564 if( allDetermined )
565 continue;
566
567 // Calculate product of incoming messages on factor I
568 Prob prod2 = beliefF(I).p();
569
570 // The allowed configuration is restrained according to the variables assigned so far:
571 // pick the argmax amongst the allowed states
572 Real maxProb = numeric_limits<Real>::min();
573 State maxState( factor(I).vars() );
574 for( State s( factor(I).vars() ); s.valid(); ++s ){
575 // First, calculate whether this state is consistent with variables that
576 // have been assigned already
577 bool allowedState = true;
578 foreach( const Neighbor &j, nbF(I) )
579 if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
580 allowedState = false;
581 break;
582 }
583 // If it is consistent, check if its probability is larger than what we have seen so far
584 if( allowedState && prod2[s] > maxProb ) {
585 maxState = s;
586 maxProb = prod2[s];
587 }
588 }
589
590 // Decode the argmax
591 foreach( const Neighbor &j, nbF(I) ) {
592 if( visitedVars[j.node] ) {
593 // We have already visited j earlier - hopefully our state is consistent
594 if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
595 cerr << "JTree::findMaximum - warning: maximum not consistent due to loops." << endl;
596 } else {
597 // We found a consistent state for variable j
598 visitedVars[j.node] = true;
599 maximum[j.node] = maxState( var(j.node) );
600 foreach( const Neighbor &J, nbV(j) )
601 if( !visitedFactors[J] )
602 scheduledFactors.push(J);
603 }
604 }
605 }
606 }
607 return maximum;
608 }
609
610
611 } // end of namespace dai