XGitUrl: http://git.tuebingen.mpg.de/?p=libdai.git;a=blobdiff_plain;f=include%2Fdai%2Fjtree.h;h=96ef98abb38f1e7d8e8e662eb3fb77f24c3a1183;hp=8107757bc565113767b52ad7546fad7b9a7e72ee;hb=48fcb551d982e6e2b95dcfa114aad66c3de6844f;hpb=03a2ade22d55a6723343bfe084e74a5a695378b4
diff git a/include/dai/jtree.h b/include/dai/jtree.h
index 8107757..96ef98a 100644
 a/include/dai/jtree.h
+++ b/include/dai/jtree.h
@@ 10,8 +10,7 @@
/// \file
/// \brief Defines class JTree
/// \todo Improve documentation
+/// \brief Defines class JTree, which implements the junction tree algorithm
#ifndef __defined_libdai_jtree_h
@@ 34,105 +33,183 @@ namespace dai {
/// Exact inference algorithm using junction tree
+/** The junction tree algorithm uses message passing on a junction tree to calculate
+ * exact marginal probability distributions ("beliefs") for specified cliques
+ * (outer regions) and separators (intersections of pairs of cliques).
+ *
+ * There are two variants, the sumproduct algorithm (corresponding to
+ * finite temperature) and the maxproduct algorithm (corresponding to
+ * zero temperature).
+ */
class JTree : public DAIAlgRG {
private:
+ /// Stores the messages
std::vector > _mes;
 double _logZ;
+
+ /// Stores the logarithm of the partition sum
+ Real _logZ;
public:
 /// Rooted tree
 DEdgeVec RTree;
+ /// The junction tree (stored as a rooted tree)
+ RootedTree RTree;
/// Outer region beliefs
 std::vector Qa;
+ std::vector Qa;
/// Inner region beliefs
 std::vector Qb;
+ std::vector Qb;
 /// Parameters of this inference algorithm
+ /// Parameters for JTree
struct Properties {
/// Enumeration of possible JTree updates
+ /** There are two types of updates:
+ *  HUGIN similar to those in HUGIN
+ *  SHSH ShaferShenoy type
+ */
DAI_ENUM(UpdateType,HUGIN,SHSH);
/// Enumeration of inference variants
+ /** There are two inference variants:
+ *  SUMPROD SumProduct
+ *  MAXPROD MaxProduct (equivalent to MinSum)
+ */
DAI_ENUM(InfType,SUMPROD,MAXPROD);
 /// Verbosity
+ /// Enumeration of elimination cost functions used for constructing the junction tree
+ /** The cost of eliminating a variable can be (\see [\ref KoF09], page 314)):
+ *  MINNEIGHBORS the number of neighbors it has in the current adjacency graph;
+ *  MINWEIGHT the product of the number of states of all neighbors in the current adjacency graph;
+ *  MINFILL the number of edges that need to be added to the adjacency graph due to the elimination;
+ *  WEIGHTEDMINFILL the sum of weights of the edges that need to be added to the adjacency graph
+ * due to the elimination, where a weight of an edge is the produt of weights of its constituent
+ * vertices.
+ * The elimination sequence is chosen greedily in order to minimize the cost.
+ */
+ DAI_ENUM(HeuristicType,MINNEIGHBORS,MINWEIGHT,MINFILL,WEIGHTEDMINFILL);
+
+ /// Verbosity (amount of output sent to stderr)
size_t verbose;
/// Type of updates
UpdateType updates;
 /// Type of inference: sumproduct or maxproduct?
+ /// Type of inference
InfType inference;
+
+ /// Heuristic to use for constructing the junction tree
+ HeuristicType heuristic;
+
+ /// Maximum memory to use in bytes (0 means unlimited)
+ size_t maxmem;
} props;
/// Name of this inference algorithm
static const char *Name;
public:
+ /// \name Constructors/destructors
+ //@{
/// Default constructor
JTree() : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {}
 /// Construct from FactorGraph fg and PropertySet opts
+ /// Construct from FactorGraph \a fg and PropertySet \a opts
+ /** \param fg factor graph
+ ** \param opts Parameters @see Properties
+ * \param automatic if \c true, construct the junction tree automatically, using the heuristic in opts['heuristic'].
+ */
JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true );
+ //@}
 /// @name General InfAlg interface
 //@{
+ /// \name General InfAlg interface
+ //@{
virtual JTree* clone() const { return new JTree(*this); }
virtual std::string identify() const;
 virtual Factor belief( const Var &n ) const;
 virtual Factor belief( const VarSet &ns ) const;
+ virtual Factor belief( const VarSet &vs ) const;
virtual std::vector beliefs() const;
virtual Real logZ() const;
virtual void init() {}
virtual void init( const VarSet &/*ns*/ ) {}
 virtual double run();
 virtual double maxDiff() const { return 0.0; }
+ virtual Real run();
+ virtual Real maxDiff() const { return 0.0; }
virtual size_t Iterations() const { return 1UL; }
 //@}

+ virtual void setProperties( const PropertySet &opts );
+ virtual PropertySet getProperties() const;
+ virtual std::string printProperties() const;
+ //@}
+
+
+ /// \name Additional interface specific for JTree
+ //@{
+ /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
+ /** First, constructs a weighted graph, where the nodes are the elements of \a cl, and
+ * each edge is weighted with the cardinality of the intersection of the state spaces of the nodes.
+ * Then, a maximal spanning tree for this weighted graph is calculated.
+ * Subsequently, a corresponding region graph is built:
+ *  the outer regions correspond with the cliques and have counting number 1;
+ *  the inner regions correspond with the seperators, i.e., the intersections of two
+ * cliques that are neighbors in the spanning tree, and have counting number 1
+ * (except empty ones, which have counting number 0);
+ *  inner and outer regions are connected by an edge if the inner region is a
+ * seperator for the outer region.
+ * Finally, Beliefs are constructed.
+ * If \a verify == \c true, checks whether each factor is subsumed by a clique.
+ */
+ void construct( const FactorGraph &fg, const std::vector &cl, bool verify=false );
 /// @name Additional interface specific for JTree
 //@{
 void GenerateJT( const std::vector &Cliques );
+ /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
+ /** Invokes construct() and then constructs messages.
+ * \see construct()
+ */
+ void GenerateJT( const FactorGraph &fg, const std::vector &cl );
 /// Returns reference the message from outer region alpha to its _beta'th neighboring inner region
 Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }
 /// Returns const reference to the message from outer region alpha to its _beta'th neighboring inner region
