b846d35e5f15de5b749d433fd3dd544cc779f945
[libdai.git] / src / jtree.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <dai/jtree.h>
24
25
26 namespace dai {
27
28
29 using namespace std;
30
31
32 const char *JTree::Name = "JTREE";
33
34
35 bool JTree::checkProperties() {
36 if (!HasProperty("verbose") )
37 return false;
38 if( !HasProperty("updates") )
39 return false;
40
41 ConvertPropertyTo<size_t>("verbose");
42 ConvertPropertyTo<UpdateType>("updates");
43
44 return true;
45 }
46
47
48 JTree::JTree( const FactorGraph &fg, const Properties &opts, bool automatic ) : DAIAlgRG(fg, opts), _RTree(), _Qa(), _Qb(), _mes(), _logZ() {
49 assert( checkProperties() );
50
51 if( automatic ) {
52 // Copy VarSets of factors
53 vector<VarSet> cl;
54 cl.reserve( fg.nrFactors() );
55 for( size_t I = 0; I < nrFactors(); I++ )
56 cl.push_back( factor(I).vars() );
57 ClusterGraph _cg( cl );
58
59 if( Verbose() >= 3 )
60 cout << "Initial clusters: " << _cg << endl;
61
62 // Retain only maximal clusters
63 _cg.eraseNonMaximal();
64 if( Verbose() >= 3 )
65 cout << "Maximal clusters: " << _cg << endl;
66
67 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
68 if( Verbose() >= 3 )
69 cout << "VarElim_MinFill result: " << ElimVec << endl;
70
71 GenerateJT( ElimVec );
72 }
73 }
74
75
76 void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
77 // Construct a weighted graph (each edge is weighted with the cardinality
78 // of the intersection of the nodes, where the nodes are the elements of
79 // Cliques).
80 WeightedGraph<int> JuncGraph;
81 for( size_t i = 0; i < Cliques.size(); i++ )
82 for( size_t j = i+1; j < Cliques.size(); j++ ) {
83 size_t w = (Cliques[i] & Cliques[j]).size();
84 JuncGraph[UEdge(i,j)] = w;
85 }
86
87 // Construct maximal spanning tree using Prim's algorithm
88 _RTree = MaxSpanningTreePrims( JuncGraph );
89
90 // Construct corresponding region graph
91
92 // Create outer regions
93 ORs.reserve( Cliques.size() );
94 for( size_t i = 0; i < Cliques.size(); i++ )
95 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
96
97 // For each factor, find an outer region that subsumes that factor.
98 // Then, multiply the outer region with that factor.
99 for( size_t I = 0; I < nrFactors(); I++ ) {
100 size_t alpha;
101 for( alpha = 0; alpha < nrORs(); alpha++ )
102 if( OR(alpha).vars() >> factor(I).vars() ) {
103 // OR(alpha) *= factor(I);
104 fac2OR.push_back( alpha );
105 break;
106 }
107 assert( alpha != nrORs() );
108 }
109 RecomputeORs();
110
111 // Create inner regions and edges
112 IRs.reserve( _RTree.size() );
113 vector<Edge> edges;
114 edges.reserve( 2 * _RTree.size() );
115 for( size_t i = 0; i < _RTree.size(); i++ ) {
116 edges.push_back( Edge( _RTree[i].n1, nrIRs() ) );
117 edges.push_back( Edge( _RTree[i].n2, nrIRs() ) );
118 // inner clusters have counting number -1
119 IRs.push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
120 }
121
122 // create bipartite graph
123 G.create( nrORs(), nrIRs(), edges.begin(), edges.end() );
124
125 // Create messages and beliefs
126 _Qa.clear();
127 _Qa.reserve( nrORs() );
128 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
129 _Qa.push_back( OR(alpha) );
130
131 _Qb.clear();
132 _Qb.reserve( nrIRs() );
133 for( size_t beta = 0; beta < nrIRs(); beta++ )
134 _Qb.push_back( Factor( IR(beta), 1.0 ) );
135
136 _mes.clear();
137 _mes.reserve( nrORs() );
138 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
139 _mes.push_back( vector<Factor>() );
140 _mes[alpha].reserve( nbOR(alpha).size() );
141 foreach( const Neighbor &beta, nbOR(alpha) )
142 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
143 }
144
145 // Check counting numbers
146 Check_Counting_Numbers();
147
148 if( Verbose() >= 3 ) {
149 cout << "Resulting regiongraph: " << *this << endl;
150 }
151 }
152
153
154 string JTree::identify() const {
155 stringstream result (stringstream::out);
156 result << Name << GetProperties();
157 return result.str();
158 }
159
160
161 Factor JTree::belief( const VarSet &ns ) const {
162 vector<Factor>::const_iterator beta;
163 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
164 if( beta->vars() >> ns )
165 break;
166 if( beta != _Qb.end() )
167 return( beta->marginal(ns) );
168 else {
169 vector<Factor>::const_iterator alpha;
170 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
171 if( alpha->vars() >> ns )
172 break;
173 assert( alpha != _Qa.end() );
174 return( alpha->marginal(ns) );
175 }
176 }
177
178
179 vector<Factor> JTree::beliefs() const {
180 vector<Factor> result;
181 for( size_t beta = 0; beta < nrIRs(); beta++ )
182 result.push_back( _Qb[beta] );
183 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
184 result.push_back( _Qa[alpha] );
185 return result;
186 }
187
188
189 Factor JTree::belief( const Var &n ) const {
190 return belief( (VarSet)n );
191 }
192
193
194 // Needs no init
195 void JTree::runHUGIN() {
196 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
197 _Qa[alpha] = OR(alpha);
198
199 for( size_t beta = 0; beta < nrIRs(); beta++ )
200 _Qb[beta].fill( 1.0 );
201
202 // CollectEvidence
203 _logZ = 0.0;
204 for( size_t i = _RTree.size(); (i--) != 0; ) {
205 // Make outer region _RTree[i].n1 consistent with outer region _RTree[i].n2
206 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
207 Factor new_Qb = _Qa[_RTree[i].n2].part_sum( IR( i ) );
208 _logZ += log(new_Qb.normalize( Prob::NORMPROB ));
209 _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
210 _Qb[i] = new_Qb;
211 }
212 if( _RTree.empty() )
213 _logZ += log(_Qa[0].normalize( Prob::NORMPROB ) );
214 else
215 _logZ += log(_Qa[_RTree[0].n1].normalize( Prob::NORMPROB ));
216
217 // DistributeEvidence
218 for( size_t i = 0; i < _RTree.size(); i++ ) {
219 // Make outer region _RTree[i].n2 consistent with outer region _RTree[i].n1
220 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
221 Factor new_Qb = _Qa[_RTree[i].n1].marginal( IR( i ) );
222 _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
223 _Qb[i] = new_Qb;
224 }
225
226 // Normalize
227 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
228 _Qa[alpha].normalize( Prob::NORMPROB );
229 }
230
231
232 // Really needs no init! Initial messages can be anything.
233 void JTree::runShaferShenoy() {
234 // First pass
235 _logZ = 0.0;
236 for( size_t e = nrIRs(); (e--) != 0; ) {
237 // send a message from _RTree[e].n2 to _RTree[e].n1
238 // or, actually, from the seperator IR(e) to _RTree[e].n1
239
240 size_t i = nbIR(e)[1].node; // = _RTree[e].n2
241 size_t j = nbIR(e)[0].node; // = _RTree[e].n1
242 size_t _e = nbIR(e)[0].dual;
243
244 Factor piet = OR(i);
245 foreach( const Neighbor &k, nbOR(i) )
246 if( k != e )
247 piet *= message( i, k.iter );
248 message( j, _e ) = piet.part_sum( IR(e) );
249 _logZ += log( message(j,_e).normalize( Prob::NORMPROB ) );
250 }
251
252 // Second pass
253 for( size_t e = 0; e < nrIRs(); e++ ) {
254 size_t i = nbIR(e)[0].node; // = _RTree[e].n1
255 size_t j = nbIR(e)[1].node; // = _RTree[e].n2
256 size_t _e = nbIR(e)[1].dual;
257
258 Factor piet = OR(i);
259 foreach( const Neighbor &k, nbOR(i) )
260 if( k != e )
261 piet *= message( i, k.iter );
262 message( j, _e ) = piet.marginal( IR(e) );
263 }
264
265 // Calculate beliefs
266 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
267 Factor piet = OR(alpha);
268 foreach( const Neighbor &k, nbOR(alpha) )
269 piet *= message( alpha, k.iter );
270 if( nrIRs() == 0 ) {
271 _logZ += log( piet.normalize( Prob::NORMPROB ) );
272 _Qa[alpha] = piet;
273 } else if( alpha == nbIR(0)[0].node /*_RTree[0].n1*/ ) {
274 _logZ += log( piet.normalize( Prob::NORMPROB ) );
275 _Qa[alpha] = piet;
276 } else
277 _Qa[alpha] = piet.normalized( Prob::NORMPROB );
278 }
279
280 // Only for logZ (and for belief)...
281 for( size_t beta = 0; beta < nrIRs(); beta++ )
282 _Qb[beta] = _Qa[nbIR(beta)[0].node].marginal( IR(beta) );
283 }
284
285
286 double JTree::run() {
287 if( Updates() == UpdateType::HUGIN )
288 runHUGIN();
289 else if( Updates() == UpdateType::SHSH )
290 runShaferShenoy();
291 return 0.0;
292 }
293
294
295 Complex JTree::logZ() const {
296 Complex sum = 0.0;
297 for( size_t beta = 0; beta < nrIRs(); beta++ )
298 sum += Complex(IR(beta).c()) * _Qb[beta].entropy();
299 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
300 sum += Complex(OR(alpha).c()) * _Qa[alpha].entropy();
301 sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
302 }
303 return sum;
304 }
305
306
307
308 size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot ) const {
309 // find new root clique (the one with maximal statespace overlap with ns)
310 size_t maxval = 0, maxalpha = 0;
311 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
312 size_t val = (ns & OR(alpha).vars()).states();
313 if( val > maxval ) {
314 maxval = val;
315 maxalpha = alpha;
316 }
317 }
318
319 // for( size_t e = 0; e < _RTree.size(); e++ )
320 // cout << OR(_RTree[e].n1).vars() << "->" << OR(_RTree[e].n2).vars() << ", ";
321 // cout << endl;
322 // grow new tree
323 Graph oldTree;
324 for( DEdgeVec::const_iterator e = _RTree.begin(); e != _RTree.end(); e++ )
325 oldTree.insert( UEdge(e->n1, e->n2) );
326 DEdgeVec newTree = GrowRootedTree( oldTree, maxalpha );
327 // cout << ns << ": ";
328 // for( size_t e = 0; e < newTree.size(); e++ )
329 // cout << OR(newTree[e].n1).vars() << "->" << OR(newTree[e].n2).vars() << ", ";
330 // cout << endl;
331
332 // identify subtree that contains variables of ns which are not in the new root
333 VarSet nsrem = ns / OR(maxalpha).vars();
334 // cout << "nsrem:" << nsrem << endl;
335 set<DEdge> subTree;
336 // for each variable in ns that is not in the root clique
337 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ ) {
338 // find first occurence of *n in the tree, which is closest to the root
339 size_t e = 0;
340 for( ; e != newTree.size(); e++ ) {
341 if( OR(newTree[e].n2).vars() && *n )
342 break;
343 }
344 assert( e != newTree.size() );
345
346 // track-back path to root and add edges to subTree
347 subTree.insert( newTree[e] );
348 size_t pos = newTree[e].n1;
349 for( ; e > 0; e-- )
350 if( newTree[e-1].n2 == pos ) {
351 subTree.insert( newTree[e-1] );
352 pos = newTree[e-1].n1;
353 }
354 }
355 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
356 // find first occurence of PreviousRoot in the tree, which is closest to the new root
357 size_t e = 0;
358 for( ; e != newTree.size(); e++ ) {
359 if( newTree[e].n2 == PreviousRoot )
360 break;
361 }
362 assert( e != newTree.size() );
363
364 // track-back path to root and add edges to subTree
365 subTree.insert( newTree[e] );
366 size_t pos = newTree[e].n1;
367 for( ; e > 0; e-- )
368 if( newTree[e-1].n2 == pos ) {
369 subTree.insert( newTree[e-1] );
370 pos = newTree[e-1].n1;
371 }
372 }
373 // cout << "subTree: " << endl;
374 // for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
375 // cout << OR(sTi->n1).vars() << "->" << OR(sTi->n2).vars() << ", ";
376 // cout << endl;
377
378 // Resulting Tree is a reordered copy of newTree
379 // First add edges in subTree to Tree
380 Tree.clear();
381 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
382 if( subTree.count( *e ) ) {
383 Tree.push_back( *e );
384 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
385 }
386 // cout << endl;
387 // Then add edges pointing away from nsrem
388 // FIXME
389 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
390 for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
391 if( *e != *sTi ) {
392 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
393 e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
394 Tree.push_back( *e );
395 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
396 }
397 }*/
398 // FIXME
399 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
400 if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
401 bool found = false;
402 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
403 if( (OR(e->n1).vars() && *n) ) {
404 found = true;
405 break;
406 }
407 if( found ) {
408 Tree.push_back( *e );
409 cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
410 }
411 }
412 cout << endl;*/
413 size_t subTreeSize = Tree.size();
414 // Then add remaining edges
415 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
416 if( find( Tree.begin(), Tree.end(), *e ) == Tree.end() )
417 Tree.push_back( *e );
418
419 return subTreeSize;
420 }
421
422
423 // Cutset conditioning
424 // assumes that run() has been called already
425 Factor JTree::calcMarginal( const VarSet& ns ) {
426 vector<Factor>::const_iterator beta;
427 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
428 if( beta->vars() >> ns )
429 break;
430 if( beta != _Qb.end() )
431 return( beta->marginal(ns) );
432 else {
433 vector<Factor>::const_iterator alpha;
434 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
435 if( alpha->vars() >> ns )
436 break;
437 if( alpha != _Qa.end() )
438 return( alpha->marginal(ns) );
439 else {
440 // Find subtree to do efficient inference
441 DEdgeVec T;
442 size_t Tsize = findEfficientTree( ns, T );
443
444 // Find remaining variables (which are not in the new root)
445 VarSet nsrem = ns / OR(T.front().n1).vars();
446 Factor Pns (ns, 0.0);
447
448 // Save _Qa and _Qb on the subtree
449 map<size_t,Factor> _Qa_old;
450 map<size_t,Factor> _Qb_old;
451 vector<size_t> b(Tsize, 0);
452 for( size_t i = Tsize; (i--) != 0; ) {
453 size_t alpha1 = T[i].n1;
454 size_t alpha2 = T[i].n2;
455 size_t beta;
456 for( beta = 0; beta < nrIRs(); beta++ )
457 if( UEdge( _RTree[beta].n1, _RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
458 break;
459 assert( beta != nrIRs() );
460 b[i] = beta;
461
462 if( !_Qa_old.count( alpha1 ) )
463 _Qa_old[alpha1] = _Qa[alpha1];
464 if( !_Qa_old.count( alpha2 ) )
465 _Qa_old[alpha2] = _Qa[alpha2];
466 if( !_Qb_old.count( beta ) )
467 _Qb_old[beta] = _Qb[beta];
468 }
469
470 // For all states of nsrem
471 for( State s(nsrem); s.valid(); s++ ) {
472
473 // CollectEvidence
474 double logZ = 0.0;
475 for( size_t i = Tsize; (i--) != 0; ) {
476 // Make outer region T[i].n1 consistent with outer region T[i].n2
477 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
478
479 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
480 if( _Qa[T[i].n2].vars() >> *n ) {
481 Factor piet( *n, 0.0 );
482 piet[s(*n)] = 1.0;
483 _Qa[T[i].n2] *= piet;
484 }
485
486 Factor new_Qb = _Qa[T[i].n2].part_sum( IR( b[i] ) );
487 logZ += log(new_Qb.normalize( Prob::NORMPROB ));
488 _Qa[T[i].n1] *= new_Qb.divided_by( _Qb[b[i]] );
489 _Qb[b[i]] = new_Qb;
490 }
491 logZ += log(_Qa[T[0].n1].normalize( Prob::NORMPROB ));
492
493 Factor piet( nsrem, 0.0 );
494 piet[s] = exp(logZ);
495 Pns += piet * _Qa[T[0].n1].part_sum( ns / nsrem ); // OPTIMIZE ME
496
497 // Restore clamped beliefs
498 for( map<size_t,Factor>::const_iterator alpha = _Qa_old.begin(); alpha != _Qa_old.end(); alpha++ )
499 _Qa[alpha->first] = alpha->second;
500 for( map<size_t,Factor>::const_iterator beta = _Qb_old.begin(); beta != _Qb_old.end(); beta++ )
501 _Qb[beta->first] = beta->second;
502 }
503
504 return( Pns.normalized(Prob::NORMPROB) );
505 }
506 }
507 }
508
509
510 } // end of namespace dai