index 543b489..37f6981 100644 (file)
@@ -10,8 +10,7 @@

/// \file

/// \file
-/// \brief Defines class JTree
-/// \todo Improve documentation
+/// \brief Defines class JTree, which implements the junction tree algorithm

#ifndef __defined_libdai_jtree_h

#ifndef __defined_libdai_jtree_h
@@ -34,36 +33,55 @@ namespace dai {

/// Exact inference algorithm using junction tree

/// Exact inference algorithm using junction tree
+/** The junction tree algorithm uses message passing on a junction tree to calculate
+ *  exact marginal probability distributions ("beliefs") for specified cliques
+ *  (outer regions) and separators (intersections of pairs of cliques).
+ *
+ *  There are two variants, the sum-product algorithm (corresponding to
+ *  finite temperature) and the max-product algorithm (corresponding to
+ *  zero temperature).
+ */
class JTree : public DAIAlgRG {
private:
class JTree : public DAIAlgRG {
private:
+        /// Stores the messages
std::vector<std::vector<Factor> >  _mes;
std::vector<std::vector<Factor> >  _mes;
+
+        /// Stores the logarithm of the partition sum
Real _logZ;

public:
Real _logZ;

public:
-        /// Rooted tree
-        DEdgeVec             RTree;
+        /// The junction tree (stored as a rooted tree)
+        RootedTree RTree;

/// Outer region beliefs

/// Outer region beliefs
-        std::vector<Factor>  Qa;
+        std::vector<Factor> Qa;

/// Inner region beliefs

/// Inner region beliefs
-        std::vector<Factor>  Qb;
+        std::vector<Factor> Qb;

-        /// Parameters of this inference algorithm
+        /// Parameters for JTree
struct Properties {
/// Enumeration of possible JTree updates
struct Properties {
/// Enumeration of possible JTree updates
+            /** There are two types of updates:
+             *  - HUGIN similar to those in HUGIN
+             *  - SHSH Shafer-Shenoy type
+             */
DAI_ENUM(UpdateType,HUGIN,SHSH);

/// Enumeration of inference variants
DAI_ENUM(UpdateType,HUGIN,SHSH);

/// Enumeration of inference variants
+            /** There are two inference variants:
+             *  - SUMPROD Sum-Product
+             *  - MAXPROD Max-Product (equivalent to Min-Sum)
+             */
DAI_ENUM(InfType,SUMPROD,MAXPROD);

DAI_ENUM(InfType,SUMPROD,MAXPROD);

-            /// Verbosity
+            /// Verbosity (amount of output sent to stderr)
size_t verbose;

size_t verbose;

-            /// Type of inference: sum-product or max-product?
+            /// Type of inference
InfType inference;
} props;

InfType inference;
} props;

@@ -71,19 +89,26 @@ class JTree : public DAIAlgRG {
static const char *Name;

public:
static const char *Name;

public:
+    /// \name Constructors/destructors
+    //@{
/// Default constructor
JTree() : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {}

/// Default constructor
JTree() : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {}

-        /// Construct from FactorGraph fg and PropertySet opts
+        /// Construct from FactorGraph \a fg and PropertySet \a opts
+        /** \param fg factor graph (which has to be connected);
+         ** \param opts Parameters @see Properties
+         *  \param automatic if \c true, construct the junction tree automatically, using the MinFill heuristic.
+         *  \throw FACTORGRAPH_NOT_CONNECTED if \a fg is not connected
+         */
JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true );
JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true );
+    //@}

/// \name General InfAlg interface
//@{
virtual JTree* clone() const { return new JTree(*this); }
virtual std::string identify() const;

/// \name General InfAlg interface
//@{
virtual JTree* clone() const { return new JTree(*this); }
virtual std::string identify() const;
-        virtual Factor belief( const Var &n ) const;
-        virtual Factor belief( const VarSet &ns ) const;
+        virtual Factor belief( const VarSet &vs ) const;
virtual std::vector<Factor> beliefs() const;
virtual Real logZ() const;
virtual void init() {}
virtual std::vector<Factor> beliefs() const;
virtual Real logZ() const;
virtual void init() {}
@@ -91,56 +116,78 @@ class JTree : public DAIAlgRG {
virtual Real run();
virtual Real maxDiff() const { return 0.0; }
virtual size_t Iterations() const { return 1UL; }
virtual Real run();
virtual Real maxDiff() const { return 0.0; }
virtual size_t Iterations() const { return 1UL; }
+        virtual void setProperties( const PropertySet &opts );
+        virtual PropertySet getProperties() const;
+        virtual std::string printProperties() const;
//@}

/// \name Additional interface specific for JTree
//@{
//@}

/// \name Additional interface specific for JTree
//@{
-        void GenerateJT( const std::vector<VarSet> &Cliques );
+        /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
+        /** First, constructs a weighted graph, where the nodes are the elements of \a cl, and
+         *  each edge is weighted with the cardinality of the intersection of the state spaces of the nodes.
+         *  Then, a maximal spanning tree for this weighted graph is calculated.
+         *  Subsequently, a corresponding region graph is built:
+         *    - the outer regions correspond with the cliques and have counting number 1;
+         *    - the inner regions correspond with the seperators, i.e., the intersections of two
+         *      cliques that are neighbors in the spanning tree, and have counting number -1;
+         *    - inner and outer regions are connected by an edge if the inner region is a
+         *      seperator for the outer region.
+         *  Finally, Beliefs are constructed.
+         *  If \a verify == \c true, checks whether each factor is subsumed by a clique.
+         */
+        void construct( const std::vector<VarSet> &cl, bool verify=false );

-        /// Returns reference the message from outer region alpha to its _beta'th neighboring inner region
-        Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }
-        /// Returns const reference to the message from outer region alpha to its _beta'th neighboring inner region
+        /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
+        /** Invokes construct() and then constructs messages.
+         *  \see construct()
+         */
+        void GenerateJT( const std::vector<VarSet> &cl );
+
+        /// Returns constant reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
const Factor & message( size_t alpha, size_t _beta ) const { return _mes[alpha][_beta]; }
const Factor & message( size_t alpha, size_t _beta ) const { return _mes[alpha][_beta]; }
+        /// Returns reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
+        Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }

-        /// Runs junction-tree with HUGIN updates
+        /// Runs junction tree algorithm using HUGIN updates
+        /** \note The initial messages may be arbitrary.
+         */
void runHUGIN();

void runHUGIN();

-        /// Runs junction-tree with Shafer-Shenoy updates
+        /// Runs junction tree algorithm using Shafer-Shenoy updates
+        /** \note The initial messages may be arbitrary.
+         */
void runShaferShenoy();

void runShaferShenoy();

-        /// Finds an efficient tree for calculating the marginal of some variables
-        size_t findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot=(size_t)-1 ) const;
+        /// Finds an efficient subtree for calculating the marginal of the variables in \a vs
+        /** First, the current junction tree is reordered such that it gets as root the clique
+         *  that has maximal state space overlap with \a vs. Then, the minimal subtree
+         *  (starting from the root) is identified that contains all the variables in \a vs
+         *  and also the outer region with index \a PreviousRoot (if specified). Finally,
+         *  the current junction tree is reordered such that this minimal subtree comes
+         *  before the other edges, and the size of the minimal subtree is returned.
+         */
+        size_t findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot=(size_t)-1 ) const;

