Merged tests/*, matlab/*, utils/* from SVN head...
[libdai.git] / src / jtree.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <dai/jtree.h>
24
25
26 namespace dai {
27
28
29 using namespace std;
30
31
32 const char *JTree::Name = "JTREE";
33
34
35 void JTree::setProperties( const PropertySet &opts ) {
36 assert( opts.hasKey("verbose") );
37 assert( opts.hasKey("updates") );
38
39 props.verbose = opts.getStringAs<size_t>("verbose");
40 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
41 }
42
43
44 PropertySet JTree::getProperties() const {
45 PropertySet opts;
46 opts.Set( "verbose", props.verbose );
47 opts.Set( "updates", props.updates );
48 return opts;
49 }
50
51
52 string JTree::printProperties() const {
53 stringstream s( stringstream::out );
54 s << "[";
55 s << "verbose=" << props.verbose << ",";
56 s << "updates=" << props.updates << "]";
57 return s.str();
58 }
59
60
61 JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _RTree(), _Qa(), _Qb(), _mes(), _logZ(), props() {
62 setProperties( opts );
63
64 if( automatic ) {
65 // Copy VarSets of factors
66 vector<VarSet> cl;
67 cl.reserve( fg.nrFactors() );
68 for( size_t I = 0; I < nrFactors(); I++ )
69 cl.push_back( factor(I).vars() );
70 ClusterGraph _cg( cl );
71
72 if( props.verbose >= 3 )
73 cout << "Initial clusters: " << _cg << endl;
74
75 // Retain only maximal clusters
76 _cg.eraseNonMaximal();
77 if( props.verbose >= 3 )
78 cout << "Maximal clusters: " << _cg << endl;
79
80 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
81 if( props.verbose >= 3 )
82 cout << "VarElim_MinFill result: " << ElimVec << endl;
83
84 GenerateJT( ElimVec );
85 }
86 }
87
88
89 void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
90 // Construct a weighted graph (each edge is weighted with the cardinality
91 // of the intersection of the nodes, where the nodes are the elements of
92 // Cliques).
93 WeightedGraph<int> JuncGraph;
94 for( size_t i = 0; i < Cliques.size(); i++ )
95 for( size_t j = i+1; j < Cliques.size(); j++ ) {
96 size_t w = (Cliques[i] & Cliques[j]).size();
97 JuncGraph[UEdge(i,j)] = w;
98 }
99
100 // Construct maximal spanning tree using Prim's algorithm
101 _RTree = MaxSpanningTreePrims( JuncGraph );
102
103 // Construct corresponding region graph
104
105 // Create outer regions
106 ORs.reserve( Cliques.size() );
107 for( size_t i = 0; i < Cliques.size(); i++ )
108 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
109
110 // For each factor, find an outer region that subsumes that factor.
111 // Then, multiply the outer region with that factor.
112 for( size_t I = 0; I < nrFactors(); I++ ) {
113 size_t alpha;
114 for( alpha = 0; alpha < nrORs(); alpha++ )
115 if( OR(alpha).vars() >> factor(I).vars() ) {
116 // OR(alpha) *= factor(I);
117 fac2OR.push_back( alpha );
118 break;
119 }
120 assert( alpha != nrORs() );
121 }
122 RecomputeORs();
123
124 // Create inner regions and edges
125 IRs.reserve( _RTree.size() );
126 vector<Edge> edges;
127 edges.reserve( 2 * _RTree.size() );
128 for( size_t i = 0; i < _RTree.size(); i++ ) {
129 edges.push_back( Edge( _RTree[i].n1, nrIRs() ) );
130 edges.push_back( Edge( _RTree[i].n2, nrIRs() ) );
131 // inner clusters have counting number -1
132 IRs.push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
133 }
134
135 // create bipartite graph
136 G.create( nrORs(), nrIRs(), edges.begin(), edges.end() );
137
138 // Create messages and beliefs
139 _Qa.clear();
140 _Qa.reserve( nrORs() );
141 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
142 _Qa.push_back( OR(alpha) );
143
144 _Qb.clear();
145 _Qb.reserve( nrIRs() );
146 for( size_t beta = 0; beta < nrIRs(); beta++ )
147 _Qb.push_back( Factor( IR(beta), 1.0 ) );
148
149 _mes.clear();
150 _mes.reserve( nrORs() );
151 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
152 _mes.push_back( vector<Factor>() );
153 _mes[alpha].reserve( nbOR(alpha).size() );
154 foreach( const Neighbor &beta, nbOR(alpha) )
155 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
156 }
157
158 // Check counting numbers
159 Check_Counting_Numbers();
160
161 if( props.verbose >= 3 ) {
162 cout << "Resulting regiongraph: " << *this << endl;
163 }
164 }
165
166
167 string JTree::identify() const {
168 return string(Name) + printProperties();
169 }
170
171
172 Factor JTree::belief( const VarSet &ns ) const {
173 vector<Factor>::const_iterator beta;
174 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
175 if( beta->vars() >> ns )
176 break;
177 if( beta != _Qb.end() )
178 return( beta->marginal(ns) );
179 else {
180 vector<Factor>::const_iterator alpha;
181 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
182 if( alpha->vars() >> ns )
183 break;
184 assert( alpha != _Qa.end() );
185 return( alpha->marginal(ns) );
186 }
187 }
188
189
190 vector<Factor> JTree::beliefs() const {
191 vector<Factor> result;
192 for( size_t beta = 0; beta < nrIRs(); beta++ )
193 result.push_back( _Qb[beta] );
194 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
195 result.push_back( _Qa[alpha] );
196 return result;
197 }
198
199
200 Factor JTree::belief( const Var &n ) const {
201 return belief( (VarSet)n );
202 }
203
204
205 // Needs no init
206 void JTree::runHUGIN() {
207 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
208 _Qa[alpha] = OR(alpha);
209
210 for( size_t beta = 0; beta < nrIRs(); beta++ )
211 _Qb[beta].fill( 1.0 );
212
213 // CollectEvidence
214 _logZ = 0.0;
215 for( size_t i = _RTree.size(); (i--) != 0; ) {
216 // Make outer region _RTree[i].n1 consistent with outer region _RTree[i].n2
217 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
218 Factor new_Qb = _Qa[_RTree[i].n2].partSum( IR( i ) );
219 _logZ += log(new_Qb.normalize( Prob::NORMPROB ));
220 _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
221 _Qb[i] = new_Qb;
222 }
223 if( _RTree.empty() )
224 _logZ += log(_Qa[0].normalize( Prob::NORMPROB ) );
225 else
226 _logZ += log(_Qa[_RTree[0].n1].normalize( Prob::NORMPROB ));
227
228 // DistributeEvidence
229 for( size_t i = 0; i < _RTree.size(); i++ ) {
230 // Make outer region _RTree[i].n2 consistent with outer region _RTree[i].n1
231 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
232 Factor new_Qb = _Qa[_RTree[i].n1].marginal( IR( i ) );
233 _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
234 _Qb[i] = new_Qb;
235 }
236
237 // Normalize
238 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
239 _Qa[alpha].normalize( Prob::NORMPROB );
240 }
241
242
243 // Really needs no init! Initial messages can be anything.
244 void JTree::runShaferShenoy() {
245 // First pass
246 _logZ = 0.0;
247 for( size_t e = nrIRs(); (e--) != 0; ) {
248 // send a message from _RTree[e].n2 to _RTree[e].n1
249 // or, actually, from the seperator IR(e) to _RTree[e].n1
250
251 size_t i = nbIR(e)[1].node; // = _RTree[e].n2
252 size_t j = nbIR(e)[0].node; // = _RTree[e].n1
253 size_t _e = nbIR(e)[0].dual;
254
255 Factor piet = OR(i);
256 foreach( const Neighbor &k, nbOR(i) )
257 if( k != e )
258 piet *= message( i, k.iter );
259 message( j, _e ) = piet.partSum( IR(e) );
260 _logZ += log( message(j,_e).normalize( Prob::NORMPROB ) );
261 }
262
263 // Second pass
264 for( size_t e = 0; e < nrIRs(); e++ ) {
265 size_t i = nbIR(e)[0].node; // = _RTree[e].n1
266 size_t j = nbIR(e)[1].node; // = _RTree[e].n2
267 size_t _e = nbIR(e)[1].dual;
268
269 Factor piet = OR(i);
270 foreach( const Neighbor &k, nbOR(i) )
271 if( k != e )
272 piet *= message( i, k.iter );
273 message( j, _e ) = piet.marginal( IR(e) );
274 }
275
276 // Calculate beliefs
277 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
278 Factor piet = OR(alpha);
279 foreach( const Neighbor &k, nbOR(alpha) )
280 piet *= message( alpha, k.iter );
281 if( nrIRs() == 0 ) {
282 _logZ += log( piet.normalize( Prob::NORMPROB ) );
283 _Qa[alpha] = piet;
284 } else if( alpha == nbIR(0)[0].node /*_RTree[0].n1*/ ) {
285 _logZ += log( piet.normalize( Prob::NORMPROB ) );
286 _Qa[alpha] = piet;
287 } else
288 _Qa[alpha] = piet.normalized( Prob::NORMPROB );
289 }
290
291 // Only for logZ (and for belief)...
292 for( size_t beta = 0; beta < nrIRs(); beta++ )
293 _Qb[beta] = _Qa[nbIR(beta)[0].node].marginal( IR(beta) );
294 }
295
296
297 double JTree::run() {
298 if( props.updates == Properties::UpdateType::HUGIN )
299 runHUGIN();
300 else if( props.updates == Properties::UpdateType::SHSH )
301 runShaferShenoy();
302 return 0.0;
303 }
304
305
306 Real JTree::logZ() const {
307 Real sum = 0.0;
308 for( size_t beta = 0; beta < nrIRs(); beta++ )
309 sum += IR(beta).c() * _Qb[beta].entropy();
310 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
311 sum += OR(alpha).c() * _Qa[alpha].entropy();
312 sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
313 }
314 return sum;
315 }
316
317
318
319 size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot ) const {
320 // find new root clique (the one with maximal statespace overlap with ns)
321 size_t maxval = 0, maxalpha = 0;
322 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
323 size_t val = (ns & OR(alpha).vars()).states();
324 if( val > maxval ) {
325 maxval = val;
326 maxalpha = alpha;
327 }
328 }
329
330 // for( size_t e = 0; e < _RTree.size(); e++ )
331 // cout << OR(_RTree[e].n1).vars() << "->" << OR(_RTree[e].n2).vars() << ", ";
332 // cout << endl;
333 // grow new tree
334 Graph oldTree;
335 for( DEdgeVec::const_iterator e = _RTree.begin(); e != _RTree.end(); e++ )
336 oldTree.insert( UEdge(e->n1, e->n2) );
337 DEdgeVec newTree = GrowRootedTree( oldTree, maxalpha );
338 // cout << ns << ": ";
339 // for( size_t e = 0; e < newTree.size(); e++ )
340 // cout << OR(newTree[e].n1).vars() << "->" << OR(newTree[e].n2).vars() << ", ";
341 // cout << endl;
342
343 // identify subtree that contains variables of ns which are not in the new root
344 VarSet nsrem = ns / OR(maxalpha).vars();
345 // cout << "nsrem:" << nsrem << endl;
346 set<DEdge> subTree;
347 // for each variable in ns that is not in the root clique
348 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ ) {
349 // find first occurence of *n in the tree, which is closest to the root
350 size_t e = 0;
351 for( ; e != newTree.size(); e++ ) {
352 if( OR(newTree[e].n2).vars().contains( *n ) )
353 break;
354 }
355 assert( e != newTree.size() );
356
357 // track-back path to root and add edges to subTree
358 subTree.insert( newTree[e] );
359 size_t pos = newTree[e].n1;
360 for( ; e > 0; e-- )
361 if( newTree[e-1].n2 == pos ) {
362 subTree.insert( newTree[e-1] );
363 pos = newTree[e-1].n1;
364 }
365 }
366 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
367 // find first occurence of PreviousRoot in the tree, which is closest to the new root
368 size_t e = 0;
369 for( ; e != newTree.size(); e++ ) {
370 if( newTree[e].n2 == PreviousRoot )
371 break;
372 }
373 assert( e != newTree.size() );
374
375 // track-back path to root and add edges to subTree
376 subTree.insert( newTree[e] );
377 size_t pos = newTree[e].n1;
378 for( ; e > 0; e-- )
379 if( newTree[e-1].n2 == pos ) {
380 subTree.insert( newTree[e-1] );
381 pos = newTree[e-1].n1;
382 }
383 }
384 // cout << "subTree: " << endl;
385 // for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
386 // cout << OR(sTi->n1).vars() << "->" << OR(sTi->n2).vars() << ", ";
387 // cout << endl;
388
389 // Resulting Tree is a reordered copy of newTree
390 // First add edges in subTree to Tree
391 Tree.clear();
392 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
393 if( subTree.count( *e ) ) {
394 Tree.push_back( *e );
395 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
396 }
397 // cout << endl;
398 // Then add edges pointing away from nsrem
399 // FIXME
400 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
401 for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
402 if( *e != *sTi ) {
403 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
404 e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
405 Tree.push_back( *e );
406 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
407 }
408 }*/
409 // FIXME
410 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
411 if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
412 bool found = false;
413 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
414 if( (OR(e->n1).vars() && *n) ) {
415 found = true;
416 break;
417 }
418 if( found ) {
419 Tree.push_back( *e );
420 cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
421 }
422 }
423 cout << endl;*/
424 size_t subTreeSize = Tree.size();
425 // Then add remaining edges
426 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
427 if( find( Tree.begin(), Tree.end(), *e ) == Tree.end() )
428 Tree.push_back( *e );
429
430 return subTreeSize;
431 }
432
433
434 // Cutset conditioning
435 // assumes that run() has been called already
436 Factor JTree::calcMarginal( const VarSet& ns ) {
437 vector<Factor>::const_iterator beta;
438 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
439 if( beta->vars() >> ns )
440 break;
441 if( beta != _Qb.end() )
442 return( beta->marginal(ns) );
443 else {
444 vector<Factor>::const_iterator alpha;
445 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
446 if( alpha->vars() >> ns )
447 break;
448 if( alpha != _Qa.end() )
449 return( alpha->marginal(ns) );
450 else {
451 // Find subtree to do efficient inference
452 DEdgeVec T;
453 size_t Tsize = findEfficientTree( ns, T );
454
455 // Find remaining variables (which are not in the new root)
456 VarSet nsrem = ns / OR(T.front().n1).vars();
457 Factor Pns (ns, 0.0);
458
459 // Save _Qa and _Qb on the subtree
460 map<size_t,Factor> _Qa_old;
461 map<size_t,Factor> _Qb_old;
462 vector<size_t> b(Tsize, 0);
463 for( size_t i = Tsize; (i--) != 0; ) {
464 size_t alpha1 = T[i].n1;
465 size_t alpha2 = T[i].n2;
466 size_t beta;
467 for( beta = 0; beta < nrIRs(); beta++ )
468 if( UEdge( _RTree[beta].n1, _RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
469 break;
470 assert( beta != nrIRs() );
471 b[i] = beta;
472
473 if( !_Qa_old.count( alpha1 ) )
474 _Qa_old[alpha1] = _Qa[alpha1];
475 if( !_Qa_old.count( alpha2 ) )
476 _Qa_old[alpha2] = _Qa[alpha2];
477 if( !_Qb_old.count( beta ) )
478 _Qb_old[beta] = _Qb[beta];
479 }
480
481 // For all states of nsrem
482 for( State s(nsrem); s.valid(); s++ ) {
483
484 // CollectEvidence
485 double logZ = 0.0;
486 for( size_t i = Tsize; (i--) != 0; ) {
487 // Make outer region T[i].n1 consistent with outer region T[i].n2
488 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
489
490 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
491 if( _Qa[T[i].n2].vars() >> *n ) {
492 Factor piet( *n, 0.0 );
493 piet[s(*n)] = 1.0;
494 _Qa[T[i].n2] *= piet;
495 }
496
497 Factor new_Qb = _Qa[T[i].n2].partSum( IR( b[i] ) );
498 logZ += log(new_Qb.normalize( Prob::NORMPROB ));
499 _Qa[T[i].n1] *= new_Qb.divided_by( _Qb[b[i]] );
500 _Qb[b[i]] = new_Qb;
501 }
502 logZ += log(_Qa[T[0].n1].normalize( Prob::NORMPROB ));
503
504 Factor piet( nsrem, 0.0 );
505 piet[s] = exp(logZ);
506 Pns += piet * _Qa[T[0].n1].partSum( ns / nsrem ); // OPTIMIZE ME
507
508 // Restore clamped beliefs
509 for( map<size_t,Factor>::const_iterator alpha = _Qa_old.begin(); alpha != _Qa_old.end(); alpha++ )
510 _Qa[alpha->first] = alpha->second;
511 for( map<size_t,Factor>::const_iterator beta = _Qb_old.begin(); beta != _Qb_old.end(); beta++ )
512 _Qb[beta->first] = beta->second;
513 }
514
515 return( Pns.normalized(Prob::NORMPROB) );
516 }
517 }
518 }
519
520
521 // first return value is treewidth
522 // second return value is number of states in largest clique
523 pair<size_t,size_t> treewidth( const FactorGraph & fg ) {
524 ClusterGraph _cg;
525
526 // Copy factors
527 for( size_t I = 0; I < fg.nrFactors(); I++ )
528 _cg.insert( fg.factor(I).vars() );
529
530 // Retain only maximal clusters
531 _cg.eraseNonMaximal();
532
533 // Obtain elimination sequence
534 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
535
536 // Calculate treewidth
537 size_t treewidth = 0;
538 size_t nrstates = 0;
539 for( size_t i = 0; i < ElimVec.size(); i++ ) {
540 if( ElimVec[i].size() > treewidth )
541 treewidth = ElimVec[i].size();
542 if( ElimVec[i].states() > nrstates )
543 nrstates = ElimVec[i].states();
544 }
545
546 return pair<size_t,size_t>(treewidth, nrstates);
547 }
548
549
550 } // end of namespace dai