First working version of EM
authorCharlie Vaske <cvaske@cvaske.gateway.2wire.net>
Mon, 29 Jun 2009 02:07:49 +0000 (19:07 -0700)
committerCharlie Vaske <cvaske@cvaske.gateway.2wire.net>
Mon, 29 Jun 2009 02:07:49 +0000 (19:07 -0700)
15 files changed:
Makefile
include/dai/alldai.h
include/dai/emalg.h [new file with mode: 0644]
include/dai/evidence.h [new file with mode: 0644]
include/dai/exceptions.h
include/dai/util.h
src/emalg.cpp [new file with mode: 0644]
src/evidence.cpp [new file with mode: 0644]
src/exceptions.cpp
src/util.cpp
tests/testem/2var.em [new file with mode: 0644]
tests/testem/2var.fg [new file with mode: 0644]
tests/testem/2var_data.tab [new file with mode: 0644]
tests/testem/hoi1_data.tab [new file with mode: 0644]
tests/testem/testem.cpp [new file with mode: 0644]

index bdfbdf1..7334c15 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -45,7 +45,7 @@ ifdef WITH_MATLAB
 endif
 
 # Define conditional build targets
-OBJECTS:=exactinf$(OE)
+OBJECTS:=exactinf$(OE) evidence$(OE) emalg$(OE)
 ifdef WITH_BP
   CCFLAGS:=$(CCFLAGS) -DDAI_WITH_BP
   OBJECTS:=$(OBJECTS) bp$(OE)
@@ -104,7 +104,7 @@ examples : examples/example$(EE) examples/example_bipgraph$(EE) examples/example
 
 matlabs : matlab/dai$(ME) matlab/dai_readfg$(ME) matlab/dai_writefg$(ME) matlab/dai_potstrength$(ME)
 
-tests : tests/testdai$(EE)
+tests : tests/testdai$(EE) tests/testem/testem$(EE)
 
 utils : utils/createfg$(EE) utils/fg2dot$(EE) utils/fginfo$(EE)
 
@@ -162,6 +162,12 @@ mr$(OE) : $(SRC)/mr.cpp $(INC)/mr.h $(HEADERS)
 gibbs$(OE) : $(SRC)/gibbs.cpp $(INC)/gibbs.h $(HEADERS)
        $(CC) -c $(SRC)/gibbs.cpp
 
+evidence$(OE) : $(SRC)/evidence.cpp $(INC)/evidence.h $(HEADERS)
+       $(CC) -c $(SRC)/evidence.cpp
+
+emalg$(OE) : $(SRC)/emalg.cpp $(INC)/emalg.h $(INC)/evidence.h $(HEADERS)
+       $(CC) -c $(SRC)/emalg.cpp
+
 properties$(OE) : $(SRC)/properties.cpp $(HEADERS)
        $(CC) -c $(SRC)/properties.cpp
 
