Merged mf.* from SVN head (which implements damping)...
[libdai.git] / src / treeep.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <fstream>
24 #include <vector>
25 #include <dai/jtree.h>
26 #include <dai/treeep.h>
27 #include <dai/util.h>
28 #include <dai/diffs.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *TreeEP::Name = "TREEEP";
38
39
40 void TreeEP::setProperties( const PropertySet &opts ) {
41 assert( opts.hasKey("tol") );
42 assert( opts.hasKey("maxiter") );
43 assert( opts.hasKey("verbose") );
44 assert( opts.hasKey("type") );
45
46 props.tol = opts.getStringAs<double>("tol");
47 props.maxiter = opts.getStringAs<size_t>("maxiter");
48 props.verbose = opts.getStringAs<size_t>("verbose");
49 props.type = opts.getStringAs<Properties::TypeType>("type");
50 }
51
52
53 PropertySet TreeEP::getProperties() const {
54 PropertySet opts;
55 opts.Set( "tol", props.tol );
56 opts.Set( "maxiter", props.maxiter );
57 opts.Set( "verbose", props.verbose );
58 opts.Set( "type", props.type );
59 return opts;
60 }
61
62
63 string TreeEP::printProperties() const {
64 stringstream s( stringstream::out );
65 s << "[";
66 s << "tol=" << props.tol << ",";
67 s << "maxiter=" << props.maxiter << ",";
68 s << "verbose=" << props.verbose << ",";
69 s << "type=" << props.type << "]";
70 return s.str();
71 }
72
73
74 TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
75 _ns = _I->vars();
76
77 // Make _Qa, _Qb, _a and _b corresponding to the subtree
78 _b.reserve( subRTree.size() );
79 _Qb.reserve( subRTree.size() );
80 _RTree.reserve( subRTree.size() );
81 for( size_t i = 0; i < subRTree.size(); i++ ) {
82 size_t alpha1 = subRTree[i].n1; // old index 1
83 size_t alpha2 = subRTree[i].n2; // old index 2
84 size_t beta; // old sep index
85 for( beta = 0; beta < jt_RTree.size(); beta++ )
86 if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
87 break;
88 assert( beta != jt_RTree.size() );
89
90 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
91 if( newalpha1 == _a.size() ) {
92 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
93 _a.push_back( alpha1 ); // save old index in index conversion table
94 }
95
96 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
97 if( newalpha2 == _a.size() ) {
98 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
99 _a.push_back( alpha2 ); // save old index in index conversion table
100 }
101
102 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
103 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
104 _b.push_back( beta );
105 }
106
107 // Find remaining variables (which are not in the new root)
108 _nsrem = _ns / _Qa[0].vars();
109 };
110
111
112 void TreeEPSubTree::init() {
113 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
114 _Qa[alpha].fill( 1.0 );
115 for( size_t beta = 0; beta < _Qb.size(); beta++ )
116 _Qb[beta].fill( 1.0 );
117 }
118
119
120 void TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
121 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
122 _Qa[alpha] = Qa[_a[alpha]].divided_by( _Qa[alpha] );
123
124 for( size_t beta = 0; beta < _Qb.size(); beta++ )
125 _Qb[beta] = Qb[_b[beta]].divided_by( _Qb[beta] );
126 }
127
128
129 void TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
130 // Backup _Qa and _Qb
131 vector<Factor> _Qa_old(_Qa);
132 vector<Factor> _Qb_old(_Qb);
133
134 // Clear Qa and Qb
135 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
136 Qa[_a[alpha]].fill( 0.0 );
137 for( size_t beta = 0; beta < _Qb.size(); beta++ )
138 Qb[_b[beta]].fill( 0.0 );
139
140 // For all states of _nsrem
141 for( State s(_nsrem); s.valid(); s++ ) {
142 // Multiply root with slice of I
143 _Qa[0] *= _I->slice( _nsrem, s );
144
145 // CollectEvidence
146 for( size_t i = _RTree.size(); (i--) != 0; ) {
147 // clamp variables in nsrem
148 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
149 if( _Qa[_RTree[i].n2].vars() >> *n ) {
150 Factor delta( *n, 0.0 );
151 delta[s(*n)] = 1.0;
152 _Qa[_RTree[i].n2] *= delta;
153 }
154 Factor new_Qb = _Qa[_RTree[i].n2].partSum( _Qb[i].vars() );
155 _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
156 _Qb[i] = new_Qb;
157 }
158
159 // DistributeEvidence
160 for( size_t i = 0; i < _RTree.size(); i++ ) {
161 Factor new_Qb = _Qa[_RTree[i].n1].partSum( _Qb[i].vars() );
162 _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
163 _Qb[i] = new_Qb;
164 }
165
166 // Store Qa's and Qb's
167 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
168 Qa[_a[alpha]].p() += _Qa[alpha].p();
169 for( size_t beta = 0; beta < _Qb.size(); beta++ )
170 Qb[_b[beta]].p() += _Qb[beta].p();
171
172 // Restore _Qa and _Qb
173 _Qa = _Qa_old;
174 _Qb = _Qb_old;
175 }
176
177 // Normalize Qa and Qb
178 _logZ = 0.0;
179 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
180 _logZ += log(Qa[_a[alpha]].totalSum());
181 Qa[_a[alpha]].normalize( Prob::NORMPROB );
182 }
183 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
184 _logZ -= log(Qb[_b[beta]].totalSum());
185 Qb[_b[beta]].normalize( Prob::NORMPROB );
186 }
187 }
188
189
190 double TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
191 double sum = 0.0;
192 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
193 sum += (Qa[_a[alpha]] * _Qa[alpha].log0()).totalSum();
194 for( size_t beta = 0; beta < _Qb.size(); beta++ )
195 sum -= (Qb[_b[beta]] * _Qb[beta].log0()).totalSum();
196 return sum + _logZ;
197 }
198
199
200 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), props(), maxdiff(0.0) {
201 setProperties( opts );
202
203 assert( fg.isConnected() );
204
205 if( opts.hasKey("tree") ) {
206 ConstructRG( opts.GetAs<DEdgeVec>("tree") );
207 } else {
208 if( props.type == Properties::TypeType::ORG ) {
209 // construct weighted graph with as weights a crude estimate of the
210 // mutual information between the nodes
211 WeightedGraph<double> wg;
212 for( size_t i = 0; i < nrVars(); ++i ) {
213 Var v_i = var(i);
214 VarSet di = delta(i);
215 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
216 if( v_i < *j ) {
217 Factor piet;
218 for( size_t I = 0; I < nrFactors(); I++ ) {
219 VarSet Ivars = factor(I).vars();
220 if( (Ivars == v_i) || (Ivars == *j) )
221 piet *= factor(I);
222 else if( Ivars >> (v_i | *j) )
223 piet *= factor(I).marginal( v_i | *j );
224 }
225 if( piet.vars() >> (v_i | *j) ) {
226 piet = piet.marginal( v_i | *j );
227 Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
228 wg[UEdge(i,findVar(*j))] = KL_dist( piet, pietf );
229 } else
230 wg[UEdge(i,findVar(*j))] = 0;
231 }
232 }
233
234 // find maximal spanning tree
235 ConstructRG( MaxSpanningTreePrims( wg ) );
236
237 // cout << "Constructing maximum spanning tree..." << endl;
238 // DEdgeVec MST = MaxSpanningTreePrims( wg );
239 // cout << "Maximum spanning tree:" << endl;
240 // for( DEdgeVec::const_iterator e = MST.begin(); e != MST.end(); e++ )
241 // cout << *e << endl;
242 // ConstructRG( MST );
243 } else if( props.type == Properties::TypeType::ALT ) {
244 // construct weighted graph with as weights an upper bound on the
245 // effective interaction strength between pairs of nodes
246 WeightedGraph<double> wg;
247 for( size_t i = 0; i < nrVars(); ++i ) {
248 Var v_i = var(i);
249 VarSet di = delta(i);
250 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
251 if( v_i < *j ) {
252 Factor piet;
253 for( size_t I = 0; I < nrFactors(); I++ ) {
254 VarSet Ivars = factor(I).vars();
255 if( Ivars >> (v_i | *j) )
256 piet *= factor(I);
257 }
258 wg[UEdge(i,findVar(*j))] = piet.strength(v_i, *j);
259 }
260 }
261
262 // find maximal spanning tree
263 ConstructRG( MaxSpanningTreePrims( wg ) );
264 } else {
265 DAI_THROW(INTERNAL_ERROR);
266 }
267 }
268 }
269
270
271 void TreeEP::ConstructRG( const DEdgeVec &tree ) {
272 vector<VarSet> Cliques;
273 for( size_t i = 0; i < tree.size(); i++ )
274 Cliques.push_back( var(tree[i].n1) | var(tree[i].n2) );
275
276 // Construct a weighted graph (each edge is weighted with the cardinality
277 // of the intersection of the nodes, where the nodes are the elements of
278 // Cliques).
279 WeightedGraph<int> JuncGraph;
280 for( size_t i = 0; i < Cliques.size(); i++ )
281 for( size_t j = i+1; j < Cliques.size(); j++ ) {
282 size_t w = (Cliques[i] & Cliques[j]).size();
283 JuncGraph[UEdge(i,j)] = w;
284 }
285
286 // Construct maximal spanning tree using Prim's algorithm
287 _RTree = MaxSpanningTreePrims( JuncGraph );
288
289 // Construct corresponding region graph
290
291 // Create outer regions
292 ORs.reserve( Cliques.size() );
293 for( size_t i = 0; i < Cliques.size(); i++ )
294 ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
295
296 // For each factor, find an outer region that subsumes that factor.
297 // Then, multiply the outer region with that factor.
298 // If no outer region can be found subsuming that factor, label the
299 // factor as off-tree.
300 fac2OR.clear();
301 fac2OR.resize( nrFactors(), -1U );
302 for( size_t I = 0; I < nrFactors(); I++ ) {
303 size_t alpha;
304 for( alpha = 0; alpha < nrORs(); alpha++ )
305 if( OR(alpha).vars() >> factor(I).vars() ) {
306 fac2OR[I] = alpha;
307 break;
308 }
309 // DIFF WITH JTree::GenerateJT: assert
310 }
311 RecomputeORs();
312
313 // Create inner regions and edges
314 IRs.reserve( _RTree.size() );
315 vector<Edge> edges;
316 edges.reserve( 2 * _RTree.size() );
317 for( size_t i = 0; i < _RTree.size(); i++ ) {
318 edges.push_back( Edge( _RTree[i].n1, IRs.size() ) );
319 edges.push_back( Edge( _RTree[i].n2, IRs.size() ) );
320 // inner clusters have counting number -1
321 IRs.push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
322 }
323
324 // create bipartite graph
325 G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
326
327 // Check counting numbers
328 Check_Counting_Numbers();
329
330 // Create messages and beliefs
331 _Qa.clear();
332 _Qa.reserve( nrORs() );
333 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
334 _Qa.push_back( OR(alpha) );
335
336 _Qb.clear();
337 _Qb.reserve( nrIRs() );
338 for( size_t beta = 0; beta < nrIRs(); beta++ )
339 _Qb.push_back( Factor( IR(beta), 1.0 ) );
340
341 // DIFF with JTree::GenerateJT: no messages
342
343 // DIFF with JTree::GenerateJT:
344 // Create factor approximations
345 _Q.clear();
346 size_t PreviousRoot = (size_t)-1;
347 for( size_t I = 0; I < nrFactors(); I++ )
348 if( offtree(I) ) {
349 // find efficient subtree
350 DEdgeVec subTree;
351 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
352 PreviousRoot = subTree[0].n1;
353 //subTree.resize( subTreeSize ); // FIXME
354 // cout << "subtree " << I << " has size " << subTreeSize << endl;
355
356 /*
357 char fn[30];
358 sprintf( fn, "/tmp/subtree_%d.dot", I );
359 std::ofstream dots(fn);
360 dots << "graph G {" << endl;
361 dots << "graph[size=\"9,9\"];" << endl;
362 dots << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
363 for( size_t i = 0; i < nrVars(); i++ )
364 dots << "\tx" << var(i).label() << ((factor(I).vars() >> var(i)) ? "[color=blue];" : ";") << endl;
365 dots << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
366 for( size_t J = 0; J < nrFactors(); J++ )
367 dots << "\tp" << J << ";" << endl;
368 for( size_t iI = 0; iI < FactorGraph::nr_edges(); iI++ )
369 dots << "\tx" << var(FactorGraph::edge(iI).first).label() << " -- p" << FactorGraph::edge(iI).second << ";" << endl;
370 for( size_t a = 0; a < tree.size(); a++ )
371 dots << "\tx" << var(tree[a].n1).label() << " -- x" << var(tree[a].n2).label() << " [color=red];" << endl;
372 dots << "}" << endl;
373 dots.close();
374 */
375
376 TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
377 _Q[I] = QI;
378 }
379 // Previous root of first off-tree factor should be the root of the last off-tree factor
380 for( size_t I = 0; I < nrFactors(); I++ )
381 if( offtree(I) ) {
382 DEdgeVec subTree;
383 /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
384 PreviousRoot = subTree[0].n1;
385 //subTree.resize( subTreeSize ); // FIXME
386 // cout << "subtree " << I << " has size " << subTreeSize << endl;
387
388 TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
389 _Q[I] = QI;
390 break;
391 }
392
393 if( props.verbose >= 3 ) {
394 cout << "Resulting regiongraph: " << *this << endl;
395 }
396 }
397
398
399 string TreeEP::identify() const {
400 return string(Name) + printProperties();
401 }
402
403
404 void TreeEP::init() {
405 runHUGIN();
406
407 // Init factor approximations
408 for( size_t I = 0; I < nrFactors(); I++ )
409 if( offtree(I) )
410 _Q[I].init();
411 }
412
413
414 double TreeEP::run() {
415 if( props.verbose >= 1 )
416 cout << "Starting " << identify() << "...";
417 if( props.verbose >= 3)
418 cout << endl;
419
420 double tic = toc();
421 Diffs diffs(nrVars(), 1.0);
422
423 vector<Factor> old_beliefs;
424 old_beliefs.reserve( nrVars() );
425 for( size_t i = 0; i < nrVars(); i++ )
426 old_beliefs.push_back(belief(var(i)));
427
428 size_t iter = 0;
429
430 // do several passes over the network until maximum number of iterations has
431 // been reached or until the maximum belief difference is smaller than tolerance
432 for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; iter++ ) {
433 for( size_t I = 0; I < nrFactors(); I++ )
434 if( offtree(I) ) {
435 _Q[I].InvertAndMultiply( _Qa, _Qb );
436 _Q[I].HUGIN_with_I( _Qa, _Qb );
437 _Q[I].InvertAndMultiply( _Qa, _Qb );
438 }
439
440 // calculate new beliefs and compare with old ones
441 for( size_t i = 0; i < nrVars(); i++ ) {
442 Factor nb( belief(var(i)) );
443 diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
444 old_beliefs[i] = nb;
445 }
446
447 if( props.verbose >= 3 )
448 cout << "TreeEP::run: maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
449 }
450
451 if( diffs.maxDiff() > maxdiff )
452 maxdiff = diffs.maxDiff();
453
454 if( props.verbose >= 1 ) {
455 if( diffs.maxDiff() > props.tol ) {
456 if( props.verbose == 1 )
457 cout << endl;
458 cout << "TreeEP::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
459 } else {
460 if( props.verbose >= 3 )
461 cout << "TreeEP::run: ";
462 cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
463 }
464 }
465
466 return diffs.maxDiff();
467 }
468
469
470 Real TreeEP::logZ() const {
471 double sum = 0.0;
472
473 // entropy of the tree
474 for( size_t beta = 0; beta < nrIRs(); beta++ )
475 sum -= _Qb[beta].entropy();
476 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
477 sum += _Qa[alpha].entropy();
478
479 // energy of the on-tree factors
480 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
481 sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
482
483 // energy of the off-tree factors
484 for( size_t I = 0; I < nrFactors(); I++ )
485 if( offtree(I) )
486 sum += (_Q.find(I))->second.logZ( _Qa, _Qb );
487
488 return sum;
489 }
490
491
492 } // end of namespace dai