Improved TreeEP (added 'maxtime' property)
[libdai.git] / src / treeep.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <vector>
15 #include <dai/jtree.h>
16 #include <dai/treeep.h>
17 #include <dai/util.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 const char *TreeEP::Name = "TREEEP";
27
28
29 void TreeEP::setProperties( const PropertySet &opts ) {
30 DAI_ASSERT( opts.hasKey("tol") );
31 DAI_ASSERT( opts.hasKey("type") );
32
33 props.tol = opts.getStringAs<Real>("tol");
34 props.type = opts.getStringAs<Properties::TypeType>("type");
35 if( opts.hasKey("maxiter") )
36 props.maxiter = opts.getStringAs<size_t>("maxiter");
37 else
38 props.maxiter = 10000;
39 if( opts.hasKey("maxtime") )
40 props.maxtime = opts.getStringAs<Real>("maxtime");
41 else
42 props.maxtime = INFINITY;
43 if( opts.hasKey("verbose") )
44 props.verbose = opts.getStringAs<size_t>("verbose");
45 else
46 props.verbose = 0;
47 }
48
49
50 PropertySet TreeEP::getProperties() const {
51 PropertySet opts;
52 opts.set( "tol", props.tol );
53 opts.set( "maxiter", props.maxiter );
54 opts.set( "maxtime", props.maxtime );
55 opts.set( "verbose", props.verbose );
56 opts.set( "type", props.type );
57 return opts;
58 }
59
60
61 string TreeEP::printProperties() const {
62 stringstream s( stringstream::out );
63 s << "[";
64 s << "tol=" << props.tol << ",";
65 s << "maxiter=" << props.maxiter << ",";
66 s << "maxtime=" << props.maxtime << ",";
67 s << "verbose=" << props.verbose << ",";
68 s << "type=" << props.type << "]";
69 return s.str();
70 }
71
72
73 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
74 setProperties( opts );
75
76 if( opts.hasKey("tree") ) {
77 construct( fg, opts.getAs<RootedTree>("tree") );
78 } else {
79 if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
80 // ORG: construct weighted graph with as weights a crude estimate of the
81 // mutual information between the nodes
82 // ALT: construct weighted graph with as weights an upper bound on the
83 // effective interaction strength between pairs of nodes
84
85 WeightedGraph<Real> wg;
86 // in order to get a connected weighted graph, we start
87 // by connecting every variable to the zero'th variable with weight 0
88 for( size_t i = 1; i < fg.nrVars(); i++ )
89 wg[UEdge(i,0)] = 0.0;
90 for( size_t i = 0; i < fg.nrVars(); i++ ) {
91 SmallSet<size_t> delta_i = fg.bipGraph().delta1( i, false );
92 const Var& v_i = fg.var(i);
93 foreach( size_t j, delta_i )
94 if( i < j ) {
95 const Var& v_j = fg.var(j);
96 VarSet v_ij( v_i, v_j );
97 SmallSet<size_t> nb_ij = fg.bipGraph().nb1Set( i ) | fg.bipGraph().nb1Set( j );
98 Factor piet;
99 foreach( size_t I, nb_ij ) {
100 const VarSet& Ivars = fg.factor(I).vars();
101 if( props.type == Properties::TypeType::ORG ) {
102 if( (Ivars == v_i) || (Ivars == v_j) )
103 piet *= fg.factor(I);
104 else if( Ivars >> v_ij )
105 piet *= fg.factor(I).marginal( v_ij );
106 } else {
107 if( Ivars >> v_ij )
108 piet *= fg.factor(I);
109 }
110 }
111 if( props.type == Properties::TypeType::ORG ) {
112 if( piet.vars() >> v_ij ) {
113 piet = piet.marginal( v_ij );
114 Factor pietf = piet.marginal(v_i) * piet.marginal(v_j);
115 wg[UEdge(i,j)] = dist( piet, pietf, DISTKL );
116 } else {
117 // this should never happen...
118 DAI_ASSERT( 0 == 1 );
119 wg[UEdge(i,j)] = 0;
120 }
121 } else
122 wg[UEdge(i,j)] = piet.strength(v_i, v_j);
123 }
124 }
125
126 // find maximal spanning tree
127 if( props.verbose >= 3 )
128 cerr << "WeightedGraph: " << wg << endl;
129 RootedTree t = MaxSpanningTree( wg, true );
130 if( props.verbose >= 3 )
131 cerr << "Spanningtree: " << t << endl;
132 construct( fg, t );
133 } else
134 DAI_THROW(UNKNOWN_ENUM_VALUE);
135 }
136 }
137
138
139 void TreeEP::construct( const FactorGraph& fg, const RootedTree& tree ) {
140 // Copy the factor graph
141 FactorGraph::operator=( fg );
142
143 vector<VarSet> cl;
144 for( size_t i = 0; i < tree.size(); i++ )
145 cl.push_back( VarSet( var(tree[i].first), var(tree[i].second) ) );
146
147 // If no outer region can be found subsuming that factor, label the
148 // factor as off-tree.
149 JTree::construct( *this, cl, false );
150
151 if( props.verbose >= 1 )
152 cerr << "TreeEP::construct: The tree has size " << JTree::RTree.size() << endl;
153 if( props.verbose >= 3 )
154 cerr << " it is " << JTree::RTree << " with cliques " << cl << endl;
155
156 // Create factor approximations
157 _Q.clear();
158 size_t PreviousRoot = (size_t)-1;
159 // Second repetition: previous root of first off-tree factor should be the root of the last off-tree factor
160 for( size_t repeats = 0; repeats < 2; repeats++ )
161 for( size_t I = 0; I < nrFactors(); I++ )
162 if( offtree(I) ) {
163 // find efficient subtree
164 RootedTree subTree;
165 size_t subTreeSize = findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
166 PreviousRoot = subTree[0].first;
167 subTree.resize( subTreeSize );
168 if( props.verbose >= 1 )
169 cerr << "Subtree " << I << " has size " << subTreeSize << endl;
170 if( props.verbose >= 3 )
171 cerr << " it is " << subTree << endl;
172 _Q[I] = TreeEPSubTree( subTree, RTree, Qa, Qb, &factor(I) );
173 if( repeats == 1 )
174 break;
175 }
176
177 if( props.verbose >= 3 )
178 cerr << "Resulting regiongraph: " << *this << endl;
179 }
180
181
182 string TreeEP::identify() const {
183 return string(Name) + printProperties();
184 }
185
186
187 void TreeEP::init() {
188 runHUGIN();
189
190 // Init factor approximations
191 for( size_t I = 0; I < nrFactors(); I++ )
192 if( offtree(I) )
193 _Q[I].init();
194 }
195
196
197 Real TreeEP::run() {
198 if( props.verbose >= 1 )
199 cerr << "Starting " << identify() << "...";
200 if( props.verbose >= 3 )
201 cerr << endl;
202
203 double tic = toc();
204
205 vector<Factor> oldBeliefs = beliefs();
206
207 // do several passes over the network until maximum number of iterations has
208 // been reached or until the maximum belief difference is smaller than tolerance
209 Real maxDiff = INFINITY;
210 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol && (toc() - tic) < props.maxtime; _iters++ ) {
211 for( size_t I = 0; I < nrFactors(); I++ )
212 if( offtree(I) ) {
213 _Q[I].InvertAndMultiply( Qa, Qb );
214 _Q[I].HUGIN_with_I( Qa, Qb );
215 _Q[I].InvertAndMultiply( Qa, Qb );
216 }
217
218 // calculate new beliefs and compare with old ones
219 vector<Factor> newBeliefs = beliefs();
220 maxDiff = -INFINITY;
221 for( size_t t = 0; t < oldBeliefs.size(); t++ )
222 maxDiff = std::max( maxDiff, dist( newBeliefs[t], oldBeliefs[t], DISTLINF ) );
223 swap( newBeliefs, oldBeliefs );
224
225 if( props.verbose >= 3 )
226 cerr << Name << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
227 }
228
229 if( maxDiff > _maxdiff )
230 _maxdiff = maxDiff;
231
232 if( props.verbose >= 1 ) {
233 if( maxDiff > props.tol ) {
234 if( props.verbose == 1 )
235 cerr << endl;
236 cerr << Name << "::run: WARNING: not converged after " << _iters << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
237 } else {
238 if( props.verbose >= 3 )
239 cerr << Name << "::run: ";
240 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
241 }
242 }
243
244 return maxDiff;
245 }
246
247
248 Real TreeEP::logZ() const {
249 Real s = 0.0;
250
251 // entropy of the tree
252 for( size_t beta = 0; beta < nrIRs(); beta++ )
253 s -= Qb[beta].entropy();
254 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
255 s += Qa[alpha].entropy();
256
257 // energy of the on-tree factors
258 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
259 s += (OR(alpha).log(true) * Qa[alpha]).sum();
260
261 // energy of the off-tree factors
262 for( size_t I = 0; I < nrFactors(); I++ )
263 if( offtree(I) )
264 s += (_Q.find(I))->second.logZ( Qa, Qb );
265
266 return s;
267 }
268
269
270 TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
271 _ns = _I->vars();
272
273 // Make _Qa, _Qb, _a and _b corresponding to the subtree
274 _b.reserve( subRTree.size() );
275 _Qb.reserve( subRTree.size() );
276 _RTree.reserve( subRTree.size() );
277 for( size_t i = 0; i < subRTree.size(); i++ ) {
278 size_t alpha1 = subRTree[i].first; // old index 1
279 size_t alpha2 = subRTree[i].second; // old index 2
280 size_t beta; // old sep index
281 for( beta = 0; beta < jt_RTree.size(); beta++ )
282 if( UEdge( jt_RTree[beta].first, jt_RTree[beta].second ) == UEdge( alpha1, alpha2 ) )
283 break;
284 DAI_ASSERT( beta != jt_RTree.size() );
285
286 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
287 if( newalpha1 == _a.size() ) {
288 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
289 _a.push_back( alpha1 ); // save old index in index conversion table
290 }
291
292 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
293 if( newalpha2 == _a.size() ) {
294 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
295 _a.push_back( alpha2 ); // save old index in index conversion table
296 }
297
298 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
299 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
300 _b.push_back( beta );
301 }
302
303 // Find remaining variables (which are not in the new root)
304 _nsrem = _ns / _Qa[0].vars();
305 }
306
307
308 void TreeEP::TreeEPSubTree::init() {
309 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
310 _Qa[alpha].fill( 1.0 );
311 for( size_t beta = 0; beta < _Qb.size(); beta++ )
312 _Qb[beta].fill( 1.0 );
313 }
314
315
316 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
317 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
318 _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
319
320 for( size_t beta = 0; beta < _Qb.size(); beta++ )
321 _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
322 }
323
324
325 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
326 // Backup _Qa and _Qb
327 vector<Factor> _Qa_old(_Qa);
328 vector<Factor> _Qb_old(_Qb);
329
330 // Clear Qa and Qb
331 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
332 Qa[_a[alpha]].fill( 0.0 );
333 for( size_t beta = 0; beta < _Qb.size(); beta++ )
334 Qb[_b[beta]].fill( 0.0 );
335
336 // For all states of _nsrem
337 for( State s(_nsrem); s.valid(); s++ ) {
338 // Multiply root with slice of I
339 _Qa[0] *= _I->slice( _nsrem, s );
340
341 // CollectEvidence
342 for( size_t i = _RTree.size(); (i--) != 0; ) {
343 // clamp variables in nsrem
344 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
345 if( _Qa[_RTree[i].second].vars() >> *n )
346 _Qa[_RTree[i].second] *= createFactorDelta( *n, s(*n) );
347 Factor new_Qb = _Qa[_RTree[i].second].marginal( _Qb[i].vars(), false );
348 _Qa[_RTree[i].first] *= new_Qb / _Qb[i];
349 _Qb[i] = new_Qb;
350 }
351
352 // DistributeEvidence
353 for( size_t i = 0; i < _RTree.size(); i++ ) {
354 Factor new_Qb = _Qa[_RTree[i].first].marginal( _Qb[i].vars(), false );
355 _Qa[_RTree[i].second] *= new_Qb / _Qb[i];
356 _Qb[i] = new_Qb;
357 }
358
359 // Store Qa's and Qb's
360 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
361 Qa[_a[alpha]].p() += _Qa[alpha].p();
362 for( size_t beta = 0; beta < _Qb.size(); beta++ )
363 Qb[_b[beta]].p() += _Qb[beta].p();
364
365 // Restore _Qa and _Qb
366 _Qa = _Qa_old;
367 _Qb = _Qb_old;
368 }
369
370 // Normalize Qa and Qb
371 _logZ = 0.0;
372 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
373 _logZ += log(Qa[_a[alpha]].sum());
374 Qa[_a[alpha]].normalize();
375 }
376 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
377 _logZ -= log(Qb[_b[beta]].sum());
378 Qb[_b[beta]].normalize();
379 }
380 }
381
382
383 Real TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
384 Real s = 0.0;
385 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
386 s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
387 for( size_t beta = 0; beta < _Qb.size(); beta++ )
388 s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
389 return s + _logZ;
390 }
391
392
393 } // end of namespace dai