dde9725ed4e7997fe44937c94da42f73b925a7b5
[libdai.git] / src / treeep.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <vector>
15 #include <dai/jtree.h>
16 #include <dai/treeep.h>
17 #include <dai/util.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 const char *TreeEP::Name = "TREEEP";
27
28
29 void TreeEP::setProperties( const PropertySet &opts ) {
30 DAI_ASSERT( opts.hasKey("tol") );
31 DAI_ASSERT( opts.hasKey("maxiter") );
32 DAI_ASSERT( opts.hasKey("type") );
33
34 props.tol = opts.getStringAs<Real>("tol");
35 props.maxiter = opts.getStringAs<size_t>("maxiter");
36 props.type = opts.getStringAs<Properties::TypeType>("type");
37 if( opts.hasKey("verbose") )
38 props.verbose = opts.getStringAs<size_t>("verbose");
39 else
40 props.verbose = 0;
41 }
42
43
44 PropertySet TreeEP::getProperties() const {
45 PropertySet opts;
46 opts.set( "tol", props.tol );
47 opts.set( "maxiter", props.maxiter );
48 opts.set( "verbose", props.verbose );
49 opts.set( "type", props.type );
50 return opts;
51 }
52
53
54 string TreeEP::printProperties() const {
55 stringstream s( stringstream::out );
56 s << "[";
57 s << "tol=" << props.tol << ",";
58 s << "maxiter=" << props.maxiter << ",";
59 s << "verbose=" << props.verbose << ",";
60 s << "type=" << props.type << "]";
61 return s.str();
62 }
63
64
65 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
66 setProperties( opts );
67
68 if( opts.hasKey("tree") ) {
69 construct( fg, opts.getAs<RootedTree>("tree") );
70 } else {
71 if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
72 // ORG: construct weighted graph with as weights a crude estimate of the
73 // mutual information between the nodes
74 // ALT: construct weighted graph with as weights an upper bound on the
75 // effective interaction strength between pairs of nodes
76
77 WeightedGraph<Real> wg;
78 // in order to get a connected weighted graph, we start
79 // by connecting every variable to the zero'th variable with weight 0
80 for( size_t i = 1; i < fg.nrVars(); i++ )
81 wg[UEdge(i,0)] = 0.0;
82 for( size_t i = 0; i < fg.nrVars(); i++ ) {
83 SmallSet<size_t> delta_i = fg.bipGraph().delta1( i, false );
84 const Var& v_i = fg.var(i);
85 foreach( size_t j, delta_i )
86 if( i < j ) {
87 const Var& v_j = fg.var(j);
88 VarSet v_ij( v_i, v_j );
89 SmallSet<size_t> nb_ij = fg.bipGraph().nb1Set( i ) | fg.bipGraph().nb1Set( j );
90 Factor piet;
91 foreach( size_t I, nb_ij ) {
92 const VarSet& Ivars = fg.factor(I).vars();
93 if( props.type == Properties::TypeType::ORG ) {
94 if( (Ivars == v_i) || (Ivars == v_j) )
95 piet *= fg.factor(I);
96 else if( Ivars >> v_ij )
97 piet *= fg.factor(I).marginal( v_ij );
98 } else {
99 if( Ivars >> v_ij )
100 piet *= fg.factor(I);
101 }
102 }
103 if( props.type == Properties::TypeType::ORG ) {
104 if( piet.vars() >> v_ij ) {
105 piet = piet.marginal( v_ij );
106 Factor pietf = piet.marginal(v_i) * piet.marginal(v_j);
107 wg[UEdge(i,j)] = dist( piet, pietf, DISTKL );
108 } else {
109 // this should never happen...
110 DAI_ASSERT( 0 == 1 );
111 wg[UEdge(i,j)] = 0;
112 }
113 } else
114 wg[UEdge(i,j)] = piet.strength(v_i, v_j);
115 }
116 }
117
118 // find maximal spanning tree
119 if( props.verbose >= 3 )
120 cerr << "WeightedGraph: " << wg << endl;
121 RootedTree t = MaxSpanningTree( wg, true );
122 if( props.verbose >= 3 )
123 cerr << "Spanningtree: " << t << endl;
124 construct( fg, t );
125 } else
126 DAI_THROW(UNKNOWN_ENUM_VALUE);
127 }
128 }
129
130
131 void TreeEP::construct( const FactorGraph& fg, const RootedTree& tree ) {
132 // Copy the factor graph
133 FactorGraph::operator=( fg );
134
135 vector<VarSet> cl;
136 for( size_t i = 0; i < tree.size(); i++ )
137 cl.push_back( VarSet( var(tree[i].first), var(tree[i].second) ) );
138
139 // If no outer region can be found subsuming that factor, label the
140 // factor as off-tree.
141 JTree::construct( *this, cl, false );
142
143 if( props.verbose >= 1 )
144 cerr << "TreeEP::construct: The tree has size " << JTree::RTree.size() << endl;
145 if( props.verbose >= 3 )
146 cerr << " it is " << JTree::RTree << " with cliques " << cl << endl;
147
148 // Create factor approximations
149 _Q.clear();
150 size_t PreviousRoot = (size_t)-1;
151 // Second repetition: previous root of first off-tree factor should be the root of the last off-tree factor
152 for( size_t repeats = 0; repeats < 2; repeats++ )
153 for( size_t I = 0; I < nrFactors(); I++ )
154 if( offtree(I) ) {
155 // find efficient subtree
156 RootedTree subTree;
157 size_t subTreeSize = findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
158 PreviousRoot = subTree[0].first;
159 subTree.resize( subTreeSize );
160 if( props.verbose >= 1 )
161 cerr << "Subtree " << I << " has size " << subTreeSize << endl;
162 if( props.verbose >= 3 )
163 cerr << " it is " << subTree << endl;
164 _Q[I] = TreeEPSubTree( subTree, RTree, Qa, Qb, &factor(I) );
165 if( repeats == 1 )
166 break;
167 }
168
169 if( props.verbose >= 3 )
170 cerr << "Resulting regiongraph: " << *this << endl;
171 }
172
173
174 string TreeEP::identify() const {
175 return string(Name) + printProperties();
176 }
177
178
179 void TreeEP::init() {
180 runHUGIN();
181
182 // Init factor approximations
183 for( size_t I = 0; I < nrFactors(); I++ )
184 if( offtree(I) )
185 _Q[I].init();
186 }
187
188
189 Real TreeEP::run() {
190 if( props.verbose >= 1 )
191 cerr << "Starting " << identify() << "...";
192 if( props.verbose >= 3 )
193 cerr << endl;
194
195 double tic = toc();
196
197 vector<Factor> oldBeliefs = beliefs();
198
199 // do several passes over the network until maximum number of iterations has
200 // been reached or until the maximum belief difference is smaller than tolerance
201 Real maxDiff = INFINITY;
202 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol; _iters++ ) {
203 for( size_t I = 0; I < nrFactors(); I++ )
204 if( offtree(I) ) {
205 _Q[I].InvertAndMultiply( Qa, Qb );
206 _Q[I].HUGIN_with_I( Qa, Qb );
207 _Q[I].InvertAndMultiply( Qa, Qb );
208 }
209
210 // calculate new beliefs and compare with old ones
211 vector<Factor> newBeliefs = beliefs();
212 maxDiff = -INFINITY;
213 for( size_t t = 0; t < oldBeliefs.size(); t++ )
214 maxDiff = std::max( maxDiff, dist( newBeliefs[t], oldBeliefs[t], DISTLINF ) );
215 swap( newBeliefs, oldBeliefs );
216
217 if( props.verbose >= 3 )
218 cerr << Name << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
219 }
220
221 if( maxDiff > _maxdiff )
222 _maxdiff = maxDiff;
223
224 if( props.verbose >= 1 ) {
225 if( maxDiff > props.tol ) {
226 if( props.verbose == 1 )
227 cerr << endl;
228 cerr << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
229 } else {
230 if( props.verbose >= 3 )
231 cerr << Name << "::run: ";
232 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
233 }
234 }
235
236 return maxDiff;
237 }
238
239
240 Real TreeEP::logZ() const {
241 Real s = 0.0;
242
243 // entropy of the tree
244 for( size_t beta = 0; beta < nrIRs(); beta++ )
245 s -= Qb[beta].entropy();
246 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
247 s += Qa[alpha].entropy();
248
249 // energy of the on-tree factors
250 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
251 s += (OR(alpha).log(true) * Qa[alpha]).sum();
252
253 // energy of the off-tree factors
254 for( size_t I = 0; I < nrFactors(); I++ )
255 if( offtree(I) )
256 s += (_Q.find(I))->second.logZ( Qa, Qb );
257
258 return s;
259 }
260
261
262 TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
263 _ns = _I->vars();
264
265 // Make _Qa, _Qb, _a and _b corresponding to the subtree
266 _b.reserve( subRTree.size() );
267 _Qb.reserve( subRTree.size() );
268 _RTree.reserve( subRTree.size() );
269 for( size_t i = 0; i < subRTree.size(); i++ ) {
270 size_t alpha1 = subRTree[i].first; // old index 1
271 size_t alpha2 = subRTree[i].second; // old index 2
272 size_t beta; // old sep index
273 for( beta = 0; beta < jt_RTree.size(); beta++ )
274 if( UEdge( jt_RTree[beta].first, jt_RTree[beta].second ) == UEdge( alpha1, alpha2 ) )
275 break;
276 DAI_ASSERT( beta != jt_RTree.size() );
277
278 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
279 if( newalpha1 == _a.size() ) {
280 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
281 _a.push_back( alpha1 ); // save old index in index conversion table
282 }
283
284 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
285 if( newalpha2 == _a.size() ) {
286 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
287 _a.push_back( alpha2 ); // save old index in index conversion table
288 }
289
290 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
291 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
292 _b.push_back( beta );
293 }
294
295 // Find remaining variables (which are not in the new root)
296 _nsrem = _ns / _Qa[0].vars();
297 }
298
299
300 void TreeEP::TreeEPSubTree::init() {
301 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
302 _Qa[alpha].fill( 1.0 );
303 for( size_t beta = 0; beta < _Qb.size(); beta++ )
304 _Qb[beta].fill( 1.0 );
305 }
306
307
308 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
309 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
310 _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
311
312 for( size_t beta = 0; beta < _Qb.size(); beta++ )
313 _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
314 }
315
316
317 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
318 // Backup _Qa and _Qb
319 vector<Factor> _Qa_old(_Qa);
320 vector<Factor> _Qb_old(_Qb);
321
322 // Clear Qa and Qb
323 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
324 Qa[_a[alpha]].fill( 0.0 );
325 for( size_t beta = 0; beta < _Qb.size(); beta++ )
326 Qb[_b[beta]].fill( 0.0 );
327
328 // For all states of _nsrem
329 for( State s(_nsrem); s.valid(); s++ ) {
330 // Multiply root with slice of I
331 _Qa[0] *= _I->slice( _nsrem, s );
332
333 // CollectEvidence
334 for( size_t i = _RTree.size(); (i--) != 0; ) {
335 // clamp variables in nsrem
336 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
337 if( _Qa[_RTree[i].second].vars() >> *n )
338 _Qa[_RTree[i].second] *= createFactorDelta( *n, s(*n) );
339 Factor new_Qb = _Qa[_RTree[i].second].marginal( _Qb[i].vars(), false );
340 _Qa[_RTree[i].first] *= new_Qb / _Qb[i];
341 _Qb[i] = new_Qb;
342 }
343
344 // DistributeEvidence
345 for( size_t i = 0; i < _RTree.size(); i++ ) {
346 Factor new_Qb = _Qa[_RTree[i].first].marginal( _Qb[i].vars(), false );
347 _Qa[_RTree[i].second] *= new_Qb / _Qb[i];
348 _Qb[i] = new_Qb;
349 }
350
351 // Store Qa's and Qb's
352 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
353 Qa[_a[alpha]].p() += _Qa[alpha].p();
354 for( size_t beta = 0; beta < _Qb.size(); beta++ )
355 Qb[_b[beta]].p() += _Qb[beta].p();
356
357 // Restore _Qa and _Qb
358 _Qa = _Qa_old;
359 _Qb = _Qb_old;
360 }
361
362 // Normalize Qa and Qb
363 _logZ = 0.0;
364 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
365 _logZ += log(Qa[_a[alpha]].sum());
366 Qa[_a[alpha]].normalize();
367 }
368 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
369 _logZ -= log(Qb[_b[beta]].sum());
370 Qb[_b[beta]].normalize();
371 }
372 }
373
374
375 Real TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
376 Real s = 0.0;
377 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
378 s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
379 for( size_t beta = 0; beta < _Qb.size(); beta++ )
380 s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
381 return s + _logZ;
382 }
383
384
385 } // end of namespace dai