c27be06c7f0ff8661022561773d7004ff4803b1d
[libdai.git] / include / dai / jtree.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines class JTree, which implements the junction tree algorithm
14
15
16 #ifndef __defined_libdai_jtree_h
17 #define __defined_libdai_jtree_h
18
19
20 #include <vector>
21 #include <string>
22 #include <dai/daialg.h>
23 #include <dai/varset.h>
24 #include <dai/regiongraph.h>
25 #include <dai/factorgraph.h>
26 #include <dai/clustergraph.h>
27 #include <dai/weightedgraph.h>
28 #include <dai/enum.h>
29 #include <dai/properties.h>
30
31
32 namespace dai {
33
34
35 /// Exact inference algorithm using junction tree
36 /** The junction tree algorithm uses message passing on a junction tree to calculate
37 * exact marginal probability distributions ("beliefs") for specified cliques
38 * (outer regions) and separators (intersections of pairs of cliques).
39 *
40 * There are two variants, the sum-product algorithm (corresponding to
41 * finite temperature) and the max-product algorithm (corresponding to
42 * zero temperature).
43 */
44 class JTree : public DAIAlgRG {
45 private:
46 /// Stores the messages
47 std::vector<std::vector<Factor> > _mes;
48
49 /// Stores the logarithm of the partition sum
50 Real _logZ;
51
52 public:
53 /// The junction tree (stored as a rooted tree)
54 RootedTree RTree;
55
56 /// Outer region beliefs
57 std::vector<Factor> Qa;
58
59 /// Inner region beliefs
60 std::vector<Factor> Qb;
61
62 /// Parameters for JTree
63 struct Properties {
64 /// Enumeration of possible JTree updates
65 /** There are two types of updates:
66 * - HUGIN similar to those in HUGIN
67 * - SHSH Shafer-Shenoy type
68 */
69 DAI_ENUM(UpdateType,HUGIN,SHSH);
70
71 /// Enumeration of inference variants
72 /** There are two inference variants:
73 * - SUMPROD Sum-Product
74 * - MAXPROD Max-Product (equivalent to Min-Sum)
75 */
76 DAI_ENUM(InfType,SUMPROD,MAXPROD);
77
78 /// Enumeration of elimination cost functions used for constructing the junction tree
79 /** The cost of eliminating a variable can be (\see [\ref KoF09], page 314)):
80 * - MINNEIGHBORS the number of neighbors it has in the current adjacency graph;
81 * - MINWEIGHT the product of the number of states of all neighbors in the current adjacency graph;
82 * - MINFILL the number of edges that need to be added to the adjacency graph due to the elimination;
83 * - WEIGHTEDMINFILL the sum of weights of the edges that need to be added to the adjacency graph
84 * due to the elimination, where a weight of an edge is the produt of weights of its constituent
85 * vertices.
86 * The elimination sequence is chosen greedily in order to minimize the cost.
87 */
88 DAI_ENUM(HeuristicType,MINNEIGHBORS,MINWEIGHT,MINFILL,WEIGHTEDMINFILL);
89
90 /// Verbosity (amount of output sent to stderr)
91 size_t verbose;
92
93 /// Type of updates
94 UpdateType updates;
95
96 /// Type of inference
97 InfType inference;
98
99 /// Heuristic to use for constructing the junction tree
100 HeuristicType heuristic;
101
102 /// Maximum memory to use in bytes (0 means unlimited)
103 size_t maxmem;
104 } props;
105
106 /// Name of this inference algorithm
107 static const char *Name;
108
109 public:
110 /// \name Constructors/destructors
111 //@{
112 /// Default constructor
113 JTree() : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {}
114
115 /// Construct from FactorGraph \a fg and PropertySet \a opts
116 /** \param fg factor graph
117 ** \param opts Parameters @see Properties
118 * \param automatic if \c true, construct the junction tree automatically, using the heuristic in opts['heuristic'].
119 */
120 JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true );
121 //@}
122
123
124 /// \name General InfAlg interface
125 //@{
126 virtual JTree* clone() const { return new JTree(*this); }
127 virtual std::string identify() const;
128 virtual Factor belief( const VarSet &vs ) const;
129 virtual std::vector<Factor> beliefs() const;
130 virtual Real logZ() const;
131 virtual void init() {}
132 virtual void init( const VarSet &/*ns*/ ) {}
133 virtual Real run();
134 virtual Real maxDiff() const { return 0.0; }
135 virtual size_t Iterations() const { return 1UL; }
136 virtual void setProperties( const PropertySet &opts );
137 virtual PropertySet getProperties() const;
138 virtual std::string printProperties() const;
139 //@}
140
141
142 /// \name Additional interface specific for JTree
143 //@{
144 /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
145 /** First, constructs a weighted graph, where the nodes are the elements of \a cl, and
146 * each edge is weighted with the cardinality of the intersection of the state spaces of the nodes.
147 * Then, a maximal spanning tree for this weighted graph is calculated.
148 * Subsequently, a corresponding region graph is built:
149 * - the outer regions correspond with the cliques and have counting number 1;
150 * - the inner regions correspond with the seperators, i.e., the intersections of two
151 * cliques that are neighbors in the spanning tree, and have counting number -1
152 * (except empty ones, which have counting number 0);
153 * - inner and outer regions are connected by an edge if the inner region is a
154 * seperator for the outer region.
155 * Finally, Beliefs are constructed.
156 * If \a verify == \c true, checks whether each factor is subsumed by a clique.
157 */
158 void construct( const FactorGraph &fg, const std::vector<VarSet> &cl, bool verify=false );
159
160 /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
161 /** Invokes construct() and then constructs messages.
162 * \see construct()
163 */
164 void GenerateJT( const FactorGraph &fg, const std::vector<VarSet> &cl );
165
166 /// Returns constant reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
167 const Factor & message( size_t alpha, size_t _beta ) const { return _mes[alpha][_beta]; }
168 /// Returns reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
169 Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }
170
171 /// Runs junction tree algorithm using HUGIN (message-free) updates
172 /** \note The initial messages may be arbitrary; actually they are not used at all.
173 */
174 void runHUGIN();
175
176 /// Runs junction tree algorithm using Shafer-Shenoy updates
177 /** \note The initial messages may be arbitrary.
178 */
179 void runShaferShenoy();
180
181 /// Finds an efficient subtree for calculating the marginal of the variables in \a vs
182 /** First, the current junction tree is reordered such that it gets as root the clique
183 * that has maximal state space overlap with \a vs. Then, the minimal subtree
184 * (starting from the root) is identified that contains all the variables in \a vs
185 * and also the outer region with index \a PreviousRoot (if specified). Finally,
186 * the current junction tree is reordered such that this minimal subtree comes
187 * before the other edges, and the size of the minimal subtree is returned.
188 */
189 size_t findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot=(size_t)-1 ) const;
190
191 /// Calculates the marginal of a set of variables (using cutset conditioning, if necessary)
192 /** \pre assumes that run() has been called already
193 */
194 Factor calcMarginal( const VarSet& vs );
195
196 /// Calculates the joint state of all variables that has maximum probability
197 /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
198 */
199 std::vector<std::size_t> findMaximum() const;
200 //@}
201 };
202
203
204 /// Calculates upper bound to the treewidth of a FactorGraph, using the specified heuristic
205 /** \relates JTree
206 * \param maxStates maximum total number of states in outer regions of junction tree (0 means no limit)
207 * \throws OUT_OF_MEMORY if the total number of states becomes larger than maxStates
208 * \return a pair (number of variables in largest clique, number of states in largest clique)
209 */
210 std::pair<size_t,double> boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn, size_t maxStates=0 );
211
212
213 } // end of namespace dai
214
215
216 #endif