Various changes:
[libdai.git] / src / jtree.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <stack>
14 #include <dai/jtree.h>
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 const char *JTree::Name = "JTREE";
24
25
26 void JTree::setProperties( const PropertySet &opts ) {
27 DAI_ASSERT( opts.hasKey("verbose") );
28 DAI_ASSERT( opts.hasKey("updates") );
29
30 props.verbose = opts.getStringAs<size_t>("verbose");
31 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
32 if( opts.hasKey("inference") )
33 props.inference = opts.getStringAs<Properties::InfType>("inference");
34 else
35 props.inference = Properties::InfType::SUMPROD;
36 if( opts.hasKey("heuristic") )
37 props.heuristic = opts.getStringAs<Properties::HeuristicType>("heuristic");
38 else
39 props.heuristic = Properties::HeuristicType::MINFILL;
40 }
41
42
43 PropertySet JTree::getProperties() const {
44 PropertySet opts;
45 opts.Set( "verbose", props.verbose );
46 opts.Set( "updates", props.updates );
47 opts.Set( "inference", props.inference );
48 opts.Set( "heuristic", props.heuristic );
49 return opts;
50 }
51
52
53 string JTree::printProperties() const {
54 stringstream s( stringstream::out );
55 s << "[";
56 s << "verbose=" << props.verbose << ",";
57 s << "updates=" << props.updates << ",";
58 s << "heuristic=" << props.heuristic << ",";
59 s << "inference=" << props.inference << "]";
60 return s.str();
61 }
62
63
64 JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {
65 setProperties( opts );
66
67 if( !isConnected() )
68 DAI_THROW(FACTORGRAPH_NOT_CONNECTED);
69
70 if( automatic ) {
71 // Create ClusterGraph which contains factors as clusters
72 vector<VarSet> cl;
73 cl.reserve( fg.nrFactors() );
74 for( size_t I = 0; I < nrFactors(); I++ )
75 cl.push_back( factor(I).vars() );
76 ClusterGraph _cg( cl );
77
78 if( props.verbose >= 3 )
79 cerr << "Initial clusters: " << _cg << endl;
80
81 // Retain only maximal clusters
82 _cg.eraseNonMaximal();
83 if( props.verbose >= 3 )
84 cerr << "Maximal clusters: " << _cg << endl;
85
86 // Use heuristic to guess optimal elimination sequence
87 greedyVariableElimination::eliminationCostFunction ec(NULL);
88 switch( (size_t)props.heuristic ) {
89 case Properties::HeuristicType::MINNEIGHBORS:
90 ec = eliminationCost_MinNeighbors;
91 break;
92 case Properties::HeuristicType::MINWEIGHT:
93 ec = eliminationCost_MinWeight;
94 break;
95 case Properties::HeuristicType::MINFILL:
96 ec = eliminationCost_MinFill;
97 break;
98 case Properties::HeuristicType::WEIGHTEDMINFILL:
99 ec = eliminationCost_WeightedMinFill;
100 break;
101 default:
102 DAI_THROW(UNKNOWN_ENUM_VALUE);
103 }
104 vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( ec ) ).eraseNonMaximal().toVector();
105 if( props.verbose >= 3 )
106 cerr << "VarElim result: " << ElimVec << endl;
107
108 // Generate the junction tree corresponding to the elimination sequence
109 GenerateJT( ElimVec );
110 }
111 }
112
113
114 void JTree::construct( const std::vector<VarSet> &cl, bool verify ) {
115 // Construct a weighted graph (each edge is weighted with the cardinality
116 // of the intersection of the nodes, where the nodes are the elements of cl).
117 WeightedGraph<int> JuncGraph;
118 for( size_t i = 0; i < cl.size(); i++ )
119 for( size_t j = i+1; j < cl.size(); j++ ) {
120 size_t w = (cl[i] & cl[j]).size();
121 if( w )
122 JuncGraph[UEdge(i,j)] = w;
123 }
124
125 // Construct maximal spanning tree using Prim's algorithm
126 RTree = MaxSpanningTreePrims( JuncGraph );
127
128 // Construct corresponding region graph
129
130 // Create outer regions
131 ORs.clear();
132 ORs.reserve( cl.size() );
133 for( size_t i = 0; i < cl.size(); i++ )
134 ORs.push_back( FRegion( Factor(cl[i], 1.0), 1.0 ) );
135
136 // For each factor, find an outer region that subsumes that factor.
137 // Then, multiply the outer region with that factor.
138 fac2OR.clear();
139 fac2OR.resize( nrFactors(), -1U );
140 for( size_t I = 0; I < nrFactors(); I++ ) {
141 size_t alpha;
142 for( alpha = 0; alpha < nrORs(); alpha++ )
143 if( OR(alpha).vars() >> factor(I).vars() ) {
144 fac2OR[I] = alpha;
145 break;
146 }
147 if( verify )
148 DAI_ASSERT( alpha != nrORs() );
149 }
150 RecomputeORs();
151
152 // Create inner regions and edges
153 IRs.clear();
154 IRs.reserve( RTree.size() );
155 vector<Edge> edges;
156 edges.reserve( 2 * RTree.size() );
157 for( size_t i = 0; i < RTree.size(); i++ ) {
158 edges.push_back( Edge( RTree[i].n1, nrIRs() ) );
159 edges.push_back( Edge( RTree[i].n2, nrIRs() ) );
160 // inner clusters have counting number -1
161 IRs.push_back( Region( cl[RTree[i].n1] & cl[RTree[i].n2], -1.0 ) );
162 }
163
164 // create bipartite graph
165 G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
166
167 // Check counting numbers
168 #ifdef DAI_DEBUG
169 checkCountingNumbers();
170 #endif
171
172 // Create beliefs
173 Qa.clear();
174 Qa.reserve( nrORs() );
175 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
176 Qa.push_back( OR(alpha) );
177
178 Qb.clear();
179 Qb.reserve( nrIRs() );
180 for( size_t beta = 0; beta < nrIRs(); beta++ )
181 Qb.push_back( Factor( IR(beta), 1.0 ) );
182 }
183
184
185 void JTree::GenerateJT( const std::vector<VarSet> &cl ) {
186 construct( cl, true );
187
188 // Create messages
189 _mes.clear();
190 _mes.reserve( nrORs() );
191 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
192 _mes.push_back( vector<Factor>() );
193 _mes[alpha].reserve( nbOR(alpha).size() );
194 foreach( const Neighbor &beta, nbOR(alpha) )
195 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
196 }
197
198 if( props.verbose >= 3 )
199 cerr << "Regiongraph generated by JTree::GenerateJT: " << *this << endl;
200 }
201
202
203 string JTree::identify() const {
204 return string(Name) + printProperties();
205 }
206
207
208 Factor JTree::belief( const VarSet &vs ) const {
209 vector<Factor>::const_iterator beta;
210 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
211 if( beta->vars() >> vs )
212 break;
213 if( beta != Qb.end() )
214 return( beta->marginal(vs) );
215 else {
216 vector<Factor>::const_iterator alpha;
217 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
218 if( alpha->vars() >> vs )
219 break;
220 if( alpha == Qa.end() ) {
221 DAI_THROW(BELIEF_NOT_AVAILABLE);
222 return Factor();
223 } else
224 return( alpha->marginal(vs) );
225 }
226 }
227
228
229 vector<Factor> JTree::beliefs() const {
230 vector<Factor> result;
231 for( size_t beta = 0; beta < nrIRs(); beta++ )
232 result.push_back( Qb[beta] );
233 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
234 result.push_back( Qa[alpha] );
235 return result;
236 }
237
238
239 void JTree::runHUGIN() {
240 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
241 Qa[alpha] = OR(alpha);
242
243 for( size_t beta = 0; beta < nrIRs(); beta++ )
244 Qb[beta].fill( 1.0 );
245
246 // CollectEvidence
247 _logZ = 0.0;
248 for( size_t i = RTree.size(); (i--) != 0; ) {
249 // Make outer region RTree[i].n1 consistent with outer region RTree[i].n2
250 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
251 Factor new_Qb;
252 if( props.inference == Properties::InfType::SUMPROD )
253 new_Qb = Qa[RTree[i].n2].marginal( IR( i ), false );
254 else
255 new_Qb = Qa[RTree[i].n2].maxMarginal( IR( i ), false );
256
257 _logZ += log(new_Qb.normalize());
258 Qa[RTree[i].n1] *= new_Qb / Qb[i];
259 Qb[i] = new_Qb;
260 }
261 if( RTree.empty() )
262 _logZ += log(Qa[0].normalize() );
263 else
264 _logZ += log(Qa[RTree[0].n1].normalize());
265
266 // DistributeEvidence
267 for( size_t i = 0; i < RTree.size(); i++ ) {
268 // Make outer region RTree[i].n2 consistent with outer region RTree[i].n1
269 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
270 Factor new_Qb;
271 if( props.inference == Properties::InfType::SUMPROD )
272 new_Qb = Qa[RTree[i].n1].marginal( IR( i ) );
273 else
274 new_Qb = Qa[RTree[i].n1].maxMarginal( IR( i ) );
275
276 Qa[RTree[i].n2] *= new_Qb / Qb[i];
277 Qb[i] = new_Qb;
278 }
279
280 // Normalize
281 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
282 Qa[alpha].normalize();
283 }
284
285
286 void JTree::runShaferShenoy() {
287 // First pass
288 _logZ = 0.0;
289 for( size_t e = nrIRs(); (e--) != 0; ) {
290 // send a message from RTree[e].n2 to RTree[e].n1
291 // or, actually, from the seperator IR(e) to RTree[e].n1
292
293 size_t i = nbIR(e)[1].node; // = RTree[e].n2
294 size_t j = nbIR(e)[0].node; // = RTree[e].n1
295 size_t _e = nbIR(e)[0].dual;
296
297 Factor msg = OR(i);
298 foreach( const Neighbor &k, nbOR(i) )
299 if( k != e )
300 msg *= message( i, k.iter );
301 if( props.inference == Properties::InfType::SUMPROD )
302 message( j, _e ) = msg.marginal( IR(e), false );
303 else
304 message( j, _e ) = msg.maxMarginal( IR(e), false );
305 _logZ += log( message(j,_e).normalize() );
306 }
307
308 // Second pass
309 for( size_t e = 0; e < nrIRs(); e++ ) {
310 size_t i = nbIR(e)[0].node; // = RTree[e].n1
311 size_t j = nbIR(e)[1].node; // = RTree[e].n2
312 size_t _e = nbIR(e)[1].dual;
313
314 Factor msg = OR(i);
315 foreach( const Neighbor &k, nbOR(i) )
316 if( k != e )
317 msg *= message( i, k.iter );
318 if( props.inference == Properties::InfType::SUMPROD )
319 message( j, _e ) = msg.marginal( IR(e) );
320 else
321 message( j, _e ) = msg.maxMarginal( IR(e) );
322 }
323
324 // Calculate beliefs
325 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
326 Factor piet = OR(alpha);
327 foreach( const Neighbor &k, nbOR(alpha) )
328 piet *= message( alpha, k.iter );
329 if( nrIRs() == 0 ) {
330 _logZ += log( piet.normalize() );
331 Qa[alpha] = piet;
332 } else if( alpha == nbIR(0)[0].node /*RTree[0].n1*/ ) {
333 _logZ += log( piet.normalize() );
334 Qa[alpha] = piet;
335 } else
336 Qa[alpha] = piet.normalized();
337 }
338
339 // Only for logZ (and for belief)...
340 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
341 if( props.inference == Properties::InfType::SUMPROD )
342 Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
343 else
344 Qb[beta] = Qa[nbIR(beta)[0].node].maxMarginal( IR(beta) );
345 }
346 }
347
348
349 Real JTree::run() {
350 if( props.updates == Properties::UpdateType::HUGIN )
351 runHUGIN();
352 else if( props.updates == Properties::UpdateType::SHSH )
353 runShaferShenoy();
354 return 0.0;
355 }
356
357
358 Real JTree::logZ() const {
359 /* Real s = 0.0;
360 for( size_t beta = 0; beta < nrIRs(); beta++ )
361 s += IR(beta).c() * Qb[beta].entropy();
362 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
363 s += OR(alpha).c() * Qa[alpha].entropy();
364 s += (OR(alpha).log(true) * Qa[alpha]).sum();
365 }
366 DAI_ASSERT( abs( _logZ - s ) < 1e-8 );
367 return s;*/
368 return _logZ;
369 }
370
371
372 size_t JTree::findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot ) const {
373 // find new root clique (the one with maximal statespace overlap with vs)
374 size_t maxval = 0, maxalpha = 0;
375 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
376 size_t val = VarSet(vs & OR(alpha).vars()).nrStates();
377 if( val > maxval ) {
378 maxval = val;
379 maxalpha = alpha;
380 }
381 }
382
383 // reorder the tree edges such that maxalpha becomes the new root
384 RootedTree newTree( GraphEL( RTree.begin(), RTree.end() ), maxalpha );
385
386 // identify subtree that contains all variables of vs which are not in the new root
387 set<DEdge> subTree;
388 // for each variable in vs
389 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ ) {
390 for( size_t e = 0; e < newTree.size(); e++ ) {
391 if( OR(newTree[e].n2).vars().contains( *n ) ) {
392 size_t f = e;
393 subTree.insert( newTree[f] );
394 size_t pos = newTree[f].n1;
395 for( ; f > 0; f-- )
396 if( newTree[f-1].n2 == pos ) {
397 subTree.insert( newTree[f-1] );
398 pos = newTree[f-1].n1;
399 }
400 }
401 }
402 }
403 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
404 // find first occurence of PreviousRoot in the tree, which is closest to the new root
405 size_t e = 0;
406 for( ; e != newTree.size(); e++ ) {
407 if( newTree[e].n2 == PreviousRoot )
408 break;
409 }
410 DAI_ASSERT( e != newTree.size() );
411
412 // track-back path to root and add edges to subTree
413 subTree.insert( newTree[e] );
414 size_t pos = newTree[e].n1;
415 for( ; e > 0; e-- )
416 if( newTree[e-1].n2 == pos ) {
417 subTree.insert( newTree[e-1] );
418 pos = newTree[e-1].n1;
419 }
420 }
421
422 // Resulting Tree is a reordered copy of newTree
423 // First add edges in subTree to Tree
424 Tree.clear();
425 vector<DEdge> remTree;
426 for( RootedTree::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
427 if( subTree.count( *e ) )
428 Tree.push_back( *e );
429 else
430 remTree.push_back( *e );
431 size_t subTreeSize = Tree.size();
432 // Then add remaining edges
433 copy( remTree.begin(), remTree.end(), back_inserter( Tree ) );
434
435 return subTreeSize;
436 }
437
438
439 Factor JTree::calcMarginal( const VarSet& vs ) {
440 vector<Factor>::const_iterator beta;
441 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
442 if( beta->vars() >> vs )
443 break;
444 if( beta != Qb.end() )
445 return( beta->marginal(vs) );
446 else {
447 vector<Factor>::const_iterator alpha;
448 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
449 if( alpha->vars() >> vs )
450 break;
451 if( alpha != Qa.end() )
452 return( alpha->marginal(vs) );
453 else {
454 // Find subtree to do efficient inference
455 RootedTree T;
456 size_t Tsize = findEfficientTree( vs, T );
457
458 // Find remaining variables (which are not in the new root)
459 VarSet vsrem = vs / OR(T.front().n1).vars();
460 Factor Pvs (vs, 0.0);
461
462 // Save Qa and Qb on the subtree
463 map<size_t,Factor> Qa_old;
464 map<size_t,Factor> Qb_old;
465 vector<size_t> b(Tsize, 0);
466 for( size_t i = Tsize; (i--) != 0; ) {
467 size_t alpha1 = T[i].n1;
468 size_t alpha2 = T[i].n2;
469 size_t beta;
470 for( beta = 0; beta < nrIRs(); beta++ )
471 if( UEdge( RTree[beta].n1, RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
472 break;
473 DAI_ASSERT( beta != nrIRs() );
474 b[i] = beta;
475
476 if( !Qa_old.count( alpha1 ) )
477 Qa_old[alpha1] = Qa[alpha1];
478 if( !Qa_old.count( alpha2 ) )
479 Qa_old[alpha2] = Qa[alpha2];
480 if( !Qb_old.count( beta ) )
481 Qb_old[beta] = Qb[beta];
482 }
483
484 // For all states of vsrem
485 for( State s(vsrem); s.valid(); s++ ) {
486 // CollectEvidence
487 Real logZ = 0.0;
488 for( size_t i = Tsize; (i--) != 0; ) {
489 // Make outer region T[i].n1 consistent with outer region T[i].n2
490 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
491
492 for( VarSet::const_iterator n = vsrem.begin(); n != vsrem.end(); n++ )
493 if( Qa[T[i].n2].vars() >> *n ) {
494 Factor piet( *n, 0.0 );
495 piet[s(*n)] = 1.0;
496 Qa[T[i].n2] *= piet;
497 }
498
499 Factor new_Qb = Qa[T[i].n2].marginal( IR( b[i] ), false );
500 logZ += log(new_Qb.normalize());
501 Qa[T[i].n1] *= new_Qb / Qb[b[i]];
502 Qb[b[i]] = new_Qb;
503 }
504 logZ += log(Qa[T[0].n1].normalize());
505
506 Factor piet( vsrem, 0.0 );
507 piet[s] = exp(logZ);
508 Pvs += piet * Qa[T[0].n1].marginal( vs / vsrem, false ); // OPTIMIZE ME
509
510 // Restore clamped beliefs
511 for( map<size_t,Factor>::const_iterator alpha = Qa_old.begin(); alpha != Qa_old.end(); alpha++ )
512 Qa[alpha->first] = alpha->second;
513 for( map<size_t,Factor>::const_iterator beta = Qb_old.begin(); beta != Qb_old.end(); beta++ )
514 Qb[beta->first] = beta->second;
515 }
516
517 return( Pvs.normalized() );
518 }
519 }
520 }
521
522
523 std::pair<size_t,double> boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn ) {
524 ClusterGraph _cg;
525
526 // Copy factors
527 for( size_t I = 0; I < fg.nrFactors(); I++ )
528 _cg.insert( fg.factor(I).vars() );
529
530 // Retain only maximal clusters
531 _cg.eraseNonMaximal();
532
533 // Obtain elimination sequence
534 vector<VarSet> ElimVec = _cg.VarElim( greedyVariableElimination( fn ) ).eraseNonMaximal().toVector();
535
536 // Calculate treewidth
537 size_t treewidth = 0;
538 double nrstates = 0.0;
539 for( size_t i = 0; i < ElimVec.size(); i++ ) {
540 if( ElimVec[i].size() > treewidth )
541 treewidth = ElimVec[i].size();
542 size_t s = ElimVec[i].nrStates();
543 if( s > nrstates )
544 nrstates = s;
545 }
546
547 return make_pair(treewidth, nrstates);
548 }
549
550
551 std::vector<size_t> JTree::findMaximum() const {
552 vector<size_t> maximum( nrVars() );
553 vector<bool> visitedVars( nrVars(), false );
554 vector<bool> visitedFactors( nrFactors(), false );
555 stack<size_t> scheduledFactors;
556 for( size_t i = 0; i < nrVars(); ++i ) {
557 if( visitedVars[i] )
558 continue;
559 visitedVars[i] = true;
560
561 // Maximise with respect to variable i
562 Prob prod = beliefV(i).p();
563 maximum[i] = prod.argmax().first;
564
565 foreach( const Neighbor &I, nbV(i) )
566 if( !visitedFactors[I] )
567 scheduledFactors.push(I);
568
569 while( !scheduledFactors.empty() ){
570 size_t I = scheduledFactors.top();
571 scheduledFactors.pop();
572 if( visitedFactors[I] )
573 continue;
574 visitedFactors[I] = true;
575
576 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
577 bool allDetermined = true;
578 foreach( const Neighbor &j, nbF(I) )
579 if( !visitedVars[j.node] ) {
580 allDetermined = false;
581 break;
582 }
583 if( allDetermined )
584 continue;
585
586 // Calculate product of incoming messages on factor I
587 Prob prod2 = beliefF(I).p();
588
589 // The allowed configuration is restrained according to the variables assigned so far:
590 // pick the argmax amongst the allowed states
591 Real maxProb = numeric_limits<Real>::min();
592 State maxState( factor(I).vars() );
593 for( State s( factor(I).vars() ); s.valid(); ++s ){
594 // First, calculate whether this state is consistent with variables that
595 // have been assigned already
596 bool allowedState = true;
597 foreach( const Neighbor &j, nbF(I) )
598 if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
599 allowedState = false;
600 break;
601 }
602 // If it is consistent, check if its probability is larger than what we have seen so far
603 if( allowedState && prod2[s] > maxProb ) {
604 maxState = s;
605 maxProb = prod2[s];
606 }
607 }
608
609 // Decode the argmax
610 foreach( const Neighbor &j, nbF(I) ) {
611 if( visitedVars[j.node] ) {
612 // We have already visited j earlier - hopefully our state is consistent
613 if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
614 cerr << "JTree::findMaximum - warning: maximum not consistent due to loops." << endl;
615 } else {
616 // We found a consistent state for variable j
617 visitedVars[j.node] = true;
618 maximum[j.node] = maxState( var(j.node) );
619 foreach( const Neighbor &J, nbV(j) )
620 if( !visitedFactors[J] )
621 scheduledFactors.push(J);
622 }
623 }
624 }
625 }
626 return maximum;
627 }
628
629
630 } // end of namespace dai