Added TProb<T>::operator==( const TProb<T> & ) and added some unit tests for prob...
[libdai.git] / include / dai / mr.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2007 Bastian Wemmenhove
8 * Copyright (C) 2007-2010 Joris Mooij [joris dot mooij at libdai dot org]
9 * Copyright (C) 2007 Radboud University Nijmegen, The Netherlands
10 */
11
12
13 /// \file
14 /// \brief Defines class MR, which implements loop corrections as proposed by Montanari and Rizzo
15
16
17 #ifndef __defined_libdai_mr_h
18 #define __defined_libdai_mr_h
19
20
21 #include <vector>
22 #include <string>
23 #include <dai/factorgraph.h>
24 #include <dai/daialg.h>
25 #include <dai/enum.h>
26 #include <dai/properties.h>
27 #include <dai/exceptions.h>
28 #include <dai/graph.h>
29 #include <boost/dynamic_bitset.hpp>
30
31
32 namespace dai {
33
34
35 /// Approximate inference algorithm by Montanari and Rizzo [\ref MoR05]
36 /** \author Bastian Wemmenhove wrote the original implementation before it was merged into libDAI
37 */
38 class MR : public DAIAlgFG {
39 private:
40 /// Is the underlying factor graph supported?
41 bool supported;
42
43 /// The interaction graph (Markov graph)
44 GraphAL G;
45
46 /// Convenience shortcut
47 typedef GraphAL::Neighbor Neighbor;
48
49 /// tJ[i][_j] is the hyperbolic tangent of the interaction between spin \a i and its neighbour G.nb(i,_j)
50 std::vector<std::vector<Real> > tJ;
51 /// theta[i] is the local field on spin \a i
52 std::vector<Real> theta;
53
54 /// M[i][_j] is \f$ M^{(i)}_j \f$
55 std::vector<std::vector<Real> > M;
56 /// Cavity correlations
57 std::vector<std::vector<std::vector<Real> > > cors;
58
59 /// Type used for managing a subset of neighbors
60 typedef boost::dynamic_bitset<> sub_nb;
61
62 /// Magnetizations
63 std::vector<Real> Mag;
64
65 /// Maximum difference encountered so far
66 Real _maxdiff;
67
68 /// Number of iterations needed
69 size_t _iters;
70
71 public:
72 /// Parameters for MR
73 struct Properties {
74 /// Enumeration of different types of update equations
75 /** The possible update equations are:
76 * - FULL full updates, slow but accurate
77 * - LINEAR linearized updates, faster but less accurate
78 */
79 DAI_ENUM(UpdateType,FULL,LINEAR);
80
81 /// Enumeration of different ways of initializing the cavity correlations
82 /** The possible cavity initializations are:
83 * - RESPPROP using response propagation ("linear response")
84 * - CLAMPING using clamping and BP
85 * - EXACT using JunctionTree
86 */
87 DAI_ENUM(InitType,RESPPROP,CLAMPING,EXACT);
88
89 /// Verbosity (amount of output sent to stderr)
90 size_t verbose;
91
92 /// Tolerance for convergence test
93 Real tol;
94
95 /// Update equations
96 UpdateType updates;
97
98 /// How to initialize the cavity correlations
99 InitType inits;
100 } props;
101
102 /// Name of this inference method
103 static const char *Name;
104
105 public:
106 /// Default constructor
107 MR() : DAIAlgFG(), supported(), G(), tJ(), theta(), M(), cors(), Mag(), _maxdiff(), _iters(), props() {}
108
109 /// Construct from FactorGraph \a fg and PropertySet \a opts
110 /** \param fg Factor graph.
111 * \param opts Parameters @see Properties
112 * \note This implementation only deals with binary variables and pairwise interactions.
113 * \throw NOT_IMPLEMENTED if \a fg has factors depending on three or more variables or has variables with more than two possible states.
114 */
115 MR( const FactorGraph &fg, const PropertySet &opts );
116
117
118 /// \name General InfAlg interface
119 //@{
120 virtual MR* clone() const { return new MR(*this); }
121 virtual std::string identify() const;
122 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
123 virtual Factor belief( const VarSet &/*vs*/ ) const { DAI_THROW(NOT_IMPLEMENTED); return Factor(); }
124 virtual Factor beliefV( size_t i ) const;
125 virtual std::vector<Factor> beliefs() const;
126 virtual Real logZ() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
127 virtual void init() {}
128 virtual void init( const VarSet &/*ns*/ ) { DAI_THROW(NOT_IMPLEMENTED); }
129 virtual Real run();
130 virtual Real maxDiff() const { return _maxdiff; }
131 virtual size_t Iterations() const { return _iters; }
132 virtual void setProperties( const PropertySet &opts );
133 virtual PropertySet getProperties() const;
134 virtual std::string printProperties() const;
135 //@}
136
137 private:
138 /// Initialize cors
139 Real calcCavityCorrelations();
140
141 /// Iterate update equations for cavity fields
142 void propagateCavityFields();
143
144 /// Calculate magnetizations
145 void calcMagnetizations();
146
147 /// Calculate the product of all tJ[i][_j] for _j in A
148 /** \param i variable index
149 * \param A subset of neighbors of variable \a i
150 */
151 Real _tJ(size_t i, sub_nb A);
152
153 /// Calculate \f$ \Omega^{(i)}_{j,l} \f$ as defined in [\ref MoR05] eqn. (2.15)
154 Real Omega(size_t i, size_t _j, size_t _l);
155
156 /// Calculate \f$ T^{(i)}_A \f$ as defined in [\ref MoR05] eqn. (2.17) with \f$ A = \{l_1,l_2,\dots\} \f$
157 /** \param i variable index
158 * \param A subset of neighbors of variable \a i
159 */
160 Real T(size_t i, sub_nb A);
161
162 /// Calculates \f$ T^{(i)}_j \f$ where \a j is the \a _j 'th neighbor of \a i
163 Real T(size_t i, size_t _j);
164
165 /// Calculates \f$ \Gamma^{(i)}_{j,l_1l_2} \f$ as defined in [\ref MoR05] eqn. (2.16)
166 Real Gamma(size_t i, size_t _j, size_t _l1, size_t _l2);
167
168 /// Calculates \f$ \Gamma^{(i)}_{l_1l_2} \f$ as defined in [\ref MoK07] on page 1141
169 Real Gamma(size_t i, size_t _l1, size_t _l2);
170
171 /// Approximates moments of variables in \a A
172 /** Calculate the moment of variables in \a A from M and cors, neglecting higher order cumulants,
173 * defined as the sum over all partitions of A into subsets of cardinality two at most of the
174 * product of the cumulants (either first order, i.e. M, or second order, i.e. cors) of the
175 * entries of the partitions.
176 *
177 * \param i variable index
178 * \param A subset of neighbors of variable \a i
179 */
180 Real appM(size_t i, sub_nb A);
181
182 /// Calculate sum over all even/odd subsets B of \a A of _tJ(j,B) appM(j,B)
183 /** \param j variable index
184 * \param A subset of neighbors of variable \a j
185 * \param sum_even on return, will contain the sum over all even subsets
186 * \param sum_odd on return, will contain the sum over all odd subsets
187 */
188 void sum_subs(size_t j, sub_nb A, Real *sum_even, Real *sum_odd);
189 };
190
191
192 } // end of namespace dai
193
194
195 #endif