Merge branch 'no-edges2'
[libdai.git] / src / treeep.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <fstream>
24 #include <vector>
25 #include <dai/jtree.h>
26 #include <dai/treeep.h>
27 #include <dai/util.h>
28 #include <dai/diffs.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *TreeEP::Name = "TREEEP";
38
39
40 bool TreeEP::checkProperties() {
41 if( !HasProperty("type") )
42 return false;
43 if( !HasProperty("tol") )
44 return false;
45 if (!HasProperty("maxiter") )
46 return false;
47 if (!HasProperty("verbose") )
48 return false;
49
50 ConvertPropertyTo<TypeType>("type");
51 ConvertPropertyTo<double>("tol");
52 ConvertPropertyTo<size_t>("maxiter");
53 ConvertPropertyTo<size_t>("verbose");
54
55 return true;
56 }
57
58
59 TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const vector<Factor> &jt_Qa, const vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
60 _ns = _I->vars();
61
62 // Make _Qa, _Qb, _a and _b corresponding to the subtree
63 _b.reserve( subRTree.size() );
64 _Qb.reserve( subRTree.size() );
65 _RTree.reserve( subRTree.size() );
66 for( size_t i = 0; i < subRTree.size(); i++ ) {
67 size_t alpha1 = subRTree[i].n1; // old index 1
68 size_t alpha2 = subRTree[i].n2; // old index 2
69 size_t beta; // old sep index
70 for( beta = 0; beta < jt_RTree.size(); beta++ )
71 if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
72 break;
73 assert( beta != jt_RTree.size() );
74
75 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
76 if( newalpha1 == _a.size() ) {
77 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
78 _a.push_back( alpha1 ); // save old index in index conversion table
79 }
80
81 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
82 if( newalpha2 == _a.size() ) {
83 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
84 _a.push_back( alpha2 ); // save old index in index conversion table
85 }
86
87 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
88 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
89 _b.push_back( beta );
90 }
91
92 // Find remaining variables (which are not in the new root)
93 _nsrem = _ns / _Qa[0].vars();
94 };
95
96
97 void TreeEPSubTree::init() {
98 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
99 _Qa[alpha].fill( 1.0 );
100 for( size_t beta = 0; beta < _Qb.size(); beta++ )
101 _Qb[beta].fill( 1.0 );
102 }
103
104
105 void TreeEPSubTree::InvertAndMultiply( const vector<Factor> &Qa, const vector<Factor> &Qb ) {
106 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
107 _Qa[alpha] = Qa[_a[alpha]].divided_by( _Qa[alpha] );
108
109 for( size_t beta = 0; beta < _Qb.size(); beta++ )
110 _Qb[beta] = Qb[_b[beta]].divided_by( _Qb[beta] );
111 }
112
113
114 void TreeEPSubTree::HUGIN_with_I( vector<Factor> &Qa, vector<Factor> &Qb ) {
115 multind mi( _nsrem );
116
117 // Backup _Qa and _Qb
118 vector<Factor> _Qa_old(_Qa);
119 vector<Factor> _Qb_old(_Qb);
120
121 // Clear Qa and Qb
122 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
123 Qa[_a[alpha]].fill( 0.0 );
124 for( size_t beta = 0; beta < _Qb.size(); beta++ )
125 Qb[_b[beta]].fill( 0.0 );
126
127 // For all states of _nsrem
128 for( size_t j = 0; j < mi.max(); j++ ) {
129 vector<size_t> vi = mi.vi( j );
130
131 // Multiply root with slice of I
132 _Qa[0] *= _I->slice( _nsrem, j );
133
134 // CollectEvidence
135 for( size_t i = _RTree.size(); (i--) != 0; ) {
136 // clamp variables in nsrem
137 size_t k = 0;
138 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++, k++ )
139 if( _Qa[_RTree[i].n2].vars() >> *n ) {
140 Factor delta( *n, 0.0 );
141 delta[vi[k]] = 1.0;
142 _Qa[_RTree[i].n2] *= delta;
143 }
144 Factor new_Qb = _Qa[_RTree[i].n2].part_sum( _Qb[i].vars() );
145 _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
146 _Qb[i] = new_Qb;
147 }
148
149 // DistributeEvidence
150 for( size_t i = 0; i < _RTree.size(); i++ ) {
151 Factor new_Qb = _Qa[_RTree[i].n1].part_sum( _Qb[i].vars() );
152 _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
153 _Qb[i] = new_Qb;
154 }
155
156 // Store Qa's and Qb's
157 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
158 Qa[_a[alpha]].p() += _Qa[alpha].p();
159 for( size_t beta = 0; beta < _Qb.size(); beta++ )
160 Qb[_b[beta]].p() += _Qb[beta].p();
161
162 // Restore _Qa and _Qb
163 _Qa = _Qa_old;
164 _Qb = _Qb_old;
165 }
166
167 // Normalize Qa and Qb
168 _logZ = 0.0;
169 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
170 _logZ += log(Qa[_a[alpha]].totalSum());
171 Qa[_a[alpha]].normalize( Prob::NORMPROB );
172 }
173 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
174 _logZ -= log(Qb[_b[beta]].totalSum());
175 Qb[_b[beta]].normalize( Prob::NORMPROB );
176 }
177 }
178
179
180 double TreeEPSubTree::logZ( const vector<Factor> &Qa, const vector<Factor> &Qb ) const {
181 double sum = 0.0;
182 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
183 sum += (Qa[_a[alpha]] * _Qa[alpha].log0()).totalSum();
184 for( size_t beta = 0; beta < _Qb.size(); beta++ )
185 sum -= (Qb[_b[beta]] * _Qb[beta].log0()).totalSum();
186 return sum + _logZ;
187 }
188
189
190 TreeEP::TreeEP( const FactorGraph &fg, const Properties &opts ) : JTree(fg, opts("updates",string("HUGIN")), false) {
191 assert( checkProperties() );
192
193 assert( fg.G.isConnected() );
194
195 if( opts.hasKey("tree") ) {
196 ConstructRG( opts.GetAs<DEdgeVec>("tree") );
197 } else {
198 if( Type() == TypeType::ORG ) {
199 // construct weighted graph with as weights a crude estimate of the
200 // mutual information between the nodes
201 WeightedGraph<double> wg;
202 for( size_t i = 0; i < nrVars(); ++i ) {
203 Var v_i = var(i);
204 VarSet di = delta(i);
205 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
206 if( v_i < *j ) {
207 Factor piet;
208 for( size_t I = 0; I < nrFactors(); I++ ) {
209 VarSet Ivars = factor(I).vars();
210 if( (Ivars == v_i) || (Ivars == *j) )
211 piet *= factor(I);
212 else if( Ivars >> (v_i | *j) )
213 piet *= factor(I).marginal( v_i | *j );
214 }
215 if( piet.vars() >> (v_i | *j) ) {
216 piet = piet.marginal( v_i | *j );
217 Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
218 wg[UEdge(i,findVar(*j))] = real( KL_dist( piet, pietf ) );
219 } else
220 wg[UEdge(i,findVar(*j))] = 0;
221 }
222 }
223
224 // find maximal spanning tree
225 ConstructRG( MaxSpanningTreePrim( wg ) );
226
227 // cout << "Constructing maximum spanning tree..." << endl;
228 // DEdgeVec MST = MaxSpanningTreePrim( wg );
229 // cout << "Maximum spanning tree:" << endl;
230 // for( DEdgeVec::const_iterator e = MST.begin(); e != MST.end(); e++ )
231 // cout << *e << endl;
232 // ConstructRG( MST );
233 } else if( Type() == TypeType::ALT ) {
234 // construct weighted graph with as weights an upper bound on the
235 // effective interaction strength between pairs of nodes
236 WeightedGraph<double> wg;
237 for( size_t i = 0; i < nrVars(); ++i ) {
238 Var v_i = var(i);
239 VarSet di = delta(i);
240 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
241 if( v_i < *j ) {
242 Factor piet;
243 for( size_t I = 0; I < nrFactors(); I++ ) {
244 VarSet Ivars = factor(I).vars();
245 if( Ivars >> (v_i | *j) )
246 piet *= factor(I);
247 }
248 wg[UEdge(i,findVar(*j))] = piet.strength(v_i, *j);
249 }
250 }
251
252 // find maximal spanning tree
253 ConstructRG( MaxSpanningTreePrim( wg ) );
254 } else {
255 assert( 0 == 1 );
256 }
257 }
258 }
259
260
261 void TreeEP::ConstructRG( const DEdgeVec &tree ) {
262 vector<VarSet> Cliques;
263 for( size_t i = 0; i < tree.size(); i++ )
264 Cliques.push_back( var(tree[i].n1) | var(tree[i].n2) );
265
266 // Construct a weighted graph (each edge is weighted with the cardinality
267 // of the intersection of the nodes, where the nodes are the elements of
268 // Cliques).
269 WeightedGraph<int> JuncGraph;
270 for( size_t i = 0; i < Cliques.size(); i++ )
271 for( size_t j = i+1; j < Cliques.size(); j++ ) {
272 size_t w = (Cliques[i] & Cliques[j]).size();
273 JuncGraph[UEdge(i,j)] = w;
274 }
275
276 // Construct maximal spanning tree using Prim's algorithm
277 _RTree = MaxSpanningTreePrim( JuncGraph );
278
279 // Construct corresponding region graph
280
281 // Create outer regions
282 ORs.reserve( Cliques.size() );
283 for( size_t i = 0; i < Cliques.size(); i++ )
284 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
285
286 // For each factor, find an outer region that subsumes that factor.
287 // Then, multiply the outer region with that factor.
288 // If no outer region can be found subsuming that factor, label the
289 // factor as off-tree.
290 fac2OR.clear();
291 fac2OR.resize( nrFactors(), -1U );
292 for( size_t I = 0; I < nrFactors(); I++ ) {
293 size_t alpha;
294 for( alpha = 0; alpha < nrORs(); alpha++ )
295 if( OR(alpha).vars() >> factor(I).vars() ) {
296 fac2OR[I] = alpha;
297 break;
298 }
299 // DIFF WITH JTree::GenerateJT: assert
300 }
301 RecomputeORs();
302
303 // Create inner regions and edges
304 IRs.reserve( _RTree.size() );
305 typedef pair<size_t,size_t> Edge;
306 vector<Edge> edges;
307 edges.reserve( 2 * _RTree.size() );
308 for( size_t i = 0; i < _RTree.size(); i++ ) {
309 edges.push_back( Edge( _RTree[i].n1, IRs.size() ) );
310 edges.push_back( Edge( _RTree[i].n2, IRs.size() ) );
311 // inner clusters have counting number -1
312 IRs.push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
313 }
314
315 // create bipartite graph
316 G.create( nrORs(), nrIRs(), edges.begin(), edges.end() );
317
318 // Check counting numbers
319 Check_Counting_Numbers();
320
321 // Create messages and beliefs
322 _Qa.clear();
323 _Qa.reserve( nrORs() );
324 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
325 _Qa.push_back( OR(alpha) );
326
327 _Qb.clear();
328 _Qb.reserve( nrIRs() );
329 for( size_t beta = 0; beta < nrIRs(); beta++ )
330 _Qb.push_back( Factor( IR(beta), 1.0 ) );
331
332 // DIFF with JTree::GenerateJT: no messages
333
334 // DIFF with JTree::GenerateJT:
335 // Create factor approximations
336 _Q.clear();
337 size_t PreviousRoot = (size_t)-1;
338 for( size_t I = 0; I < nrFactors(); I++ )
339 if( offtree(I) ) {
340 // find efficient subtree
341 DEdgeVec subTree;
342 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
343 PreviousRoot = subTree[0].n1;
344 //subTree.resize( subTreeSize ); // FIXME
345 // cout << "subtree " << I << " has size " << subTreeSize << endl;
346
347 /*
348 char fn[30];
349 sprintf( fn, "/tmp/subtree_%d.dot", I );
350 std::ofstream dots(fn);
351 dots << "graph G {" << endl;
352 dots << "graph[size=\"9,9\"];" << endl;
353 dots << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
354 for( size_t i = 0; i < nrVars(); i++ )
355 dots << "\tx" << var(i).label() << ((factor(I).vars() >> var(i)) ? "[color=blue];" : ";") << endl;
356 dots << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
357 for( size_t J = 0; J < nrFactors(); J++ )
358 dots << "\tp" << J << ";" << endl;
359 for( size_t iI = 0; iI < FactorGraph::nr_edges(); iI++ )
360 dots << "\tx" << var(FactorGraph::edge(iI).first).label() << " -- p" << FactorGraph::edge(iI).second << ";" << endl;
361 for( size_t a = 0; a < tree.size(); a++ )
362 dots << "\tx" << var(tree[a].n1).label() << " -- x" << var(tree[a].n2).label() << " [color=red];" << endl;
363 dots << "}" << endl;
364 dots.close();
365 */
366
367 TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
368 _Q[I] = QI;
369 }
370 // Previous root of first off-tree factor should be the root of the last off-tree factor
371 for( size_t I = 0; I < nrFactors(); I++ )
372 if( offtree(I) ) {
373 DEdgeVec subTree;
374 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
375 PreviousRoot = subTree[0].n1;
376 //subTree.resize( subTreeSize ); // FIXME
377 // cout << "subtree " << I << " has size " << subTreeSize << endl;
378
379 TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
380 _Q[I] = QI;
381 break;
382 }
383
384 if( Verbose() >= 3 ) {
385 cout << "Resulting regiongraph: " << *this << endl;
386 }
387 }
388
389
390 string TreeEP::identify() const {
391 stringstream result (stringstream::out);
392 result << Name << GetProperties();
393 return result.str();
394 }
395
396
397 void TreeEP::init() {
398 assert( checkProperties() );
399
400 runHUGIN();
401
402 // Init factor approximations
403 for( size_t I = 0; I < nrFactors(); I++ )
404 if( offtree(I) )
405 _Q[I].init();
406 }
407
408
409 double TreeEP::run() {
410 if( Verbose() >= 1 )
411 cout << "Starting " << identify() << "...";
412 if( Verbose() >= 3)
413 cout << endl;
414
415 clock_t tic = toc();
416 Diffs diffs(nrVars(), 1.0);
417
418 vector<Factor> old_beliefs;
419 old_beliefs.reserve( nrVars() );
420 for( size_t i = 0; i < nrVars(); i++ )
421 old_beliefs.push_back(belief(var(i)));
422
423 size_t iter = 0;
424
425 // do several passes over the network until maximum number of iterations has
426 // been reached or until the maximum belief difference is smaller than tolerance
427 for( iter=0; iter < MaxIter() && diffs.max() > Tol(); iter++ ) {
428 for( size_t I = 0; I < nrFactors(); I++ )
429 if( offtree(I) ) {
430 _Q[I].InvertAndMultiply( _Qa, _Qb );
431 _Q[I].HUGIN_with_I( _Qa, _Qb );
432 _Q[I].InvertAndMultiply( _Qa, _Qb );
433 }
434
435 // calculate new beliefs and compare with old ones
436 for( size_t i = 0; i < nrVars(); i++ ) {
437 Factor nb( belief(var(i)) );
438 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
439 old_beliefs[i] = nb;
440 }
441
442 if( Verbose() >= 3 )
443 cout << "TreeEP::run: maxdiff " << diffs.max() << " after " << iter+1 << " passes" << endl;
444 }
445
446 updateMaxDiff( diffs.max() );
447
448 if( Verbose() >= 1 ) {
449 if( diffs.max() > Tol() ) {
450 if( Verbose() == 1 )
451 cout << endl;
452 cout << "TreeEP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
453 } else {
454 if( Verbose() >= 3 )
455 cout << "TreeEP::run: ";
456 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
457 }
458 }
459
460 return diffs.max();
461 }
462
463
464 Complex TreeEP::logZ() const {
465 double sum = 0.0;
466
467 // entropy of the tree
468 for( size_t beta = 0; beta < nrIRs(); beta++ )
469 sum -= real(_Qb[beta].entropy());
470 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
471 sum += real(_Qa[alpha].entropy());
472
473 // energy of the on-tree factors
474 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
475 sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
476
477 // energy of the off-tree factors
478 for( size_t I = 0; I < nrFactors(); I++ )
479 if( offtree(I) )
480 sum += (_Q.find(I))->second.logZ( _Qa, _Qb );
481
482 return sum;
483 }
484
485
486 } // end of namespace dai