Various cleanups
[libdai.git] / src / treeep.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <fstream>
25 #include <vector>
26 #include <dai/jtree.h>
27 #include <dai/treeep.h>
28 #include <dai/util.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *TreeEP::Name = "TREEEP";
38
39
40 void TreeEP::setProperties( const PropertySet &opts ) {
41 assert( opts.hasKey("tol") );
42 assert( opts.hasKey("maxiter") );
43 assert( opts.hasKey("verbose") );
44 assert( opts.hasKey("type") );
45
46 props.tol = opts.getStringAs<double>("tol");
47 props.maxiter = opts.getStringAs<size_t>("maxiter");
48 props.verbose = opts.getStringAs<size_t>("verbose");
49 props.type = opts.getStringAs<Properties::TypeType>("type");
50 }
51
52
53 PropertySet TreeEP::getProperties() const {
54 PropertySet opts;
55 opts.Set( "tol", props.tol );
56 opts.Set( "maxiter", props.maxiter );
57 opts.Set( "verbose", props.verbose );
58 opts.Set( "type", props.type );
59 return opts;
60 }
61
62
63 string TreeEP::printProperties() const {
64 stringstream s( stringstream::out );
65 s << "[";
66 s << "tol=" << props.tol << ",";
67 s << "maxiter=" << props.maxiter << ",";
68 s << "verbose=" << props.verbose << ",";
69 s << "type=" << props.type << "]";
70 return s.str();
71 }
72
73
74 TreeEP::TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
75 _ns = _I->vars();
76
77 // Make _Qa, _Qb, _a and _b corresponding to the subtree
78 _b.reserve( subRTree.size() );
79 _Qb.reserve( subRTree.size() );
80 _RTree.reserve( subRTree.size() );
81 for( size_t i = 0; i < subRTree.size(); i++ ) {
82 size_t alpha1 = subRTree[i].n1; // old index 1
83 size_t alpha2 = subRTree[i].n2; // old index 2
84 size_t beta; // old sep index
85 for( beta = 0; beta < jt_RTree.size(); beta++ )
86 if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
87 break;
88 assert( beta != jt_RTree.size() );
89
90 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
91 if( newalpha1 == _a.size() ) {
92 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
93 _a.push_back( alpha1 ); // save old index in index conversion table
94 }
95
96 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
97 if( newalpha2 == _a.size() ) {
98 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
99 _a.push_back( alpha2 ); // save old index in index conversion table
100 }
101
102 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
103 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
104 _b.push_back( beta );
105 }
106
107 // Find remaining variables (which are not in the new root)
108 _nsrem = _ns / _Qa[0].vars();
109 }
110
111
112 void TreeEP::TreeEPSubTree::init() {
113 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
114 _Qa[alpha].fill( 1.0 );
115 for( size_t beta = 0; beta < _Qb.size(); beta++ )
116 _Qb[beta].fill( 1.0 );
117 }
118
119
120 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
121 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
122 _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
123
124 for( size_t beta = 0; beta < _Qb.size(); beta++ )
125 _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
126 }
127
128
129 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
130 // Backup _Qa and _Qb
131 vector<Factor> _Qa_old(_Qa);
132 vector<Factor> _Qb_old(_Qb);
133
134 // Clear Qa and Qb
135 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
136 Qa[_a[alpha]].fill( 0.0 );
137 for( size_t beta = 0; beta < _Qb.size(); beta++ )
138 Qb[_b[beta]].fill( 0.0 );
139
140 // For all states of _nsrem
141 for( State s(_nsrem); s.valid(); s++ ) {
142 // Multiply root with slice of I
143 _Qa[0] *= _I->slice( _nsrem, s );
144
145 // CollectEvidence
146 for( size_t i = _RTree.size(); (i--) != 0; ) {
147 // clamp variables in nsrem
148 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
149 if( _Qa[_RTree[i].n2].vars() >> *n ) {
150 Factor delta( *n, 0.0 );
151 delta[s(*n)] = 1.0;
152 _Qa[_RTree[i].n2] *= delta;
153 }
154 Factor new_Qb = _Qa[_RTree[i].n2].marginal( _Qb[i].vars(), false );
155 _Qa[_RTree[i].n1] *= new_Qb / _Qb[i];
156 _Qb[i] = new_Qb;
157 }
158
159 // DistributeEvidence
160 for( size_t i = 0; i < _RTree.size(); i++ ) {
161 Factor new_Qb = _Qa[_RTree[i].n1].marginal( _Qb[i].vars(), false );
162 _Qa[_RTree[i].n2] *= new_Qb / _Qb[i];
163 _Qb[i] = new_Qb;
164 }
165
166 // Store Qa's and Qb's
167 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
168 Qa[_a[alpha]].p() += _Qa[alpha].p();
169 for( size_t beta = 0; beta < _Qb.size(); beta++ )
170 Qb[_b[beta]].p() += _Qb[beta].p();
171
172 // Restore _Qa and _Qb
173 _Qa = _Qa_old;
174 _Qb = _Qb_old;
175 }
176
177 // Normalize Qa and Qb
178 _logZ = 0.0;
179 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
180 _logZ += log(Qa[_a[alpha]].sum());
181 Qa[_a[alpha]].normalize();
182 }
183 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
184 _logZ -= log(Qb[_b[beta]].sum());
185 Qb[_b[beta]].normalize();
186 }
187 }
188
189
190 double TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
191 double s = 0.0;
192 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
193 s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
194 for( size_t beta = 0; beta < _Qb.size(); beta++ )
195 s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
196 return s + _logZ;
197 }
198
199
200 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
201 setProperties( opts );
202
203 assert( fg.isConnected() );
204
205 if( opts.hasKey("tree") ) {
206 ConstructRG( opts.GetAs<DEdgeVec>("tree") );
207 } else {
208 if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
209 // ORG: construct weighted graph with as weights a crude estimate of the
210 // mutual information between the nodes
211 // ALT: construct weighted graph with as weights an upper bound on the
212 // effective interaction strength between pairs of nodes
213
214 WeightedGraph<double> wg;
215 for( size_t i = 0; i < nrVars(); ++i ) {
216 Var v_i = var(i);
217 VarSet di = delta(i);
218 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
219 if( v_i < *j ) {
220 VarSet ij(v_i,*j);
221 Factor piet;
222 for( size_t I = 0; I < nrFactors(); I++ ) {
223 VarSet Ivars = factor(I).vars();
224 if( props.type == Properties::TypeType::ORG ) {
225 if( (Ivars == v_i) || (Ivars == *j) )
226 piet *= factor(I);
227 else if( Ivars >> ij )
228 piet *= factor(I).marginal( ij );
229 } else {
230 if( Ivars >> ij )
231 piet *= factor(I);
232 }
233 }
234 if( props.type == Properties::TypeType::ORG ) {
235 if( piet.vars() >> ij ) {
236 piet = piet.marginal( ij );
237 Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
238 wg[UEdge(i,findVar(*j))] = dist( piet, pietf, Prob::DISTKL );
239 } else
240 wg[UEdge(i,findVar(*j))] = 0;
241 } else {
242 wg[UEdge(i,findVar(*j))] = piet.strength(v_i, *j);
243 }
244 }
245 }
246
247 // find maximal spanning tree
248 ConstructRG( MaxSpanningTreePrims( wg ) );
249 } else
250 DAI_THROW(INTERNAL_ERROR);
251 }
252 }
253
254
255 void TreeEP::ConstructRG( const DEdgeVec &tree ) {
256 vector<VarSet> Cliques;
257 for( size_t i = 0; i < tree.size(); i++ )
258 Cliques.push_back( VarSet( var(tree[i].n1), var(tree[i].n2) ) );
259
260 // Construct a weighted graph (each edge is weighted with the cardinality
261 // of the intersection of the nodes, where the nodes are the elements of
262 // Cliques).
263 WeightedGraph<int> JuncGraph;
264 for( size_t i = 0; i < Cliques.size(); i++ )
265 for( size_t j = i+1; j < Cliques.size(); j++ ) {
266 size_t w = (Cliques[i] & Cliques[j]).size();
267 if( w )
268 JuncGraph[UEdge(i,j)] = w;
269 }
270
271 // Construct maximal spanning tree using Prim's algorithm
272 RTree = MaxSpanningTreePrims( JuncGraph );
273
274 // Construct corresponding region graph
275
276 // Create outer regions
277 ORs.reserve( Cliques.size() );
278 for( size_t i = 0; i < Cliques.size(); i++ )
279 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
280
281 // For each factor, find an outer region that subsumes that factor.
282 // Then, multiply the outer region with that factor.
283 // If no outer region can be found subsuming that factor, label the
284 // factor as off-tree.
285 fac2OR.clear();
286 fac2OR.resize( nrFactors(), -1U );
287 for( size_t I = 0; I < nrFactors(); I++ ) {
288 size_t alpha;
289 for( alpha = 0; alpha < nrORs(); alpha++ )
290 if( OR(alpha).vars() >> factor(I).vars() ) {
291 fac2OR[I] = alpha;
292 break;
293 }
294 // DIFF WITH JTree::GenerateJT: assert
295 }
296 RecomputeORs();
297
298 // Create inner regions and edges
299 IRs.reserve( RTree.size() );
300 vector<Edge> edges;
301 edges.reserve( 2 * RTree.size() );
302 for( size_t i = 0; i < RTree.size(); i++ ) {
303 edges.push_back( Edge( RTree[i].n1, IRs.size() ) );
304 edges.push_back( Edge( RTree[i].n2, IRs.size() ) );
305 // inner clusters have counting number -1
306 IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) );
307 }
308
309 // create bipartite graph
310 G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
311
312 // Check counting numbers
313 Check_Counting_Numbers();
314
315 // Create messages and beliefs
316 Qa.clear();
317 Qa.reserve( nrORs() );
318 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
319 Qa.push_back( OR(alpha) );
320
321 Qb.clear();
322 Qb.reserve( nrIRs() );
323 for( size_t beta = 0; beta < nrIRs(); beta++ )
324 Qb.push_back( Factor( IR(beta), 1.0 ) );
325
326 // DIFF with JTree::GenerateJT: no messages
327
328 // DIFF with JTree::GenerateJT:
329 // Create factor approximations
330 _Q.clear();
331 size_t PreviousRoot = (size_t)-1;
332 for( size_t I = 0; I < nrFactors(); I++ )
333 if( offtree(I) ) {
334 // find efficient subtree
335 DEdgeVec subTree;
336 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
337 PreviousRoot = subTree[0].n1;
338 //subTree.resize( subTreeSize ); // FIXME
339 // cerr << "subtree " << I << " has size " << subTreeSize << endl;
340
341 TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
342 _Q[I] = QI;
343 }
344 // Previous root of first off-tree factor should be the root of the last off-tree factor
345 for( size_t I = 0; I < nrFactors(); I++ )
346 if( offtree(I) ) {
347 DEdgeVec subTree;
348 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
349 PreviousRoot = subTree[0].n1;
350 //subTree.resize( subTreeSize ); // FIXME
351 // cerr << "subtree " << I << " has size " << subTreeSize << endl;
352
353 TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
354 _Q[I] = QI;
355 break;
356 }
357
358 if( props.verbose >= 3 ) {
359 cerr << "Resulting regiongraph: " << *this << endl;
360 }
361 }
362
363
364 string TreeEP::identify() const {
365 return string(Name) + printProperties();
366 }
367
368
369 void TreeEP::init() {
370 runHUGIN();
371
372 // Init factor approximations
373 for( size_t I = 0; I < nrFactors(); I++ )
374 if( offtree(I) )
375 _Q[I].init();
376 }
377
378
379 double TreeEP::run() {
380 if( props.verbose >= 1 )
381 cerr << "Starting " << identify() << "...";
382 if( props.verbose >= 3)
383 cerr << endl;
384
385 double tic = toc();
386 Diffs diffs(nrVars(), 1.0);
387
388 vector<Factor> old_beliefs;
389 old_beliefs.reserve( nrVars() );
390 for( size_t i = 0; i < nrVars(); i++ )
391 old_beliefs.push_back(belief(var(i)));
392
393 // do several passes over the network until maximum number of iterations has
394 // been reached or until the maximum belief difference is smaller than tolerance
395 for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; _iters++ ) {
396 for( size_t I = 0; I < nrFactors(); I++ )
397 if( offtree(I) ) {
398 _Q[I].InvertAndMultiply( Qa, Qb );
399 _Q[I].HUGIN_with_I( Qa, Qb );
400 _Q[I].InvertAndMultiply( Qa, Qb );
401 }
402
403 // calculate new beliefs and compare with old ones
404 for( size_t i = 0; i < nrVars(); i++ ) {
405 Factor nb( belief(var(i)) );
406 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
407 old_beliefs[i] = nb;
408 }
409
410 if( props.verbose >= 3 )
411 cerr << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
412 }
413
414 if( diffs.maxDiff() > _maxdiff )
415 _maxdiff = diffs.maxDiff();
416
417 if( props.verbose >= 1 ) {
418 if( diffs.maxDiff() > props.tol ) {
419 if( props.verbose == 1 )
420 cerr << endl;
421 cerr << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
422 } else {
423 if( props.verbose >= 3 )
424 cerr << Name << "::run: ";
425 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
426 }
427 }
428
429 return diffs.maxDiff();
430 }
431
432
433 Real TreeEP::logZ() const {
434 double s = 0.0;
435
436 // entropy of the tree
437 for( size_t beta = 0; beta < nrIRs(); beta++ )
438 s -= Qb[beta].entropy();
439 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
440 s += Qa[alpha].entropy();
441
442 // energy of the on-tree factors
443 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
444 s += (OR(alpha).log(true) * Qa[alpha]).sum();
445
446 // energy of the off-tree factors
447 for( size_t I = 0; I < nrFactors(); I++ )
448 if( offtree(I) )
449 s += (_Q.find(I))->second.logZ( Qa, Qb );
450
451 return s;
452 }
453
454
455 } // end of namespace dai