1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
13 /// \brief Defines class JTree, which implements the junction tree algorithm
16 #ifndef __defined_libdai_jtree_h
17 #define __defined_libdai_jtree_h
22 #include <dai/daialg.h>
23 #include <dai/varset.h>
24 #include <dai/regiongraph.h>
25 #include <dai/factorgraph.h>
26 #include <dai/clustergraph.h>
27 #include <dai/weightedgraph.h>
29 #include <dai/properties.h>
35 /// Exact inference algorithm using junction tree
36 /** The junction tree algorithm uses message passing on a junction tree to calculate
37 * exact marginal probability distributions ("beliefs") for specified cliques
38 * (outer regions) and separators (intersections of pairs of cliques).
40 * There are two variants, the sum-product algorithm (corresponding to
41 * finite temperature) and the max-product algorithm (corresponding to
44 class JTree
: public DAIAlgRG
{
46 /// Stores the messages
47 std::vector
<std::vector
<Factor
> > _mes
;
49 /// Stores the logarithm of the partition sum
53 /// The junction tree (stored as a rooted tree)
56 /// Outer region beliefs
57 std::vector
<Factor
> Qa
;
59 /// Inner region beliefs
60 std::vector
<Factor
> Qb
;
62 /// Parameters of this inference algorithm
64 /// Enumeration of possible JTree updates
65 DAI_ENUM(UpdateType
,HUGIN
,SHSH
);
67 /// Enumeration of inference variants
68 DAI_ENUM(InfType
,SUMPROD
,MAXPROD
);
73 /// Type of updates: HUGIN or Shafer-Shenoy
76 /// Type of inference: sum-product or max-product
80 /// Name of this inference algorithm
81 static const char *Name
;
84 /// \name Constructors/destructors
86 /// Default constructor
87 JTree() : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {}
89 /// Construct from FactorGraph \a fg and PropertySet \a opts
90 /** \param fg factor graph (which has to be connected);
91 * \param opts parameters;
92 * \param automatic if \c true, construct the junction tree automatically, using the MinFill heuristic.
93 * \throw FACTORGRAPH_NOT_CONNECTED if \a fg is not connected
95 JTree( const FactorGraph
&fg
, const PropertySet
&opts
, bool automatic
=true );
99 /// \name General InfAlg interface
101 virtual JTree
* clone() const { return new JTree(*this); }
102 virtual std::string
identify() const;
103 virtual Factor
belief( const Var
&n
) const;
104 virtual Factor
belief( const VarSet
&ns
) const;
105 virtual std::vector
<Factor
> beliefs() const;
106 virtual Real
logZ() const;
107 virtual void init() {}
108 virtual void init( const VarSet
&/*ns*/ ) {}
110 virtual Real
maxDiff() const { return 0.0; }
111 virtual size_t Iterations() const { return 1UL; }
112 virtual void setProperties( const PropertySet
&opts
);
113 virtual PropertySet
getProperties() const;
114 virtual std::string
printProperties() const;
118 /// \name Additional interface specific for JTree
120 /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
121 /** First, constructs a weighted graph, where the nodes are the elements of \a cl, and
122 * each edge is weighted with the cardinality of the intersection of the state spaces of the nodes.
123 * Then, a maximal spanning tree for this weighted graph is calculated.
124 * Finally, a corresponding region graph is built:
125 * - the outer regions correspond with the cliques and have counting number 1;
126 * - the inner regions correspond with the seperators, i.e., the intersections of two
127 * cliques that are neighbors in the spanning tree, and have counting number -1;
128 * - inner and outer regions are connected by an edge if the inner region is a
129 * seperator for the outer region.
131 void GenerateJT( const std::vector
<VarSet
> &cl
);
133 /// Returns constant reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
134 const Factor
& message( size_t alpha
, size_t _beta
) const { return _mes
[alpha
][_beta
]; }
135 /// Returns reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
136 Factor
& message( size_t alpha
, size_t _beta
) { return _mes
[alpha
][_beta
]; }
138 /// Runs junction tree algorithm using HUGIN updates
139 /** \note The initial messages may be arbitrary.
143 /// Runs junction tree algorithm using Shafer-Shenoy updates
144 /** \note The initial messages may be arbitrary.
146 void runShaferShenoy();
148 /// Finds an efficient subtree for calculating the marginal of the variables in \a vs
149 /** First, the current junction tree is reordered such that it gets as root the clique
150 * that has maximal state space overlap with \a vs. Then, the minimal subtree
151 * (starting from the root) is identified that contains all the variables in \a vs
152 * and also the outer region with index \a PreviousRoot (if specified). Finally,
153 * the current junction tree is reordered such that this minimal subtree comes
154 * before the other edges, and the size of the minimal subtree is returned.
156 size_t findEfficientTree( const VarSet
& vs
, RootedTree
&Tree
, size_t PreviousRoot
=(size_t)-1 ) const;
158 /// Calculates the marginal of a set of variables (using cutset conditioning, if necessary)
159 /** \pre assumes that run() has been called already
161 Factor
calcMarginal( const VarSet
& vs
);
163 /// Calculates the joint state of all variables that has maximum probability
164 /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
166 std::vector
<std::size_t> findMaximum() const;
171 /// Calculates upper bound to the treewidth of a FactorGraph, using the MinFill heuristic
173 * \return a pair (number of variables in largest clique, number of states in largest clique)
175 std::pair
<size_t,size_t> boundTreewidth( const FactorGraph
& fg
);
178 /// Calculates upper bound to the treewidth of a FactorGraph, using the MinFill heuristic
179 /** \deprecated Renamed into boundTreewidth()
181 std::pair
<size_t,size_t> treewidth( const FactorGraph
& fg
);
184 } // end of namespace dai