1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 #include <dai/jtree.h>
32 const char *JTree::Name
= "JTREE";
35 bool JTree::checkProperties() {
36 if (!HasProperty("verbose") )
38 if( !HasProperty("updates") )
41 ConvertPropertyTo
<size_t>("verbose");
42 ConvertPropertyTo
<UpdateType
>("updates");
48 JTree::JTree( const FactorGraph
&fg
, const Properties
&opts
, bool automatic
) : DAIAlgRG(fg
, opts
), _RTree(), _Qa(), _Qb(), _mes(), _logZ() {
49 assert( checkProperties() );
55 for( size_t I
= 0; I
< nrFactors(); I
++ )
56 _cg
.insert( factor(I
).vars() );
58 cout
<< "Initial clusters: " << _cg
<< endl
;
60 // Retain only maximal clusters
61 _cg
.eraseNonMaximal();
63 cout
<< "Maximal clusters: " << _cg
<< endl
;
65 vector
<VarSet
> ElimVec
= _cg
.VarElim_MinFill().eraseNonMaximal().toVector();
66 if( Verbose() >= 3 ) {
67 cout
<< "VarElim_MinFill result: {" << endl
;
68 for( size_t i
= 0; i
< ElimVec
.size(); i
++ ) {
76 GenerateJT( ElimVec
);
81 void JTree::GenerateJT( const vector
<VarSet
> &Cliques
) {
82 // Construct a weighted graph (each edge is weighted with the cardinality
83 // of the intersection of the nodes, where the nodes are the elements of
85 WeightedGraph
<int> JuncGraph
;
86 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
87 for( size_t j
= i
+1; j
< Cliques
.size(); j
++ ) {
88 size_t w
= (Cliques
[i
] & Cliques
[j
]).size();
89 JuncGraph
[UEdge(i
,j
)] = w
;
92 // Construct maximal spanning tree using Prim's algorithm
93 _RTree
= MaxSpanningTreePrim( JuncGraph
);
95 // Construct corresponding region graph
97 // Create outer regions
98 ORs().reserve( Cliques
.size() );
99 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
100 ORs().push_back( FRegion( Factor(Cliques
[i
], 1.0), 1.0 ) );
102 // For each factor, find an outer region that subsumes that factor.
103 // Then, multiply the outer region with that factor.
104 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
106 for( alpha
= 0; alpha
< nr_ORs(); alpha
++ )
107 if( OR(alpha
).vars() >> factor(I
).vars() ) {
108 // OR(alpha) *= factor(I);
112 assert( alpha
!= nr_ORs() );
116 // Create inner regions and edges
117 IRs().reserve( _RTree
.size() );
118 Redges().reserve( 2 * _RTree
.size() );
119 for( size_t i
= 0; i
< _RTree
.size(); i
++ ) {
120 Redges().push_back( R_edge_t( _RTree
[i
].n1
, IRs().size() ) );
121 Redges().push_back( R_edge_t( _RTree
[i
].n2
, IRs().size() ) );
122 // inner clusters have counting number -1
123 IRs().push_back( Region( Cliques
[_RTree
[i
].n1
] & Cliques
[_RTree
[i
].n2
], -1.0 ) );
126 // Regenerate BipartiteGraph internals
129 // Create messages and beliefs
131 _Qa
.reserve( nr_ORs() );
132 for( size_t alpha
= 0; alpha
< nr_ORs(); alpha
++ )
133 _Qa
.push_back( OR(alpha
) );
136 _Qb
.reserve( nr_IRs() );
137 for( size_t beta
= 0; beta
< nr_IRs(); beta
++ )
138 _Qb
.push_back( Factor( IR(beta
), 1.0 ) );
141 _mes
.reserve( nr_Redges() );
142 for( size_t e
= 0; e
< nr_Redges(); e
++ )
143 _mes
.push_back( Factor( IR(Redge(e
).second
), 1.0 ) );
145 // Check counting numbers
146 Check_Counting_Numbers();
148 if( Verbose() >= 3 ) {
149 cout
<< "Resulting regiongraph: " << *this << endl
;
154 string
JTree::identify() const {
155 stringstream
result (stringstream::out
);
156 result
<< Name
<< GetProperties();
161 Factor
JTree::belief( const VarSet
&ns
) const {
162 vector
<Factor
>::const_iterator beta
;
163 for( beta
= _Qb
.begin(); beta
!= _Qb
.end(); beta
++ )
164 if( beta
->vars() >> ns
)
166 if( beta
!= _Qb
.end() )
167 return( beta
->marginal(ns
) );
169 vector
<Factor
>::const_iterator alpha
;
170 for( alpha
= _Qa
.begin(); alpha
!= _Qa
.end(); alpha
++ )
171 if( alpha
->vars() >> ns
)
173 assert( alpha
!= _Qa
.end() );
174 return( alpha
->marginal(ns
) );
179 vector
<Factor
> JTree::beliefs() const {
180 vector
<Factor
> result
;
181 for( size_t beta
= 0; beta
< nr_IRs(); beta
++ )
182 result
.push_back( _Qb
[beta
] );
183 for( size_t alpha
= 0; alpha
< nr_ORs(); alpha
++ )
184 result
.push_back( _Qa
[alpha
] );
189 Factor
JTree::belief( const Var
&n
) const {
190 return belief( (VarSet
)n
);
195 void JTree::runHUGIN() {
196 for( size_t alpha
= 0; alpha
< nr_ORs(); alpha
++ )
197 _Qa
[alpha
] = OR(alpha
);
199 for( size_t beta
= 0; beta
< nr_IRs(); beta
++ )
200 _Qb
[beta
].fill( 1.0 );
204 for( size_t i
= _RTree
.size(); (i
--) != 0; ) {
205 // Make outer region _RTree[i].n1 consistent with outer region _RTree[i].n2
206 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
207 Factor new_Qb
= _Qa
[_RTree
[i
].n2
].part_sum( IR( i
) );
208 _logZ
+= log(new_Qb
.normalize( Prob::NORMPROB
));
209 _Qa
[_RTree
[i
].n1
] *= new_Qb
.divided_by( _Qb
[i
] );
213 _logZ
+= log(_Qa
[0].normalize( Prob::NORMPROB
) );
215 _logZ
+= log(_Qa
[_RTree
[0].n1
].normalize( Prob::NORMPROB
));
217 // DistributeEvidence
218 for( size_t i
= 0; i
< _RTree
.size(); i
++ ) {
219 // Make outer region _RTree[i].n2 consistent with outer region _RTree[i].n1
220 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
221 Factor new_Qb
= _Qa
[_RTree
[i
].n1
].marginal( IR( i
) );
222 _Qa
[_RTree
[i
].n2
] *= new_Qb
.divided_by( _Qb
[i
] );
227 for( size_t alpha
= 0; alpha
< nr_ORs(); alpha
++ )
228 _Qa
[alpha
].normalize( Prob::NORMPROB
);
232 // Really needs no init! Initial messages can be anything.
233 void JTree::runShaferShenoy() {
236 for( size_t e
= _RTree
.size(); (e
--) != 0; ) {
237 // send a message from _RTree[e].n2 to _RTree[e].n1
238 // or, actually, from the seperator IR(e) to _RTree[e].n1
240 size_t i
= _RTree
[e
].n2
;
241 size_t j
= _RTree
[e
].n1
;
244 for( R_nb_cit k
= nbOR(i
).begin(); k
!= nbOR(i
).end(); k
++ )
246 piet
*= message( i
, *k
);
247 message( j
, e
) = piet
.part_sum( IR(e
) );
248 _logZ
+= log( message(j
,e
).normalize( Prob::NORMPROB
) );
252 for( size_t e
= 0; e
< _RTree
.size(); e
++ ) {
253 size_t i
= _RTree
[e
].n1
;
254 size_t j
= _RTree
[e
].n2
;
257 for( R_nb_cit k
= nbOR(i
).begin(); k
!= nbOR(i
).end(); k
++ )
259 piet
*= message( i
, *k
);
260 message( j
, e
) = piet
.marginal( IR(e
) );
264 for( size_t alpha
= 0; alpha
< nr_ORs(); alpha
++ ) {
265 Factor piet
= OR(alpha
);
266 for( R_nb_cit k
= nbOR(alpha
).begin(); k
!= nbOR(alpha
).end(); k
++ )
267 piet
*= message( alpha
, *k
);
268 if( _RTree
.empty() ) {
269 _logZ
+= log( piet
.normalize( Prob::NORMPROB
) );
271 } else if( alpha
== _RTree
[0].n1
) {
272 _logZ
+= log( piet
.normalize( Prob::NORMPROB
) );
275 _Qa
[alpha
] = piet
.normalized( Prob::NORMPROB
);
278 // Only for logZ (and for belief)...
279 for( size_t beta
= 0; beta
< nr_IRs(); beta
++ )
280 _Qb
[beta
] = _Qa
[nbIR(beta
)[0]].marginal( IR(beta
) );
284 double JTree::run() {
285 if( Updates() == UpdateType::HUGIN
)
287 else if( Updates() == UpdateType::SHSH
)
293 Complex
JTree::logZ() const {
295 for( size_t beta
= 0; beta
< nr_IRs(); beta
++ )
296 sum
+= Complex(IR(beta
).c()) * _Qb
[beta
].entropy();
297 for( size_t alpha
= 0; alpha
< nr_ORs(); alpha
++ ) {
298 sum
+= Complex(OR(alpha
).c()) * _Qa
[alpha
].entropy();
299 sum
+= (OR(alpha
).log0() * _Qa
[alpha
]).totalSum();
306 size_t JTree::findEfficientTree( const VarSet
& ns
, DEdgeVec
&Tree
, size_t PreviousRoot
) const {
307 // find new root clique (the one with maximal statespace overlap with ns)
308 size_t maxval
= 0, maxalpha
= 0;
309 for( size_t alpha
= 0; alpha
< nr_ORs(); alpha
++ ) {
310 size_t val
= (ns
& OR(alpha
).vars()).stateSpace();
317 // for( size_t e = 0; e < _RTree.size(); e++ )
318 // cout << OR(_RTree[e].n1).vars() << "->" << OR(_RTree[e].n2).vars() << ", ";
322 for( DEdgeVec::const_iterator e
= _RTree
.begin(); e
!= _RTree
.end(); e
++ )
323 oldTree
.insert( UEdge(e
->n1
, e
->n2
) );
324 DEdgeVec newTree
= GrowRootedTree( oldTree
, maxalpha
);
325 // cout << ns << ": ";
326 // for( size_t e = 0; e < newTree.size(); e++ )
327 // cout << OR(newTree[e].n1).vars() << "->" << OR(newTree[e].n2).vars() << ", ";
330 // identify subtree that contains variables of ns which are not in the new root
331 VarSet nsrem
= ns
/ OR(maxalpha
).vars();
332 // cout << "nsrem:" << nsrem << endl;
334 // for each variable in ns that is not in the root clique
335 for( VarSet::const_iterator n
= nsrem
.begin(); n
!= nsrem
.end(); n
++ ) {
336 // find first occurence of *n in the tree, which is closest to the root
338 for( ; e
!= newTree
.size(); e
++ ) {
339 if( OR(newTree
[e
].n2
).vars() && *n
)
342 assert( e
!= newTree
.size() );
344 // track-back path to root and add edges to subTree
345 subTree
.insert( newTree
[e
] );
346 size_t pos
= newTree
[e
].n1
;
348 if( newTree
[e
-1].n2
== pos
) {
349 subTree
.insert( newTree
[e
-1] );
350 pos
= newTree
[e
-1].n1
;
353 if( PreviousRoot
!= (size_t)-1 && PreviousRoot
!= maxalpha
) {
354 // find first occurence of PreviousRoot in the tree, which is closest to the new root
356 for( ; e
!= newTree
.size(); e
++ ) {
357 if( newTree
[e
].n2
== PreviousRoot
)
360 assert( e
!= newTree
.size() );
362 // track-back path to root and add edges to subTree
363 subTree
.insert( newTree
[e
] );
364 size_t pos
= newTree
[e
].n1
;
366 if( newTree
[e
-1].n2
== pos
) {
367 subTree
.insert( newTree
[e
-1] );
368 pos
= newTree
[e
-1].n1
;
371 // cout << "subTree: " << endl;
372 // for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
373 // cout << OR(sTi->n1).vars() << "->" << OR(sTi->n2).vars() << ", ";
376 // Resulting Tree is a reordered copy of newTree
377 // First add edges in subTree to Tree
379 for( DEdgeVec::const_iterator e
= newTree
.begin(); e
!= newTree
.end(); e
++ )
380 if( subTree
.count( *e
) ) {
381 Tree
.push_back( *e
);
382 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
385 // Then add edges pointing away from nsrem
387 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
388 for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
390 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
391 e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
392 Tree.push_back( *e );
393 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
397 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
398 if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
400 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
401 if( (OR(e->n1).vars() && *n) ) {
406 Tree.push_back( *e );
407 cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
411 size_t subTreeSize
= Tree
.size();
412 // Then add remaining edges
413 for( DEdgeVec::const_iterator e
= newTree
.begin(); e
!= newTree
.end(); e
++ )
414 if( find( Tree
.begin(), Tree
.end(), *e
) == Tree
.end() )
415 Tree
.push_back( *e
);
421 // Cutset conditioning
422 // assumes that run() has been called already
423 Factor
JTree::calcMarginal( const VarSet
& ns
) {
424 vector
<Factor
>::const_iterator beta
;
425 for( beta
= _Qb
.begin(); beta
!= _Qb
.end(); beta
++ )
426 if( beta
->vars() >> ns
)
428 if( beta
!= _Qb
.end() )
429 return( beta
->marginal(ns
) );
431 vector
<Factor
>::const_iterator alpha
;
432 for( alpha
= _Qa
.begin(); alpha
!= _Qa
.end(); alpha
++ )
433 if( alpha
->vars() >> ns
)
435 if( alpha
!= _Qa
.end() )
436 return( alpha
->marginal(ns
) );
438 // Find subtree to do efficient inference
440 size_t Tsize
= findEfficientTree( ns
, T
);
442 // Find remaining variables (which are not in the new root)
443 VarSet nsrem
= ns
/ OR(T
.front().n1
).vars();
444 Factor
Pns (ns
, 0.0);
448 // Save _Qa and _Qb on the subtree
449 map
<size_t,Factor
> _Qa_old
;
450 map
<size_t,Factor
> _Qb_old
;
451 vector
<size_t> b(Tsize
, 0);
452 for( size_t i
= Tsize
; (i
--) != 0; ) {
453 size_t alpha1
= T
[i
].n1
;
454 size_t alpha2
= T
[i
].n2
;
456 for( beta
= 0; beta
< nr_IRs(); beta
++ )
457 if( UEdge( _RTree
[beta
].n1
, _RTree
[beta
].n2
) == UEdge( alpha1
, alpha2
) )
459 assert( beta
!= nr_IRs() );
462 if( !_Qa_old
.count( alpha1
) )
463 _Qa_old
[alpha1
] = _Qa
[alpha1
];
464 if( !_Qa_old
.count( alpha2
) )
465 _Qa_old
[alpha2
] = _Qa
[alpha2
];
466 if( !_Qb_old
.count( beta
) )
467 _Qb_old
[beta
] = _Qb
[beta
];
470 // For all states of nsrem
471 for( size_t j
= 0; j
< mi
.max(); j
++ ) {
472 vector
<size_t> vi
= mi
.vi( j
);
476 for( size_t i
= Tsize
; (i
--) != 0; ) {
477 // Make outer region T[i].n1 consistent with outer region T[i].n2
478 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
481 for( VarSet::const_iterator n
= nsrem
.begin(); n
!= nsrem
.end(); n
++, k
++ )
482 if( _Qa
[T
[i
].n2
].vars() >> *n
) {
483 Factor
piet( *n
, 0.0 );
485 _Qa
[T
[i
].n2
] *= piet
;
488 Factor new_Qb
= _Qa
[T
[i
].n2
].part_sum( IR( b
[i
] ) );
489 logZ
+= log(new_Qb
.normalize( Prob::NORMPROB
));
490 _Qa
[T
[i
].n1
] *= new_Qb
.divided_by( _Qb
[b
[i
]] );
493 logZ
+= log(_Qa
[T
[0].n1
].normalize( Prob::NORMPROB
));
495 Factor
piet( nsrem
, 0.0 );
497 Pns
+= piet
* _Qa
[T
[0].n1
].part_sum( ns
/ nsrem
); // OPTIMIZE ME
499 // Restore clamped beliefs
500 for( map
<size_t,Factor
>::const_iterator alpha
= _Qa_old
.begin(); alpha
!= _Qa_old
.end(); alpha
++ )
501 _Qa
[alpha
->first
] = alpha
->second
;
502 for( map
<size_t,Factor
>::const_iterator beta
= _Qb_old
.begin(); beta
!= _Qb_old
.end(); beta
++ )
503 _Qb
[beta
->first
] = beta
->second
;
506 return( Pns
.normalized(Prob::NORMPROB
) );
512 } // end of namespace dai