Moved alias code from tests/testdai.cpp to src/alldai.cpp
[libdai.git] / include / dai / treeep.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines class TreeEP, which implements Tree Expectation Propagation
14 /// \todo Clean up the TreeEP code (exploiting that a large part of the code
15 /// is just a special case of JTree).
16
17
18 #ifndef __defined_libdai_treeep_h
19 #define __defined_libdai_treeep_h
20
21
22 #include <vector>
23 #include <string>
24 #include <dai/daialg.h>
25 #include <dai/varset.h>
26 #include <dai/regiongraph.h>
27 #include <dai/factorgraph.h>
28 #include <dai/clustergraph.h>
29 #include <dai/weightedgraph.h>
30 #include <dai/jtree.h>
31 #include <dai/properties.h>
32 #include <dai/enum.h>
33
34
35 namespace dai {
36
37
38 /// Approximate inference algorithm "Tree Expectation Propagation" [\ref MiQ04]
39 class TreeEP : public JTree {
40 private:
41 /// Maximum difference encountered so far
42 Real _maxdiff;
43 /// Number of iterations needed
44 size_t _iters;
45
46 public:
47 /// Parameters for TreeEP
48 struct Properties {
49 /// Enumeration of possible choices for the tree
50 /** The two possibilities are:
51 * - \c ORG: take the maximum spanning tree where the weights are crude
52 * estimates of the mutual information between the nodes;
53 * - \c ALT: take the maximum spanning tree where the weights are upper
54 * bounds on the effective interaction strengths between pairs of nodes.
55 */
56 DAI_ENUM(TypeType,ORG,ALT);
57
58 /// Verbosity (amount of output sent to stderr)
59 size_t verbose;
60
61 /// Maximum number of iterations
62 size_t maxiter;
63
64 /// Tolerance for convergence test
65 Real tol;
66
67 /// How to choose the tree
68 TypeType type;
69 } props;
70
71 /// Name of this inference method
72 static const char *Name;
73
74 private:
75 /// Stores the data structures needed to efficiently update the approximation of an off-tree factor.
76 /** The TreeEP object stores a TreeEPSubTree object for each off-tree factor.
77 * It stores the approximation of that off-tree factor, which is represented
78 * as a distribution on a subtree of the main tree.
79 */
80 class TreeEPSubTree {
81 private:
82 /// Outer region pseudomarginals (corresponding with the \f$\tilde f_i(x_j,x_k)\f$ in [\ref MiQ04])
83 std::vector<Factor> _Qa;
84 /// Inner region pseudomarginals (corresponding with the \f$\tilde f_i(x_s)\f$ in [\ref MiQ04])
85 std::vector<Factor> _Qb;
86 /// The junction tree (stored as a rooted tree)
87 RootedTree _RTree;
88 /// Index conversion table for outer region indices (_Qa[alpha] corresponds with Qa[_a[alpha]] of the supertree)
89 std::vector<size_t> _a;
90 /// Index conversion table for inner region indices (_Qb[beta] corresponds with Qb[_b[beta]] of the supertree)
91 std::vector<size_t> _b;
92 /// Pointer to off-tree factor
93 const Factor * _I;
94 /// Variables in off-tree factor
95 VarSet _ns;
96 /// Variables in off-tree factor which are not in the root of this subtree
97 VarSet _nsrem;
98 /// Used for calculating the free energy
99 Real _logZ;
100
101 public:
102 /// \name Constructors/destructors
103 //@{
104 /// Default constructor
105 TreeEPSubTree() : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(NULL), _ns(), _nsrem(), _logZ(0.0) {}
106
107 /// Copy constructor
108 TreeEPSubTree( const TreeEPSubTree &x ) : _Qa(x._Qa), _Qb(x._Qb), _RTree(x._RTree), _a(x._a), _b(x._b), _I(x._I), _ns(x._ns), _nsrem(x._nsrem), _logZ(x._logZ) {}
109
110 /// Assignment operator
111 TreeEPSubTree & operator=( const TreeEPSubTree& x ) {
112 if( this != &x ) {
113 _Qa = x._Qa;
114 _Qb = x._Qb;
115 _RTree = x._RTree;
116 _a = x._a;
117 _b = x._b;
118 _I = x._I;
119 _ns = x._ns;
120 _nsrem = x._nsrem;
121 _logZ = x._logZ;
122 }
123 return *this;
124 }
125
126 /// Construct from \a subRTree, which is a subtree of the main tree \a jt_RTree, with distribution represented by \a jt_Qa and \a jt_Qb, for off-tree factor \a I
127 TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I );
128 //@}
129
130 /// Initializes beliefs of this subtree
131 void init();
132
133 /// Inverts this approximation and multiplies it by the (super) junction tree marginals \a Qa and \a Qb
134 void InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb );
135
136 /// Runs junction tree algorithm (including off-tree factor I) storing the results in the (super) junction tree \a Qa and \a Qb
137 void HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb );
138
139 /// Returns energy (?) of this subtree
140 Real logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const;
141
142 /// Returns constant reference to the pointer to the off-tree factor
143 const Factor *& I() { return _I; }
144 };
145
146 /// Stores a TreeEPSubTree object for each off-tree factor
147 std::map<size_t, TreeEPSubTree> _Q;
148
149 public:
150 /// Default constructor
151 TreeEP() : JTree(), _maxdiff(0.0), _iters(0), props(), _Q() {}
152
153 /// Copy constructor
154 TreeEP( const TreeEP &x ) : JTree(x), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props), _Q(x._Q) {
155 for( size_t I = 0; I < nrFactors(); I++ )
156 if( offtree( I ) )
157 _Q[I].I() = &factor(I);
158 }
159
160 /// Assignment operator
161 TreeEP& operator=( const TreeEP &x ) {
162 if( this != &x ) {
163 JTree::operator=( x );
164 _maxdiff = x._maxdiff;
165 _iters = x._iters;
166 props = x.props;
167 _Q = x._Q;
168 for( size_t I = 0; I < nrFactors(); I++ )
169 if( offtree( I ) )
170 _Q[I].I() = &factor(I);
171 }
172 return *this;
173 }
174
175 /// Construct from FactorGraph \a fg and PropertySet \a opts
176 /** \param opts Parameters @see Properties
177 * \note The factor graph has to be connected.
178 * \throw FACTORGRAPH_NOT_CONNECTED if \a fg is not connected
179 */
180 TreeEP( const FactorGraph &fg, const PropertySet &opts );
181
182
183 /// \name General InfAlg interface
184 //@{
185 virtual TreeEP* clone() const { return new TreeEP(*this); }
186 virtual std::string identify() const;
187 virtual Real logZ() const;
188 virtual void init();
189 virtual void init( const VarSet &/*ns*/ ) { init(); }
190 virtual Real run();
191 virtual Real maxDiff() const { return _maxdiff; }
192 virtual size_t Iterations() const { return _iters; }
193 virtual void setProperties( const PropertySet &opts );
194 virtual PropertySet getProperties() const;
195 virtual std::string printProperties() const;
196 //@}
197
198 private:
199 /// Helper function for constructors
200 void construct( const RootedTree &tree );
201 /// Returns \c true if factor \a I is not part of the tree
202 bool offtree( size_t I ) const { return (fac2OR[I] == -1U); }
203 };
204
205
206 } // end of namespace dai
207
208
209 #endif