1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
8 * Copyright (C) 2009 University of California, Santa Cruz
13 #include <dai/emalg.h>
19 // Initialize static private member of ParameterEstimation
20 std::map
<std::string
, ParameterEstimation::ParamEstFactory
> *ParameterEstimation::_registry
= NULL
;
23 void ParameterEstimation::loadDefaultRegistry() {
24 _registry
= new std::map
<std::string
, ParamEstFactory
>();
25 (*_registry
)["ConditionalProbEstimation"] = CondProbEstimation::factory
;
29 ParameterEstimation
* ParameterEstimation::construct( const std::string
&method
, const PropertySet
&p
) {
30 if( _registry
== NULL
)
31 loadDefaultRegistry();
32 std::map
<std::string
, ParamEstFactory
>::iterator i
= _registry
->find(method
);
33 if( i
== _registry
->end() )
34 DAI_THROWE(UNKNOWN_PARAMETER_ESTIMATION_METHOD
, "Unknown parameter estimation method: " + method
);
35 ParamEstFactory factory
= i
->second
;
40 ParameterEstimation
* CondProbEstimation::factory( const PropertySet
&p
) {
41 size_t target_dimension
= p
.getStringAs
<size_t>("target_dim");
42 size_t total_dimension
= p
.getStringAs
<size_t>("total_dim");
43 Real pseudo_count
= 1;
44 if( p
.hasKey("pseudo_count") )
45 pseudo_count
= p
.getStringAs
<Real
>("pseudo_count");
46 return new CondProbEstimation( target_dimension
, Prob( total_dimension
, pseudo_count
) );
50 CondProbEstimation::CondProbEstimation( size_t target_dimension
, const Prob
&pseudocounts
)
51 : _target_dim(target_dimension
), _stats(pseudocounts
), _initial_stats(pseudocounts
)
53 DAI_ASSERT( !(_stats
.size() % _target_dim
) );
57 void CondProbEstimation::addSufficientStatistics( const Prob
&p
) {
62 Prob
CondProbEstimation::estimate() {
63 // normalize pseudocounts
64 for( size_t parent
= 0; parent
< _stats
.size(); parent
+= _target_dim
) {
66 size_t top
= parent
+ _target_dim
;
67 Real norm
= std::accumulate( &(_stats
[parent
]), &(_stats
[top
]), 0.0 );
71 for( size_t i
= parent
; i
< top
; ++i
)
74 // reset _stats to _initial_stats
76 _stats
= _initial_stats
;
81 Permute
SharedParameters::calculatePermutation( const std::vector
<Var
> &varOrder
, VarSet
&outVS
) {
82 outVS
= VarSet( varOrder
.begin(), varOrder
.end(), varOrder
.size() );
83 return Permute( varOrder
);
87 void SharedParameters::setPermsAndVarSetsFromVarOrders() {
88 if( _varorders
.size() == 0 )
90 DAI_ASSERT( _estimation
!= NULL
);
92 // Construct the permutation objects and the varsets
93 for( FactorOrientations::const_iterator foi
= _varorders
.begin(); foi
!= _varorders
.end(); ++foi
) {
95 _perms
[foi
->first
] = calculatePermutation( foi
->second
, vs
);
96 _varsets
[foi
->first
] = vs
;
97 DAI_ASSERT( _estimation
->probSize() == vs
.nrStates() );
102 SharedParameters::SharedParameters( const FactorOrientations
&varorders
, ParameterEstimation
*estimation
, bool ownPE
)
103 : _varsets(), _perms(), _varorders(varorders
), _estimation(estimation
), _ownEstimation(ownPE
)
105 // Calculate the necessary permutations and varsets
106 setPermsAndVarSetsFromVarOrders();
110 SharedParameters::SharedParameters( std::istream
&is
, const FactorGraph
&fg
)
111 : _varsets(), _perms(), _varorders(), _estimation(NULL
), _ownEstimation(true)
113 // Read the desired parameter estimation method from the stream
114 std::string est_method
;
119 // Construct a corresponding object
120 _estimation
= ParameterEstimation::construct( est_method
, props
);
122 // Read in the factors that are to be estimated
125 for( size_t sp_i
= 0; sp_i
< num_factors
; ++sp_i
) {
127 while( line
.size() == 0 && getline(is
, line
) )
130 std::vector
<std::string
> fields
;
131 tokenizeString(line
, fields
, " \t");
133 // Lookup the factor in the factorgraph
134 if( fields
.size() < 1 )
135 DAI_THROWE(INVALID_EMALG_FILE
,"Empty line unexpected");
136 std::istringstream iss
;
137 iss
.str( fields
[0] );
140 const VarSet
&vs
= fg
.factor(factor
).vars();
141 if( fields
.size() != vs
.size() + 1 )
142 DAI_THROWE(INVALID_EMALG_FILE
,"Number of fields does not match factor size");
144 // Construct the vector of Vars
145 std::vector
<Var
> var_order
;
146 var_order
.reserve( vs
.size() );
147 for( size_t fi
= 1; fi
< fields
.size(); ++fi
) {
148 // Lookup a single variable by label
150 std::istringstream
labelparse( fields
[fi
] );
152 VarSet::const_iterator vsi
= vs
.begin();
153 for( ; vsi
!= vs
.end(); ++vsi
)
154 if( vsi
->label() == label
)
156 if( vsi
== vs
.end() )
157 DAI_THROWE(INVALID_EMALG_FILE
,"Specified variables do not match the factor variables");
158 var_order
.push_back( *vsi
);
160 _varorders
[factor
] = var_order
;
163 // Calculate the necessary permutations
164 setPermsAndVarSetsFromVarOrders();
168 void SharedParameters::collectSufficientStatistics( InfAlg
&alg
) {
169 for( std::map
< FactorIndex
, Permute
>::iterator i
= _perms
.begin(); i
!= _perms
.end(); ++i
) {
170 Permute
&perm
= i
->second
;
171 VarSet
&vs
= _varsets
[i
->first
];
173 Factor b
= alg
.belief(vs
);
174 Prob
p( b
.states(), 0.0 );
175 for( size_t entry
= 0; entry
< b
.states(); ++entry
)
176 p
[entry
] = b
[perm
.convertLinearIndex(entry
)]; // apply inverse permutation
177 _estimation
->addSufficientStatistics( p
);
182 void SharedParameters::setParameters( FactorGraph
&fg
) {
183 Prob p
= _estimation
->estimate();
184 for( std::map
<FactorIndex
, Permute
>::iterator i
= _perms
.begin(); i
!= _perms
.end(); ++i
) {
185 Permute
&perm
= i
->second
;
186 VarSet
&vs
= _varsets
[i
->first
];
189 for( size_t entry
= 0; entry
< f
.states(); ++entry
)
190 f
[perm
.convertLinearIndex(entry
)] = p
[entry
];
192 fg
.setFactor( i
->first
, f
);
197 MaximizationStep::MaximizationStep( std::istream
&is
, const FactorGraph
&fg_varlookup
) : _params() {
198 size_t num_params
= -1;
200 _params
.reserve( num_params
);
201 for( size_t i
= 0; i
< num_params
; ++i
)
202 _params
.push_back( SharedParameters( is
, fg_varlookup
) );
206 void MaximizationStep::addExpectations( InfAlg
&alg
) {
207 for( size_t i
= 0; i
< _params
.size(); ++i
)
208 _params
[i
].collectSufficientStatistics( alg
);
212 void MaximizationStep::maximize( FactorGraph
&fg
) {
213 for( size_t i
= 0; i
< _params
.size(); ++i
)
214 _params
[i
].setParameters( fg
);
218 const std::string
EMAlg::MAX_ITERS_KEY("max_iters");
219 const std::string
EMAlg::LOG_Z_TOL_KEY("log_z_tol");
220 const size_t EMAlg::MAX_ITERS_DEFAULT
= 30;
221 const Real
EMAlg::LOG_Z_TOL_DEFAULT
= 0.01;
224 EMAlg::EMAlg( const Evidence
&evidence
, InfAlg
&estep
, std::istream
&msteps_file
)
225 : _evidence(evidence
), _estep(estep
), _msteps(), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT
), _log_z_tol(LOG_Z_TOL_DEFAULT
)
227 msteps_file
.exceptions( std::istream::eofbit
| std::istream::failbit
| std::istream::badbit
);
228 size_t num_msteps
= -1;
229 msteps_file
>> num_msteps
;
230 _msteps
.reserve(num_msteps
);
231 for( size_t i
= 0; i
< num_msteps
; ++i
)
232 _msteps
.push_back( MaximizationStep( msteps_file
, estep
.fg() ) );
236 void EMAlg::setTermConditions( const PropertySet
&p
) {
237 if( p
.hasKey(MAX_ITERS_KEY
) )
238 _max_iters
= p
.getStringAs
<size_t>(MAX_ITERS_KEY
);
239 if( p
.hasKey(LOG_Z_TOL_KEY
) )
240 _log_z_tol
= p
.getStringAs
<Real
>(LOG_Z_TOL_KEY
);
244 bool EMAlg::hasSatisfiedTermConditions() const {
245 if( _iters
>= _max_iters
)
247 else if( _lastLogZ
.size() < 3 )
248 // need at least 2 to calculate ratio
249 // Also, throw away first iteration, as the parameters may not
250 // have been normalized according to the estimation method
253 Real current
= _lastLogZ
[_lastLogZ
.size() - 1];
254 Real previous
= _lastLogZ
[_lastLogZ
.size() - 2];
257 Real diff
= current
- previous
;
259 std::cerr
<< "Error: in EM log-likehood decreased from " << previous
<< " to " << current
<< std::endl
;
262 return (diff
/ fabs(previous
)) <= _log_z_tol
;
267 Real
EMAlg::iterate( MaximizationStep
&mstep
) {
272 logZ
= _estep
.logZ();
274 // Expectation calculation
275 for( Evidence::const_iterator e
= _evidence
.begin(); e
!= _evidence
.end(); ++e
) {
276 InfAlg
* clamped
= _estep
.clone();
278 for( Evidence::Observation::const_iterator i
= e
->begin(); i
!= e
->end(); ++i
)
279 clamped
->clamp( clamped
->fg().findVar(i
->first
), i
->second
);
283 likelihood
+= clamped
->logZ() - logZ
;
285 mstep
.addExpectations( *clamped
);
290 // Maximization of parameters
291 mstep
.maximize( _estep
.fg() );
297 Real
EMAlg::iterate() {
299 for( size_t i
= 0; i
< _msteps
.size(); ++i
)
300 likelihood
= iterate( _msteps
[i
] );
301 _lastLogZ
.push_back( likelihood
);
308 while( !hasSatisfiedTermConditions() )
313 } // end of namespace dai