Merge branch 'master' of git@git.tuebingen.mpg.de:libdai
[libdai.git] / src / treeep.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <fstream>
24 #include <vector>
25 #include <dai/jtree.h>
26 #include <dai/treeep.h>
27 #include <dai/util.h>
28 #include <dai/diffs.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *TreeEP::Name = "TREEEP";
38
39
40 void TreeEP::setProperties( const PropertySet &opts ) {
41 assert( opts.hasKey("tol") );
42 assert( opts.hasKey("maxiter") );
43 assert( opts.hasKey("verbose") );
44 assert( opts.hasKey("type") );
45
46 props.tol = opts.getStringAs<double>("tol");
47 props.maxiter = opts.getStringAs<size_t>("maxiter");
48 props.verbose = opts.getStringAs<size_t>("verbose");
49 props.type = opts.getStringAs<Properties::TypeType>("type");
50 }
51
52
53 PropertySet TreeEP::getProperties() const {
54 PropertySet opts;
55 opts.Set( "tol", props.tol );
56 opts.Set( "maxiter", props.maxiter );
57 opts.Set( "verbose", props.verbose );
58 opts.Set( "type", props.type );
59 return opts;
60 }
61
62
63 TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
64 _ns = _I->vars();
65
66 // Make _Qa, _Qb, _a and _b corresponding to the subtree
67 _b.reserve( subRTree.size() );
68 _Qb.reserve( subRTree.size() );
69 _RTree.reserve( subRTree.size() );
70 for( size_t i = 0; i < subRTree.size(); i++ ) {
71 size_t alpha1 = subRTree[i].n1; // old index 1
72 size_t alpha2 = subRTree[i].n2; // old index 2
73 size_t beta; // old sep index
74 for( beta = 0; beta < jt_RTree.size(); beta++ )
75 if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
76 break;
77 assert( beta != jt_RTree.size() );
78
79 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
80 if( newalpha1 == _a.size() ) {
81 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
82 _a.push_back( alpha1 ); // save old index in index conversion table
83 }
84
85 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
86 if( newalpha2 == _a.size() ) {
87 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
88 _a.push_back( alpha2 ); // save old index in index conversion table
89 }
90
91 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
92 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
93 _b.push_back( beta );
94 }
95
96 // Find remaining variables (which are not in the new root)
97 _nsrem = _ns / _Qa[0].vars();
98 };
99
100
101 void TreeEPSubTree::init() {
102 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
103 _Qa[alpha].fill( 1.0 );
104 for( size_t beta = 0; beta < _Qb.size(); beta++ )
105 _Qb[beta].fill( 1.0 );
106 }
107
108
109 void TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
110 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
111 _Qa[alpha] = Qa[_a[alpha]].divided_by( _Qa[alpha] );
112
113 for( size_t beta = 0; beta < _Qb.size(); beta++ )
114 _Qb[beta] = Qb[_b[beta]].divided_by( _Qb[beta] );
115 }
116
117
118 void TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
119 // Backup _Qa and _Qb
120 vector<Factor> _Qa_old(_Qa);
121 vector<Factor> _Qb_old(_Qb);
122
123 // Clear Qa and Qb
124 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
125 Qa[_a[alpha]].fill( 0.0 );
126 for( size_t beta = 0; beta < _Qb.size(); beta++ )
127 Qb[_b[beta]].fill( 0.0 );
128
129 // For all states of _nsrem
130 for( State s(_nsrem); s.valid(); s++ ) {
131 // Multiply root with slice of I
132 _Qa[0] *= _I->slice( _nsrem, s );
133
134 // CollectEvidence
135 for( size_t i = _RTree.size(); (i--) != 0; ) {
136 // clamp variables in nsrem
137 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
138 if( _Qa[_RTree[i].n2].vars() >> *n ) {
139 Factor delta( *n, 0.0 );
140 delta[s(*n)] = 1.0;
141 _Qa[_RTree[i].n2] *= delta;
142 }
143 Factor new_Qb = _Qa[_RTree[i].n2].part_sum( _Qb[i].vars() );
144 _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
145 _Qb[i] = new_Qb;
146 }
147
148 // DistributeEvidence
149 for( size_t i = 0; i < _RTree.size(); i++ ) {
150 Factor new_Qb = _Qa[_RTree[i].n1].part_sum( _Qb[i].vars() );
151 _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
152 _Qb[i] = new_Qb;
153 }
154
155 // Store Qa's and Qb's
156 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
157 Qa[_a[alpha]].p() += _Qa[alpha].p();
158 for( size_t beta = 0; beta < _Qb.size(); beta++ )
159 Qb[_b[beta]].p() += _Qb[beta].p();
160
161 // Restore _Qa and _Qb
162 _Qa = _Qa_old;
163 _Qb = _Qb_old;
164 }
165
166 // Normalize Qa and Qb
167 _logZ = 0.0;
168 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
169 _logZ += log(Qa[_a[alpha]].totalSum());
170 Qa[_a[alpha]].normalize( Prob::NORMPROB );
171 }
172 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
173 _logZ -= log(Qb[_b[beta]].totalSum());
174 Qb[_b[beta]].normalize( Prob::NORMPROB );
175 }
176 }
177
178
179 double TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
180 double sum = 0.0;
181 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
182 sum += (Qa[_a[alpha]] * _Qa[alpha].log0()).totalSum();
183 for( size_t beta = 0; beta < _Qb.size(); beta++ )
184 sum -= (Qb[_b[beta]] * _Qb[beta].log0()).totalSum();
185 return sum + _logZ;
186 }
187
188
189 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), props(), maxdiff(0.0) {
190 setProperties( opts );
191
192 assert( fg.G.isConnected() );
193
194 if( opts.hasKey("tree") ) {
195 ConstructRG( opts.GetAs<DEdgeVec>("tree") );
196 } else {
197 if( props.type == Properties::TypeType::ORG ) {
198 // construct weighted graph with as weights a crude estimate of the
199 // mutual information between the nodes
200 WeightedGraph<double> wg;
201 for( size_t i = 0; i < nrVars(); ++i ) {
202 Var v_i = var(i);
203 VarSet di = delta(i);
204 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
205 if( v_i < *j ) {
206 Factor piet;
207 for( size_t I = 0; I < nrFactors(); I++ ) {
208 VarSet Ivars = factor(I).vars();
209 if( (Ivars == v_i) || (Ivars == *j) )
210 piet *= factor(I);
211 else if( Ivars >> (v_i | *j) )
212 piet *= factor(I).marginal( v_i | *j );
213 }
214 if( piet.vars() >> (v_i | *j) ) {
215 piet = piet.marginal( v_i | *j );
216 Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
217 wg[UEdge(i,findVar(*j))] = KL_dist( piet, pietf );
218 } else
219 wg[UEdge(i,findVar(*j))] = 0;
220 }
221 }
222
223 // find maximal spanning tree
224 ConstructRG( MaxSpanningTreePrims( wg ) );
225
226 // cout << "Constructing maximum spanning tree..." << endl;
227 // DEdgeVec MST = MaxSpanningTreePrims( wg );
228 // cout << "Maximum spanning tree:" << endl;
229 // for( DEdgeVec::const_iterator e = MST.begin(); e != MST.end(); e++ )
230 // cout << *e << endl;
231 // ConstructRG( MST );
232 } else if( props.type == Properties::TypeType::ALT ) {
233 // construct weighted graph with as weights an upper bound on the
234 // effective interaction strength between pairs of nodes
235 WeightedGraph<double> wg;
236 for( size_t i = 0; i < nrVars(); ++i ) {
237 Var v_i = var(i);
238 VarSet di = delta(i);
239 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
240 if( v_i < *j ) {
241 Factor piet;
242 for( size_t I = 0; I < nrFactors(); I++ ) {
243 VarSet Ivars = factor(I).vars();
244 if( Ivars >> (v_i | *j) )
245 piet *= factor(I);
246 }
247 wg[UEdge(i,findVar(*j))] = piet.strength(v_i, *j);
248 }
249 }
250
251 // find maximal spanning tree
252 ConstructRG( MaxSpanningTreePrims( wg ) );
253 } else {
254 DAI_THROW(INTERNAL_ERROR);
255 }
256 }
257 }
258
259
260 void TreeEP::ConstructRG( const DEdgeVec &tree ) {
261 vector<VarSet> Cliques;
262 for( size_t i = 0; i < tree.size(); i++ )
263 Cliques.push_back( var(tree[i].n1) | var(tree[i].n2) );
264
265 // Construct a weighted graph (each edge is weighted with the cardinality
266 // of the intersection of the nodes, where the nodes are the elements of
267 // Cliques).
268 WeightedGraph<int> JuncGraph;
269 for( size_t i = 0; i < Cliques.size(); i++ )
270 for( size_t j = i+1; j < Cliques.size(); j++ ) {
271 size_t w = (Cliques[i] & Cliques[j]).size();
272 JuncGraph[UEdge(i,j)] = w;
273 }
274
275 // Construct maximal spanning tree using Prim's algorithm
276 _RTree = MaxSpanningTreePrims( JuncGraph );
277
278 // Construct corresponding region graph
279
280 // Create outer regions
281 ORs.reserve( Cliques.size() );
282 for( size_t i = 0; i < Cliques.size(); i++ )
283 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
284
285 // For each factor, find an outer region that subsumes that factor.
286 // Then, multiply the outer region with that factor.
287 // If no outer region can be found subsuming that factor, label the
288 // factor as off-tree.
289 fac2OR.clear();
290 fac2OR.resize( nrFactors(), -1U );
291 for( size_t I = 0; I < nrFactors(); I++ ) {
292 size_t alpha;
293 for( alpha = 0; alpha < nrORs(); alpha++ )
294 if( OR(alpha).vars() >> factor(I).vars() ) {
295 fac2OR[I] = alpha;
296 break;
297 }
298 // DIFF WITH JTree::GenerateJT: assert
299 }
300 RecomputeORs();
301
302 // Create inner regions and edges
303 IRs.reserve( _RTree.size() );
304 vector<Edge> edges;
305 edges.reserve( 2 * _RTree.size() );
306 for( size_t i = 0; i < _RTree.size(); i++ ) {
307 edges.push_back( Edge( _RTree[i].n1, IRs.size() ) );
308 edges.push_back( Edge( _RTree[i].n2, IRs.size() ) );
309 // inner clusters have counting number -1
310 IRs.push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
311 }
312
313 // create bipartite graph
314 G.create( nrORs(), nrIRs(), edges.begin(), edges.end() );
315
316 // Check counting numbers
317 Check_Counting_Numbers();
318
319 // Create messages and beliefs
320 _Qa.clear();
321 _Qa.reserve( nrORs() );
322 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
323 _Qa.push_back( OR(alpha) );
324
325 _Qb.clear();
326 _Qb.reserve( nrIRs() );
327 for( size_t beta = 0; beta < nrIRs(); beta++ )
328 _Qb.push_back( Factor( IR(beta), 1.0 ) );
329
330 // DIFF with JTree::GenerateJT: no messages
331
332 // DIFF with JTree::GenerateJT:
333 // Create factor approximations
334 _Q.clear();
335 size_t PreviousRoot = (size_t)-1;
336 for( size_t I = 0; I < nrFactors(); I++ )
337 if( offtree(I) ) {
338 // find efficient subtree
339 DEdgeVec subTree;
340 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
341 PreviousRoot = subTree[0].n1;
342 //subTree.resize( subTreeSize ); // FIXME
343 // cout << "subtree " << I << " has size " << subTreeSize << endl;
344
345 /*
346 char fn[30];
347 sprintf( fn, "/tmp/subtree_%d.dot", I );
348 std::ofstream dots(fn);
349 dots << "graph G {" << endl;
350 dots << "graph[size=\"9,9\"];" << endl;
351 dots << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
352 for( size_t i = 0; i < nrVars(); i++ )
353 dots << "\tx" << var(i).label() << ((factor(I).vars() >> var(i)) ? "[color=blue];" : ";") << endl;
354 dots << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
355 for( size_t J = 0; J < nrFactors(); J++ )
356 dots << "\tp" << J << ";" << endl;
357 for( size_t iI = 0; iI < FactorGraph::nr_edges(); iI++ )
358 dots << "\tx" << var(FactorGraph::edge(iI).first).label() << " -- p" << FactorGraph::edge(iI).second << ";" << endl;
359 for( size_t a = 0; a < tree.size(); a++ )
360 dots << "\tx" << var(tree[a].n1).label() << " -- x" << var(tree[a].n2).label() << " [color=red];" << endl;
361 dots << "}" << endl;
362 dots.close();
363 */
364
365 TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
366 _Q[I] = QI;
367 }
368 // Previous root of first off-tree factor should be the root of the last off-tree factor
369 for( size_t I = 0; I < nrFactors(); I++ )
370 if( offtree(I) ) {
371 DEdgeVec subTree;
372 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
373 PreviousRoot = subTree[0].n1;
374 //subTree.resize( subTreeSize ); // FIXME
375 // cout << "subtree " << I << " has size " << subTreeSize << endl;
376
377 TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
378 _Q[I] = QI;
379 break;
380 }
381
382 if( props.verbose >= 3 ) {
383 cout << "Resulting regiongraph: " << *this << endl;
384 }
385 }
386
387
388 string TreeEP::identify() const {
389 stringstream result (stringstream::out);
390 result << Name << getProperties();
391 return result.str();
392 }
393
394
395 void TreeEP::init() {
396 runHUGIN();
397
398 // Init factor approximations
399 for( size_t I = 0; I < nrFactors(); I++ )
400 if( offtree(I) )
401 _Q[I].init();
402 }
403
404
405 double TreeEP::run() {
406 if( props.verbose >= 1 )
407 cout << "Starting " << identify() << "...";
408 if( props.verbose >= 3)
409 cout << endl;
410
411 double tic = toc();
412 Diffs diffs(nrVars(), 1.0);
413
414 vector<Factor> old_beliefs;
415 old_beliefs.reserve( nrVars() );
416 for( size_t i = 0; i < nrVars(); i++ )
417 old_beliefs.push_back(belief(var(i)));
418
419 size_t iter = 0;
420
421 // do several passes over the network until maximum number of iterations has
422 // been reached or until the maximum belief difference is smaller than tolerance
423 for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; iter++ ) {
424 for( size_t I = 0; I < nrFactors(); I++ )
425 if( offtree(I) ) {
426 _Q[I].InvertAndMultiply( _Qa, _Qb );
427 _Q[I].HUGIN_with_I( _Qa, _Qb );
428 _Q[I].InvertAndMultiply( _Qa, _Qb );
429 }
430
431 // calculate new beliefs and compare with old ones
432 for( size_t i = 0; i < nrVars(); i++ ) {
433 Factor nb( belief(var(i)) );
434 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
435 old_beliefs[i] = nb;
436 }
437
438 if( props.verbose >= 3 )
439 cout << "TreeEP::run: maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
440 }
441
442 if( diffs.maxDiff() > maxdiff )
443 maxdiff = diffs.maxDiff();
444
445 if( props.verbose >= 1 ) {
446 if( diffs.maxDiff() > props.tol ) {
447 if( props.verbose == 1 )
448 cout << endl;
449 cout << "TreeEP::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
450 } else {
451 if( props.verbose >= 3 )
452 cout << "TreeEP::run: ";
453 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
454 }
455 }
456
457 return diffs.maxDiff();
458 }
459
460
461 Real TreeEP::logZ() const {
462 double sum = 0.0;
463
464 // entropy of the tree
465 for( size_t beta = 0; beta < nrIRs(); beta++ )
466 sum -= _Qb[beta].entropy();
467 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
468 sum += _Qa[alpha].entropy();
469
470 // energy of the on-tree factors
471 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
472 sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
473
474 // energy of the off-tree factors
475 for( size_t I = 0; I < nrFactors(); I++ )
476 if( offtree(I) )
477 sum += (_Q.find(I))->second.logZ( _Qa, _Qb );
478
479 return sum;
480 }
481
482
483 } // end of namespace dai