2a60738b351ea0ea71fb10c5d30795a5e8b510e6
[libdai.git] / src / treeep.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <fstream>
24 #include <vector>
25 #include <dai/jtree.h>
26 #include <dai/treeep.h>
27 #include <dai/util.h>
28 #include <dai/diffs.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *TreeEP::Name = "TREEEP";
38
39
40 bool TreeEP::checkProperties() {
41 if( !HasProperty("type") )
42 return false;
43 if( !HasProperty("tol") )
44 return false;
45 if (!HasProperty("maxiter") )
46 return false;
47 if (!HasProperty("verbose") )
48 return false;
49
50 ConvertPropertyTo<TypeType>("type");
51 ConvertPropertyTo<double>("tol");
52 ConvertPropertyTo<size_t>("maxiter");
53 ConvertPropertyTo<size_t>("verbose");
54
55 return true;
56 }
57
58
59 TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
60 _ns = _I->vars();
61
62 // Make _Qa, _Qb, _a and _b corresponding to the subtree
63 _b.reserve( subRTree.size() );
64 _Qb.reserve( subRTree.size() );
65 _RTree.reserve( subRTree.size() );
66 for( size_t i = 0; i < subRTree.size(); i++ ) {
67 size_t alpha1 = subRTree[i].n1; // old index 1
68 size_t alpha2 = subRTree[i].n2; // old index 2
69 size_t beta; // old sep index
70 for( beta = 0; beta < jt_RTree.size(); beta++ )
71 if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
72 break;
73 assert( beta != jt_RTree.size() );
74
75 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
76 if( newalpha1 == _a.size() ) {
77 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
78 _a.push_back( alpha1 ); // save old index in index conversion table
79 }
80
81 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
82 if( newalpha2 == _a.size() ) {
83 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
84 _a.push_back( alpha2 ); // save old index in index conversion table
85 }
86
87 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
88 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
89 _b.push_back( beta );
90 }
91
92 // Find remaining variables (which are not in the new root)
93 _nsrem = _ns / _Qa[0].vars();
94 };
95
96
97 void TreeEPSubTree::init() {
98 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
99 _Qa[alpha].fill( 1.0 );
100 for( size_t beta = 0; beta < _Qb.size(); beta++ )
101 _Qb[beta].fill( 1.0 );
102 }
103
104
105 void TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
106 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
107 _Qa[alpha] = Qa[_a[alpha]].divided_by( _Qa[alpha] );
108
109 for( size_t beta = 0; beta < _Qb.size(); beta++ )
110 _Qb[beta] = Qb[_b[beta]].divided_by( _Qb[beta] );
111 }
112
113
114 void TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
115 // Backup _Qa and _Qb
116 vector<Factor> _Qa_old(_Qa);
117 vector<Factor> _Qb_old(_Qb);
118
119 // Clear Qa and Qb
120 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
121 Qa[_a[alpha]].fill( 0.0 );
122 for( size_t beta = 0; beta < _Qb.size(); beta++ )
123 Qb[_b[beta]].fill( 0.0 );
124
125 // For all states of _nsrem
126 for( State s(_nsrem); s.valid(); s++ ) {
127 // Multiply root with slice of I
128 _Qa[0] *= _I->slice( _nsrem, s );
129
130 // CollectEvidence
131 for( size_t i = _RTree.size(); (i--) != 0; ) {
132 // clamp variables in nsrem
133 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
134 if( _Qa[_RTree[i].n2].vars() >> *n ) {
135 Factor delta( *n, 0.0 );
136 delta[s(*n)] = 1.0;
137 _Qa[_RTree[i].n2] *= delta;
138 }
139 Factor new_Qb = _Qa[_RTree[i].n2].part_sum( _Qb[i].vars() );
140 _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
141 _Qb[i] = new_Qb;
142 }
143
144 // DistributeEvidence
145 for( size_t i = 0; i < _RTree.size(); i++ ) {
146 Factor new_Qb = _Qa[_RTree[i].n1].part_sum( _Qb[i].vars() );
147 _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
148 _Qb[i] = new_Qb;
149 }
150
151 // Store Qa's and Qb's
152 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
153 Qa[_a[alpha]].p() += _Qa[alpha].p();
154 for( size_t beta = 0; beta < _Qb.size(); beta++ )
155 Qb[_b[beta]].p() += _Qb[beta].p();
156
157 // Restore _Qa and _Qb
158 _Qa = _Qa_old;
159 _Qb = _Qb_old;
160 }
161
162 // Normalize Qa and Qb
163 _logZ = 0.0;
164 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
165 _logZ += log(Qa[_a[alpha]].totalSum());
166 Qa[_a[alpha]].normalize( Prob::NORMPROB );
167 }
168 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
169 _logZ -= log(Qb[_b[beta]].totalSum());
170 Qb[_b[beta]].normalize( Prob::NORMPROB );
171 }
172 }
173
174
175 double TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
176 double sum = 0.0;
177 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
178 sum += (Qa[_a[alpha]] * _Qa[alpha].log0()).totalSum();
179 for( size_t beta = 0; beta < _Qb.size(); beta++ )
180 sum -= (Qb[_b[beta]] * _Qb[beta].log0()).totalSum();
181 return sum + _logZ;
182 }
183
184
185 TreeEP::TreeEP( const FactorGraph &fg, const Properties &opts ) : JTree(fg, opts("updates",string("HUGIN")), false) {
186 assert( checkProperties() );
187
188 assert( fg.G.isConnected() );
189
190 if( opts.hasKey("tree") ) {
191 ConstructRG( opts.GetAs<DEdgeVec>("tree") );
192 } else {
193 if( Type() == TypeType::ORG ) {
194 // construct weighted graph with as weights a crude estimate of the
195 // mutual information between the nodes
196 WeightedGraph<double> wg;
197 for( size_t i = 0; i < nrVars(); ++i ) {
198 Var v_i = var(i);
199 VarSet di = delta(i);
200 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
201 if( v_i < *j ) {
202 Factor piet;
203 for( size_t I = 0; I < nrFactors(); I++ ) {
204 VarSet Ivars = factor(I).vars();
205 if( (Ivars == v_i) || (Ivars == *j) )
206 piet *= factor(I);
207 else if( Ivars >> (v_i | *j) )
208 piet *= factor(I).marginal( v_i | *j );
209 }
210 if( piet.vars() >> (v_i | *j) ) {
211 piet = piet.marginal( v_i | *j );
212 Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
213 wg[UEdge(i,findVar(*j))] = KL_dist( piet, pietf );
214 } else
215 wg[UEdge(i,findVar(*j))] = 0;
216 }
217 }
218
219 // find maximal spanning tree
220 ConstructRG( MaxSpanningTreePrims( wg ) );
221
222 // cout << "Constructing maximum spanning tree..." << endl;
223 // DEdgeVec MST = MaxSpanningTreePrims( wg );
224 // cout << "Maximum spanning tree:" << endl;
225 // for( DEdgeVec::const_iterator e = MST.begin(); e != MST.end(); e++ )
226 // cout << *e << endl;
227 // ConstructRG( MST );
228 } else if( Type() == TypeType::ALT ) {
229 // construct weighted graph with as weights an upper bound on the
230 // effective interaction strength between pairs of nodes
231 WeightedGraph<double> wg;
232 for( size_t i = 0; i < nrVars(); ++i ) {
233 Var v_i = var(i);
234 VarSet di = delta(i);
235 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
236 if( v_i < *j ) {
237 Factor piet;
238 for( size_t I = 0; I < nrFactors(); I++ ) {
239 VarSet Ivars = factor(I).vars();
240 if( Ivars >> (v_i | *j) )
241 piet *= factor(I);
242 }
243 wg[UEdge(i,findVar(*j))] = piet.strength(v_i, *j);
244 }
245 }
246
247 // find maximal spanning tree
248 ConstructRG( MaxSpanningTreePrims( wg ) );
249 } else {
250 DAI_THROW(INTERNAL_ERROR);
251 }
252 }
253 }
254
255
256 void TreeEP::ConstructRG( const DEdgeVec &tree ) {
257 vector<VarSet> Cliques;
258 for( size_t i = 0; i < tree.size(); i++ )
259 Cliques.push_back( var(tree[i].n1) | var(tree[i].n2) );
260
261 // Construct a weighted graph (each edge is weighted with the cardinality
262 // of the intersection of the nodes, where the nodes are the elements of
263 // Cliques).
264 WeightedGraph<int> JuncGraph;
265 for( size_t i = 0; i < Cliques.size(); i++ )
266 for( size_t j = i+1; j < Cliques.size(); j++ ) {
267 size_t w = (Cliques[i] & Cliques[j]).size();
268 JuncGraph[UEdge(i,j)] = w;
269 }
270
271 // Construct maximal spanning tree using Prim's algorithm
272 _RTree = MaxSpanningTreePrims( JuncGraph );
273
274 // Construct corresponding region graph
275
276 // Create outer regions
277 ORs.reserve( Cliques.size() );
278 for( size_t i = 0; i < Cliques.size(); i++ )
279 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
280
281 // For each factor, find an outer region that subsumes that factor.
282 // Then, multiply the outer region with that factor.
283 // If no outer region can be found subsuming that factor, label the
284 // factor as off-tree.
285 fac2OR.clear();
286 fac2OR.resize( nrFactors(), -1U );
287 for( size_t I = 0; I < nrFactors(); I++ ) {
288 size_t alpha;
289 for( alpha = 0; alpha < nrORs(); alpha++ )
290 if( OR(alpha).vars() >> factor(I).vars() ) {
291 fac2OR[I] = alpha;
292 break;
293 }
294 // DIFF WITH JTree::GenerateJT: assert
295 }
296 RecomputeORs();
297
298 // Create inner regions and edges
299 IRs.reserve( _RTree.size() );
300 vector<Edge> edges;
301 edges.reserve( 2 * _RTree.size() );
302 for( size_t i = 0; i < _RTree.size(); i++ ) {
303 edges.push_back( Edge( _RTree[i].n1, IRs.size() ) );
304 edges.push_back( Edge( _RTree[i].n2, IRs.size() ) );
305 // inner clusters have counting number -1
306 IRs.push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
307 }
308
309 // create bipartite graph
310 G.create( nrORs(), nrIRs(), edges.begin(), edges.end() );
311
312 // Check counting numbers
313 Check_Counting_Numbers();
314
315 // Create messages and beliefs
316 _Qa.clear();
317 _Qa.reserve( nrORs() );
318 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
319 _Qa.push_back( OR(alpha) );
320
321 _Qb.clear();
322 _Qb.reserve( nrIRs() );
323 for( size_t beta = 0; beta < nrIRs(); beta++ )
324 _Qb.push_back( Factor( IR(beta), 1.0 ) );
325
326 // DIFF with JTree::GenerateJT: no messages
327
328 // DIFF with JTree::GenerateJT:
329 // Create factor approximations
330 _Q.clear();
331 size_t PreviousRoot = (size_t)-1;
332 for( size_t I = 0; I < nrFactors(); I++ )
333 if( offtree(I) ) {
334 // find efficient subtree
335 DEdgeVec subTree;
336 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
337 PreviousRoot = subTree[0].n1;
338 //subTree.resize( subTreeSize ); // FIXME
339 // cout << "subtree " << I << " has size " << subTreeSize << endl;
340
341 /*
342 char fn[30];
343 sprintf( fn, "/tmp/subtree_%d.dot", I );
344 std::ofstream dots(fn);
345 dots << "graph G {" << endl;
346 dots << "graph[size=\"9,9\"];" << endl;
347 dots << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
348 for( size_t i = 0; i < nrVars(); i++ )
349 dots << "\tx" << var(i).label() << ((factor(I).vars() >> var(i)) ? "[color=blue];" : ";") << endl;
350 dots << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
351 for( size_t J = 0; J < nrFactors(); J++ )
352 dots << "\tp" << J << ";" << endl;
353 for( size_t iI = 0; iI < FactorGraph::nr_edges(); iI++ )
354 dots << "\tx" << var(FactorGraph::edge(iI).first).label() << " -- p" << FactorGraph::edge(iI).second << ";" << endl;
355 for( size_t a = 0; a < tree.size(); a++ )
356 dots << "\tx" << var(tree[a].n1).label() << " -- x" << var(tree[a].n2).label() << " [color=red];" << endl;
357 dots << "}" << endl;
358 dots.close();
359 */
360
361 TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
362 _Q[I] = QI;
363 }
364 // Previous root of first off-tree factor should be the root of the last off-tree factor
365 for( size_t I = 0; I < nrFactors(); I++ )
366 if( offtree(I) ) {
367 DEdgeVec subTree;
368 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
369 PreviousRoot = subTree[0].n1;
370 //subTree.resize( subTreeSize ); // FIXME
371 // cout << "subtree " << I << " has size " << subTreeSize << endl;
372
373 TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
374 _Q[I] = QI;
375 break;
376 }
377
378 if( Verbose() >= 3 ) {
379 cout << "Resulting regiongraph: " << *this << endl;
380 }
381 }
382
383
384 string TreeEP::identify() const {
385 stringstream result (stringstream::out);
386 result << Name << GetProperties();
387 return result.str();
388 }
389
390
391 void TreeEP::init() {
392 assert( checkProperties() );
393
394 runHUGIN();
395
396 // Init factor approximations
397 for( size_t I = 0; I < nrFactors(); I++ )
398 if( offtree(I) )
399 _Q[I].init();
400 }
401
402
403 double TreeEP::run() {
404 if( Verbose() >= 1 )
405 cout << "Starting " << identify() << "...";
406 if( Verbose() >= 3)
407 cout << endl;
408
409 double tic = toc();
410 Diffs diffs(nrVars(), 1.0);
411
412 vector<Factor> old_beliefs;
413 old_beliefs.reserve( nrVars() );
414 for( size_t i = 0; i < nrVars(); i++ )
415 old_beliefs.push_back(belief(var(i)));
416
417 size_t iter = 0;
418
419 // do several passes over the network until maximum number of iterations has
420 // been reached or until the maximum belief difference is smaller than tolerance
421 for( iter=0; iter < MaxIter() && diffs.maxDiff() > Tol(); iter++ ) {
422 for( size_t I = 0; I < nrFactors(); I++ )
423 if( offtree(I) ) {
424 _Q[I].InvertAndMultiply( _Qa, _Qb );
425 _Q[I].HUGIN_with_I( _Qa, _Qb );
426 _Q[I].InvertAndMultiply( _Qa, _Qb );
427 }
428
429 // calculate new beliefs and compare with old ones
430 for( size_t i = 0; i < nrVars(); i++ ) {
431 Factor nb( belief(var(i)) );
432 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
433 old_beliefs[i] = nb;
434 }
435
436 if( Verbose() >= 3 )
437 cout << "TreeEP::run: maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
438 }
439
440 updateMaxDiff( diffs.maxDiff() );
441
442 if( Verbose() >= 1 ) {
443 if( diffs.maxDiff() > Tol() ) {
444 if( Verbose() == 1 )
445 cout << endl;
446 cout << "TreeEP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
447 } else {
448 if( Verbose() >= 3 )
449 cout << "TreeEP::run: ";
450 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
451 }
452 }
453
454 return diffs.maxDiff();
455 }
456
457
458 Real TreeEP::logZ() const {
459 double sum = 0.0;
460
461 // entropy of the tree
462 for( size_t beta = 0; beta < nrIRs(); beta++ )
463 sum -= _Qb[beta].entropy();
464 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
465 sum += _Qa[alpha].entropy();
466
467 // energy of the on-tree factors
468 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
469 sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
470
471 // energy of the off-tree factors
472 for( size_t I = 0; I < nrFactors(); I++ )
473 if( offtree(I) )
474 sum += (_Q.find(I))->second.logZ( _Qa, _Qb );
475
476 return sum;
477 }
478
479
480 } // end of namespace dai