1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
5 This file is part of libDAI.
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
24 #include <dai/jtree.h>
33 const char *JTree::Name
= "JTREE";
36 void JTree::setProperties( const PropertySet
&opts
) {
37 assert( opts
.hasKey("verbose") );
38 assert( opts
.hasKey("updates") );
40 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
41 props
.updates
= opts
.getStringAs
<Properties::UpdateType
>("updates");
45 PropertySet
JTree::getProperties() const {
47 opts
.Set( "verbose", props
.verbose
);
48 opts
.Set( "updates", props
.updates
);
53 string
JTree::printProperties() const {
54 stringstream
s( stringstream::out
);
56 s
<< "verbose=" << props
.verbose
<< ",";
57 s
<< "updates=" << props
.updates
<< "]";
62 JTree::JTree( const FactorGraph
&fg
, const PropertySet
&opts
, bool automatic
) : DAIAlgRG(fg
), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {
63 setProperties( opts
);
66 DAI_THROW(FACTORGRAPH_NOT_CONNECTED
);
69 // Create ClusterGraph which contains factors as clusters
71 cl
.reserve( fg
.nrFactors() );
72 for( size_t I
= 0; I
< nrFactors(); I
++ )
73 cl
.push_back( factor(I
).vars() );
74 ClusterGraph
_cg( cl
);
76 if( props
.verbose
>= 3 )
77 cout
<< "Initial clusters: " << _cg
<< endl
;
79 // Retain only maximal clusters
80 _cg
.eraseNonMaximal();
81 if( props
.verbose
>= 3 )
82 cout
<< "Maximal clusters: " << _cg
<< endl
;
84 vector
<VarSet
> ElimVec
= _cg
.VarElim_MinFill().eraseNonMaximal().toVector();
85 if( props
.verbose
>= 3 )
86 cout
<< "VarElim_MinFill result: " << ElimVec
<< endl
;
88 GenerateJT( ElimVec
);
93 void JTree::GenerateJT( const std::vector
<VarSet
> &Cliques
) {
94 // Construct a weighted graph (each edge is weighted with the cardinality
95 // of the intersection of the nodes, where the nodes are the elements of
97 WeightedGraph
<int> JuncGraph
;
98 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
99 for( size_t j
= i
+1; j
< Cliques
.size(); j
++ ) {
100 size_t w
= (Cliques
[i
] & Cliques
[j
]).size();
102 JuncGraph
[UEdge(i
,j
)] = w
;
105 // Construct maximal spanning tree using Prim's algorithm
106 RTree
= MaxSpanningTreePrims( JuncGraph
);
108 // Construct corresponding region graph
110 // Create outer regions
111 ORs
.reserve( Cliques
.size() );
112 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
113 ORs
.push_back( FRegion( Factor(Cliques
[i
], 1.0), 1.0 ) );
115 // For each factor, find an outer region that subsumes that factor.
116 // Then, multiply the outer region with that factor.
117 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
119 for( alpha
= 0; alpha
< nrORs(); alpha
++ )
120 if( OR(alpha
).vars() >> factor(I
).vars() ) {
121 fac2OR
.push_back( alpha
);
124 assert( alpha
!= nrORs() );
128 // Create inner regions and edges
129 IRs
.reserve( RTree
.size() );
131 edges
.reserve( 2 * RTree
.size() );
132 for( size_t i
= 0; i
< RTree
.size(); i
++ ) {
133 edges
.push_back( Edge( RTree
[i
].n1
, nrIRs() ) );
134 edges
.push_back( Edge( RTree
[i
].n2
, nrIRs() ) );
135 // inner clusters have counting number -1
136 IRs
.push_back( Region( Cliques
[RTree
[i
].n1
] & Cliques
[RTree
[i
].n2
], -1.0 ) );
139 // create bipartite graph
140 G
.construct( nrORs(), nrIRs(), edges
.begin(), edges
.end() );
142 // Create messages and beliefs
144 Qa
.reserve( nrORs() );
145 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
146 Qa
.push_back( OR(alpha
) );
149 Qb
.reserve( nrIRs() );
150 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
151 Qb
.push_back( Factor( IR(beta
), 1.0 ) );
154 _mes
.reserve( nrORs() );
155 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
156 _mes
.push_back( vector
<Factor
>() );
157 _mes
[alpha
].reserve( nbOR(alpha
).size() );
158 foreach( const Neighbor
&beta
, nbOR(alpha
) )
159 _mes
[alpha
].push_back( Factor( IR(beta
), 1.0 ) );
162 // Check counting numbers
163 Check_Counting_Numbers();
165 if( props
.verbose
>= 3 ) {
166 cout
<< "Resulting regiongraph: " << *this << endl
;
171 string
JTree::identify() const {
172 return string(Name
) + printProperties();
176 Factor
JTree::belief( const VarSet
&ns
) const {
177 vector
<Factor
>::const_iterator beta
;
178 for( beta
= Qb
.begin(); beta
!= Qb
.end(); beta
++ )
179 if( beta
->vars() >> ns
)
181 if( beta
!= Qb
.end() )
182 return( beta
->marginal(ns
) );
184 vector
<Factor
>::const_iterator alpha
;
185 for( alpha
= Qa
.begin(); alpha
!= Qa
.end(); alpha
++ )
186 if( alpha
->vars() >> ns
)
188 assert( alpha
!= Qa
.end() );
189 return( alpha
->marginal(ns
) );
194 vector
<Factor
> JTree::beliefs() const {
195 vector
<Factor
> result
;
196 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
197 result
.push_back( Qb
[beta
] );
198 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
199 result
.push_back( Qa
[alpha
] );
204 Factor
JTree::belief( const Var
&n
) const {
205 return belief( (VarSet
)n
);
210 void JTree::runHUGIN() {
211 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
212 Qa
[alpha
] = OR(alpha
);
214 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
215 Qb
[beta
].fill( 1.0 );
219 for( size_t i
= RTree
.size(); (i
--) != 0; ) {
220 // Make outer region RTree[i].n1 consistent with outer region RTree[i].n2
221 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
222 Factor new_Qb
= Qa
[RTree
[i
].n2
].partSum( IR( i
) );
223 _logZ
+= log(new_Qb
.normalize());
224 Qa
[RTree
[i
].n1
] *= new_Qb
.divided_by( Qb
[i
] );
228 _logZ
+= log(Qa
[0].normalize() );
230 _logZ
+= log(Qa
[RTree
[0].n1
].normalize());
232 // DistributeEvidence
233 for( size_t i
= 0; i
< RTree
.size(); i
++ ) {
234 // Make outer region RTree[i].n2 consistent with outer region RTree[i].n1
235 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
236 Factor new_Qb
= Qa
[RTree
[i
].n1
].marginal( IR( i
) );
237 Qa
[RTree
[i
].n2
] *= new_Qb
.divided_by( Qb
[i
] );
242 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
243 Qa
[alpha
].normalize();
247 // Really needs no init! Initial messages can be anything.
248 void JTree::runShaferShenoy() {
251 for( size_t e
= nrIRs(); (e
--) != 0; ) {
252 // send a message from RTree[e].n2 to RTree[e].n1
253 // or, actually, from the seperator IR(e) to RTree[e].n1
255 size_t i
= nbIR(e
)[1].node
; // = RTree[e].n2
256 size_t j
= nbIR(e
)[0].node
; // = RTree[e].n1
257 size_t _e
= nbIR(e
)[0].dual
;
260 foreach( const Neighbor
&k
, nbOR(i
) )
262 piet
*= message( i
, k
.iter
);
263 message( j
, _e
) = piet
.partSum( IR(e
) );
264 _logZ
+= log( message(j
,_e
).normalize() );
268 for( size_t e
= 0; e
< nrIRs(); e
++ ) {
269 size_t i
= nbIR(e
)[0].node
; // = RTree[e].n1
270 size_t j
= nbIR(e
)[1].node
; // = RTree[e].n2
271 size_t _e
= nbIR(e
)[1].dual
;
274 foreach( const Neighbor
&k
, nbOR(i
) )
276 piet
*= message( i
, k
.iter
);
277 message( j
, _e
) = piet
.marginal( IR(e
) );
281 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
282 Factor piet
= OR(alpha
);
283 foreach( const Neighbor
&k
, nbOR(alpha
) )
284 piet
*= message( alpha
, k
.iter
);
286 _logZ
+= log( piet
.normalize() );
288 } else if( alpha
== nbIR(0)[0].node
/*RTree[0].n1*/ ) {
289 _logZ
+= log( piet
.normalize() );
292 Qa
[alpha
] = piet
.normalized();
295 // Only for logZ (and for belief)...
296 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
297 Qb
[beta
] = Qa
[nbIR(beta
)[0].node
].marginal( IR(beta
) );
301 double JTree::run() {
302 if( props
.updates
== Properties::UpdateType::HUGIN
)
304 else if( props
.updates
== Properties::UpdateType::SHSH
)
310 Real
JTree::logZ() const {
312 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
313 sum
+= IR(beta
).c() * Qb
[beta
].entropy();
314 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
315 sum
+= OR(alpha
).c() * Qa
[alpha
].entropy();
316 sum
+= (OR(alpha
).log0() * Qa
[alpha
]).totalSum();
323 size_t JTree::findEfficientTree( const VarSet
& ns
, DEdgeVec
&Tree
, size_t PreviousRoot
) const {
324 // find new root clique (the one with maximal statespace overlap with ns)
325 size_t maxval
= 0, maxalpha
= 0;
326 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
327 size_t val
= VarSet(ns
& OR(alpha
).vars()).nrStates();
336 for( DEdgeVec::const_iterator e
= RTree
.begin(); e
!= RTree
.end(); e
++ )
337 oldTree
.insert( UEdge(e
->n1
, e
->n2
) );
338 DEdgeVec newTree
= GrowRootedTree( oldTree
, maxalpha
);
340 // identify subtree that contains variables of ns which are not in the new root
341 VarSet nsrem
= ns
/ OR(maxalpha
).vars();
343 // for each variable in ns that is not in the root clique
344 for( VarSet::const_iterator n
= nsrem
.begin(); n
!= nsrem
.end(); n
++ ) {
345 // find first occurence of *n in the tree, which is closest to the root
347 for( ; e
!= newTree
.size(); e
++ ) {
348 if( OR(newTree
[e
].n2
).vars().contains( *n
) )
351 assert( e
!= newTree
.size() );
353 // track-back path to root and add edges to subTree
354 subTree
.insert( newTree
[e
] );
355 size_t pos
= newTree
[e
].n1
;
357 if( newTree
[e
-1].n2
== pos
) {
358 subTree
.insert( newTree
[e
-1] );
359 pos
= newTree
[e
-1].n1
;
362 if( PreviousRoot
!= (size_t)-1 && PreviousRoot
!= maxalpha
) {
363 // find first occurence of PreviousRoot in the tree, which is closest to the new root
365 for( ; e
!= newTree
.size(); e
++ ) {
366 if( newTree
[e
].n2
== PreviousRoot
)
369 assert( e
!= newTree
.size() );
371 // track-back path to root and add edges to subTree
372 subTree
.insert( newTree
[e
] );
373 size_t pos
= newTree
[e
].n1
;
375 if( newTree
[e
-1].n2
== pos
) {
376 subTree
.insert( newTree
[e
-1] );
377 pos
= newTree
[e
-1].n1
;
381 // Resulting Tree is a reordered copy of newTree
382 // First add edges in subTree to Tree
384 for( DEdgeVec::const_iterator e
= newTree
.begin(); e
!= newTree
.end(); e
++ )
385 if( subTree
.count( *e
) ) {
386 Tree
.push_back( *e
);
388 // Then add edges pointing away from nsrem
390 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
391 for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
393 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
394 e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
395 Tree.push_back( *e );
399 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
400 if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
402 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
403 if( (OR(e->n1).vars() && *n) ) {
408 Tree.push_back( *e );
411 size_t subTreeSize
= Tree
.size();
412 // Then add remaining edges
413 for( DEdgeVec::const_iterator e
= newTree
.begin(); e
!= newTree
.end(); e
++ )
414 if( find( Tree
.begin(), Tree
.end(), *e
) == Tree
.end() )
415 Tree
.push_back( *e
);
421 // Cutset conditioning
422 // assumes that run() has been called already
423 Factor
JTree::calcMarginal( const VarSet
& ns
) {
424 vector
<Factor
>::const_iterator beta
;
425 for( beta
= Qb
.begin(); beta
!= Qb
.end(); beta
++ )
426 if( beta
->vars() >> ns
)
428 if( beta
!= Qb
.end() )
429 return( beta
->marginal(ns
) );
431 vector
<Factor
>::const_iterator alpha
;
432 for( alpha
= Qa
.begin(); alpha
!= Qa
.end(); alpha
++ )
433 if( alpha
->vars() >> ns
)
435 if( alpha
!= Qa
.end() )
436 return( alpha
->marginal(ns
) );
438 // Find subtree to do efficient inference
440 size_t Tsize
= findEfficientTree( ns
, T
);
442 // Find remaining variables (which are not in the new root)
443 VarSet nsrem
= ns
/ OR(T
.front().n1
).vars();
444 Factor
Pns (ns
, 0.0);
446 // Save Qa and Qb on the subtree
447 map
<size_t,Factor
> Qa_old
;
448 map
<size_t,Factor
> Qb_old
;
449 vector
<size_t> b(Tsize
, 0);
450 for( size_t i
= Tsize
; (i
--) != 0; ) {
451 size_t alpha1
= T
[i
].n1
;
452 size_t alpha2
= T
[i
].n2
;
454 for( beta
= 0; beta
< nrIRs(); beta
++ )
455 if( UEdge( RTree
[beta
].n1
, RTree
[beta
].n2
) == UEdge( alpha1
, alpha2
) )
457 assert( beta
!= nrIRs() );
460 if( !Qa_old
.count( alpha1
) )
461 Qa_old
[alpha1
] = Qa
[alpha1
];
462 if( !Qa_old
.count( alpha2
) )
463 Qa_old
[alpha2
] = Qa
[alpha2
];
464 if( !Qb_old
.count( beta
) )
465 Qb_old
[beta
] = Qb
[beta
];
468 // For all states of nsrem
469 for( State
s(nsrem
); s
.valid(); s
++ ) {
472 for( size_t i
= Tsize
; (i
--) != 0; ) {
473 // Make outer region T[i].n1 consistent with outer region T[i].n2
474 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
476 for( VarSet::const_iterator n
= nsrem
.begin(); n
!= nsrem
.end(); n
++ )
477 if( Qa
[T
[i
].n2
].vars() >> *n
) {
478 Factor
piet( *n
, 0.0 );
483 Factor new_Qb
= Qa
[T
[i
].n2
].partSum( IR( b
[i
] ) );
484 logZ
+= log(new_Qb
.normalize());
485 Qa
[T
[i
].n1
] *= new_Qb
.divided_by( Qb
[b
[i
]] );
488 logZ
+= log(Qa
[T
[0].n1
].normalize());
490 Factor
piet( nsrem
, 0.0 );
492 Pns
+= piet
* Qa
[T
[0].n1
].partSum( ns
/ nsrem
); // OPTIMIZE ME
494 // Restore clamped beliefs
495 for( map
<size_t,Factor
>::const_iterator alpha
= Qa_old
.begin(); alpha
!= Qa_old
.end(); alpha
++ )
496 Qa
[alpha
->first
] = alpha
->second
;
497 for( map
<size_t,Factor
>::const_iterator beta
= Qb_old
.begin(); beta
!= Qb_old
.end(); beta
++ )
498 Qb
[beta
->first
] = beta
->second
;
501 return( Pns
.normalized() );
507 /// Calculates upper bound to the treewidth of a FactorGraph
509 * \return a pair (number of variables in largest clique, number of states in largest clique)
511 std::pair
<size_t,size_t> treewidth( const FactorGraph
& fg
) {
515 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
516 _cg
.insert( fg
.factor(I
).vars() );
518 // Retain only maximal clusters
519 _cg
.eraseNonMaximal();
521 // Obtain elimination sequence
522 vector
<VarSet
> ElimVec
= _cg
.VarElim_MinFill().eraseNonMaximal().toVector();
524 // Calculate treewidth
525 size_t treewidth
= 0;
527 for( size_t i
= 0; i
< ElimVec
.size(); i
++ ) {
528 if( ElimVec
[i
].size() > treewidth
)
529 treewidth
= ElimVec
[i
].size();
530 size_t s
= ElimVec
[i
].nrStates();
535 return pair
<size_t,size_t>(treewidth
, nrstates
);
539 } // end of namespace dai