Added examples example_sprinkler_gibbs and example_sprinkler_em
[libdai.git] / src / treeep.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <vector>
15 #include <dai/jtree.h>
16 #include <dai/treeep.h>
17 #include <dai/util.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 const char *TreeEP::Name = "TREEEP";
27
28
29 void TreeEP::setProperties( const PropertySet &opts ) {
30 DAI_ASSERT( opts.hasKey("tol") );
31 DAI_ASSERT( opts.hasKey("maxiter") );
32 DAI_ASSERT( opts.hasKey("verbose") );
33 DAI_ASSERT( opts.hasKey("type") );
34
35 props.tol = opts.getStringAs<Real>("tol");
36 props.maxiter = opts.getStringAs<size_t>("maxiter");
37 props.verbose = opts.getStringAs<size_t>("verbose");
38 props.type = opts.getStringAs<Properties::TypeType>("type");
39 }
40
41
42 PropertySet TreeEP::getProperties() const {
43 PropertySet opts;
44 opts.Set( "tol", props.tol );
45 opts.Set( "maxiter", props.maxiter );
46 opts.Set( "verbose", props.verbose );
47 opts.Set( "type", props.type );
48 return opts;
49 }
50
51
52 string TreeEP::printProperties() const {
53 stringstream s( stringstream::out );
54 s << "[";
55 s << "tol=" << props.tol << ",";
56 s << "maxiter=" << props.maxiter << ",";
57 s << "verbose=" << props.verbose << ",";
58 s << "type=" << props.type << "]";
59 return s.str();
60 }
61
62
63 TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
64 _ns = _I->vars();
65
66 // Make _Qa, _Qb, _a and _b corresponding to the subtree
67 _b.reserve( subRTree.size() );
68 _Qb.reserve( subRTree.size() );
69 _RTree.reserve( subRTree.size() );
70 for( size_t i = 0; i < subRTree.size(); i++ ) {
71 size_t alpha1 = subRTree[i].n1; // old index 1
72 size_t alpha2 = subRTree[i].n2; // old index 2
73 size_t beta; // old sep index
74 for( beta = 0; beta < jt_RTree.size(); beta++ )
75 if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
76 break;
77 DAI_ASSERT( beta != jt_RTree.size() );
78
79 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
80 if( newalpha1 == _a.size() ) {
81 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
82 _a.push_back( alpha1 ); // save old index in index conversion table
83 }
84
85 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
86 if( newalpha2 == _a.size() ) {
87 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
88 _a.push_back( alpha2 ); // save old index in index conversion table
89 }
90
91 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
92 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
93 _b.push_back( beta );
94 }
95
96 // Find remaining variables (which are not in the new root)
97 _nsrem = _ns / _Qa[0].vars();
98 }
99
100
101 void TreeEP::TreeEPSubTree::init() {
102 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
103 _Qa[alpha].fill( 1.0 );
104 for( size_t beta = 0; beta < _Qb.size(); beta++ )
105 _Qb[beta].fill( 1.0 );
106 }
107
108
109 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
110 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
111 _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
112
113 for( size_t beta = 0; beta < _Qb.size(); beta++ )
114 _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
115 }
116
117
118 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
119 // Backup _Qa and _Qb
120 vector<Factor> _Qa_old(_Qa);
121 vector<Factor> _Qb_old(_Qb);
122
123 // Clear Qa and Qb
124 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
125 Qa[_a[alpha]].fill( 0.0 );
126 for( size_t beta = 0; beta < _Qb.size(); beta++ )
127 Qb[_b[beta]].fill( 0.0 );
128
129 // For all states of _nsrem
130 for( State s(_nsrem); s.valid(); s++ ) {
131 // Multiply root with slice of I
132 _Qa[0] *= _I->slice( _nsrem, s );
133
134 // CollectEvidence
135 for( size_t i = _RTree.size(); (i--) != 0; ) {
136 // clamp variables in nsrem
137 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
138 if( _Qa[_RTree[i].n2].vars() >> *n ) {
139 Factor delta( *n, 0.0 );
140 delta[s(*n)] = 1.0;
141 _Qa[_RTree[i].n2] *= delta;
142 }
143 Factor new_Qb = _Qa[_RTree[i].n2].marginal( _Qb[i].vars(), false );
144 _Qa[_RTree[i].n1] *= new_Qb / _Qb[i];
145 _Qb[i] = new_Qb;
146 }
147
148 // DistributeEvidence
149 for( size_t i = 0; i < _RTree.size(); i++ ) {
150 Factor new_Qb = _Qa[_RTree[i].n1].marginal( _Qb[i].vars(), false );
151 _Qa[_RTree[i].n2] *= new_Qb / _Qb[i];
152 _Qb[i] = new_Qb;
153 }
154
155 // Store Qa's and Qb's
156 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
157 Qa[_a[alpha]].p() += _Qa[alpha].p();
158 for( size_t beta = 0; beta < _Qb.size(); beta++ )
159 Qb[_b[beta]].p() += _Qb[beta].p();
160
161 // Restore _Qa and _Qb
162 _Qa = _Qa_old;
163 _Qb = _Qb_old;
164 }
165
166 // Normalize Qa and Qb
167 _logZ = 0.0;
168 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
169 _logZ += log(Qa[_a[alpha]].sum());
170 Qa[_a[alpha]].normalize();
171 }
172 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
173 _logZ -= log(Qb[_b[beta]].sum());
174 Qb[_b[beta]].normalize();
175 }
176 }
177
178
179 Real TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
180 Real s = 0.0;
181 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
182 s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
183 for( size_t beta = 0; beta < _Qb.size(); beta++ )
184 s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
185 return s + _logZ;
186 }
187
188
189 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
190 setProperties( opts );
191
192 if( !isConnected() )
193 DAI_THROW(FACTORGRAPH_NOT_CONNECTED);
194
195 if( opts.hasKey("tree") ) {
196 construct( opts.GetAs<RootedTree>("tree") );
197 } else {
198 if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
199 // ORG: construct weighted graph with as weights a crude estimate of the
200 // mutual information between the nodes
201 // ALT: construct weighted graph with as weights an upper bound on the
202 // effective interaction strength between pairs of nodes
203
204 WeightedGraph<Real> wg;
205 for( size_t i = 0; i < nrVars(); ++i ) {
206 Var v_i = var(i);
207 VarSet di = delta(i);
208 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
209 if( v_i < *j ) {
210 VarSet ij(v_i,*j);
211 Factor piet;
212 for( size_t I = 0; I < nrFactors(); I++ ) {
213 VarSet Ivars = factor(I).vars();
214 if( props.type == Properties::TypeType::ORG ) {
215 if( (Ivars == v_i) || (Ivars == *j) )
216 piet *= factor(I);
217 else if( Ivars >> ij )
218 piet *= factor(I).marginal( ij );
219 } else {
220 if( Ivars >> ij )
221 piet *= factor(I);
222 }
223 }
224 if( props.type == Properties::TypeType::ORG ) {
225 if( piet.vars() >> ij ) {
226 piet = piet.marginal( ij );
227 Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
228 wg[UEdge(i,findVar(*j))] = dist( piet, pietf, Prob::DISTKL );
229 } else
230 wg[UEdge(i,findVar(*j))] = 0;
231 } else {
232 wg[UEdge(i,findVar(*j))] = piet.strength(v_i, *j);
233 }
234 }
235 }
236
237 // find maximal spanning tree
238 construct( MaxSpanningTreePrims( wg ) );
239 } else
240 DAI_THROW(UNKNOWN_ENUM_VALUE);
241 }
242 }
243
244
245 void TreeEP::construct( const RootedTree &tree ) {
246 vector<VarSet> Cliques;
247 for( size_t i = 0; i < tree.size(); i++ )
248 Cliques.push_back( VarSet( var(tree[i].n1), var(tree[i].n2) ) );
249
250 // Construct a weighted graph (each edge is weighted with the cardinality
251 // of the intersection of the nodes, where the nodes are the elements of
252 // Cliques).
253 WeightedGraph<int> JuncGraph;
254 for( size_t i = 0; i < Cliques.size(); i++ )
255 for( size_t j = i+1; j < Cliques.size(); j++ ) {
256 size_t w = (Cliques[i] & Cliques[j]).size();
257 if( w )
258 JuncGraph[UEdge(i,j)] = w;
259 }
260
261 // Construct maximal spanning tree using Prim's algorithm
262 RTree = MaxSpanningTreePrims( JuncGraph );
263
264 // Construct corresponding region graph
265
266 // Create outer regions
267 ORs.reserve( Cliques.size() );
268 for( size_t i = 0; i < Cliques.size(); i++ )
269 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
270
271 // For each factor, find an outer region that subsumes that factor.
272 // Then, multiply the outer region with that factor.
273 // If no outer region can be found subsuming that factor, label the
274 // factor as off-tree.
275 fac2OR.clear();
276 fac2OR.resize( nrFactors(), -1U );
277 for( size_t I = 0; I < nrFactors(); I++ ) {
278 size_t alpha;
279 for( alpha = 0; alpha < nrORs(); alpha++ )
280 if( OR(alpha).vars() >> factor(I).vars() ) {
281 fac2OR[I] = alpha;
282 break;
283 }
284 // DIFF WITH JTree::GenerateJT: assert
285 }
286 RecomputeORs();
287
288 // Create inner regions and edges
289 IRs.reserve( RTree.size() );
290 vector<Edge> edges;
291 edges.reserve( 2 * RTree.size() );
292 for( size_t i = 0; i < RTree.size(); i++ ) {
293 edges.push_back( Edge( RTree[i].n1, IRs.size() ) );
294 edges.push_back( Edge( RTree[i].n2, IRs.size() ) );
295 // inner clusters have counting number -1
296 IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) );
297 }
298
299 // create bipartite graph
300 G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
301
302 // Check counting numbers
303 checkCountingNumbers();
304
305 // Create messages and beliefs
306 Qa.clear();
307 Qa.reserve( nrORs() );
308 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
309 Qa.push_back( OR(alpha) );
310
311 Qb.clear();
312 Qb.reserve( nrIRs() );
313 for( size_t beta = 0; beta < nrIRs(); beta++ )
314 Qb.push_back( Factor( IR(beta), 1.0 ) );
315
316 // DIFF with JTree::GenerateJT: no messages
317
318 // DIFF with JTree::GenerateJT:
319 // Create factor approximations
320 _Q.clear();
321 size_t PreviousRoot = (size_t)-1;
322 for( size_t I = 0; I < nrFactors(); I++ )
323 if( offtree(I) ) {
324 // find efficient subtree
325 RootedTree subTree;
326 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
327 PreviousRoot = subTree[0].n1;
328 //subTree.resize( subTreeSize ); // FIXME
329 // cerr << "subtree " << I << " has size " << subTreeSize << endl;
330
331 TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
332 _Q[I] = QI;
333 }
334 // Previous root of first off-tree factor should be the root of the last off-tree factor
335 for( size_t I = 0; I < nrFactors(); I++ )
336 if( offtree(I) ) {
337 RootedTree subTree;
338 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
339 PreviousRoot = subTree[0].n1;
340 //subTree.resize( subTreeSize ); // FIXME
341 // cerr << "subtree " << I << " has size " << subTreeSize << endl;
342
343 TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
344 _Q[I] = QI;
345 break;
346 }
347
348 if( props.verbose >= 3 ) {
349 cerr << "Resulting regiongraph: " << *this << endl;
350 }
351 }
352
353
354 string TreeEP::identify() const {
355 return string(Name) + printProperties();
356 }
357
358
359 void TreeEP::init() {
360 runHUGIN();
361
362 // Init factor approximations
363 for( size_t I = 0; I < nrFactors(); I++ )
364 if( offtree(I) )
365 _Q[I].init();
366 }
367
368
369 Real TreeEP::run() {
370 if( props.verbose >= 1 )
371 cerr << "Starting " << identify() << "...";
372 if( props.verbose >= 3 )
373 cerr << endl;
374
375 double tic = toc();
376
377 vector<Factor> oldBeliefsV;
378 oldBeliefsV.reserve( nrVars() );
379 for( size_t i = 0; i < nrVars(); i++ )
380 oldBeliefsV.push_back( beliefV(i) );
381
382 // do several passes over the network until maximum number of iterations has
383 // been reached or until the maximum belief difference is smaller than tolerance
384 Real maxDiff = INFINITY;
385 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol; _iters++ ) {
386 for( size_t I = 0; I < nrFactors(); I++ )
387 if( offtree(I) ) {
388 _Q[I].InvertAndMultiply( Qa, Qb );
389 _Q[I].HUGIN_with_I( Qa, Qb );
390 _Q[I].InvertAndMultiply( Qa, Qb );
391 }
392
393 // calculate new beliefs and compare with old ones
394 maxDiff = -INFINITY;
395 for( size_t i = 0; i < nrVars(); i++ ) {
396 Factor nb( beliefV(i) );
397 maxDiff = std::max( maxDiff, dist( nb, oldBeliefsV[i], Prob::DISTLINF ) );
398 oldBeliefsV[i] = nb;
399 }
400
401 if( props.verbose >= 3 )
402 cerr << Name << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
403 }
404
405 if( maxDiff > _maxdiff )
406 _maxdiff = maxDiff;
407
408 if( props.verbose >= 1 ) {
409 if( maxDiff > props.tol ) {
410 if( props.verbose == 1 )
411 cerr << endl;
412 cerr << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
413 } else {
414 if( props.verbose >= 3 )
415 cerr << Name << "::run: ";
416 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
417 }
418 }
419
420 return maxDiff;
421 }
422
423
424 Real TreeEP::logZ() const {
425 Real s = 0.0;
426
427 // entropy of the tree
428 for( size_t beta = 0; beta < nrIRs(); beta++ )
429 s -= Qb[beta].entropy();
430 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
431 s += Qa[alpha].entropy();
432
433 // energy of the on-tree factors
434 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
435 s += (OR(alpha).log(true) * Qa[alpha]).sum();
436
437 // energy of the off-tree factors
438 for( size_t I = 0; I < nrFactors(); I++ )
439 if( offtree(I) )
440 s += (_Q.find(I))->second.logZ( Qa, Qb );
441
442 return s;
443 }
444
445
446 } // end of namespace dai