3ef4be468c230d10ee8e99b7d2a4aaa38840e73a
[libdai.git] / src / emalg.cpp
1 /* Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
2 University of California Santa Cruz
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <dai/util.h>
23 #include <dai/emalg.h>
24
25
26 namespace dai {
27
28
29 std::map<std::string, ParameterEstimation::ParamEstFactory> *ParameterEstimation::_registry = NULL;
30
31
32 void ParameterEstimation::loadDefaultRegistry() {
33 _registry = new std::map< std::string, ParamEstFactory>();
34 (*_registry)["ConditionalProbEstimation"] = CondProbEstimation::factory;
35 }
36
37
38 ParameterEstimation* ParameterEstimation::construct( const std::string& method, const PropertySet& p ) {
39 if (_registry == NULL)
40 loadDefaultRegistry();
41 std::map< std::string, ParamEstFactory>::iterator i = _registry->find(method);
42 if (i == _registry->end())
43 DAI_THROW(UNKNOWN_PARAMETER_ESTIMATION_METHOD);
44 ParamEstFactory factory = i->second;
45 return factory(p);
46 }
47
48
49 ParameterEstimation* CondProbEstimation::factory( const PropertySet& p ) {
50 size_t target_dimension = p.getStringAs<size_t>("target_dim");
51 size_t total_dimension = p.getStringAs<size_t>("total_dim");
52 Real pseudo_count = 1;
53 if (p.hasKey("pseudo_count"))
54 pseudo_count = p.getStringAs<Real>("pseudo_count");
55 Prob counts_vec(total_dimension, pseudo_count);
56 return new CondProbEstimation(target_dimension, counts_vec);
57 }
58
59
60 CondProbEstimation::CondProbEstimation( size_t target_dimension, Prob pseudocounts )
61 : _target_dim(target_dimension), _stats(pseudocounts), _initial_stats(pseudocounts)
62 {
63 if (_stats.size() % _target_dim)
64 DAI_THROW(MALFORMED_PROPERTY);
65 }
66
67
68 void CondProbEstimation::addSufficientStatistics( Prob& p ) {
69 _stats += p;
70 }
71
72
73 Prob CondProbEstimation::estimate() {
74 for (size_t parent = 0; parent < _stats.size(); parent += _target_dim) {
75 Real norm = 0;
76 size_t top = parent + _target_dim;
77 for (size_t i = parent; i < top; ++i)
78 norm += _stats[i];
79 if (norm != 0)
80 norm = 1 / norm;
81 for (size_t i = parent; i < top; ++i)
82 _stats[i] *= norm;
83 }
84 Prob result = _stats;
85 _stats = _initial_stats;
86 return result;
87 }
88
89
90 Permute SharedParameters::calculatePermutation( const std::vector< Var >& varorder, VarSet& outVS ) {
91 std::vector<size_t> dims;
92 std::vector<long> labels;
93
94 // Collect all labels and dimensions, and order them in vs
95 dims.reserve( varorder.size() );
96 labels.reserve( varorder.size() );
97 for( size_t i = 0; i < varorder.size(); i++ ) {
98 dims.push_back( varorder[i].states() );
99 labels.push_back( varorder[i].label() );
100 outVS |= varorder[i];
101 }
102
103 // Construct the sigma array for the permutation object
104 std::vector<size_t> sigma;
105 sigma.reserve( dims.size() );
106 for (VarSet::iterator set_iterator = outVS.begin(); sigma.size() < dims.size(); ++set_iterator)
107 sigma.push_back( find(labels.begin(), labels.end(), set_iterator->label()) - labels.begin() );
108
109 return Permute(dims, sigma);
110 }
111
112
113 void SharedParameters::setPermsAndVarSetsFromVarOrders() {
114 if (_varorders.size() == 0)
115 return;
116 if( _estimation == NULL )
117 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
118
119 // Construct the permutation objects and the varsets
120 for( FactorOrientations::const_iterator foi = _varorders.begin(); foi != _varorders.end(); ++foi ) {
121 VarSet vs;
122 _perms[foi->first] = calculatePermutation(foi->second, vs);
123 _varsets[foi->first] = vs;
124 if( _estimation->probSize() != vs.nrStates() )
125 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
126 }
127 }
128
129
130 SharedParameters::SharedParameters(std::istream& is, const FactorGraph& fg_varlookup)
131 : _varsets(), _perms(), _varorders(), _estimation(NULL), _deleteEstimation(true)
132 {
133 std::string est_method;
134 PropertySet props;
135 is >> est_method;
136 is >> props;
137
138 _estimation = ParameterEstimation::construct(est_method, props);
139
140 size_t num_factors;
141 is >> num_factors;
142 for( size_t sp_i = 0; sp_i < num_factors; ++sp_i ) {
143 std::string line;
144 while(line.size() == 0 && getline(is, line))
145 ;
146
147 std::vector< std::string > fields;
148 tokenizeString(line, fields, " \t");
149
150 // Lookup the factor in the factorgraph
151 if (fields.size() < 1)
152 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
153 std::istringstream iss;
154 iss.str(fields[0]);
155 size_t factor;
156 iss >> factor;
157 const VarSet& vs = fg_varlookup.factor(factor).vars();
158 if (fields.size() != vs.size() + 1)
159 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
160
161 // Construct the vector of Vars
162 std::vector< Var > var_order;
163 var_order.reserve( vs.size() );
164 for (size_t fi = 1; fi < fields.size(); ++fi) {
165 // Lookup a single variable by label
166 long label;
167 std::istringstream labelparse(fields[fi]);
168 labelparse >> label;
169 VarSet::const_iterator vsi = vs.begin();
170 for ( ; vsi != vs.end(); ++vsi)
171 if (vsi->label() == label)
172 break;
173 if (vsi == vs.end())
174 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
175 var_order.push_back(*vsi);
176 }
177 _varorders[factor] = var_order;
178 }
179 setPermsAndVarSetsFromVarOrders();
180 }
181
182
183 SharedParameters::SharedParameters( const SharedParameters& sp )
184 : _varsets(sp._varsets), _perms(sp._perms), _varorders(sp._varorders), _estimation(sp._estimation), _deleteEstimation(sp._deleteEstimation)
185 {
186 if (_deleteEstimation)
187 _estimation = _estimation->clone();
188 }
189
190
191 SharedParameters::SharedParameters( const FactorOrientations& varorders, ParameterEstimation* estimation )
192 : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _deleteEstimation(false)
193 {
194 setPermsAndVarSetsFromVarOrders();
195 }
196
197
198 void SharedParameters::collectSufficientStatistics(InfAlg& alg) {
199 for ( std::map< FactorIndex, Permute >::iterator i = _perms.begin(); i != _perms.end(); ++i) {
200 Permute& perm = i->second;
201 VarSet& vs = _varsets[i->first];
202
203 Factor b = alg.belief(vs);
204 Prob p(b.states(), 0.0);
205 for (size_t entry = 0; entry < b.states(); ++entry)
206 p[entry] = b[perm.convert_linear_index(entry)];
207 _estimation->addSufficientStatistics(p);
208 }
209 }
210
211
212 void SharedParameters::setParameters(FactorGraph& fg) {
213 Prob p = _estimation->estimate();
214 for ( std::map< FactorIndex, Permute >::iterator i = _perms.begin(); i != _perms.end(); ++i) {
215 Permute& perm = i->second;
216 VarSet& vs = _varsets[i->first];
217
218 Factor f(vs, 0.0);
219 for (size_t entry = 0; entry < f.states(); ++entry)
220 f[perm.convert_linear_index(entry)] = p[entry];
221
222 fg.setFactor(i->first, f);
223 }
224 }
225
226
227 MaximizationStep::MaximizationStep (std::istream& is, const FactorGraph& fg_varlookup )
228 : _params()
229 {
230 size_t num_params = -1;
231 is >> num_params;
232 _params.reserve(num_params);
233 for (size_t i = 0; i < num_params; ++i) {
234 SharedParameters p(is, fg_varlookup);
235 _params.push_back(p);
236 }
237 }
238
239
240 void MaximizationStep::addExpectations(InfAlg& alg) {
241 for (size_t i = 0; i < _params.size(); ++i)
242 _params[i].collectSufficientStatistics(alg);
243 }
244
245
246 void MaximizationStep::maximize(FactorGraph& fg) {
247 for (size_t i = 0; i < _params.size(); ++i)
248 _params[i].setParameters(fg);
249 }
250
251
252 const std::string EMAlg::MAX_ITERS_KEY("max_iters");
253 const std::string EMAlg::LOG_Z_TOL_KEY("log_z_tol");
254 const size_t EMAlg::MAX_ITERS_DEFAULT = 30;
255 const Real EMAlg::LOG_Z_TOL_DEFAULT = 0.01;
256
257
258 EMAlg::EMAlg(const Evidence& evidence, InfAlg& estep, std::istream& msteps_file)
259 : _evidence(evidence), _estep(estep), _msteps(), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT), _log_z_tol(LOG_Z_TOL_DEFAULT)
260 {
261 msteps_file.exceptions( std::istream::eofbit | std::istream::failbit | std::istream::badbit );
262 size_t num_msteps = -1;
263 msteps_file >> num_msteps;
264 _msteps.reserve(num_msteps);
265 for (size_t i = 0; i < num_msteps; ++i) {
266 MaximizationStep m(msteps_file, estep.fg());
267 _msteps.push_back(m);
268 }
269 }
270
271
272 void EMAlg::setTermConditions(const PropertySet &p) {
273 if (p.hasKey(MAX_ITERS_KEY))
274 _max_iters = p.getStringAs<size_t>(MAX_ITERS_KEY);
275 if (p.hasKey(LOG_Z_TOL_KEY))
276 _log_z_tol = p.getStringAs<Real>(LOG_Z_TOL_KEY);
277 }
278
279
280 bool EMAlg::hasSatisfiedTermConditions() const {
281 if (_iters >= _max_iters) {
282 return true;
283 } else if (_lastLogZ.size() < 3) {
284 // need at least 2 to calculate ratio
285 // Also, throw away first iteration, as the parameters may not
286 // have been normalized according to the estimation method
287 return false;
288 } else {
289 Real current = _lastLogZ[_lastLogZ.size() - 1];
290 Real previous = _lastLogZ[_lastLogZ.size() - 2];
291 if (previous == 0)
292 return false;
293 Real diff = current - previous;
294 if (diff < 0) {
295 std::cerr << "Error: in EM log-likehood decreased from " << previous << " to " << current << std::endl;
296 return true;
297 }
298 return diff / abs(previous) <= _log_z_tol;
299 }
300 }
301
302
303 Real EMAlg::iterate( MaximizationStep& mstep ) {
304 Real logZ = 0;
305
306 // Expectation calculation
307 for( Evidence::const_iterator e = _evidence.begin(); e != _evidence.end(); ++e ) {
308 InfAlg* clamped = _estep.clone();
309 e->applyEvidence(*clamped);
310 clamped->init();
311 clamped->run();
312
313 logZ += clamped->logZ();
314
315 mstep.addExpectations(*clamped);
316
317 delete clamped;
318 }
319
320 // Maximization of parameters
321 mstep.maximize(_estep.fg());
322
323 return logZ;
324 }
325
326
327 Real EMAlg::iterate() {
328 Real likelihood;
329 for (size_t i = 0; i < _msteps.size(); ++i)
330 likelihood = iterate(_msteps[i]);
331 _lastLogZ.push_back(likelihood);
332 ++_iters;
333 return likelihood;
334 }
335
336
337 void EMAlg::run() {
338 while(!hasSatisfiedTermConditions())
339 iterate();
340 }
341
342
343 } // end of namespace dai