Partial adoption of contributions by Giuseppe:
[libdai.git] / src / treeep.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <fstream>
24 #include <vector>
25 #include <dai/jtree.h>
26 #include <dai/treeep.h>
27 #include <dai/util.h>
28 #include <dai/diffs.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *TreeEP::Name = "TREEEP";
38
39
40 bool TreeEP::checkProperties() {
41 if( !HasProperty("type") )
42 return false;
43 if( !HasProperty("tol") )
44 return false;
45 if (!HasProperty("maxiter") )
46 return false;
47 if (!HasProperty("verbose") )
48 return false;
49
50 ConvertPropertyTo<TypeType>("type");
51 ConvertPropertyTo<double>("tol");
52 ConvertPropertyTo<size_t>("maxiter");
53 ConvertPropertyTo<size_t>("verbose");
54
55 return true;
56 }
57
58
59 TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const vector<Factor> &jt_Qa, const vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
60 _ns = _I->vars();
61
62 // Make _Qa, _Qb, _a and _b corresponding to the subtree
63 _b.reserve( subRTree.size() );
64 _Qb.reserve( subRTree.size() );
65 _RTree.reserve( subRTree.size() );
66 for( size_t i = 0; i < subRTree.size(); i++ ) {
67 size_t alpha1 = subRTree[i].n1; // old index 1
68 size_t alpha2 = subRTree[i].n2; // old index 2
69 size_t beta; // old sep index
70 for( beta = 0; beta < jt_RTree.size(); beta++ )
71 if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
72 break;
73 assert( beta != jt_RTree.size() );
74
75 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
76 if( newalpha1 == _a.size() ) {
77 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
78 _a.push_back( alpha1 ); // save old index in index conversion table
79 }
80
81 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
82 if( newalpha2 == _a.size() ) {
83 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
84 _a.push_back( alpha2 ); // save old index in index conversion table
85 }
86
87 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
88 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
89 _b.push_back( beta );
90 }
91
92 // Find remaining variables (which are not in the new root)
93 _nsrem = _ns / _Qa[0].vars();
94 };
95
96
97 void TreeEPSubTree::init() {
98 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
99 _Qa[alpha].fill( 1.0 );
100 for( size_t beta = 0; beta < _Qb.size(); beta++ )
101 _Qb[beta].fill( 1.0 );
102 }
103
104
105 void TreeEPSubTree::InvertAndMultiply( const vector<Factor> &Qa, const vector<Factor> &Qb ) {
106 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
107 _Qa[alpha] = Qa[_a[alpha]].divided_by( _Qa[alpha] );
108
109 for( size_t beta = 0; beta < _Qb.size(); beta++ )
110 _Qb[beta] = Qb[_b[beta]].divided_by( _Qb[beta] );
111 }
112
113
114 void TreeEPSubTree::HUGIN_with_I( vector<Factor> &Qa, vector<Factor> &Qb ) {
115 multind mi( _nsrem );
116
117 // Backup _Qa and _Qb
118 vector<Factor> _Qa_old(_Qa);
119 vector<Factor> _Qb_old(_Qb);
120
121 // Clear Qa and Qb
122 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
123 Qa[_a[alpha]].fill( 0.0 );
124 for( size_t beta = 0; beta < _Qb.size(); beta++ )
125 Qb[_b[beta]].fill( 0.0 );
126
127 // For all states of _nsrem
128 for( size_t j = 0; j < mi.max(); j++ ) {
129 vector<size_t> vi = mi.vi( j );
130
131 // Multiply root with slice of I
132 _Qa[0] *= _I->slice( _nsrem, j );
133
134 // CollectEvidence
135 for( size_t i = _RTree.size(); (i--) != 0; ) {
136 // clamp variables in nsrem
137 size_t k = 0;
138 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++, k++ )
139 if( _Qa[_RTree[i].n2].vars() >> *n ) {
140 Factor delta( *n, 0.0 );
141 delta[vi[k]] = 1.0;
142 _Qa[_RTree[i].n2] *= delta;
143 }
144 Factor new_Qb = _Qa[_RTree[i].n2].part_sum( _Qb[i].vars() );
145 _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
146 _Qb[i] = new_Qb;
147 }
148
149 // DistributeEvidence
150 for( size_t i = 0; i < _RTree.size(); i++ ) {
151 Factor new_Qb = _Qa[_RTree[i].n1].part_sum( _Qb[i].vars() );
152 _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
153 _Qb[i] = new_Qb;
154 }
155
156 // Store Qa's and Qb's
157 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
158 Qa[_a[alpha]].p() += _Qa[alpha].p();
159 for( size_t beta = 0; beta < _Qb.size(); beta++ )
160 Qb[_b[beta]].p() += _Qb[beta].p();
161
162 // Restore _Qa and _Qb
163 _Qa = _Qa_old;
164 _Qb = _Qb_old;
165 }
166
167 // Normalize Qa and Qb
168 _logZ = 0.0;
169 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
170 _logZ += log(Qa[_a[alpha]].totalSum());
171 Qa[_a[alpha]].normalize( Prob::NORMPROB );
172 }
173 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
174 _logZ -= log(Qb[_b[beta]].totalSum());
175 Qb[_b[beta]].normalize( Prob::NORMPROB );
176 }
177 }
178
179
180 double TreeEPSubTree::logZ( const vector<Factor> &Qa, const vector<Factor> &Qb ) const {
181 double sum = 0.0;
182 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
183 sum += (Qa[_a[alpha]] * _Qa[alpha].log0()).totalSum();
184 for( size_t beta = 0; beta < _Qb.size(); beta++ )
185 sum -= (Qb[_b[beta]] * _Qb[beta].log0()).totalSum();
186 return sum + _logZ;
187 }
188
189
190 TreeEP::TreeEP( const FactorGraph &fg, const Properties &opts ) : JTree(fg, opts("updates",string("HUGIN")), false) {
191 assert( checkProperties() );
192
193 assert( fg.isConnected() );
194
195 if( opts.hasKey("tree") ) {
196 ConstructRG( opts.GetAs<DEdgeVec>("tree") );
197 } else {
198 if( Type() == TypeType::ORG ) {
199 // construct weighted graph with as weights a crude estimate of the
200 // mutual information between the nodes
201 WeightedGraph<double> wg;
202 for( vector<Var>::const_iterator i = vars().begin(); i != vars().end(); i++ ) {
203 VarSet di = delta(*i);
204 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
205 if( *i < *j ) {
206 Factor piet;
207 for( size_t I = 0; I < nrFactors(); I++ ) {
208 VarSet Ivars = factor(I).vars();
209 if( (Ivars == *i) || (Ivars == *j) )
210 piet *= factor(I);
211 else if( Ivars >> (*i | *j) )
212 piet *= factor(I).marginal( *i | *j );
213 }
214 if( piet.vars() >> (*i | *j) ) {
215 piet = piet.marginal( *i | *j );
216 Factor pietf = piet.marginal(*i) * piet.marginal(*j);
217 wg[UEdge(findVar(*i),findVar(*j))] = real( KL_dist( piet, pietf ) );
218 } else
219 wg[UEdge(findVar(*i),findVar(*j))] = 0;
220 }
221 }
222
223 // find maximal spanning tree
224 ConstructRG( MaxSpanningTreePrim( wg ) );
225
226 // cout << "Constructing maximum spanning tree..." << endl;
227 // DEdgeVec MST = MaxSpanningTreePrim( wg );
228 // cout << "Maximum spanning tree:" << endl;
229 // for( DEdgeVec::const_iterator e = MST.begin(); e != MST.end(); e++ )
230 // cout << *e << endl;
231 // ConstructRG( MST );
232 } else if( Type() == TypeType::ALT ) {
233 // construct weighted graph with as weights an upper bound on the
234 // effective interaction strength between pairs of nodes
235 WeightedGraph<double> wg;
236 for( vector<Var>::const_iterator i = vars().begin(); i != vars().end(); i++ ) {
237 VarSet di = delta(*i);
238 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
239 if( *i < *j ) {
240 Factor piet;
241 for( size_t I = 0; I < nrFactors(); I++ ) {
242 VarSet Ivars = factor(I).vars();
243 if( Ivars >> (*i | *j) )
244 piet *= factor(I);
245 }
246 wg[UEdge(findVar(*i),findVar(*j))] = piet.strength(*i, *j);
247 }
248 }
249
250 // find maximal spanning tree
251 ConstructRG( MaxSpanningTreePrim( wg ) );
252 } else {
253 assert( 0 == 1 );
254 }
255 }
256 }
257
258
259 void TreeEP::ConstructRG( const DEdgeVec &tree ) {
260 vector<VarSet> Cliques;
261 for( size_t i = 0; i < tree.size(); i++ )
262 Cliques.push_back( var(tree[i].n1) | var(tree[i].n2) );
263
264 // Construct a weighted graph (each edge is weighted with the cardinality
265 // of the intersection of the nodes, where the nodes are the elements of
266 // Cliques).
267 WeightedGraph<int> JuncGraph;
268 for( size_t i = 0; i < Cliques.size(); i++ )
269 for( size_t j = i+1; j < Cliques.size(); j++ ) {
270 size_t w = (Cliques[i] & Cliques[j]).size();
271 JuncGraph[UEdge(i,j)] = w;
272 }
273
274 // Construct maximal spanning tree using Prim's algorithm
275 _RTree = MaxSpanningTreePrim( JuncGraph );
276
277 // Construct corresponding region graph
278
279 // Create outer regions
280 ORs().reserve( Cliques.size() );
281 for( size_t i = 0; i < Cliques.size(); i++ )
282 ORs().push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
283
284 // For each factor, find an outer region that subsumes that factor.
285 // Then, multiply the outer region with that factor.
286 // If no outer region can be found subsuming that factor, label the
287 // factor as off-tree.
288 for( size_t I = 0; I < nrFactors(); I++ ) {
289 size_t alpha;
290 for( alpha = 0; alpha < nr_ORs(); alpha++ )
291 if( OR(alpha).vars() >> factor(I).vars() ) {
292 _fac2OR[I] = alpha;
293 break;
294 }
295 // DIFF WITH JTree::GenerateJT: assert
296 }
297 RecomputeORs();
298
299 // Create inner regions and edges
300 IRs().reserve( _RTree.size() );
301 Redges().reserve( 2 * _RTree.size() );
302 for( size_t i = 0; i < _RTree.size(); i++ ) {
303 Redges().push_back( R_edge_t( _RTree[i].n1, IRs().size() ) );
304 Redges().push_back( R_edge_t( _RTree[i].n2, IRs().size() ) );
305 // inner clusters have counting number -1
306 IRs().push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
307 }
308
309 // Regenerate BipartiteGraph internals
310 Regenerate();
311
312 // Check counting numbers
313 Check_Counting_Numbers();
314
315 // Create messages and beliefs
316 _Qa.clear();
317 _Qa.reserve( nr_ORs() );
318 for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
319 _Qa.push_back( OR(alpha) );
320
321 _Qb.clear();
322 _Qb.reserve( nr_IRs() );
323 for( size_t beta = 0; beta < nr_IRs(); beta++ )
324 _Qb.push_back( Factor( IR(beta), 1.0 ) );
325
326 // DIFF with JTree::GenerateJT: no messages
327
328 // DIFF with JTree::GenerateJT:
329 // Create factor approximations
330 _Q.clear();
331 size_t PreviousRoot = (size_t)-1;
332 for( size_t I = 0; I < nrFactors(); I++ )
333 if( offtree(I) ) {
334 // find efficient subtree
335 DEdgeVec subTree;
336 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
337 PreviousRoot = subTree[0].n1;
338 //subTree.resize( subTreeSize ); // FIXME
339 // cout << "subtree " << I << " has size " << subTreeSize << endl;
340
341 /*
342 char fn[30];
343 sprintf( fn, "/tmp/subtree_%d.dot", I );
344 std::ofstream dots(fn);
345 dots << "graph G {" << endl;
346 dots << "graph[size=\"9,9\"];" << endl;
347 dots << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
348 for( size_t i = 0; i < nrVars(); i++ )
349 dots << "\tx" << var(i).label() << ((factor(I).vars() >> var(i)) ? "[color=blue];" : ";") << endl;
350 dots << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
351 for( size_t J = 0; J < nrFactors(); J++ )
352 dots << "\tp" << J << ";" << endl;
353 for( size_t iI = 0; iI < FactorGraph::nr_edges(); iI++ )
354 dots << "\tx" << var(FactorGraph::edge(iI).first).label() << " -- p" << FactorGraph::edge(iI).second << ";" << endl;
355 for( size_t a = 0; a < tree.size(); a++ )
356 dots << "\tx" << var(tree[a].n1).label() << " -- x" << var(tree[a].n2).label() << " [color=red];" << endl;
357 dots << "}" << endl;
358 dots.close();
359 */
360
361 TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
362 _Q[I] = QI;
363 }
364 // Previous root of first off-tree factor should be the root of the last off-tree factor
365 for( size_t I = 0; I < nrFactors(); I++ )
366 if( offtree(I) ) {
367 DEdgeVec subTree;
368 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
369 PreviousRoot = subTree[0].n1;
370 //subTree.resize( subTreeSize ); // FIXME
371 // cout << "subtree " << I << " has size " << subTreeSize << endl;
372
373 TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
374 _Q[I] = QI;
375 break;
376 }
377
378 if( Verbose() >= 3 ) {
379 cout << "Resulting regiongraph: " << *this << endl;
380 }
381 }
382
383
384 string TreeEP::identify() const {
385 stringstream result (stringstream::out);
386 result << Name << GetProperties();
387 return result.str();
388 }
389
390
391 void TreeEP::init() {
392 assert( checkProperties() );
393
394 runHUGIN();
395
396 // Init factor approximations
397 for( size_t I = 0; I < nrFactors(); I++ )
398 if( offtree(I) )
399 _Q[I].init();
400 }
401
402
403 double TreeEP::run() {
404 if( Verbose() >= 1 )
405 cout << "Starting " << identify() << "...";
406 if( Verbose() >= 3)
407 cout << endl;
408
409 clock_t tic = toc();
410 Diffs diffs(nrVars(), 1.0);
411
412 vector<Factor> old_beliefs;
413 old_beliefs.reserve( nrVars() );
414 for( size_t i = 0; i < nrVars(); i++ )
415 old_beliefs.push_back(belief(var(i)));
416
417 size_t iter = 0;
418
419 // do several passes over the network until maximum number of iterations has
420 // been reached or until the maximum belief difference is smaller than tolerance
421 for( iter=0; iter < MaxIter() && diffs.max() > Tol(); iter++ ) {
422 for( size_t I = 0; I < nrFactors(); I++ )
423 if( offtree(I) ) {
424 _Q[I].InvertAndMultiply( _Qa, _Qb );
425 _Q[I].HUGIN_with_I( _Qa, _Qb );
426 _Q[I].InvertAndMultiply( _Qa, _Qb );
427 }
428
429 // calculate new beliefs and compare with old ones
430 for( size_t i = 0; i < nrVars(); i++ ) {
431 Factor nb( belief(var(i)) );
432 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
433 old_beliefs[i] = nb;
434 }
435
436 if( Verbose() >= 3 )
437 cout << "TreeEP::run: maxdiff " << diffs.max() << " after " << iter+1 << " passes" << endl;
438 }
439
440 updateMaxDiff( diffs.max() );
441
442 if( Verbose() >= 1 ) {
443 if( diffs.max() > Tol() ) {
444 if( Verbose() == 1 )
445 cout << endl;
446 cout << "TreeEP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
447 } else {
448 if( Verbose() >= 3 )
449 cout << "TreeEP::run: ";
450 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
451 }
452 }
453
454 return diffs.max();
455 }
456
457
458 Complex TreeEP::logZ() const {
459 double sum = 0.0;
460
461 // entropy of the tree
462 for( size_t beta = 0; beta < nr_IRs(); beta++ )
463 sum -= real(_Qb[beta].entropy());
464 for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
465 sum += real(_Qa[alpha].entropy());
466
467 // energy of the on-tree factors
468 for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
469 sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
470
471 // energy of the off-tree factors
472 for( size_t I = 0; I < nrFactors(); I++ )
473 if( offtree(I) )
474 sum += (_Q.find(I))->second.logZ( _Qa, _Qb );
475
476 return sum;
477 }
478
479
480 } // end of namespace dai