X-Git-Url: http://git.tuebingen.mpg.de/?p=libdai.git;a=blobdiff_plain;f=include%2Fdai%2Fjtree.h;h=aaff0e21adacb144366e8f2fb302b195a3018917;hp=db1ddd14f747b9dcd2f04109b87c3aeae228d157;hb=87c6827445f8fd67801efb6e818771e16229313b;hpb=1a6cf1800decab50998bfacabefd14101e9cc5a5 diff --git a/include/dai/jtree.h b/include/dai/jtree.h index db1ddd1..aaff0e2 100644 --- a/include/dai/jtree.h +++ b/include/dai/jtree.h @@ -20,6 +20,11 @@ */ +/// \file +/// \brief Defines class JTree +/// \todo Improve documentation + + #ifndef __defined_libdai_jtree_h #define __defined_libdai_jtree_h @@ -39,21 +44,35 @@ namespace dai { +/// Exact inference algorithm using junction tree class JTree : public DAIAlgRG { private: std::vector > _mes; double _logZ; public: - DEdgeVec RTree; // rooted tree + /// Rooted tree + DEdgeVec RTree; + + /// Outer region beliefs std::vector Qa; + + /// Inner region beliefs std::vector Qb; + + /// Parameters of this inference algorithm struct Properties { - size_t verbose; + /// Enumeration of possible JTree updates DAI_ENUM(UpdateType,HUGIN,SHSH) + + /// Verbosity + size_t verbose; + + /// Type of updates UpdateType updates; } props; - /// Name of this inference method + + /// Name of this inference algorithm static const char *Name; public: @@ -63,76 +82,57 @@ class JTree : public DAIAlgRG { /// Construct from FactorGraph fg and PropertySet opts JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true ); - /// Copy constructor - JTree( const JTree &x ) : DAIAlgRG(x), _mes(x._mes), _logZ(x._logZ), RTree(x.RTree), Qa(x.Qa), Qb(x.Qb), props(x.props) {} - /// Clone *this (virtual copy constructor) + /// @name General InfAlg interface + //@{ virtual JTree* clone() const { return new JTree(*this); } - - /// Create (virtual default constructor) virtual JTree* create() const { return new JTree(); } - - /// Assignment operator - JTree& operator=( const JTree &x ) { - if( this != &x ) { - DAIAlgRG::operator=( x ); - _mes = x._mes; - _logZ = x._logZ; - RTree = x.RTree; - Qa = x.Qa; - Qb = x.Qb; - props = x.props; - } - return *this; - } - - /// Identifies itself for logging purposes virtual std::string identify() const; - - /// Get single node belief virtual Factor belief( const Var &n ) const; - - /// Get general belief virtual Factor belief( const VarSet &ns ) const; - - /// Get all beliefs virtual std::vector beliefs() const; - - /// Get log partition sum virtual Real logZ() const; - - /// Clear messages and beliefs virtual void init() {} - - /// Clear messages and beliefs corresponding to the nodes in ns virtual void init( const VarSet &/*ns*/ ) {} - - /// The actual approximate inference algorithm virtual double run(); - - /// Return maximum difference between single node beliefs in the last pass virtual double maxDiff() const { return 0.0; } - - /// Return number of passes over the factorgraph virtual size_t Iterations() const { return 1UL; } + //@} + /// @name Additional interface specific for JTree + //@{ void GenerateJT( const std::vector &Cliques ); + /// Returns reference the message from outer region alpha to its _beta'th neighboring inner region Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; } + /// Returns const reference to the message from outer region alpha to its _beta'th neighboring inner region const Factor & message( size_t alpha, size_t _beta ) const { return _mes[alpha][_beta]; } + /// Runs junction-tree with HUGIN updates void runHUGIN(); + + /// Runs junction-tree with Shafer-Shenoy updates void runShaferShenoy(); + + /// Finds an efficient tree for calculating the marginal of some variables size_t findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot=(size_t)-1 ) const; + + /// Calculates the marginal of a set of variables Factor calcMarginal( const VarSet& ns ); + //@} + private: void setProperties( const PropertySet &opts ); PropertySet getProperties() const; std::string printProperties() const; }; +/// Calculates upper bound to the treewidth of a FactorGraph +/** \relates JTree + * \return a pair (number of variables in largest clique, number of states in largest clique) + */ std::pair treewidth( const FactorGraph & fg );