Replaced Name members by name() virtual functions (fixing a bug in matlab/dai.cpp)
[libdai.git] / src / treeep.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <vector>
15 #include <dai/jtree.h>
16 #include <dai/treeep.h>
17 #include <dai/util.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 void TreeEP::setProperties( const PropertySet &opts ) {
27 DAI_ASSERT( opts.hasKey("tol") );
28 DAI_ASSERT( opts.hasKey("type") );
29
30 props.tol = opts.getStringAs<Real>("tol");
31 props.type = opts.getStringAs<Properties::TypeType>("type");
32 if( opts.hasKey("maxiter") )
33 props.maxiter = opts.getStringAs<size_t>("maxiter");
34 else
35 props.maxiter = 10000;
36 if( opts.hasKey("maxtime") )
37 props.maxtime = opts.getStringAs<Real>("maxtime");
38 else
39 props.maxtime = INFINITY;
40 if( opts.hasKey("verbose") )
41 props.verbose = opts.getStringAs<size_t>("verbose");
42 else
43 props.verbose = 0;
44 }
45
46
47 PropertySet TreeEP::getProperties() const {
48 PropertySet opts;
49 opts.set( "tol", props.tol );
50 opts.set( "maxiter", props.maxiter );
51 opts.set( "maxtime", props.maxtime );
52 opts.set( "verbose", props.verbose );
53 opts.set( "type", props.type );
54 return opts;
55 }
56
57
58 string TreeEP::printProperties() const {
59 stringstream s( stringstream::out );
60 s << "[";
61 s << "tol=" << props.tol << ",";
62 s << "maxiter=" << props.maxiter << ",";
63 s << "maxtime=" << props.maxtime << ",";
64 s << "verbose=" << props.verbose << ",";
65 s << "type=" << props.type << "]";
66 return s.str();
67 }
68
69
70 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
71 setProperties( opts );
72
73 if( opts.hasKey("tree") ) {
74 construct( fg, opts.getAs<RootedTree>("tree") );
75 } else {
76 if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
77 // ORG: construct weighted graph with as weights a crude estimate of the
78 // mutual information between the nodes
79 // ALT: construct weighted graph with as weights an upper bound on the
80 // effective interaction strength between pairs of nodes
81
82 WeightedGraph<Real> wg;
83 // in order to get a connected weighted graph, we start
84 // by connecting every variable to the zero'th variable with weight 0
85 for( size_t i = 1; i < fg.nrVars(); i++ )
86 wg[UEdge(i,0)] = 0.0;
87 for( size_t i = 0; i < fg.nrVars(); i++ ) {
88 SmallSet<size_t> delta_i = fg.bipGraph().delta1( i, false );
89 const Var& v_i = fg.var(i);
90 foreach( size_t j, delta_i )
91 if( i < j ) {
92 const Var& v_j = fg.var(j);
93 VarSet v_ij( v_i, v_j );
94 SmallSet<size_t> nb_ij = fg.bipGraph().nb1Set( i ) | fg.bipGraph().nb1Set( j );
95 Factor piet;
96 foreach( size_t I, nb_ij ) {
97 const VarSet& Ivars = fg.factor(I).vars();
98 if( props.type == Properties::TypeType::ORG ) {
99 if( (Ivars == v_i) || (Ivars == v_j) )
100 piet *= fg.factor(I);
101 else if( Ivars >> v_ij )
102 piet *= fg.factor(I).marginal( v_ij );
103 } else {
104 if( Ivars >> v_ij )
105 piet *= fg.factor(I);
106 }
107 }
108 if( props.type == Properties::TypeType::ORG ) {
109 if( piet.vars() >> v_ij ) {
110 piet = piet.marginal( v_ij );
111 Factor pietf = piet.marginal(v_i) * piet.marginal(v_j);
112 wg[UEdge(i,j)] = dist( piet, pietf, DISTKL );
113 } else {
114 // this should never happen...
115 DAI_ASSERT( 0 == 1 );
116 wg[UEdge(i,j)] = 0;
117 }
118 } else
119 wg[UEdge(i,j)] = piet.strength(v_i, v_j);
120 }
121 }
122
123 // find maximal spanning tree
124 if( props.verbose >= 3 )
125 cerr << "WeightedGraph: " << wg << endl;
126 RootedTree t = MaxSpanningTree( wg, true );
127 if( props.verbose >= 3 )
128 cerr << "Spanningtree: " << t << endl;
129 construct( fg, t );
130 } else
131 DAI_THROW(UNKNOWN_ENUM_VALUE);
132 }
133 }
134
135
136 void TreeEP::construct( const FactorGraph& fg, const RootedTree& tree ) {
137 // Copy the factor graph
138 FactorGraph::operator=( fg );
139
140 vector<VarSet> cl;
141 for( size_t i = 0; i < tree.size(); i++ )
142 cl.push_back( VarSet( var(tree[i].first), var(tree[i].second) ) );
143
144 // If no outer region can be found subsuming that factor, label the
145 // factor as off-tree.
146 JTree::construct( *this, cl, false );
147
148 if( props.verbose >= 1 )
149 cerr << "TreeEP::construct: The tree has size " << JTree::RTree.size() << endl;
150 if( props.verbose >= 3 )
151 cerr << " it is " << JTree::RTree << " with cliques " << cl << endl;
152
153 // Create factor approximations
154 _Q.clear();
155 size_t PreviousRoot = (size_t)-1;
156 // Second repetition: previous root of first off-tree factor should be the root of the last off-tree factor
157 for( size_t repeats = 0; repeats < 2; repeats++ )
158 for( size_t I = 0; I < nrFactors(); I++ )
159 if( offtree(I) ) {
160 // find efficient subtree
161 RootedTree subTree;
162 size_t subTreeSize = findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
163 PreviousRoot = subTree[0].first;
164 subTree.resize( subTreeSize );
165 if( props.verbose >= 1 )
166 cerr << "Subtree " << I << " has size " << subTreeSize << endl;
167 if( props.verbose >= 3 )
168 cerr << " it is " << subTree << endl;
169 _Q[I] = TreeEPSubTree( subTree, RTree, Qa, Qb, &factor(I) );
170 if( repeats == 1 )
171 break;
172 }
173
174 if( props.verbose >= 3 )
175 cerr << "Resulting regiongraph: " << *this << endl;
176 }
177
178
179 void TreeEP::init() {
180 runHUGIN();
181
182 // Init factor approximations
183 for( size_t I = 0; I < nrFactors(); I++ )
184 if( offtree(I) )
185 _Q[I].init();
186 }
187
188
189 Real TreeEP::run() {
190 if( props.verbose >= 1 )
191 cerr << "Starting " << identify() << "...";
192 if( props.verbose >= 3 )
193 cerr << endl;
194
195 double tic = toc();
196
197 vector<Factor> oldBeliefs = beliefs();
198
199 // do several passes over the network until maximum number of iterations has
200 // been reached or until the maximum belief difference is smaller than tolerance
201 Real maxDiff = INFINITY;
202 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol && (toc() - tic) < props.maxtime; _iters++ ) {
203 for( size_t I = 0; I < nrFactors(); I++ )
204 if( offtree(I) ) {
205 _Q[I].InvertAndMultiply( Qa, Qb );
206 _Q[I].HUGIN_with_I( Qa, Qb );
207 _Q[I].InvertAndMultiply( Qa, Qb );
208 }
209
210 // calculate new beliefs and compare with old ones
211 vector<Factor> newBeliefs = beliefs();
212 maxDiff = -INFINITY;
213 for( size_t t = 0; t < oldBeliefs.size(); t++ )
214 maxDiff = std::max( maxDiff, dist( newBeliefs[t], oldBeliefs[t], DISTLINF ) );
215 swap( newBeliefs, oldBeliefs );
216
217 if( props.verbose >= 3 )
218 cerr << name() << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
219 }
220
221 if( maxDiff > _maxdiff )
222 _maxdiff = maxDiff;
223
224 if( props.verbose >= 1 ) {
225 if( maxDiff > props.tol ) {
226 if( props.verbose == 1 )
227 cerr << endl;
228 cerr << name() << "::run: WARNING: not converged after " << _iters << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
229 } else {
230 if( props.verbose >= 3 )
231 cerr << name() << "::run: ";
232 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
233 }
234 }
235
236 return maxDiff;
237 }
238
239
240 Real TreeEP::logZ() const {
241 Real s = 0.0;
242
243 // entropy of the tree
244 for( size_t beta = 0; beta < nrIRs(); beta++ )
245 s -= Qb[beta].entropy();
246 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
247 s += Qa[alpha].entropy();
248
249 // energy of the on-tree factors
250 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
251 s += (OR(alpha).log(true) * Qa[alpha]).sum();
252
253 // energy of the off-tree factors
254 for( size_t I = 0; I < nrFactors(); I++ )
255 if( offtree(I) )
256 s += (_Q.find(I))->second.logZ( Qa, Qb );
257
258 return s;
259 }
260
261
262 TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
263 _ns = _I->vars();
264
265 // Make _Qa, _Qb, _a and _b corresponding to the subtree
266 _b.reserve( subRTree.size() );
267 _Qb.reserve( subRTree.size() );
268 _RTree.reserve( subRTree.size() );
269 for( size_t i = 0; i < subRTree.size(); i++ ) {
270 size_t alpha1 = subRTree[i].first; // old index 1
271 size_t alpha2 = subRTree[i].second; // old index 2
272 size_t beta; // old sep index
273 for( beta = 0; beta < jt_RTree.size(); beta++ )
274 if( UEdge( jt_RTree[beta].first, jt_RTree[beta].second ) == UEdge( alpha1, alpha2 ) )
275 break;
276 DAI_ASSERT( beta != jt_RTree.size() );
277
278 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
279 if( newalpha1 == _a.size() ) {
280 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
281 _a.push_back( alpha1 ); // save old index in index conversion table
282 }
283
284 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
285 if( newalpha2 == _a.size() ) {
286 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
287 _a.push_back( alpha2 ); // save old index in index conversion table
288 }
289
290 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
291 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
292 _b.push_back( beta );
293 }
294
295 // Find remaining variables (which are not in the new root)
296 _nsrem = _ns / _Qa[0].vars();
297 }
298
299
300 void TreeEP::TreeEPSubTree::init() {
301 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
302 _Qa[alpha].fill( 1.0 );
303 for( size_t beta = 0; beta < _Qb.size(); beta++ )
304 _Qb[beta].fill( 1.0 );
305 }
306
307
308 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
309 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
310 _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
311
312 for( size_t beta = 0; beta < _Qb.size(); beta++ )
313 _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
314 }
315
316
317 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
318 // Backup _Qa and _Qb
319 vector<Factor> _Qa_old(_Qa);
320 vector<Factor> _Qb_old(_Qb);
321
322 // Clear Qa and Qb
323 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
324 Qa[_a[alpha]].fill( 0.0 );
325 for( size_t beta = 0; beta < _Qb.size(); beta++ )
326 Qb[_b[beta]].fill( 0.0 );
327
328 // For all states of _nsrem
329 for( State s(_nsrem); s.valid(); s++ ) {
330 // Multiply root with slice of I
331 _Qa[0] *= _I->slice( _nsrem, s );
332
333 // CollectEvidence
334 for( size_t i = _RTree.size(); (i--) != 0; ) {
335 // clamp variables in nsrem
336 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
337 if( _Qa[_RTree[i].second].vars() >> *n )
338 _Qa[_RTree[i].second] *= createFactorDelta( *n, s(*n) );
339 Factor new_Qb = _Qa[_RTree[i].second].marginal( _Qb[i].vars(), false );
340 _Qa[_RTree[i].first] *= new_Qb / _Qb[i];
341 _Qb[i] = new_Qb;
342 }
343
344 // DistributeEvidence
345 for( size_t i = 0; i < _RTree.size(); i++ ) {
346 Factor new_Qb = _Qa[_RTree[i].first].marginal( _Qb[i].vars(), false );
347 _Qa[_RTree[i].second] *= new_Qb / _Qb[i];
348 _Qb[i] = new_Qb;
349 }
350
351 // Store Qa's and Qb's
352 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
353 Qa[_a[alpha]].p() += _Qa[alpha].p();
354 for( size_t beta = 0; beta < _Qb.size(); beta++ )
355 Qb[_b[beta]].p() += _Qb[beta].p();
356
357 // Restore _Qa and _Qb
358 _Qa = _Qa_old;
359 _Qb = _Qb_old;
360 }
361
362 // Normalize Qa and Qb
363 _logZ = 0.0;
364 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
365 _logZ += log(Qa[_a[alpha]].sum());
366 Qa[_a[alpha]].normalize();
367 }
368 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
369 _logZ -= log(Qb[_b[beta]].sum());
370 Qb[_b[beta]].normalize();
371 }
372 }
373
374
375 Real TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
376 Real s = 0.0;
377 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
378 s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
379 for( size_t beta = 0; beta < _Qb.size(); beta++ )
380 s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
381 return s + _logZ;
382 }
383
384
385 } // end of namespace dai