1 /* Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
2 University of California Santa Cruz
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 #include <dai/emalg.h>
29 std::map
<std::string
, ParameterEstimation::ParamEstFactory
> *ParameterEstimation::_registry
= NULL
;
32 void ParameterEstimation::loadDefaultRegistry() {
33 _registry
= new std::map
<std::string
, ParamEstFactory
>();
34 (*_registry
)["ConditionalProbEstimation"] = CondProbEstimation::factory
;
38 ParameterEstimation
* ParameterEstimation::construct( const std::string
&method
, const PropertySet
&p
) {
39 if( _registry
== NULL
)
40 loadDefaultRegistry();
41 std::map
<std::string
, ParamEstFactory
>::iterator i
= _registry
->find(method
);
42 if( i
== _registry
->end() )
43 DAI_THROWE(UNKNOWN_PARAMETER_ESTIMATION_METHOD
, "Unknown parameter estimation method: " + method
);
44 ParamEstFactory factory
= i
->second
;
49 ParameterEstimation
* CondProbEstimation::factory( const PropertySet
&p
) {
50 size_t target_dimension
= p
.getStringAs
<size_t>("target_dim");
51 size_t total_dimension
= p
.getStringAs
<size_t>("total_dim");
52 Real pseudo_count
= 1;
53 if( p
.hasKey("pseudo_count") )
54 pseudo_count
= p
.getStringAs
<Real
>("pseudo_count");
55 return new CondProbEstimation( target_dimension
, Prob( total_dimension
, pseudo_count
) );
59 CondProbEstimation::CondProbEstimation( size_t target_dimension
, const Prob
&pseudocounts
)
60 : _target_dim(target_dimension
), _stats(pseudocounts
), _initial_stats(pseudocounts
)
62 assert( !(_stats
.size() % _target_dim
) );
66 void CondProbEstimation::addSufficientStatistics( const Prob
&p
) {
71 Prob
CondProbEstimation::estimate() {
72 // normalize pseudocounts
73 for( size_t parent
= 0; parent
< _stats
.size(); parent
+= _target_dim
) {
76 size_t top
= parent
+ _target_dim
;
77 for( size_t i
= parent
; i
< top
; ++i
)
82 for( size_t i
= parent
; i
< top
; ++i
)
85 // reset _stats to _initial_stats
87 _stats
= _initial_stats
;
92 Permute
SharedParameters::calculatePermutation( const std::vector
<Var
> &varorder
, VarSet
&outVS
) {
93 // Collect all labels and dimensions, and order them in vs
94 std::vector
<size_t> dims
;
95 dims
.reserve( varorder
.size() );
96 std::vector
<long> labels
;
97 labels
.reserve( varorder
.size() );
98 for( size_t i
= 0; i
< varorder
.size(); i
++ ) {
99 dims
.push_back( varorder
[i
].states() );
100 labels
.push_back( varorder
[i
].label() );
101 outVS
|= varorder
[i
];
104 // Construct the sigma array for the permutation object
105 std::vector
<size_t> sigma
;
106 sigma
.reserve( dims
.size() );
107 for( VarSet::iterator set_iterator
= outVS
.begin(); sigma
.size() < dims
.size(); ++set_iterator
)
108 sigma
.push_back( find(labels
.begin(), labels
.end(), set_iterator
->label()) - labels
.begin() );
110 return Permute( dims
, sigma
);
114 void SharedParameters::setPermsAndVarSetsFromVarOrders() {
115 if( _varorders
.size() == 0 )
117 assert( _estimation
!= NULL
);
119 // Construct the permutation objects and the varsets
120 for( FactorOrientations::const_iterator foi
= _varorders
.begin(); foi
!= _varorders
.end(); ++foi
) {
122 _perms
[foi
->first
] = calculatePermutation( foi
->second
, vs
);
123 _varsets
[foi
->first
] = vs
;
124 assert( _estimation
->probSize() == vs
.nrStates() );
129 SharedParameters::SharedParameters( std::istream
&is
, const FactorGraph
&fg_varlookup
)
130 : _varsets(), _perms(), _varorders(), _estimation(NULL
), _deleteEstimation(true)
132 // Read the desired parameter estimation method from the stream
133 std::string est_method
;
138 // Construct a corresponding object
139 _estimation
= ParameterEstimation::construct( est_method
, props
);
141 // Read in the factors that are to be estimated
144 for( size_t sp_i
= 0; sp_i
< num_factors
; ++sp_i
) {
146 while( line
.size() == 0 && getline(is
, line
) )
149 std::vector
<std::string
> fields
;
150 tokenizeString(line
, fields
, " \t");
152 // Lookup the factor in the factorgraph
153 if( fields
.size() < 1 )
154 DAI_THROW(INVALID_EMALG_FILE
);
155 std::istringstream iss
;
156 iss
.str( fields
[0] );
159 const VarSet
&vs
= fg_varlookup
.factor(factor
).vars();
160 if( fields
.size() != vs
.size() + 1 )
161 DAI_THROW(INVALID_EMALG_FILE
);
163 // Construct the vector of Vars
164 std::vector
<Var
> var_order
;
165 var_order
.reserve( vs
.size() );
166 for( size_t fi
= 1; fi
< fields
.size(); ++fi
) {
167 // Lookup a single variable by label
169 std::istringstream
labelparse( fields
[fi
] );
171 VarSet::const_iterator vsi
= vs
.begin();
172 for( ; vsi
!= vs
.end(); ++vsi
)
173 if( vsi
->label() == label
)
175 if( vsi
== vs
.end() )
176 DAI_THROW(INVALID_EMALG_FILE
);
177 var_order
.push_back( *vsi
);
179 _varorders
[factor
] = var_order
;
182 // Calculate the necessary permutations
183 setPermsAndVarSetsFromVarOrders();
187 SharedParameters::SharedParameters( const SharedParameters
&sp
)
188 : _varsets(sp
._varsets
), _perms(sp
._perms
), _varorders(sp
._varorders
), _estimation(sp
._estimation
), _deleteEstimation(sp
._deleteEstimation
)
190 // If sp owns its _estimation object, we should clone it instead
191 if( _deleteEstimation
)
192 _estimation
= _estimation
->clone();
196 SharedParameters::SharedParameters( const FactorOrientations
&varorders
, ParameterEstimation
*estimation
, bool deletePE
)
197 : _varsets(), _perms(), _varorders(varorders
), _estimation(estimation
), _deleteEstimation(deletePE
)
199 // Calculate the necessary permutations
200 setPermsAndVarSetsFromVarOrders();
204 void SharedParameters::collectSufficientStatistics( InfAlg
&alg
) {
205 for( std::map
< FactorIndex
, Permute
>::iterator i
= _perms
.begin(); i
!= _perms
.end(); ++i
) {
206 Permute
&perm
= i
->second
;
207 VarSet
&vs
= _varsets
[i
->first
];
209 Factor b
= alg
.belief(vs
);
210 Prob
p( b
.states(), 0.0 );
211 for( size_t entry
= 0; entry
< b
.states(); ++entry
)
212 p
[entry
] = b
[perm
.convert_linear_index(entry
)];
213 _estimation
->addSufficientStatistics( p
);
218 void SharedParameters::setParameters( FactorGraph
&fg
) {
219 Prob p
= _estimation
->estimate();
220 for( std::map
<FactorIndex
, Permute
>::iterator i
= _perms
.begin(); i
!= _perms
.end(); ++i
) {
221 Permute
&perm
= i
->second
;
222 VarSet
&vs
= _varsets
[i
->first
];
225 for( size_t entry
= 0; entry
< f
.states(); ++entry
)
226 f
[perm
.convert_linear_index(entry
)] = p
[entry
];
228 fg
.setFactor( i
->first
, f
);
233 void SharedParameters::collectParameters( const FactorGraph
&fg
, std::vector
<Real
> &outVals
, std::vector
<Var
> &outVarOrder
) {
234 FactorOrientations::iterator it
= _varorders
.begin();
235 if( it
== _varorders
.end() )
237 FactorIndex I
= it
->first
;
238 for( std::vector
<Var
>::const_iterator var_it
= _varorders
[I
].begin(); var_it
!= _varorders
[I
].end(); ++var_it
)
239 outVarOrder
.push_back( *var_it
);
241 const Factor
&f
= fg
.factor(I
);
242 assert( f
.vars() == _varsets
[I
] );
243 const Permute
&perm
= _perms
[I
];
244 for( size_t val_index
= 0; val_index
< f
.states(); ++val_index
)
245 outVals
.push_back( f
[perm
.convert_linear_index(val_index
)] );
249 MaximizationStep::MaximizationStep( std::istream
&is
, const FactorGraph
&fg_varlookup
) : _params() {
250 size_t num_params
= -1;
252 _params
.reserve( num_params
);
253 for( size_t i
= 0; i
< num_params
; ++i
)
254 _params
.push_back( SharedParameters( is
, fg_varlookup
) );
258 void MaximizationStep::addExpectations( InfAlg
&alg
) {
259 for( size_t i
= 0; i
< _params
.size(); ++i
)
260 _params
[i
].collectSufficientStatistics( alg
);
264 void MaximizationStep::maximize( FactorGraph
&fg
) {
265 for( size_t i
= 0; i
< _params
.size(); ++i
)
266 _params
[i
].setParameters( fg
);
270 const std::string
EMAlg::MAX_ITERS_KEY("max_iters");
271 const std::string
EMAlg::LOG_Z_TOL_KEY("log_z_tol");
272 const size_t EMAlg::MAX_ITERS_DEFAULT
= 30;
273 const Real
EMAlg::LOG_Z_TOL_DEFAULT
= 0.01;
276 EMAlg::EMAlg( const Evidence
&evidence
, InfAlg
&estep
, std::istream
&msteps_file
)
277 : _evidence(evidence
), _estep(estep
), _msteps(), _iters(0), _lastLogZ(), _max_iters(MAX_ITERS_DEFAULT
), _log_z_tol(LOG_Z_TOL_DEFAULT
)
279 msteps_file
.exceptions( std::istream::eofbit
| std::istream::failbit
| std::istream::badbit
);
280 size_t num_msteps
= -1;
281 msteps_file
>> num_msteps
;
282 _msteps
.reserve(num_msteps
);
283 for( size_t i
= 0; i
< num_msteps
; ++i
)
284 _msteps
.push_back( MaximizationStep( msteps_file
, estep
.fg() ) );
288 void EMAlg::setTermConditions( const PropertySet
&p
) {
289 if( p
.hasKey(MAX_ITERS_KEY
) )
290 _max_iters
= p
.getStringAs
<size_t>(MAX_ITERS_KEY
);
291 if( p
.hasKey(LOG_Z_TOL_KEY
) )
292 _log_z_tol
= p
.getStringAs
<Real
>(LOG_Z_TOL_KEY
);
296 bool EMAlg::hasSatisfiedTermConditions() const {
297 if( _iters
>= _max_iters
)
299 else if( _lastLogZ
.size() < 3 )
300 // need at least 2 to calculate ratio
301 // Also, throw away first iteration, as the parameters may not
302 // have been normalized according to the estimation method
305 Real current
= _lastLogZ
[_lastLogZ
.size() - 1];
306 Real previous
= _lastLogZ
[_lastLogZ
.size() - 2];
309 Real diff
= current
- previous
;
311 std::cerr
<< "Error: in EM log-likehood decreased from " << previous
<< " to " << current
<< std::endl
;
314 return (diff
/ fabs(previous
)) <= _log_z_tol
;
319 Real
EMAlg::iterate( MaximizationStep
&mstep
) {
324 logZ
= _estep
.logZ();
326 // Expectation calculation
327 for( Evidence::const_iterator e
= _evidence
.begin(); e
!= _evidence
.end(); ++e
) {
328 InfAlg
* clamped
= _estep
.clone();
329 e
->applyEvidence( *clamped
);
333 likelihood
+= clamped
->logZ() - logZ
;
335 mstep
.addExpectations( *clamped
);
340 // Maximization of parameters
341 mstep
.maximize( _estep
.fg() );
347 Real
EMAlg::iterate() {
349 for( size_t i
= 0; i
< _msteps
.size(); ++i
)
350 likelihood
= iterate( _msteps
[i
] );
351 _lastLogZ
.push_back( likelihood
);
358 while( !hasSatisfiedTermConditions() )
363 } // end of namespace dai