Added utils/uai2fg, ExactInf::findMaximum(), and fixed two bugs
[libdai.git] / src / treeep.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <vector>
15 #include <dai/jtree.h>
16 #include <dai/treeep.h>
17 #include <dai/util.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 const char *TreeEP::Name = "TREEEP";
27
28
29 void TreeEP::setProperties( const PropertySet &opts ) {
30 DAI_ASSERT( opts.hasKey("tol") );
31 DAI_ASSERT( opts.hasKey("maxiter") );
32 DAI_ASSERT( opts.hasKey("type") );
33
34 props.tol = opts.getStringAs<Real>("tol");
35 props.maxiter = opts.getStringAs<size_t>("maxiter");
36 props.type = opts.getStringAs<Properties::TypeType>("type");
37 if( opts.hasKey("verbose") )
38 props.verbose = opts.getStringAs<size_t>("verbose");
39 else
40 props.verbose = 0;
41 }
42
43
44 PropertySet TreeEP::getProperties() const {
45 PropertySet opts;
46 opts.set( "tol", props.tol );
47 opts.set( "maxiter", props.maxiter );
48 opts.set( "verbose", props.verbose );
49 opts.set( "type", props.type );
50 return opts;
51 }
52
53
54 string TreeEP::printProperties() const {
55 stringstream s( stringstream::out );
56 s << "[";
57 s << "tol=" << props.tol << ",";
58 s << "maxiter=" << props.maxiter << ",";
59 s << "verbose=" << props.verbose << ",";
60 s << "type=" << props.type << "]";
61 return s.str();
62 }
63
64
65 TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
66 setProperties( opts );
67
68 if( opts.hasKey("tree") ) {
69 construct( fg, opts.getAs<RootedTree>("tree") );
70 } else {
71 if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
72 // ORG: construct weighted graph with as weights a crude estimate of the
73 // mutual information between the nodes
74 // ALT: construct weighted graph with as weights an upper bound on the
75 // effective interaction strength between pairs of nodes
76
77 WeightedGraph<Real> wg;
78 for( size_t i = 0; i < fg.nrVars(); ++i ) {
79 Var v_i = fg.var(i);
80 VarSet di = fg.delta(i);
81 if( i )
82 wg[UEdge(i,0)] = 0.0;
83 for( VarSet::const_iterator cit_j = di.begin(); cit_j != di.end(); cit_j++ )
84 if( v_i < *cit_j ) {
85 VarSet ij(v_i,*cit_j);
86 Factor piet;
87 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
88 VarSet Ivars = fg.factor(I).vars();
89 if( props.type == Properties::TypeType::ORG ) {
90 if( (Ivars == v_i) || (Ivars == *cit_j) )
91 piet *= fg.factor(I);
92 else if( Ivars >> ij )
93 piet *= fg.factor(I).marginal( ij );
94 } else {
95 if( Ivars >> ij )
96 piet *= fg.factor(I);
97 }
98 }
99 size_t j = fg.findVar( *cit_j );
100 if( props.type == Properties::TypeType::ORG ) {
101 if( piet.vars() >> ij ) {
102 piet = piet.marginal( ij );
103 Factor pietf = piet.marginal(v_i) * piet.marginal(*cit_j);
104 wg[UEdge(i,j)] = dist( piet, pietf, DISTKL );
105 } else {
106 // this should never happen...
107 DAI_ASSERT( 0 == 1 );
108 wg[UEdge(i,j)] = 0;
109 }
110 } else
111 wg[UEdge(i,j)] = piet.strength(v_i, *cit_j);
112 }
113 }
114
115 // find maximal spanning tree
116 if( props.verbose >= 3 )
117 cerr << "WeightedGraph: " << wg << endl;
118 RootedTree t = MaxSpanningTree( wg, true );
119 if( props.verbose >= 3 )
120 cerr << "Spanningtree: " << t << endl;
121 construct( fg, t );
122 } else
123 DAI_THROW(UNKNOWN_ENUM_VALUE);
124 }
125 }
126
127
128 void TreeEP::construct( const FactorGraph& fg, const RootedTree& tree ) {
129 // Copy the factor graph
130 FactorGraph::operator=( fg );
131
132 vector<VarSet> cl;
133 for( size_t i = 0; i < tree.size(); i++ )
134 cl.push_back( VarSet( var(tree[i].first), var(tree[i].second) ) );
135
136 // If no outer region can be found subsuming that factor, label the
137 // factor as off-tree.
138 JTree::construct( *this, cl, false );
139
140 if( props.verbose >= 1 )
141 cerr << "TreeEP::construct: The tree has size " << JTree::RTree.size() << endl;
142 if( props.verbose >= 3 )
143 cerr << " it is " << JTree::RTree << " with cliques " << cl << endl;
144
145 // Create factor approximations
146 _Q.clear();
147 size_t PreviousRoot = (size_t)-1;
148 // Second repetition: previous root of first off-tree factor should be the root of the last off-tree factor
149 for( size_t repeats = 0; repeats < 2; repeats++ )
150 for( size_t I = 0; I < nrFactors(); I++ )
151 if( offtree(I) ) {
152 // find efficient subtree
153 RootedTree subTree;
154 size_t subTreeSize = findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
155 PreviousRoot = subTree[0].first;
156 subTree.resize( subTreeSize );
157 if( props.verbose >= 1 )
158 cerr << "Subtree " << I << " has size " << subTreeSize << endl;
159 if( props.verbose >= 3 )
160 cerr << " it is " << subTree << endl;
161 _Q[I] = TreeEPSubTree( subTree, RTree, Qa, Qb, &factor(I) );
162 if( repeats == 1 )
163 break;
164 }
165
166 if( props.verbose >= 3 )
167 cerr << "Resulting regiongraph: " << *this << endl;
168 }
169
170
171 string TreeEP::identify() const {
172 return string(Name) + printProperties();
173 }
174
175
176 void TreeEP::init() {
177 runHUGIN();
178
179 // Init factor approximations
180 for( size_t I = 0; I < nrFactors(); I++ )
181 if( offtree(I) )
182 _Q[I].init();
183 }
184
185
186 Real TreeEP::run() {
187 if( props.verbose >= 1 )
188 cerr << "Starting " << identify() << "...";
189 if( props.verbose >= 3 )
190 cerr << endl;
191
192 double tic = toc();
193
194 vector<Factor> oldBeliefs = beliefs();
195
196 // do several passes over the network until maximum number of iterations has
197 // been reached or until the maximum belief difference is smaller than tolerance
198 Real maxDiff = INFINITY;
199 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol; _iters++ ) {
200 for( size_t I = 0; I < nrFactors(); I++ )
201 if( offtree(I) ) {
202 _Q[I].InvertAndMultiply( Qa, Qb );
203 _Q[I].HUGIN_with_I( Qa, Qb );
204 _Q[I].InvertAndMultiply( Qa, Qb );
205 }
206
207 // calculate new beliefs and compare with old ones
208 vector<Factor> newBeliefs = beliefs();
209 maxDiff = -INFINITY;
210 for( size_t t = 0; t < oldBeliefs.size(); t++ )
211 maxDiff = std::max( maxDiff, dist( newBeliefs[t], oldBeliefs[t], DISTLINF ) );
212 swap( newBeliefs, oldBeliefs );
213
214 if( props.verbose >= 3 )
215 cerr << Name << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
216 }
217
218 if( maxDiff > _maxdiff )
219 _maxdiff = maxDiff;
220
221 if( props.verbose >= 1 ) {
222 if( maxDiff > props.tol ) {
223 if( props.verbose == 1 )
224 cerr << endl;
225 cerr << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
226 } else {
227 if( props.verbose >= 3 )
228 cerr << Name << "::run: ";
229 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
230 }
231 }
232
233 return maxDiff;
234 }
235
236
237 Real TreeEP::logZ() const {
238 Real s = 0.0;
239
240 // entropy of the tree
241 for( size_t beta = 0; beta < nrIRs(); beta++ )
242 s -= Qb[beta].entropy();
243 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
244 s += Qa[alpha].entropy();
245
246 // energy of the on-tree factors
247 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
248 s += (OR(alpha).log(true) * Qa[alpha]).sum();
249
250 // energy of the off-tree factors
251 for( size_t I = 0; I < nrFactors(); I++ )
252 if( offtree(I) )
253 s += (_Q.find(I))->second.logZ( Qa, Qb );
254
255 return s;
256 }
257
258
259 TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
260 _ns = _I->vars();
261
262 // Make _Qa, _Qb, _a and _b corresponding to the subtree
263 _b.reserve( subRTree.size() );
264 _Qb.reserve( subRTree.size() );
265 _RTree.reserve( subRTree.size() );
266 for( size_t i = 0; i < subRTree.size(); i++ ) {
267 size_t alpha1 = subRTree[i].first; // old index 1
268 size_t alpha2 = subRTree[i].second; // old index 2
269 size_t beta; // old sep index
270 for( beta = 0; beta < jt_RTree.size(); beta++ )
271 if( UEdge( jt_RTree[beta].first, jt_RTree[beta].second ) == UEdge( alpha1, alpha2 ) )
272 break;
273 DAI_ASSERT( beta != jt_RTree.size() );
274
275 size_t newalpha1 = find(_a.begin(), _a.end(), alpha1) - _a.begin();
276 if( newalpha1 == _a.size() ) {
277 _Qa.push_back( Factor( jt_Qa[alpha1].vars(), 1.0 ) );
278 _a.push_back( alpha1 ); // save old index in index conversion table
279 }
280
281 size_t newalpha2 = find(_a.begin(), _a.end(), alpha2) - _a.begin();
282 if( newalpha2 == _a.size() ) {
283 _Qa.push_back( Factor( jt_Qa[alpha2].vars(), 1.0 ) );
284 _a.push_back( alpha2 ); // save old index in index conversion table
285 }
286
287 _RTree.push_back( DEdge( newalpha1, newalpha2 ) );
288 _Qb.push_back( Factor( jt_Qb[beta].vars(), 1.0 ) );
289 _b.push_back( beta );
290 }
291
292 // Find remaining variables (which are not in the new root)
293 _nsrem = _ns / _Qa[0].vars();
294 }
295
296
297 void TreeEP::TreeEPSubTree::init() {
298 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
299 _Qa[alpha].fill( 1.0 );
300 for( size_t beta = 0; beta < _Qb.size(); beta++ )
301 _Qb[beta].fill( 1.0 );
302 }
303
304
305 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
306 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
307 _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
308
309 for( size_t beta = 0; beta < _Qb.size(); beta++ )
310 _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
311 }
312
313
314 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
315 // Backup _Qa and _Qb
316 vector<Factor> _Qa_old(_Qa);
317 vector<Factor> _Qb_old(_Qb);
318
319 // Clear Qa and Qb
320 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
321 Qa[_a[alpha]].fill( 0.0 );
322 for( size_t beta = 0; beta < _Qb.size(); beta++ )
323 Qb[_b[beta]].fill( 0.0 );
324
325 // For all states of _nsrem
326 for( State s(_nsrem); s.valid(); s++ ) {
327 // Multiply root with slice of I
328 _Qa[0] *= _I->slice( _nsrem, s );
329
330 // CollectEvidence
331 for( size_t i = _RTree.size(); (i--) != 0; ) {
332 // clamp variables in nsrem
333 for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
334 if( _Qa[_RTree[i].second].vars() >> *n )
335 _Qa[_RTree[i].second] *= createFactorDelta( *n, s(*n) );
336 Factor new_Qb = _Qa[_RTree[i].second].marginal( _Qb[i].vars(), false );
337 _Qa[_RTree[i].first] *= new_Qb / _Qb[i];
338 _Qb[i] = new_Qb;
339 }
340
341 // DistributeEvidence
342 for( size_t i = 0; i < _RTree.size(); i++ ) {
343 Factor new_Qb = _Qa[_RTree[i].first].marginal( _Qb[i].vars(), false );
344 _Qa[_RTree[i].second] *= new_Qb / _Qb[i];
345 _Qb[i] = new_Qb;
346 }
347
348 // Store Qa's and Qb's
349 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
350 Qa[_a[alpha]].p() += _Qa[alpha].p();
351 for( size_t beta = 0; beta < _Qb.size(); beta++ )
352 Qb[_b[beta]].p() += _Qb[beta].p();
353
354 // Restore _Qa and _Qb
355 _Qa = _Qa_old;
356 _Qb = _Qb_old;
357 }
358
359 // Normalize Qa and Qb
360 _logZ = 0.0;
361 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
362 _logZ += log(Qa[_a[alpha]].sum());
363 Qa[_a[alpha]].normalize();
364 }
365 for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
366 _logZ -= log(Qb[_b[beta]].sum());
367 Qb[_b[beta]].normalize();
368 }
369 }
370
371
372 Real TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
373 Real s = 0.0;
374 for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
375 s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
376 for( size_t beta = 0; beta < _Qb.size(); beta++ )
377 s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
378 return s + _logZ;
379 }
380
381
382 } // end of namespace dai