68969f83bbd6130419db4ca76d2090b31b71629e
[libdai.git] / src / emalg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/util.h>
10 #include <dai/emalg.h>
11
12
13 namespace dai {
14
15
16 // Initialize static private member of ParameterEstimation
17 std::map<std::string, ParameterEstimation::ParamEstFactory> *ParameterEstimation::_registry = NULL;
18
19
20 void ParameterEstimation::loadDefaultRegistry() {
21 _registry = new std::map<std::string, ParamEstFactory>();
22 (*_registry)["CondProbEstimation"] = CondProbEstimation::factory;
23 }
24
25
26 ParameterEstimation* ParameterEstimation::construct( const std::string &method, const PropertySet &p ) {
27 if( _registry == NULL )
28 loadDefaultRegistry();
29 std::map<std::string, ParamEstFactory>::iterator i = _registry->find(method);
30 if( i == _registry->end() )
31 DAI_THROWE(UNKNOWN_PARAMETER_ESTIMATION_METHOD, "Unknown parameter estimation method: " + method);
32 ParamEstFactory factory = i->second;
33 return factory(p);
34 }
35
36
37 std::string CondProbEstimation::_name = "CondProbEstimation";
38
39 ParameterEstimation* CondProbEstimation::factory( const PropertySet &p ) {
40 size_t target_dimension = p.getStringAs<size_t>("target_dim");
41 size_t total_dimension = p.getStringAs<size_t>("total_dim");
42 Real pseudo_count = 1;
43 if( p.hasKey("pseudo_count") )
44 pseudo_count = p.getStringAs<Real>("pseudo_count");
45 return new CondProbEstimation( target_dimension, Prob( total_dimension, pseudo_count ) );
46 }
47
48 CondProbEstimation::CondProbEstimation( size_t target_dimension, const Prob &pseudocounts )
49 : _target_dim(target_dimension), _initial_stats(pseudocounts), _props()
50 {
51 DAI_ASSERT( !(_initial_stats.size() % _target_dim) );
52 _props.setAsString<size_t>("target_dim", _target_dim);
53 _props.setAsString<size_t>("total_dim", _initial_stats.size());
54 _props.setAsString<Real>("pseudo_count", _initial_stats.get(0));
55 }
56
57
58 Prob CondProbEstimation::parametersToFactor(const Prob& p) {
59 Prob result(p);
60 return result;
61 }
62
63 // In the case of a conditional probability table, the
64 // parameters are identical to the estimated factor
65 Prob CondProbEstimation::parameters(const Prob& p) {
66 Prob stats = p + _initial_stats;
67 // normalize pseudocounts
68 for( size_t parent = 0; parent < stats.size(); parent += _target_dim ) {
69 // calculate norm
70 size_t top = parent + _target_dim;
71 Real norm = 0.0;
72 for( size_t i = parent; i < top; ++i )
73 norm += stats[i];
74 if( norm != 0.0 )
75 norm = 1.0 / norm;
76 // normalize
77 for( size_t i = parent; i < top; ++i )
78 stats.set( i, stats[i] * norm );
79 }
80 return stats;
81 }
82
83
84 Permute SharedParameters::calculatePermutation( const std::vector<Var> &varOrder, VarSet &outVS ) {
85 outVS = VarSet( varOrder.begin(), varOrder.end(), varOrder.size() );
86 return Permute( varOrder );
87 }
88
89
90 void SharedParameters::setPermsAndVarSetsFromVarOrders() {
91 if( _varorders.size() == 0 )
92 return;
93 DAI_ASSERT( _estimation != NULL );
94 _expectations = new Prob(_estimation->probSize(), 0);
95
96 // Construct the permutation objects and the varsets
97 for( FactorOrientations::const_iterator foi = _varorders.begin(); foi != _varorders.end(); ++foi ) {
98 VarSet vs;
99 _perms[foi->first] = calculatePermutation( foi->second, vs );
100 _varsets[foi->first] = vs;
101 DAI_ASSERT( (BigInt)_estimation->probSize() == vs.nrStates() );
102 }
103 }
104
105
106 SharedParameters::SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool ownPE )
107 : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _ownEstimation(ownPE), _expectations(NULL)
108 {
109 // Calculate the necessary permutations and varsets
110 setPermsAndVarSetsFromVarOrders();
111 }
112
113
114 SharedParameters::SharedParameters( std::istream &is, const FactorGraph &fg )
115 : _varsets(), _perms(), _varorders(), _estimation(NULL), _ownEstimation(true), _expectations(NULL)
116 {
117 // Read the desired parameter estimation method from the stream
118 std::string est_method;
119 PropertySet props;
120 is >> est_method;
121 is >> props;
122
123 // Construct a corresponding object
124 _estimation = ParameterEstimation::construct( est_method, props );
125
126 // Read in the factors that are to be estimated
127 size_t num_factors;
128 is >> num_factors;
129 for( size_t sp_i = 0; sp_i < num_factors; ++sp_i ) {
130 std::string line;
131 while( line.size() == 0 && getline(is, line) )
132 ;
133
134 std::vector<std::string> fields = tokenizeString( line, true, " \t" );
135
136 // Lookup the factor in the factorgraph
137 if( fields.size() < 1 )
138 DAI_THROWE(INVALID_EMALG_FILE,"Empty line unexpected");
139 std::istringstream iss;
140 iss.str( fields[0] );
141 size_t factor;
142 iss >> factor;
143 const VarSet &vs = fg.factor(factor).vars();
144 if( fields.size() != vs.size() + 1 )
145 DAI_THROWE(INVALID_EMALG_FILE,"Number of fields does not match factor size");
146
147 // Construct the vector of Vars
148 std::vector<Var> var_order;
149 var_order.reserve( vs.size() );
150 for( size_t fi = 1; fi < fields.size(); ++fi ) {
151 // Lookup a single variable by label
152 size_t label;
153 std::istringstream labelparse( fields[fi] );
154 labelparse >> label;
155 VarSet::const_iterator vsi = vs.begin();
156 for( ; vsi != vs.end(); ++vsi )
157 if( vsi->label() == label )
158 break;
159 if( vsi == vs.end() )
160 DAI_THROWE(INVALID_EMALG_FILE,"Specified variables do not match the factor variables");
161 var_order.push_back( *vsi );
162 }
163 _varorders[factor] = var_order;
164 }
165
166 // Calculate the necessary permutations
167 setPermsAndVarSetsFromVarOrders();
168 }
169
170
171 void SharedParameters::collectExpectations( InfAlg &alg ) {
172 for( std::map< FactorIndex, Permute >::iterator i = _perms.begin(); i != _perms.end(); ++i ) {
173 Permute &perm = i->second;
174 VarSet &vs = _varsets[i->first];
175
176 Factor b = alg.belief(vs);
177 Prob p( b.nrStates(), 0.0 );
178 for( size_t entry = 0; entry < b.nrStates(); ++entry )
179 p.set( entry, b[perm.convertLinearIndex(entry)] ); // apply inverse permutation
180 (*_expectations) += p;
181 }
182 }
183
184
185 void SharedParameters::setParameters( FactorGraph &fg ) {
186 Prob p = _estimation->estimate(this->currentExpectations());
187 for( std::map<FactorIndex, Permute>::iterator i = _perms.begin(); i != _perms.end(); ++i ) {
188 Permute &perm = i->second;
189 VarSet &vs = _varsets[i->first];
190
191 Factor f( vs, 0.0 );
192 for( size_t entry = 0; entry < f.nrStates(); ++entry )
193 f.set( perm.convertLinearIndex(entry), p[entry] );
194
195 fg.setFactor( i->first, f );
196 }
197 }
198
199
200 MaximizationStep::MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup ) : _params() {
201 size_t num_params = -1;
202 is >> num_params;
203 _params.reserve( num_params );
204 for( size_t i = 0; i < num_params; ++i )
205 _params.push_back( SharedParameters( is, fg_varlookup ) );
206 }
207
208
209 void MaximizationStep::addExpectations( InfAlg &alg ) {
210 for( size_t i = 0; i < _params.size(); ++i )
211 _params[i].collectExpectations( alg );
212 }
213
214
215 void MaximizationStep::maximize( FactorGraph &fg ) {
216 for( size_t i = 0; i < _params.size(); ++i )
217 _params[i].setParameters( fg );
218 }
219
220
221 void MaximizationStep::clear( ) {
222 for( size_t i = 0; i < _params.size(); ++i )
223 _params[i].clear( );
224 }
225
226 const std::string EMAlg::MAX_ITERS_KEY("max_iters");
227 const std::string EMAlg::LOG_Z_TOL_KEY("log_z_tol");
228 const size_t EMAlg::MAX_ITERS_DEFAULT = 30;
229 const Real EMAlg::LOG_Z_TOL_DEFAULT = 0.01;
230
231
232 EMAlg::EMAlg( const Evidence &evidence, InfAlg &estep, std::istream &msteps_file )
233 : _evidence(evidence), _estep(estep), _msteps(), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT), _log_z_tol(LOG_Z_TOL_DEFAULT)
234 {
235 msteps_file.exceptions( std::istream::eofbit | std::istream::failbit | std::istream::badbit );
236 size_t num_msteps = -1;
237 msteps_file >> num_msteps;
238 _msteps.reserve(num_msteps);
239 for( size_t i = 0; i < num_msteps; ++i )
240 _msteps.push_back( MaximizationStep( msteps_file, estep.fg() ) );
241 }
242
243
244 void EMAlg::setTermConditions( const PropertySet &p ) {
245 if( p.hasKey(MAX_ITERS_KEY) )
246 _max_iters = p.getStringAs<size_t>(MAX_ITERS_KEY);
247 if( p.hasKey(LOG_Z_TOL_KEY) )
248 _log_z_tol = p.getStringAs<Real>(LOG_Z_TOL_KEY);
249 }
250
251
252 bool EMAlg::hasSatisfiedTermConditions() const {
253 if( _iters >= _max_iters )
254 return true;
255 else if( _lastLogZ.size() < 3 )
256 // need at least 2 to calculate ratio
257 // Also, throw away first iteration, as the parameters may not
258 // have been normalized according to the estimation method
259 return false;
260 else {
261 Real current = _lastLogZ[_lastLogZ.size() - 1];
262 Real previous = _lastLogZ[_lastLogZ.size() - 2];
263 if( previous == 0 )
264 return false;
265 Real diff = current - previous;
266 if( diff < 0 ) {
267 std::cerr << "Error: in EM log-likehood decreased from " << previous << " to " << current << std::endl;
268 return true;
269 }
270 return (diff / fabs(previous)) <= _log_z_tol;
271 }
272 }
273
274
275 Real EMAlg::iterate( MaximizationStep &mstep ) {
276 Real logZ = 0;
277 Real likelihood = 0;
278
279 mstep.clear();
280
281 _estep.run();
282 logZ = _estep.logZ();
283
284 // Expectation calculation
285 for( Evidence::const_iterator e = _evidence.begin(); e != _evidence.end(); ++e ) {
286 InfAlg* clamped = _estep.clone();
287 // Apply evidence
288 for( Evidence::Observation::const_iterator i = e->begin(); i != e->end(); ++i )
289 clamped->clamp( clamped->fg().findVar(i->first), i->second );
290 clamped->init();
291 clamped->run();
292
293 likelihood += clamped->logZ() - logZ;
294
295 mstep.addExpectations( *clamped );
296
297 delete clamped;
298 }
299
300 // Maximization of parameters
301 mstep.maximize( _estep.fg() );
302
303 return likelihood;
304 }
305
306
307 Real EMAlg::iterate() {
308 Real likelihood;
309 for( size_t i = 0; i < _msteps.size(); ++i )
310 likelihood = iterate( _msteps[i] );
311 _lastLogZ.push_back( likelihood );
312 ++_iters;
313 return likelihood;
314 }
315
316
317 void EMAlg::run() {
318 while( !hasSatisfiedTermConditions() )
319 iterate();
320 }
321
322
323 } // end of namespace dai