1 /* Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
2 University of California Santa Cruz
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 #include <dai/emalg.h>
29 std::map
<std::string
, ParameterEstimation::ParamEstFactory
> *ParameterEstimation::_registry
= NULL
;
32 void ParameterEstimation::loadDefaultRegistry() {
33 _registry
= new std::map
< std::string
, ParamEstFactory
>();
34 (*_registry
)["ConditionalProbEstimation"] = CondProbEstimation::factory
;
38 ParameterEstimation
* ParameterEstimation::construct( const std::string
& method
, const PropertySet
& p
) {
39 if (_registry
== NULL
)
40 loadDefaultRegistry();
41 std::map
< std::string
, ParamEstFactory
>::iterator i
= _registry
->find(method
);
42 if (i
== _registry
->end())
43 DAI_THROW(UNKNOWN_PARAMETER_ESTIMATION_METHOD
);
44 ParamEstFactory factory
= i
->second
;
49 ParameterEstimation
* CondProbEstimation::factory( const PropertySet
& p
) {
50 size_t target_dimension
= p
.getStringAs
<size_t>("target_dim");
51 size_t total_dimension
= p
.getStringAs
<size_t>("total_dim");
52 Real pseudo_count
= 1;
53 if (p
.hasKey("pseudo_count"))
54 pseudo_count
= p
.getStringAs
<Real
>("pseudo_count");
55 Prob
counts_vec(total_dimension
, pseudo_count
);
56 return new CondProbEstimation(target_dimension
, counts_vec
);
60 CondProbEstimation::CondProbEstimation( size_t target_dimension
, Prob pseudocounts
)
61 : _target_dim(target_dimension
), _stats(pseudocounts
), _initial_stats(pseudocounts
)
63 if (_stats
.size() % _target_dim
)
64 DAI_THROW(MALFORMED_PROPERTY
);
68 void CondProbEstimation::addSufficientStatistics( Prob
& p
) {
73 Prob
CondProbEstimation::estimate() {
74 for (size_t parent
= 0; parent
< _stats
.size(); parent
+= _target_dim
) {
76 size_t top
= parent
+ _target_dim
;
77 for (size_t i
= parent
; i
< top
; ++i
)
81 for (size_t i
= parent
; i
< top
; ++i
)
85 _stats
= _initial_stats
;
90 Permute
SharedParameters::calculatePermutation(const std::vector
< Var
>& varorder
, const std::vector
< size_t >& dims
, VarSet
& outVS
) {
91 std::vector
<long> labels(dims
.size());
93 // Check that the variable set is compatible
94 if (varorder
.size() != dims
.size())
95 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER
);
97 // Collect all labels, and order them in vs
98 for (size_t di
= 0; di
< dims
.size(); ++di
) {
99 if (dims
[di
] != varorder
[di
].states())
100 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER
);
101 outVS
|= varorder
[di
];
102 labels
[di
] = varorder
[di
].label();
105 // Construct the sigma array for the permutation object
106 std::vector
<size_t> sigma(dims
.size(), 0);
107 VarSet::iterator set_iterator
= outVS
.begin();
108 for (size_t vs_i
= 0; vs_i
< dims
.size(); ++vs_i
, ++set_iterator
) {
109 std::vector
< long >::iterator location
= find(labels
.begin(), labels
.end(), set_iterator
->label());
110 sigma
[vs_i
] = location
- labels
.begin();
113 return Permute(dims
, sigma
);
117 void SharedParameters::setPermsAndVarSetsFromVarOrders() {
118 if (_varorders
.size() == 0)
120 FactorOrientations::const_iterator foi
= _varorders
.begin();
121 std::vector
< size_t > dims(foi
->second
.size());
122 size_t total_dim
= 1;
123 for (size_t i
= 0; i
< dims
.size(); ++i
) {
124 dims
[i
] = foi
->second
[i
].states();
125 total_dim
*= dims
[i
];
128 // Construct the permutation objects
129 for ( ; foi
!= _varorders
.end(); ++foi
) {
131 _perms
[foi
->first
] = calculatePermutation(foi
->second
, dims
, vs
);
132 _varsets
[foi
->first
] = vs
;
135 if (_estimation
== NULL
|| _estimation
->probSize() != total_dim
)
136 DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER
);
140 SharedParameters::SharedParameters(std::istream
& is
, const FactorGraph
& fg_varlookup
)
141 : _varsets(), _perms(), _varorders(), _estimation(NULL
), _deleteEstimation(1)
143 std::string est_method
;
148 _estimation
= ParameterEstimation::construct(est_method
, props
);
152 for (size_t sp_i
= 0; sp_i
< num_factors
; ++sp_i
) {
154 std::vector
< std::string
> fields
;
156 std::vector
< Var
> var_order
;
157 std::istringstream iss
;
159 while(line
.size() == 0 && getline(is
, line
))
161 tokenizeString(line
, fields
, " \t");
163 // Lookup the factor in the factorgraph
164 if (fields
.size() < 1)
165 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE
);
168 const VarSet
& vs
= fg_varlookup
.factor(factor
).vars();
169 if (fields
.size() != vs
.size() + 1)
170 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE
);
172 // Construct the vector of Vars
173 for (size_t fi
= 1; fi
< fields
.size(); ++fi
) {
174 // Lookup a single variable by label
176 std::istringstream
labelparse(fields
[fi
]);
178 VarSet::const_iterator vsi
= vs
.begin();
179 for ( ; vsi
!= vs
.end(); ++vsi
)
180 if (vsi
->label() == label
)
183 DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE
);
184 var_order
.push_back(*vsi
);
186 _varorders
[factor
] = var_order
;
188 setPermsAndVarSetsFromVarOrders();
192 SharedParameters::SharedParameters( const SharedParameters
& sp
)
193 : _varsets(sp
._varsets
), _perms(sp
._perms
), _varorders(sp
._varorders
), _estimation(sp
._estimation
), _deleteEstimation(sp
._deleteEstimation
)
195 if (_deleteEstimation
)
196 _estimation
= _estimation
->clone();
200 SharedParameters::SharedParameters( const FactorOrientations
& varorders
, ParameterEstimation
* estimation
)
201 : _varsets(), _perms(), _varorders(varorders
), _estimation(estimation
), _deleteEstimation(0)
203 setPermsAndVarSetsFromVarOrders();
207 void SharedParameters::collectSufficientStatistics(InfAlg
& alg
) {
208 std::map
< FactorIndex
, Permute
>::iterator i
= _perms
.begin();
209 for ( ; i
!= _perms
.end(); ++i
) {
210 Permute
& perm
= i
->second
;
211 VarSet
& vs
= _varsets
[i
->first
];
213 Factor b
= alg
.belief(vs
);
214 Prob
p(b
.states(), 0.0);
215 for (size_t entry
= 0; entry
< b
.states(); ++entry
)
216 p
[entry
] = b
[perm
.convert_linear_index(entry
)];
217 _estimation
->addSufficientStatistics(p
);
222 void SharedParameters::setParameters(FactorGraph
& fg
) {
223 Prob p
= _estimation
->estimate();
224 std::map
< FactorIndex
, Permute
>::iterator i
= _perms
.begin();
225 for ( ; i
!= _perms
.end(); ++i
) {
226 Permute
& perm
= i
->second
;
227 VarSet
& vs
= _varsets
[i
->first
];
230 for (size_t entry
= 0; entry
< f
.states(); ++entry
)
231 f
[perm
.convert_linear_index(entry
)] = p
[entry
];
233 fg
.setFactor(i
->first
, f
);
238 MaximizationStep::MaximizationStep (std::istream
& is
, const FactorGraph
& fg_varlookup
)
241 size_t num_params
= -1;
243 _params
.reserve(num_params
);
244 for (size_t i
= 0; i
< num_params
; ++i
) {
245 SharedParameters
p(is
, fg_varlookup
);
246 _params
.push_back(p
);
251 void MaximizationStep::addExpectations(InfAlg
& alg
) {
252 for (size_t i
= 0; i
< _params
.size(); ++i
)
253 _params
[i
].collectSufficientStatistics(alg
);
257 void MaximizationStep::maximize(FactorGraph
& fg
) {
258 for (size_t i
= 0; i
< _params
.size(); ++i
)
259 _params
[i
].setParameters(fg
);
263 const std::string
EMAlg::MAX_ITERS_KEY("max_iters");
264 const std::string
EMAlg::LOG_Z_TOL_KEY("log_z_tol");
265 const size_t EMAlg::MAX_ITERS_DEFAULT
= 30;
266 const Real
EMAlg::LOG_Z_TOL_DEFAULT
= 0.01;
269 EMAlg::EMAlg(const Evidence
& evidence
, InfAlg
& estep
, std::istream
& msteps_file
)
270 : _evidence(evidence
), _estep(estep
), _msteps(), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT
), _log_z_tol(LOG_Z_TOL_DEFAULT
)
272 msteps_file
.exceptions( std::istream::eofbit
| std::istream::failbit
| std::istream::badbit
);
273 size_t num_msteps
= -1;
274 msteps_file
>> num_msteps
;
275 _msteps
.reserve(num_msteps
);
276 for (size_t i
= 0; i
< num_msteps
; ++i
) {
277 MaximizationStep
m(msteps_file
, estep
.fg());
278 _msteps
.push_back(m
);
283 void EMAlg::setTermConditions(const PropertySet
* p
) {
286 if (p
->hasKey(MAX_ITERS_KEY
))
287 _max_iters
= p
->getStringAs
<size_t>(MAX_ITERS_KEY
);
288 if (p
->hasKey(LOG_Z_TOL_KEY
))
289 _log_z_tol
= p
->getStringAs
<Real
>(LOG_Z_TOL_KEY
);
293 bool EMAlg::hasSatisfiedTermConditions() const {
294 if (_iters
>= _max_iters
) {
296 } else if (_lastLogZ
.size() < 3) {
297 // need at least 2 to calculate ratio
298 // Also, throw away first iteration, as the parameters may not
299 // have been normalized according to the estimation method
302 Real current
= _lastLogZ
[_lastLogZ
.size() - 1];
303 Real previous
= _lastLogZ
[_lastLogZ
.size() - 2];
304 if (previous
== 0) return 0;
305 Real diff
= current
- previous
;
307 std::cerr
<< "Error: in EM log-likehood decreased from " << previous
<< " to " << current
<< std::endl
;
310 return diff
/ abs(previous
) <= _log_z_tol
;
315 Real
EMAlg::iterate(MaximizationStep
& mstep
) {
316 Evidence::const_iterator e
= _evidence
.begin();
319 // Expectation calculation
320 for ( ; e
!= _evidence
.end(); ++e
) {
321 InfAlg
* clamped
= _estep
.clone();
322 e
->second
.applyEvidence(*clamped
);
325 logZ
+= clamped
->logZ();
327 mstep
.addExpectations(*clamped
);
332 // Maximization of parameters
333 mstep
.maximize(_estep
.fg());
339 Real
EMAlg::iterate() {
341 for (size_t i
= 0; i
< _msteps
.size(); ++i
)
342 likelihood
= iterate(_msteps
[i
]);
343 _lastLogZ
.push_back(likelihood
);
350 while(!hasSatisfiedTermConditions())
355 } // end of namespace dai