Merged jtree.h and jtree.cpp from SVN head
[libdai.git] / src / jtree.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <dai/jtree.h>
24
25
26 namespace dai {
27
28
29 using namespace std;
30
31
32 const char *JTree::Name = "JTREE";
33
34
35 void JTree::setProperties( const PropertySet &opts ) {
36 assert( opts.hasKey("verbose") );
37 assert( opts.hasKey("updates") );
38
39 props.verbose = opts.getStringAs<size_t>("verbose");
40 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
41 }
42
43
44 PropertySet JTree::getProperties() const {
45 PropertySet opts;
46 opts.Set( "verbose", props.verbose );
47 opts.Set( "updates", props.updates );
48 return opts;
49 }
50
51
52 string JTree::printProperties() const {
53 stringstream s( stringstream::out );
54 s << "[";
55 s << "verbose=" << props.verbose << ",";
56 s << "updates=" << props.updates << "]";
57 return s.str();
58 }
59
60
61 JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _RTree(), _Qa(), _Qb(), _mes(), _logZ(), props() {
62 setProperties( opts );
63
64 if( !isConnected() )
65 DAI_THROW(FACTORGRAPH_NOT_CONNECTED);
66
67 if( automatic ) {
68 // Create ClusterGraph which contains factors as clusters
69 vector<VarSet> cl;
70 cl.reserve( fg.nrFactors() );
71 for( size_t I = 0; I < nrFactors(); I++ )
72 cl.push_back( factor(I).vars() );
73 ClusterGraph _cg( cl );
74
75 if( props.verbose >= 3 )
76 cout << "Initial clusters: " << _cg << endl;
77
78 // Retain only maximal clusters
79 _cg.eraseNonMaximal();
80 if( props.verbose >= 3 )
81 cout << "Maximal clusters: " << _cg << endl;
82
83 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
84 if( props.verbose >= 3 )
85 cout << "VarElim_MinFill result: " << ElimVec << endl;
86
87 GenerateJT( ElimVec );
88 }
89 }
90
91
92 void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
93 // Construct a weighted graph (each edge is weighted with the cardinality
94 // of the intersection of the nodes, where the nodes are the elements of
95 // Cliques).
96 WeightedGraph<int> JuncGraph;
97 for( size_t i = 0; i < Cliques.size(); i++ )
98 for( size_t j = i+1; j < Cliques.size(); j++ ) {
99 size_t w = (Cliques[i] & Cliques[j]).size();
100 if( w )
101 JuncGraph[UEdge(i,j)] = w;
102 }
103
104 // Construct maximal spanning tree using Prim's algorithm
105 _RTree = MaxSpanningTreePrims( JuncGraph );
106
107 // Construct corresponding region graph
108
109 // Create outer regions
110 ORs.reserve( Cliques.size() );
111 for( size_t i = 0; i < Cliques.size(); i++ )
112 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
113
114 // For each factor, find an outer region that subsumes that factor.
115 // Then, multiply the outer region with that factor.
116 for( size_t I = 0; I < nrFactors(); I++ ) {
117 size_t alpha;
118 for( alpha = 0; alpha < nrORs(); alpha++ )
119 if( OR(alpha).vars() >> factor(I).vars() ) {
120 fac2OR.push_back( alpha );
121 break;
122 }
123 assert( alpha != nrORs() );
124 }
125 RecomputeORs();
126
127 // Create inner regions and edges
128 IRs.reserve( _RTree.size() );
129 vector<Edge> edges;
130 edges.reserve( 2 * _RTree.size() );
131 for( size_t i = 0; i < _RTree.size(); i++ ) {
132 edges.push_back( Edge( _RTree[i].n1, nrIRs() ) );
133 edges.push_back( Edge( _RTree[i].n2, nrIRs() ) );
134 // inner clusters have counting number -1
135 IRs.push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
136 }
137
138 // create bipartite graph
139 G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
140
141 // Create messages and beliefs
142 _Qa.clear();
143 _Qa.reserve( nrORs() );
144 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
145 _Qa.push_back( OR(alpha) );
146
147 _Qb.clear();
148 _Qb.reserve( nrIRs() );
149 for( size_t beta = 0; beta < nrIRs(); beta++ )
150 _Qb.push_back( Factor( IR(beta), 1.0 ) );
151
152 _mes.clear();
153 _mes.reserve( nrORs() );
154 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
155 _mes.push_back( vector<Factor>() );
156 _mes[alpha].reserve( nbOR(alpha).size() );
157 foreach( const Neighbor &beta, nbOR(alpha) )
158 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
159 }
160
161 // Check counting numbers
162 Check_Counting_Numbers();
163
164 if( props.verbose >= 3 ) {
165 cout << "Resulting regiongraph: " << *this << endl;
166 }
167 }
168
169
170 string JTree::identify() const {
171 return string(Name) + printProperties();
172 }
173
174
175 Factor JTree::belief( const VarSet &ns ) const {
176 vector<Factor>::const_iterator beta;
177 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
178 if( beta->vars() >> ns )
179 break;
180 if( beta != _Qb.end() )
181 return( beta->marginal(ns) );
182 else {
183 vector<Factor>::const_iterator alpha;
184 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
185 if( alpha->vars() >> ns )
186 break;
187 assert( alpha != _Qa.end() );
188 return( alpha->marginal(ns) );
189 }
190 }
191
192
193 vector<Factor> JTree::beliefs() const {
194 vector<Factor> result;
195 for( size_t beta = 0; beta < nrIRs(); beta++ )
196 result.push_back( _Qb[beta] );
197 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
198 result.push_back( _Qa[alpha] );
199 return result;
200 }
201
202
203 Factor JTree::belief( const Var &n ) const {
204 return belief( (VarSet)n );
205 }
206
207
208 // Needs no init
209 void JTree::runHUGIN() {
210 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
211 _Qa[alpha] = OR(alpha);
212
213 for( size_t beta = 0; beta < nrIRs(); beta++ )
214 _Qb[beta].fill( 1.0 );
215
216 // CollectEvidence
217 _logZ = 0.0;
218 for( size_t i = _RTree.size(); (i--) != 0; ) {
219 // Make outer region _RTree[i].n1 consistent with outer region _RTree[i].n2
220 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
221 Factor new_Qb = _Qa[_RTree[i].n2].partSum( IR( i ) );
222 _logZ += log(new_Qb.normalize());
223 _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
224 _Qb[i] = new_Qb;
225 }
226 if( _RTree.empty() )
227 _logZ += log(_Qa[0].normalize() );
228 else
229 _logZ += log(_Qa[_RTree[0].n1].normalize());
230
231 // DistributeEvidence
232 for( size_t i = 0; i < _RTree.size(); i++ ) {
233 // Make outer region _RTree[i].n2 consistent with outer region _RTree[i].n1
234 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
235 Factor new_Qb = _Qa[_RTree[i].n1].marginal( IR( i ) );
236 _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
237 _Qb[i] = new_Qb;
238 }
239
240 // Normalize
241 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
242 _Qa[alpha].normalize();
243 }
244
245
246 // Really needs no init! Initial messages can be anything.
247 void JTree::runShaferShenoy() {
248 // First pass
249 _logZ = 0.0;
250 for( size_t e = nrIRs(); (e--) != 0; ) {
251 // send a message from _RTree[e].n2 to _RTree[e].n1
252 // or, actually, from the seperator IR(e) to _RTree[e].n1
253
254 size_t i = nbIR(e)[1].node; // = _RTree[e].n2
255 size_t j = nbIR(e)[0].node; // = _RTree[e].n1
256 size_t _e = nbIR(e)[0].dual;
257
258 Factor piet = OR(i);
259 foreach( const Neighbor &k, nbOR(i) )
260 if( k != e )
261 piet *= message( i, k.iter );
262 message( j, _e ) = piet.partSum( IR(e) );
263 _logZ += log( message(j,_e).normalize() );
264 }
265
266 // Second pass
267 for( size_t e = 0; e < nrIRs(); e++ ) {
268 size_t i = nbIR(e)[0].node; // = _RTree[e].n1
269 size_t j = nbIR(e)[1].node; // = _RTree[e].n2
270 size_t _e = nbIR(e)[1].dual;
271
272 Factor piet = OR(i);
273 foreach( const Neighbor &k, nbOR(i) )
274 if( k != e )
275 piet *= message( i, k.iter );
276 message( j, _e ) = piet.marginal( IR(e) );
277 }
278
279 // Calculate beliefs
280 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
281 Factor piet = OR(alpha);
282 foreach( const Neighbor &k, nbOR(alpha) )
283 piet *= message( alpha, k.iter );
284 if( nrIRs() == 0 ) {
285 _logZ += log( piet.normalize() );
286 _Qa[alpha] = piet;
287 } else if( alpha == nbIR(0)[0].node /*_RTree[0].n1*/ ) {
288 _logZ += log( piet.normalize() );
289 _Qa[alpha] = piet;
290 } else
291 _Qa[alpha] = piet.normalized();
292 }
293
294 // Only for logZ (and for belief)...
295 for( size_t beta = 0; beta < nrIRs(); beta++ )
296 _Qb[beta] = _Qa[nbIR(beta)[0].node].marginal( IR(beta) );
297 }
298
299
300 double JTree::run() {
301 if( props.updates == Properties::UpdateType::HUGIN )
302 runHUGIN();
303 else if( props.updates == Properties::UpdateType::SHSH )
304 runShaferShenoy();
305 return 0.0;
306 }
307
308
309 Real JTree::logZ() const {
310 Real sum = 0.0;
311 for( size_t beta = 0; beta < nrIRs(); beta++ )
312 sum += IR(beta).c() * _Qb[beta].entropy();
313 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
314 sum += OR(alpha).c() * _Qa[alpha].entropy();
315 sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
316 }
317 return sum;
318 }
319
320
321
322 size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot ) const {
323 // find new root clique (the one with maximal statespace overlap with ns)
324 size_t maxval = 0, maxalpha = 0;
325 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
326 size_t val = (ns & OR(alpha).vars()).states();
327 if( val > maxval ) {
328 maxval = val;
329 maxalpha = alpha;
330 }
331 }
332
333 // for( size_t e = 0; e < _RTree.size(); e++ )
334 // cout << OR(_RTree[e].n1).vars() << "->" << OR(_RTree[e].n2).vars() << ", ";
335 // cout << endl;
336 // grow new tree
337 Graph oldTree;
338 for( DEdgeVec::const_iterator e = _RTree.begin(); e != _RTree.end(); e++ )
339 oldTree.insert( UEdge(e->n1, e->n2) );
340 DEdgeVec newTree = GrowRootedTree( oldTree, maxalpha );
341 // cout << ns << ": ";
342 // for( size_t e = 0; e < newTree.size(); e++ )
343 // cout << OR(newTree[e].n1).vars() << "->" << OR(newTree[e].n2).vars() << ", ";
344 // cout << endl;
345
346 // identify subtree that contains variables of ns which are not in the new root
347 VarSet nsrem = ns / OR(maxalpha).vars();
348 // cout << "nsrem:" << nsrem << endl;
349 set<DEdge> subTree;
350 // for each variable in ns that is not in the root clique
351 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ ) {
352 // find first occurence of *n in the tree, which is closest to the root
353 size_t e = 0;
354 for( ; e != newTree.size(); e++ ) {
355 if( OR(newTree[e].n2).vars().contains( *n ) )
356 break;
357 }
358 assert( e != newTree.size() );
359
360 // track-back path to root and add edges to subTree
361 subTree.insert( newTree[e] );
362 size_t pos = newTree[e].n1;
363 for( ; e > 0; e-- )
364 if( newTree[e-1].n2 == pos ) {
365 subTree.insert( newTree[e-1] );
366 pos = newTree[e-1].n1;
367 }
368 }
369 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
370 // find first occurence of PreviousRoot in the tree, which is closest to the new root
371 size_t e = 0;
372 for( ; e != newTree.size(); e++ ) {
373 if( newTree[e].n2 == PreviousRoot )
374 break;
375 }
376 assert( e != newTree.size() );
377
378 // track-back path to root and add edges to subTree
379 subTree.insert( newTree[e] );
380 size_t pos = newTree[e].n1;
381 for( ; e > 0; e-- )
382 if( newTree[e-1].n2 == pos ) {
383 subTree.insert( newTree[e-1] );
384 pos = newTree[e-1].n1;
385 }
386 }
387 // cout << "subTree: " << endl;
388 // for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
389 // cout << OR(sTi->n1).vars() << "->" << OR(sTi->n2).vars() << ", ";
390 // cout << endl;
391
392 // Resulting Tree is a reordered copy of newTree
393 // First add edges in subTree to Tree
394 Tree.clear();
395 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
396 if( subTree.count( *e ) ) {
397 Tree.push_back( *e );
398 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
399 }
400 // cout << endl;
401 // Then add edges pointing away from nsrem
402 // FIXME
403 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
404 for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
405 if( *e != *sTi ) {
406 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
407 e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
408 Tree.push_back( *e );
409 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
410 }
411 }*/
412 // FIXME
413 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
414 if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
415 bool found = false;
416 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
417 if( (OR(e->n1).vars() && *n) ) {
418 found = true;
419 break;
420 }
421 if( found ) {
422 Tree.push_back( *e );
423 cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
424 }
425 }
426 cout << endl;*/
427 size_t subTreeSize = Tree.size();
428 // Then add remaining edges
429 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
430 if( find( Tree.begin(), Tree.end(), *e ) == Tree.end() )
431 Tree.push_back( *e );
432
433 return subTreeSize;
434 }
435
436
437 // Cutset conditioning
438 // assumes that run() has been called already
439 Factor JTree::calcMarginal( const VarSet& ns ) {
440 vector<Factor>::const_iterator beta;
441 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
442 if( beta->vars() >> ns )
443 break;
444 if( beta != _Qb.end() )
445 return( beta->marginal(ns) );
446 else {
447 vector<Factor>::const_iterator alpha;
448 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
449 if( alpha->vars() >> ns )
450 break;
451 if( alpha != _Qa.end() )
452 return( alpha->marginal(ns) );
453 else {
454 // Find subtree to do efficient inference
455 DEdgeVec T;
456 size_t Tsize = findEfficientTree( ns, T );
457
458 // Find remaining variables (which are not in the new root)
459 VarSet nsrem = ns / OR(T.front().n1).vars();
460 Factor Pns (ns, 0.0);
461
462 // Save _Qa and _Qb on the subtree
463 map<size_t,Factor> _Qa_old;
464 map<size_t,Factor> _Qb_old;
465 vector<size_t> b(Tsize, 0);
466 for( size_t i = Tsize; (i--) != 0; ) {
467 size_t alpha1 = T[i].n1;
468 size_t alpha2 = T[i].n2;
469 size_t beta;
470 for( beta = 0; beta < nrIRs(); beta++ )
471 if( UEdge( _RTree[beta].n1, _RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
472 break;
473 assert( beta != nrIRs() );
474 b[i] = beta;
475
476 if( !_Qa_old.count( alpha1 ) )
477 _Qa_old[alpha1] = _Qa[alpha1];
478 if( !_Qa_old.count( alpha2 ) )
479 _Qa_old[alpha2] = _Qa[alpha2];
480 if( !_Qb_old.count( beta ) )
481 _Qb_old[beta] = _Qb[beta];
482 }
483
484 // For all states of nsrem
485 for( State s(nsrem); s.valid(); s++ ) {
486 // CollectEvidence
487 double logZ = 0.0;
488 for( size_t i = Tsize; (i--) != 0; ) {
489 // Make outer region T[i].n1 consistent with outer region T[i].n2
490 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
491
492 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
493 if( _Qa[T[i].n2].vars() >> *n ) {
494 Factor piet( *n, 0.0 );
495 piet[s(*n)] = 1.0;
496 _Qa[T[i].n2] *= piet;
497 }
498
499 Factor new_Qb = _Qa[T[i].n2].partSum( IR( b[i] ) );
500 logZ += log(new_Qb.normalize());
501 _Qa[T[i].n1] *= new_Qb.divided_by( _Qb[b[i]] );
502 _Qb[b[i]] = new_Qb;
503 }
504 logZ += log(_Qa[T[0].n1].normalize());
505
506 Factor piet( nsrem, 0.0 );
507 piet[s] = exp(logZ);
508 Pns += piet * _Qa[T[0].n1].partSum( ns / nsrem ); // OPTIMIZE ME
509
510 // Restore clamped beliefs
511 for( map<size_t,Factor>::const_iterator alpha = _Qa_old.begin(); alpha != _Qa_old.end(); alpha++ )
512 _Qa[alpha->first] = alpha->second;
513 for( map<size_t,Factor>::const_iterator beta = _Qb_old.begin(); beta != _Qb_old.end(); beta++ )
514 _Qb[beta->first] = beta->second;
515 }
516
517 return( Pns.normalized() );
518 }
519 }
520 }
521
522
523 // first return value is treewidth
524 // second return value is number of states in largest clique
525 pair<size_t,size_t> treewidth( const FactorGraph & fg ) {
526 ClusterGraph _cg;
527
528 // Copy factors
529 for( size_t I = 0; I < fg.nrFactors(); I++ )
530 _cg.insert( fg.factor(I).vars() );
531
532 // Retain only maximal clusters
533 _cg.eraseNonMaximal();
534
535 // Obtain elimination sequence
536 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
537
538 // Calculate treewidth
539 size_t treewidth = 0;
540 size_t nrstates = 0;
541 for( size_t i = 0; i < ElimVec.size(); i++ ) {
542 if( ElimVec[i].size() > treewidth )
543 treewidth = ElimVec[i].size();
544 if( ElimVec[i].states() > nrstates )
545 nrstates = ElimVec[i].states();
546 }
547
548 return pair<size_t,size_t>(treewidth, nrstates);
549 }
550
551
552 } // end of namespace dai