X-Git-Url: http://git.tuebingen.mpg.de/?p=libdai.git;a=blobdiff_plain;f=include%2Fdai%2Fjtree.h;h=37f698173b4680f8e2060d442410a92d0c277428;hp=543b489b20b3cd7d00584528c6052ea7f0a46f5e;hb=b04b00f5eb8997766bef6f9d1b5dd105ff832645;hpb=49cd0d863897488f464ac2521ca5f612207814ba diff --git a/include/dai/jtree.h b/include/dai/jtree.h index 543b489..37f6981 100644 --- a/include/dai/jtree.h +++ b/include/dai/jtree.h @@ -10,8 +10,7 @@ /// \file -/// \brief Defines class JTree -/// \todo Improve documentation +/// \brief Defines class JTree, which implements the junction tree algorithm #ifndef __defined_libdai_jtree_h @@ -34,36 +33,55 @@ namespace dai { /// Exact inference algorithm using junction tree +/** The junction tree algorithm uses message passing on a junction tree to calculate + * exact marginal probability distributions ("beliefs") for specified cliques + * (outer regions) and separators (intersections of pairs of cliques). + * + * There are two variants, the sum-product algorithm (corresponding to + * finite temperature) and the max-product algorithm (corresponding to + * zero temperature). + */ class JTree : public DAIAlgRG { private: + /// Stores the messages std::vector > _mes; + + /// Stores the logarithm of the partition sum Real _logZ; public: - /// Rooted tree - DEdgeVec RTree; + /// The junction tree (stored as a rooted tree) + RootedTree RTree; /// Outer region beliefs - std::vector Qa; + std::vector Qa; /// Inner region beliefs - std::vector Qb; + std::vector Qb; - /// Parameters of this inference algorithm + /// Parameters for JTree struct Properties { /// Enumeration of possible JTree updates + /** There are two types of updates: + * - HUGIN similar to those in HUGIN + * - SHSH Shafer-Shenoy type + */ DAI_ENUM(UpdateType,HUGIN,SHSH); /// Enumeration of inference variants + /** There are two inference variants: + * - SUMPROD Sum-Product + * - MAXPROD Max-Product (equivalent to Min-Sum) + */ DAI_ENUM(InfType,SUMPROD,MAXPROD); - /// Verbosity + /// Verbosity (amount of output sent to stderr) size_t verbose; /// Type of updates UpdateType updates; - /// Type of inference: sum-product or max-product? + /// Type of inference InfType inference; } props; @@ -71,19 +89,26 @@ class JTree : public DAIAlgRG { static const char *Name; public: + /// \name Constructors/destructors + //@{ /// Default constructor JTree() : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {} - /// Construct from FactorGraph fg and PropertySet opts + /// Construct from FactorGraph \a fg and PropertySet \a opts + /** \param fg factor graph (which has to be connected); + ** \param opts Parameters @see Properties + * \param automatic if \c true, construct the junction tree automatically, using the MinFill heuristic. + * \throw FACTORGRAPH_NOT_CONNECTED if \a fg is not connected + */ JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true ); + //@} /// \name General InfAlg interface //@{ virtual JTree* clone() const { return new JTree(*this); } virtual std::string identify() const; - virtual Factor belief( const Var &n ) const; - virtual Factor belief( const VarSet &ns ) const; + virtual Factor belief( const VarSet &vs ) const; virtual std::vector beliefs() const; virtual Real logZ() const; virtual void init() {} @@ -91,56 +116,78 @@ class JTree : public DAIAlgRG { virtual Real run(); virtual Real maxDiff() const { return 0.0; } virtual size_t Iterations() const { return 1UL; } + virtual void setProperties( const PropertySet &opts ); + virtual PropertySet getProperties() const; + virtual std::string printProperties() const; //@} /// \name Additional interface specific for JTree //@{ - void GenerateJT( const std::vector &Cliques ); + /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence). + /** First, constructs a weighted graph, where the nodes are the elements of \a cl, and + * each edge is weighted with the cardinality of the intersection of the state spaces of the nodes. + * Then, a maximal spanning tree for this weighted graph is calculated. + * Subsequently, a corresponding region graph is built: + * - the outer regions correspond with the cliques and have counting number 1; + * - the inner regions correspond with the seperators, i.e., the intersections of two + * cliques that are neighbors in the spanning tree, and have counting number -1; + * - inner and outer regions are connected by an edge if the inner region is a + * seperator for the outer region. + * Finally, Beliefs are constructed. + * If \a verify == \c true, checks whether each factor is subsumed by a clique. + */ + void construct( const std::vector &cl, bool verify=false ); - /// Returns reference the message from outer region alpha to its _beta'th neighboring inner region - Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; } - /// Returns const reference to the message from outer region alpha to its _beta'th neighboring inner region + /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence). + /** Invokes construct() and then constructs messages. + * \see construct() + */ + void GenerateJT( const std::vector &cl ); + + /// Returns constant reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region const Factor & message( size_t alpha, size_t _beta ) const { return _mes[alpha][_beta]; } + /// Returns reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region + Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; } - /// Runs junction-tree with HUGIN updates + /// Runs junction tree algorithm using HUGIN updates + /** \note The initial messages may be arbitrary. + */ void runHUGIN(); - /// Runs junction-tree with Shafer-Shenoy updates + /// Runs junction tree algorithm using Shafer-Shenoy updates + /** \note The initial messages may be arbitrary. + */ void runShaferShenoy(); - /// Finds an efficient tree for calculating the marginal of some variables - size_t findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot=(size_t)-1 ) const; + /// Finds an efficient subtree for calculating the marginal of the variables in \a vs + /** First, the current junction tree is reordered such that it gets as root the clique + * that has maximal state space overlap with \a vs. Then, the minimal subtree + * (starting from the root) is identified that contains all the variables in \a vs + * and also the outer region with index \a PreviousRoot (if specified). Finally, + * the current junction tree is reordered such that this minimal subtree comes + * before the other edges, and the size of the minimal subtree is returned. + */ + size_t findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot=(size_t)-1 ) const; - /// Calculates the marginal of a set of variables - Factor calcMarginal( const VarSet& ns ); + /// Calculates the marginal of a set of variables (using cutset conditioning, if necessary) + /** \pre assumes that run() has been called already + */ + Factor calcMarginal( const VarSet& vs ); /// Calculates the joint state of all variables that has maximum probability - /** Assumes that run() has been called and that props.inference == MAXPROD + /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD */ std::vector findMaximum() const; //@} - - /// \name Managing parameters (which are stored in JTree::props) - //@{ - /// Set parameters of this inference algorithm. - /** The parameters are set according to \a opts. - * The values can be stored either as std::string or as the type of the corresponding JTree::props member. - */ - void setProperties( const PropertySet &opts ); - /// Returns parameters of this inference algorithm converted into a PropertySet. - PropertySet getProperties() const; - /// Returns parameters of this inference algorithm formatted as a string in the format "[key1=val1,key2=val2,...,keyn=valn]". - std::string printProperties() const; - //@} }; -/// Calculates upper bound to the treewidth of a FactorGraph +/// Calculates upper bound to the treewidth of a FactorGraph, using the MinFill heuristic /** \relates JTree * \return a pair (number of variables in largest clique, number of states in largest clique) */ -std::pair treewidth( const FactorGraph & fg ); +std::pair boundTreewidth( const FactorGraph & fg ); } // end of namespace dai