-        /// Calculates the marginal of a set of variables
-        Factor calcMarginal( const VarSet& ns );
+        /// Calculates the marginal of a set of variables (using cutset conditioning, if necessary)
+        /** \pre assumes that run() has been called already
+         */
+        Factor calcMarginal( const VarSet& vs );

/// Calculates the joint state of all variables that has maximum probability

/// Calculates the joint state of all variables that has maximum probability
-        /** Assumes that run() has been called and that props.inference == MAXPROD
+        /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
*/
std::vector<std::size_t> findMaximum() const;
//@}
*/
std::vector<std::size_t> findMaximum() const;
//@}
-
-    /// \name Managing parameters (which are stored in JTree::props)
-    //@{
-        /// Set parameters of this inference algorithm.
-        /** The parameters are set according to \a opts.
-         *  The values can be stored either as std::string or as the type of the corresponding JTree::props member.
-         */
-        void setProperties( const PropertySet &opts );
-        /// Returns parameters of this inference algorithm converted into a PropertySet.
-        PropertySet getProperties() const;
-        /// Returns parameters of this inference algorithm formatted as a string in the format "[key1=val1,key2=val2,...,keyn=valn]".
-        std::string printProperties() const;
-    //@}
};

};

-/// Calculates upper bound to the treewidth of a FactorGraph
+/// Calculates upper bound to the treewidth of a FactorGraph, using the MinFill heuristic
/** \relates JTree
*  \return a pair (number of variables in largest clique, number of states in largest clique)
*/
/** \relates JTree
*  \return a pair (number of variables in largest clique, number of states in largest clique)
*/
-std::pair<size_t,size_t> treewidth( const FactorGraph & fg );
+std::pair<size_t,size_t> boundTreewidth( const FactorGraph & fg );

} // end of namespace dai

} // end of namespace dai