Renamed SampleData->Observation, changed Observation and Evidence
[libdai.git] / src / emalg.cpp
1 /* Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
2 University of California Santa Cruz
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <dai/util.h>
23 #include <dai/emalg.h>
24
25
26 namespace dai {
27
28
29 std::map<std::string, ParameterEstimation::ParamEstFactory> *ParameterEstimation::_registry = NULL;
30
31
32 void ParameterEstimation::loadDefaultRegistry() {
33 _registry = new std::map< std::string, ParamEstFactory>();
34 (*_registry)["ConditionalProbEstimation"] = CondProbEstimation::factory;
35 }
36
37
38 ParameterEstimation* ParameterEstimation::construct( const std::string& method, const PropertySet& p ) {
39 if (_registry == NULL)
40 loadDefaultRegistry();
41 std::map< std::string, ParamEstFactory>::iterator i = _registry->find(method);
42 if (i == _registry->end())
43 DAI_THROW(UNKNOWN_PARAMETER_ESTIMATION_METHOD);
44 ParamEstFactory factory = i->second;
45 return factory(p);
46 }
47
48
49 ParameterEstimation* CondProbEstimation::factory( const PropertySet& p ) {
50 size_t target_dimension = p.getStringAs<size_t>("target_dim");
51 size_t total_dimension = p.getStringAs<size_t>("total_dim");
52 Real pseudo_count = 1;
53 if (p.hasKey("pseudo_count"))
54 pseudo_count = p.getStringAs<Real>("pseudo_count");
55 Prob counts_vec(total_dimension, pseudo_count);
56 return new CondProbEstimation(target_dimension, counts_vec);
57 }
58
59
60 CondProbEstimation::CondProbEstimation( size_t target_dimension, Prob pseudocounts )
61 : _target_dim(target_dimension), _stats(pseudocounts), _initial_stats(pseudocounts)
62 {
63 if (_stats.size() % _target_dim)
64 DAI_THROW(MALFORMED_PROPERTY);
65 }
66
67
68 void CondProbEstimation::addSufficientStatistics( Prob& p ) {
69 _stats += p;
70 }
71
72
73 Prob CondProbEstimation::estimate() {
74 for (size_t parent = 0; parent < _stats.size(); parent += _target_dim) {
75 Real norm = 0;
76 size_t top = parent + _target_dim;
77 for (size_t i = parent; i < top; ++i)
78 norm += _stats[i];
79 if (norm != 0)
80 norm = 1 / norm;
81 for (size_t i = parent; i < top; ++i)
82 _stats[i] *= norm;
83 }
84 Prob result = _stats;
85 _stats = _initial_stats;
86 return result;
87 }
88
89
90 Permute SharedParameters::calculatePermutation(const std::vector< Var >& varorder, const std::vector< size_t >& dims, VarSet& outVS) {
91 std::vector<long> labels(dims.size());
92
93 // Check that the variable set is compatible
94 if (varorder.size() != dims.size())
95 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
96
97 // Collect all labels, and order them in vs
98 for (size_t di = 0; di < dims.size(); ++di) {
99 if (dims[di] != varorder[di].states())
100 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
101 outVS |= varorder[di];
102 labels[di] = varorder[di].label();
103 }
104
105 // Construct the sigma array for the permutation object
106 std::vector<size_t> sigma(dims.size(), 0);
107 VarSet::iterator set_iterator = outVS.begin();
108 for (size_t vs_i = 0; vs_i < dims.size(); ++vs_i, ++set_iterator) {
109 std::vector< long >::iterator location = find(labels.begin(), labels.end(), set_iterator->label());
110 sigma[vs_i] = location - labels.begin();
111 }
112
113 return Permute(dims, sigma);
114 }
115
116
117 void SharedParameters::setPermsAndVarSetsFromVarOrders() {
118 if (_varorders.size() == 0)
119 return;
120 FactorOrientations::const_iterator foi = _varorders.begin();
121 std::vector< size_t > dims(foi->second.size());
122 size_t total_dim = 1;
123 for (size_t i = 0; i < dims.size(); ++i) {
124 dims[i] = foi->second[i].states();
125 total_dim *= dims[i];
126 }
127
128 // Construct the permutation objects
129 for ( ; foi != _varorders.end(); ++foi) {
130 VarSet vs;
131 _perms[foi->first] = calculatePermutation(foi->second, dims, vs);
132 _varsets[foi->first] = vs;
133 }
134
135 if (_estimation == NULL || _estimation->probSize() != total_dim)
136 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
137 }
138
139
140 SharedParameters::SharedParameters(std::istream& is, const FactorGraph& fg_varlookup)
141 : _varsets(), _perms(), _varorders(), _estimation(NULL), _deleteEstimation(1)
142 {
143 std::string est_method;
144 PropertySet props;
145 is >> est_method;
146 is >> props;
147
148 _estimation = ParameterEstimation::construct(est_method, props);
149
150 size_t num_factors;
151 is >> num_factors;
152 for (size_t sp_i = 0; sp_i < num_factors; ++sp_i) {
153 std::string line;
154 std::vector< std::string > fields;
155 size_t factor;
156 std::vector< Var > var_order;
157 std::istringstream iss;
158
159 while(line.size() == 0 && getline(is, line))
160 ;
161 tokenizeString(line, fields, " \t");
162
163 // Lookup the factor in the factorgraph
164 if (fields.size() < 1)
165 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
166 iss.str(fields[0]);
167 iss >> factor;
168 const VarSet& vs = fg_varlookup.factor(factor).vars();
169 if (fields.size() != vs.size() + 1)
170 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
171
172 // Construct the vector of Vars
173 for (size_t fi = 1; fi < fields.size(); ++fi) {
174 // Lookup a single variable by label
175 long label;
176 std::istringstream labelparse(fields[fi]);
177 labelparse >> label;
178 VarSet::const_iterator vsi = vs.begin();
179 for ( ; vsi != vs.end(); ++vsi)
180 if (vsi->label() == label)
181 break;
182 if (vsi == vs.end())
183 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
184 var_order.push_back(*vsi);
185 }
186 _varorders[factor] = var_order;
187 }
188 setPermsAndVarSetsFromVarOrders();
189 }
190
191
192 SharedParameters::SharedParameters( const SharedParameters& sp )
193 : _varsets(sp._varsets), _perms(sp._perms), _varorders(sp._varorders), _estimation(sp._estimation), _deleteEstimation(sp._deleteEstimation)
194 {
195 if (_deleteEstimation)
196 _estimation = _estimation->clone();
197 }
198
199
200 SharedParameters::SharedParameters( const FactorOrientations& varorders, ParameterEstimation* estimation )
201 : _varsets(), _perms(), _varorders(varorders), _estimation(estimation), _deleteEstimation(0)
202 {
203 setPermsAndVarSetsFromVarOrders();
204 }
205
206
207 void SharedParameters::collectSufficientStatistics(InfAlg& alg) {
208 std::map< FactorIndex, Permute >::iterator i = _perms.begin();
209 for ( ; i != _perms.end(); ++i) {
210 Permute& perm = i->second;
211 VarSet& vs = _varsets[i->first];
212
213 Factor b = alg.belief(vs);
214 Prob p(b.states(), 0.0);
215 for (size_t entry = 0; entry < b.states(); ++entry)
216 p[entry] = b[perm.convert_linear_index(entry)];
217 _estimation->addSufficientStatistics(p);
218 }
219 }
220
221
222 void SharedParameters::setParameters(FactorGraph& fg) {
223 Prob p = _estimation->estimate();
224 std::map< FactorIndex, Permute >::iterator i = _perms.begin();
225 for ( ; i != _perms.end(); ++i) {
226 Permute& perm = i->second;
227 VarSet& vs = _varsets[i->first];
228
229 Factor f(vs, 0.0);
230 for (size_t entry = 0; entry < f.states(); ++entry)
231 f[perm.convert_linear_index(entry)] = p[entry];
232
233 fg.setFactor(i->first, f);
234 }
235 }
236
237
238 MaximizationStep::MaximizationStep (std::istream& is, const FactorGraph& fg_varlookup )
239 : _params()
240 {
241 size_t num_params = -1;
242 is >> num_params;
243 _params.reserve(num_params);
244 for (size_t i = 0; i < num_params; ++i) {
245 SharedParameters p(is, fg_varlookup);
246 _params.push_back(p);
247 }
248 }
249
250
251 void MaximizationStep::addExpectations(InfAlg& alg) {
252 for (size_t i = 0; i < _params.size(); ++i)
253 _params[i].collectSufficientStatistics(alg);
254 }
255
256
257 void MaximizationStep::maximize(FactorGraph& fg) {
258 for (size_t i = 0; i < _params.size(); ++i)
259 _params[i].setParameters(fg);
260 }
261
262
263 const std::string EMAlg::MAX_ITERS_KEY("max_iters");
264 const std::string EMAlg::LOG_Z_TOL_KEY("log_z_tol");
265 const size_t EMAlg::MAX_ITERS_DEFAULT = 30;
266 const Real EMAlg::LOG_Z_TOL_DEFAULT = 0.01;
267
268
269 EMAlg::EMAlg(const Evidence& evidence, InfAlg& estep, std::istream& msteps_file)
270 : _evidence(evidence), _estep(estep), _msteps(), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT), _log_z_tol(LOG_Z_TOL_DEFAULT)
271 {
272 msteps_file.exceptions( std::istream::eofbit | std::istream::failbit | std::istream::badbit );
273 size_t num_msteps = -1;
274 msteps_file >> num_msteps;
275 _msteps.reserve(num_msteps);
276 for (size_t i = 0; i < num_msteps; ++i) {
277 MaximizationStep m(msteps_file, estep.fg());
278 _msteps.push_back(m);
279 }
280 }
281
282
283 void EMAlg::setTermConditions(const PropertySet* p) {
284 if (NULL == p)
285 return;
286 if (p->hasKey(MAX_ITERS_KEY))
287 _max_iters = p->getStringAs<size_t>(MAX_ITERS_KEY);
288 if (p->hasKey(LOG_Z_TOL_KEY))
289 _log_z_tol = p->getStringAs<Real>(LOG_Z_TOL_KEY);
290 }
291
292
293 bool EMAlg::hasSatisfiedTermConditions() const {
294 if (_iters >= _max_iters) {
295 return 1;
296 } else if (_lastLogZ.size() < 3) {
297 // need at least 2 to calculate ratio
298 // Also, throw away first iteration, as the parameters may not
299 // have been normalized according to the estimation method
300 return 0;
301 } else {
302 Real current = _lastLogZ[_lastLogZ.size() - 1];
303 Real previous = _lastLogZ[_lastLogZ.size() - 2];
304 if (previous == 0) return 0;
305 Real diff = current - previous;
306 if (diff < 0) {
307 std::cerr << "Error: in EM log-likehood decreased from " << previous << " to " << current << std::endl;
308 return 1;
309 }
310 return diff / abs(previous) <= _log_z_tol;
311 }
312 }
313
314
315 Real EMAlg::iterate(MaximizationStep& mstep) {
316 Evidence::const_iterator e = _evidence.begin();
317 Real logZ = 0;
318
319 // Expectation calculation
320 for ( ; e != _evidence.end(); ++e) {
321 InfAlg* clamped = _estep.clone();
322 e->applyEvidence(*clamped);
323 clamped->run();
324
325 logZ += clamped->logZ();
326
327 mstep.addExpectations(*clamped);
328
329 delete clamped;
330 }
331
332 // Maximization of parameters
333 mstep.maximize(_estep.fg());
334
335 return logZ;
336 }
337
338
339 Real EMAlg::iterate() {
340 Real likelihood;
341 for (size_t i = 0; i < _msteps.size(); ++i)
342 likelihood = iterate(_msteps[i]);
343 _lastLogZ.push_back(likelihood);
344 ++_iters;
345 return likelihood;
346 }
347
348
349 void EMAlg::run() {
350 while(!hasSatisfiedTermConditions())
351 iterate();
352 }
353
354
355 } // end of namespace dai