1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 #include <dai/jtree.h>
32 const char *JTree::Name
= "JTREE";
35 void JTree::setProperties( const PropertySet
&opts
) {
36 assert( opts
.hasKey("verbose") );
37 assert( opts
.hasKey("updates") );
39 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
40 props
.updates
= opts
.getStringAs
<Properties::UpdateType
>("updates");
44 PropertySet
JTree::getProperties() const {
46 opts
.Set( "verbose", props
.verbose
);
47 opts
.Set( "updates", props
.updates
);
52 string
JTree::printProperties() const {
53 stringstream
s( stringstream::out
);
55 s
<< "verbose=" << props
.verbose
<< ",";
56 s
<< "updates=" << props
.updates
<< "]";
61 JTree::JTree( const FactorGraph
&fg
, const PropertySet
&opts
, bool automatic
) : DAIAlgRG(fg
), _RTree(), _Qa(), _Qb(), _mes(), _logZ(), props() {
62 setProperties( opts
);
65 // Copy VarSets of factors
67 cl
.reserve( fg
.nrFactors() );
68 for( size_t I
= 0; I
< nrFactors(); I
++ )
69 cl
.push_back( factor(I
).vars() );
70 ClusterGraph
_cg( cl
);
72 if( props
.verbose
>= 3 )
73 cout
<< "Initial clusters: " << _cg
<< endl
;
75 // Retain only maximal clusters
76 _cg
.eraseNonMaximal();
77 if( props
.verbose
>= 3 )
78 cout
<< "Maximal clusters: " << _cg
<< endl
;
80 vector
<VarSet
> ElimVec
= _cg
.VarElim_MinFill().eraseNonMaximal().toVector();
81 if( props
.verbose
>= 3 )
82 cout
<< "VarElim_MinFill result: " << ElimVec
<< endl
;
84 GenerateJT( ElimVec
);
89 void JTree::GenerateJT( const std::vector
<VarSet
> &Cliques
) {
90 // Construct a weighted graph (each edge is weighted with the cardinality
91 // of the intersection of the nodes, where the nodes are the elements of
93 WeightedGraph
<int> JuncGraph
;
94 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
95 for( size_t j
= i
+1; j
< Cliques
.size(); j
++ ) {
96 size_t w
= (Cliques
[i
] & Cliques
[j
]).size();
97 JuncGraph
[UEdge(i
,j
)] = w
;
100 // Construct maximal spanning tree using Prim's algorithm
101 _RTree
= MaxSpanningTreePrims( JuncGraph
);
103 // Construct corresponding region graph
105 // Create outer regions
106 ORs
.reserve( Cliques
.size() );
107 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
108 ORs
.push_back( FRegion( Factor(Cliques
[i
], 1.0), 1.0 ) );
110 // For each factor, find an outer region that subsumes that factor.
111 // Then, multiply the outer region with that factor.
112 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
114 for( alpha
= 0; alpha
< nrORs(); alpha
++ )
115 if( OR(alpha
).vars() >> factor(I
).vars() ) {
116 // OR(alpha) *= factor(I);
117 fac2OR
.push_back( alpha
);
120 assert( alpha
!= nrORs() );
124 // Create inner regions and edges
125 IRs
.reserve( _RTree
.size() );
127 edges
.reserve( 2 * _RTree
.size() );
128 for( size_t i
= 0; i
< _RTree
.size(); i
++ ) {
129 edges
.push_back( Edge( _RTree
[i
].n1
, nrIRs() ) );
130 edges
.push_back( Edge( _RTree
[i
].n2
, nrIRs() ) );
131 // inner clusters have counting number -1
132 IRs
.push_back( Region( Cliques
[_RTree
[i
].n1
] & Cliques
[_RTree
[i
].n2
], -1.0 ) );
135 // create bipartite graph
136 G
.create( nrORs(), nrIRs(), edges
.begin(), edges
.end() );
138 // Create messages and beliefs
140 _Qa
.reserve( nrORs() );
141 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
142 _Qa
.push_back( OR(alpha
) );
145 _Qb
.reserve( nrIRs() );
146 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
147 _Qb
.push_back( Factor( IR(beta
), 1.0 ) );
150 _mes
.reserve( nrORs() );
151 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
152 _mes
.push_back( vector
<Factor
>() );
153 _mes
[alpha
].reserve( nbOR(alpha
).size() );
154 foreach( const Neighbor
&beta
, nbOR(alpha
) )
155 _mes
[alpha
].push_back( Factor( IR(beta
), 1.0 ) );
158 // Check counting numbers
159 Check_Counting_Numbers();
161 if( props
.verbose
>= 3 ) {
162 cout
<< "Resulting regiongraph: " << *this << endl
;
167 string
JTree::identify() const {
168 return string(Name
) + printProperties();
172 Factor
JTree::belief( const VarSet
&ns
) const {
173 vector
<Factor
>::const_iterator beta
;
174 for( beta
= _Qb
.begin(); beta
!= _Qb
.end(); beta
++ )
175 if( beta
->vars() >> ns
)
177 if( beta
!= _Qb
.end() )
178 return( beta
->marginal(ns
) );
180 vector
<Factor
>::const_iterator alpha
;
181 for( alpha
= _Qa
.begin(); alpha
!= _Qa
.end(); alpha
++ )
182 if( alpha
->vars() >> ns
)
184 assert( alpha
!= _Qa
.end() );
185 return( alpha
->marginal(ns
) );
190 vector
<Factor
> JTree::beliefs() const {
191 vector
<Factor
> result
;
192 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
193 result
.push_back( _Qb
[beta
] );
194 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
195 result
.push_back( _Qa
[alpha
] );
200 Factor
JTree::belief( const Var
&n
) const {
201 return belief( (VarSet
)n
);
206 void JTree::runHUGIN() {
207 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
208 _Qa
[alpha
] = OR(alpha
);
210 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
211 _Qb
[beta
].fill( 1.0 );
215 for( size_t i
= _RTree
.size(); (i
--) != 0; ) {
216 // Make outer region _RTree[i].n1 consistent with outer region _RTree[i].n2
217 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
218 Factor new_Qb
= _Qa
[_RTree
[i
].n2
].partSum( IR( i
) );
219 _logZ
+= log(new_Qb
.normalize( Prob::NORMPROB
));
220 _Qa
[_RTree
[i
].n1
] *= new_Qb
.divided_by( _Qb
[i
] );
224 _logZ
+= log(_Qa
[0].normalize( Prob::NORMPROB
) );
226 _logZ
+= log(_Qa
[_RTree
[0].n1
].normalize( Prob::NORMPROB
));
228 // DistributeEvidence
229 for( size_t i
= 0; i
< _RTree
.size(); i
++ ) {
230 // Make outer region _RTree[i].n2 consistent with outer region _RTree[i].n1
231 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
232 Factor new_Qb
= _Qa
[_RTree
[i
].n1
].marginal( IR( i
) );
233 _Qa
[_RTree
[i
].n2
] *= new_Qb
.divided_by( _Qb
[i
] );
238 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
239 _Qa
[alpha
].normalize( Prob::NORMPROB
);
243 // Really needs no init! Initial messages can be anything.
244 void JTree::runShaferShenoy() {
247 for( size_t e
= nrIRs(); (e
--) != 0; ) {
248 // send a message from _RTree[e].n2 to _RTree[e].n1
249 // or, actually, from the seperator IR(e) to _RTree[e].n1
251 size_t i
= nbIR(e
)[1].node
; // = _RTree[e].n2
252 size_t j
= nbIR(e
)[0].node
; // = _RTree[e].n1
253 size_t _e
= nbIR(e
)[0].dual
;
256 foreach( const Neighbor
&k
, nbOR(i
) )
258 piet
*= message( i
, k
.iter
);
259 message( j
, _e
) = piet
.partSum( IR(e
) );
260 _logZ
+= log( message(j
,_e
).normalize( Prob::NORMPROB
) );
264 for( size_t e
= 0; e
< nrIRs(); e
++ ) {
265 size_t i
= nbIR(e
)[0].node
; // = _RTree[e].n1
266 size_t j
= nbIR(e
)[1].node
; // = _RTree[e].n2
267 size_t _e
= nbIR(e
)[1].dual
;
270 foreach( const Neighbor
&k
, nbOR(i
) )
272 piet
*= message( i
, k
.iter
);
273 message( j
, _e
) = piet
.marginal( IR(e
) );
277 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
278 Factor piet
= OR(alpha
);
279 foreach( const Neighbor
&k
, nbOR(alpha
) )
280 piet
*= message( alpha
, k
.iter
);
282 _logZ
+= log( piet
.normalize( Prob::NORMPROB
) );
284 } else if( alpha
== nbIR(0)[0].node
/*_RTree[0].n1*/ ) {
285 _logZ
+= log( piet
.normalize( Prob::NORMPROB
) );
288 _Qa
[alpha
] = piet
.normalized( Prob::NORMPROB
);
291 // Only for logZ (and for belief)...
292 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
293 _Qb
[beta
] = _Qa
[nbIR(beta
)[0].node
].marginal( IR(beta
) );
297 double JTree::run() {
298 if( props
.updates
== Properties::UpdateType::HUGIN
)
300 else if( props
.updates
== Properties::UpdateType::SHSH
)
306 Real
JTree::logZ() const {
308 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
309 sum
+= IR(beta
).c() * _Qb
[beta
].entropy();
310 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
311 sum
+= OR(alpha
).c() * _Qa
[alpha
].entropy();
312 sum
+= (OR(alpha
).log0() * _Qa
[alpha
]).totalSum();
319 size_t JTree::findEfficientTree( const VarSet
& ns
, DEdgeVec
&Tree
, size_t PreviousRoot
) const {
320 // find new root clique (the one with maximal statespace overlap with ns)
321 size_t maxval
= 0, maxalpha
= 0;
322 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
323 size_t val
= (ns
& OR(alpha
).vars()).states();
330 // for( size_t e = 0; e < _RTree.size(); e++ )
331 // cout << OR(_RTree[e].n1).vars() << "->" << OR(_RTree[e].n2).vars() << ", ";
335 for( DEdgeVec::const_iterator e
= _RTree
.begin(); e
!= _RTree
.end(); e
++ )
336 oldTree
.insert( UEdge(e
->n1
, e
->n2
) );
337 DEdgeVec newTree
= GrowRootedTree( oldTree
, maxalpha
);
338 // cout << ns << ": ";
339 // for( size_t e = 0; e < newTree.size(); e++ )
340 // cout << OR(newTree[e].n1).vars() << "->" << OR(newTree[e].n2).vars() << ", ";
343 // identify subtree that contains variables of ns which are not in the new root
344 VarSet nsrem
= ns
/ OR(maxalpha
).vars();
345 // cout << "nsrem:" << nsrem << endl;
347 // for each variable in ns that is not in the root clique
348 for( VarSet::const_iterator n
= nsrem
.begin(); n
!= nsrem
.end(); n
++ ) {
349 // find first occurence of *n in the tree, which is closest to the root
351 for( ; e
!= newTree
.size(); e
++ ) {
352 if( OR(newTree
[e
].n2
).vars().contains( *n
) )
355 assert( e
!= newTree
.size() );
357 // track-back path to root and add edges to subTree
358 subTree
.insert( newTree
[e
] );
359 size_t pos
= newTree
[e
].n1
;
361 if( newTree
[e
-1].n2
== pos
) {
362 subTree
.insert( newTree
[e
-1] );
363 pos
= newTree
[e
-1].n1
;
366 if( PreviousRoot
!= (size_t)-1 && PreviousRoot
!= maxalpha
) {
367 // find first occurence of PreviousRoot in the tree, which is closest to the new root
369 for( ; e
!= newTree
.size(); e
++ ) {
370 if( newTree
[e
].n2
== PreviousRoot
)
373 assert( e
!= newTree
.size() );
375 // track-back path to root and add edges to subTree
376 subTree
.insert( newTree
[e
] );
377 size_t pos
= newTree
[e
].n1
;
379 if( newTree
[e
-1].n2
== pos
) {
380 subTree
.insert( newTree
[e
-1] );
381 pos
= newTree
[e
-1].n1
;
384 // cout << "subTree: " << endl;
385 // for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
386 // cout << OR(sTi->n1).vars() << "->" << OR(sTi->n2).vars() << ", ";
389 // Resulting Tree is a reordered copy of newTree
390 // First add edges in subTree to Tree
392 for( DEdgeVec::const_iterator e
= newTree
.begin(); e
!= newTree
.end(); e
++ )
393 if( subTree
.count( *e
) ) {
394 Tree
.push_back( *e
);
395 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
398 // Then add edges pointing away from nsrem
400 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
401 for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
403 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
404 e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
405 Tree.push_back( *e );
406 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
410 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
411 if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
413 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
414 if( (OR(e->n1).vars() && *n) ) {
419 Tree.push_back( *e );
420 cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
424 size_t subTreeSize
= Tree
.size();
425 // Then add remaining edges
426 for( DEdgeVec::const_iterator e
= newTree
.begin(); e
!= newTree
.end(); e
++ )
427 if( find( Tree
.begin(), Tree
.end(), *e
) == Tree
.end() )
428 Tree
.push_back( *e
);
434 // Cutset conditioning
435 // assumes that run() has been called already
436 Factor
JTree::calcMarginal( const VarSet
& ns
) {
437 vector
<Factor
>::const_iterator beta
;
438 for( beta
= _Qb
.begin(); beta
!= _Qb
.end(); beta
++ )
439 if( beta
->vars() >> ns
)
441 if( beta
!= _Qb
.end() )
442 return( beta
->marginal(ns
) );
444 vector
<Factor
>::const_iterator alpha
;
445 for( alpha
= _Qa
.begin(); alpha
!= _Qa
.end(); alpha
++ )
446 if( alpha
->vars() >> ns
)
448 if( alpha
!= _Qa
.end() )
449 return( alpha
->marginal(ns
) );
451 // Find subtree to do efficient inference
453 size_t Tsize
= findEfficientTree( ns
, T
);
455 // Find remaining variables (which are not in the new root)
456 VarSet nsrem
= ns
/ OR(T
.front().n1
).vars();
457 Factor
Pns (ns
, 0.0);
459 // Save _Qa and _Qb on the subtree
460 map
<size_t,Factor
> _Qa_old
;
461 map
<size_t,Factor
> _Qb_old
;
462 vector
<size_t> b(Tsize
, 0);
463 for( size_t i
= Tsize
; (i
--) != 0; ) {
464 size_t alpha1
= T
[i
].n1
;
465 size_t alpha2
= T
[i
].n2
;
467 for( beta
= 0; beta
< nrIRs(); beta
++ )
468 if( UEdge( _RTree
[beta
].n1
, _RTree
[beta
].n2
) == UEdge( alpha1
, alpha2
) )
470 assert( beta
!= nrIRs() );
473 if( !_Qa_old
.count( alpha1
) )
474 _Qa_old
[alpha1
] = _Qa
[alpha1
];
475 if( !_Qa_old
.count( alpha2
) )
476 _Qa_old
[alpha2
] = _Qa
[alpha2
];
477 if( !_Qb_old
.count( beta
) )
478 _Qb_old
[beta
] = _Qb
[beta
];
481 // For all states of nsrem
482 for( State
s(nsrem
); s
.valid(); s
++ ) {
486 for( size_t i
= Tsize
; (i
--) != 0; ) {
487 // Make outer region T[i].n1 consistent with outer region T[i].n2
488 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
490 for( VarSet::const_iterator n
= nsrem
.begin(); n
!= nsrem
.end(); n
++ )
491 if( _Qa
[T
[i
].n2
].vars() >> *n
) {
492 Factor
piet( *n
, 0.0 );
494 _Qa
[T
[i
].n2
] *= piet
;
497 Factor new_Qb
= _Qa
[T
[i
].n2
].partSum( IR( b
[i
] ) );
498 logZ
+= log(new_Qb
.normalize( Prob::NORMPROB
));
499 _Qa
[T
[i
].n1
] *= new_Qb
.divided_by( _Qb
[b
[i
]] );
502 logZ
+= log(_Qa
[T
[0].n1
].normalize( Prob::NORMPROB
));
504 Factor
piet( nsrem
, 0.0 );
506 Pns
+= piet
* _Qa
[T
[0].n1
].partSum( ns
/ nsrem
); // OPTIMIZE ME
508 // Restore clamped beliefs
509 for( map
<size_t,Factor
>::const_iterator alpha
= _Qa_old
.begin(); alpha
!= _Qa_old
.end(); alpha
++ )
510 _Qa
[alpha
->first
] = alpha
->second
;
511 for( map
<size_t,Factor
>::const_iterator beta
= _Qb_old
.begin(); beta
!= _Qb_old
.end(); beta
++ )
512 _Qb
[beta
->first
] = beta
->second
;
515 return( Pns
.normalized(Prob::NORMPROB
) );
521 // first return value is treewidth
522 // second return value is number of states in largest clique
523 pair
<size_t,size_t> treewidth( const FactorGraph
& fg
) {
527 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
528 _cg
.insert( fg
.factor(I
).vars() );
530 // Retain only maximal clusters
531 _cg
.eraseNonMaximal();
533 // Obtain elimination sequence
534 vector
<VarSet
> ElimVec
= _cg
.VarElim_MinFill().eraseNonMaximal().toVector();
536 // Calculate treewidth
537 size_t treewidth
= 0;
539 for( size_t i
= 0; i
< ElimVec
.size(); i
++ ) {
540 if( ElimVec
[i
].size() > treewidth
)
541 treewidth
= ElimVec
[i
].size();
542 if( ElimVec
[i
].states() > nrstates
)
543 nrstates
= ElimVec
[i
].states();
546 return pair
<size_t,size_t>(treewidth
, nrstates
);
550 } // end of namespace dai