Finished release 0.2.4
[libdai.git] / src / treeep.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <vector>
15 #include <dai/jtree.h>
16 #include <dai/treeep.h>
17 #include <dai/util.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 const char *TreeEP::Name = "TREEEP";
27
28
29 void TreeEP::setProperties( const PropertySet &opts ) {
30 DAI_ASSERT( opts.hasKey("tol") );
31 DAI_ASSERT( opts.hasKey("maxiter") );
32 DAI_ASSERT( opts.hasKey("verbose") );
33 DAI_ASSERT( opts.hasKey("type") );
34
35 props.tol = opts.getStringAs<Real>("tol");
36 props.maxiter = opts.getStringAs<size_t>("maxiter");
37 props.verbose = opts.getStringAs<size_t>("verbose");
38 props.type = opts.getStringAs<Properties::TypeType>("type");
39 }
40
41
42 PropertySet TreeEP::getProperties() const {
43 PropertySet opts;
44 opts.Set( "tol", props.tol );
45 opts.Set( "maxiter", props.maxiter );
46 opts.Set( "verbose", props.verbose );
47 opts.Set( "type", props.type );
48 return opts;
49 }
50
51
52 string TreeEP::printProperties() const {
53 stringstream s( stringstream::out );
54 s << "[";
55 s << "tol=" << props.tol << ",";
56 s << "maxiter=" << props.maxiter << ",";
57 s << "verbose=" << props.verbose << ",";
58 s << "type=" << props.type << "]";
59 return s.str();
60 }
61
62
63 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
64 setProperties( opts );
65
66 if( !isConnected() )
67 DAI_THROW(FACTORGRAPH_NOT_CONNECTED);
68
69 if( opts.hasKey("tree") ) {
70 construct( opts.GetAs<RootedTree>("tree") );
71 } else {
72 if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
73 // ORG: construct weighted graph with as weights a crude estimate of the
74 // mutual information between the nodes
75 // ALT: construct weighted graph with as weights an upper bound on the
76 // effective interaction strength between pairs of nodes
77
78 WeightedGraph<Real> wg;
79 for( size_t i = 0; i < nrVars(); ++i ) {
80 Var v_i = var(i);
81 VarSet di = delta(i);
82 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
83 if( v_i < *j ) {
84 VarSet ij(v_i,*j);
85 Factor piet;
86 for( size_t I = 0; I < nrFactors(); I++ ) {
87 VarSet Ivars = factor(I).vars();
88 if( props.type == Properties::TypeType::ORG ) {
89 if( (Ivars == v_i) || (Ivars == *j) )
90 piet *= factor(I);
91 else if( Ivars >> ij )
92 piet *= factor(I).marginal( ij );
93 } else {
94 if( Ivars >> ij )
95 piet *= factor(I);
96 }
97 }
98 if( props.type == Properties::TypeType::ORG ) {
99 if( piet.vars() >> ij ) {
100 piet = piet.marginal( ij );
101 Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
102 wg[UEdge(i,findVar(*j))] = dist( piet, pietf, Prob::DISTKL );
103 } else
104 wg[UEdge(i,findVar(*j))] = 0;
105 } else {
106 wg[UEdge(i,findVar(*j))] = piet.strength(v_i, *j);
107 }
108 }
109 }
110
111 // find maximal spanning tree
112 construct( MaxSpanningTreePrims( wg ) );
113 } else
114 DAI_THROW(UNKNOWN_ENUM_VALUE);
115 }
116 }
117
118
119 void TreeEP::construct( const RootedTree &tree ) {
120 vector<VarSet> cl;
121 for( size_t i = 0; i < tree.size(); i++ )
122 cl.push_back( VarSet( var(tree[i].n1), var(tree[i].n2) ) );
123
124 // If no outer region can be found subsuming that factor, label the
125 // factor as off-tree.
126 JTree::construct( cl, false );
127
128 if( props.verbose >= 1 )
129 cerr << "TreeEP::construct: The tree has size " << JTree::RTree.size() << endl;
130 if( props.verbose >= 3 )
131 cerr << " it is " << JTree::RTree << " with cliques " << cl << endl;
132
133 // Create factor approximations
134 _Q.clear();
135 size_t PreviousRoot = (size_t)-1;
136 // Second repetition: previous root of first off-tree factor should be the root of the last off-tree factor
137 for( size_t repeats = 0; repeats < 2; repeats++ )
138 for( size_t I = 0; I < nrFactors(); I++ )
139 if( offtree(I) ) {
140 // find efficient subtree
141 RootedTree subTree;
142 size_t subTreeSize = findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
143 PreviousRoot = subTree[0].n1;
144 subTree.resize( subTreeSize );
145 if( props.verbose >= 1 )
146 cerr << "Subtree " << I << " has size " << subTreeSize << endl;
147 if( props.verbose >= 3 )
148 cerr << " it is " << subTree << endl;
149 _Q[I] = TreeEPSubTree( subTree, RTree, Qa, Qb, &factor(I) );
150 if( repeats == 1 )
151 break;
152 }
153
154 if( props.verbose >= 3 )
155 cerr << "Resulting regiongraph: " << *this << endl;
156 }
157
158
159 string TreeEP::identify() const {
160 return string(Name) + printProperties();
161 }
162
163
164 void TreeEP::init() {
165 runHUGIN();
166
167 // Init factor approximations
168 for( size_t I = 0; I < nrFactors(); I++ )
169 if( offtree(I) )
170 _Q[I].init();
171 }
172
173
174 Real TreeEP::run() {
175 if( props.verbose >= 1 )
176 cerr << "Starting " << identify() << "...";
177 if( props.verbose >= 3 )
178 cerr << endl;
179
180 double tic = toc();
181
182 vector<Factor> oldBeliefs = beliefs();
183
184 // do several passes over the network until maximum number of iterations has
185 // been reached or until the maximum belief difference is smaller than tolerance
186 Real maxDiff = INFINITY;
187 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol; _iters++ ) {
188 for( size_t I = 0; I < nrFactors(); I++ )
189 if( offtree(I) ) {
190 _Q[I].InvertAndMultiply( Qa, Qb );
191 _Q[I].HUGIN_with_I( Qa, Qb );
192 _Q[I].InvertAndMultiply( Qa, Qb );
193 }
194
195 // calculate new beliefs and compare with old ones
196 vector<Factor> newBeliefs = beliefs();
197 maxDiff = -INFINITY;
198 for( size_t t = 0; t < oldBeliefs.size(); t++ )
199 maxDiff = std::max( maxDiff, dist( newBeliefs[t], oldBeliefs[t], Prob::DISTLINF ) );
200 swap( newBeliefs, oldBeliefs );
201
202 if( props.verbose >= 3 )
203 cerr << Name << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
204 }
205
206 if( maxDiff > _maxdiff )
207 _maxdiff = maxDiff;
208
209 if( props.verbose >= 1 ) {
210 if( maxDiff > props.tol ) {
211 if( props.verbose == 1 )
212 cerr << endl;
213 cerr << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
214 } else {
215 if( props.verbose >= 3 )
216 cerr << Name << "::run: ";
217 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
218 }
219 }
220
221 return maxDiff;
222 }
223
224
225 Real TreeEP::logZ() const {
226 Real s = 0.0;
227
228 // entropy of the tree
229 for( size_t beta = 0; beta < nrIRs(); beta++ )
230 s -= Qb[beta].entropy();
231 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
232 s += Qa[alpha].entropy();
233
234 // energy of the on-tree factors
235 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
236 s += (OR(alpha).log(true) * Qa[alpha]).sum();
237
238 // energy of the off-tree factors
239 for( size_t I = 0; I < nrFactors(); I++ )
240 if( offtree(I) )
241 s += (_Q.find(I))->second.logZ( Qa, Qb );
242
243 return s;
244 }
245
246
247 TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
248 _ns = _I->vars();
249
250 // Make _Qa, _Qb, _a and _b corresponding to the subtree
251 _b.reserve( subRTree.size() );
252 _Qb.reserve( subRTree.size() );
253 _RTree.reserve( subRTree.size() );
254 for( size_t i = 0; i < subRTree.size(); i++ ) {
255 size_t alpha1 = subRTree[i].n1; // old index 1
256 size_t alpha2 = subRTree[i].n2; // old index 2
257 size_t beta; // old sep index
258 for( beta = 0; beta < jt_RTree.size(); beta++ )
259 if( UEdge( jt_RTree[beta].n1, jt_RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
260 break;
261 DAI_ASSERT( beta != jt_RTree.size() );
262
263 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
264 if( newalpha1 == _a.size() ) {
265 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
266 _a.push_back( alpha1 ); // save old index in index conversion table
267 }
268
269 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
270 if( newalpha2 == _a.size() ) {
271 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
272 _a.push_back( alpha2 ); // save old index in index conversion table
273 }
274
275 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
276 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
277 _b.push_back( beta );
278 }
279
280 // Find remaining variables (which are not in the new root)
281 _nsrem = _ns / _Qa[0].vars();
282 }
283
284
285 void TreeEP::TreeEPSubTree::init() {
286 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
287 _Qa[alpha].fill( 1.0 );
288 for( size_t beta = 0; beta < _Qb.size(); beta++ )
289 _Qb[beta].fill( 1.0 );
290 }
291
292
293 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
294 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
295 _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
296
297 for( size_t beta = 0; beta < _Qb.size(); beta++ )
298 _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
299 }
300
301
302 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
303 // Backup _Qa and _Qb
304 vector<Factor> _Qa_old(_Qa);
305 vector<Factor> _Qb_old(_Qb);
306
307 // Clear Qa and Qb
308 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
309 Qa[_a[alpha]].fill( 0.0 );
310 for( size_t beta = 0; beta < _Qb.size(); beta++ )
311 Qb[_b[beta]].fill( 0.0 );
312
313 // For all states of _nsrem
314 for( State s(_nsrem); s.valid(); s++ ) {
315 // Multiply root with slice of I
316 _Qa[0] *= _I->slice( _nsrem, s );
317
318 // CollectEvidence
319 for( size_t i = _RTree.size(); (i--) != 0; ) {
320 // clamp variables in nsrem
321 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
322 if( _Qa[_RTree[i].n2].vars() >> *n ) {
323 Factor delta( *n, 0.0 );
324 delta[s(*n)] = 1.0;
325 _Qa[_RTree[i].n2] *= delta;
326 }
327 Factor new_Qb = _Qa[_RTree[i].n2].marginal( _Qb[i].vars(), false );
328 _Qa[_RTree[i].n1] *= new_Qb / _Qb[i];
329 _Qb[i] = new_Qb;
330 }
331
332 // DistributeEvidence
333 for( size_t i = 0; i < _RTree.size(); i++ ) {
334 Factor new_Qb = _Qa[_RTree[i].n1].marginal( _Qb[i].vars(), false );
335 _Qa[_RTree[i].n2] *= new_Qb / _Qb[i];
336 _Qb[i] = new_Qb;
337 }
338
339 // Store Qa's and Qb's
340 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
341 Qa[_a[alpha]].p() += _Qa[alpha].p();
342 for( size_t beta = 0; beta < _Qb.size(); beta++ )
343 Qb[_b[beta]].p() += _Qb[beta].p();
344
345 // Restore _Qa and _Qb
346 _Qa = _Qa_old;
347 _Qb = _Qb_old;
348 }
349
350 // Normalize Qa and Qb
351 _logZ = 0.0;
352 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
353 _logZ += log(Qa[_a[alpha]].sum());
354 Qa[_a[alpha]].normalize();
355 }
356 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
357 _logZ -= log(Qb[_b[beta]].sum());
358 Qb[_b[beta]].normalize();
359 }
360 }
361
362
363 Real TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
364 Real s = 0.0;
365 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
366 s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
367 for( size_t beta = 0; beta < _Qb.size(); beta++ )
368 s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
369 return s + _logZ;
370 }
371
372
373 } // end of namespace dai