1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
8 * Copyright (C) 2009 University of California, Santa Cruz
13 #include <dai/emalg.h>
19 std::map
<std::string
, ParameterEstimation::ParamEstFactory
> *ParameterEstimation::_registry
= NULL
;
22 void ParameterEstimation::loadDefaultRegistry() {
23 _registry
= new std::map
<std::string
, ParamEstFactory
>();
24 (*_registry
)["ConditionalProbEstimation"] = CondProbEstimation::factory
;
28 ParameterEstimation
* ParameterEstimation::construct( const std::string
&method
, const PropertySet
&p
) {
29 if( _registry
== NULL
)
30 loadDefaultRegistry();
31 std::map
<std::string
, ParamEstFactory
>::iterator i
= _registry
->find(method
);
32 if( i
== _registry
->end() )
33 DAI_THROWE(UNKNOWN_PARAMETER_ESTIMATION_METHOD
, "Unknown parameter estimation method: " + method
);
34 ParamEstFactory factory
= i
->second
;
39 ParameterEstimation
* CondProbEstimation::factory( const PropertySet
&p
) {
40 size_t target_dimension
= p
.getStringAs
<size_t>("target_dim");
41 size_t total_dimension
= p
.getStringAs
<size_t>("total_dim");
42 Real pseudo_count
= 1;
43 if( p
.hasKey("pseudo_count") )
44 pseudo_count
= p
.getStringAs
<Real
>("pseudo_count");
45 return new CondProbEstimation( target_dimension
, Prob( total_dimension
, pseudo_count
) );
49 CondProbEstimation::CondProbEstimation( size_t target_dimension
, const Prob
&pseudocounts
)
50 : _target_dim(target_dimension
), _stats(pseudocounts
), _initial_stats(pseudocounts
)
52 assert( !(_stats
.size() % _target_dim
) );
56 void CondProbEstimation::addSufficientStatistics( const Prob
&p
) {
61 Prob
CondProbEstimation::estimate() {
62 // normalize pseudocounts
63 for( size_t parent
= 0; parent
< _stats
.size(); parent
+= _target_dim
) {
66 size_t top
= parent
+ _target_dim
;
67 for( size_t i
= parent
; i
< top
; ++i
)
72 for( size_t i
= parent
; i
< top
; ++i
)
75 // reset _stats to _initial_stats
77 _stats
= _initial_stats
;
82 Permute
SharedParameters::calculatePermutation( const std::vector
<Var
> &varorder
, VarSet
&outVS
) {
83 // Collect all labels and dimensions, and order them in vs
84 std::vector
<size_t> dims
;
85 dims
.reserve( varorder
.size() );
86 std::vector
<long> labels
;
87 labels
.reserve( varorder
.size() );
88 for( size_t i
= 0; i
< varorder
.size(); i
++ ) {
89 dims
.push_back( varorder
[i
].states() );
90 labels
.push_back( varorder
[i
].label() );
94 // Construct the sigma array for the permutation object
95 std::vector
<size_t> sigma
;
96 sigma
.reserve( dims
.size() );
97 for( VarSet::iterator set_iterator
= outVS
.begin(); sigma
.size() < dims
.size(); ++set_iterator
)
98 sigma
.push_back( find(labels
.begin(), labels
.end(), set_iterator
->label()) - labels
.begin() );
100 return Permute( dims
, sigma
);
104 void SharedParameters::setPermsAndVarSetsFromVarOrders() {
105 if( _varorders
.size() == 0 )
107 assert( _estimation
!= NULL
);
109 // Construct the permutation objects and the varsets
110 for( FactorOrientations::const_iterator foi
= _varorders
.begin(); foi
!= _varorders
.end(); ++foi
) {
112 _perms
[foi
->first
] = calculatePermutation( foi
->second
, vs
);
113 _varsets
[foi
->first
] = vs
;
114 assert( _estimation
->probSize() == vs
.nrStates() );
119 SharedParameters::SharedParameters( std::istream
&is
, const FactorGraph
&fg_varlookup
)
120 : _varsets(), _perms(), _varorders(), _estimation(NULL
), _deleteEstimation(true)
122 // Read the desired parameter estimation method from the stream
123 std::string est_method
;
128 // Construct a corresponding object
129 _estimation
= ParameterEstimation::construct( est_method
, props
);
131 // Read in the factors that are to be estimated
134 for( size_t sp_i
= 0; sp_i
< num_factors
; ++sp_i
) {
136 while( line
.size() == 0 && getline(is
, line
) )
139 std::vector
<std::string
> fields
;
140 tokenizeString(line
, fields
, " \t");
142 // Lookup the factor in the factorgraph
143 if( fields
.size() < 1 )
144 DAI_THROW(INVALID_EMALG_FILE
);
145 std::istringstream iss
;
146 iss
.str( fields
[0] );
149 const VarSet
&vs
= fg_varlookup
.factor(factor
).vars();
150 if( fields
.size() != vs
.size() + 1 )
151 DAI_THROW(INVALID_EMALG_FILE
);
153 // Construct the vector of Vars
154 std::vector
<Var
> var_order
;
155 var_order
.reserve( vs
.size() );
156 for( size_t fi
= 1; fi
< fields
.size(); ++fi
) {
157 // Lookup a single variable by label
159 std::istringstream
labelparse( fields
[fi
] );
161 VarSet::const_iterator vsi
= vs
.begin();
162 for( ; vsi
!= vs
.end(); ++vsi
)
163 if( vsi
->label() == label
)
165 if( vsi
== vs
.end() )
166 DAI_THROW(INVALID_EMALG_FILE
);
167 var_order
.push_back( *vsi
);
169 _varorders
[factor
] = var_order
;
172 // Calculate the necessary permutations
173 setPermsAndVarSetsFromVarOrders();
177 SharedParameters::SharedParameters( const SharedParameters
&sp
)
178 : _varsets(sp
._varsets
), _perms(sp
._perms
), _varorders(sp
._varorders
), _estimation(sp
._estimation
), _deleteEstimation(sp
._deleteEstimation
)
180 // If sp owns its _estimation object, we should clone it instead
181 if( _deleteEstimation
)
182 _estimation
= _estimation
->clone();
186 SharedParameters::SharedParameters( const FactorOrientations
&varorders
, ParameterEstimation
*estimation
, bool deletePE
)
187 : _varsets(), _perms(), _varorders(varorders
), _estimation(estimation
), _deleteEstimation(deletePE
)
189 // Calculate the necessary permutations
190 setPermsAndVarSetsFromVarOrders();
194 void SharedParameters::collectSufficientStatistics( InfAlg
&alg
) {
195 for( std::map
< FactorIndex
, Permute
>::iterator i
= _perms
.begin(); i
!= _perms
.end(); ++i
) {
196 Permute
&perm
= i
->second
;
197 VarSet
&vs
= _varsets
[i
->first
];
199 Factor b
= alg
.belief(vs
);
200 Prob
p( b
.states(), 0.0 );
201 for( size_t entry
= 0; entry
< b
.states(); ++entry
)
202 p
[entry
] = b
[perm
.convert_linear_index(entry
)];
203 _estimation
->addSufficientStatistics( p
);
208 void SharedParameters::setParameters( FactorGraph
&fg
) {
209 Prob p
= _estimation
->estimate();
210 for( std::map
<FactorIndex
, Permute
>::iterator i
= _perms
.begin(); i
!= _perms
.end(); ++i
) {
211 Permute
&perm
= i
->second
;
212 VarSet
&vs
= _varsets
[i
->first
];
215 for( size_t entry
= 0; entry
< f
.states(); ++entry
)
216 f
[perm
.convert_linear_index(entry
)] = p
[entry
];
218 fg
.setFactor( i
->first
, f
);
223 void SharedParameters::collectParameters( const FactorGraph
&fg
, std::vector
<Real
> &outVals
, std::vector
<Var
> &outVarOrder
) {
224 FactorOrientations::iterator it
= _varorders
.begin();
225 if( it
== _varorders
.end() )
227 FactorIndex I
= it
->first
;
228 for( std::vector
<Var
>::const_iterator var_it
= _varorders
[I
].begin(); var_it
!= _varorders
[I
].end(); ++var_it
)
229 outVarOrder
.push_back( *var_it
);
231 const Factor
&f
= fg
.factor(I
);
232 assert( f
.vars() == _varsets
[I
] );
233 const Permute
&perm
= _perms
[I
];
234 for( size_t val_index
= 0; val_index
< f
.states(); ++val_index
)
235 outVals
.push_back( f
[perm
.convert_linear_index(val_index
)] );
239 MaximizationStep::MaximizationStep( std::istream
&is
, const FactorGraph
&fg_varlookup
) : _params() {
240 size_t num_params
= -1;
242 _params
.reserve( num_params
);
243 for( size_t i
= 0; i
< num_params
; ++i
)
244 _params
.push_back( SharedParameters( is
, fg_varlookup
) );
248 void MaximizationStep::addExpectations( InfAlg
&alg
) {
249 for( size_t i
= 0; i
< _params
.size(); ++i
)
250 _params
[i
].collectSufficientStatistics( alg
);
254 void MaximizationStep::maximize( FactorGraph
&fg
) {
255 for( size_t i
= 0; i
< _params
.size(); ++i
)
256 _params
[i
].setParameters( fg
);
260 const std::string
EMAlg::MAX_ITERS_KEY("max_iters");
261 const std::string
EMAlg::LOG_Z_TOL_KEY("log_z_tol");
262 const size_t EMAlg::MAX_ITERS_DEFAULT
= 30;
263 const Real
EMAlg::LOG_Z_TOL_DEFAULT
= 0.01;
266 EMAlg::EMAlg( const Evidence
&evidence
, InfAlg
&estep
, std::istream
&msteps_file
)
267 : _evidence(evidence
), _estep(estep
), _msteps(), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT
), _log_z_tol(LOG_Z_TOL_DEFAULT
)
269 msteps_file
.exceptions( std::istream::eofbit
| std::istream::failbit
| std::istream::badbit
);
270 size_t num_msteps
= -1;
271 msteps_file
>> num_msteps
;
272 _msteps
.reserve(num_msteps
);
273 for( size_t i
= 0; i
< num_msteps
; ++i
)
274 _msteps
.push_back( MaximizationStep( msteps_file
, estep
.fg() ) );
278 void EMAlg::setTermConditions( const PropertySet
&p
) {
279 if( p
.hasKey(MAX_ITERS_KEY
) )
280 _max_iters
= p
.getStringAs
<size_t>(MAX_ITERS_KEY
);
281 if( p
.hasKey(LOG_Z_TOL_KEY
) )
282 _log_z_tol
= p
.getStringAs
<Real
>(LOG_Z_TOL_KEY
);
286 bool EMAlg::hasSatisfiedTermConditions() const {
287 if( _iters
>= _max_iters
)
289 else if( _lastLogZ
.size() < 3 )
290 // need at least 2 to calculate ratio
291 // Also, throw away first iteration, as the parameters may not
292 // have been normalized according to the estimation method
295 Real current
= _lastLogZ
[_lastLogZ
.size() - 1];
296 Real previous
= _lastLogZ
[_lastLogZ
.size() - 2];
299 Real diff
= current
- previous
;
301 std::cerr
<< "Error: in EM log-likehood decreased from " << previous
<< " to " << current
<< std::endl
;
304 return (diff
/ fabs(previous
)) <= _log_z_tol
;
309 Real
EMAlg::iterate( MaximizationStep
&mstep
) {
314 logZ
= _estep
.logZ();
316 // Expectation calculation
317 for( Evidence::const_iterator e
= _evidence
.begin(); e
!= _evidence
.end(); ++e
) {
318 InfAlg
* clamped
= _estep
.clone();
319 e
->applyEvidence( *clamped
);
323 likelihood
+= clamped
->logZ() - logZ
;
325 mstep
.addExpectations( *clamped
);
330 // Maximization of parameters
331 mstep
.maximize( _estep
.fg() );
337 Real
EMAlg::iterate() {
339 for( size_t i
= 0; i
< _msteps
.size(); ++i
)
340 likelihood
= iterate( _msteps
[i
] );
341 _lastLogZ
.push_back( likelihood
);
348 while( !hasSatisfiedTermConditions() )
353 } // end of namespace dai