Some improvements to jtree and regiongraph and started work on regiongraph unit tests
[libdai.git] / src / jtree.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <stack>
14 #include <dai/jtree.h>
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 const char *JTree::Name = "JTREE";
24
25
26 void JTree::setProperties( const PropertySet &opts ) {
27 DAI_ASSERT( opts.hasKey("verbose") );
28 DAI_ASSERT( opts.hasKey("updates") );
29
30 props.verbose = opts.getStringAs<size_t>("verbose");
31 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
32 if( opts.hasKey("inference") )
33 props.inference = opts.getStringAs<Properties::InfType>("inference");
34 else
35 props.inference = Properties::InfType::SUMPROD;
36 if( opts.hasKey("heuristic") )
37 props.heuristic = opts.getStringAs<Properties::HeuristicType>("heuristic");
38 else
39 props.heuristic = Properties::HeuristicType::MINFILL;
40 }
41
42
43 PropertySet JTree::getProperties() const {
44 PropertySet opts;
45 opts.set( "verbose", props.verbose );
46 opts.set( "updates", props.updates );
47 opts.set( "inference", props.inference );
48 opts.set( "heuristic", props.heuristic );
49 return opts;
50 }
51
52
53 string JTree::printProperties() const {
54 stringstream s( stringstream::out );
55 s << "[";
56 s << "verbose=" << props.verbose << ",";
57 s << "updates=" << props.updates << ",";
58 s << "heuristic=" << props.heuristic << ",";
59 s << "inference=" << props.inference << "]";
60 return s.str();
61 }
62
63
64 JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {
65 setProperties( opts );
66
67 if( !fg.isConnected() )
68 DAI_THROW(FACTORGRAPH_NOT_CONNECTED);
69
70 if( automatic ) {
71 // Create ClusterGraph which contains factors as clusters
72 vector<VarSet> cl;
73 cl.reserve( fg.nrFactors() );
74 for( size_t I = 0; I < fg.nrFactors(); I++ )
75 cl.push_back( fg.factor(I).vars() );
76 ClusterGraph _cg( cl );
77
78 if( props.verbose >= 3 )
79 cerr << "Initial clusters: " << _cg << endl;
80
81 // Retain only maximal clusters
82 _cg.eraseNonMaximal();
83 if( props.verbose >= 3 )
84 cerr << "Maximal clusters: " << _cg << endl;
85
86 // Use heuristic to guess optimal elimination sequence
87 greedyVariableElimination::eliminationCostFunction ec(NULL);
88 switch( (size_t)props.heuristic ) {
89 case Properties::HeuristicType::MINNEIGHBORS:
90 ec = eliminationCost_MinNeighbors;
91 break;
92 case Properties::HeuristicType::MINWEIGHT:
93 ec = eliminationCost_MinWeight;
94 break;
95 case Properties::HeuristicType::MINFILL:
96 ec = eliminationCost_MinFill;
97 break;
98 case Properties::HeuristicType::WEIGHTEDMINFILL:
99 ec = eliminationCost_WeightedMinFill;
100 break;
101 default:
102 DAI_THROW(UNKNOWN_ENUM_VALUE);
103 }
104 vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( ec ) ).eraseNonMaximal().clusters();
105 if( props.verbose >= 3 )
106 cerr << "VarElim result: " << ElimVec << endl;
107
108 // Generate the junction tree corresponding to the elimination sequence
109 GenerateJT( fg, ElimVec );
110 }
111 }
112
113
114 void JTree::construct( const FactorGraph &fg, const std::vector<VarSet> &cl, bool verify ) {
115 // Copy the factor graph
116 FactorGraph::operator=( fg );
117
118 // Construct a weighted graph (each edge is weighted with the cardinality
119 // of the intersection of the nodes, where the nodes are the elements of cl).
120 WeightedGraph<int> JuncGraph;
121 for( size_t i = 0; i < cl.size(); i++ )
122 for( size_t j = i+1; j < cl.size(); j++ ) {
123 size_t w = (cl[i] & cl[j]).size();
124 if( w )
125 JuncGraph[UEdge(i,j)] = w;
126 }
127
128 // Construct maximal spanning tree using Prim's algorithm
129 RTree = MaxSpanningTree( JuncGraph, true );
130
131 // Construct corresponding region graph
132
133 // Create outer regions
134 _ORs.clear();
135 _ORs.reserve( cl.size() );
136 for( size_t i = 0; i < cl.size(); i++ )
137 _ORs.push_back( FRegion( Factor(cl[i], 1.0), 1.0 ) );
138
139 // For each factor, find an outer region that subsumes that factor.
140 // Then, multiply the outer region with that factor.
141 _fac2OR.clear();
142 _fac2OR.resize( nrFactors(), -1U );
143 for( size_t I = 0; I < nrFactors(); I++ ) {
144 size_t alpha;
145 for( alpha = 0; alpha < nrORs(); alpha++ )
146 if( OR(alpha).vars() >> factor(I).vars() ) {
147 _fac2OR[I] = alpha;
148 break;
149 }
150 if( verify )
151 DAI_ASSERT( alpha != nrORs() );
152 }
153 RecomputeORs();
154
155 // Create inner regions and edges
156 _IRs.clear();
157 _IRs.reserve( RTree.size() );
158 vector<Edge> edges;
159 edges.reserve( 2 * RTree.size() );
160 for( size_t i = 0; i < RTree.size(); i++ ) {
161 edges.push_back( Edge( RTree[i].first, nrIRs() ) );
162 edges.push_back( Edge( RTree[i].second, nrIRs() ) );
163 // inner clusters have counting number -1
164 _IRs.push_back( Region( cl[RTree[i].first] & cl[RTree[i].second], -1.0 ) );
165 }
166
167 // create bipartite graph
168 _G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
169
170 // Check counting numbers
171 #ifdef DAI_DEBUG
172 checkCountingNumbers();
173 #endif
174
175 // Create beliefs
176 Qa.clear();
177 Qa.reserve( nrORs() );
178 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
179 Qa.push_back( OR(alpha) );
180
181 Qb.clear();
182 Qb.reserve( nrIRs() );
183 for( size_t beta = 0; beta < nrIRs(); beta++ )
184 Qb.push_back( Factor( IR(beta), 1.0 ) );
185 }
186
187
188 void JTree::GenerateJT( const FactorGraph &fg, const std::vector<VarSet> &cl ) {
189 construct( fg, cl, true );
190
191 // Create messages
192 _mes.clear();
193 _mes.reserve( nrORs() );
194 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
195 _mes.push_back( vector<Factor>() );
196 _mes[alpha].reserve( nbOR(alpha).size() );
197 foreach( const Neighbor &beta, nbOR(alpha) )
198 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
199 }
200
201 if( props.verbose >= 3 )
202 cerr << "Regiongraph generated by JTree::GenerateJT: " << *this << endl;
203 }
204
205
206 string JTree::identify() const {
207 return string(Name) + printProperties();
208 }
209
210
211 Factor JTree::belief( const VarSet &vs ) const {
212 vector<Factor>::const_iterator beta;
213 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
214 if( beta->vars() >> vs )
215 break;
216 if( beta != Qb.end() )
217 return( beta->marginal(vs) );
218 else {
219 vector<Factor>::const_iterator alpha;
220 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
221 if( alpha->vars() >> vs )
222 break;
223 if( alpha == Qa.end() ) {
224 DAI_THROW(BELIEF_NOT_AVAILABLE);
225 return Factor();
226 } else
227 return( alpha->marginal(vs) );
228 }
229 }
230
231
232 vector<Factor> JTree::beliefs() const {
233 vector<Factor> result;
234 for( size_t beta = 0; beta < nrIRs(); beta++ )
235 result.push_back( Qb[beta] );
236 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
237 result.push_back( Qa[alpha] );
238 return result;
239 }
240
241
242 void JTree::runHUGIN() {
243 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
244 Qa[alpha] = OR(alpha);
245
246 for( size_t beta = 0; beta < nrIRs(); beta++ )
247 Qb[beta].fill( 1.0 );
248
249 // CollectEvidence
250 _logZ = 0.0;
251 for( size_t i = RTree.size(); (i--) != 0; ) {
252 // Make outer region RTree[i].first consistent with outer region RTree[i].second
253 // IR(i) = seperator OR(RTree[i].first) && OR(RTree[i].second)
254 Factor new_Qb;
255 if( props.inference == Properties::InfType::SUMPROD )
256 new_Qb = Qa[RTree[i].second].marginal( IR( i ), false );
257 else
258 new_Qb = Qa[RTree[i].second].maxMarginal( IR( i ), false );
259
260 _logZ += log(new_Qb.normalize());
261 Qa[RTree[i].first] *= new_Qb / Qb[i];
262 Qb[i] = new_Qb;
263 }
264 if( RTree.empty() )
265 _logZ += log(Qa[0].normalize() );
266 else
267 _logZ += log(Qa[RTree[0].first].normalize());
268
269 // DistributeEvidence
270 for( size_t i = 0; i < RTree.size(); i++ ) {
271 // Make outer region RTree[i].second consistent with outer region RTree[i].first
272 // IR(i) = seperator OR(RTree[i].first) && OR(RTree[i].second)
273 Factor new_Qb;
274 if( props.inference == Properties::InfType::SUMPROD )
275 new_Qb = Qa[RTree[i].first].marginal( IR( i ) );
276 else
277 new_Qb = Qa[RTree[i].first].maxMarginal( IR( i ) );
278
279 Qa[RTree[i].second] *= new_Qb / Qb[i];
280 Qb[i] = new_Qb;
281 }
282
283 // Normalize
284 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
285 Qa[alpha].normalize();
286 }
287
288
289 void JTree::runShaferShenoy() {
290 // First pass
291 _logZ = 0.0;
292 for( size_t e = nrIRs(); (e--) != 0; ) {
293 // send a message from RTree[e].second to RTree[e].first
294 // or, actually, from the seperator IR(e) to RTree[e].first
295
296 size_t i = nbIR(e)[1].node; // = RTree[e].second
297 size_t j = nbIR(e)[0].node; // = RTree[e].first
298 size_t _e = nbIR(e)[0].dual;
299
300 Factor msg = OR(i);
301 foreach( const Neighbor &k, nbOR(i) )
302 if( k != e )
303 msg *= message( i, k.iter );
304 if( props.inference == Properties::InfType::SUMPROD )
305 message( j, _e ) = msg.marginal( IR(e), false );
306 else
307 message( j, _e ) = msg.maxMarginal( IR(e), false );
308 _logZ += log( message(j,_e).normalize() );
309 }
310
311 // Second pass
312 for( size_t e = 0; e < nrIRs(); e++ ) {
313 size_t i = nbIR(e)[0].node; // = RTree[e].first
314 size_t j = nbIR(e)[1].node; // = RTree[e].second
315 size_t _e = nbIR(e)[1].dual;
316
317 Factor msg = OR(i);
318 foreach( const Neighbor &k, nbOR(i) )
319 if( k != e )
320 msg *= message( i, k.iter );
321 if( props.inference == Properties::InfType::SUMPROD )
322 message( j, _e ) = msg.marginal( IR(e) );
323 else
324 message( j, _e ) = msg.maxMarginal( IR(e) );
325 }
326
327 // Calculate beliefs
328 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
329 Factor piet = OR(alpha);
330 foreach( const Neighbor &k, nbOR(alpha) )
331 piet *= message( alpha, k.iter );
332 if( nrIRs() == 0 ) {
333 _logZ += log( piet.normalize() );
334 Qa[alpha] = piet;
335 } else if( alpha == nbIR(0)[0].node /*RTree[0].first*/ ) {
336 _logZ += log( piet.normalize() );
337 Qa[alpha] = piet;
338 } else
339 Qa[alpha] = piet.normalized();
340 }
341
342 // Only for logZ (and for belief)...
343 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
344 if( props.inference == Properties::InfType::SUMPROD )
345 Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
346 else
347 Qb[beta] = Qa[nbIR(beta)[0].node].maxMarginal( IR(beta) );
348 }
349 }
350
351
352 Real JTree::run() {
353 if( props.updates == Properties::UpdateType::HUGIN )
354 runHUGIN();
355 else if( props.updates == Properties::UpdateType::SHSH )
356 runShaferShenoy();
357 return 0.0;
358 }
359
360
361 Real JTree::logZ() const {
362 /* Real s = 0.0;
363 for( size_t beta = 0; beta < nrIRs(); beta++ )
364 s += IR(beta).c() * Qb[beta].entropy();
365 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
366 s += OR(alpha).c() * Qa[alpha].entropy();
367 s += (OR(alpha).log(true) * Qa[alpha]).sum();
368 }
369 DAI_ASSERT( abs( _logZ - s ) < 1e-8 );
370 return s;*/
371 return _logZ;
372 }
373
374
375 size_t JTree::findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot ) const {
376 // find new root clique (the one with maximal statespace overlap with vs)
377 size_t maxval = 0, maxalpha = 0;
378 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
379 size_t val = VarSet(vs & OR(alpha).vars()).nrStates();
380 if( val > maxval ) {
381 maxval = val;
382 maxalpha = alpha;
383 }
384 }
385
386 // reorder the tree edges such that maxalpha becomes the new root
387 RootedTree newTree( GraphEL( RTree.begin(), RTree.end() ), maxalpha );
388
389 // identify subtree that contains all variables of vs which are not in the new root
390 set<DEdge> subTree;
391 // for each variable in vs
392 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ ) {
393 for( size_t e = 0; e < newTree.size(); e++ ) {
394 if( OR(newTree[e].second).vars().contains( *n ) ) {
395 size_t f = e;
396 subTree.insert( newTree[f] );
397 size_t pos = newTree[f].first;
398 for( ; f > 0; f-- )
399 if( newTree[f-1].second == pos ) {
400 subTree.insert( newTree[f-1] );
401 pos = newTree[f-1].first;
402 }
403 }
404 }
405 }
406 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
407 // find first occurence of PreviousRoot in the tree, which is closest to the new root
408 size_t e = 0;
409 for( ; e != newTree.size(); e++ ) {
410 if( newTree[e].second == PreviousRoot )
411 break;
412 }
413 DAI_ASSERT( e != newTree.size() );
414
415 // track-back path to root and add edges to subTree
416 subTree.insert( newTree[e] );
417 size_t pos = newTree[e].first;
418 for( ; e > 0; e-- )
419 if( newTree[e-1].second == pos ) {
420 subTree.insert( newTree[e-1] );
421 pos = newTree[e-1].first;
422 }
423 }
424
425 // Resulting Tree is a reordered copy of newTree
426 // First add edges in subTree to Tree
427 Tree.clear();
428 vector<DEdge> remTree;
429 for( RootedTree::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
430 if( subTree.count( *e ) )
431 Tree.push_back( *e );
432 else
433 remTree.push_back( *e );
434 size_t subTreeSize = Tree.size();
435 // Then add remaining edges
436 copy( remTree.begin(), remTree.end(), back_inserter( Tree ) );
437
438 return subTreeSize;
439 }
440
441
442 Factor JTree::calcMarginal( const VarSet& vs ) {
443 vector<Factor>::const_iterator beta;
444 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
445 if( beta->vars() >> vs )
446 break;
447 if( beta != Qb.end() )
448 return( beta->marginal(vs) );
449 else {
450 vector<Factor>::const_iterator alpha;
451 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
452 if( alpha->vars() >> vs )
453 break;
454 if( alpha != Qa.end() )
455 return( alpha->marginal(vs) );
456 else {
457 // Find subtree to do efficient inference
458 RootedTree T;
459 size_t Tsize = findEfficientTree( vs, T );
460
461 // Find remaining variables (which are not in the new root)
462 VarSet vsrem = vs / OR(T.front().first).vars();
463 Factor Pvs (vs, 0.0);
464
465 // Save Qa and Qb on the subtree
466 map<size_t,Factor> Qa_old;
467 map<size_t,Factor> Qb_old;
468 vector<size_t> b(Tsize, 0);
469 for( size_t i = Tsize; (i--) != 0; ) {
470 size_t alpha1 = T[i].first;
471 size_t alpha2 = T[i].second;
472 size_t beta;
473 for( beta = 0; beta < nrIRs(); beta++ )
474 if( UEdge( RTree[beta].first, RTree[beta].second ) == UEdge( alpha1, alpha2 ) )
475 break;
476 DAI_ASSERT( beta != nrIRs() );
477 b[i] = beta;
478
479 if( !Qa_old.count( alpha1 ) )
480 Qa_old[alpha1] = Qa[alpha1];
481 if( !Qa_old.count( alpha2 ) )
482 Qa_old[alpha2] = Qa[alpha2];
483 if( !Qb_old.count( beta ) )
484 Qb_old[beta] = Qb[beta];
485 }
486
487 // For all states of vsrem
488 for( State s(vsrem); s.valid(); s++ ) {
489 // CollectEvidence
490 Real logZ = 0.0;
491 for( size_t i = Tsize; (i--) != 0; ) {
492 // Make outer region T[i].first consistent with outer region T[i].second
493 // IR(i) = seperator OR(T[i].first) && OR(T[i].second)
494
495 for( VarSet::const_iterator n = vsrem.begin(); n != vsrem.end(); n++ )
496 if( Qa[T[i].second].vars() >> *n ) {
497 Factor piet( *n, 0.0 );
498 piet.set( s(*n), 1.0 );
499 Qa[T[i].second] *= piet;
500 }
501
502 Factor new_Qb = Qa[T[i].second].marginal( IR( b[i] ), false );
503 logZ += log(new_Qb.normalize());
504 Qa[T[i].first] *= new_Qb / Qb[b[i]];
505 Qb[b[i]] = new_Qb;
506 }
507 logZ += log(Qa[T[0].first].normalize());
508
509 Factor piet( vsrem, 0.0 );
510 piet.set( s, exp(logZ) );
511 Pvs += piet * Qa[T[0].first].marginal( vs / vsrem, false ); // OPTIMIZE ME
512
513 // Restore clamped beliefs
514 for( map<size_t,Factor>::const_iterator alpha = Qa_old.begin(); alpha != Qa_old.end(); alpha++ )
515 Qa[alpha->first] = alpha->second;
516 for( map<size_t,Factor>::const_iterator beta = Qb_old.begin(); beta != Qb_old.end(); beta++ )
517 Qb[beta->first] = beta->second;
518 }
519
520 return( Pvs.normalized() );
521 }
522 }
523 }
524
525
526 std::pair<size_t,double> boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn ) {
527 ClusterGraph _cg;
528
529 // Copy factors
530 for( size_t I = 0; I < fg.nrFactors(); I++ )
531 _cg.insert( fg.factor(I).vars() );
532
533 // Retain only maximal clusters
534 _cg.eraseNonMaximal();
535
536 // Obtain elimination sequence
537 vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( fn ) ).eraseNonMaximal().clusters();
538
539 // Calculate treewidth
540 size_t treewidth = 0;
541 double nrstates = 0.0;
542 for( size_t i = 0; i < ElimVec.size(); i++ ) {
543 if( ElimVec[i].size() > treewidth )
544 treewidth = ElimVec[i].size();
545 size_t s = ElimVec[i].nrStates();
546 if( s > nrstates )
547 nrstates = s;
548 }
549
550 return make_pair(treewidth, nrstates);
551 }
552
553
554 std::vector<size_t> JTree::findMaximum() const {
555 vector<size_t> maximum( nrVars() );
556 vector<bool> visitedVars( nrVars(), false );
557 vector<bool> visitedFactors( nrFactors(), false );
558 stack<size_t> scheduledFactors;
559 for( size_t i = 0; i < nrVars(); ++i ) {
560 if( visitedVars[i] )
561 continue;
562 visitedVars[i] = true;
563
564 // Maximise with respect to variable i
565 Prob prod = beliefV(i).p();
566 maximum[i] = prod.argmax().first;
567
568 foreach( const Neighbor &I, nbV(i) )
569 if( !visitedFactors[I] )
570 scheduledFactors.push(I);
571
572 while( !scheduledFactors.empty() ){
573 size_t I = scheduledFactors.top();
574 scheduledFactors.pop();
575 if( visitedFactors[I] )
576 continue;
577 visitedFactors[I] = true;
578
579 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
580 bool allDetermined = true;
581 foreach( const Neighbor &j, nbF(I) )
582 if( !visitedVars[j.node] ) {
583 allDetermined = false;
584 break;
585 }
586 if( allDetermined )
587 continue;
588
589 // Calculate product of incoming messages on factor I
590 Prob prod2 = beliefF(I).p();
591
592 // The allowed configuration is restrained according to the variables assigned so far:
593 // pick the argmax amongst the allowed states
594 Real maxProb = numeric_limits<Real>::min();
595 State maxState( factor(I).vars() );
596 for( State s( factor(I).vars() ); s.valid(); ++s ){
597 // First, calculate whether this state is consistent with variables that
598 // have been assigned already
599 bool allowedState = true;
600 foreach( const Neighbor &j, nbF(I) )
601 if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
602 allowedState = false;
603 break;
604 }
605 // If it is consistent, check if its probability is larger than what we have seen so far
606 if( allowedState && prod2[s] > maxProb ) {
607 maxState = s;
608 maxProb = prod2[s];
609 }
610 }
611
612 // Decode the argmax
613 foreach( const Neighbor &j, nbF(I) ) {
614 if( visitedVars[j.node] ) {
615 // We have already visited j earlier - hopefully our state is consistent
616 if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
617 cerr << "JTree::findMaximum - warning: maximum not consistent due to loops." << endl;
618 } else {
619 // We found a consistent state for variable j
620 visitedVars[j.node] = true;
621 maximum[j.node] = maxState( var(j.node) );
622 foreach( const Neighbor &J, nbV(j) )
623 if( !visitedFactors[J] )
624 scheduledFactors.push(J);
625 }
626 }
627 }
628 }
629 return maximum;
630 }
631
632
633 } // end of namespace dai