a938ee716201c09f5a6f818ecf5d30baa566d95d
[libdai.git] / src / jtree.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <stack>
14 #include <dai/jtree.h>
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 const char *JTree::Name = "JTREE";
24
25
26 void JTree::setProperties( const PropertySet &opts ) {
27 DAI_ASSERT( opts.hasKey("verbose") );
28 DAI_ASSERT( opts.hasKey("updates") );
29
30 props.verbose = opts.getStringAs<size_t>("verbose");
31 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
32 if( opts.hasKey("inference") )
33 props.inference = opts.getStringAs<Properties::InfType>("inference");
34 else
35 props.inference = Properties::InfType::SUMPROD;
36 }
37
38
39 PropertySet JTree::getProperties() const {
40 PropertySet opts;
41 opts.Set( "verbose", props.verbose );
42 opts.Set( "updates", props.updates );
43 opts.Set( "inference", props.inference );
44 return opts;
45 }
46
47
48 string JTree::printProperties() const {
49 stringstream s( stringstream::out );
50 s << "[";
51 s << "verbose=" << props.verbose << ",";
52 s << "updates=" << props.updates << ",";
53 s << "inference=" << props.inference << "]";
54 return s.str();
55 }
56
57
58 JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {
59 setProperties( opts );
60
61 if( !isConnected() )
62 DAI_THROW(FACTORGRAPH_NOT_CONNECTED);
63
64 if( automatic ) {
65 // Create ClusterGraph which contains factors as clusters
66 vector<VarSet> cl;
67 cl.reserve( fg.nrFactors() );
68 for( size_t I = 0; I < nrFactors(); I++ )
69 cl.push_back( factor(I).vars() );
70 ClusterGraph _cg( cl );
71
72 if( props.verbose >= 3 )
73 cerr << "Initial clusters: " << _cg << endl;
74
75 // Retain only maximal clusters
76 _cg.eraseNonMaximal();
77 if( props.verbose >= 3 )
78 cerr << "Maximal clusters: " << _cg << endl;
79
80 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
81 if( props.verbose >= 3 )
82 cerr << "VarElim_MinFill result: " << ElimVec << endl;
83
84 GenerateJT( ElimVec );
85 }
86 }
87
88
89 void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
90 // Construct a weighted graph (each edge is weighted with the cardinality
91 // of the intersection of the nodes, where the nodes are the elements of
92 // Cliques).
93 WeightedGraph<int> JuncGraph;
94 for( size_t i = 0; i < Cliques.size(); i++ )
95 for( size_t j = i+1; j < Cliques.size(); j++ ) {
96 size_t w = (Cliques[i] & Cliques[j]).size();
97 if( w )
98 JuncGraph[UEdge(i,j)] = w;
99 }
100
101 // Construct maximal spanning tree using Prim's algorithm
102 RTree = MaxSpanningTreePrims( JuncGraph );
103
104 // Construct corresponding region graph
105
106 // Create outer regions
107 ORs.reserve( Cliques.size() );
108 for( size_t i = 0; i < Cliques.size(); i++ )
109 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
110
111 // For each factor, find an outer region that subsumes that factor.
112 // Then, multiply the outer region with that factor.
113 for( size_t I = 0; I < nrFactors(); I++ ) {
114 size_t alpha;
115 for( alpha = 0; alpha < nrORs(); alpha++ )
116 if( OR(alpha).vars() >> factor(I).vars() ) {
117 fac2OR.push_back( alpha );
118 break;
119 }
120 DAI_ASSERT( alpha != nrORs() );
121 }
122 RecomputeORs();
123
124 // Create inner regions and edges
125 IRs.reserve( RTree.size() );
126 vector<Edge> edges;
127 edges.reserve( 2 * RTree.size() );
128 for( size_t i = 0; i < RTree.size(); i++ ) {
129 edges.push_back( Edge( RTree[i].n1, nrIRs() ) );
130 edges.push_back( Edge( RTree[i].n2, nrIRs() ) );
131 // inner clusters have counting number -1
132 IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) );
133 }
134
135 // create bipartite graph
136 G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
137
138 // Create messages and beliefs
139 Qa.clear();
140 Qa.reserve( nrORs() );
141 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
142 Qa.push_back( OR(alpha) );
143
144 Qb.clear();
145 Qb.reserve( nrIRs() );
146 for( size_t beta = 0; beta < nrIRs(); beta++ )
147 Qb.push_back( Factor( IR(beta), 1.0 ) );
148
149 _mes.clear();
150 _mes.reserve( nrORs() );
151 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
152 _mes.push_back( vector<Factor>() );
153 _mes[alpha].reserve( nbOR(alpha).size() );
154 foreach( const Neighbor &beta, nbOR(alpha) )
155 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
156 }
157
158 // Check counting numbers
159 checkCountingNumbers();
160
161 if( props.verbose >= 3 ) {
162 cerr << "Resulting regiongraph: " << *this << endl;
163 }
164 }
165
166
167 string JTree::identify() const {
168 return string(Name) + printProperties();
169 }
170
171
172 Factor JTree::belief( const VarSet &ns ) const {
173 vector<Factor>::const_iterator beta;
174 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
175 if( beta->vars() >> ns )
176 break;
177 if( beta != Qb.end() )
178 return( beta->marginal(ns) );
179 else {
180 vector<Factor>::const_iterator alpha;
181 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
182 if( alpha->vars() >> ns )
183 break;
184 DAI_ASSERT( alpha != Qa.end() );
185 return( alpha->marginal(ns) );
186 }
187 }
188
189
190 vector<Factor> JTree::beliefs() const {
191 vector<Factor> result;
192 for( size_t beta = 0; beta < nrIRs(); beta++ )
193 result.push_back( Qb[beta] );
194 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
195 result.push_back( Qa[alpha] );
196 return result;
197 }
198
199
200 Factor JTree::belief( const Var &n ) const {
201 return belief( (VarSet)n );
202 }
203
204
205 // Needs no init
206 void JTree::runHUGIN() {
207 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
208 Qa[alpha] = OR(alpha);
209
210 for( size_t beta = 0; beta < nrIRs(); beta++ )
211 Qb[beta].fill( 1.0 );
212
213 // CollectEvidence
214 _logZ = 0.0;
215 for( size_t i = RTree.size(); (i--) != 0; ) {
216 // Make outer region RTree[i].n1 consistent with outer region RTree[i].n2
217 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
218 Factor new_Qb;
219 if( props.inference == Properties::InfType::SUMPROD )
220 new_Qb = Qa[RTree[i].n2].marginal( IR( i ), false );
221 else
222 new_Qb = Qa[RTree[i].n2].maxMarginal( IR( i ), false );
223
224 _logZ += log(new_Qb.normalize());
225 Qa[RTree[i].n1] *= new_Qb / Qb[i];
226 Qb[i] = new_Qb;
227 }
228 if( RTree.empty() )
229 _logZ += log(Qa[0].normalize() );
230 else
231 _logZ += log(Qa[RTree[0].n1].normalize());
232
233 // DistributeEvidence
234 for( size_t i = 0; i < RTree.size(); i++ ) {
235 // Make outer region RTree[i].n2 consistent with outer region RTree[i].n1
236 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
237 Factor new_Qb;
238 if( props.inference == Properties::InfType::SUMPROD )
239 new_Qb = Qa[RTree[i].n1].marginal( IR( i ) );
240 else
241 new_Qb = Qa[RTree[i].n1].maxMarginal( IR( i ) );
242
243 Qa[RTree[i].n2] *= new_Qb / Qb[i];
244 Qb[i] = new_Qb;
245 }
246
247 // Normalize
248 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
249 Qa[alpha].normalize();
250 }
251
252
253 // Really needs no init! Initial messages can be anything.
254 void JTree::runShaferShenoy() {
255 // First pass
256 _logZ = 0.0;
257 for( size_t e = nrIRs(); (e--) != 0; ) {
258 // send a message from RTree[e].n2 to RTree[e].n1
259 // or, actually, from the seperator IR(e) to RTree[e].n1
260
261 size_t i = nbIR(e)[1].node; // = RTree[e].n2
262 size_t j = nbIR(e)[0].node; // = RTree[e].n1
263 size_t _e = nbIR(e)[0].dual;
264
265 Factor msg = OR(i);
266 foreach( const Neighbor &k, nbOR(i) )
267 if( k != e )
268 msg *= message( i, k.iter );
269 if( props.inference == Properties::InfType::SUMPROD )
270 message( j, _e ) = msg.marginal( IR(e), false );
271 else
272 message( j, _e ) = msg.maxMarginal( IR(e), false );
273 _logZ += log( message(j,_e).normalize() );
274 }
275
276 // Second pass
277 for( size_t e = 0; e < nrIRs(); e++ ) {
278 size_t i = nbIR(e)[0].node; // = RTree[e].n1
279 size_t j = nbIR(e)[1].node; // = RTree[e].n2
280 size_t _e = nbIR(e)[1].dual;
281
282 Factor msg = OR(i);
283 foreach( const Neighbor &k, nbOR(i) )
284 if( k != e )
285 msg *= message( i, k.iter );
286 if( props.inference == Properties::InfType::SUMPROD )
287 message( j, _e ) = msg.marginal( IR(e) );
288 else
289 message( j, _e ) = msg.maxMarginal( IR(e) );
290 }
291
292 // Calculate beliefs
293 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
294 Factor piet = OR(alpha);
295 foreach( const Neighbor &k, nbOR(alpha) )
296 piet *= message( alpha, k.iter );
297 if( nrIRs() == 0 ) {
298 _logZ += log( piet.normalize() );
299 Qa[alpha] = piet;
300 } else if( alpha == nbIR(0)[0].node /*RTree[0].n1*/ ) {
301 _logZ += log( piet.normalize() );
302 Qa[alpha] = piet;
303 } else
304 Qa[alpha] = piet.normalized();
305 }
306
307 // Only for logZ (and for belief)...
308 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
309 if( props.inference == Properties::InfType::SUMPROD )
310 Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
311 else
312 Qb[beta] = Qa[nbIR(beta)[0].node].maxMarginal( IR(beta) );
313 }
314 }
315
316
317 Real JTree::run() {
318 if( props.updates == Properties::UpdateType::HUGIN )
319 runHUGIN();
320 else if( props.updates == Properties::UpdateType::SHSH )
321 runShaferShenoy();
322 return 0.0;
323 }
324
325
326 Real JTree::logZ() const {
327 Real s = 0.0;
328 for( size_t beta = 0; beta < nrIRs(); beta++ )
329 s += IR(beta).c() * Qb[beta].entropy();
330 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
331 s += OR(alpha).c() * Qa[alpha].entropy();
332 s += (OR(alpha).log(true) * Qa[alpha]).sum();
333 }
334 return s;
335 }
336
337
338
339 size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot ) const {
340 // find new root clique (the one with maximal statespace overlap with ns)
341 size_t maxval = 0, maxalpha = 0;
342 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
343 size_t val = VarSet(ns & OR(alpha).vars()).nrStates();
344 if( val > maxval ) {
345 maxval = val;
346 maxalpha = alpha;
347 }
348 }
349
350 // grow new tree
351 Graph oldTree;
352 for( DEdgeVec::const_iterator e = RTree.begin(); e != RTree.end(); e++ )
353 oldTree.insert( UEdge(e->n1, e->n2) );
354 DEdgeVec newTree = GrowRootedTree( oldTree, maxalpha );
355
356 // identify subtree that contains variables of ns which are not in the new root
357 VarSet nsrem = ns / OR(maxalpha).vars();
358 set<DEdge> subTree;
359 // for each variable in ns that is not in the root clique
360 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ ) {
361 // find first occurence of *n in the tree, which is closest to the root
362 size_t e = 0;
363 for( ; e != newTree.size(); e++ ) {
364 if( OR(newTree[e].n2).vars().contains( *n ) )
365 break;
366 }
367 DAI_ASSERT( e != newTree.size() );
368
369 // track-back path to root and add edges to subTree
370 subTree.insert( newTree[e] );
371 size_t pos = newTree[e].n1;
372 for( ; e > 0; e-- )
373 if( newTree[e-1].n2 == pos ) {
374 subTree.insert( newTree[e-1] );
375 pos = newTree[e-1].n1;
376 }
377 }
378 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
379 // find first occurence of PreviousRoot in the tree, which is closest to the new root
380 size_t e = 0;
381 for( ; e != newTree.size(); e++ ) {
382 if( newTree[e].n2 == PreviousRoot )
383 break;
384 }
385 DAI_ASSERT( e != newTree.size() );
386
387 // track-back path to root and add edges to subTree
388 subTree.insert( newTree[e] );
389 size_t pos = newTree[e].n1;
390 for( ; e > 0; e-- )
391 if( newTree[e-1].n2 == pos ) {
392 subTree.insert( newTree[e-1] );
393 pos = newTree[e-1].n1;
394 }
395 }
396
397 // Resulting Tree is a reordered copy of newTree
398 // First add edges in subTree to Tree
399 Tree.clear();
400 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
401 if( subTree.count( *e ) ) {
402 Tree.push_back( *e );
403 }
404 // Then add edges pointing away from nsrem
405 // FIXME
406 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
407 for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
408 if( *e != *sTi ) {
409 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
410 e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
411 Tree.push_back( *e );
412 }
413 }*/
414 // FIXME
415 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
416 if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
417 bool found = false;
418 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
419 if( (OR(e->n1).vars() && *n) ) {
420 found = true;
421 break;
422 }
423 if( found ) {
424 Tree.push_back( *e );
425 }
426 }*/
427 size_t subTreeSize = Tree.size();
428 // Then add remaining edges
429 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
430 if( find( Tree.begin(), Tree.end(), *e ) == Tree.end() )
431 Tree.push_back( *e );
432
433 return subTreeSize;
434 }
435
436
437 // Cutset conditioning
438 // assumes that run() has been called already
439 Factor JTree::calcMarginal( const VarSet& ns ) {
440 vector<Factor>::const_iterator beta;
441 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
442 if( beta->vars() >> ns )
443 break;
444 if( beta != Qb.end() )
445 return( beta->marginal(ns) );
446 else {
447 vector<Factor>::const_iterator alpha;
448 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
449 if( alpha->vars() >> ns )
450 break;
451 if( alpha != Qa.end() )
452 return( alpha->marginal(ns) );
453 else {
454 // Find subtree to do efficient inference
455 DEdgeVec T;
456 size_t Tsize = findEfficientTree( ns, T );
457
458 // Find remaining variables (which are not in the new root)
459 VarSet nsrem = ns / OR(T.front().n1).vars();
460 Factor Pns (ns, 0.0);
461
462 // Save Qa and Qb on the subtree
463 map<size_t,Factor> Qa_old;
464 map<size_t,Factor> Qb_old;
465 vector<size_t> b(Tsize, 0);
466 for( size_t i = Tsize; (i--) != 0; ) {
467 size_t alpha1 = T[i].n1;
468 size_t alpha2 = T[i].n2;
469 size_t beta;
470 for( beta = 0; beta < nrIRs(); beta++ )
471 if( UEdge( RTree[beta].n1, RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
472 break;
473 DAI_ASSERT( beta != nrIRs() );
474 b[i] = beta;
475
476 if( !Qa_old.count( alpha1 ) )
477 Qa_old[alpha1] = Qa[alpha1];
478 if( !Qa_old.count( alpha2 ) )
479 Qa_old[alpha2] = Qa[alpha2];
480 if( !Qb_old.count( beta ) )
481 Qb_old[beta] = Qb[beta];
482 }
483
484 // For all states of nsrem
485 for( State s(nsrem); s.valid(); s++ ) {
486 // CollectEvidence
487 Real logZ = 0.0;
488 for( size_t i = Tsize; (i--) != 0; ) {
489 // Make outer region T[i].n1 consistent with outer region T[i].n2
490 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
491
492 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
493 if( Qa[T[i].n2].vars() >> *n ) {
494 Factor piet( *n, 0.0 );
495 piet[s(*n)] = 1.0;
496 Qa[T[i].n2] *= piet;
497 }
498
499 Factor new_Qb = Qa[T[i].n2].marginal( IR( b[i] ), false );
500 logZ += log(new_Qb.normalize());
501 Qa[T[i].n1] *= new_Qb / Qb[b[i]];
502 Qb[b[i]] = new_Qb;
503 }
504 logZ += log(Qa[T[0].n1].normalize());
505
506 Factor piet( nsrem, 0.0 );
507 piet[s] = exp(logZ);
508 Pns += piet * Qa[T[0].n1].marginal( ns / nsrem, false ); // OPTIMIZE ME
509
510 // Restore clamped beliefs
511 for( map<size_t,Factor>::const_iterator alpha = Qa_old.begin(); alpha != Qa_old.end(); alpha++ )
512 Qa[alpha->first] = alpha->second;
513 for( map<size_t,Factor>::const_iterator beta = Qb_old.begin(); beta != Qb_old.end(); beta++ )
514 Qb[beta->first] = beta->second;
515 }
516
517 return( Pns.normalized() );
518 }
519 }
520 }
521
522
523 /// Calculates upper bound to the treewidth of a FactorGraph
524 /** \relates JTree
525 * \return a pair (number of variables in largest clique, number of states in largest clique)
526 */
527 std::pair<size_t,size_t> treewidth( const FactorGraph & fg ) {
528 ClusterGraph _cg;
529
530 // Copy factors
531 for( size_t I = 0; I < fg.nrFactors(); I++ )
532 _cg.insert( fg.factor(I).vars() );
533
534 // Retain only maximal clusters
535 _cg.eraseNonMaximal();
536
537 // Obtain elimination sequence
538 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
539
540 // Calculate treewidth
541 size_t treewidth = 0;
542 size_t nrstates = 0;
543 for( size_t i = 0; i < ElimVec.size(); i++ ) {
544 if( ElimVec[i].size() > treewidth )
545 treewidth = ElimVec[i].size();
546 size_t s = ElimVec[i].nrStates();
547 if( s > nrstates )
548 nrstates = s;
549 }
550
551 return pair<size_t,size_t>(treewidth, nrstates);
552 }
553
554
555 std::vector<size_t> JTree::findMaximum() const {
556 vector<size_t> maximum( nrVars() );
557 vector<bool> visitedVars( nrVars(), false );
558 vector<bool> visitedFactors( nrFactors(), false );
559 stack<size_t> scheduledFactors;
560 for( size_t i = 0; i < nrVars(); ++i ) {
561 if( visitedVars[i] )
562 continue;
563 visitedVars[i] = true;
564
565 // Maximise with respect to variable i
566 Prob prod = beliefV(i).p();
567 maximum[i] = prod.argmax().first;
568
569 foreach( const Neighbor &I, nbV(i) )
570 if( !visitedFactors[I] )
571 scheduledFactors.push(I);
572
573 while( !scheduledFactors.empty() ){
574 size_t I = scheduledFactors.top();
575 scheduledFactors.pop();
576 if( visitedFactors[I] )
577 continue;
578 visitedFactors[I] = true;
579
580 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
581 bool allDetermined = true;
582 foreach( const Neighbor &j, nbF(I) )
583 if( !visitedVars[j.node] ) {
584 allDetermined = false;
585 break;
586 }
587 if( allDetermined )
588 continue;
589
590 // Calculate product of incoming messages on factor I
591 Prob prod2 = beliefF(I).p();
592
593 // The allowed configuration is restrained according to the variables assigned so far:
594 // pick the argmax amongst the allowed states
595 Real maxProb = numeric_limits<Real>::min();
596 State maxState( factor(I).vars() );
597 for( State s( factor(I).vars() ); s.valid(); ++s ){
598 // First, calculate whether this state is consistent with variables that
599 // have been assigned already
600 bool allowedState = true;
601 foreach( const Neighbor &j, nbF(I) )
602 if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
603 allowedState = false;
604 break;
605 }
606 // If it is consistent, check if its probability is larger than what we have seen so far
607 if( allowedState && prod2[s] > maxProb ) {
608 maxState = s;
609 maxProb = prod2[s];
610 }
611 }
612
613 // Decode the argmax
614 foreach( const Neighbor &j, nbF(I) ) {
615 if( visitedVars[j.node] ) {
616 // We have already visited j earlier - hopefully our state is consistent
617 if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
618 cerr << "JTree::findMaximum - warning: maximum not consistent due to loops." << endl;
619 } else {
620 // We found a consistent state for variable j
621 visitedVars[j.node] = true;
622 maximum[j.node] = maxState( var(j.node) );
623 foreach( const Neighbor &J, nbV(j) )
624 if( !visitedFactors[J] )
625 scheduledFactors.push(J);
626 }
627 }
628 }
629 }
630 return maximum;
631 }
632
633
634 } // end of namespace dai