fe13a28643aa4b34c8b48ee9b6344a98367cafb4
[libdai.git] / src / emalg.cpp
1 #include <dai/util.h>
2
3 #include <dai/emalg.h>
4
5 namespace dai{
6
7 std::map< std::string, ParameterEstimation::ParamEstFactory>*
8 ParameterEstimation::_registry = NULL;
9
10 void ParameterEstimation::loadDefaultRegistry() {
11 _registry = new std::map< std::string, ParamEstFactory>();
12 (*_registry)["ConditionalProbEstimation"] = CondProbEstimation::factory;
13 }
14
15 ParameterEstimation* ParameterEstimation::construct(const std::string& method,
16 const PropertySet& p) {
17 if (_registry == NULL) {
18 loadDefaultRegistry();
19 }
20 std::map< std::string, ParamEstFactory>::iterator i = _registry->find(method);
21 if (i == _registry->end()) {
22 DAI_THROW(UNKNOWN_PARAMETER_ESTIMATION_METHOD);
23 }
24 ParamEstFactory factory = i->second;
25 return factory(p);
26 }
27
28 ParameterEstimation* CondProbEstimation::factory(const PropertySet& p) {
29 size_t target_dimension = p.getStringAs<size_t>("target_dim");
30 size_t total_dimension = p.getStringAs<size_t>("total_dim");
31 Real pseudo_count = 1;
32 if (p.hasKey("pseudo_count")) {
33 pseudo_count = p.getStringAs<Real>("pseudo_count");
34 }
35 Prob counts_vec(total_dimension, pseudo_count);
36 return new CondProbEstimation(target_dimension, counts_vec);
37 }
38
39 CondProbEstimation::CondProbEstimation(size_t target_dimension,
40 Prob pseudocounts)
41 : _target_dim(target_dimension),
42 _stats(pseudocounts),
43 _initial_stats(pseudocounts) {
44 if (_stats.size() % _target_dim) {
45 DAI_THROW(MALFORMED_PROPERTY);
46 }
47 }
48
49 void CondProbEstimation::addSufficientStatistics(Prob& p) {
50 _stats += p;
51 }
52
53 Prob CondProbEstimation::estimate() {
54 for (size_t parent = 0; parent < _stats.size(); parent += _target_dim) {
55 Real norm = 0;
56 size_t top = parent + _target_dim;
57 for (size_t i = parent; i < top; ++i) {
58 norm += _stats[i];
59 }
60 if (norm != 0) {
61 norm = 1 / norm;
62 }
63 for (size_t i = parent; i < top; ++i) {
64 _stats[i] *= norm;
65 }
66 }
67 Prob result = _stats;
68 _stats = _initial_stats;
69 return result;
70 }
71
72 Permute
73 SharedParameters::calculatePermutation(const std::vector< Var >& varorder,
74 const std::vector< size_t >& dims,
75 VarSet& outVS) {
76 std::vector<long> labels(dims.size());
77
78 // Check that the variable set is compatible
79 if (varorder.size() != dims.size()) {
80 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
81 }
82
83 // Collect all labels, and order them in vs
84 for (size_t di = 0; di < dims.size(); ++di) {
85 if (dims[di] != varorder[di].states()) {
86 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
87 }
88 outVS |= varorder[di];
89 labels[di] = varorder[di].label();
90 }
91
92 // Construct the sigma array for the permutation object
93 std::vector<size_t> sigma(dims.size(), 0);
94 VarSet::iterator set_iterator = outVS.begin();
95 for (size_t vs_i = 0; vs_i < dims.size(); ++vs_i, ++set_iterator) {
96 std::vector< long >::iterator location = find(labels.begin(), labels.end(),
97 set_iterator->label());
98 sigma[vs_i] = location - labels.begin();
99 }
100
101 return Permute(dims, sigma);
102 }
103
104 void SharedParameters::setPermsAndVarSetsFromVarOrders() {
105 if (_varorders.size() == 0) {
106 return;
107 }
108 FactorOrientations::const_iterator foi = _varorders.begin();
109 std::vector< size_t > dims(foi->second.size());
110 size_t total_dim = 1;
111 for (size_t i = 0; i < dims.size(); ++i) {
112 dims[i] = foi->second[i].states();
113 total_dim *= dims[i];
114 }
115
116 // Construct the permutation objects
117 for ( ; foi != _varorders.end(); ++foi) {
118 VarSet vs;
119 _perms[foi->first] = calculatePermutation(foi->second, dims, vs);
120 _varsets[foi->first] = vs;
121 }
122
123 if (_estimation == NULL || _estimation->probSize() != total_dim) {
124 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
125 }
126 }
127
128 SharedParameters::SharedParameters(std::istream& is,
129 const FactorGraph& fg_varlookup)
130 : _varsets(),
131 _perms(),
132 _varorders(),
133 _estimation(NULL),
134 _deleteEstimation(1)
135 {
136 std::string est_method;
137 PropertySet props;
138 is >> est_method;
139 is >> props;
140
141 _estimation = ParameterEstimation::construct(est_method, props);
142
143 size_t num_factors;
144 is >> num_factors;
145 for (size_t sp_i = 0; sp_i < num_factors; ++sp_i) {
146 std::string line;
147 std::vector< std::string > fields;
148 size_t factor;
149 std::vector< Var > var_order;
150 std::istringstream iss;
151
152 while(line.size() == 0 && getline(is, line));
153 tokenizeString(line, fields, " \t");
154
155 // Lookup the factor in the factorgraph
156 if (fields.size() < 1) {
157 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
158 }
159 iss.str(fields[0]);
160 iss >> factor;
161 const VarSet& vs = fg_varlookup.factor(factor).vars();
162 if (fields.size() != vs.size() + 1) {
163 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
164 }
165
166 // Construct the vector of Vars
167 for (size_t fi = 1; fi < fields.size(); ++fi) {
168 // Lookup a single variable by label
169 long label;
170 std::istringstream labelparse(fields[fi]);
171 labelparse >> label;
172 VarSet::const_iterator vsi = vs.begin();
173 for ( ; vsi != vs.end(); ++vsi) {
174 if (vsi->label() == label) break;
175 }
176 if (vsi == vs.end()) {
177 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
178 }
179 var_order.push_back(*vsi);
180 }
181 _varorders[factor] = var_order;
182 }
183 setPermsAndVarSetsFromVarOrders();
184 }
185
186 SharedParameters::SharedParameters(const SharedParameters& sp)
187 : _varsets(sp._varsets),
188 _perms(sp._perms),
189 _varorders(sp._varorders),
190 _estimation(sp._estimation),
191 _deleteEstimation(sp._deleteEstimation)
192 {
193 if (_deleteEstimation) {
194 _estimation = _estimation->clone();
195 }
196 }
197
198 SharedParameters::SharedParameters(const FactorOrientations& varorders,
199 ParameterEstimation* estimation)
200 : _varsets(),
201 _perms(),
202 _varorders(varorders),
203 _estimation(estimation),
204 _deleteEstimation(0)
205 {
206 setPermsAndVarSetsFromVarOrders();
207 }
208
209 void SharedParameters::collectSufficientStatistics(InfAlg& alg) {
210 std::map< FactorIndex, Permute >::iterator i = _perms.begin();
211 for ( ; i != _perms.end(); ++i) {
212 Permute& perm = i->second;
213 VarSet& vs = _varsets[i->first];
214
215 Factor b = alg.belief(vs);
216 Prob p(b.states(), 0.0);
217 for (size_t entry = 0; entry < b.states(); ++entry) {
218 p[entry] = b[perm.convert_linear_index(entry)];
219 }
220 _estimation->addSufficientStatistics(p);
221 }
222 }
223
224 void SharedParameters::setParameters(FactorGraph& fg) {
225 Prob p = _estimation->estimate();
226 std::map< FactorIndex, Permute >::iterator i = _perms.begin();
227 for ( ; i != _perms.end(); ++i) {
228 Permute& perm = i->second;
229 VarSet& vs = _varsets[i->first];
230
231 Factor f(vs, 0.0);
232 for (size_t entry = 0; entry < f.states(); ++entry) {
233 f[perm.convert_linear_index(entry)] = p[entry];
234 }
235
236 fg.setFactor(i->first, f);
237 }
238 }
239
240 MaximizationStep::MaximizationStep (std::istream& is,
241 const FactorGraph& fg_varlookup )
242 : _params()
243 {
244 size_t num_params = -1;
245 is >> num_params;
246 _params.reserve(num_params);
247 for (size_t i = 0; i < num_params; ++i) {
248 SharedParameters p(is, fg_varlookup);
249 _params.push_back(p);
250 }
251 }
252
253
254 void MaximizationStep::addExpectations(InfAlg& alg) {
255 for (size_t i = 0; i < _params.size(); ++i) {
256 _params[i].collectSufficientStatistics(alg);
257 }
258 }
259
260 void MaximizationStep::maximize(FactorGraph& fg) {
261 for (size_t i = 0; i < _params.size(); ++i) {
262 _params[i].setParameters(fg);
263 }
264 }
265
266 const std::string EMAlg::MAX_ITERS_KEY("max_iters");
267 const std::string EMAlg::LOG_Z_TOL_KEY("log_z_tol");
268 const size_t EMAlg::MAX_ITERS_DEFAULT = 30;
269 const Real EMAlg::LOG_Z_TOL_DEFAULT = 0.01;
270
271 EMAlg::EMAlg(const Evidence& evidence, InfAlg& estep, std::istream& msteps_file)
272 : _evidence(evidence),
273 _estep(estep),
274 _msteps(),
275 _iters(0),
276 _lastLogZ(),
277 _max_iters(MAX_ITERS_DEFAULT),
278 _log_z_tol(LOG_Z_TOL_DEFAULT)
279 {
280 msteps_file.exceptions( std::istream::eofbit | std::istream::failbit
281 | std::istream::badbit );
282 size_t num_msteps = -1;
283 msteps_file >> num_msteps;
284 _msteps.reserve(num_msteps);
285 for (size_t i = 0; i < num_msteps; ++i) {
286 MaximizationStep m(msteps_file, estep.fg());
287 _msteps.push_back(m);
288 }
289 }
290
291 void EMAlg::setTermConditions(const PropertySet* p) {
292 if (NULL == p) {
293 return;
294 }
295 if (p->hasKey(MAX_ITERS_KEY)) {
296 _max_iters = p->getStringAs<size_t>(MAX_ITERS_KEY);
297 }
298 if (p->hasKey(LOG_Z_TOL_KEY)) {
299 _log_z_tol = p->getStringAs<Real>(LOG_Z_TOL_KEY);
300 }
301 }
302
303 bool EMAlg::hasSatisfiedTermConditions() const {
304 if (_iters >= _max_iters) {
305 return 1;
306 } else if (_lastLogZ.size() < 3) {
307 // need at least 2 to calculate ratio
308 // Also, throw away first iteration, as the parameters may not
309 // have been normalized according to the estimation method
310 return 0;
311 } else {
312 Real current = _lastLogZ[_lastLogZ.size() - 1];
313 Real previous = _lastLogZ[_lastLogZ.size() - 2];
314 if (previous == 0) return 0;
315 Real diff = current - previous;
316 if (diff < 0) {
317 std::cerr << "Error: in EM log-likehood decreased from " << previous
318 << " to " << current << std::endl;
319 return 1;
320 }
321 return diff / abs(previous) <= _log_z_tol;
322 }
323 }
324
325 Real EMAlg::iterate(MaximizationStep& mstep) {
326 Evidence::const_iterator e = _evidence.begin();
327 Real logZ = 0;
328
329 // Expectation calculation
330 for ( ; e != _evidence.end(); ++e) {
331 InfAlg* clamped = _estep.clone();
332 e->second.applyEvidence(*clamped);
333 clamped->run();
334
335 logZ += clamped->logZ();
336
337 mstep.addExpectations(*clamped);
338
339 delete clamped;
340 }
341
342 // Maximization of parameters
343 mstep.maximize(_estep.fg());
344
345 return logZ;
346 }
347
348 Real EMAlg::iterate() {
349 Real likelihood;
350 for (size_t i = 0; i < _msteps.size(); ++i) {
351 likelihood = iterate(_msteps[i]);
352 }
353 _lastLogZ.push_back(likelihood);
354 ++_iters;
355 return likelihood;
356 }
357
358 void EMAlg::run() {
359 while(!hasSatisfiedTermConditions()) {
360 iterate();
361 }
362 }
363
364 }