Removed stuff from InfAlg, moved it to individual inference algorithms
[libdai.git] / src / jtree.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <dai/jtree.h>
24
25
26 namespace dai {
27
28
29 using namespace std;
30
31
32 const char *JTree::Name = "JTREE";
33
34
35 void JTree::setProperties( const PropertySet &opts ) {
36 assert( opts.hasKey("verbose") );
37 assert( opts.hasKey("updates") );
38
39 props.verbose = opts.getStringAs<size_t>("verbose");
40 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
41 }
42
43
44 PropertySet JTree::getProperties() const {
45 PropertySet opts;
46 opts.Set( "verbose", props.verbose );
47 opts.Set( "updates", props.updates );
48 return opts;
49 }
50
51
52 JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _RTree(), _Qa(), _Qb(), _mes(), _logZ(), props() {
53 setProperties( opts );
54
55 if( automatic ) {
56 // Copy VarSets of factors
57 vector<VarSet> cl;
58 cl.reserve( fg.nrFactors() );
59 for( size_t I = 0; I < nrFactors(); I++ )
60 cl.push_back( factor(I).vars() );
61 ClusterGraph _cg( cl );
62
63 if( props.verbose >= 3 )
64 cout << "Initial clusters: " << _cg << endl;
65
66 // Retain only maximal clusters
67 _cg.eraseNonMaximal();
68 if( props.verbose >= 3 )
69 cout << "Maximal clusters: " << _cg << endl;
70
71 vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
72 if( props.verbose >= 3 )
73 cout << "VarElim_MinFill result: " << ElimVec << endl;
74
75 GenerateJT( ElimVec );
76 }
77 }
78
79
80 void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
81 // Construct a weighted graph (each edge is weighted with the cardinality
82 // of the intersection of the nodes, where the nodes are the elements of
83 // Cliques).
84 WeightedGraph<int> JuncGraph;
85 for( size_t i = 0; i < Cliques.size(); i++ )
86 for( size_t j = i+1; j < Cliques.size(); j++ ) {
87 size_t w = (Cliques[i] & Cliques[j]).size();
88 JuncGraph[UEdge(i,j)] = w;
89 }
90
91 // Construct maximal spanning tree using Prim's algorithm
92 _RTree = MaxSpanningTreePrims( JuncGraph );
93
94 // Construct corresponding region graph
95
96 // Create outer regions
97 ORs.reserve( Cliques.size() );
98 for( size_t i = 0; i < Cliques.size(); i++ )
99 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
100
101 // For each factor, find an outer region that subsumes that factor.
102 // Then, multiply the outer region with that factor.
103 for( size_t I = 0; I < nrFactors(); I++ ) {
104 size_t alpha;
105 for( alpha = 0; alpha < nrORs(); alpha++ )
106 if( OR(alpha).vars() >> factor(I).vars() ) {
107 // OR(alpha) *= factor(I);
108 fac2OR.push_back( alpha );
109 break;
110 }
111 assert( alpha != nrORs() );
112 }
113 RecomputeORs();
114
115 // Create inner regions and edges
116 IRs.reserve( _RTree.size() );
117 vector<Edge> edges;
118 edges.reserve( 2 * _RTree.size() );
119 for( size_t i = 0; i < _RTree.size(); i++ ) {
120 edges.push_back( Edge( _RTree[i].n1, nrIRs() ) );
121 edges.push_back( Edge( _RTree[i].n2, nrIRs() ) );
122 // inner clusters have counting number -1
123 IRs.push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
124 }
125
126 // create bipartite graph
127 G.create( nrORs(), nrIRs(), edges.begin(), edges.end() );
128
129 // Create messages and beliefs
130 _Qa.clear();
131 _Qa.reserve( nrORs() );
132 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
133 _Qa.push_back( OR(alpha) );
134
135 _Qb.clear();
136 _Qb.reserve( nrIRs() );
137 for( size_t beta = 0; beta < nrIRs(); beta++ )
138 _Qb.push_back( Factor( IR(beta), 1.0 ) );
139
140 _mes.clear();
141 _mes.reserve( nrORs() );
142 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
143 _mes.push_back( vector<Factor>() );
144 _mes[alpha].reserve( nbOR(alpha).size() );
145 foreach( const Neighbor &beta, nbOR(alpha) )
146 _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
147 }
148
149 // Check counting numbers
150 Check_Counting_Numbers();
151
152 if( props.verbose >= 3 ) {
153 cout << "Resulting regiongraph: " << *this << endl;
154 }
155 }
156
157
158 string JTree::identify() const {
159 stringstream result (stringstream::out);
160 result << Name << getProperties();
161 return result.str();
162 }
163
164
165 Factor JTree::belief( const VarSet &ns ) const {
166 vector<Factor>::const_iterator beta;
167 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
168 if( beta->vars() >> ns )
169 break;
170 if( beta != _Qb.end() )
171 return( beta->marginal(ns) );
172 else {
173 vector<Factor>::const_iterator alpha;
174 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
175 if( alpha->vars() >> ns )
176 break;
177 assert( alpha != _Qa.end() );
178 return( alpha->marginal(ns) );
179 }
180 }
181
182
183 vector<Factor> JTree::beliefs() const {
184 vector<Factor> result;
185 for( size_t beta = 0; beta < nrIRs(); beta++ )
186 result.push_back( _Qb[beta] );
187 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
188 result.push_back( _Qa[alpha] );
189 return result;
190 }
191
192
193 Factor JTree::belief( const Var &n ) const {
194 return belief( (VarSet)n );
195 }
196
197
198 // Needs no init
199 void JTree::runHUGIN() {
200 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
201 _Qa[alpha] = OR(alpha);
202
203 for( size_t beta = 0; beta < nrIRs(); beta++ )
204 _Qb[beta].fill( 1.0 );
205
206 // CollectEvidence
207 _logZ = 0.0;
208 for( size_t i = _RTree.size(); (i--) != 0; ) {
209 // Make outer region _RTree[i].n1 consistent with outer region _RTree[i].n2
210 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
211 Factor new_Qb = _Qa[_RTree[i].n2].part_sum( IR( i ) );
212 _logZ += log(new_Qb.normalize( Prob::NORMPROB ));
213 _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
214 _Qb[i] = new_Qb;
215 }
216 if( _RTree.empty() )
217 _logZ += log(_Qa[0].normalize( Prob::NORMPROB ) );
218 else
219 _logZ += log(_Qa[_RTree[0].n1].normalize( Prob::NORMPROB ));
220
221 // DistributeEvidence
222 for( size_t i = 0; i < _RTree.size(); i++ ) {
223 // Make outer region _RTree[i].n2 consistent with outer region _RTree[i].n1
224 // IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
225 Factor new_Qb = _Qa[_RTree[i].n1].marginal( IR( i ) );
226 _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
227 _Qb[i] = new_Qb;
228 }
229
230 // Normalize
231 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
232 _Qa[alpha].normalize( Prob::NORMPROB );
233 }
234
235
236 // Really needs no init! Initial messages can be anything.
237 void JTree::runShaferShenoy() {
238 // First pass
239 _logZ = 0.0;
240 for( size_t e = nrIRs(); (e--) != 0; ) {
241 // send a message from _RTree[e].n2 to _RTree[e].n1
242 // or, actually, from the seperator IR(e) to _RTree[e].n1
243
244 size_t i = nbIR(e)[1].node; // = _RTree[e].n2
245 size_t j = nbIR(e)[0].node; // = _RTree[e].n1
246 size_t _e = nbIR(e)[0].dual;
247
248 Factor piet = OR(i);
249 foreach( const Neighbor &k, nbOR(i) )
250 if( k != e )
251 piet *= message( i, k.iter );
252 message( j, _e ) = piet.part_sum( IR(e) );
253 _logZ += log( message(j,_e).normalize( Prob::NORMPROB ) );
254 }
255
256 // Second pass
257 for( size_t e = 0; e < nrIRs(); e++ ) {
258 size_t i = nbIR(e)[0].node; // = _RTree[e].n1
259 size_t j = nbIR(e)[1].node; // = _RTree[e].n2
260 size_t _e = nbIR(e)[1].dual;
261
262 Factor piet = OR(i);
263 foreach( const Neighbor &k, nbOR(i) )
264 if( k != e )
265 piet *= message( i, k.iter );
266 message( j, _e ) = piet.marginal( IR(e) );
267 }
268
269 // Calculate beliefs
270 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
271 Factor piet = OR(alpha);
272 foreach( const Neighbor &k, nbOR(alpha) )
273 piet *= message( alpha, k.iter );
274 if( nrIRs() == 0 ) {
275 _logZ += log( piet.normalize( Prob::NORMPROB ) );
276 _Qa[alpha] = piet;
277 } else if( alpha == nbIR(0)[0].node /*_RTree[0].n1*/ ) {
278 _logZ += log( piet.normalize( Prob::NORMPROB ) );
279 _Qa[alpha] = piet;
280 } else
281 _Qa[alpha] = piet.normalized( Prob::NORMPROB );
282 }
283
284 // Only for logZ (and for belief)...
285 for( size_t beta = 0; beta < nrIRs(); beta++ )
286 _Qb[beta] = _Qa[nbIR(beta)[0].node].marginal( IR(beta) );
287 }
288
289
290 double JTree::run() {
291 if( props.updates == Properties::UpdateType::HUGIN )
292 runHUGIN();
293 else if( props.updates == Properties::UpdateType::SHSH )
294 runShaferShenoy();
295 return 0.0;
296 }
297
298
299 Complex JTree::logZ() const {
300 Complex sum = 0.0;
301 for( size_t beta = 0; beta < nrIRs(); beta++ )
302 sum += Complex(IR(beta).c()) * _Qb[beta].entropy();
303 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
304 sum += Complex(OR(alpha).c()) * _Qa[alpha].entropy();
305 sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
306 }
307 return sum;
308 }
309
310
311
312 size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot ) const {
313 // find new root clique (the one with maximal statespace overlap with ns)
314 size_t maxval = 0, maxalpha = 0;
315 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
316 size_t val = (ns & OR(alpha).vars()).states();
317 if( val > maxval ) {
318 maxval = val;
319 maxalpha = alpha;
320 }
321 }
322
323 // for( size_t e = 0; e < _RTree.size(); e++ )
324 // cout << OR(_RTree[e].n1).vars() << "->" << OR(_RTree[e].n2).vars() << ", ";
325 // cout << endl;
326 // grow new tree
327 Graph oldTree;
328 for( DEdgeVec::const_iterator e = _RTree.begin(); e != _RTree.end(); e++ )
329 oldTree.insert( UEdge(e->n1, e->n2) );
330 DEdgeVec newTree = GrowRootedTree( oldTree, maxalpha );
331 // cout << ns << ": ";
332 // for( size_t e = 0; e < newTree.size(); e++ )
333 // cout << OR(newTree[e].n1).vars() << "->" << OR(newTree[e].n2).vars() << ", ";
334 // cout << endl;
335
336 // identify subtree that contains variables of ns which are not in the new root
337 VarSet nsrem = ns / OR(maxalpha).vars();
338 // cout << "nsrem:" << nsrem << endl;
339 set<DEdge> subTree;
340 // for each variable in ns that is not in the root clique
341 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ ) {
342 // find first occurence of *n in the tree, which is closest to the root
343 size_t e = 0;
344 for( ; e != newTree.size(); e++ ) {
345 if( OR(newTree[e].n2).vars().contains( *n ) )
346 break;
347 }
348 assert( e != newTree.size() );
349
350 // track-back path to root and add edges to subTree
351 subTree.insert( newTree[e] );
352 size_t pos = newTree[e].n1;
353 for( ; e > 0; e-- )
354 if( newTree[e-1].n2 == pos ) {
355 subTree.insert( newTree[e-1] );
356 pos = newTree[e-1].n1;
357 }
358 }
359 if( PreviousRoot != (size_t)-1 && PreviousRoot != maxalpha) {
360 // find first occurence of PreviousRoot in the tree, which is closest to the new root
361 size_t e = 0;
362 for( ; e != newTree.size(); e++ ) {
363 if( newTree[e].n2 == PreviousRoot )
364 break;
365 }
366 assert( e != newTree.size() );
367
368 // track-back path to root and add edges to subTree
369 subTree.insert( newTree[e] );
370 size_t pos = newTree[e].n1;
371 for( ; e > 0; e-- )
372 if( newTree[e-1].n2 == pos ) {
373 subTree.insert( newTree[e-1] );
374 pos = newTree[e-1].n1;
375 }
376 }
377 // cout << "subTree: " << endl;
378 // for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
379 // cout << OR(sTi->n1).vars() << "->" << OR(sTi->n2).vars() << ", ";
380 // cout << endl;
381
382 // Resulting Tree is a reordered copy of newTree
383 // First add edges in subTree to Tree
384 Tree.clear();
385 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
386 if( subTree.count( *e ) ) {
387 Tree.push_back( *e );
388 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
389 }
390 // cout << endl;
391 // Then add edges pointing away from nsrem
392 // FIXME
393 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
394 for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
395 if( *e != *sTi ) {
396 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
397 e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
398 Tree.push_back( *e );
399 // cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
400 }
401 }*/
402 // FIXME
403 /* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
404 if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
405 bool found = false;
406 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
407 if( (OR(e->n1).vars() && *n) ) {
408 found = true;
409 break;
410 }
411 if( found ) {
412 Tree.push_back( *e );
413 cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
414 }
415 }
416 cout << endl;*/
417 size_t subTreeSize = Tree.size();
418 // Then add remaining edges
419 for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
420 if( find( Tree.begin(), Tree.end(), *e ) == Tree.end() )
421 Tree.push_back( *e );
422
423 return subTreeSize;
424 }
425
426
427 // Cutset conditioning
428 // assumes that run() has been called already
429 Factor JTree::calcMarginal( const VarSet& ns ) {
430 vector<Factor>::const_iterator beta;
431 for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
432 if( beta->vars() >> ns )
433 break;
434 if( beta != _Qb.end() )
435 return( beta->marginal(ns) );
436 else {
437 vector<Factor>::const_iterator alpha;
438 for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
439 if( alpha->vars() >> ns )
440 break;
441 if( alpha != _Qa.end() )
442 return( alpha->marginal(ns) );
443 else {
444 // Find subtree to do efficient inference
445 DEdgeVec T;
446 size_t Tsize = findEfficientTree( ns, T );
447
448 // Find remaining variables (which are not in the new root)
449 VarSet nsrem = ns / OR(T.front().n1).vars();
450 Factor Pns (ns, 0.0);
451
452 // Save _Qa and _Qb on the subtree
453 map<size_t,Factor> _Qa_old;
454 map<size_t,Factor> _Qb_old;
455 vector<size_t> b(Tsize, 0);
456 for( size_t i = Tsize; (i--) != 0; ) {
457 size_t alpha1 = T[i].n1;
458 size_t alpha2 = T[i].n2;
459 size_t beta;
460 for( beta = 0; beta < nrIRs(); beta++ )
461 if( UEdge( _RTree[beta].n1, _RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
462 break;
463 assert( beta != nrIRs() );
464 b[i] = beta;
465
466 if( !_Qa_old.count( alpha1 ) )
467 _Qa_old[alpha1] = _Qa[alpha1];
468 if( !_Qa_old.count( alpha2 ) )
469 _Qa_old[alpha2] = _Qa[alpha2];
470 if( !_Qb_old.count( beta ) )
471 _Qb_old[beta] = _Qb[beta];
472 }
473
474 // For all states of nsrem
475 for( State s(nsrem); s.valid(); s++ ) {
476
477 // CollectEvidence
478 double logZ = 0.0;
479 for( size_t i = Tsize; (i--) != 0; ) {
480 // Make outer region T[i].n1 consistent with outer region T[i].n2
481 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
482
483 for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
484 if( _Qa[T[i].n2].vars() >> *n ) {
485 Factor piet( *n, 0.0 );
486 piet[s(*n)] = 1.0;
487 _Qa[T[i].n2] *= piet;
488 }
489
490 Factor new_Qb = _Qa[T[i].n2].part_sum( IR( b[i] ) );
491 logZ += log(new_Qb.normalize( Prob::NORMPROB ));
492 _Qa[T[i].n1] *= new_Qb.divided_by( _Qb[b[i]] );
493 _Qb[b[i]] = new_Qb;
494 }
495 logZ += log(_Qa[T[0].n1].normalize( Prob::NORMPROB ));
496
497 Factor piet( nsrem, 0.0 );
498 piet[s] = exp(logZ);
499 Pns += piet * _Qa[T[0].n1].part_sum( ns / nsrem ); // OPTIMIZE ME
500
501 // Restore clamped beliefs
502 for( map<size_t,Factor>::const_iterator alpha = _Qa_old.begin(); alpha != _Qa_old.end(); alpha++ )
503 _Qa[alpha->first] = alpha->second;
504 for( map<size_t,Factor>::const_iterator beta = _Qb_old.begin(); beta != _Qb_old.end(); beta++ )
505 _Qb[beta->first] = beta->second;
506 }
507
508 return( Pns.normalized(Prob::NORMPROB) );
509 }
510 }
511 }
512
513
514 } // end of namespace dai