@@ -193,7 +199,8 @@ examples/example_sprinkler$(EE) : examples/example_sprinkler.cpp $(HEADERS) $(LI
 
 tests/testdai$(EE) : tests/testdai.cpp $(HEADERS) $(LIB)/libdai$(LE)
        $(CC) $(CCO)tests/testdai$(EE) tests/testdai.cpp $(LIBS) $(BOOSTLIBS)
-
+tests/testem/testem$(EE): tests/testem/testem.cpp $(HEADERS) $(LIB)/libdai$(LE)
+       $(CC) $(CCO)$@ $< $(LIBS) $(BOOSTLIBS)
 
 # MATLAB INTERFACE
 ###################
index 430e8eb..c0010ca 100644 (file)
@@ -33,6 +33,8 @@
 #include <dai/daialg.h>
 #include <dai/properties.h>
 #include <dai/exactinf.h>
+#include <dai/evidence.h>
+#include <dai/emalg.h>
 #ifdef DAI_WITH_BP
     #include <dai/bp.h>
 #endif
diff --git a/include/dai/emalg.h b/include/dai/emalg.h
new file mode 100644 (file)
index 0000000..1ce9e00
--- /dev/null
@@ -0,0 +1,240 @@
+/*
+  Copyright 2009 Charles Vaske <cvaske@soe.ucsc.edu>
+  University of California Santa Cruz
+
+  This program is free software: you can redistribute it and/or modify
+  it under the terms of the GNU General Public License as published by
+  the Free Software Foundation, either version 3 of the License, or
+  (at your option) any later version.
+  
+  This program is distributed in the hope that it will be useful,
+  but WITHOUT ANY WARRANTY; without even the implied warranty of
+  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+  GNU General Public License for more details.
+  
+  You should have received a copy of the GNU General Public License
+  along with this program.  If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#ifndef __defined_libdai_emalg_h
+#define __defined_libdai_emalg_h
+
+#include<vector>
+#include<map>
+
+#include <dai/factor.h>
+#include <dai/daialg.h>
+#include <dai/evidence.h>
+#include <dai/index.h>
+#include <dai/properties.h>
+
+
+/// \file 
+/** \brief Defines classes related to Expectation Maximization:
+ *  EMAlg, ParameterEstimate, and FactorOrientations
+ */
+
+namespace dai {
+
+///Interface for a parameter estimation method. 
+/** This parameter estimation interface is based on sufficient statistics. 
+ *  Implementations are responsible for collecting data from a probability 
+ *  vector passed to it from a SharedParameters container object.
+ *
+ *  Implementations of this interface should register a factory function
+ *  via the static ParameterEstimation::registerMethod function.
+ */
+class ParameterEstimation {
+public:  
+  /// A pointer to a factory function.
+  typedef ParameterEstimation* (*ParamEstFactory)(const PropertySet&);
+
+  /// General factory method for construction of ParameterEstimation subclasses.
+  static ParameterEstimation* construct(const std::string& method, 
+                                       const PropertySet& p);
+  /// Register a subclass with ParameterEstimation::construct.
+  static void registerMethod(const std::string method, 
+                            const ParamEstFactory f) {
+    if (_registry == NULL) {
+      loadDefaultRegistry();
+    }
+    (*_registry)[method] = f;
+  }
+  /// Virtual destructor for deleting pointers to derived classes.
+  virtual ~ParameterEstimation() {}
+  /// Estimate the factor using the accumulated sufficient statistics and reset.
+  virtual Prob estimate() = 0;
+  /// Accumulate the sufficient statistics for p.
+  virtual void addSufficientStatistics(Prob& p) = 0;
+  /// Returns the size of the Prob that is passed to addSufficientStatistics.
+  virtual size_t probSize() const = 0;
+  /// A virtual copy constructor.
+  virtual ParameterEstimation* clone() const= 0;
+private:
+  static std::map< std::string, ParamEstFactory >* _registry;
+  static void loadDefaultRegistry();
+};
+
+/// Estimates the parameters of a conditional probability, using pseudocounts.
+class CondProbEstimation : private ParameterEstimation {
+private:
+  size_t _target_dim;
+  Prob _stats;
+  Prob _initial_stats;
+public:
+  /** For a conditional probability \f$ Pr( X | Y ) \f$, 
+   *  \param target_dimension should equal \f$ | X | \f$
+   *  \param pseudocounts has length \f$ |X| \cdot |Y| \f$
+   */
+  CondProbEstimation(size_t target_dimension, Prob pseudocounts);
+
+  /// Virtual constructor, using a PropertySet.
+  /** Some keys in the PropertySet are required:
+   *     - target_dimension, which should be equal to \f$ | X | \f$
+   *     - total_dimension, which sholud be equal to \f$ |X| \cdot |Y| \f$
+   *  
+   *  An optional key is:
+   *     - pseudo_count which specifies the initial counts (defaults to 1)
+   */
+  static ParameterEstimation* factory(const PropertySet& p);
+  /// Virtual destructor
+  virtual ~CondProbEstimation() {}
+  /// Returns an estimate of the conditional probability distribution.
+  /** The format of the resulting Prob keeps all the values for 
+   *  \f$ P(X | Y=a) \f$ sequential in teh array.
+   */
+  virtual Prob estimate();
+  /// Accumulate sufficient statistics from the expectations in p.
+  virtual void addSufficientStatistics(Prob& p);
+  /// Returns the required size for arguments to addSufficientStatistics
+  virtual size_t probSize() const { return _stats.size(); }
+  /// Virtual copy constructor.
+  virtual ParameterEstimation* clone() const {
+    return new CondProbEstimation(_target_dim, _initial_stats);
+  }
+};
+
+/** A single factor or set of factors whose parameters should be
+ *  estimated.  Each factor's values are reordered to match a
+ *  canonical variable ordering.  This canonical variable ordering
+ *  will likely not be the order of variables required to make two
+ *  factors parameters isomorphic.  Therefore, this ordering of the
+ *  variables must be specified for ever factor to ensure that
+ *  parameters can be shared between different factors during EM.
+ */
+class SharedParameters {
+public:
+  /// Convenience label for an index into a FactorGraph to a factor.
+  typedef size_t FactorIndex;
+  /// Convenience label for a grouping of factor orientations.
+  typedef std::map< FactorIndex, std::vector< Var > > FactorOrientations;
+private:
+  std::map< FactorIndex, VarSet > _varsets;
+  std::map< FactorIndex, Permute > _perms;
+  FactorOrientations _varorders;
+  ParameterEstimation* _estimation;
+  bool _deleteEstimation;
+
+  static Permute calculatePermutation(const std::vector< Var >& varorder,
+                                     const std::vector< size_t >& dims,
+                                     VarSet& outVS);
+  void setPermsAndVarSetsFromVarOrders();
+public:
+  /// Copy constructor
+  SharedParameters(const SharedParameters& sp);
+  /// Constructor useful in programmatic settings 
+  /** \param varorders  all the factor orientations for this parameter
+      \param estimation a pointer to the parameter estimation method
+   */ 
+  SharedParameters(const FactorOrientations& varorders,
+                  ParameterEstimation* estimation);
+
+  /// Constructor for making an object from a stream
+  SharedParameters(std::istream& is, const FactorGraph& fg_varlookup);
+
+  /// Destructor
+  ~SharedParameters() { if (_deleteEstimation) delete _estimation; }
+
+  /// Collect the necessary statistics from expected values
+  void collectSufficientStatistics(InfAlg& alg);
+
+  /// Estimate and set the shared parameters
+  void setParameters(FactorGraph& fg);
+};
+
+/** A maximization step groups together several parameter estimation
+ * tasks into a single unit.
+ */
+class MaximizationStep { 
+private:
+  std::vector< SharedParameters > _params;
+public:
+  MaximizationStep() : _params() {}
+
+  /// Construct an step object taht contains all these estimation probelms
+  MaximizationStep(std::vector< SharedParameters >& maximizations) : 
+    _params(maximizations) {}  
+
+  /// Construct a step from an input stream
+  MaximizationStep(std::istream& is, const FactorGraph& fg_varlookup);
+  
+  /** Collect the beliefs from this InfAlg as expectations for
+   *  the next Maximization step.
+   */
+  void addExpectations(InfAlg& alg);
+
+  /** Using all of the currently added expectations, make new factors 
+   *  with maximized parameters and set them in the FactorGraph.
+   */
+  void maximize(FactorGraph& fg);
+};
+
+/// EMAlg performs Expectation Maximization to learn factor parameters.
+/** This requires specifying:
+ *     - Evidence (instances of observations from the graphical model),
+ *     - InfAlg for performing the E-step, which includes the factor graph
+ *     - a vector of MaximizationSteps steps to be performed
+ *
+ *  This implementation can peform incremental EM by using multiple 
+ *  MaximizationSteps.  An expectation step is performed between execution
+ *  of each MaximizationStep.  A call to iterate() will cycle through all
+ *  MaximizationSteps.
+ */  
+class EMAlg {
+private:
+  /// All the data samples used during learning
+  const Evidence& _evidence;
+  
+  /// How to do the expectation step
+  InfAlg& _estep;
+
+  /// The maximization steps to take
+  std::vector<MaximizationStep> _msteps;
+
+  size_t _iters;
+  std::vector<Real> _lastLogZ;
+
+public:
+  /// Construct an EMAlg from all these objects
+  EMAlg(const Evidence& evidence, InfAlg& estep, 
+       std::vector<MaximizationStep>& msteps) 
+    : _evidence(evidence),
+      _estep(estep),
+      _msteps(msteps),
+      _iters(0),
+      _lastLogZ() 
+  {}
+  
+  /// Construct an EMAlg from an input stream
+  EMAlg(const Evidence& evidence, InfAlg& estep, std::istream& mstep_file);
+
+  /// Perform an iteration over all maximization steps
+  Real iterate();
+  /// Performs an iteration over a single MaximizationStep
+  Real iterate(const MaximizationStep& mstep);
+
+};
+
+} // namespace dai
+
+#endif
diff --git a/include/dai/evidence.h b/include/dai/evidence.h
new file mode 100644 (file)
index 0000000..4353c6b
--- /dev/null
@@ -0,0 +1,86 @@
+/*
+  Copyright 2009 Charles Vaske <cvaske@soe.ucsc.edu>
+  University of California Santa Cruz
+
+  This program is free software: you can redistribute it and/or modify
+  it under the terms of the GNU General Public License as published by
+  the Free Software Foundation, either version 3 of the License, or
+  (at your option) any later version.
+  
+  This program is distributed in the hope that it will be useful,
+  but WITHOUT ANY WARRANTY; without even the implied warranty of
+  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+  GNU General Public License for more details.
+  
+  You should have received a copy of the GNU General Public License
+  along with this program.  If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#ifndef __defined_libdai_evidence_h
+#define __defined_libdai_evidence_h
+
+#include <istream>
+
+#include <dai/daialg.h>
+
+namespace dai {
+
+/// Store joint observations on a graphical model.
+class SampleData {
+private:
+  std::string _name;
+  std::map<Var, size_t> _obs;
+public:
+  /// Construct an empty object
+  SampleData() : _name(), _obs() {}
+  /// Set the name of the sample
+  void name(const std::string& name) { _name = name; }
+  /// Get the name of the sample
+  const std::string& name() const { return _name; }
+  /// Read from the observation map
+  const std::map<Var, size_t>& observations() const { return _obs; }
+  /// Add an observation
+  void addObservation(Var node, size_t setting);
+  /// Add evidence by clamping variables to observed values.
+  void applyEvidence(InfAlg& alg) const;
+};
+
+/// Store observations from a graphical model.
+class Evidence {
+private:
+  std::map< std::string, SampleData > _samples;
+public:
+  /// Start with empty obects, then fill with calls to addEvidenceTabFile()
+ Evidence() : _samples() {}
+  
+  /** Read in tab-data from a stream. Each line contains one sample, and
+   * the first line is a header line with names. The first column contains
+   * names for each of the samples.
+   */
+  void addEvidenceTabFile(std::istream& is,
+                         std::map< std::string, Var >& varMap);
+
+  /** Read in tab-data from a stream. Each line contains one sample,
+   * and the first line is a header line with variable IDs. The first
+   * column contains names for each of the samples.
+   */
+  void addEvidenceTabFile(std::istream& is, FactorGraph& fg);
+  
+  /// Total number of samples in this evidence file
+  size_t nrSamples() const { return _samples.size(); }
+
+  /// @name iterator interface
+  //@{
+  typedef std::map< std::string, SampleData >::iterator iterator;
+  typedef std::map< std::string, SampleData >::const_iterator const_iterator;
+  iterator begin() { return _samples.begin(); }
+  const_iterator begin() const { return _samples.begin(); }
+  iterator end() { return _samples.end(); }
+  const_iterator end() const { return _samples.end(); }
+  //@}
+
+};
+  
+}
+
+#endif
index 5f5681c..86e9b1a 100644 (file)
@@ -72,6 +72,12 @@ class Exception : public std::runtime_error {
                    IMPOSSIBLE_TYPECAST,
                    INTERNAL_ERROR,
                    NOT_NORMALIZABLE,
+                  INVALID_EVIDENCE_FILE,
+                  INVALID_EVIDENCE_LINE,
+                  INVALID_EVIDENCE_OBSERVATION,
+                  INVALID_SHARED_PARAMETERS_ORDER,
+                  INVALID_SHARED_PARAMETERS_INPUT_LINE,
+                  UNKNOWN_PARAMETER_ESTIMATION_METHOD,
                    NUM_ERRORS};  // NUM_ERRORS should be the last entry
 
         /// Constructor
