b7041e24011d51b6f32f8d8e12534b03d91116a3
[libdai.git] / include / dai / jtree.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 /// \file
10 /// \brief Defines class JTree, which implements the junction tree algorithm
11
12
13 #ifndef __defined_libdai_jtree_h
14 #define __defined_libdai_jtree_h
15
16
17 #include <vector>
18 #include <string>
19 #include <dai/daialg.h>
20 #include <dai/varset.h>
21 #include <dai/regiongraph.h>
22 #include <dai/factorgraph.h>
23 #include <dai/clustergraph.h>
24 #include <dai/weightedgraph.h>
25 #include <dai/enum.h>
26 #include <dai/properties.h>
27
28
29 namespace dai {
30
31
32 /// Exact inference algorithm using junction tree
33 /** The junction tree algorithm uses message passing on a junction tree to calculate
34 * exact marginal probability distributions ("beliefs") for specified cliques
35 * (outer regions) and separators (intersections of pairs of cliques).
36 *
37 * There are two variants, the sum-product algorithm (corresponding to
38 * finite temperature) and the max-product algorithm (corresponding to
39 * zero temperature).
40 */
41 class JTree : public DAIAlgRG {
42 private:
43 /// Stores the messages
44 std::vector<std::vector<Factor> > _mes;
45
46 /// Stores the logarithm of the partition sum
47 Real _logZ;
48
49 public:
50 /// The junction tree (stored as a rooted tree)
51 RootedTree RTree;
52
53 /// Outer region beliefs
54 std::vector<Factor> Qa;
55
56 /// Inner region beliefs
57 std::vector<Factor> Qb;
58
59 /// Parameters for JTree
60 struct Properties {
61 /// Enumeration of possible JTree updates
62 /** There are two types of updates:
63 * - HUGIN similar to those in HUGIN
64 * - SHSH Shafer-Shenoy type
65 */
66 DAI_ENUM(UpdateType,HUGIN,SHSH);
67
68 /// Enumeration of inference variants
69 /** There are two inference variants:
70 * - SUMPROD Sum-Product
71 * - MAXPROD Max-Product (equivalent to Min-Sum)
72 */
73 DAI_ENUM(InfType,SUMPROD,MAXPROD);
74
75 /// Enumeration of elimination cost functions used for constructing the junction tree
76 /** The cost of eliminating a variable can be (\see [\ref KoF09], page 314)):
77 * - MINNEIGHBORS the number of neighbors it has in the current adjacency graph;
78 * - MINWEIGHT the product of the number of states of all neighbors in the current adjacency graph;
79 * - MINFILL the number of edges that need to be added to the adjacency graph due to the elimination;
80 * - WEIGHTEDMINFILL the sum of weights of the edges that need to be added to the adjacency graph
81 * due to the elimination, where a weight of an edge is the produt of weights of its constituent
82 * vertices.
83 * The elimination sequence is chosen greedily in order to minimize the cost.
84 */
85 DAI_ENUM(HeuristicType,MINNEIGHBORS,MINWEIGHT,MINFILL,WEIGHTEDMINFILL);
86
87 /// Verbosity (amount of output sent to stderr)
88 size_t verbose;
89
90 /// Type of updates
91 UpdateType updates;
92
93 /// Type of inference
94 InfType inference;
95
96 /// Heuristic to use for constructing the junction tree
97 HeuristicType heuristic;
98
99 /// Maximum memory to use in bytes (0 means unlimited)
100 size_t maxmem;
101 } props;
102
103 public:
104 /// \name Constructors/destructors
105 //@{
106 /// Default constructor
107 JTree() : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {}
108
109 /// Construct from FactorGraph \a fg and PropertySet \a opts
110 /** \param fg factor graph
111 ** \param opts Parameters @see Properties
112 * \param automatic if \c true, construct the junction tree automatically, using the heuristic in opts['heuristic'].
113 */
114 JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true );
115 //@}
116
117
118 /// \name General InfAlg interface
119 //@{
120 virtual JTree* clone() const { return new JTree(*this); }
121 virtual JTree* construct( const FactorGraph &fg, const PropertySet &opts ) const { return new JTree( fg, opts ); }
122 virtual std::string name() const { return "JTREE"; }
123 virtual Factor belief( const VarSet &vs ) const;
124 virtual std::vector<Factor> beliefs() const;
125 virtual Real logZ() const;
126 /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
127 */
128 std::vector<std::size_t> findMaximum() const;
129 virtual void init() {}
130 virtual void init( const VarSet &/*ns*/ ) {}
131 virtual Real run();
132 virtual Real maxDiff() const { return 0.0; }
133 virtual size_t Iterations() const { return 1UL; }
134 virtual void setProperties( const PropertySet &opts );
135 virtual PropertySet getProperties() const;
136 virtual std::string printProperties() const;
137 //@}
138
139
140 /// \name Additional interface specific for JTree
141 //@{
142 /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
143 /** First, constructs a weighted graph, where the nodes are the elements of \a cl, and
144 * each edge is weighted with the cardinality of the intersection of the state spaces of the nodes.
145 * Then, a maximal spanning tree for this weighted graph is calculated.
146 * Subsequently, a corresponding region graph is built:
147 * - the outer regions correspond with the cliques and have counting number 1;
148 * - the inner regions correspond with the seperators, i.e., the intersections of two
149 * cliques that are neighbors in the spanning tree, and have counting number -1
150 * (except empty ones, which have counting number 0);
151 * - inner and outer regions are connected by an edge if the inner region is a
152 * seperator for the outer region.
153 * Finally, Beliefs are constructed.
154 * If \a verify == \c true, checks whether each factor is subsumed by a clique.
155 */
156 void construct( const FactorGraph &fg, const std::vector<VarSet> &cl, bool verify=false );
157
158 /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
159 /** Invokes construct() and then constructs messages.
160 * \see construct()
161 */
162 void GenerateJT( const FactorGraph &fg, const std::vector<VarSet> &cl );
163
164 /// Returns constant reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
165 const Factor & message( size_t alpha, size_t _beta ) const { return _mes[alpha][_beta]; }
166 /// Returns reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
167 Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }
168
169 /// Runs junction tree algorithm using HUGIN (message-free) updates
170 /** \note The initial messages may be arbitrary; actually they are not used at all.
171 */
172 void runHUGIN();
173
174 /// Runs junction tree algorithm using Shafer-Shenoy updates
175 /** \note The initial messages may be arbitrary.
176 */
177 void runShaferShenoy();
178
179 /// Finds an efficient subtree for calculating the marginal of the variables in \a vs
180 /** First, the current junction tree is reordered such that it gets as root the clique
181 * that has maximal state space overlap with \a vs. Then, the minimal subtree
182 * (starting from the root) is identified that contains all the variables in \a vs
183 * and also the outer region with index \a PreviousRoot (if specified). Finally,
184 * the current junction tree is reordered such that this minimal subtree comes
185 * before the other edges, and the size of the minimal subtree is returned.
186 */
187 size_t findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot=(size_t)-1 ) const;
188
189 /// Calculates the marginal of a set of variables (using cutset conditioning, if necessary)
190 /** \pre assumes that run() has been called already
191 */
192 Factor calcMarginal( const VarSet& vs );
193 //@}
194 };
195
196
197 /// Calculates upper bound to the treewidth of a FactorGraph, using the specified heuristic
198 /** \relates JTree
199 * \param fg the factor graph for which the treewidth should be bounded
200 * \param fn the heuristic cost function used for greedy variable elimination
201 * \param maxStates maximum total number of states in outer regions of junction tree (0 means no limit)
202 * \throws OUT_OF_MEMORY if the total number of states becomes larger than maxStates
203 * \return a pair (number of variables in largest clique, number of states in largest clique)
204 */
205 std::pair<size_t,BigInt> boundTreewidth( const FactorGraph &fg, greedyVariableElimination::eliminationCostFunction fn, size_t maxStates=0 );
206
207
208 } // end of namespace dai
209
210
211 #endif