X-Git-Url: http://git.tuebingen.mpg.de/?p=libdai.git;a=blobdiff_plain;f=include%2Fdai%2Fjtree.h;h=37f698173b4680f8e2060d442410a92d0c277428;hp=799cfe4428f85f86336a37cb920fd01cd8f26e19;hb=b04b00f5eb8997766bef6f9d1b5dd105ff832645;hpb=a7a8c134163b46d582eacde331d5f20ee9a0f435
diff --git a/include/dai/jtree.h b/include/dai/jtree.h
index 799cfe4..37f6981 100644
--- a/include/dai/jtree.h
+++ b/include/dai/jtree.h
@@ -10,8 +10,7 @@
/// \file
-/// \brief Defines class JTree
-/// \todo Improve documentation
+/// \brief Defines class JTree, which implements the junction tree algorithm
#ifndef __defined_libdai_jtree_h
@@ -34,36 +33,55 @@ namespace dai {
/// Exact inference algorithm using junction tree
+/** The junction tree algorithm uses message passing on a junction tree to calculate
+ * exact marginal probability distributions ("beliefs") for specified cliques
+ * (outer regions) and separators (intersections of pairs of cliques).
+ *
+ * There are two variants, the sum-product algorithm (corresponding to
+ * finite temperature) and the max-product algorithm (corresponding to
+ * zero temperature).
+ */
class JTree : public DAIAlgRG {
private:
+ /// Stores the messages
std::vector > _mes;
+
+ /// Stores the logarithm of the partition sum
Real _logZ;
public:
- /// Rooted tree
- DEdgeVec RTree;
+ /// The junction tree (stored as a rooted tree)
+ RootedTree RTree;
/// Outer region beliefs
- std::vector Qa;
+ std::vector Qa;
/// Inner region beliefs
- std::vector Qb;
+ std::vector Qb;
- /// Parameters of this inference algorithm
+ /// Parameters for JTree
struct Properties {
/// Enumeration of possible JTree updates
+ /** There are two types of updates:
+ * - HUGIN similar to those in HUGIN
+ * - SHSH Shafer-Shenoy type
+ */
DAI_ENUM(UpdateType,HUGIN,SHSH);
/// Enumeration of inference variants
+ /** There are two inference variants:
+ * - SUMPROD Sum-Product
+ * - MAXPROD Max-Product (equivalent to Min-Sum)
+ */
DAI_ENUM(InfType,SUMPROD,MAXPROD);
- /// Verbosity
+ /// Verbosity (amount of output sent to stderr)
size_t verbose;
/// Type of updates
UpdateType updates;
- /// Type of inference: sum-product or max-product?
+ /// Type of inference
InfType inference;
} props;
@@ -71,19 +89,26 @@ class JTree : public DAIAlgRG {
static const char *Name;
public:
+ /// \name Constructors/destructors
+ //@{
/// Default constructor
JTree() : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {}
- /// Construct from FactorGraph fg and PropertySet opts
+ /// Construct from FactorGraph \a fg and PropertySet \a opts
+ /** \param fg factor graph (which has to be connected);
+ ** \param opts Parameters @see Properties
+ * \param automatic if \c true, construct the junction tree automatically, using the MinFill heuristic.
+ * \throw FACTORGRAPH_NOT_CONNECTED if \a fg is not connected
+ */
JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true );
+ //@}
/// \name General InfAlg interface
//@{
virtual JTree* clone() const { return new JTree(*this); }
virtual std::string identify() const;
- virtual Factor belief( const Var &n ) const;
- virtual Factor belief( const VarSet &ns ) const;
+ virtual Factor belief( const VarSet &vs ) const;
virtual std::vector beliefs() const;
virtual Real logZ() const;
virtual void init() {}
@@ -99,38 +124,70 @@ class JTree : public DAIAlgRG {
/// \name Additional interface specific for JTree
//@{
- void GenerateJT( const std::vector &Cliques );
+ /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
+ /** First, constructs a weighted graph, where the nodes are the elements of \a cl, and
+ * each edge is weighted with the cardinality of the intersection of the state spaces of the nodes.
+ * Then, a maximal spanning tree for this weighted graph is calculated.
+ * Subsequently, a corresponding region graph is built:
+ * - the outer regions correspond with the cliques and have counting number 1;
+ * - the inner regions correspond with the seperators, i.e., the intersections of two
+ * cliques that are neighbors in the spanning tree, and have counting number -1;
+ * - inner and outer regions are connected by an edge if the inner region is a
+ * seperator for the outer region.
+ * Finally, Beliefs are constructed.
+ * If \a verify == \c true, checks whether each factor is subsumed by a clique.
+ */
+ void construct( const std::vector &cl, bool verify=false );
- /// Returns reference the message from outer region alpha to its _beta'th neighboring inner region
- Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }
- /// Returns const reference to the message from outer region alpha to its _beta'th neighboring inner region
+ /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
+ /** Invokes construct() and then constructs messages.
+ * \see construct()
+ */
+ void GenerateJT( const std::vector &cl );
+
+ /// Returns constant reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
const Factor & message( size_t alpha, size_t _beta ) const { return _mes[alpha][_beta]; }
+ /// Returns reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
+ Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }
- /// Runs junction-tree with HUGIN updates
+ /// Runs junction tree algorithm using HUGIN updates
+ /** \note The initial messages may be arbitrary.
+ */
void runHUGIN();
- /// Runs junction-tree with Shafer-Shenoy updates
+ /// Runs junction tree algorithm using Shafer-Shenoy updates
+ /** \note The initial messages may be arbitrary.
+ */
void runShaferShenoy();
- /// Finds an efficient tree for calculating the marginal of some variables
- size_t findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot=(size_t)-1 ) const;
+ /// Finds an efficient subtree for calculating the marginal of the variables in \a vs
+ /** First, the current junction tree is reordered such that it gets as root the clique
+ * that has maximal state space overlap with \a vs. Then, the minimal subtree
+ * (starting from the root) is identified that contains all the variables in \a vs
+ * and also the outer region with index \a PreviousRoot (if specified). Finally,
+ * the current junction tree is reordered such that this minimal subtree comes
+ * before the other edges, and the size of the minimal subtree is returned.
+ */
+ size_t findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot=(size_t)-1 ) const;
- /// Calculates the marginal of a set of variables
- Factor calcMarginal( const VarSet& ns );
+ /// Calculates the marginal of a set of variables (using cutset conditioning, if necessary)
+ /** \pre assumes that run() has been called already
+ */
+ Factor calcMarginal( const VarSet& vs );
/// Calculates the joint state of all variables that has maximum probability
- /** Assumes that run() has been called and that props.inference == MAXPROD
+ /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
*/
std::vector findMaximum() const;
//@}
};
-/// Calculates upper bound to the treewidth of a FactorGraph
+/// Calculates upper bound to the treewidth of a FactorGraph, using the MinFill heuristic
/** \relates JTree
* \return a pair (number of variables in largest clique, number of states in largest clique)
*/
-std::pair treewidth( const FactorGraph & fg );
+std::pair boundTreewidth( const FactorGraph & fg );
} // end of namespace dai