From: Charlie Vaske Date: Tue, 30 Jun 2009 02:00:48 +0000 (-0700) Subject: Full command line inference and tests X-Git-Tag: v0.2.3~79^2~6 X-Git-Url: http://git.tuebingen.mpg.de/?p=libdai.git;a=commitdiff_plain;h=29c3b85d4e90e1ce7a4194cec38cd72c551f2c24 Full command line inference and tests --- diff --git a/Makefile b/Makefile index 7334c15..a242850 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ SRC=src LIB=lib # Define build targets -TARGETS=tests utils lib examples testregression +TARGETS=tests utils lib examples testregression testem ifdef WITH_DOC TARGETS:=$(TARGETS) doc endif @@ -261,6 +261,9 @@ testregression : tests/testdai$(EE) cd tests && testregression.bat && cd .. endif +testem: tests/testem/testem$(EE) + @echo Starting EM tests + cd tests/testem && ./runtests && cd ../.. # DOCUMENTATION ################ diff --git a/include/dai/emalg.h b/include/dai/emalg.h index 1ce9e00..79f2e02 100644 --- a/include/dai/emalg.h +++ b/include/dai/emalg.h @@ -210,28 +210,68 @@ private: /// The maximization steps to take std::vector _msteps; - size_t _iters; std::vector _lastLogZ; + + size_t _max_iters; + Real _log_z_tol; public: + /// Key for setting maximum iterations @see setTermConditions + static const std::string MAX_ITERS_KEY;//("max_iters"); + /// Default maximum iterations + static const size_t MAX_ITERS_DEFAULT; + /// Key for setting likelihood termination condition @see setTermConditions + static const std::string LOG_Z_TOL_KEY; + /// Default log_z_tol + static const Real LOG_Z_TOL_DEFAULT; + /// Construct an EMAlg from all these objects EMAlg(const Evidence& evidence, InfAlg& estep, - std::vector& msteps) + std::vector& msteps, + PropertySet* termconditions=NULL) : _evidence(evidence), _estep(estep), _msteps(msteps), _iters(0), - _lastLogZ() - {} + _lastLogZ(), + _max_iters(MAX_ITERS_DEFAULT), + _log_z_tol(LOG_Z_TOL_DEFAULT) { + setTermConditions(termconditions); + } /// Construct an EMAlg from an input stream EMAlg(const Evidence& evidence, InfAlg& estep, std::istream& mstep_file); + /// Change the coditions for termination + /** There are two possible parameters in the PropertySety + * - max_iters maximum number of iterations (default 30) + * - log_z_tol proportion of increase in logZ (default 0.01) + * + * \see hasSatisifiedTermConditions() + */ + void setTermConditions(const PropertySet* p); + + /// Determine if the termination conditions have been met. + /** There are two sufficient termination conditions: + * -# the maximum number of iterations has been performed + * -# the ratio of logZ increase over previous logZ is less than the + * tolerance. I.e. + \f$ \frac{\log(Z_{current}) - \log(Z_{previous})} + {| \log(Z_{previous}) | } < tol \f$. + */ + bool hasSatisfiedTermConditions() const; + + size_t getCurrentIters() const { return _iters; } + /// Perform an iteration over all maximization steps Real iterate(); + /// Performs an iteration over a single MaximizationStep - Real iterate(const MaximizationStep& mstep); + Real iterate(MaximizationStep& mstep); + + /// Iterate until termination conditions satisfied + void run(); }; diff --git a/src/emalg.cpp b/src/emalg.cpp index be7f417..fe13a28 100644 --- a/src/emalg.cpp +++ b/src/emalg.cpp @@ -263,14 +263,23 @@ void MaximizationStep::maximize(FactorGraph& fg) { } } +const std::string EMAlg::MAX_ITERS_KEY("max_iters"); +const std::string EMAlg::LOG_Z_TOL_KEY("log_z_tol"); +const size_t EMAlg::MAX_ITERS_DEFAULT = 30; +const Real EMAlg::LOG_Z_TOL_DEFAULT = 0.01; + EMAlg::EMAlg(const Evidence& evidence, InfAlg& estep, std::istream& msteps_file) - : _evidence(evidence), + : _evidence(evidence), _estep(estep), _msteps(), _iters(0), - _lastLogZ() + _lastLogZ(), + _max_iters(MAX_ITERS_DEFAULT), + _log_z_tol(LOG_Z_TOL_DEFAULT) { - size_t num_msteps; + msteps_file.exceptions( std::istream::eofbit | std::istream::failbit + | std::istream::badbit ); + size_t num_msteps = -1; msteps_file >> num_msteps; _msteps.reserve(num_msteps); for (size_t i = 0; i < num_msteps; ++i) { @@ -279,7 +288,41 @@ EMAlg::EMAlg(const Evidence& evidence, InfAlg& estep, std::istream& msteps_file) } } -Real EMAlg::iterate(const MaximizationStep& mstep) { +void EMAlg::setTermConditions(const PropertySet* p) { + if (NULL == p) { + return; + } + if (p->hasKey(MAX_ITERS_KEY)) { + _max_iters = p->getStringAs(MAX_ITERS_KEY); + } + if (p->hasKey(LOG_Z_TOL_KEY)) { + _log_z_tol = p->getStringAs(LOG_Z_TOL_KEY); + } +} + +bool EMAlg::hasSatisfiedTermConditions() const { + if (_iters >= _max_iters) { + return 1; + } else if (_lastLogZ.size() < 3) { + // need at least 2 to calculate ratio + // Also, throw away first iteration, as the parameters may not + // have been normalized according to the estimation method + return 0; + } else { + Real current = _lastLogZ[_lastLogZ.size() - 1]; + Real previous = _lastLogZ[_lastLogZ.size() - 2]; + if (previous == 0) return 0; + Real diff = current - previous; + if (diff < 0) { + std::cerr << "Error: in EM log-likehood decreased from " << previous + << " to " << current << std::endl; + return 1; + } + return diff / abs(previous) <= _log_z_tol; + } +} + +Real EMAlg::iterate(MaximizationStep& mstep) { Evidence::const_iterator e = _evidence.begin(); Real logZ = 0; @@ -312,4 +355,10 @@ Real EMAlg::iterate() { return likelihood; } +void EMAlg::run() { + while(!hasSatisfiedTermConditions()) { + iterate(); + } +} + } diff --git a/src/exceptions.cpp b/src/exceptions.cpp index 6a48ecc..25db247 100644 --- a/src/exceptions.cpp +++ b/src/exceptions.cpp @@ -45,7 +45,7 @@ namespace dai { "Can't parse Evidence line", "Invalid observation in Evidence file", "Invalid variable order in SharedParameters", - "Input line in variable order invalid", + "Variable order line in EM file is invalid", "Unrecognized parameter estimation method" }; diff --git a/tests/testem/3var.em b/tests/testem/3var.em new file mode 100644 index 0000000..dff5410 --- /dev/null +++ b/tests/testem/3var.em @@ -0,0 +1,6 @@ +1 + +1 +ConditionalProbEstimation [target_dim=2,total_dim=8,pseudo_count=1] +1 +0 1 0 2 diff --git a/tests/testem/3var.fg b/tests/testem/3var.fg new file mode 100644 index 0000000..69f0077 --- /dev/null +++ b/tests/testem/3var.fg @@ -0,0 +1,14 @@ +1 + +3 +0 1 2 +2 2 2 +8 +0 0.5 +1 0.5 +2 0.5 +3 0.5 +4 0.5 +5 0.5 +6 0.5 +7 0.5 \ No newline at end of file diff --git a/tests/testem/hoi1_infer_f2.em b/tests/testem/hoi1_infer_f2.em new file mode 100644 index 0000000..977eb5b --- /dev/null +++ b/tests/testem/hoi1_infer_f2.em @@ -0,0 +1,6 @@ +1 + +1 +ConditionalProbEstimation [target_dim=2,total_dim=8,pseudo_count=1] +1 +2 1 2 4 diff --git a/tests/testem/hoi1_share_f0_f1_f2.em b/tests/testem/hoi1_share_f0_f1_f2.em new file mode 100644 index 0000000..1f7756f --- /dev/null +++ b/tests/testem/hoi1_share_f0_f1_f2.em @@ -0,0 +1,8 @@ +1 + +1 +ConditionalProbEstimation [target_dim=2,total_dim=8,pseudo_count=1] +3 +2 1 2 4 +1 0 1 6 +0 6 2 7 diff --git a/tests/testem/hoi1_share_f0_f2.em b/tests/testem/hoi1_share_f0_f2.em new file mode 100644 index 0000000..fd656c6 --- /dev/null +++ b/tests/testem/hoi1_share_f0_f2.em @@ -0,0 +1,7 @@ +1 + +1 +ConditionalProbEstimation [target_dim=2,total_dim=8,pseudo_count=1] +2 +2 1 2 4 +0 6 2 7 diff --git a/tests/testem/runtests b/tests/testem/runtests new file mode 100755 index 0000000..2442db8 --- /dev/null +++ b/tests/testem/runtests @@ -0,0 +1,11 @@ +#!/bin/bash +TMPFILE1=`mktemp /var/tmp/testem.XXXXXX` +trap 'rm -f $TMPFILE1' 0 1 15 + +./testem 2var.fg 2var_data.tab 2var.em > $TMPFILE1 +./testem 3var.fg 2var_data.tab 3var.em >> $TMPFILE1 +./testem ../hoi1.fg hoi1_data.tab hoi1_share_f0_f2.em >> $TMPFILE1 +./testem ../hoi1.fg hoi1_data.tab hoi1_share_f0_f1_f2.em >> $TMPFILE1 +diff -s $TMPFILE1 testem.out || exit 1 + +rm -f $TMPFILE1 diff --git a/tests/testem/testem.cpp b/tests/testem/testem.cpp index 27e7104..f974111 100644 --- a/tests/testem/testem.cpp +++ b/tests/testem/testem.cpp @@ -45,10 +45,15 @@ int main(int argc, char** argv) { ifstream emstream(argv[3]); EMAlg em(e, *inf, emstream); - for (size_t i = 0; i < 10; ++i) { + while(!em.hasSatisfiedTermConditions()) { Real l = em.iterate(); - cout << "Iteration " << i << " likelihood: " << l <fg(); + return 0; } diff --git a/tests/testem/testem.out b/tests/testem/testem.out new file mode 100644 index 0000000..9637392 --- /dev/null +++ b/tests/testem/testem.out @@ -0,0 +1,184 @@ +Number of samples: 20 +Sample sample.0 has 2 observations. +Sample sample.1 has 2 observations. +Sample sample.2 has 2 observations. +Sample sample.3 has 2 observations. +Sample sample.4 has 2 observations. +Sample sample.5 has 2 observations. +Sample sample.6 has 2 observations. +Sample sample.7 has 2 observations. +Sample sample.8 has 2 observations. +Sample sample.9 has 2 observations. +Sample sample_0 has 2 observations. +Sample sample_1 has 2 observations. +Sample sample_2 has 2 observations. +Sample sample_3 has 2 observations. +Sample sample_4 has 2 observations. +Sample sample_5 has 2 observations. +Sample sample_6 has 2 observations. +Sample sample_7 has 2 observations. +Sample sample_8 has 2 observations. +Sample sample_9 has 2 observations. +Iteration 1 likelihood: -13.8629 +Iteration 2 likelihood: -9.56675 +Iteration 3 likelihood: -9.56675 + +Inferred Factor Graph: +###################### +1 + +2 +0 1 +2 2 +4 +0 0.16666666666667 +1 0.66666666666667 +2 0.83333333333333 +3 0.33333333333333 +Number of samples: 20 +Sample sample.0 has 2 observations. +Sample sample.1 has 2 observations. +Sample sample.2 has 2 observations. +Sample sample.3 has 2 observations. +Sample sample.4 has 2 observations. +Sample sample.5 has 2 observations. +Sample sample.6 has 2 observations. +Sample sample.7 has 2 observations. +Sample sample.8 has 2 observations. +Sample sample.9 has 2 observations. +Sample sample_0 has 2 observations. +Sample sample_1 has 2 observations. +Sample sample_2 has 2 observations. +Sample sample_3 has 2 observations. +Sample sample_4 has 2 observations. +Sample sample_5 has 2 observations. +Sample sample_6 has 2 observations. +Sample sample_7 has 2 observations. +Sample sample_8 has 2 observations. +Sample sample_9 has 2 observations. +Iteration 1 likelihood: 0 +Iteration 2 likelihood: 3.97035 +Iteration 3 likelihood: 3.97035 + +Inferred Factor Graph: +###################### +1 + +3 +0 1 2 +2 2 2 +8 +0 0.21428571428571 +1 0.64285714285714 +2 0.78571428571429 +3 0.35714285714286 +4 0.21428571428571 +5 0.64285714285714 +6 0.78571428571429 +7 0.35714285714286 +Number of samples: 5 +Sample sample_0 has 5 observations. +Sample sample_1 has 4 observations. +Sample sample_2 has 6 observations. +Sample sample_3 has 6 observations. +Sample sample_4 has 5 observations. +Iteration 1 likelihood: 11.1646 +Iteration 2 likelihood: 1.53723 +Iteration 3 likelihood: 1.64691 +Iteration 4 likelihood: 1.67497 +Iteration 5 likelihood: 1.68191 + +Inferred Factor Graph: +###################### +3 + +3 +2 6 7 +2 2 2 +8 +0 0.39834336080566 +1 0.35146414656547 +2 0.60165663919434 +3 0.64853585343454 +4 0.80484468795651 +5 0.67374245802992 +6 0.19515531204349 +7 0.32625754197008 + +3 +0 1 6 +2 2 2 +8 +0 1.0352133626924 +1 1.547478952122 +2 2.3176521897449 +3 1.2804190071868 +4 4.9220798130027 +5 2.5272557501946 +6 0.83127929631575 +7 0.26280563080263 + +3 +1 2 4 +2 2 2 +8 +0 0.39834336080566 +1 0.60165663919434 +2 0.35146414656547 +3 0.64853585343454 +4 0.80484468795651 +5 0.19515531204349 +6 0.67374245802992 +7 0.32625754197008 +Number of samples: 5 +Sample sample_0 has 5 observations. +Sample sample_1 has 4 observations. +Sample sample_2 has 6 observations. +Sample sample_3 has 6 observations. +Sample sample_4 has 5 observations. +Iteration 1 likelihood: 11.1646 +Iteration 2 likelihood: -7.29331 +Iteration 3 likelihood: -7.261 + +Inferred Factor Graph: +###################### +3 + +3 +2 6 7 +2 2 2 +8 +0 0.49531219972645 +1 0.49910794290825 +2 0.50468780027355 +3 0.50089205709175 +4 0.64041654995841 +5 0.49512229392399 +6 0.35958345004159 +7 0.50487770607601 + +3 +0 1 6 +2 2 2 +8 +0 0.49531219972645 +1 0.50468780027355 +2 0.49910794290825 +3 0.50089205709175 +4 0.64041654995841 +5 0.35958345004159 +6 0.49512229392399 +7 0.50487770607601 + +3 +1 2 4 +2 2 2 +8 +0 0.49531219972645 +1 0.50468780027355 +2 0.49910794290825 +3 0.50089205709175 +4 0.64041654995841 +5 0.35958345004159 +6 0.49512229392399 +7 0.50487770607601