e8958c28c1345c46d5b794aa6e326922269eb199
[libdai.git] / src / emalg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/util.h>
10 #include <dai/emalg.h>
11
12
13 namespace dai {
14
15
16 // Initialize static private member of ParameterEstimation
17 std::map<std::string, ParameterEstimation::ParamEstFactory> *ParameterEstimation::_registry = NULL;
18
19
20 void ParameterEstimation::loadDefaultRegistry() {
21 _registry = new std::map<std::string, ParamEstFactory>();
22 (*_registry)["CondProbEstimation"] = CondProbEstimation::factory;
23 }
24
25
26 ParameterEstimation* ParameterEstimation::construct( const std::string &method, const PropertySet &p ) {
27 if( _registry == NULL )
28 loadDefaultRegistry();
29 std::map<std::string, ParamEstFactory>::iterator i = _registry->find(method);
30 if( i == _registry->end() )
31 DAI_THROWE(UNKNOWN_PARAMETER_ESTIMATION_METHOD, "Unknown parameter estimation method: " + method);
32 ParamEstFactory factory = i->second;
33 return factory(p);
34 }
35
36
37 ParameterEstimation* CondProbEstimation::factory( const PropertySet &p ) {
38 size_t target_dimension = p.getStringAs<size_t>("target_dim");
39 size_t total_dimension = p.getStringAs<size_t>("total_dim");
40 Real pseudo_count = 1;
41 if( p.hasKey("pseudo_count") )
42 pseudo_count = p.getStringAs<Real>("pseudo_count");
43 return new CondProbEstimation( target_dimension, Prob( total_dimension, pseudo_count ) );
44 }
45
46
47 CondProbEstimation::CondProbEstimation( size_t target_dimension, const Prob &pseudocounts )
48 : _target_dim(target_dimension), _stats(pseudocounts), _initial_stats(pseudocounts)
49 {
50 DAI_ASSERT( !(_stats.size() % _target_dim) );
51 }
52
53
54 void CondProbEstimation::addSufficientStatistics( const Prob &p ) {
55 _stats += p;
56 }
57
58
59 Prob CondProbEstimation::estimate() {
60 // normalize pseudocounts
61 for( size_t parent = 0; parent < _stats.size(); parent += _target_dim ) {
62 // calculate norm
63 size_t top = parent + _target_dim;
64 Real norm = 0.0;
65 for( size_t i = parent; i < top; ++i )
66 norm += _stats[i];
67 if( norm != 0.0 )
68 norm = 1.0 / norm;
69 // normalize
70 for( size_t i = parent; i < top; ++i )
71 _stats.set( i, _stats[i] * norm );
72 }
73 // reset _stats to _initial_stats
74 Prob result = _stats;
75 _stats = _initial_stats;
76 return result;
77 }
78
79
80 Permute SharedParameters::calculatePermutation( const std::vector<Var> &varOrder, VarSet &outVS ) {
81 outVS = VarSet( varOrder.begin(), varOrder.end(), varOrder.size() );
82 return Permute( varOrder );
83 }
84
85
86 void SharedParameters::setPermsAndVarSetsFromVarOrders() {
87 if( _varorders.size() == 0 )
88 return;
89 DAI_ASSERT( _estimation != NULL );
90
91 // Construct the permutation objects and the varsets
92 for( FactorOrientations::const_iterator foi = _varorders.begin(); foi != _varorders.end(); ++foi ) {
93 VarSet vs;
94 _perms[foi->first] = calculatePermutation( foi->second, vs );
95 _varsets[foi->first] = vs;
96 DAI_ASSERT( (BigInt)_estimation->probSize() == vs.nrStates() );
97 }
98 }
99
100
101 SharedParameters::SharedParameters( const FactorOrientations &varorders, ParameterEstimation *estimation, bool ownPE )
102 : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _ownEstimation(ownPE)
103 {
104 // Calculate the necessary permutations and varsets
105 setPermsAndVarSetsFromVarOrders();
106 }
107
108
109 SharedParameters::SharedParameters( std::istream &is, const FactorGraph &fg )
110 : _varsets(), _perms(), _varorders(), _estimation(NULL), _ownEstimation(true)
111 {
112 // Read the desired parameter estimation method from the stream
113 std::string est_method;
114 PropertySet props;
115 is >> est_method;
116 is >> props;
117
118 // Construct a corresponding object
119 _estimation = ParameterEstimation::construct( est_method, props );
120
121 // Read in the factors that are to be estimated
122 size_t num_factors;
123 is >> num_factors;
124 for( size_t sp_i = 0; sp_i < num_factors; ++sp_i ) {
125 std::string line;
126 while( line.size() == 0 && getline(is, line) )
127 ;
128
129 std::vector<std::string> fields = tokenizeString( line, true, " \t" );
130
131 // Lookup the factor in the factorgraph
132 if( fields.size() < 1 )
133 DAI_THROWE(INVALID_EMALG_FILE,"Empty line unexpected");
134 std::istringstream iss;
135 iss.str( fields[0] );
136 size_t factor;
137 iss >> factor;
138 const VarSet &vs = fg.factor(factor).vars();
139 if( fields.size() != vs.size() + 1 )
140 DAI_THROWE(INVALID_EMALG_FILE,"Number of fields does not match factor size");
141
142 // Construct the vector of Vars
143 std::vector<Var> var_order;
144 var_order.reserve( vs.size() );
145 for( size_t fi = 1; fi < fields.size(); ++fi ) {
146 // Lookup a single variable by label
147 size_t label;
148 std::istringstream labelparse( fields[fi] );
149 labelparse >> label;
150 VarSet::const_iterator vsi = vs.begin();
151 for( ; vsi != vs.end(); ++vsi )
152 if( vsi->label() == label )
153 break;
154 if( vsi == vs.end() )
155 DAI_THROWE(INVALID_EMALG_FILE,"Specified variables do not match the factor variables");
156 var_order.push_back( *vsi );
157 }
158 _varorders[factor] = var_order;
159 }
160
161 // Calculate the necessary permutations
162 setPermsAndVarSetsFromVarOrders();
163 }
164
165
166 void SharedParameters::collectSufficientStatistics( InfAlg &alg ) {
167 for( std::map< FactorIndex, Permute >::iterator i = _perms.begin(); i != _perms.end(); ++i ) {
168 Permute &perm = i->second;
169 VarSet &vs = _varsets[i->first];
170
171 Factor b = alg.belief(vs);
172 Prob p( b.nrStates(), 0.0 );
173 for( size_t entry = 0; entry < b.nrStates(); ++entry )
174 p.set( entry, b[perm.convertLinearIndex(entry)] ); // apply inverse permutation
175 _estimation->addSufficientStatistics( p );
176 }
177 }
178
179
180 void SharedParameters::setParameters( FactorGraph &fg ) {
181 Prob p = _estimation->estimate();
182 for( std::map<FactorIndex, Permute>::iterator i = _perms.begin(); i != _perms.end(); ++i ) {
183 Permute &perm = i->second;
184 VarSet &vs = _varsets[i->first];
185
186 Factor f( vs, 0.0 );
187 for( size_t entry = 0; entry < f.nrStates(); ++entry )
188 f.set( perm.convertLinearIndex(entry), p[entry] );
189
190 fg.setFactor( i->first, f );
191 }
192 }
193
194
195 MaximizationStep::MaximizationStep( std::istream &is, const FactorGraph &fg_varlookup ) : _params() {
196 size_t num_params = -1;
197 is >> num_params;
198 _params.reserve( num_params );
199 for( size_t i = 0; i < num_params; ++i )
200 _params.push_back( SharedParameters( is, fg_varlookup ) );
201 }
202
203
204 void MaximizationStep::addExpectations( InfAlg &alg ) {
205 for( size_t i = 0; i < _params.size(); ++i )
206 _params[i].collectSufficientStatistics( alg );
207 }
208
209
210 void MaximizationStep::maximize( FactorGraph &fg ) {
211 for( size_t i = 0; i < _params.size(); ++i )
212 _params[i].setParameters( fg );
213 }
214
215
216 const std::string EMAlg::MAX_ITERS_KEY("max_iters");
217 const std::string EMAlg::LOG_Z_TOL_KEY("log_z_tol");
218 const size_t EMAlg::MAX_ITERS_DEFAULT = 30;
219 const Real EMAlg::LOG_Z_TOL_DEFAULT = 0.01;
220
221
222 EMAlg::EMAlg( const Evidence &evidence, InfAlg &estep, std::istream &msteps_file )
223 : _evidence(evidence), _estep(estep), _msteps(), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT), _log_z_tol(LOG_Z_TOL_DEFAULT)
224 {
225 msteps_file.exceptions( std::istream::eofbit | std::istream::failbit | std::istream::badbit );
226 size_t num_msteps = -1;
227 msteps_file >> num_msteps;
228 _msteps.reserve(num_msteps);
229 for( size_t i = 0; i < num_msteps; ++i )
230 _msteps.push_back( MaximizationStep( msteps_file, estep.fg() ) );
231 }
232
233
234 void EMAlg::setTermConditions( const PropertySet &p ) {
235 if( p.hasKey(MAX_ITERS_KEY) )
236 _max_iters = p.getStringAs<size_t>(MAX_ITERS_KEY);
237 if( p.hasKey(LOG_Z_TOL_KEY) )
238 _log_z_tol = p.getStringAs<Real>(LOG_Z_TOL_KEY);
239 }
240
241
242 bool EMAlg::hasSatisfiedTermConditions() const {
243 if( _iters >= _max_iters )
244 return true;
245 else if( _lastLogZ.size() < 3 )
246 // need at least 2 to calculate ratio
247 // Also, throw away first iteration, as the parameters may not
248 // have been normalized according to the estimation method
249 return false;
250 else {
251 Real current = _lastLogZ[_lastLogZ.size() - 1];
252 Real previous = _lastLogZ[_lastLogZ.size() - 2];
253 if( previous == 0 )
254 return false;
255 Real diff = current - previous;
256 if( diff < 0 ) {
257 std::cerr << "Error: in EM log-likehood decreased from " << previous << " to " << current << std::endl;
258 return true;
259 }
260 return (diff / fabs(previous)) <= _log_z_tol;
261 }
262 }
263
264
265 Real EMAlg::iterate( MaximizationStep &mstep ) {
266 Real logZ = 0;
267 Real likelihood = 0;
268
269 _estep.run();
270 logZ = _estep.logZ();
271
272 // Expectation calculation
273 for( Evidence::const_iterator e = _evidence.begin(); e != _evidence.end(); ++e ) {
274 InfAlg* clamped = _estep.clone();
275 // Apply evidence
276 for( Evidence::Observation::const_iterator i = e->begin(); i != e->end(); ++i )
277 clamped->clamp( clamped->fg().findVar(i->first), i->second );
278 clamped->init();
279 clamped->run();
280
281 likelihood += clamped->logZ() - logZ;
282
283 mstep.addExpectations( *clamped );
284
285 delete clamped;
286 }
287
288 // Maximization of parameters
289 mstep.maximize( _estep.fg() );
290
291 return likelihood;
292 }
293
294
295 Real EMAlg::iterate() {
296 Real likelihood;
297 for( size_t i = 0; i < _msteps.size(); ++i )
298 likelihood = iterate( _msteps[i] );
299 _lastLogZ.push_back( likelihood );
300 ++_iters;
301 return likelihood;
302 }
303
304
305 void EMAlg::run() {
306 while( !hasSatisfiedTermConditions() )
307 iterate();
308 }
309
310
311 } // end of namespace dai