Improved documentation of include/dai/evidence.h
[libdai.git] / src / emalg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
8 * Copyright (C) 2009 University of California, Santa Cruz
9 */
10
11
12 #include <dai/util.h>
13 #include <dai/emalg.h>
14
15
16 namespace dai {
17
18
19 std::map<std::string, ParameterEstimation::ParamEstFactory> *ParameterEstimation::_registry = NULL;
20
21
22 void ParameterEstimation::loadDefaultRegistry() {
23 _registry = new std::map<std::string, ParamEstFactory>();
24 (*_registry)["ConditionalProbEstimation"] = CondProbEstimation::factory;
25 }
26
27
28 ParameterEstimation* ParameterEstimation::construct( const std::string &method, const PropertySet &p ) {
29 if( _registry == NULL )
30 loadDefaultRegistry();
31 std::map<std::string, ParamEstFactory>::iterator i = _registry->find(method);
32 if( i == _registry->end() )
33 DAI_THROWE(UNKNOWN_PARAMETER_ESTIMATION_METHOD, "Unknown parameter estimation method: " + method);
34 ParamEstFactory factory = i->second;
35 return factory(p);
36 }
37
38
39 ParameterEstimation* CondProbEstimation::factory( const PropertySet &p ) {
40 size_t target_dimension = p.getStringAs<size_t>("target_dim");
41 size_t total_dimension = p.getStringAs<size_t>("total_dim");
42 Real pseudo_count = 1;
43 if( p.hasKey("pseudo_count") )
44 pseudo_count = p.getStringAs<Real>("pseudo_count");
45 return new CondProbEstimation( target_dimension, Prob( total_dimension, pseudo_count ) );
46 }
47
48
49 CondProbEstimation::CondProbEstimation( size_t target_dimension, const Prob &pseudocounts )
50 : _target_dim(target_dimension), _stats(pseudocounts), _initial_stats(pseudocounts)
51 {
52 DAI_ASSERT( !(_stats.size() % _target_dim) );
53 }
54
55
56 void CondProbEstimation::addSufficientStatistics( const Prob &p ) {
57 _stats += p;
58 }
59
60
61 Prob CondProbEstimation::estimate() {
62 // normalize pseudocounts
63 for( size_t parent = 0; parent < _stats.size(); parent += _target_dim ) {
64 // calculate norm
65 Real norm = 0.0;
66 size_t top = parent + _target_dim;
67 for( size_t i = parent; i < top; ++i )
68 norm += _stats[i];
69 if( norm != 0.0 )
70 norm = 1.0 / norm;
71 // normalize
72 for( size_t i = parent; i < top; ++i )
73 _stats[i] *= norm;
74 }
75 // reset _stats to _initial_stats
76 Prob result = _stats;
77 _stats = _initial_stats;
78 return result;
79 }
80
81
82 Permute SharedParameters::calculatePermutation( const std::vector<Var> &varorder, VarSet &outVS ) {
83 // Collect all labels and dimensions, and order them in vs
84 std::vector<size_t> dims;
85 dims.reserve( varorder.size() );
86 std::vector<long> labels;
87 labels.reserve( varorder.size() );
88 for( size_t i = 0; i < varorder.size(); i++ ) {
89 dims.push_back( varorder[i].states() );
90 labels.push_back( varorder[i].label() );
91 outVS |= varorder[i];
92 }
93
94 // Construct the sigma array for the permutation object
95 std::vector<size_t> sigma;
96 sigma.reserve( dims.size() );
97 for( VarSet::iterator set_iterator = outVS.begin(); sigma.size() < dims.size(); ++set_iterator )
98 sigma.push_back( find(labels.begin(), labels.end(), set_iterator->label()) - labels.begin() );
99
100 return Permute( dims, sigma );
101 }
102
103
104 void SharedParameters::setPermsAndVarSetsFromVarOrders() {
105 if( _varorders.size() == 0 )
106 return;
107 DAI_ASSERT( _estimation != NULL );
108
109 // Construct the permutation objects and the varsets
110 for( FactorOrientations::const_iterator foi = _varorders.begin(); foi != _varorders.end(); ++foi ) {
111 VarSet vs;
112 _perms[foi->first] = calculatePermutation( foi->second, vs );
113 _varsets[foi->first] = vs;
114 DAI_ASSERT( _estimation->probSize() == vs.nrStates() );
115 }
116 }
117
118
119 SharedParameters::SharedParameters( std::istream &is, const FactorGraph &fg_varlookup )
120 : _varsets(), _perms(), _varorders(), _estimation(NULL), _deleteEstimation(true)
121 {
122 // Read the desired parameter estimation method from the stream
123 std::string est_method;
124 PropertySet props;
125 is >> est_method;
126 is >> props;
127
128 // Construct a corresponding object
129 _estimation = ParameterEstimation::construct( est_method, props );
130
131 // Read in the factors that are to be estimated
132 size_t num_factors;
133 is >> num_factors;
134 for( size_t sp_i = 0; sp_i < num_factors; ++sp_i ) {
135 std::string line;
136 while( line.size() == 0 && getline(is, line) )
137 ;
138
139 std::vector<std::string> fields;
140 tokenizeString(line, fields, " \t");
141
142 // Lookup the factor in the factorgraph
143 if( fields.size() < 1 )
144 DAI_THROW(INVALID_EMALG_FILE);
145 std::istringstream iss;
146 iss.str( fields[0] );
147 size_t factor;
148 iss >> factor;
149 const VarSet &vs = fg_varlookup.factor(factor).vars();
150 if( fields.size() != vs.size() + 1 )
151 DAI_THROW(INVALID_EMALG_FILE);
152
153 // Construct the vector of Vars
154 std::vector<Var> var_order;
155 var_order.reserve( vs.size() );
156 for( size_t fi = 1; fi < fields.size(); ++fi ) {
157 // Lookup a single variable by label
158 long label;
159 std::istringstream labelparse( fields[fi] );
160 labelparse >> label;
161 VarSet::const_iterator vsi = vs.begin();
162 for( ; vsi != vs.end(); ++vsi )
163 if( vsi->label() == label )
164 break;
165 if( vsi == vs.end() )
166 DAI_THROW(INVALID_EMALG_FILE);
167 var_order.push_back( *vsi );
168 }
169 _varorders[factor] = var_order;
170 }
171
172 // Calculate the necessary permutations
173 setPermsAndVarSetsFromVarOrders();
174 }
175
176
177 SharedParameters::SharedParameters( const SharedParameters &sp )
178 : _varsets(sp._varsets), _perms(sp._perms), _varorders(sp._varorders), _estimation(sp._estimation), _deleteEstimation(sp._deleteEstimation)
179 {
180 // If sp owns its _estimation object, we should clone it instead
181 if( _deleteEstimation )
182 _estimation = _estimation->clone();
183 }
184
185
186 SharedParameters::SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool deletePE )
187 : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _deleteEstimation(deletePE)
188 {
189 // Calculate the necessary permutations
190 setPermsAndVarSetsFromVarOrders();
191 }
192
193
194 void SharedParameters::collectSufficientStatistics( InfAlg &alg ) {
195 for( std::map< FactorIndex, Permute >::iterator i = _perms.begin(); i != _perms.end(); ++i ) {
196 Permute &perm = i->second;
197 VarSet &vs = _varsets[i->first];
198
199 Factor b = alg.belief(vs);
200 Prob p( b.states(), 0.0 );
201 for( size_t entry = 0; entry < b.states(); ++entry )
202 p[entry] = b[perm.convertLinearIndex(entry)];
203 _estimation->addSufficientStatistics( p );
204 }
205 }
206
207
208 void SharedParameters::setParameters( FactorGraph &fg ) {
209 Prob p = _estimation->estimate();
210 for( std::map<FactorIndex, Permute>::iterator i = _perms.begin(); i != _perms.end(); ++i ) {
211 Permute &perm = i->second;
212 VarSet &vs = _varsets[i->first];
213
214 Factor f( vs, 0.0 );
215 for( size_t entry = 0; entry < f.states(); ++entry )
216 f[perm.convertLinearIndex(entry)] = p[entry];
217
218 fg.setFactor( i->first, f );
219 }
220 }
221
222
223 void SharedParameters::collectParameters( const FactorGraph &fg, std::vector<Real> &outVals, std::vector<Var> &outVarOrder ) {
224 FactorOrientations::iterator it = _varorders.begin();
225 if( it == _varorders.end() )
226 return;
227 FactorIndex I = it->first;
228 for( std::vector<Var>::const_iterator var_it = _varorders[I].begin(); var_it != _varorders[I].end(); ++var_it )
229 outVarOrder.push_back( *var_it );
230
231 const Factor &f = fg.factor(I);
232 DAI_ASSERT( f.vars() == _varsets[I] );
233 const Permute &perm = _perms[I];
234 for( size_t val_index = 0; val_index < f.states(); ++val_index )
235 outVals.push_back( f[perm.convertLinearIndex(val_index)] );
236 }
237
238
239 MaximizationStep::MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup ) : _params() {
240 size_t num_params = -1;
241 is >> num_params;
242 _params.reserve( num_params );
243 for( size_t i = 0; i < num_params; ++i )
244 _params.push_back( SharedParameters( is, fg_varlookup ) );
245 }
246
247
248 void MaximizationStep::addExpectations( InfAlg &alg ) {
249 for( size_t i = 0; i < _params.size(); ++i )
250 _params[i].collectSufficientStatistics( alg );
251 }
252
253
254 void MaximizationStep::maximize( FactorGraph &fg ) {
255 for( size_t i = 0; i < _params.size(); ++i )
256 _params[i].setParameters( fg );
257 }
258
259
260 const std::string EMAlg::MAX_ITERS_KEY("max_iters");
261 const std::string EMAlg::LOG_Z_TOL_KEY("log_z_tol");
262 const size_t EMAlg::MAX_ITERS_DEFAULT = 30;
263 const Real EMAlg::LOG_Z_TOL_DEFAULT = 0.01;
264
265
266 EMAlg::EMAlg( const Evidence &evidence, InfAlg &estep, std::istream &msteps_file )
267 : _evidence(evidence), _estep(estep), _msteps(), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT), _log_z_tol(LOG_Z_TOL_DEFAULT)
268 {
269 msteps_file.exceptions( std::istream::eofbit | std::istream::failbit | std::istream::badbit );
270 size_t num_msteps = -1;
271 msteps_file >> num_msteps;
272 _msteps.reserve(num_msteps);
273 for( size_t i = 0; i < num_msteps; ++i )
274 _msteps.push_back( MaximizationStep( msteps_file, estep.fg() ) );
275 }
276
277
278 void EMAlg::setTermConditions( const PropertySet &p ) {
279 if( p.hasKey(MAX_ITERS_KEY) )
280 _max_iters = p.getStringAs<size_t>(MAX_ITERS_KEY);
281 if( p.hasKey(LOG_Z_TOL_KEY) )
282 _log_z_tol = p.getStringAs<Real>(LOG_Z_TOL_KEY);
283 }
284
285
286 bool EMAlg::hasSatisfiedTermConditions() const {
287 if( _iters >= _max_iters )
288 return true;
289 else if( _lastLogZ.size() < 3 )
290 // need at least 2 to calculate ratio
291 // Also, throw away first iteration, as the parameters may not
292 // have been normalized according to the estimation method
293 return false;
294 else {
295 Real current = _lastLogZ[_lastLogZ.size() - 1];
296 Real previous = _lastLogZ[_lastLogZ.size() - 2];
297 if( previous == 0 )
298 return false;
299 Real diff = current - previous;
300 if( diff < 0 ) {
301 std::cerr << "Error: in EM log-likehood decreased from " << previous << " to " << current << std::endl;
302 return true;
303 }
304 return (diff / fabs(previous)) <= _log_z_tol;
305 }
306 }
307
308
309 Real EMAlg::iterate( MaximizationStep &mstep ) {
310 Real logZ = 0;
311 Real likelihood = 0;
312
313 _estep.run();
314 logZ = _estep.logZ();
315
316 // Expectation calculation
317 for( Evidence::const_iterator e = _evidence.begin(); e != _evidence.end(); ++e ) {
318 InfAlg* clamped = _estep.clone();
319 // Apply evidence
320 for( Observation::const_iterator i = e->begin(); i != e->end(); ++i )
321 clamped->clamp( clamped->fg().findVar(i->first), i->second );
322 clamped->init();
323 clamped->run();
324
325 likelihood += clamped->logZ() - logZ;
326
327 mstep.addExpectations( *clamped );
328
329 delete clamped;
330 }
331
332 // Maximization of parameters
333 mstep.maximize( _estep.fg() );
334
335 return likelihood;
336 }
337
338
339 Real EMAlg::iterate() {
340 Real likelihood;
341 for( size_t i = 0; i < _msteps.size(); ++i )
342 likelihood = iterate( _msteps[i] );
343 _lastLogZ.push_back( likelihood );
344 ++_iters;
345 return likelihood;
346 }
347
348
349 void EMAlg::run() {
350 while( !hasSatisfiedTermConditions() )
351 iterate();
352 }
353
354
355 } // end of namespace dai