EM bugfix. Convenience methods in Factor, Permute, Properties, EM.
[libdai.git] / src / emalg.cpp
1 /* Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
2 University of California Santa Cruz
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <dai/util.h>
23 #include <dai/emalg.h>
24
25
26 namespace dai {
27
28
29 std::map<std::string, ParameterEstimation::ParamEstFactory> *ParameterEstimation::_registry = NULL;
30
31
32 void ParameterEstimation::loadDefaultRegistry() {
33 _registry = new std::map<std::string, ParamEstFactory>();
34 (*_registry)["ConditionalProbEstimation"] = CondProbEstimation::factory;
35 }
36
37
38 ParameterEstimation* ParameterEstimation::construct( const std::string &method, const PropertySet &p ) {
39 if( _registry == NULL )
40 loadDefaultRegistry();
41 std::map<std::string, ParamEstFactory>::iterator i = _registry->find(method);
42 if( i == _registry->end() )
43 DAI_THROW(UNKNOWN_PARAMETER_ESTIMATION_METHOD);
44 ParamEstFactory factory = i->second;
45 return factory(p);
46 }
47
48
49 ParameterEstimation* CondProbEstimation::factory( const PropertySet &p ) {
50 size_t target_dimension = p.getStringAs<size_t>("target_dim");
51 size_t total_dimension = p.getStringAs<size_t>("total_dim");
52 Real pseudo_count = 1;
53 if( p.hasKey("pseudo_count") )
54 pseudo_count = p.getStringAs<Real>("pseudo_count");
55 return new CondProbEstimation( target_dimension, Prob( total_dimension, pseudo_count ) );
56 }
57
58
59 CondProbEstimation::CondProbEstimation( size_t target_dimension, const Prob &pseudocounts )
60 : _target_dim(target_dimension), _stats(pseudocounts), _initial_stats(pseudocounts)
61 {
62 if( _stats.size() % _target_dim )
63 DAI_THROW(MALFORMED_PROPERTY);
64 }
65
66
67 void CondProbEstimation::addSufficientStatistics( const Prob &p ) {
68 _stats += p;
69 }
70
71
72 Prob CondProbEstimation::estimate() {
73 // normalize pseudocounts
74 for( size_t parent = 0; parent < _stats.size(); parent += _target_dim ) {
75 // calculate norm
76 Real norm = 0.0;
77 size_t top = parent + _target_dim;
78 for( size_t i = parent; i < top; ++i )
79 norm += _stats[i];
80 if( norm != 0.0 )
81 norm = 1.0 / norm;
82 // normalize
83 for( size_t i = parent; i < top; ++i )
84 _stats[i] *= norm;
85 }
86 // reset _stats to _initial_stats
87 Prob result = _stats;
88 _stats = _initial_stats;
89 return result;
90 }
91
92
93 Permute SharedParameters::calculatePermutation( const std::vector<Var> &varorder, VarSet &outVS ) {
94 // Collect all labels and dimensions, and order them in vs
95 std::vector<size_t> dims;
96 dims.reserve( varorder.size() );
97 std::vector<long> labels;
98 labels.reserve( varorder.size() );
99 for( size_t i = 0; i < varorder.size(); i++ ) {
100 dims.push_back( varorder[i].states() );
101 labels.push_back( varorder[i].label() );
102 outVS |= varorder[i];
103 }
104
105 // Construct the sigma array for the permutation object
106 std::vector<size_t> sigma;
107 sigma.reserve( dims.size() );
108 for( VarSet::iterator set_iterator = outVS.begin(); sigma.size() < dims.size(); ++set_iterator )
109 sigma.push_back( find(labels.begin(), labels.end(), set_iterator->label()) - labels.begin() );
110
111 return Permute( dims, sigma );
112 }
113
114
115 void SharedParameters::setPermsAndVarSetsFromVarOrders() {
116 if( _varorders.size() == 0 )
117 return;
118 if( _estimation == NULL )
119 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
120
121 // Construct the permutation objects and the varsets
122 for( FactorOrientations::const_iterator foi = _varorders.begin(); foi != _varorders.end(); ++foi ) {
123 VarSet vs;
124 _perms[foi->first] = calculatePermutation( foi->second, vs );
125 _varsets[foi->first] = vs;
126 if( _estimation->probSize() != vs.nrStates() )
127 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
128 }
129 }
130
131
132 SharedParameters::SharedParameters( std::istream &is, const FactorGraph &fg_varlookup )
133 : _varsets(), _perms(), _varorders(), _estimation(NULL), _deleteEstimation(true)
134 {
135 // Read the desired parameter estimation method from the stream
136 std::string est_method;
137 PropertySet props;
138 is >> est_method;
139 is >> props;
140
141 // Construct a corresponding object
142 _estimation = ParameterEstimation::construct( est_method, props );
143
144 // Read in the factors that are to be estimated
145 size_t num_factors;
146 is >> num_factors;
147 for( size_t sp_i = 0; sp_i < num_factors; ++sp_i ) {
148 std::string line;
149 while( line.size() == 0 && getline(is, line) )
150 ;
151
152 std::vector<std::string> fields;
153 tokenizeString(line, fields, " \t");
154
155 // Lookup the factor in the factorgraph
156 if( fields.size() < 1 )
157 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
158 std::istringstream iss;
159 iss.str( fields[0] );
160 size_t factor;
161 iss >> factor;
162 const VarSet &vs = fg_varlookup.factor(factor).vars();
163 if( fields.size() != vs.size() + 1 )
164 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
165
166 // Construct the vector of Vars
167 std::vector<Var> var_order;
168 var_order.reserve( vs.size() );
169 for( size_t fi = 1; fi < fields.size(); ++fi ) {
170 // Lookup a single variable by label
171 long label;
172 std::istringstream labelparse( fields[fi] );
173 labelparse >> label;
174 VarSet::const_iterator vsi = vs.begin();
175 for( ; vsi != vs.end(); ++vsi )
176 if( vsi->label() == label )
177 break;
178 if( vsi == vs.end() )
179 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
180 var_order.push_back( *vsi );
181 }
182 _varorders[factor] = var_order;
183 }
184
185 // Calculate the necessary permutations
186 setPermsAndVarSetsFromVarOrders();
187 }
188
189
190 SharedParameters::SharedParameters( const SharedParameters &sp )
191 : _varsets(sp._varsets), _perms(sp._perms), _varorders(sp._varorders), _estimation(sp._estimation), _deleteEstimation(sp._deleteEstimation)
192 {
193 // If sp owns its _estimation object, we should clone it instead
194 if( _deleteEstimation )
195 _estimation = _estimation->clone();
196 }
197
198
199 SharedParameters::SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool deleteParameterEstimationInDestructor )
200 : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _deleteEstimation(deleteParameterEstimationInDestructor)
201 {
202 // Calculate the necessary permutations
203 setPermsAndVarSetsFromVarOrders();
204 }
205
206
207 void SharedParameters::collectSufficientStatistics( InfAlg &alg ) {
208 for( std::map< FactorIndex, Permute >::iterator i = _perms.begin(); i != _perms.end(); ++i ) {
209 Permute &perm = i->second;
210 VarSet &vs = _varsets[i->first];
211
212 Factor b = alg.belief(vs);
213 Prob p( b.states(), 0.0 );
214 for( size_t entry = 0; entry < b.states(); ++entry )
215 p[entry] = b[perm.convert_linear_index(entry)];
216 _estimation->addSufficientStatistics( p );
217 }
218 }
219
220
221 void SharedParameters::setParameters( FactorGraph &fg ) {
222 Prob p = _estimation->estimate();
223 for( std::map<FactorIndex, Permute>::iterator i = _perms.begin(); i != _perms.end(); ++i ) {
224 Permute &perm = i->second;
225 VarSet &vs = _varsets[i->first];
226
227 Factor f( vs, 0.0 );
228 for( size_t entry = 0; entry < f.states(); ++entry )
229 f[perm.convert_linear_index(entry)] = p[entry];
230
231 fg.setFactor( i->first, f );
232 }
233 }
234
235
236 void SharedParameters::collectParameters( const FactorGraph& fg, std::vector< Real >& outVals, std::vector< Var >& outVarOrder ) {
237 FactorOrientations::iterator it = _varorders.begin();
238 if (it == _varorders.end()) {
239 return;
240 }
241 FactorIndex i = it->first;
242 std::vector< Var >::iterator var_it = _varorders[i].begin();
243 std::vector< Var >::iterator var_stop = _varorders[i].end();
244 for ( ; var_it != var_stop; ++var_it) {
245 outVarOrder.push_back(*var_it);
246 }
247 const Factor& f = fg.factor(i);
248 assert(f.vars() == _varsets[i]);
249 const Permute& perm = _perms[i];
250 for (size_t val_index = 0; val_index < f.states(); ++val_index) {
251 outVals.push_back(f[perm.convert_linear_index(val_index)]);
252 }
253 }
254
255
256 MaximizationStep::MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup ) : _params() {
257 size_t num_params = -1;
258 is >> num_params;
259 _params.reserve( num_params );
260 for( size_t i = 0; i < num_params; ++i )
261 _params.push_back( SharedParameters( is, fg_varlookup ) );
262 }
263
264
265 void MaximizationStep::addExpectations( InfAlg &alg ) {
266 for( size_t i = 0; i < _params.size(); ++i )
267 _params[i].collectSufficientStatistics( alg );
268 }
269
270
271 void MaximizationStep::maximize( FactorGraph &fg ) {
272 for( size_t i = 0; i < _params.size(); ++i )
273 _params[i].setParameters( fg );
274 }
275
276
277 const std::string EMAlg::MAX_ITERS_KEY("max_iters");
278 const std::string EMAlg::LOG_Z_TOL_KEY("log_z_tol");
279 const size_t EMAlg::MAX_ITERS_DEFAULT = 30;
280 const Real EMAlg::LOG_Z_TOL_DEFAULT = 0.01;
281
282
283 EMAlg::EMAlg( const Evidence &evidence, InfAlg &estep, std::istream &msteps_file )
284 : _evidence(evidence), _estep(estep), _msteps(), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT), _log_z_tol(LOG_Z_TOL_DEFAULT)
285 {
286 msteps_file.exceptions( std::istream::eofbit | std::istream::failbit | std::istream::badbit );
287 size_t num_msteps = -1;
288 msteps_file >> num_msteps;
289 _msteps.reserve(num_msteps);
290 for( size_t i = 0; i < num_msteps; ++i )
291 _msteps.push_back( MaximizationStep( msteps_file, estep.fg() ) );
292 }
293
294
295 void EMAlg::setTermConditions( const PropertySet &p ) {
296 if( p.hasKey(MAX_ITERS_KEY) )
297 _max_iters = p.getStringAs<size_t>(MAX_ITERS_KEY);
298 if( p.hasKey(LOG_Z_TOL_KEY) )
299 _log_z_tol = p.getStringAs<Real>(LOG_Z_TOL_KEY);
300 }
301
302
303 bool EMAlg::hasSatisfiedTermConditions() const {
304 if( _iters >= _max_iters )
305 return true;
306 else if( _lastLogZ.size() < 3 )
307 // need at least 2 to calculate ratio
308 // Also, throw away first iteration, as the parameters may not
309 // have been normalized according to the estimation method
310 return false;
311 else {
312 Real current = _lastLogZ[_lastLogZ.size() - 1];
313 Real previous = _lastLogZ[_lastLogZ.size() - 2];
314 if( previous == 0 )
315 return false;
316 Real diff = current - previous;
317 if( diff < 0 ) {
318 std::cerr << "Error: in EM log-likehood decreased from " << previous << " to " << current << std::endl;
319 return true;
320 }
321 return (diff / fabs(previous)) <= _log_z_tol;
322 }
323 }
324
325
326 Real EMAlg::iterate( MaximizationStep &mstep ) {
327 Real logZ = 0;
328 Real likelihood = 0;
329
330 _estep.run();
331 logZ = _estep.logZ();
332
333 // Expectation calculation
334 for( Evidence::const_iterator e = _evidence.begin(); e != _evidence.end(); ++e ) {
335 InfAlg* clamped = _estep.clone();
336 e->applyEvidence( *clamped );
337 clamped->init();
338 clamped->run();
339
340 likelihood += clamped->logZ() - logZ;
341
342 mstep.addExpectations( *clamped );
343
344 delete clamped;
345 }
346
347 // Maximization of parameters
348 mstep.maximize( _estep.fg() );
349
350 return likelihood;
351 }
352
353
354 Real EMAlg::iterate() {
355 Real likelihood;
356 for( size_t i = 0; i < _msteps.size(); ++i )
357 likelihood = iterate( _msteps[i] );
358 _lastLogZ.push_back( likelihood );
359 ++_iters;
360 return likelihood;
361 }
362
363
364 void EMAlg::run() {
365 while( !hasSatisfiedTermConditions() )
366 iterate();
367 }
368
369
370 } // end of namespace dai