[Frederik Eaton] Major cleanup of BBP and CBP code and documentation
[libdai.git] / src / jtree.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <dai/jtree.h>
25
26
27 namespace dai {
28
29
30 using namespace std;
31
32
33 const char *JTree::Name = "JTREE";
34
35
36 void JTree::setProperties( const PropertySet &opts ) {
37 assert( opts.hasKey("verbose") );
38 assert( opts.hasKey("updates") );
39
40 props.verbose = opts.getStringAs<size_t>("verbose");
41 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
42 }
43
44
45 PropertySet JTree::getProperties() const {
46 PropertySet opts;
47 opts.Set( "verbose", props.verbose );
48 opts.Set( "updates", props.updates );
49 return opts;
50 }
51
52
53 string JTree::printProperties() const {
54 stringstream s( stringstream::out );
55 s << "[";
56 s << "verbose=" << props.verbose << ",";
57 s << "updates=" << props.updates << "]";
58 return s.str();
59 }
60
61
62 JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {
63 setProperties( opts );
64
65 if( !isConnected() )
66 DAI_THROW(FACTORGRAPH_NOT_CONNECTED);
67
68 if( automatic ) {
69 // Create ClusterGraph which contains factors as clusters
70 vector<VarSet> cl;
71 cl.reserve( fg.nrFactors() );
72 for( size_t I = 0; I < nrFactors(); I++ )
73 cl.push_back( factor(I).vars() );
74 ClusterGraph _cg( cl );
75
76 if( props.verbose >= 3 )
77 cerr << "Initial clusters: " << _cg << endl;
78
79 // Retain only maximal clusters
80 _cg.eraseNonMaximal();
81 if( props.verbose >= 3 )
82 cerr << "Maximal clusters: " << _cg << endl;
83
84 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
85 if( props.verbose >= 3 )
86 cerr << "VarElim_MinFill result: " << ElimVec << endl;
87
88 GenerateJT( ElimVec );
89 }
90 }
91
92
93 void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
94 // Construct a weighted graph (each edge is weighted with the cardinality
95 // of the intersection of the nodes, where the nodes are the elements of
96 // Cliques).
97 WeightedGraph<int> JuncGraph;
98 for( size_t i = 0; i < Cliques.size(); i++ )
99 for( size_t j = i+1; j < Cliques.size(); j++ ) {
100 size_t w = (Cliques[i] & Cliques[j]).size();
101 if( w )
102 JuncGraph[UEdge(i,j)] = w;
103 }
104
105 // Construct maximal spanning tree using Prim's algorithm
106 RTree = MaxSpanningTreePrims( JuncGraph );
107
108 // Construct corresponding region graph
109
110 // Create outer regions
111 ORs.reserve( Cliques.size() );
112 for( size_t i = 0; i < Cliques.size(); i++ )
113 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
114
115 // For each factor, find an outer region that subsumes that factor.
116 // Then, multiply the outer region with that factor.
117 for( size_t I = 0; I < nrFactors(); I++ ) {
118 size_t alpha;
119 for( alpha = 0; alpha < nrORs(); alpha++ )
120 if( OR(alpha).vars() >> factor(I).vars() ) {
121 fac2OR.push_back( alpha );
122 break;
123 }
124 assert( alpha != nrORs() );
125 }
126 RecomputeORs();
127
128 // Create inner regions and edges
129 IRs.reserve( RTree.size() );
130 vector<Edge> edges;
131 edges.reserve( 2 * RTree.size() );
132 for( size_t i = 0; i < RTree.size(); i++ ) {
133 edges.push_back( Edge( RTree[i].n1, nrIRs() ) );
134 edges.push_back( Edge( RTree[i].n2, nrIRs() ) );
135 // inner clusters have counting number -1
136 IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) );
137 }
138
139 // create bipartite graph
140 G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
141
142 // Create messages and beliefs
143 Qa.clear();
144 Qa.reserve( nrORs() );
145 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
146 Qa.push_back( OR(alpha) );
147
148 Qb.clear();
149 Qb.reserve( nrIRs() );
150 for( size_t beta = 0; beta < nrIRs(); beta++ )
151 Qb.push_back( Factor( IR(beta), 1.0 ) );
152
153 _mes.clear();
154 _mes.reserve( nrORs() );
155 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
156 _mes.push_back( vector<Factor>() );
157 _mes[alpha].reserve( nbOR(alpha).size() );
158 foreach( const Neighbor &beta, nbOR(alpha) )
159 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
160 }
161
162 // Check counting numbers
163 Check_Counting_Numbers();
164
165 if( props.verbose >= 3 ) {
166 cerr << "Resulting regiongraph: " << *this << endl;
167 }
168 }
169
170
171 string JTree::identify() const {
172 return string(Name) + printProperties();
173 }
174
175
176 Factor JTree::belief( const VarSet &ns ) const {
177 vector<Factor>::const_iterator beta;
178 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
179 if( beta->vars() >> ns )
180 break;
181 if( beta != Qb.end() )
182 return( beta->marginal(ns) );
183 else {
184 vector<Factor>::const_iterator alpha;
185 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
186 if( alpha->vars() >> ns )
187 break;
188 assert( alpha != Qa.end() );
189 return( alpha->marginal(ns) );
190 }
191 }
192
193
194 vector<Factor> JTree::beliefs() const {
195 vector<Factor> result;
196 for( size_t beta = 0; beta < nrIRs(); beta++ )
197 result.push_back( Qb[beta] );
198 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
199 result.push_back( Qa[alpha] );
200 return result;
201 }
202
203
204 Factor JTree::belief( const Var &n ) const {
205 return belief( (VarSet)n );
206 }
207
208
209 // Needs no init
210 void JTree::runHUGIN() {
211 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
212 Qa[alpha] = OR(alpha);
213
214 for( size_t beta = 0; beta < nrIRs(); beta++ )
215 Qb[beta].fill( 1.0 );
216
217 // CollectEvidence
218 _logZ = 0.0;
219 for( size_t i = RTree.size(); (i--) != 0; ) {
220 // Make outer region RTree[i].n1 consistent with outer region RTree[i].n2
221 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
222 Factor new_Qb = Qa[RTree[i].n2].marginal( IR( i ), false );
223 _logZ += log(new_Qb.normalize());
224 Qa[RTree[i].n1] *= new_Qb / Qb[i];
225 Qb[i] = new_Qb;
226 }
227 if( RTree.empty() )
228 _logZ += log(Qa[0].normalize() );
229 else
230 _logZ += log(Qa[RTree[0].n1].normalize());
231
232 // DistributeEvidence
233 for( size_t i = 0; i < RTree.size(); i++ ) {
234 // Make outer region RTree[i].n2 consistent with outer region RTree[i].n1
235 // IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
236 Factor new_Qb = Qa[RTree[i].n1].marginal( IR( i ) );
237 Qa[RTree[i].n2] *= new_Qb / Qb[i];
238 Qb[i] = new_Qb;
239 }
240
241 // Normalize
242 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
243 Qa[alpha].normalize();
244 }
245
246
247 // Really needs no init! Initial messages can be anything.
248 void JTree::runShaferShenoy() {
249 // First pass
250 _logZ = 0.0;
251 for( size_t e = nrIRs(); (e--) != 0; ) {
252 // send a message from RTree[e].n2 to RTree[e].n1
253 // or, actually, from the seperator IR(e) to RTree[e].n1
254
255 size_t i = nbIR(e)[1].node; // = RTree[e].n2
256 size_t j = nbIR(e)[0].node; // = RTree[e].n1
257 size_t _e = nbIR(e)[0].dual;
258
259 Factor piet = OR(i);
260 foreach( const Neighbor &k, nbOR(i) )
261 if( k != e )
262 piet *= message( i, k.iter );
263 message( j, _e ) = piet.marginal( IR(e), false );
264 _logZ += log( message(j,_e).normalize() );
265 }
266
267 // Second pass
268 for( size_t e = 0; e < nrIRs(); e++ ) {
269 size_t i = nbIR(e)[0].node; // = RTree[e].n1
270 size_t j = nbIR(e)[1].node; // = RTree[e].n2
271 size_t _e = nbIR(e)[1].dual;
272
273 Factor piet = OR(i);
274 foreach( const Neighbor &k, nbOR(i) )
275 if( k != e )
276 piet *= message( i, k.iter );
277 message( j, _e ) = piet.marginal( IR(e) );
278 }
279
280 // Calculate beliefs
281 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
282 Factor piet = OR(alpha);
283 foreach( const Neighbor &k, nbOR(alpha) )
284 piet *= message( alpha, k.iter );
285 if( nrIRs() == 0 ) {
286 _logZ += log( piet.normalize() );
287 Qa[alpha] = piet;
288 } else if( alpha == nbIR(0)[0].node /*RTree[0].n1*/ ) {
289 _logZ += log( piet.normalize() );
290 Qa[alpha] = piet;
291 } else
292 Qa[alpha] = piet.normalized();
293 }
294
295 // Only for logZ (and for belief)...
296 for( size_t beta = 0; beta < nrIRs(); beta++ )
297 Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
298 }
299
300
301 double JTree::run() {
302 if( props.updates == Properties::UpdateType::HUGIN )
303 runHUGIN();
304 else if( props.updates == Properties::UpdateType::SHSH )
305 runShaferShenoy();
306 return 0.0;
307 }
308
309
310 Real JTree::logZ() const {
311 Real s = 0.0;
312 for( size_t beta = 0; beta < nrIRs(); beta++ )
313 s += IR(beta).c() * Qb[beta].entropy();
314 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
315 s += OR(alpha).c() * Qa[alpha].entropy();
316 s += (OR(alpha).log(true) * Qa[alpha]).sum();
317 }
318 return s;
319 }
320
321
322
323 size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot ) const {
324 // find new root clique (the one with maximal statespace overlap with ns)
325 size_t maxval = 0, maxalpha = 0;
326 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
327 size_t val = VarSet(ns & OR(alpha).vars()).nrStates();
328 if( val > maxval ) {
329 maxval = val;
330 maxalpha = alpha;
331 }
332 }
333
334 // grow new tree
335 Graph oldTree;
336 for( DEdgeVec::const_iterator e = RTree.begin(); e != RTree.end(); e++ )
337 oldTree.insert( UEdge(e->n1, e->n2) );
338 DEdgeVec newTree = GrowRootedTree( oldTree, maxalpha );
339
340 // identify subtree that contains variables of ns which are not in the new root
341 VarSet nsrem = ns / OR(maxalpha).vars();
342 set<DEdge> subTree;
343 // for each variable in ns that is not in the root clique
344 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ ) {
345 // find first occurence of *n in the tree, which is closest to the root
346 size_t e = 0;
347 for( ; e != newTree.size(); e++ ) {
348 if( OR(newTree[e].n2).vars().contains( *n ) )
349 break;
350 }
351 assert( e != newTree.size() );
352
353 // track-back path to root and add edges to subTree
354 subTree.insert( newTree[e] );
355 size_t pos = newTree[e].n1;
356 for( ; e > 0; e-- )
357 if( newTree[e-1].n2 == pos ) {
358 subTree.insert( newTree[e-1] );
359 pos = newTree[e-1].n1;
360 }
361 }
362 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
363 // find first occurence of PreviousRoot in the tree, which is closest to the new root
364 size_t e = 0;
365 for( ; e != newTree.size(); e++ ) {
366 if( newTree[e].n2 == PreviousRoot )
367 break;
368 }
369 assert( e != newTree.size() );
370
371 // track-back path to root and add edges to subTree
372 subTree.insert( newTree[e] );
373 size_t pos = newTree[e].n1;
374 for( ; e > 0; e-- )
375 if( newTree[e-1].n2 == pos ) {
376 subTree.insert( newTree[e-1] );
377 pos = newTree[e-1].n1;
378 }
379 }
380
381 // Resulting Tree is a reordered copy of newTree
382 // First add edges in subTree to Tree
383 Tree.clear();
384 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
385 if( subTree.count( *e ) ) {
386 Tree.push_back( *e );
387 }
388 // Then add edges pointing away from nsrem
389 // FIXME
390 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
391 for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
392 if( *e != *sTi ) {
393 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
394 e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
395 Tree.push_back( *e );
396 }
397 }*/
398 // FIXME
399 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
400 if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
401 bool found = false;
402 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
403 if( (OR(e->n1).vars() && *n) ) {
404 found = true;
405 break;
406 }
407 if( found ) {
408 Tree.push_back( *e );
409 }
410 }*/
411 size_t subTreeSize = Tree.size();
412 // Then add remaining edges
413 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
414 if( find( Tree.begin(), Tree.end(), *e ) == Tree.end() )
415 Tree.push_back( *e );
416
417 return subTreeSize;
418 }
419
420
421 // Cutset conditioning
422 // assumes that run() has been called already
423 Factor JTree::calcMarginal( const VarSet& ns ) {
424 vector<Factor>::const_iterator beta;
425 for( beta = Qb.begin(); beta != Qb.end(); beta++ )
426 if( beta->vars() >> ns )
427 break;
428 if( beta != Qb.end() )
429 return( beta->marginal(ns) );
430 else {
431 vector<Factor>::const_iterator alpha;
432 for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
433 if( alpha->vars() >> ns )
434 break;
435 if( alpha != Qa.end() )
436 return( alpha->marginal(ns) );
437 else {
438 // Find subtree to do efficient inference
439 DEdgeVec T;
440 size_t Tsize = findEfficientTree( ns, T );
441
442 // Find remaining variables (which are not in the new root)
443 VarSet nsrem = ns / OR(T.front().n1).vars();
444 Factor Pns (ns, 0.0);
445
446 // Save Qa and Qb on the subtree
447 map<size_t,Factor> Qa_old;
448 map<size_t,Factor> Qb_old;
449 vector<size_t> b(Tsize, 0);
450 for( size_t i = Tsize; (i--) != 0; ) {
451 size_t alpha1 = T[i].n1;
452 size_t alpha2 = T[i].n2;
453 size_t beta;
454 for( beta = 0; beta < nrIRs(); beta++ )
455 if( UEdge( RTree[beta].n1, RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
456 break;
457 assert( beta != nrIRs() );
458 b[i] = beta;
459
460 if( !Qa_old.count( alpha1 ) )
461 Qa_old[alpha1] = Qa[alpha1];
462 if( !Qa_old.count( alpha2 ) )
463 Qa_old[alpha2] = Qa[alpha2];
464 if( !Qb_old.count( beta ) )
465 Qb_old[beta] = Qb[beta];
466 }
467
468 // For all states of nsrem
469 for( State s(nsrem); s.valid(); s++ ) {
470 // CollectEvidence
471 double logZ = 0.0;
472 for( size_t i = Tsize; (i--) != 0; ) {
473 // Make outer region T[i].n1 consistent with outer region T[i].n2
474 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
475
476 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
477 if( Qa[T[i].n2].vars() >> *n ) {
478 Factor piet( *n, 0.0 );
479 piet[s(*n)] = 1.0;
480 Qa[T[i].n2] *= piet;
481 }
482
483 Factor new_Qb = Qa[T[i].n2].marginal( IR( b[i] ), false );
484 logZ += log(new_Qb.normalize());
485 Qa[T[i].n1] *= new_Qb / Qb[b[i]];
486 Qb[b[i]] = new_Qb;
487 }
488 logZ += log(Qa[T[0].n1].normalize());
489
490 Factor piet( nsrem, 0.0 );
491 piet[s] = exp(logZ);
492 Pns += piet * Qa[T[0].n1].marginal( ns / nsrem, false ); // OPTIMIZE ME
493
494 // Restore clamped beliefs
495 for( map<size_t,Factor>::const_iterator alpha = Qa_old.begin(); alpha != Qa_old.end(); alpha++ )
496 Qa[alpha->first] = alpha->second;
497 for( map<size_t,Factor>::const_iterator beta = Qb_old.begin(); beta != Qb_old.end(); beta++ )
498 Qb[beta->first] = beta->second;
499 }
500
501 return( Pns.normalized() );
502 }
503 }
504 }
505
506
507 /// Calculates upper bound to the treewidth of a FactorGraph
508 /** \relates JTree
509 * \return a pair (number of variables in largest clique, number of states in largest clique)
510 */
511 std::pair<size_t,size_t> treewidth( const FactorGraph & fg ) {
512 ClusterGraph _cg;
513
514 // Copy factors
515 for( size_t I = 0; I < fg.nrFactors(); I++ )
516 _cg.insert( fg.factor(I).vars() );
517
518 // Retain only maximal clusters
519 _cg.eraseNonMaximal();
520
521 // Obtain elimination sequence
522 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
523
524 // Calculate treewidth
525 size_t treewidth = 0;
526 size_t nrstates = 0;
527 for( size_t i = 0; i < ElimVec.size(); i++ ) {
528 if( ElimVec[i].size() > treewidth )
529 treewidth = ElimVec[i].size();
530 size_t s = ElimVec[i].nrStates();
531 if( s > nrstates )
532 nrstates = s;
533 }
534
535 return pair<size_t,size_t>(treewidth, nrstates);
536 }
537
538
539 } // end of namespace dai