+ /// Returns constant reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
const Factor & message( size_t alpha, size_t _beta ) const { return _mes[alpha][_beta]; }
+ /// Returns reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
+ Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }
 /// Runs junctiontree with HUGIN updates
+ /// Runs junction tree algorithm using HUGIN (messagefree) updates
+ /** \note The initial messages may be arbitrary; actually they are not used at all.
+ */
void runHUGIN();
 /// Runs junctiontree with ShaferShenoy updates
+ /// Runs junction tree algorithm using ShaferShenoy updates
+ /** \note The initial messages may be arbitrary.
+ */
void runShaferShenoy();
 /// Finds an efficient tree for calculating the marginal of some variables
 size_t findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot=(size_t)1 ) const;
+ /// Finds an efficient subtree for calculating the marginal of the variables in \a vs
+ /** First, the current junction tree is reordered such that it gets as root the clique
+ * that has maximal state space overlap with \a vs. Then, the minimal subtree
+ * (starting from the root) is identified that contains all the variables in \a vs
+ * and also the outer region with index \a PreviousRoot (if specified). Finally,
+ * the current junction tree is reordered such that this minimal subtree comes
+ * before the other edges, and the size of the minimal subtree is returned.
+ */
+ size_t findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot=(size_t)1 ) const;
 /// Calculates the marginal of a set of variables
 Factor calcMarginal( const VarSet& ns );
+ /// Calculates the marginal of a set of variables (using cutset conditioning, if necessary)
+ /** \pre assumes that run() has been called already
+ */
+ Factor calcMarginal( const VarSet& vs );
/// Calculates the joint state of all variables that has maximum probability
 /** Assumes that run() has been called and that props.inference == MAXPROD
+ /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
*/
std::vector findMaximum() const;
 //@}

 private:
 void setProperties( const PropertySet &opts );
 PropertySet getProperties() const;
 std::string printProperties() const;
+ //@}
};
/// Calculates upper bound to the treewidth of a FactorGraph
+/// Calculates upper bound to the treewidth of a FactorGraph, using the specified heuristic
/** \relates JTree
+ * \param fg the factor graph for which the treewidth should be bounded
+ * \param fn the heuristic cost function used for greedy variable elimination
+ * \param maxStates maximum total number of states in outer regions of junction tree (0 means no limit)
+ * \throws OUT_OF_MEMORY if the total number of states becomes larger than maxStates
* \return a pair (number of variables in largest clique, number of states in largest clique)
*/
std::pair treewidth( const FactorGraph & fg );
+std::pair boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn, size_t maxStates=0 );
} // end of namespace dai