index ca3e173..6c2ee01 100644 (file)
@@ -195,6 +195,10 @@ class Diffs : public std::vector<double> {
         size_t maxSize() { return _maxsize; }
 };
 
+/// Split a string into tokens
+void tokenizeString(const std::string& s,
+                   std::vector<std::string>& outTokens,
+                   const std::string& delim="\t\n");
 
 } // end of namespace dai
 
diff --git a/src/emalg.cpp b/src/emalg.cpp
new file mode 100644 (file)
index 0000000..be7f417
--- /dev/null
@@ -0,0 +1,315 @@
+#include <dai/util.h>
+
+#include <dai/emalg.h>
+
+namespace dai{
+
+std::map< std::string, ParameterEstimation::ParamEstFactory>* 
+ParameterEstimation::_registry = NULL;
+
+void ParameterEstimation::loadDefaultRegistry() {
+  _registry = new std::map< std::string, ParamEstFactory>();
+  (*_registry)["ConditionalProbEstimation"] = CondProbEstimation::factory;
+}
+
+ParameterEstimation* ParameterEstimation::construct(const std::string& method,
+                                                   const PropertySet& p) {
+  if (_registry == NULL) {
+    loadDefaultRegistry();
+  }
+  std::map< std::string, ParamEstFactory>::iterator i = _registry->find(method);
+  if (i == _registry->end()) {
+    DAI_THROW(UNKNOWN_PARAMETER_ESTIMATION_METHOD);
+  }
+  ParamEstFactory factory = i->second;
+  return factory(p);
+}
+
+ParameterEstimation* CondProbEstimation::factory(const PropertySet& p) {
+  size_t target_dimension =  p.getStringAs<size_t>("target_dim");
+  size_t total_dimension = p.getStringAs<size_t>("total_dim");
+  Real pseudo_count = 1;
+  if (p.hasKey("pseudo_count")) {
+    pseudo_count = p.getStringAs<Real>("pseudo_count");
+  }
+  Prob counts_vec(total_dimension, pseudo_count);
+  return new CondProbEstimation(target_dimension, counts_vec);
+}
+
+CondProbEstimation::CondProbEstimation(size_t target_dimension, 
+                                      Prob pseudocounts) 
+  : _target_dim(target_dimension),
+    _stats(pseudocounts),
+    _initial_stats(pseudocounts) {
+  if (_stats.size() % _target_dim) {
+    DAI_THROW(MALFORMED_PROPERTY);
+  }
+}
+
+void CondProbEstimation::addSufficientStatistics(Prob& p) {
+  _stats += p;
+}
+
+Prob CondProbEstimation::estimate() {
+  for (size_t parent = 0; parent < _stats.size(); parent += _target_dim) {
+    Real norm = 0;
+    size_t top = parent + _target_dim;
+    for (size_t i = parent; i < top; ++i) {
+      norm += _stats[i];
+    }
+    if (norm != 0) {
+      norm = 1 / norm;
+    }
+    for (size_t i = parent; i < top; ++i) {
+      _stats[i] *= norm;
+    }
+  }
+  Prob result = _stats;
+  _stats = _initial_stats;
+  return result;
+}
+
+Permute
+SharedParameters::calculatePermutation(const std::vector< Var >& varorder,
+                                      const std::vector< size_t >& dims,
+                                      VarSet& outVS) {
+  std::vector<long> labels(dims.size());
+  
+  // Check that the variable set is compatible
+  if (varorder.size() != dims.size()) {
+    DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
+  }
+  
+  // Collect all labels, and order them in vs
+  for (size_t di = 0; di < dims.size(); ++di) {
+    if (dims[di] != varorder[di].states()) {
+      DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
+    }
+    outVS |= varorder[di];
+    labels[di] = varorder[di].label();
+  }
+  
+  // Construct the sigma array for the permutation object
+  std::vector<size_t> sigma(dims.size(), 0);
+  VarSet::iterator set_iterator = outVS.begin();
+  for  (size_t vs_i = 0; vs_i < dims.size(); ++vs_i, ++set_iterator) {
+    std::vector< long >::iterator location = find(labels.begin(), labels.end(),
+                                                 set_iterator->label());
+    sigma[vs_i] = location - labels.begin();
+  }
+  
+  return Permute(dims, sigma);
+}
+
+void SharedParameters::setPermsAndVarSetsFromVarOrders() {
+  if (_varorders.size() == 0) {
+    return;
+  }
+  FactorOrientations::const_iterator foi = _varorders.begin();
+  std::vector< size_t > dims(foi->second.size());
+  size_t total_dim = 1;
+  for (size_t i = 0; i < dims.size(); ++i) {
+    dims[i] = foi->second[i].states();
+    total_dim *= dims[i];
+  }
+  
+  // Construct the permutation objects
+  for ( ; foi != _varorders.end(); ++foi) {
+    VarSet vs;
+    _perms[foi->first] = calculatePermutation(foi->second, dims, vs);
+    _varsets[foi->first] = vs;
+  }
+  
+  if  (_estimation == NULL || _estimation->probSize() != total_dim) {
+    DAI_THROW(INVALID_SHARED_PARAMETERS_ORDER);
+  }
+}
+
+SharedParameters::SharedParameters(std::istream& is,
+                                  const FactorGraph& fg_varlookup)
+  : _varsets(),
+    _perms(),
+    _varorders(),
+    _estimation(NULL),
+    _deleteEstimation(1) 
+{
+  std::string est_method;
+  PropertySet props;
+  is >> est_method;
+  is >> props;
+
+  _estimation = ParameterEstimation::construct(est_method, props);
+
+  size_t num_factors;
+  is >> num_factors;
+  for (size_t sp_i = 0; sp_i < num_factors; ++sp_i) {
+    std::string line;
+    std::vector< std::string > fields;
+    size_t factor;
+    std::vector< Var > var_order;
+    std::istringstream iss;
+
+    while(line.size() == 0 && getline(is, line));
+    tokenizeString(line, fields, " \t");
+
+    // Lookup the factor in the factorgraph
+    if (fields.size() < 1) { 
+      DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
+    }
+    iss.str(fields[0]);
+    iss >> factor;
+    const VarSet& vs = fg_varlookup.factor(factor).vars();
+    if (fields.size() != vs.size() + 1) {
+      DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
+    }
+
+    // Construct the vector of Vars
+    for (size_t fi = 1; fi < fields.size(); ++fi) {
+      // Lookup a single variable by label
+      long label;
+      std::istringstream labelparse(fields[fi]);
+      labelparse >> label;
+      VarSet::const_iterator vsi = vs.begin();
+      for ( ; vsi != vs.end(); ++vsi) {
+       if (vsi->label() == label) break;
+      }
+      if (vsi == vs.end()) {
+       DAI_THROW(INVALID_SHARED_PARAMETERS_INPUT_LINE);
+      }
+      var_order.push_back(*vsi);
+    }
+    _varorders[factor] = var_order;
+  }
+  setPermsAndVarSetsFromVarOrders();
+}
+
+SharedParameters::SharedParameters(const SharedParameters& sp) 
+  : _varsets(sp._varsets),
+    _perms(sp._perms),
+    _varorders(sp._varorders),
+    _estimation(sp._estimation),
+    _deleteEstimation(sp._deleteEstimation)
+{
+  if (_deleteEstimation) {
+    _estimation = _estimation->clone();
+  }
+}
+
+SharedParameters::SharedParameters(const FactorOrientations& varorders,
+                                  ParameterEstimation* estimation) 
+  : _varsets(),
+    _perms(),
+    _varorders(varorders),
+    _estimation(estimation),
+    _deleteEstimation(0) 
+{
+  setPermsAndVarSetsFromVarOrders();
+}
+
+void SharedParameters::collectSufficientStatistics(InfAlg& alg) {
+  std::map< FactorIndex, Permute >::iterator i = _perms.begin();
+  for ( ; i != _perms.end(); ++i) {
+    Permute& perm = i->second;
+    VarSet& vs = _varsets[i->first];
+    
+    Factor b = alg.belief(vs);
+    Prob p(b.states(), 0.0);
+    for (size_t entry = 0; entry < b.states(); ++entry) {
+      p[entry] = b[perm.convert_linear_index(entry)];
+    }
+    _estimation->addSufficientStatistics(p);
+  }
+}
+
+void SharedParameters::setParameters(FactorGraph& fg) {
+  Prob p = _estimation->estimate();
+  std::map< FactorIndex, Permute >::iterator i = _perms.begin();
+  for ( ; i != _perms.end(); ++i) {
+    Permute& perm = i->second;
+    VarSet& vs = _varsets[i->first];
+    
+    Factor f(vs, 0.0);
+    for (size_t entry = 0; entry < f.states(); ++entry) {
+      f[perm.convert_linear_index(entry)] = p[entry];
+    }
+
+    fg.setFactor(i->first, f);
+  }
+}
+
+MaximizationStep::MaximizationStep (std::istream& is,
+                                   const FactorGraph& fg_varlookup ) 
+  : _params()
+{
+  size_t num_params = -1;
+  is >> num_params;
+  _params.reserve(num_params);
+  for (size_t i = 0; i < num_params; ++i) {
+    SharedParameters p(is, fg_varlookup);
+    _params.push_back(p);
+  }
+}
+
+
+void MaximizationStep::addExpectations(InfAlg& alg) {
+  for (size_t i = 0; i < _params.size(); ++i) {
+    _params[i].collectSufficientStatistics(alg);
+  }
+}
+
+void MaximizationStep::maximize(FactorGraph& fg) {
+  for (size_t i = 0; i < _params.size(); ++i) {
+    _params[i].setParameters(fg);
+  }
+}
+
+EMAlg::EMAlg(const Evidence& evidence, InfAlg& estep, std::istream& msteps_file)
+   : _evidence(evidence),
+    _estep(estep),
+    _msteps(),
+    _iters(0),
+    _lastLogZ()
+{
+  size_t num_msteps;
+  msteps_file >> num_msteps;
+  _msteps.reserve(num_msteps);
+  for (size_t i = 0; i < num_msteps; ++i) {
+    MaximizationStep m(msteps_file, estep.fg());
+    _msteps.push_back(m);
+  }
+}      
+
+Real EMAlg::iterate(const MaximizationStep& mstep) {
+  Evidence::const_iterator e = _evidence.begin();
+  Real logZ = 0;
+
+  // Expectation calculation
+  for ( ; e != _evidence.end(); ++e) {
+    InfAlg* clamped = _estep.clone();
+    e->second.applyEvidence(*clamped);
+    clamped->run();
+    
+    logZ += clamped->logZ();
+
+    mstep.addExpectations(*clamped);
+
+    delete clamped;
+  }
+  
+  // Maximization of parameters
+  mstep.maximize(_estep.fg());
+
+  return logZ;
+}
+
+Real EMAlg::iterate() {
+  Real likelihood;
+  for (size_t i = 0; i < _msteps.size(); ++i) {
+    likelihood = iterate(_msteps[i]);
+  }
+  _lastLogZ.push_back(likelihood);
+  ++_iters;
+  return likelihood;
+}
+
+}
diff --git a/src/evidence.cpp b/src/evidence.cpp
new file mode 100644 (file)
index 0000000..edafbdb
--- /dev/null
@@ -0,0 +1,83 @@
+#include <sstream>
+#include <string>
+#include <cstdlib>
+
+#include <dai/util.h>
+
+#include <dai/evidence.h>
+
+namespace dai {
+
+void SampleData::addObservation(Var node, size_t setting) {
+    _obs[node] = setting;
+}
+  
+void SampleData::applyEvidence(InfAlg& alg) const {
+  std::map< Var, size_t>::const_iterator i = _obs.begin();
+  for( ; i != _obs.end(); ++i) {
+    alg.clamp(i->first, i->second);
+  }
+}
+  
+void Evidence::addEvidenceTabFile(std::istream& is, FactorGraph& fg) {
+  std::map< std::string, Var > varMap;
+  std::vector< Var >::const_iterator v = fg.vars().begin();
+  for(; v != fg.vars().end(); ++v) {
+    std::stringstream s;
+    s << v->label();
+    varMap[s.str()] = *v;
+  }
+
+  addEvidenceTabFile(is, varMap);
+}
+
+void Evidence::addEvidenceTabFile(std::istream& is, 
+                                 std::map< std::string, Var >& varMap) {
+  
+  std::vector< std::string > header_fields;
+  std::vector< Var > vars;
+  std::string line;
+  getline(is, line);
+  
+  // Parse header
+  tokenizeString(line, header_fields);
+  std::vector< std::string >::const_iterator p_field = header_fields.begin();
+
+  if (p_field == header_fields.end()) { DAI_THROW(INVALID_EVIDENCE_LINE); }
+
+  ++p_field; // first column are sample labels
+  for ( ; p_field != header_fields.end(); ++p_field) {
+    std::map< std::string, Var >::iterator elem = varMap.find(*p_field);
+    if (elem == varMap.end()) {
+      DAI_THROW(INVALID_EVIDENCE_FILE);
+    }
+    vars.push_back(elem->second);
+  }
+  
+  
+  // Read samples
+  while(getline(is, line)) {
+    std::vector< std::string > fields;
+
+    tokenizeString(line, fields);
+    
+    if (fields.size() != vars.size() + 1) { DAI_THROW(INVALID_EVIDENCE_LINE); }
+    
+    SampleData& sampleData = _samples[fields[0]];
+    sampleData.name(fields[0]); // in case of a new sample
+    for (size_t i = 0; i < vars.size(); ++i) {
+      if (fields[i+1].size() > 0) { // skip if missing observation
+       if (fields[i+1].find_first_not_of("0123456789") != std::string::npos) {
+         DAI_THROW(INVALID_EVIDENCE_OBSERVATION);
+       }
+       size_t state = atoi(fields[i+1].c_str());
+       if (state >= vars[i].states()) {
+         DAI_THROW(INVALID_EVIDENCE_OBSERVATION);
+       }
+       sampleData.addObservation(vars[i], state);
+      }
+    }
+  } // finished sample line
+}
+
+}
index 738a4ad..6a48ecc 100644 (file)
@@ -40,8 +40,15 @@ namespace dai {
         "FactorGraph is not connected",
         "Impossible typecast",
         "Internal error",
-        "Quantity not normalizable"
+        "Quantity not normalizable",
+       "Can't parse Evidence file",
+       "Can't parse Evidence line",
+       "Invalid observation in Evidence file",
+       "Invalid variable order in SharedParameters",
+       "Input line in variable order invalid",
+       "Unrecognized parameter estimation method"
     }; 
 
 
 }
