2a3a47ede619b6f9f55d9059a00de199a443b443
[libdai.git] / src / treeep.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <fstream>
11 #include <vector>
12 #include <dai/jtree.h>
13 #include <dai/treeep.h>
14 #include <dai/util.h>
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 void TreeEP::setProperties( const PropertySet &opts ) {
24 DAI_ASSERT( opts.hasKey("tol") );
25 DAI_ASSERT( opts.hasKey("type") );
26
27 props.tol = opts.getStringAs<Real>("tol");
28 props.type = opts.getStringAs<Properties::TypeType>("type");
29 if( opts.hasKey("maxiter") )
30 props.maxiter = opts.getStringAs<size_t>("maxiter");
31 else
32 props.maxiter = 10000;
33 if( opts.hasKey("maxtime") )
34 props.maxtime = opts.getStringAs<Real>("maxtime");
35 else
36 props.maxtime = INFINITY;
37 if( opts.hasKey("verbose") )
38 props.verbose = opts.getStringAs<size_t>("verbose");
39 else
40 props.verbose = 0;
41 }
42
43
44 PropertySet TreeEP::getProperties() const {
45 PropertySet opts;
46 opts.set( "tol", props.tol );
47 opts.set( "maxiter", props.maxiter );
48 opts.set( "maxtime", props.maxtime );
49 opts.set( "verbose", props.verbose );
50 opts.set( "type", props.type );
51 return opts;
52 }
53
54
55 string TreeEP::printProperties() const {
56 stringstream s( stringstream::out );
57 s << "[";
58 s << "tol=" << props.tol << ",";
59 s << "maxiter=" << props.maxiter << ",";
60 s << "maxtime=" << props.maxtime << ",";
61 s << "verbose=" << props.verbose << ",";
62 s << "type=" << props.type << "]";
63 return s.str();
64 }
65
66
67 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
68 setProperties( opts );
69
70 if( opts.hasKey("tree") ) {
71 construct( fg, opts.getAs<RootedTree>("tree") );
72 } else {
73 if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
74 // ORG: construct weighted graph with as weights a crude estimate of the
75 // mutual information between the nodes
76 // ALT: construct weighted graph with as weights an upper bound on the
77 // effective interaction strength between pairs of nodes
78
79 WeightedGraph<Real> wg;
80 // in order to get a connected weighted graph, we start
81 // by connecting every variable to the zero'th variable with weight 0
82 for( size_t i = 1; i < fg.nrVars(); i++ )
83 wg[UEdge(i,0)] = 0.0;
84 for( size_t i = 0; i < fg.nrVars(); i++ ) {
85 SmallSet<size_t> delta_i = fg.bipGraph().delta1( i, false );
86 const Var& v_i = fg.var(i);
87 bforeach( size_t j, delta_i )
88 if( i < j ) {
89 const Var& v_j = fg.var(j);
90 VarSet v_ij( v_i, v_j );
91 SmallSet<size_t> nb_ij = fg.bipGraph().nb1Set( i ) | fg.bipGraph().nb1Set( j );
92 Factor piet;
93 bforeach( size_t I, nb_ij ) {
94 const VarSet& Ivars = fg.factor(I).vars();
95 if( props.type == Properties::TypeType::ORG ) {
96 if( (Ivars == v_i) || (Ivars == v_j) )
97 piet *= fg.factor(I);
98 else if( Ivars >> v_ij )
99 piet *= fg.factor(I).marginal( v_ij );
100 } else {
101 if( Ivars >> v_ij )
102 piet *= fg.factor(I);
103 }
104 }
105 if( props.type == Properties::TypeType::ORG ) {
106 if( piet.vars() >> v_ij ) {
107 piet = piet.marginal( v_ij );
108 Factor pietf = piet.marginal(v_i) * piet.marginal(v_j);
109 wg[UEdge(i,j)] = dist( piet, pietf, DISTKL );
110 } else {
111 // this should never happen...
112 DAI_ASSERT( 0 == 1 );
113 wg[UEdge(i,j)] = 0;
114 }
115 } else
116 wg[UEdge(i,j)] = piet.strength(v_i, v_j);
117 }
118 }
119
120 // find maximal spanning tree
121 if( props.verbose >= 3 )
122 cerr << "WeightedGraph: " << wg << endl;
123 RootedTree t = MaxSpanningTree( wg, true );
124 if( props.verbose >= 3 )
125 cerr << "Spanningtree: " << t << endl;
126 construct( fg, t );
127 } else
128 DAI_THROW(UNKNOWN_ENUM_VALUE);
129 }
130 }
131
132
133 void TreeEP::construct( const FactorGraph& fg, const RootedTree& tree ) {
134 // Copy the factor graph
135 FactorGraph::operator=( fg );
136
137 vector<VarSet> cl;
138 for( size_t i = 0; i < tree.size(); i++ )
139 cl.push_back( VarSet( var(tree[i].first), var(tree[i].second) ) );
140
141 // If no outer region can be found subsuming that factor, label the
142 // factor as off-tree.
143 JTree::construct( *this, cl, false );
144
145 if( props.verbose >= 1 )
146 cerr << "TreeEP::construct: The tree has size " << JTree::RTree.size() << endl;
147 if( props.verbose >= 3 )
148 cerr << " it is " << JTree::RTree << " with cliques " << cl << endl;
149
150 // Create factor approximations
151 _Q.clear();
152 size_t PreviousRoot = (size_t)-1;
153 // Second repetition: previous root of first off-tree factor should be the root of the last off-tree factor
154 for( size_t repeats = 0; repeats < 2; repeats++ )
155 for( size_t I = 0; I < nrFactors(); I++ )
156 if( offtree(I) ) {
157 // find efficient subtree
158 RootedTree subTree;
159 size_t subTreeSize = findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
160 PreviousRoot = subTree[0].first;
161 subTree.resize( subTreeSize );
162 if( props.verbose >= 1 )
163 cerr << "Subtree " << I << " has size " << subTreeSize << endl;
164 if( props.verbose >= 3 )
165 cerr << " it is " << subTree << endl;
166 _Q[I] = TreeEPSubTree( subTree, RTree, Qa, Qb, &factor(I) );
167 if( repeats == 1 )
168 break;
169 }
170
171 if( props.verbose >= 3 )
172 cerr << "Resulting regiongraph: " << *this << endl;
173 }
174
175
176 void TreeEP::init() {
177 runHUGIN();
178
179 // Init factor approximations
180 for( size_t I = 0; I < nrFactors(); I++ )
181 if( offtree(I) )
182 _Q[I].init();
183 }
184
185
186 Real TreeEP::run() {
187 if( props.verbose >= 1 )
188 cerr << "Starting " << identify() << "...";
189 if( props.verbose >= 3 )
190 cerr << endl;
191
192 double tic = toc();
193
194 vector<Factor> oldBeliefs = beliefs();
195
196 // do several passes over the network until maximum number of iterations has
197 // been reached or until the maximum belief difference is smaller than tolerance
198 Real maxDiff = INFINITY;
199 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol && (toc() - tic) < props.maxtime; _iters++ ) {
200 for( size_t I = 0; I < nrFactors(); I++ )
201 if( offtree(I) ) {
202 _Q[I].InvertAndMultiply( Qa, Qb );
203 _Q[I].HUGIN_with_I( Qa, Qb );
204 _Q[I].InvertAndMultiply( Qa, Qb );
205 }
206
207 // calculate new beliefs and compare with old ones
208 vector<Factor> newBeliefs = beliefs();
209 maxDiff = -INFINITY;
210 for( size_t t = 0; t < oldBeliefs.size(); t++ )
211 maxDiff = std::max( maxDiff, dist( newBeliefs[t], oldBeliefs[t], DISTLINF ) );
212 swap( newBeliefs, oldBeliefs );
213
214 if( props.verbose >= 3 )
215 cerr << name() << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
216 }
217
218 if( maxDiff > _maxdiff )
219 _maxdiff = maxDiff;
220
221 if( props.verbose >= 1 ) {
222 if( maxDiff > props.tol ) {
223 if( props.verbose == 1 )
224 cerr << endl;
225 cerr << name() << "::run: WARNING: not converged after " << _iters << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
226 } else {
227 if( props.verbose >= 3 )
228 cerr << name() << "::run: ";
229 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
230 }
231 }
232
233 return maxDiff;
234 }
235
236
237 Real TreeEP::logZ() const {
238 Real s = 0.0;
239
240 // entropy of the tree
241 for( size_t beta = 0; beta < nrIRs(); beta++ )
242 s -= Qb[beta].entropy();
243 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
244 s += Qa[alpha].entropy();
245
246 // energy of the on-tree factors
247 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
248 s += (OR(alpha).log(true) * Qa[alpha]).sum();
249
250 // energy of the off-tree factors
251 for( size_t I = 0; I < nrFactors(); I++ )
252 if( offtree(I) )
253 s += (_Q.find(I))->second.logZ( Qa, Qb );
254
255 return s;
256 }
257
258
259 TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
260 _ns = _I->vars();
261
262 // Make _Qa, _Qb, _a and _b corresponding to the subtree
263 _b.reserve( subRTree.size() );
264 _Qb.reserve( subRTree.size() );
265 _RTree.reserve( subRTree.size() );
266 for( size_t i = 0; i < subRTree.size(); i++ ) {
267 size_t alpha1 = subRTree[i].first; // old index 1
268 size_t alpha2 = subRTree[i].second; // old index 2
269 size_t beta; // old sep index
270 for( beta = 0; beta < jt_RTree.size(); beta++ )
271 if( UEdge( jt_RTree[beta].first, jt_RTree[beta].second ) == UEdge( alpha1, alpha2 ) )
272 break;
273 DAI_ASSERT( beta != jt_RTree.size() );
274
275 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
276 if( newalpha1 == _a.size() ) {
277 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
278 _a.push_back( alpha1 ); // save old index in index conversion table
279 }
280
281 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
282 if( newalpha2 == _a.size() ) {
283 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
284 _a.push_back( alpha2 ); // save old index in index conversion table
285 }
286
287 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
288 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
289 _b.push_back( beta );
290 }
291
292 // Find remaining variables (which are not in the new root)
293 _nsrem = _ns / _Qa[0].vars();
294 }
295
296
297 void TreeEP::TreeEPSubTree::init() {
298 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
299 _Qa[alpha].fill( 1.0 );
300 for( size_t beta = 0; beta < _Qb.size(); beta++ )
301 _Qb[beta].fill( 1.0 );
302 }
303
304
305 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
306 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
307 _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
308
309 for( size_t beta = 0; beta < _Qb.size(); beta++ )
310 _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
311 }
312
313
314 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
315 // Backup _Qa and _Qb
316 vector<Factor> _Qa_old(_Qa);
317 vector<Factor> _Qb_old(_Qb);
318
319 // Clear Qa and Qb
320 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
321 Qa[_a[alpha]].fill( 0.0 );
322 for( size_t beta = 0; beta < _Qb.size(); beta++ )
323 Qb[_b[beta]].fill( 0.0 );
324
325 // For all states of _nsrem
326 for( State s(_nsrem); s.valid(); s++ ) {
327 // Multiply root with slice of I
328 _Qa[0] *= _I->slice( _nsrem, s );
329
330 // CollectEvidence
331 for( size_t i = _RTree.size(); (i--) != 0; ) {
332 // clamp variables in nsrem
333 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
334 if( _Qa[_RTree[i].second].vars() >> *n )
335 _Qa[_RTree[i].second] *= createFactorDelta( *n, s(*n) );
336 Factor new_Qb = _Qa[_RTree[i].second].marginal( _Qb[i].vars(), false );
337 _Qa[_RTree[i].first] *= new_Qb / _Qb[i];
338 _Qb[i] = new_Qb;
339 }
340
341 // DistributeEvidence
342 for( size_t i = 0; i < _RTree.size(); i++ ) {
343 Factor new_Qb = _Qa[_RTree[i].first].marginal( _Qb[i].vars(), false );
344 _Qa[_RTree[i].second] *= new_Qb / _Qb[i];
345 _Qb[i] = new_Qb;
346 }
347
348 // Store Qa's and Qb's
349 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
350 Qa[_a[alpha]].p() += _Qa[alpha].p();
351 for( size_t beta = 0; beta < _Qb.size(); beta++ )
352 Qb[_b[beta]].p() += _Qb[beta].p();
353
354 // Restore _Qa and _Qb
355 _Qa = _Qa_old;
356 _Qb = _Qb_old;
357 }
358
359 // Normalize Qa and Qb
360 _logZ = 0.0;
361 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
362 _logZ += log(Qa[_a[alpha]].sum());
363 Qa[_a[alpha]].normalize();
364 }
365 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
366 _logZ -= log(Qb[_b[beta]].sum());
367 Qb[_b[beta]].normalize();
368 }
369 }
370
371
372 Real TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
373 Real s = 0.0;
374 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
375 s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
376 for( size_t beta = 0; beta < _Qb.size(); beta++ )
377 s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
378 return s + _logZ;
379 }
380
381
382 } // end of namespace dai