Updated copyrights
[libdai.git] / src / treeep.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <fstream>
25 #include <vector>
26 #include <dai/jtree.h>
27 #include <dai/treeep.h>
28 #include <dai/util.h>
29 #include <dai/diffs.h>
30
31
32 namespace dai {
33
34
35 using namespace std;
36
37
38 const char *TreeEP::Name = "TREEEP";
39
40
41 void TreeEP::setProperties( const PropertySet &opts ) {
42 assert( opts.hasKey("tol") );
43 assert( opts.hasKey("maxiter") );
44 assert( opts.hasKey("verbose") );
45 assert( opts.hasKey("type") );
46
47 props.tol = opts.getStringAs<double>("tol");
48 props.maxiter = opts.getStringAs<size_t>("maxiter");
49 props.verbose = opts.getStringAs<size_t>("verbose");
50 props.type = opts.getStringAs<Properties::TypeType>("type");
51 }
52
53
54 PropertySet TreeEP::getProperties() const {
55 PropertySet opts;
56 opts.Set( "tol", props.tol );
57 opts.Set( "maxiter", props.maxiter );
58 opts.Set( "verbose", props.verbose );
59 opts.Set( "type", props.type );
60 return opts;
61 }
62
63
64 string TreeEP::printProperties() const {
65 stringstream s( stringstream::out );
66 s << "[";
67 s << "tol=" << props.tol << ",";
68 s << "maxiter=" << props.maxiter << ",";
69 s << "verbose=" << props.verbose << ",";
70 s << "type=" << props.type << "]";
71 return s.str();
72 }
73
74
75 TreeEP::TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
76 _ns = _I->vars();
77
78 // Make _Qa, _Qb, _a and _b corresponding to the subtree
79 _b.reserve( subRTree.size() );
80 _Qb.reserve( subRTree.size() );
81 _RTree.reserve( subRTree.size() );
82 for( size_t i = 0; i < subRTree.size(); i++ ) {
83 size_t alpha1 = subRTree[i].n1; // old index 1
84 size_t alpha2 = subRTree[i].n2; // old index 2
85 size_t beta; // old sep index
86 for( beta = 0; beta < jt_RTree.size(); beta++ )
87 if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
88 break;
89 assert( beta != jt_RTree.size() );
90
91 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
92 if( newalpha1 == _a.size() ) {
93 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
94 _a.push_back( alpha1 ); // save old index in index conversion table
95 }
96
97 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
98 if( newalpha2 == _a.size() ) {
99 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
100 _a.push_back( alpha2 ); // save old index in index conversion table
101 }
102
103 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
104 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
105 _b.push_back( beta );
106 }
107
108 // Find remaining variables (which are not in the new root)
109 _nsrem = _ns / _Qa[0].vars();
110 }
111
112
113 void TreeEP::TreeEPSubTree::init() {
114 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
115 _Qa[alpha].fill( 1.0 );
116 for( size_t beta = 0; beta < _Qb.size(); beta++ )
117 _Qb[beta].fill( 1.0 );
118 }
119
120
121 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
122 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
123 _Qa[alpha] = Qa[_a[alpha]].divided_by( _Qa[alpha] );
124
125 for( size_t beta = 0; beta < _Qb.size(); beta++ )
126 _Qb[beta] = Qb[_b[beta]].divided_by( _Qb[beta] );
127 }
128
129
130 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
131 // Backup _Qa and _Qb
132 vector<Factor> _Qa_old(_Qa);
133 vector<Factor> _Qb_old(_Qb);
134
135 // Clear Qa and Qb
136 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
137 Qa[_a[alpha]].fill( 0.0 );
138 for( size_t beta = 0; beta < _Qb.size(); beta++ )
139 Qb[_b[beta]].fill( 0.0 );
140
141 // For all states of _nsrem
142 for( State s(_nsrem); s.valid(); s++ ) {
143 // Multiply root with slice of I
144 _Qa[0] *= _I->slice( _nsrem, s );
145
146 // CollectEvidence
147 for( size_t i = _RTree.size(); (i--) != 0; ) {
148 // clamp variables in nsrem
149 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
150 if( _Qa[_RTree[i].n2].vars() >> *n ) {
151 Factor delta( *n, 0.0 );
152 delta[s(*n)] = 1.0;
153 _Qa[_RTree[i].n2] *= delta;
154 }
155 Factor new_Qb = _Qa[_RTree[i].n2].partSum( _Qb[i].vars() );
156 _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
157 _Qb[i] = new_Qb;
158 }
159
160 // DistributeEvidence
161 for( size_t i = 0; i < _RTree.size(); i++ ) {
162 Factor new_Qb = _Qa[_RTree[i].n1].partSum( _Qb[i].vars() );
163 _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
164 _Qb[i] = new_Qb;
165 }
166
167 // Store Qa's and Qb's
168 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
169 Qa[_a[alpha]].p() += _Qa[alpha].p();
170 for( size_t beta = 0; beta < _Qb.size(); beta++ )
171 Qb[_b[beta]].p() += _Qb[beta].p();
172
173 // Restore _Qa and _Qb
174 _Qa = _Qa_old;
175 _Qb = _Qb_old;
176 }
177
178 // Normalize Qa and Qb
179 _logZ = 0.0;
180 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
181 _logZ += log(Qa[_a[alpha]].totalSum());
182 Qa[_a[alpha]].normalize();
183 }
184 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
185 _logZ -= log(Qb[_b[beta]].totalSum());
186 Qb[_b[beta]].normalize();
187 }
188 }
189
190
191 double TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
192 double sum = 0.0;
193 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
194 sum += (Qa[_a[alpha]] * _Qa[alpha].log0()).totalSum();
195 for( size_t beta = 0; beta < _Qb.size(); beta++ )
196 sum -= (Qb[_b[beta]] * _Qb[beta].log0()).totalSum();
197 return sum + _logZ;
198 }
199
200
201 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
202 setProperties( opts );
203
204 assert( fg.isConnected() );
205
206 if( opts.hasKey("tree") ) {
207 ConstructRG( opts.GetAs<DEdgeVec>("tree") );
208 } else {
209 if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
210 // ORG: construct weighted graph with as weights a crude estimate of the
211 // mutual information between the nodes
212 // ALT: construct weighted graph with as weights an upper bound on the
213 // effective interaction strength between pairs of nodes
214
215 WeightedGraph<double> wg;
216 for( size_t i = 0; i < nrVars(); ++i ) {
217 Var v_i = var(i);
218 VarSet di = delta(i);
219 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
220 if( v_i < *j ) {
221 VarSet ij(v_i,*j);
222 Factor piet;
223 for( size_t I = 0; I < nrFactors(); I++ ) {
224 VarSet Ivars = factor(I).vars();
225 if( props.type == Properties::TypeType::ORG ) {
226 if( (Ivars == v_i) || (Ivars == *j) )
227 piet *= factor(I);
228 else if( Ivars >> ij )
229 piet *= factor(I).marginal( ij );
230 } else {
231 if( Ivars >> ij )
232 piet *= factor(I);
233 }
234 }
235 if( props.type == Properties::TypeType::ORG ) {
236 if( piet.vars() >> ij ) {
237 piet = piet.marginal( ij );
238 Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
239 wg[UEdge(i,findVar(*j))] = KL_dist( piet, pietf );
240 } else
241 wg[UEdge(i,findVar(*j))] = 0;
242 } else {
243 wg[UEdge(i,findVar(*j))] = piet.strength(v_i, *j);
244 }
245 }
246 }
247
248 // find maximal spanning tree
249 ConstructRG( MaxSpanningTreePrims( wg ) );
250 } else
251 DAI_THROW(INTERNAL_ERROR);
252 }
253 }
254
255
256 void TreeEP::ConstructRG( const DEdgeVec &tree ) {
257 vector<VarSet> Cliques;
258 for( size_t i = 0; i < tree.size(); i++ )
259 Cliques.push_back( var(tree[i].n1) | var(tree[i].n2) );
260
261 // Construct a weighted graph (each edge is weighted with the cardinality
262 // of the intersection of the nodes, where the nodes are the elements of
263 // Cliques).
264 WeightedGraph<int> JuncGraph;
265 for( size_t i = 0; i < Cliques.size(); i++ )
266 for( size_t j = i+1; j < Cliques.size(); j++ ) {
267 size_t w = (Cliques[i] & Cliques[j]).size();
268 if( w )
269 JuncGraph[UEdge(i,j)] = w;
270 }
271
272 // Construct maximal spanning tree using Prim's algorithm
273 RTree = MaxSpanningTreePrims( JuncGraph );
274
275 // Construct corresponding region graph
276
277 // Create outer regions
278 ORs.reserve( Cliques.size() );
279 for( size_t i = 0; i < Cliques.size(); i++ )
280 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
281
282 // For each factor, find an outer region that subsumes that factor.
283 // Then, multiply the outer region with that factor.
284 // If no outer region can be found subsuming that factor, label the
285 // factor as off-tree.
286 fac2OR.clear();
287 fac2OR.resize( nrFactors(), -1U );
288 for( size_t I = 0; I < nrFactors(); I++ ) {
289 size_t alpha;
290 for( alpha = 0; alpha < nrORs(); alpha++ )
291 if( OR(alpha).vars() >> factor(I).vars() ) {
292 fac2OR[I] = alpha;
293 break;
294 }
295 // DIFF WITH JTree::GenerateJT: assert
296 }
297 RecomputeORs();
298
299 // Create inner regions and edges
300 IRs.reserve( RTree.size() );
301 vector<Edge> edges;
302 edges.reserve( 2 * RTree.size() );
303 for( size_t i = 0; i < RTree.size(); i++ ) {
304 edges.push_back( Edge( RTree[i].n1, IRs.size() ) );
305 edges.push_back( Edge( RTree[i].n2, IRs.size() ) );
306 // inner clusters have counting number -1
307 IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) );
308 }
309
310 // create bipartite graph
311 G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
312
313 // Check counting numbers
314 Check_Counting_Numbers();
315
316 // Create messages and beliefs
317 Qa.clear();
318 Qa.reserve( nrORs() );
319 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
320 Qa.push_back( OR(alpha) );
321
322 Qb.clear();
323 Qb.reserve( nrIRs() );
324 for( size_t beta = 0; beta < nrIRs(); beta++ )
325 Qb.push_back( Factor( IR(beta), 1.0 ) );
326
327 // DIFF with JTree::GenerateJT: no messages
328
329 // DIFF with JTree::GenerateJT:
330 // Create factor approximations
331 _Q.clear();
332 size_t PreviousRoot = (size_t)-1;
333 for( size_t I = 0; I < nrFactors(); I++ )
334 if( offtree(I) ) {
335 // find efficient subtree
336 DEdgeVec subTree;
337 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
338 PreviousRoot = subTree[0].n1;
339 //subTree.resize( subTreeSize ); // FIXME
340 // cout << "subtree " << I << " has size " << subTreeSize << endl;
341
342 TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
343 _Q[I] = QI;
344 }
345 // Previous root of first off-tree factor should be the root of the last off-tree factor
346 for( size_t I = 0; I < nrFactors(); I++ )
347 if( offtree(I) ) {
348 DEdgeVec subTree;
349 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
350 PreviousRoot = subTree[0].n1;
351 //subTree.resize( subTreeSize ); // FIXME
352 // cout << "subtree " << I << " has size " << subTreeSize << endl;
353
354 TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
355 _Q[I] = QI;
356 break;
357 }
358
359 if( props.verbose >= 3 ) {
360 cout << "Resulting regiongraph: " << *this << endl;
361 }
362 }
363
364
365 string TreeEP::identify() const {
366 return string(Name) + printProperties();
367 }
368
369
370 void TreeEP::init() {
371 runHUGIN();
372
373 // Init factor approximations
374 for( size_t I = 0; I < nrFactors(); I++ )
375 if( offtree(I) )
376 _Q[I].init();
377 }
378
379
380 double TreeEP::run() {
381 if( props.verbose >= 1 )
382 cout << "Starting " << identify() << "...";
383 if( props.verbose >= 3)
384 cout << endl;
385
386 double tic = toc();
387 Diffs diffs(nrVars(), 1.0);
388
389 vector<Factor> old_beliefs;
390 old_beliefs.reserve( nrVars() );
391 for( size_t i = 0; i < nrVars(); i++ )
392 old_beliefs.push_back(belief(var(i)));
393
394 // do several passes over the network until maximum number of iterations has
395 // been reached or until the maximum belief difference is smaller than tolerance
396 for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; _iters++ ) {
397 for( size_t I = 0; I < nrFactors(); I++ )
398 if( offtree(I) ) {
399 _Q[I].InvertAndMultiply( Qa, Qb );
400 _Q[I].HUGIN_with_I( Qa, Qb );
401 _Q[I].InvertAndMultiply( Qa, Qb );
402 }
403
404 // calculate new beliefs and compare with old ones
405 for( size_t i = 0; i < nrVars(); i++ ) {
406 Factor nb( belief(var(i)) );
407 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
408 old_beliefs[i] = nb;
409 }
410
411 if( props.verbose >= 3 )
412 cout << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
413 }
414
415 if( diffs.maxDiff() > _maxdiff )
416 _maxdiff = diffs.maxDiff();
417
418 if( props.verbose >= 1 ) {
419 if( diffs.maxDiff() > props.tol ) {
420 if( props.verbose == 1 )
421 cout << endl;
422 cout << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
423 } else {
424 if( props.verbose >= 3 )
425 cout << Name << "::run: ";
426 cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
427 }
428 }
429
430 return diffs.maxDiff();
431 }
432
433
434 Real TreeEP::logZ() const {
435 double sum = 0.0;
436
437 // entropy of the tree
438 for( size_t beta = 0; beta < nrIRs(); beta++ )
439 sum -= Qb[beta].entropy();
440 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
441 sum += Qa[alpha].entropy();
442
443 // energy of the on-tree factors
444 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
445 sum += (OR(alpha).log0() * Qa[alpha]).totalSum();
446
447 // energy of the off-tree factors
448 for( size_t I = 0; I < nrFactors(); I++ )
449 if( offtree(I) )
450 sum += (_Q.find(I))->second.logZ( Qa, Qb );
451
452 return sum;
453 }
454
455
456 } // end of namespace dai