1 /* This file is part of libDAI - http://www.libdai.org/
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
10 /// \brief Defines class JTree, which implements the junction tree algorithm
13 #ifndef __defined_libdai_jtree_h
14 #define __defined_libdai_jtree_h
17 #include <dai/dai_config.h>
23 #include <dai/daialg.h>
24 #include <dai/varset.h>
25 #include <dai/regiongraph.h>
26 #include <dai/factorgraph.h>
27 #include <dai/clustergraph.h>
28 #include <dai/weightedgraph.h>
30 #include <dai/properties.h>
36 /// Exact inference algorithm using junction tree
37 /** The junction tree algorithm uses message passing on a junction tree to calculate
38 * exact marginal probability distributions ("beliefs") for specified cliques
39 * (outer regions) and separators (intersections of pairs of cliques).
41 * There are two variants, the sum-product algorithm (corresponding to
42 * finite temperature) and the max-product algorithm (corresponding to
45 class JTree
: public DAIAlgRG
{
47 /// Stores the messages
48 std::vector
<std::vector
<Factor
> > _mes
;
50 /// Stores the logarithm of the partition sum
54 /// The junction tree (stored as a rooted tree)
57 /// Outer region beliefs
58 std::vector
<Factor
> Qa
;
60 /// Inner region beliefs
61 std::vector
<Factor
> Qb
;
63 /// Parameters for JTree
65 /// Enumeration of possible JTree updates
66 /** There are two types of updates:
67 * - HUGIN similar to those in HUGIN
68 * - SHSH Shafer-Shenoy type
70 DAI_ENUM(UpdateType
,HUGIN
,SHSH
);
72 /// Enumeration of inference variants
73 /** There are two inference variants:
74 * - SUMPROD Sum-Product
75 * - MAXPROD Max-Product (equivalent to Min-Sum)
77 DAI_ENUM(InfType
,SUMPROD
,MAXPROD
);
79 /// Enumeration of elimination cost functions used for constructing the junction tree
80 /** The cost of eliminating a variable can be (\see [\ref KoF09], page 314)):
81 * - MINNEIGHBORS the number of neighbors it has in the current adjacency graph;
82 * - MINWEIGHT the product of the number of states of all neighbors in the current adjacency graph;
83 * - MINFILL the number of edges that need to be added to the adjacency graph due to the elimination;
84 * - WEIGHTEDMINFILL the sum of weights of the edges that need to be added to the adjacency graph
85 * due to the elimination, where a weight of an edge is the produt of weights of its constituent
87 * The elimination sequence is chosen greedily in order to minimize the cost.
89 DAI_ENUM(HeuristicType
,MINNEIGHBORS
,MINWEIGHT
,MINFILL
,WEIGHTEDMINFILL
);
91 /// Verbosity (amount of output sent to stderr)
100 /// Heuristic to use for constructing the junction tree
101 HeuristicType heuristic
;
103 /// Maximum memory to use in bytes (0 means unlimited)
108 /// \name Constructors/destructors
110 /// Default constructor
111 JTree() : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {}
113 /// Construct from FactorGraph \a fg and PropertySet \a opts
114 /** \param fg factor graph
115 ** \param opts Parameters @see Properties
116 * \param automatic if \c true, construct the junction tree automatically, using the heuristic in opts['heuristic'].
118 JTree( const FactorGraph
&fg
, const PropertySet
&opts
, bool automatic
=true );
122 /// \name General InfAlg interface
124 virtual JTree
* clone() const { return new JTree(*this); }
125 virtual JTree
* construct( const FactorGraph
&fg
, const PropertySet
&opts
) const { return new JTree( fg
, opts
); }
126 virtual std::string
name() const { return "JTREE"; }
127 virtual Factor
belief( const VarSet
&vs
) const;
128 virtual std::vector
<Factor
> beliefs() const;
129 virtual Real
logZ() const;
130 /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
132 std::vector
<size_t> findMaximum() const;
133 virtual void init() {}
134 virtual void init( const VarSet
&/*ns*/ ) {}
136 virtual Real
maxDiff() const { return 0.0; }
137 virtual size_t Iterations() const { return 1UL; }
138 virtual void setProperties( const PropertySet
&opts
);
139 virtual PropertySet
getProperties() const;
140 virtual std::string
printProperties() const;
144 /// \name Additional interface specific for JTree
146 /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
147 /** First, constructs a weighted graph, where the nodes are the elements of \a cl, and
148 * each edge is weighted with the cardinality of the intersection of the state spaces of the nodes.
149 * Then, a maximal spanning tree for this weighted graph is calculated.
150 * Subsequently, a corresponding region graph is built:
151 * - the outer regions correspond with the cliques and have counting number 1;
152 * - the inner regions correspond with the seperators, i.e., the intersections of two
153 * cliques that are neighbors in the spanning tree, and have counting number -1
154 * (except empty ones, which have counting number 0);
155 * - inner and outer regions are connected by an edge if the inner region is a
156 * seperator for the outer region.
157 * Finally, Beliefs are constructed.
158 * If \a verify == \c true, checks whether each factor is subsumed by a clique.
160 void construct( const FactorGraph
&fg
, const std::vector
<VarSet
> &cl
, bool verify
=false );
162 /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
163 /** Invokes construct() and then constructs messages.
166 void GenerateJT( const FactorGraph
&fg
, const std::vector
<VarSet
> &cl
);
168 /// Returns constant reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
169 const Factor
& message( size_t alpha
, size_t _beta
) const { return _mes
[alpha
][_beta
]; }
170 /// Returns reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
171 Factor
& message( size_t alpha
, size_t _beta
) { return _mes
[alpha
][_beta
]; }
173 /// Runs junction tree algorithm using HUGIN (message-free) updates
174 /** \note The initial messages may be arbitrary; actually they are not used at all.
178 /// Runs junction tree algorithm using Shafer-Shenoy updates
179 /** \note The initial messages may be arbitrary.
181 void runShaferShenoy();
183 /// Finds an efficient subtree for calculating the marginal of the variables in \a vs
184 /** First, the current junction tree is reordered such that it gets as root the clique
185 * that has maximal state space overlap with \a vs. Then, the minimal subtree
186 * (starting from the root) is identified that contains all the variables in \a vs
187 * and also the outer region with index \a PreviousRoot (if specified). Finally,
188 * the current junction tree is reordered such that this minimal subtree comes
189 * before the other edges, and the size of the minimal subtree is returned.
191 size_t findEfficientTree( const VarSet
& vs
, RootedTree
&Tree
, size_t PreviousRoot
=(size_t)-1 ) const;
193 /// Calculates the marginal of a set of variables (using cutset conditioning, if necessary)
194 /** \pre assumes that run() has been called already
196 Factor
calcMarginal( const VarSet
& vs
);
201 /// Calculates upper bound to the treewidth of a FactorGraph, using the specified heuristic
203 * \param fg the factor graph for which the treewidth should be bounded
204 * \param fn the heuristic cost function used for greedy variable elimination
205 * \param maxStates maximum total number of states in outer regions of junction tree (0 means no limit)
206 * \throws OUT_OF_MEMORY if the total number of states becomes larger than maxStates
207 * \return a pair (number of variables in largest clique, number of states in largest clique)
209 std::pair
<size_t,BigInt
> boundTreewidth( const FactorGraph
&fg
, greedyVariableElimination::eliminationCostFunction fn
, size_t maxStates
=0 );
212 } // end of namespace dai