Merge branch 'no-edges2'
[libdai.git] / src / jtree.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <dai/jtree.h>
24
25
26 namespace dai {
27
28
29 using namespace std;
30
31
32 const char *JTree::Name = "JTREE";
33
34
35 bool JTree::checkProperties() {
36 if (!HasProperty("verbose") )
37 return false;
38 if( !HasProperty("updates") )
39 return false;
40
41 ConvertPropertyTo<size_t>("verbose");
42 ConvertPropertyTo<UpdateType>("updates");
43
44 return true;
45 }
46
47
48 JTree::JTree( const FactorGraph &fg, const Properties &opts, bool automatic ) : DAIAlgRG(fg, opts), _RTree(), _Qa(), _Qb(), _mes(), _logZ() {
49 assert( checkProperties() );
50
51 if( automatic ) {
52 ClusterGraph _cg;
53
54 // Copy factors
55 for( size_t I = 0; I < nrFactors(); I++ )
56 _cg.insert( factor(I).vars() );
57 if( Verbose() >= 3 )
58 cout << "Initial clusters: " << _cg << endl;
59
60 // Retain only maximal clusters
61 _cg.eraseNonMaximal();
62 if( Verbose() >= 3 )
63 cout << "Maximal clusters: " << _cg << endl;
64
65 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
66 if( Verbose() >= 3 ) {
67 cout << "VarElim_MinFill result: {" << endl;
68 for( size_t i = 0; i < ElimVec.size(); i++ ) {
69 if( i != 0 )
70 cout << ", ";
71 cout << ElimVec[i];
72 }
73 cout << "}" << endl;
74 }
75
76 GenerateJT( ElimVec );
77 }
78 }
79
80
81 void JTree::GenerateJT( const vector<VarSet> &Cliques ) {
82 // Construct a weighted graph (each edge is weighted with the cardinality
83 // of the intersection of the nodes, where the nodes are the elements of
84 // Cliques).
85 WeightedGraph<int> JuncGraph;
86 for( size_t i = 0; i < Cliques.size(); i++ )
87 for( size_t j = i+1; j < Cliques.size(); j++ ) {
88 size_t w = (Cliques[i] & Cliques[j]).size();
89 JuncGraph[UEdge(i,j)] = w;
90 }
91
92 // Construct maximal spanning tree using Prim's algorithm
93 _RTree = MaxSpanningTreePrim( JuncGraph );
94
95 // Construct corresponding region graph
96
97 // Create outer regions
98 ORs.reserve( Cliques.size() );
99 for( size_t i = 0; i < Cliques.size(); i++ )
100 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
101
102 // For each factor, find an outer region that subsumes that factor.
103 // Then, multiply the outer region with that factor.
104 for( size_t I = 0; I < nrFactors(); I++ ) {
105 size_t alpha;
106 for( alpha = 0; alpha < nrORs(); alpha++ )
107 if( OR(alpha).vars() >> factor(I).vars() ) {
108 // OR(alpha) *= factor(I);
109 fac2OR.push_back( alpha );
110 break;
111 }
112 assert( alpha != nrORs() );
113 }
114 RecomputeORs();
115
116 // Create inner regions and edges
117 IRs.reserve( _RTree.size() );
118 typedef pair<size_t,size_t> Edge;
119 vector<Edge> edges;
120 edges.reserve( 2 * _RTree.size() );
121 for( size_t i = 0; i < _RTree.size(); i++ ) {
122 edges.push_back( Edge( _RTree[i].n1, nrIRs() ) );
123 edges.push_back( Edge( _RTree[i].n2, nrIRs() ) );
124 // inner clusters have counting number -1
125 IRs.push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
126 }
127
128 // create bipartite graph
129 G.create( nrORs(), nrIRs(), edges.begin(), edges.end() );
130
131 // Create messages and beliefs
132 _Qa.clear();
133 _Qa.reserve( nrORs() );
134 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
135 _Qa.push_back( OR(alpha) );
136
137 _Qb.clear();
138 _Qb.reserve( nrIRs() );
139 for( size_t beta = 0; beta < nrIRs(); beta++ )
140 _Qb.push_back( Factor( IR(beta), 1.0 ) );
141
142 _mes.clear();
143 _mes.reserve( nrORs() );
144 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
145 _mes.push_back( vector<Factor>() );
146 _mes[alpha].reserve( nbOR(alpha).size() );
147 foreach( const Neighbor &beta, nbOR(alpha) )
148 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
149 }
150
151 // Check counting numbers
152 Check_Counting_Numbers();
153
154 if( Verbose() >= 3 ) {
155 cout << "Resulting regiongraph: " << *this << endl;
156 }
157 }
158
159
160 string JTree::identify() const {
161 stringstream result (stringstream::out);
162 result << Name << GetProperties();
163 return result.str();
164 }
165
166
167 Factor JTree::belief( const VarSet &ns ) const {
168 vector<Factor>::const_iterator beta;
169 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
170 if( beta->vars() >> ns )
171 break;
172 if( beta != _Qb.end() )
173 return( beta->marginal(ns) );
174 else {
175 vector<Factor>::const_iterator alpha;
176 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
177 if( alpha->vars() >> ns )
178 break;
179 assert( alpha != _Qa.end() );
180 return( alpha->marginal(ns) );
181 }
182 }
183
184
185 vector<Factor> JTree::beliefs() const {
186 vector<Factor> result;
187 for( size_t beta = 0; beta < nrIRs(); beta++ )
188 result.push_back( _Qb[beta] );
189 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
190 result.push_back( _Qa[alpha] );
191 return result;
192 }
193
194
195 Factor JTree::belief( const Var &n ) const {
196 return belief( (VarSet)n );
197 }
198
199
200 // Needs no init
201 void JTree::runHUGIN() {
202 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
203 _Qa[alpha] = OR(alpha);
204
205 for( size_t beta = 0; beta < nrIRs(); beta++ )
206 _Qb[beta].fill( 1.0 );
207
208 // CollectEvidence
209 _logZ = 0.0;
210 for( size_t i = _RTree.size(); (i--) != 0; ) {
211 // Make outer region _RTree[i].n1 consistent with outer region _RTree[i].n2
212 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
213 Factor new_Qb = _Qa[_RTree[i].n2].part_sum( IR( i ) );
214 _logZ += log(new_Qb.normalize( Prob::NORMPROB ));
215 _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
216 _Qb[i] = new_Qb;
217 }
218 if( _RTree.empty() )
219 _logZ += log(_Qa[0].normalize( Prob::NORMPROB ) );
220 else
221 _logZ += log(_Qa[_RTree[0].n1].normalize( Prob::NORMPROB ));
222
223 // DistributeEvidence
224 for( size_t i = 0; i < _RTree.size(); i++ ) {
225 // Make outer region _RTree[i].n2 consistent with outer region _RTree[i].n1
226 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
227 Factor new_Qb = _Qa[_RTree[i].n1].marginal( IR( i ) );
228 _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
229 _Qb[i] = new_Qb;
230 }
231
232 // Normalize
233 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
234 _Qa[alpha].normalize( Prob::NORMPROB );
235 }
236
237
238 // Really needs no init! Initial messages can be anything.
239 void JTree::runShaferShenoy() {
240 // First pass
241 _logZ = 0.0;
242 for( size_t e = nrIRs(); (e--) != 0; ) {
243 // send a message from _RTree[e].n2 to _RTree[e].n1
244 // or, actually, from the seperator IR(e) to _RTree[e].n1
245
246 size_t i = nbIR(e)[1].node; // = _RTree[e].n2
247 size_t j = nbIR(e)[0].node; // = _RTree[e].n1
248 size_t _e = nbIR(e)[0].dual;
249
250 Factor piet = OR(i);
251 foreach( const Neighbor &k, nbOR(i) )
252 if( k != e )
253 piet *= message( i, k.iter );
254 message( j, _e ) = piet.part_sum( IR(e) );
255 _logZ += log( message(j,_e).normalize( Prob::NORMPROB ) );
256 }
257
258 // Second pass
259 for( size_t e = 0; e < nrIRs(); e++ ) {
260 size_t i = nbIR(e)[0].node; // = _RTree[e].n1
261 size_t j = nbIR(e)[1].node; // = _RTree[e].n2
262 size_t _e = nbIR(e)[1].dual;
263
264 Factor piet = OR(i);
265 foreach( const Neighbor &k, nbOR(i) )
266 if( k != e )
267 piet *= message( i, k.iter );
268 message( j, _e ) = piet.marginal( IR(e) );
269 }
270
271 // Calculate beliefs
272 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
273 Factor piet = OR(alpha);
274 foreach( const Neighbor &k, nbOR(alpha) )
275 piet *= message( alpha, k.iter );
276 if( nrIRs() == 0 ) {
277 _logZ += log( piet.normalize( Prob::NORMPROB ) );
278 _Qa[alpha] = piet;
279 } else if( alpha == nbIR(0)[0].node /*_RTree[0].n1*/ ) {
280 _logZ += log( piet.normalize( Prob::NORMPROB ) );
281 _Qa[alpha] = piet;
282 } else
283 _Qa[alpha] = piet.normalized( Prob::NORMPROB );
284 }
285
286 // Only for logZ (and for belief)...
287 for( size_t beta = 0; beta < nrIRs(); beta++ )
288 _Qb[beta] = _Qa[nbIR(beta)[0].node].marginal( IR(beta) );
289 }
290
291
292 double JTree::run() {
293 if( Updates() == UpdateType::HUGIN )
294 runHUGIN();
295 else if( Updates() == UpdateType::SHSH )
296 runShaferShenoy();
297 return 0.0;
298 }
299
300
301 Complex JTree::logZ() const {
302 Complex sum = 0.0;
303 for( size_t beta = 0; beta < nrIRs(); beta++ )
304 sum += Complex(IR(beta).c()) * _Qb[beta].entropy();
305 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
306 sum += Complex(OR(alpha).c()) * _Qa[alpha].entropy();
307 sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
308 }
309 return sum;
310 }
311
312
313
314 size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot ) const {
315 // find new root clique (the one with maximal statespace overlap with ns)
316 size_t maxval = 0, maxalpha = 0;
317 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
318 size_t val = (ns & OR(alpha).vars()).stateSpace();
319 if( val > maxval ) {
320 maxval = val;
321 maxalpha = alpha;
322 }
323 }
324
325 // for( size_t e = 0; e < _RTree.size(); e++ )
326 // cout << OR(_RTree[e].n1).vars() << "->" << OR(_RTree[e].n2).vars() << ", ";
327 // cout << endl;
328 // grow new tree
329 Graph oldTree;
330 for( DEdgeVec::const_iterator e = _RTree.begin(); e != _RTree.end(); e++ )
331 oldTree.insert( UEdge(e->n1, e->n2) );
332 DEdgeVec newTree = GrowRootedTree( oldTree, maxalpha );
333 // cout << ns << ": ";
334 // for( size_t e = 0; e < newTree.size(); e++ )
335 // cout << OR(newTree[e].n1).vars() << "->" << OR(newTree[e].n2).vars() << ", ";
336 // cout << endl;
337
338 // identify subtree that contains variables of ns which are not in the new root
339 VarSet nsrem = ns / OR(maxalpha).vars();
340 // cout << "nsrem:" << nsrem << endl;
341 set<DEdge> subTree;
342 // for each variable in ns that is not in the root clique
343 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ ) {
344 // find first occurence of *n in the tree, which is closest to the root
345 size_t e = 0;
346 for( ; e != newTree.size(); e++ ) {
347 if( OR(newTree[e].n2).vars() && *n )
348 break;
349 }
350 assert( e != newTree.size() );
351
352 // track-back path to root and add edges to subTree
353 subTree.insert( newTree[e] );
354 size_t pos = newTree[e].n1;
355 for( ; e > 0; e-- )
356 if( newTree[e-1].n2 == pos ) {
357 subTree.insert( newTree[e-1] );
358 pos = newTree[e-1].n1;
359 }
360 }
361 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
362 // find first occurence of PreviousRoot in the tree, which is closest to the new root
363 size_t e = 0;
364 for( ; e != newTree.size(); e++ ) {
365 if( newTree[e].n2 == PreviousRoot )
366 break;
367 }
368 assert( e != newTree.size() );
369
370 // track-back path to root and add edges to subTree
371 subTree.insert( newTree[e] );
372 size_t pos = newTree[e].n1;
373 for( ; e > 0; e-- )
374 if( newTree[e-1].n2 == pos ) {
375 subTree.insert( newTree[e-1] );
376 pos = newTree[e-1].n1;
377 }
378 }
379 // cout << "subTree: " << endl;
380 // for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
381 // cout << OR(sTi->n1).vars() << "->" << OR(sTi->n2).vars() << ", ";
382 // cout << endl;
383
384 // Resulting Tree is a reordered copy of newTree
385 // First add edges in subTree to Tree
386 Tree.clear();
387 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
388 if( subTree.count( *e ) ) {
389 Tree.push_back( *e );
390 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
391 }
392 // cout << endl;
393 // Then add edges pointing away from nsrem
394 // FIXME
395 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
396 for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
397 if( *e != *sTi ) {
398 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
399 e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
400 Tree.push_back( *e );
401 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
402 }
403 }*/
404 // FIXME
405 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
406 if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
407 bool found = false;
408 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
409 if( (OR(e->n1).vars() && *n) ) {
410 found = true;
411 break;
412 }
413 if( found ) {
414 Tree.push_back( *e );
415 cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
416 }
417 }
418 cout << endl;*/
419 size_t subTreeSize = Tree.size();
420 // Then add remaining edges
421 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
422 if( find( Tree.begin(), Tree.end(), *e ) == Tree.end() )
423 Tree.push_back( *e );
424
425 return subTreeSize;
426 }
427
428
429 // Cutset conditioning
430 // assumes that run() has been called already
431 Factor JTree::calcMarginal( const VarSet& ns ) {
432 vector<Factor>::const_iterator beta;
433 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
434 if( beta->vars() >> ns )
435 break;
436 if( beta != _Qb.end() )
437 return( beta->marginal(ns) );
438 else {
439 vector<Factor>::const_iterator alpha;
440 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
441 if( alpha->vars() >> ns )
442 break;
443 if( alpha != _Qa.end() )
444 return( alpha->marginal(ns) );
445 else {
446 // Find subtree to do efficient inference
447 DEdgeVec T;
448 size_t Tsize = findEfficientTree( ns, T );
449
450 // Find remaining variables (which are not in the new root)
451 VarSet nsrem = ns / OR(T.front().n1).vars();
452 Factor Pns (ns, 0.0);
453
454 multind mi( nsrem );
455
456 // Save _Qa and _Qb on the subtree
457 map<size_t,Factor> _Qa_old;
458 map<size_t,Factor> _Qb_old;
459 vector<size_t> b(Tsize, 0);
460 for( size_t i = Tsize; (i--) != 0; ) {
461 size_t alpha1 = T[i].n1;
462 size_t alpha2 = T[i].n2;
463 size_t beta;
464 for( beta = 0; beta < nrIRs(); beta++ )
465 if( UEdge( _RTree[beta].n1, _RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
466 break;
467 assert( beta != nrIRs() );
468 b[i] = beta;
469
470 if( !_Qa_old.count( alpha1 ) )
471 _Qa_old[alpha1] = _Qa[alpha1];
472 if( !_Qa_old.count( alpha2 ) )
473 _Qa_old[alpha2] = _Qa[alpha2];
474 if( !_Qb_old.count( beta ) )
475 _Qb_old[beta] = _Qb[beta];
476 }
477
478 // For all states of nsrem
479 for( size_t j = 0; j < mi.max(); j++ ) {
480 vector<size_t> vi = mi.vi( j );
481
482 // CollectEvidence
483 double logZ = 0.0;
484 for( size_t i = Tsize; (i--) != 0; ) {
485 // Make outer region T[i].n1 consistent with outer region T[i].n2
486 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
487
488 size_t k = 0;
489 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++, k++ )
490 if( _Qa[T[i].n2].vars() >> *n ) {
491 Factor piet( *n, 0.0 );
492 piet[vi[k]] = 1.0;
493 _Qa[T[i].n2] *= piet;
494 }
495
496 Factor new_Qb = _Qa[T[i].n2].part_sum( IR( b[i] ) );
497 logZ += log(new_Qb.normalize( Prob::NORMPROB ));
498 _Qa[T[i].n1] *= new_Qb.divided_by( _Qb[b[i]] );
499 _Qb[b[i]] = new_Qb;
500 }
501 logZ += log(_Qa[T[0].n1].normalize( Prob::NORMPROB ));
502
503 Factor piet( nsrem, 0.0 );
504 piet[j] = exp(logZ);
505 Pns += piet * _Qa[T[0].n1].part_sum( ns / nsrem ); // OPTIMIZE ME
506
507 // Restore clamped beliefs
508 for( map<size_t,Factor>::const_iterator alpha = _Qa_old.begin(); alpha != _Qa_old.end(); alpha++ )
509 _Qa[alpha->first] = alpha->second;
510 for( map<size_t,Factor>::const_iterator beta = _Qb_old.begin(); beta != _Qb_old.end(); beta++ )
511 _Qb[beta->first] = beta->second;
512 }
513
514 return( Pns.normalized(Prob::NORMPROB) );
515 }
516 }
517 }
518
519
520 } // end of namespace dai