New git HEAD version
[libdai.git] / include / dai / treeep.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 /// \file
10 /// \brief Defines class TreeEP, which implements Tree Expectation Propagation
11 /// \todo Clean up the TreeEP code (exploiting that a large part of the code
12 /// is just a special case of JTree).
13
14
15 #ifndef __defined_libdai_treeep_h
16 #define __defined_libdai_treeep_h
17
18
19 #include <dai/dai_config.h>
20 #ifdef DAI_WITH_TREEEP
21
22
23 #include <vector>
24 #include <string>
25 #include <dai/daialg.h>
26 #include <dai/varset.h>
27 #include <dai/regiongraph.h>
28 #include <dai/factorgraph.h>
29 #include <dai/clustergraph.h>
30 #include <dai/weightedgraph.h>
31 #include <dai/jtree.h>
32 #include <dai/properties.h>
33 #include <dai/enum.h>
34
35
36 namespace dai {
37
38
39 /// Approximate inference algorithm "Tree Expectation Propagation" [\ref MiQ04]
40 class TreeEP : public JTree {
41 private:
42 /// Maximum difference encountered so far
43 Real _maxdiff;
44 /// Number of iterations needed
45 size_t _iters;
46
47 public:
48 /// Parameters for TreeEP
49 struct Properties {
50 /// Enumeration of possible choices for the tree
51 /** The two possibilities are:
52 * - \c ORG: take the maximum spanning tree where the weights are crude
53 * estimates of the mutual information between the nodes;
54 * - \c ALT: take the maximum spanning tree where the weights are upper
55 * bounds on the effective interaction strengths between pairs of nodes.
56 */
57 DAI_ENUM(TypeType,ORG,ALT);
58
59 /// Verbosity (amount of output sent to stderr)
60 size_t verbose;
61
62 /// Maximum number of iterations
63 size_t maxiter;
64
65 /// Maximum time (in seconds)
66 double maxtime;
67
68 /// Tolerance for convergence test
69 Real tol;
70
71 /// How to choose the tree
72 TypeType type;
73 } props;
74
75 private:
76 /// Stores the data structures needed to efficiently update the approximation of an off-tree factor.
77 /** The TreeEP object stores a TreeEPSubTree object for each off-tree factor.
78 * It stores the approximation of that off-tree factor, which is represented
79 * as a distribution on a subtree of the main tree.
80 */
81 class TreeEPSubTree {
82 private:
83 /// Outer region pseudomarginals (corresponding with the \f$\tilde f_i(x_j,x_k)\f$ in [\ref MiQ04])
84 std::vector<Factor> _Qa;
85 /// Inner region pseudomarginals (corresponding with the \f$\tilde f_i(x_s)\f$ in [\ref MiQ04])
86 std::vector<Factor> _Qb;
87 /// The junction tree (stored as a rooted tree)
88 RootedTree _RTree;
89 /// Index conversion table for outer region indices (_Qa[alpha] corresponds with Qa[_a[alpha]] of the supertree)
90 std::vector<size_t> _a;
91 /// Index conversion table for inner region indices (_Qb[beta] corresponds with Qb[_b[beta]] of the supertree)
92 std::vector<size_t> _b;
93 /// Pointer to off-tree factor
94 const Factor * _I;
95 /// Variables in off-tree factor
96 VarSet _ns;
97 /// Variables in off-tree factor which are not in the root of this subtree
98 VarSet _nsrem;
99 /// Used for calculating the free energy
100 Real _logZ;
101
102 public:
103 /// \name Constructors/destructors
104 //@{
105 /// Default constructor
106 TreeEPSubTree() : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(NULL), _ns(), _nsrem(), _logZ(0.0) {}
107
108 /// Copy constructor
109 TreeEPSubTree( const TreeEPSubTree &x ) : _Qa(x._Qa), _Qb(x._Qb), _RTree(x._RTree), _a(x._a), _b(x._b), _I(x._I), _ns(x._ns), _nsrem(x._nsrem), _logZ(x._logZ) {}
110
111 /// Assignment operator
112 TreeEPSubTree & operator=( const TreeEPSubTree& x ) {
113 if( this != &x ) {
114 _Qa = x._Qa;
115 _Qb = x._Qb;
116 _RTree = x._RTree;
117 _a = x._a;
118 _b = x._b;
119 _I = x._I;
120 _ns = x._ns;
121 _nsrem = x._nsrem;
122 _logZ = x._logZ;
123 }
124 return *this;
125 }
126
127 /// Construct from \a subRTree, which is a subtree of the main tree \a jt_RTree, with distribution represented by \a jt_Qa and \a jt_Qb, for off-tree factor \a I
128 TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I );
129 //@}
130
131 /// Initializes beliefs of this subtree
132 void init();
133
134 /// Inverts this approximation and multiplies it by the (super) junction tree marginals \a Qa and \a Qb
135 void InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb );
136
137 /// Runs junction tree algorithm (including off-tree factor I) storing the results in the (super) junction tree \a Qa and \a Qb
138 void HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb );
139
140 /// Returns energy (?) of this subtree
141 Real logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const;
142
143 /// Returns constant reference to the pointer to the off-tree factor
144 const Factor *& I() { return _I; }
145 };
146
147 /// Stores a TreeEPSubTree object for each off-tree factor
148 std::map<size_t, TreeEPSubTree> _Q;
149
150 public:
151 /// Default constructor
152 TreeEP() : JTree(), _maxdiff(0.0), _iters(0), props(), _Q() {}
153
154 /// Copy constructor
155 TreeEP( const TreeEP &x ) : JTree(x), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props), _Q(x._Q) {
156 for( size_t I = 0; I < nrFactors(); I++ )
157 if( offtree( I ) )
158 _Q[I].I() = &factor(I);
159 }
160
161 /// Assignment operator
162 TreeEP& operator=( const TreeEP &x ) {
163 if( this != &x ) {
164 JTree::operator=( x );
165 _maxdiff = x._maxdiff;
166 _iters = x._iters;
167 props = x.props;
168 _Q = x._Q;
169 for( size_t I = 0; I < nrFactors(); I++ )
170 if( offtree( I ) )
171 _Q[I].I() = &factor(I);
172 }
173 return *this;
174 }
175
176 /// Construct from FactorGraph \a fg and PropertySet \a opts
177 /** \param fg Factor graph.
178 * \param opts Parameters @see Properties
179 */
180 TreeEP( const FactorGraph &fg, const PropertySet &opts );
181
182
183 /// \name General InfAlg interface
184 //@{
185 virtual TreeEP* clone() const { return new TreeEP(*this); }
186 virtual TreeEP* construct( const FactorGraph &fg, const PropertySet &opts ) const { return new TreeEP( fg, opts ); }
187 virtual std::string name() const { return "TREEEP"; }
188 virtual Real logZ() const;
189 virtual void init();
190 virtual void init( const VarSet &/*ns*/ ) { init(); }
191 virtual Real run();
192 virtual Real maxDiff() const { return _maxdiff; }
193 virtual size_t Iterations() const { return _iters; }
194 virtual void setMaxIter( size_t maxiter ) { props.maxiter = maxiter; }
195 virtual void setProperties( const PropertySet &opts );
196 virtual PropertySet getProperties() const;
197 virtual std::string printProperties() const;
198 //@}
199
200 private:
201 /// Helper function for constructors
202 void construct( const FactorGraph& fg, const RootedTree& tree );
203 /// Returns \c true if factor \a I is not part of the tree
204 bool offtree( size_t I ) const { return (fac2OR(I) == -1U); }
205 };
206
207
208 } // end of namespace dai
209
210
211 #endif
212
213
214 #endif