New git HEAD version
[libdai.git] / src / treeep.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/dai_config.h>
10 #ifdef DAI_WITH_TREEEP
11
12
13 #include <iostream>
14 #include <fstream>
15 #include <vector>
16 #include <dai/jtree.h>
17 #include <dai/treeep.h>
18 #include <dai/util.h>
19
20
21 namespace dai {
22
23
24 using namespace std;
25
26
27 void TreeEP::setProperties( const PropertySet &opts ) {
28 DAI_ASSERT( opts.hasKey("tol") );
29 DAI_ASSERT( opts.hasKey("type") );
30
31 props.tol = opts.getStringAs<Real>("tol");
32 props.type = opts.getStringAs<Properties::TypeType>("type");
33 if( opts.hasKey("maxiter") )
34 props.maxiter = opts.getStringAs<size_t>("maxiter");
35 else
36 props.maxiter = 10000;
37 if( opts.hasKey("maxtime") )
38 props.maxtime = opts.getStringAs<Real>("maxtime");
39 else
40 props.maxtime = INFINITY;
41 if( opts.hasKey("verbose") )
42 props.verbose = opts.getStringAs<size_t>("verbose");
43 else
44 props.verbose = 0;
45 }
46
47
48 PropertySet TreeEP::getProperties() const {
49 PropertySet opts;
50 opts.set( "tol", props.tol );
51 opts.set( "maxiter", props.maxiter );
52 opts.set( "maxtime", props.maxtime );
53 opts.set( "verbose", props.verbose );
54 opts.set( "type", props.type );
55 return opts;
56 }
57
58
59 string TreeEP::printProperties() const {
60 stringstream s( stringstream::out );
61 s << "[";
62 s << "tol=" << props.tol << ",";
63 s << "maxiter=" << props.maxiter << ",";
64 s << "maxtime=" << props.maxtime << ",";
65 s << "verbose=" << props.verbose << ",";
66 s << "type=" << props.type << "]";
67 return s.str();
68 }
69
70
71 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
72 setProperties( opts );
73
74 if( opts.hasKey("tree") ) {
75 construct( fg, opts.getAs<RootedTree>("tree") );
76 } else {
77 if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
78 // ORG: construct weighted graph with as weights a crude estimate of the
79 // mutual information between the nodes
80 // ALT: construct weighted graph with as weights an upper bound on the
81 // effective interaction strength between pairs of nodes
82
83 WeightedGraph<Real> wg;
84 // in order to get a connected weighted graph, we start
85 // by connecting every variable to the zero'th variable with weight 0
86 for( size_t i = 1; i < fg.nrVars(); i++ )
87 wg[UEdge(i,0)] = 0.0;
88 for( size_t i = 0; i < fg.nrVars(); i++ ) {
89 SmallSet<size_t> delta_i = fg.bipGraph().delta1( i, false );
90 const Var& v_i = fg.var(i);
91 bforeach( size_t j, delta_i )
92 if( i < j ) {
93 const Var& v_j = fg.var(j);
94 VarSet v_ij( v_i, v_j );
95 SmallSet<size_t> nb_ij = fg.bipGraph().nb1Set( i ) | fg.bipGraph().nb1Set( j );
96 Factor piet;
97 bforeach( size_t I, nb_ij ) {
98 const VarSet& Ivars = fg.factor(I).vars();
99 if( props.type == Properties::TypeType::ORG ) {
100 if( (Ivars == v_i) || (Ivars == v_j) )
101 piet *= fg.factor(I);
102 else if( Ivars >> v_ij )
103 piet *= fg.factor(I).marginal( v_ij );
104 } else {
105 if( Ivars >> v_ij )
106 piet *= fg.factor(I);
107 }
108 }
109 if( props.type == Properties::TypeType::ORG ) {
110 if( piet.vars() >> v_ij ) {
111 piet = piet.marginal( v_ij );
112 Factor pietf = piet.marginal(v_i) * piet.marginal(v_j);
113 wg[UEdge(i,j)] = dist( piet, pietf, DISTKL );
114 } else {
115 // this should never happen...
116 DAI_ASSERT( 0 == 1 );
117 wg[UEdge(i,j)] = 0;
118 }
119 } else
120 wg[UEdge(i,j)] = piet.strength(v_i, v_j);
121 }
122 }
123
124 // find maximal spanning tree
125 if( props.verbose >= 3 )
126 cerr << "WeightedGraph: " << wg << endl;
127 RootedTree t = MaxSpanningTree( wg, /*true*/false ); // WORKAROUND FOR BUG IN BOOST GRAPH LIBRARY VERSION 1.54
128 if( props.verbose >= 3 )
129 cerr << "Spanningtree: " << t << endl;
130 construct( fg, t );
131 } else
132 DAI_THROW(UNKNOWN_ENUM_VALUE);
133 }
134 }
135
136
137 void TreeEP::construct( const FactorGraph& fg, const RootedTree& tree ) {
138 // Copy the factor graph
139 FactorGraph::operator=( fg );
140
141 vector<VarSet> cl;
142 for( size_t i = 0; i < tree.size(); i++ )
143 cl.push_back( VarSet( var(tree[i].first), var(tree[i].second) ) );
144
145 // If no outer region can be found subsuming that factor, label the
146 // factor as off-tree.
147 JTree::construct( *this, cl, false );
148
149 if( props.verbose >= 1 )
150 cerr << "TreeEP::construct: The tree has size " << JTree::RTree.size() << endl;
151 if( props.verbose >= 3 )
152 cerr << " it is " << JTree::RTree << " with cliques " << cl << endl;
153
154 // Create factor approximations
155 _Q.clear();
156 size_t PreviousRoot = (size_t)-1;
157 // Second repetition: previous root of first off-tree factor should be the root of the last off-tree factor
158 for( size_t repeats = 0; repeats < 2; repeats++ )
159 for( size_t I = 0; I < nrFactors(); I++ )
160 if( offtree(I) ) {
161 // find efficient subtree
162 RootedTree subTree;
163 size_t subTreeSize = findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
164 PreviousRoot = subTree[0].first;
165 subTree.resize( subTreeSize );
166 if( props.verbose >= 1 )
167 cerr << "Subtree " << I << " has size " << subTreeSize << endl;
168 if( props.verbose >= 3 )
169 cerr << " it is " << subTree << endl;
170 _Q[I] = TreeEPSubTree( subTree, RTree, Qa, Qb, &factor(I) );
171 if( repeats == 1 )
172 break;
173 }
174
175 if( props.verbose >= 3 )
176 cerr << "Resulting regiongraph: " << *this << endl;
177 }
178
179
180 void TreeEP::init() {
181 runHUGIN();
182
183 // Init factor approximations
184 for( size_t I = 0; I < nrFactors(); I++ )
185 if( offtree(I) )
186 _Q[I].init();
187 }
188
189
190 Real TreeEP::run() {
191 if( props.verbose >= 1 )
192 cerr << "Starting " << identify() << "...";
193 if( props.verbose >= 3 )
194 cerr << endl;
195
196 double tic = toc();
197
198 vector<Factor> oldBeliefs = beliefs();
199
200 // do several passes over the network until maximum number of iterations has
201 // been reached or until the maximum belief difference is smaller than tolerance
202 Real maxDiff = INFINITY;
203 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol && (toc() - tic) < props.maxtime; _iters++ ) {
204 for( size_t I = 0; I < nrFactors(); I++ )
205 if( offtree(I) ) {
206 _Q[I].InvertAndMultiply( Qa, Qb );
207 _Q[I].HUGIN_with_I( Qa, Qb );
208 _Q[I].InvertAndMultiply( Qa, Qb );
209 }
210
211 // calculate new beliefs and compare with old ones
212 vector<Factor> newBeliefs = beliefs();
213 maxDiff = -INFINITY;
214 for( size_t t = 0; t < oldBeliefs.size(); t++ )
215 maxDiff = std::max( maxDiff, dist( newBeliefs[t], oldBeliefs[t], DISTLINF ) );
216 swap( newBeliefs, oldBeliefs );
217
218 if( props.verbose >= 3 )
219 cerr << name() << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
220 }
221
222 if( maxDiff > _maxdiff )
223 _maxdiff = maxDiff;
224
225 if( props.verbose >= 1 ) {
226 if( maxDiff > props.tol ) {
227 if( props.verbose == 1 )
228 cerr << endl;
229 cerr << name() << "::run: WARNING: not converged after " << _iters << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
230 } else {
231 if( props.verbose >= 3 )
232 cerr << name() << "::run: ";
233 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
234 }
235 }
236
237 return maxDiff;
238 }
239
240
241 Real TreeEP::logZ() const {
242 Real s = 0.0;
243
244 // entropy of the tree
245 for( size_t beta = 0; beta < nrIRs(); beta++ )
246 s -= Qb[beta].entropy();
247 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
248 s += Qa[alpha].entropy();
249
250 // energy of the on-tree factors
251 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
252 s += (OR(alpha).log(true) * Qa[alpha]).sum();
253
254 // energy of the off-tree factors
255 for( size_t I = 0; I < nrFactors(); I++ )
256 if( offtree(I) )
257 s += (_Q.find(I))->second.logZ( Qa, Qb );
258
259 return s;
260 }
261
262
263 TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
264 _ns = _I->vars();
265
266 // Make _Qa, _Qb, _a and _b corresponding to the subtree
267 _b.reserve( subRTree.size() );
268 _Qb.reserve( subRTree.size() );
269 _RTree.reserve( subRTree.size() );
270 for( size_t i = 0; i < subRTree.size(); i++ ) {
271 size_t alpha1 = subRTree[i].first; // old index 1
272 size_t alpha2 = subRTree[i].second; // old index 2
273 size_t beta; // old sep index
274 for( beta = 0; beta < jt_RTree.size(); beta++ )
275 if( UEdge( jt_RTree[beta].first, jt_RTree[beta].second ) == UEdge( alpha1, alpha2 ) )
276 break;
277 DAI_ASSERT( beta != jt_RTree.size() );
278
279 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
280 if( newalpha1 == _a.size() ) {
281 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
282 _a.push_back( alpha1 ); // save old index in index conversion table
283 }
284
285 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
286 if( newalpha2 == _a.size() ) {
287 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
288 _a.push_back( alpha2 ); // save old index in index conversion table
289 }
290
291 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
292 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
293 _b.push_back( beta );
294 }
295
296 // Find remaining variables (which are not in the new root)
297 _nsrem = _ns / _Qa[0].vars();
298 }
299
300
301 void TreeEP::TreeEPSubTree::init() {
302 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
303 _Qa[alpha].fill( 1.0 );
304 for( size_t beta = 0; beta < _Qb.size(); beta++ )
305 _Qb[beta].fill( 1.0 );
306 }
307
308
309 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
310 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
311 _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
312
313 for( size_t beta = 0; beta < _Qb.size(); beta++ )
314 _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
315 }
316
317
318 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
319 // Backup _Qa and _Qb
320 vector<Factor> _Qa_old(_Qa);
321 vector<Factor> _Qb_old(_Qb);
322
323 // Clear Qa and Qb
324 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
325 Qa[_a[alpha]].fill( 0.0 );
326 for( size_t beta = 0; beta < _Qb.size(); beta++ )
327 Qb[_b[beta]].fill( 0.0 );
328
329 // For all states of _nsrem
330 for( State s(_nsrem); s.valid(); s++ ) {
331 // Multiply root with slice of I
332 _Qa[0] *= _I->slice( _nsrem, s );
333
334 // CollectEvidence
335 for( size_t i = _RTree.size(); (i--) != 0; ) {
336 // clamp variables in nsrem
337 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
338 if( _Qa[_RTree[i].second].vars() >> *n )
339 _Qa[_RTree[i].second] *= createFactorDelta( *n, s(*n) );
340 Factor new_Qb = _Qa[_RTree[i].second].marginal( _Qb[i].vars(), false );
341 _Qa[_RTree[i].first] *= new_Qb / _Qb[i];
342 _Qb[i] = new_Qb;
343 }
344
345 // DistributeEvidence
346 for( size_t i = 0; i < _RTree.size(); i++ ) {
347 Factor new_Qb = _Qa[_RTree[i].first].marginal( _Qb[i].vars(), false );
348 _Qa[_RTree[i].second] *= new_Qb / _Qb[i];
349 _Qb[i] = new_Qb;
350 }
351
352 // Store Qa's and Qb's
353 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
354 Qa[_a[alpha]].p() += _Qa[alpha].p();
355 for( size_t beta = 0; beta < _Qb.size(); beta++ )
356 Qb[_b[beta]].p() += _Qb[beta].p();
357
358 // Restore _Qa and _Qb
359 _Qa = _Qa_old;
360 _Qb = _Qb_old;
361 }
362
363 // Normalize Qa and Qb
364 _logZ = 0.0;
365 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
366 _logZ += log(Qa[_a[alpha]].sum());
367 Qa[_a[alpha]].normalize();
368 }
369 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
370 _logZ -= log(Qb[_b[beta]].sum());
371 Qb[_b[beta]].normalize();
372 }
373 }
374
375
376 Real TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
377 Real s = 0.0;
378 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
379 s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
380 for( size_t beta = 0; beta < _Qb.size(); beta++ )
381 s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
382 return s + _logZ;
383 }
384
385
386 } // end of namespace dai
387
388
389 #endif