+
index f15db26..7121e28 100644 (file)
@@ -105,5 +105,19 @@ int rnd_int( int min, int max ) {
     return (int)floor(_uni_rnd() * (max + 1 - min) + min);
 }
 
+void tokenizeString(const std::string& s,
+                   std::vector<std::string>& outTokens,
+                   const std::string& delim)
+{
+  size_t start = 0;
+  while (start < s.size()) {
+    size_t end = s.find_first_of(delim, start);
+    if (end > s.size()) {
+      end = s.size();
+    }
+    outTokens.push_back(s.substr(start, end - start));
+    start = end + 1;
+  }
+}
 
 } // end of namespace dai
diff --git a/tests/testem/2var.em b/tests/testem/2var.em
new file mode 100644 (file)
index 0000000..457ebe1
--- /dev/null
@@ -0,0 +1,6 @@
+1
+
+1
+ConditionalProbEstimation [target_dim=2,total_dim=4,pseudo_count=1]
+1
+0 1 0
diff --git a/tests/testem/2var.fg b/tests/testem/2var.fg
new file mode 100644 (file)
index 0000000..279870c
--- /dev/null
@@ -0,0 +1,10 @@
+1
+
+2
+0 1
+2 2
+4
+0      0.5
+1      0.5
+2      0.5
+3      0.5
\ No newline at end of file
diff --git a/tests/testem/2var_data.tab b/tests/testem/2var_data.tab
new file mode 100644 (file)
index 0000000..1b7593b
--- /dev/null
@@ -0,0 +1,21 @@
+sample_id      0       1
+sample_0       0       1
+sample_1       0       1
+sample_2       0       1
+sample_3       0       1
+sample_4       0       1
+sample_5       0       1
+sample_6       0       1
+sample_7       0       1
+sample_8       0       1
+sample_9       0       0
+sample.0       1       1
+sample.1       1       1
+sample.2       1       1
+sample.3       1       0
+sample.4       1       0
+sample.5       1       0
+sample.6       1       0
+sample.7       1       0
+sample.8       1       0
+sample.9       1       0
\ No newline at end of file
diff --git a/tests/testem/hoi1_data.tab b/tests/testem/hoi1_data.tab
new file mode 100644 (file)
index 0000000..b103867
--- /dev/null
@@ -0,0 +1,6 @@
+sample_id      0       1       2       4       6       7
+sample_0       0       0       1               1       0
+sample_1       1                       1       1       0
+sample_2       0       1       0       0       0       1
+sample_3       0       0       0       1       0       0
+sample_4       1       0       0               1       0
diff --git a/tests/testem/testem.cpp b/tests/testem/testem.cpp
new file mode 100644 (file)
index 0000000..27e7104
--- /dev/null
@@ -0,0 +1,54 @@
+#include<iostream>
+#include<fstream>
+#include<string>
+
+#include<dai/factorgraph.h>
+#include<dai/evidence.h>
+#include<dai/alldai.h>
+
+using namespace std;
+using namespace dai;
+
+void usage(const string& msg) {
+  cerr << msg << endl;
+  cerr << "Usage:" << endl;
+  cerr << " testem factorgraph.fg evidence.tab emconfig.em" << endl;
+  exit(1);
+}
+
+int main(int argc, char** argv) {
+  if (argc != 4) {
+    usage("Incorrect number of arguments.");
+  }
+  
+  FactorGraph fg;
+  ifstream fgstream(argv[1]);
+  fgstream >> fg;
+
+  PropertySet infprops;
+  infprops.Set("verbose", (size_t)1);
+  infprops.Set("updates", string("HUGIN"));
+  InfAlg* inf = newInfAlg("JTREE", fg, infprops);
+  inf->init();
+
+  Evidence e;
+  ifstream estream(argv[2]);
+  e.addEvidenceTabFile(estream, fg);
+
+  cout << "Number of samples: " << e.nrSamples() << endl;
+  Evidence::iterator ps = e.begin();
+  for (; ps != e.end(); ps++) {
+    cout << "Sample " << ps->first << " has " 
+        << ps->second.observations().size() << " observations." << endl;
+  }
+
+  ifstream emstream(argv[3]);
+  EMAlg em(e, *inf, emstream);
+
+  for (size_t i = 0; i < 10; ++i) {
+    Real l = em.iterate();
+    cout << "Iteration " << i << " likelihood: " << l <<endl;
+  }
+
+  return 0;
+}