Miscellaneous smaller improvements
[libdai.git] / src / jtree.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <stack>
14 #include <dai/jtree.h>
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 const char *JTree::Name = "JTREE";
24
25
26 void JTree::setProperties( const PropertySet &opts ) {
27 DAI_ASSERT( opts.hasKey("verbose") );
28 DAI_ASSERT( opts.hasKey("updates") );
29
30 props.verbose = opts.getStringAs<size_t>("verbose");
31 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
32 if( opts.hasKey("inference") )
33 props.inference = opts.getStringAs<Properties::InfType>("inference");
34 else
35 props.inference = Properties::InfType::SUMPROD;
36 }
37
38
39 PropertySet JTree::getProperties() const {
40 PropertySet opts;
41 opts.Set( "verbose", props.verbose );
42 opts.Set( "updates", props.updates );
43 opts.Set( "inference", props.inference );
44 return opts;
45 }
46
47
48 string JTree::printProperties() const {
49 stringstream s( stringstream::out );
50 s << "[";
51 s << "verbose=" << props.verbose << ",";
52 s << "updates=" << props.updates << ",";
53 s << "inference=" << props.inference << "]";
54 return s.str();
55 }
56
57
58 JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {
59 setProperties( opts );
60
61 if( !isConnected() )
62 DAI_THROW(FACTORGRAPH_NOT_CONNECTED);
63
64 if( automatic ) {
65 // Create ClusterGraph which contains factors as clusters
66 vector<VarSet> cl;
67 cl.reserve( fg.nrFactors() );
68 for( size_t I = 0; I < nrFactors(); I++ )
69 cl.push_back( factor(I).vars() );
70 ClusterGraph _cg( cl );
71
72 if( props.verbose >= 3 )
73 cerr << "Initial clusters: " << _cg << endl;
74
75 // Retain only maximal clusters
76 _cg.eraseNonMaximal();
77 if( props.verbose >= 3 )
78 cerr << "Maximal clusters: " << _cg << endl;
79
80 // Use MinFill heuristic to guess optimal elimination sequence
81 vector<VarSet> ElimVec = _cg.VarElim( eliminationChoice_MinFill ).eraseNonMaximal().toVector();
82 if( props.verbose >= 3 )
83 cerr << "VarElim result: " << ElimVec << endl;
84
85 // Generate the junction tree corresponding to the elimination sequence
86 GenerateJT( ElimVec );
87 }
88 }
89
90
91 void JTree::construct( const std::vector<VarSet> &cl, bool verify ) {
92 // Construct a weighted graph (each edge is weighted with the cardinality
93 // of the intersection of the nodes, where the nodes are the elements of cl).
94 WeightedGraph<int> JuncGraph;
95 for( size_t i = 0; i < cl.size(); i++ )
96 for( size_t j = i+1; j < cl.size(); j++ ) {
97 size_t w = (cl[i] & cl[j]).size();
98 if( w )
99 JuncGraph[UEdge(i,j)] = w;
100 }
101
102 // Construct maximal spanning tree using Prim's algorithm
103 RTree = MaxSpanningTreePrims( JuncGraph );
104
105 // Construct corresponding region graph
106
107 // Create outer regions
108 ORs.clear();
109 ORs.reserve( cl.size() );
110 for( size_t i = 0; i < cl.size(); i++ )
111 ORs.push_back( FRegion( Factor(cl[i], 1.0), 1.0 ) );
112
113 // For each factor, find an outer region that subsumes that factor.
114 // Then, multiply the outer region with that factor.
115 fac2OR.clear();
116 fac2OR.resize( nrFactors(), -1U );
117 for( size_t I = 0; I < nrFactors(); I++ ) {
118 size_t alpha;
119 for( alpha = 0; alpha < nrORs(); alpha++ )
120 if( OR(alpha).vars() >> factor(I).vars() ) {
121 fac2OR[I] = alpha;
122 break;
123 }
124 if( verify )
125 DAI_ASSERT( alpha != nrORs() );
126 }
127 RecomputeORs();
128
129 // Create inner regions and edges
130 IRs.clear();
131 IRs.reserve( RTree.size() );
132 vector<Edge> edges;
133 edges.reserve( 2 * RTree.size() );
134 for( size_t i = 0; i < RTree.size(); i++ ) {
135 edges.push_back( Edge( RTree[i].n1, nrIRs() ) );
136 edges.push_back( Edge( RTree[i].n2, nrIRs() ) );
137 // inner clusters have counting number -1
138 IRs.push_back( Region( cl[RTree[i].n1] & cl[RTree[i].n2], -1.0 ) );
139 }
140
141 // create bipartite graph
142 G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
143
144 // Check counting numbers
145 #ifdef DAI_DEBUG
146 checkCountingNumbers();
147 #endif
148
149 // Create beliefs
150 Qa.clear();
151 Qa.reserve( nrORs() );
152 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
153 Qa.push_back( OR(alpha) );
154
155 Qb.clear();
156 Qb.reserve( nrIRs() );
157 for( size_t beta = 0; beta < nrIRs(); beta++ )
158 Qb.push_back( Factor( IR(beta), 1.0 ) );
159 }
160
161
162 void JTree::GenerateJT( const std::vector<VarSet> &cl ) {
163 construct( cl, true );
164
165 // Create messages
166 _mes.clear();
167 _mes.reserve( nrORs() );
168 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
169 _mes.push_back( vector<Factor>() );
170 _mes[alpha].reserve( nbOR(alpha).size() );
171 foreach( const Neighbor &beta, nbOR(alpha) )
172 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
173 }
174
175 if( props.verbose >= 3 )
176 cerr << "Regiongraph generated by JTree::GenerateJT: " << *this << endl;
177 }
178
179
180 string JTree::identify() const {
181 return string(Name) + printProperties();
182 }
183
184
185 Factor JTree::belief( const VarSet &vs ) const {
186 vector<Factor>::const_iterator beta;
187 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
188 if( beta->vars() >> vs )
189 break;
190 if( beta != Qb.end() )
191 return( beta->marginal(vs) );
192 else {
193 vector<Factor>::const_iterator alpha;
194 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
195 if( alpha->vars() >> vs )
196 break;
197 if( alpha == Qa.end() ) {
198 DAI_THROW(BELIEF_NOT_AVAILABLE);
199 return Factor();
200 } else
201 return( alpha->marginal(vs) );
202 }
203 }
204
205
206 vector<Factor> JTree::beliefs() const {
207 vector<Factor> result;
208 for( size_t beta = 0; beta < nrIRs(); beta++ )
209 result.push_back( Qb[beta] );
210 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
211 result.push_back( Qa[alpha] );
212 return result;
213 }
214
215
216 void JTree::runHUGIN() {
217 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
218 Qa[alpha] = OR(alpha);
219
220 for( size_t beta = 0; beta < nrIRs(); beta++ )
221 Qb[beta].fill( 1.0 );
222
223 // CollectEvidence
224 _logZ = 0.0;
225 for( size_t i = RTree.size(); (i--) != 0; ) {
226 // Make outer region RTree[i].n1 consistent with outer region RTree[i].n2
227 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
228 Factor new_Qb;
229 if( props.inference == Properties::InfType::SUMPROD )
230 new_Qb = Qa[RTree[i].n2].marginal( IR( i ), false );
231 else
232 new_Qb = Qa[RTree[i].n2].maxMarginal( IR( i ), false );
233
234 _logZ += log(new_Qb.normalize());
235 Qa[RTree[i].n1] *= new_Qb / Qb[i];
236 Qb[i] = new_Qb;
237 }
238 if( RTree.empty() )
239 _logZ += log(Qa[0].normalize() );
240 else
241 _logZ += log(Qa[RTree[0].n1].normalize());
242
243 // DistributeEvidence
244 for( size_t i = 0; i < RTree.size(); i++ ) {
245 // Make outer region RTree[i].n2 consistent with outer region RTree[i].n1
246 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
247 Factor new_Qb;
248 if( props.inference == Properties::InfType::SUMPROD )
249 new_Qb = Qa[RTree[i].n1].marginal( IR( i ) );
250 else
251 new_Qb = Qa[RTree[i].n1].maxMarginal( IR( i ) );
252
253 Qa[RTree[i].n2] *= new_Qb / Qb[i];
254 Qb[i] = new_Qb;
255 }
256
257 // Normalize
258 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
259 Qa[alpha].normalize();
260 }
261
262
263 void JTree::runShaferShenoy() {
264 // First pass
265 _logZ = 0.0;
266 for( size_t e = nrIRs(); (e--) != 0; ) {
267 // send a message from RTree[e].n2 to RTree[e].n1
268 // or, actually, from the seperator IR(e) to RTree[e].n1
269
270 size_t i = nbIR(e)[1].node; // = RTree[e].n2
271 size_t j = nbIR(e)[0].node; // = RTree[e].n1
272 size_t _e = nbIR(e)[0].dual;
273
274 Factor msg = OR(i);
275 foreach( const Neighbor &k, nbOR(i) )
276 if( k != e )
277 msg *= message( i, k.iter );
278 if( props.inference == Properties::InfType::SUMPROD )
279 message( j, _e ) = msg.marginal( IR(e), false );
280 else
281 message( j, _e ) = msg.maxMarginal( IR(e), false );
282 _logZ += log( message(j,_e).normalize() );
283 }
284
285 // Second pass
286 for( size_t e = 0; e < nrIRs(); e++ ) {
287 size_t i = nbIR(e)[0].node; // = RTree[e].n1
288 size_t j = nbIR(e)[1].node; // = RTree[e].n2
289 size_t _e = nbIR(e)[1].dual;
290
291 Factor msg = OR(i);
292 foreach( const Neighbor &k, nbOR(i) )
293 if( k != e )
294 msg *= message( i, k.iter );
295 if( props.inference == Properties::InfType::SUMPROD )
296 message( j, _e ) = msg.marginal( IR(e) );
297 else
298 message( j, _e ) = msg.maxMarginal( IR(e) );
299 }
300
301 // Calculate beliefs
302 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
303 Factor piet = OR(alpha);
304 foreach( const Neighbor &k, nbOR(alpha) )
305 piet *= message( alpha, k.iter );
306 if( nrIRs() == 0 ) {
307 _logZ += log( piet.normalize() );
308 Qa[alpha] = piet;
309 } else if( alpha == nbIR(0)[0].node /*RTree[0].n1*/ ) {
310 _logZ += log( piet.normalize() );
311 Qa[alpha] = piet;
312 } else
313 Qa[alpha] = piet.normalized();
314 }
315
316 // Only for logZ (and for belief)...
317 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
318 if( props.inference == Properties::InfType::SUMPROD )
319 Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
320 else
321 Qb[beta] = Qa[nbIR(beta)[0].node].maxMarginal( IR(beta) );
322 }
323 }
324
325
326 Real JTree::run() {
327 if( props.updates == Properties::UpdateType::HUGIN )
328 runHUGIN();
329 else if( props.updates == Properties::UpdateType::SHSH )
330 runShaferShenoy();
331 return 0.0;
332 }
333
334
335 Real JTree::logZ() const {
336 /* Real s = 0.0;
337 for( size_t beta = 0; beta < nrIRs(); beta++ )
338 s += IR(beta).c() * Qb[beta].entropy();
339 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
340 s += OR(alpha).c() * Qa[alpha].entropy();
341 s += (OR(alpha).log(true) * Qa[alpha]).sum();
342 }
343 DAI_ASSERT( abs( _logZ - s ) < 1e-8 );
344 return s;*/
345 return _logZ;
346 }
347
348
349 size_t JTree::findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot ) const {
350 // find new root clique (the one with maximal statespace overlap with vs)
351 size_t maxval = 0, maxalpha = 0;
352 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
353 size_t val = VarSet(vs & OR(alpha).vars()).nrStates();
354 if( val > maxval ) {
355 maxval = val;
356 maxalpha = alpha;
357 }
358 }
359
360 // reorder the tree edges such that maxalpha becomes the new root
361 RootedTree newTree( Graph( RTree.begin(), RTree.end() ), maxalpha );
362
363 // identify subtree that contains all variables of vs which are not in the new root
364 VarSet vsrem = vs / OR(maxalpha).vars();
365 set<DEdge> subTree;
366 // for each variable in vs that is not in the root clique
367 for( VarSet::const_iterator n = vsrem.begin(); n != vsrem.end(); n++ ) {
368 // find first occurence of *n in the tree, which is closest to the root
369 size_t e = 0;
370 for( ; e != newTree.size(); e++ ) {
371 if( OR(newTree[e].n2).vars().contains( *n ) )
372 break;
373 }
374 DAI_ASSERT( e != newTree.size() );
375
376 // track-back path to root and add edges to subTree
377 subTree.insert( newTree[e] );
378 size_t pos = newTree[e].n1;
379 for( ; e > 0; e-- )
380 if( newTree[e-1].n2 == pos ) {
381 subTree.insert( newTree[e-1] );
382 pos = newTree[e-1].n1;
383 }
384 }
385 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
386 // find first occurence of PreviousRoot in the tree, which is closest to the new root
387 size_t e = 0;
388 for( ; e != newTree.size(); e++ ) {
389 if( newTree[e].n2 == PreviousRoot )
390 break;
391 }
392 DAI_ASSERT( e != newTree.size() );
393
394 // track-back path to root and add edges to subTree
395 subTree.insert( newTree[e] );
396 size_t pos = newTree[e].n1;
397 for( ; e > 0; e-- )
398 if( newTree[e-1].n2 == pos ) {
399 subTree.insert( newTree[e-1] );
400 pos = newTree[e-1].n1;
401 }
402 }
403
404 // Resulting Tree is a reordered copy of newTree
405 // First add edges in subTree to Tree
406 Tree.clear();
407 vector<DEdge> remTree;
408 for( RootedTree::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
409 if( subTree.count( *e ) )
410 Tree.push_back( *e );
411 else
412 remTree.push_back( *e );
413 size_t subTreeSize = Tree.size();
414 // Then add remaining edges
415 copy( remTree.begin(), remTree.end(), back_inserter( Tree ) );
416
417 return subTreeSize;
418 }
419
420
421 Factor JTree::calcMarginal( const VarSet& vs ) {
422 vector<Factor>::const_iterator beta;
423 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
424 if( beta->vars() >> vs )
425 break;
426 if( beta != Qb.end() )
427 return( beta->marginal(vs) );
428 else {
429 vector<Factor>::const_iterator alpha;
430 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
431 if( alpha->vars() >> vs )
432 break;
433 if( alpha != Qa.end() )
434 return( alpha->marginal(vs) );
435 else {
436 // Find subtree to do efficient inference
437 RootedTree T;
438 size_t Tsize = findEfficientTree( vs, T );
439
440 // Find remaining variables (which are not in the new root)
441 VarSet vsrem = vs / OR(T.front().n1).vars();
442 Factor Pvs (vs, 0.0);
443
444 // Save Qa and Qb on the subtree
445 map<size_t,Factor> Qa_old;
446 map<size_t,Factor> Qb_old;
447 vector<size_t> b(Tsize, 0);
448 for( size_t i = Tsize; (i--) != 0; ) {
449 size_t alpha1 = T[i].n1;
450 size_t alpha2 = T[i].n2;
451 size_t beta;
452 for( beta = 0; beta < nrIRs(); beta++ )
453 if( UEdge( RTree[beta].n1, RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
454 break;
455 DAI_ASSERT( beta != nrIRs() );
456 b[i] = beta;
457
458 if( !Qa_old.count( alpha1 ) )
459 Qa_old[alpha1] = Qa[alpha1];
460 if( !Qa_old.count( alpha2 ) )
461 Qa_old[alpha2] = Qa[alpha2];
462 if( !Qb_old.count( beta ) )
463 Qb_old[beta] = Qb[beta];
464 }
465
466 // For all states of vsrem
467 for( State s(vsrem); s.valid(); s++ ) {
468 // CollectEvidence
469 Real logZ = 0.0;
470 for( size_t i = Tsize; (i--) != 0; ) {
471 // Make outer region T[i].n1 consistent with outer region T[i].n2
472 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
473
474 for( VarSet::const_iterator n = vsrem.begin(); n != vsrem.end(); n++ )
475 if( Qa[T[i].n2].vars() >> *n ) {
476 Factor piet( *n, 0.0 );
477 piet[s(*n)] = 1.0;
478 Qa[T[i].n2] *= piet;
479 }
480
481 Factor new_Qb = Qa[T[i].n2].marginal( IR( b[i] ), false );
482 logZ += log(new_Qb.normalize());
483 Qa[T[i].n1] *= new_Qb / Qb[b[i]];
484 Qb[b[i]] = new_Qb;
485 }
486 logZ += log(Qa[T[0].n1].normalize());
487
488 Factor piet( vsrem, 0.0 );
489 piet[s] = exp(logZ);
490 Pvs += piet * Qa[T[0].n1].marginal( vs / vsrem, false ); // OPTIMIZE ME
491
492 // Restore clamped beliefs
493 for( map<size_t,Factor>::const_iterator alpha = Qa_old.begin(); alpha != Qa_old.end(); alpha++ )
494 Qa[alpha->first] = alpha->second;
495 for( map<size_t,Factor>::const_iterator beta = Qb_old.begin(); beta != Qb_old.end(); beta++ )
496 Qb[beta->first] = beta->second;
497 }
498
499 return( Pvs.normalized() );
500 }
501 }
502 }
503
504
505 std::pair<size_t,size_t> boundTreewidth( const FactorGraph & fg ) {
506 ClusterGraph _cg;
507
508 // Copy factors
509 for( size_t I = 0; I < fg.nrFactors(); I++ )
510 _cg.insert( fg.factor(I).vars() );
511
512 // Retain only maximal clusters
513 _cg.eraseNonMaximal();
514
515 // Obtain elimination sequence
516 vector<VarSet> ElimVec = _cg.VarElim( eliminationChoice_MinFill ).eraseNonMaximal().toVector();
517
518 // Calculate treewidth
519 size_t treewidth = 0;
520 size_t nrstates = 0;
521 for( size_t i = 0; i < ElimVec.size(); i++ ) {
522 if( ElimVec[i].size() > treewidth )
523 treewidth = ElimVec[i].size();
524 size_t s = ElimVec[i].nrStates();
525 if( s > nrstates )
526 nrstates = s;
527 }
528
529 return pair<size_t,size_t>(treewidth, nrstates);
530 }
531
532
533 std::pair<size_t,size_t> treewidth( const FactorGraph & fg )
534 {
535 return boundTreewidth( fg );
536 }
537
538
539 std::vector<size_t> JTree::findMaximum() const {
540 vector<size_t> maximum( nrVars() );
541 vector<bool> visitedVars( nrVars(), false );
542 vector<bool> visitedFactors( nrFactors(), false );
543 stack<size_t> scheduledFactors;
544 for( size_t i = 0; i < nrVars(); ++i ) {
545 if( visitedVars[i] )
546 continue;
547 visitedVars[i] = true;
548
549 // Maximise with respect to variable i
550 Prob prod = beliefV(i).p();
551 maximum[i] = prod.argmax().first;
552
553 foreach( const Neighbor &I, nbV(i) )
554 if( !visitedFactors[I] )
555 scheduledFactors.push(I);
556
557 while( !scheduledFactors.empty() ){
558 size_t I = scheduledFactors.top();
559 scheduledFactors.pop();
560 if( visitedFactors[I] )
561 continue;
562 visitedFactors[I] = true;
563
564 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
565 bool allDetermined = true;
566 foreach( const Neighbor &j, nbF(I) )
567 if( !visitedVars[j.node] ) {
568 allDetermined = false;
569 break;
570 }
571 if( allDetermined )
572 continue;
573
574 // Calculate product of incoming messages on factor I
575 Prob prod2 = beliefF(I).p();
576
577 // The allowed configuration is restrained according to the variables assigned so far:
578 // pick the argmax amongst the allowed states
579 Real maxProb = numeric_limits<Real>::min();
580 State maxState( factor(I).vars() );
581 for( State s( factor(I).vars() ); s.valid(); ++s ){
582 // First, calculate whether this state is consistent with variables that
583 // have been assigned already
584 bool allowedState = true;
585 foreach( const Neighbor &j, nbF(I) )
586 if( visitedVars[j.node] && maximum[j.node] != s(var(j.node)) ) {
587 allowedState = false;
588 break;
589 }
590 // If it is consistent, check if its probability is larger than what we have seen so far
591 if( allowedState && prod2[s] > maxProb ) {
592 maxState = s;
593 maxProb = prod2[s];
594 }
595 }
596
597 // Decode the argmax
598 foreach( const Neighbor &j, nbF(I) ) {
599 if( visitedVars[j.node] ) {
600 // We have already visited j earlier - hopefully our state is consistent
601 if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
602 cerr << "JTree::findMaximum - warning: maximum not consistent due to loops." << endl;
603 } else {
604 // We found a consistent state for variable j
605 visitedVars[j.node] = true;
606 maximum[j.node] = maxState( var(j.node) );
607 foreach( const Neighbor &J, nbV(j) )
608 if( !visitedFactors[J] )
609 scheduledFactors.push(J);
610 }
611 }
612 }
613 }
614 return maximum;
615 }
616
617
618 } // end of namespace dai