1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
14 #include <dai/jtree.h>
23 const char *JTree::Name
= "JTREE";
26 void JTree::setProperties( const PropertySet
&opts
) {
27 DAI_ASSERT( opts
.hasKey("verbose") );
28 DAI_ASSERT( opts
.hasKey("updates") );
30 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
31 props
.updates
= opts
.getStringAs
<Properties::UpdateType
>("updates");
32 if( opts
.hasKey("inference") )
33 props
.inference
= opts
.getStringAs
<Properties::InfType
>("inference");
35 props
.inference
= Properties::InfType::SUMPROD
;
39 PropertySet
JTree::getProperties() const {
41 opts
.Set( "verbose", props
.verbose
);
42 opts
.Set( "updates", props
.updates
);
43 opts
.Set( "inference", props
.inference
);
48 string
JTree::printProperties() const {
49 stringstream
s( stringstream::out
);
51 s
<< "verbose=" << props
.verbose
<< ",";
52 s
<< "updates=" << props
.updates
<< ",";
53 s
<< "inference=" << props
.inference
<< "]";
58 JTree::JTree( const FactorGraph
&fg
, const PropertySet
&opts
, bool automatic
) : DAIAlgRG(fg
), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {
59 setProperties( opts
);
62 DAI_THROW(FACTORGRAPH_NOT_CONNECTED
);
65 // Create ClusterGraph which contains factors as clusters
67 cl
.reserve( fg
.nrFactors() );
68 for( size_t I
= 0; I
< nrFactors(); I
++ )
69 cl
.push_back( factor(I
).vars() );
70 ClusterGraph
_cg( cl
);
72 if( props
.verbose
>= 3 )
73 cerr
<< "Initial clusters: " << _cg
<< endl
;
75 // Retain only maximal clusters
76 _cg
.eraseNonMaximal();
77 if( props
.verbose
>= 3 )
78 cerr
<< "Maximal clusters: " << _cg
<< endl
;
80 vector
<VarSet
> ElimVec
= _cg
.VarElim_MinFill().eraseNonMaximal().toVector();
81 if( props
.verbose
>= 3 )
82 cerr
<< "VarElim_MinFill result: " << ElimVec
<< endl
;
84 GenerateJT( ElimVec
);
89 void JTree::GenerateJT( const std::vector
<VarSet
> &Cliques
) {
90 // Construct a weighted graph (each edge is weighted with the cardinality
91 // of the intersection of the nodes, where the nodes are the elements of
93 WeightedGraph
<int> JuncGraph
;
94 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
95 for( size_t j
= i
+1; j
< Cliques
.size(); j
++ ) {
96 size_t w
= (Cliques
[i
] & Cliques
[j
]).size();
98 JuncGraph
[UEdge(i
,j
)] = w
;
101 // Construct maximal spanning tree using Prim's algorithm
102 RTree
= MaxSpanningTreePrims( JuncGraph
);
104 // Construct corresponding region graph
106 // Create outer regions
107 ORs
.reserve( Cliques
.size() );
108 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
109 ORs
.push_back( FRegion( Factor(Cliques
[i
], 1.0), 1.0 ) );
111 // For each factor, find an outer region that subsumes that factor.
112 // Then, multiply the outer region with that factor.
113 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
115 for( alpha
= 0; alpha
< nrORs(); alpha
++ )
116 if( OR(alpha
).vars() >> factor(I
).vars() ) {
117 fac2OR
.push_back( alpha
);
120 DAI_ASSERT( alpha
!= nrORs() );
124 // Create inner regions and edges
125 IRs
.reserve( RTree
.size() );
127 edges
.reserve( 2 * RTree
.size() );
128 for( size_t i
= 0; i
< RTree
.size(); i
++ ) {
129 edges
.push_back( Edge( RTree
[i
].n1
, nrIRs() ) );
130 edges
.push_back( Edge( RTree
[i
].n2
, nrIRs() ) );
131 // inner clusters have counting number -1
132 IRs
.push_back( Region( Cliques
[RTree
[i
].n1
] & Cliques
[RTree
[i
].n2
], -1.0 ) );
135 // create bipartite graph
136 G
.construct( nrORs(), nrIRs(), edges
.begin(), edges
.end() );
138 // Create messages and beliefs
140 Qa
.reserve( nrORs() );
141 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
142 Qa
.push_back( OR(alpha
) );
145 Qb
.reserve( nrIRs() );
146 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
147 Qb
.push_back( Factor( IR(beta
), 1.0 ) );
150 _mes
.reserve( nrORs() );
151 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
152 _mes
.push_back( vector
<Factor
>() );
153 _mes
[alpha
].reserve( nbOR(alpha
).size() );
154 foreach( const Neighbor
&beta
, nbOR(alpha
) )
155 _mes
[alpha
].push_back( Factor( IR(beta
), 1.0 ) );
158 // Check counting numbers
159 Check_Counting_Numbers();
161 if( props
.verbose
>= 3 ) {
162 cerr
<< "Resulting regiongraph: " << *this << endl
;
167 string
JTree::identify() const {
168 return string(Name
) + printProperties();
172 Factor
JTree::belief( const VarSet
&ns
) const {
173 vector
<Factor
>::const_iterator beta
;
174 for( beta
= Qb
.begin(); beta
!= Qb
.end(); beta
++ )
175 if( beta
->vars() >> ns
)
177 if( beta
!= Qb
.end() )
178 return( beta
->marginal(ns
) );
180 vector
<Factor
>::const_iterator alpha
;
181 for( alpha
= Qa
.begin(); alpha
!= Qa
.end(); alpha
++ )
182 if( alpha
->vars() >> ns
)
184 DAI_ASSERT( alpha
!= Qa
.end() );
185 return( alpha
->marginal(ns
) );
190 vector
<Factor
> JTree::beliefs() const {
191 vector
<Factor
> result
;
192 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
193 result
.push_back( Qb
[beta
] );
194 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
195 result
.push_back( Qa
[alpha
] );
200 Factor
JTree::belief( const Var
&n
) const {
201 return belief( (VarSet
)n
);
206 void JTree::runHUGIN() {
207 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
208 Qa
[alpha
] = OR(alpha
);
210 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
211 Qb
[beta
].fill( 1.0 );
215 for( size_t i
= RTree
.size(); (i
--) != 0; ) {
216 // Make outer region RTree[i].n1 consistent with outer region RTree[i].n2
217 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
219 if( props
.inference
== Properties::InfType::SUMPROD
)
220 new_Qb
= Qa
[RTree
[i
].n2
].marginal( IR( i
), false );
222 new_Qb
= Qa
[RTree
[i
].n2
].maxMarginal( IR( i
), false );
224 _logZ
+= log(new_Qb
.normalize());
225 Qa
[RTree
[i
].n1
] *= new_Qb
/ Qb
[i
];
229 _logZ
+= log(Qa
[0].normalize() );
231 _logZ
+= log(Qa
[RTree
[0].n1
].normalize());
233 // DistributeEvidence
234 for( size_t i
= 0; i
< RTree
.size(); i
++ ) {
235 // Make outer region RTree[i].n2 consistent with outer region RTree[i].n1
236 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
238 if( props
.inference
== Properties::InfType::SUMPROD
)
239 new_Qb
= Qa
[RTree
[i
].n1
].marginal( IR( i
) );
241 new_Qb
= Qa
[RTree
[i
].n1
].maxMarginal( IR( i
) );
243 Qa
[RTree
[i
].n2
] *= new_Qb
/ Qb
[i
];
248 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
249 Qa
[alpha
].normalize();
253 // Really needs no init! Initial messages can be anything.
254 void JTree::runShaferShenoy() {
257 for( size_t e
= nrIRs(); (e
--) != 0; ) {
258 // send a message from RTree[e].n2 to RTree[e].n1
259 // or, actually, from the seperator IR(e) to RTree[e].n1
261 size_t i
= nbIR(e
)[1].node
; // = RTree[e].n2
262 size_t j
= nbIR(e
)[0].node
; // = RTree[e].n1
263 size_t _e
= nbIR(e
)[0].dual
;
266 foreach( const Neighbor
&k
, nbOR(i
) )
268 msg
*= message( i
, k
.iter
);
269 if( props
.inference
== Properties::InfType::SUMPROD
)
270 message( j
, _e
) = msg
.marginal( IR(e
), false );
272 message( j
, _e
) = msg
.maxMarginal( IR(e
), false );
273 _logZ
+= log( message(j
,_e
).normalize() );
277 for( size_t e
= 0; e
< nrIRs(); e
++ ) {
278 size_t i
= nbIR(e
)[0].node
; // = RTree[e].n1
279 size_t j
= nbIR(e
)[1].node
; // = RTree[e].n2
280 size_t _e
= nbIR(e
)[1].dual
;
283 foreach( const Neighbor
&k
, nbOR(i
) )
285 msg
*= message( i
, k
.iter
);
286 if( props
.inference
== Properties::InfType::SUMPROD
)
287 message( j
, _e
) = msg
.marginal( IR(e
) );
289 message( j
, _e
) = msg
.maxMarginal( IR(e
) );
293 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
294 Factor piet
= OR(alpha
);
295 foreach( const Neighbor
&k
, nbOR(alpha
) )
296 piet
*= message( alpha
, k
.iter
);
298 _logZ
+= log( piet
.normalize() );
300 } else if( alpha
== nbIR(0)[0].node
/*RTree[0].n1*/ ) {
301 _logZ
+= log( piet
.normalize() );
304 Qa
[alpha
] = piet
.normalized();
307 // Only for logZ (and for belief)...
308 for( size_t beta
= 0; beta
< nrIRs(); beta
++ ) {
309 if( props
.inference
== Properties::InfType::SUMPROD
)
310 Qb
[beta
] = Qa
[nbIR(beta
)[0].node
].marginal( IR(beta
) );
312 Qb
[beta
] = Qa
[nbIR(beta
)[0].node
].maxMarginal( IR(beta
) );
317 double JTree::run() {
318 if( props
.updates
== Properties::UpdateType::HUGIN
)
320 else if( props
.updates
== Properties::UpdateType::SHSH
)
326 Real
JTree::logZ() const {
328 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
329 s
+= IR(beta
).c() * Qb
[beta
].entropy();
330 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
331 s
+= OR(alpha
).c() * Qa
[alpha
].entropy();
332 s
+= (OR(alpha
).log(true) * Qa
[alpha
]).sum();
339 size_t JTree::findEfficientTree( const VarSet
& ns
, DEdgeVec
&Tree
, size_t PreviousRoot
) const {
340 // find new root clique (the one with maximal statespace overlap with ns)
341 size_t maxval
= 0, maxalpha
= 0;
342 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
343 size_t val
= VarSet(ns
& OR(alpha
).vars()).nrStates();
352 for( DEdgeVec::const_iterator e
= RTree
.begin(); e
!= RTree
.end(); e
++ )
353 oldTree
.insert( UEdge(e
->n1
, e
->n2
) );
354 DEdgeVec newTree
= GrowRootedTree( oldTree
, maxalpha
);
356 // identify subtree that contains variables of ns which are not in the new root
357 VarSet nsrem
= ns
/ OR(maxalpha
).vars();
359 // for each variable in ns that is not in the root clique
360 for( VarSet::const_iterator n
= nsrem
.begin(); n
!= nsrem
.end(); n
++ ) {
361 // find first occurence of *n in the tree, which is closest to the root
363 for( ; e
!= newTree
.size(); e
++ ) {
364 if( OR(newTree
[e
].n2
).vars().contains( *n
) )
367 DAI_ASSERT( e
!= newTree
.size() );
369 // track-back path to root and add edges to subTree
370 subTree
.insert( newTree
[e
] );
371 size_t pos
= newTree
[e
].n1
;
373 if( newTree
[e
-1].n2
== pos
) {
374 subTree
.insert( newTree
[e
-1] );
375 pos
= newTree
[e
-1].n1
;
378 if( PreviousRoot
!= (size_t)-1 && PreviousRoot
!= maxalpha
) {
379 // find first occurence of PreviousRoot in the tree, which is closest to the new root
381 for( ; e
!= newTree
.size(); e
++ ) {
382 if( newTree
[e
].n2
== PreviousRoot
)
385 DAI_ASSERT( e
!= newTree
.size() );
387 // track-back path to root and add edges to subTree
388 subTree
.insert( newTree
[e
] );
389 size_t pos
= newTree
[e
].n1
;
391 if( newTree
[e
-1].n2
== pos
) {
392 subTree
.insert( newTree
[e
-1] );
393 pos
= newTree
[e
-1].n1
;
397 // Resulting Tree is a reordered copy of newTree
398 // First add edges in subTree to Tree
400 for( DEdgeVec::const_iterator e
= newTree
.begin(); e
!= newTree
.end(); e
++ )
401 if( subTree
.count( *e
) ) {
402 Tree
.push_back( *e
);
404 // Then add edges pointing away from nsrem
406 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
407 for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
409 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
410 e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
411 Tree.push_back( *e );
415 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
416 if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
418 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
419 if( (OR(e->n1).vars() && *n) ) {
424 Tree.push_back( *e );
427 size_t subTreeSize
= Tree
.size();
428 // Then add remaining edges
429 for( DEdgeVec::const_iterator e
= newTree
.begin(); e
!= newTree
.end(); e
++ )
430 if( find( Tree
.begin(), Tree
.end(), *e
) == Tree
.end() )
431 Tree
.push_back( *e
);
437 // Cutset conditioning
438 // assumes that run() has been called already
439 Factor
JTree::calcMarginal( const VarSet
& ns
) {
440 vector
<Factor
>::const_iterator beta
;
441 for( beta
= Qb
.begin(); beta
!= Qb
.end(); beta
++ )
442 if( beta
->vars() >> ns
)
444 if( beta
!= Qb
.end() )
445 return( beta
->marginal(ns
) );
447 vector
<Factor
>::const_iterator alpha
;
448 for( alpha
= Qa
.begin(); alpha
!= Qa
.end(); alpha
++ )
449 if( alpha
->vars() >> ns
)
451 if( alpha
!= Qa
.end() )
452 return( alpha
->marginal(ns
) );
454 // Find subtree to do efficient inference
456 size_t Tsize
= findEfficientTree( ns
, T
);
458 // Find remaining variables (which are not in the new root)
459 VarSet nsrem
= ns
/ OR(T
.front().n1
).vars();
460 Factor
Pns (ns
, 0.0);
462 // Save Qa and Qb on the subtree
463 map
<size_t,Factor
> Qa_old
;
464 map
<size_t,Factor
> Qb_old
;
465 vector
<size_t> b(Tsize
, 0);
466 for( size_t i
= Tsize
; (i
--) != 0; ) {
467 size_t alpha1
= T
[i
].n1
;
468 size_t alpha2
= T
[i
].n2
;
470 for( beta
= 0; beta
< nrIRs(); beta
++ )
471 if( UEdge( RTree
[beta
].n1
, RTree
[beta
].n2
) == UEdge( alpha1
, alpha2
) )
473 DAI_ASSERT( beta
!= nrIRs() );
476 if( !Qa_old
.count( alpha1
) )
477 Qa_old
[alpha1
] = Qa
[alpha1
];
478 if( !Qa_old
.count( alpha2
) )
479 Qa_old
[alpha2
] = Qa
[alpha2
];
480 if( !Qb_old
.count( beta
) )
481 Qb_old
[beta
] = Qb
[beta
];
484 // For all states of nsrem
485 for( State
s(nsrem
); s
.valid(); s
++ ) {
488 for( size_t i
= Tsize
; (i
--) != 0; ) {
489 // Make outer region T[i].n1 consistent with outer region T[i].n2
490 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
492 for( VarSet::const_iterator n
= nsrem
.begin(); n
!= nsrem
.end(); n
++ )
493 if( Qa
[T
[i
].n2
].vars() >> *n
) {
494 Factor
piet( *n
, 0.0 );
499 Factor new_Qb
= Qa
[T
[i
].n2
].marginal( IR( b
[i
] ), false );
500 logZ
+= log(new_Qb
.normalize());
501 Qa
[T
[i
].n1
] *= new_Qb
/ Qb
[b
[i
]];
504 logZ
+= log(Qa
[T
[0].n1
].normalize());
506 Factor
piet( nsrem
, 0.0 );
508 Pns
+= piet
* Qa
[T
[0].n1
].marginal( ns
/ nsrem
, false ); // OPTIMIZE ME
510 // Restore clamped beliefs
511 for( map
<size_t,Factor
>::const_iterator alpha
= Qa_old
.begin(); alpha
!= Qa_old
.end(); alpha
++ )
512 Qa
[alpha
->first
] = alpha
->second
;
513 for( map
<size_t,Factor
>::const_iterator beta
= Qb_old
.begin(); beta
!= Qb_old
.end(); beta
++ )
514 Qb
[beta
->first
] = beta
->second
;
517 return( Pns
.normalized() );
523 /// Calculates upper bound to the treewidth of a FactorGraph
525 * \return a pair (number of variables in largest clique, number of states in largest clique)
527 std::pair
<size_t,size_t> treewidth( const FactorGraph
& fg
) {
531 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
532 _cg
.insert( fg
.factor(I
).vars() );
534 // Retain only maximal clusters
535 _cg
.eraseNonMaximal();
537 // Obtain elimination sequence
538 vector
<VarSet
> ElimVec
= _cg
.VarElim_MinFill().eraseNonMaximal().toVector();
540 // Calculate treewidth
541 size_t treewidth
= 0;
543 for( size_t i
= 0; i
< ElimVec
.size(); i
++ ) {
544 if( ElimVec
[i
].size() > treewidth
)
545 treewidth
= ElimVec
[i
].size();
546 size_t s
= ElimVec
[i
].nrStates();
551 return pair
<size_t,size_t>(treewidth
, nrstates
);
555 std::vector
<size_t> JTree::findMaximum() const {
556 vector
<size_t> maximum( nrVars() );
557 vector
<bool> visitedVars( nrVars(), false );
558 vector
<bool> visitedFactors( nrFactors(), false );
559 stack
<size_t> scheduledFactors
;
560 for( size_t i
= 0; i
< nrVars(); ++i
) {
563 visitedVars
[i
] = true;
565 // Maximise with respect to variable i
566 Prob prod
= beliefV(i
).p();
567 maximum
[i
] = max_element( prod
.begin(), prod
.end() ) - prod
.begin();
569 foreach( const Neighbor
&I
, nbV(i
) )
570 if( !visitedFactors
[I
] )
571 scheduledFactors
.push(I
);
573 while( !scheduledFactors
.empty() ){
574 size_t I
= scheduledFactors
.top();
575 scheduledFactors
.pop();
576 if( visitedFactors
[I
] )
578 visitedFactors
[I
] = true;
580 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
581 bool allDetermined
= true;
582 foreach( const Neighbor
&j
, nbF(I
) )
583 if( !visitedVars
[j
.node
] ) {
584 allDetermined
= false;
590 // Calculate product of incoming messages on factor I
591 Prob prod2
= beliefF(I
).p();
593 // The allowed configuration is restrained according to the variables assigned so far:
594 // pick the argmax amongst the allowed states
595 Real maxProb
= numeric_limits
<Real
>::min();
596 State
maxState( factor(I
).vars() );
597 for( State
s( factor(I
).vars() ); s
.valid(); ++s
){
598 // First, calculate whether this state is consistent with variables that
599 // have been assigned already
600 bool allowedState
= true;
601 foreach( const Neighbor
&j
, nbF(I
) )
602 if( visitedVars
[j
.node
] && maximum
[j
.node
] != s(var(j
.node
)) ) {
603 allowedState
= false;
606 // If it is consistent, check if its probability is larger than what we have seen so far
607 if( allowedState
&& prod2
[s
] > maxProb
) {
614 foreach( const Neighbor
&j
, nbF(I
) ) {
615 if( visitedVars
[j
.node
] ) {
616 // We have already visited j earlier - hopefully our state is consistent
617 if( maximum
[j
.node
] != maxState(var(j
.node
)) && props
.verbose
>= 1 )
618 cerr
<< "JTree::findMaximum - warning: maximum not consistent due to loops." << endl
;
620 // We found a consistent state for variable j
621 visitedVars
[j
.node
] = true;
622 maximum
[j
.node
] = maxState( var(j
.node
) );
623 foreach( const Neighbor
&J
, nbV(j
) )
624 if( !visitedFactors
[J
] )
625 scheduledFactors
.push(J
);
634 } // end of namespace dai