Improved documentation of include/dai/bp.h
[libdai.git] / include / dai / bp.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines class BP, which implements (Loopy) Belief Propagation
14
15
16 #ifndef __defined_libdai_bp_h
17 #define __defined_libdai_bp_h
18
19
20 #include <string>
21 #include <dai/daialg.h>
22 #include <dai/factorgraph.h>
23 #include <dai/properties.h>
24 #include <dai/enum.h>
25
26
27 namespace dai {
28
29
30 /// Approximate inference algorithm "(Loopy) Belief Propagation"
31 /** The Loopy Belief Propagation algorithm uses message passing
32 * to approximate marginal probability distributions ("beliefs") for variables
33 * and factors (more precisely, for the subset of variables depending on the factor).
34 * There are two variants, the sum-product algorithm (corresponding to
35 * finite temperature) and the max-product algorithm (corresponding to
36 * zero temperature).
37 *
38 * The messages \f$m_{I\to i}(x_i)\f$ are passed from factors \f$I\f$ to variables \f$i\f$.
39 * In case of the sum-product algorith, the update equation is:
40 * \f[ m_{I\to i}(x_i) \propto \sum_{x_{I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
41 * and in case of the max-product algorithm:
42 * \f[ m_{I\to i}(x_i) \propto \max_{x_{I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
43 * In order to improve convergence, the updates can be damped. For improved numerical stability,
44 * the updates can be done in the log-domain alternatively.
45 *
46 * After convergence, the variable beliefs are calculated by:
47 * \f[ b_i(x_i) \propto \prod_{I\in N_i} m_{I\to i}(x_i)\f]
48 * and the factor beliefs are calculated by:
49 * \f[ b_I(x_I) \propto f_I(x_I) \prod_{j\in N_I} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}(x_j) \f]
50 * The logarithm of the partition sum is calculated by:
51 * \f[ \log Z = \sum_i (1 - |N_i|) \sum_{x_i} b_i(x_i) \log b_i(x_i) - \sum_I \sum_{x_I} b_I(x_I) \log \frac{b_I(x_I)}{f_I(x_I)} \f]
52 *
53 * There are several predefined update schedules:
54 * - PARALL parallel updates
55 * - SEQFIX sequential updates using a fixed sequence
56 * - SEQRND sequential updates using a random sequence
57 * - SEQMAX maximum-residual updates [\ref EMK06]
58 *
59 * For the max-product algorithm, a heuristic way of finding the MAP state (the
60 * joint configuration of all variables which has maximum probability) is provided
61 * by the findMaximum() method, which can be called after convergence.
62 *
63 * \note There are two implementations, an optimized one (the default) which caches IndexFor objects,
64 * and a slower, less complicated one which is easier to maintain/understand. The slower one can be
65 * enabled by defining DAI_BP_FAST as false in the source file.
66 */
67 class BP : public DAIAlgFG {
68 private:
69 /// Type used for index cache
70 typedef std::vector<size_t> ind_t;
71 /// Type used for storing edge properties
72 struct EdgeProp {
73 /// Index cached for this edge
74 ind_t index;
75 /// Old message living on this edge
76 Prob message;
77 /// New message living on this edge
78 Prob newMessage;
79 /// Residual for this edge
80 Real residual;
81 };
82 /// Stores all edge properties
83 std::vector<std::vector<EdgeProp> > _edges;
84 /// Type of lookup table (only used for maximum-residual BP)
85 typedef std::multimap<Real, std::pair<std::size_t, std::size_t> > LutType;
86 /// Lookup table (only used for maximum-residual BP)
87 std::vector<std::vector<LutType::iterator> > _edge2lut;
88 /// Lookup table (only used for maximum-residual BP)
89 LutType _lut;
90 /// Maximum difference between variable beliefs encountered so far
91 Real _maxdiff;
92 /// Number of iterations needed
93 size_t _iters;
94 /// The history of message updates (only recorded if \a recordSentMessages is \c true)
95 std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
96
97 public:
98 /// Parameters of this inference algorithm
99 struct Properties {
100 /// Enumeration of possible update schedules
101 DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL);
102
103 /// Enumeration of inference variants
104 DAI_ENUM(InfType,SUMPROD,MAXPROD);
105
106 /// Verbosity
107 size_t verbose;
108
109 /// Maximum number of iterations
110 size_t maxiter;
111
112 /// Tolerance for convergence test
113 Real tol;
114
115 /// Whether updates should be done in logarithmic domain or not
116 bool logdomain;
117
118 /// Damping constant (0.0 means no damping, 1.0 is maximum damping)
119 Real damping;
120
121 /// Message update schedule
122 UpdateType updates;
123
124 /// Type of inference: sum-product or max-product?
125 InfType inference;
126 } props;
127
128 /// Name of this inference algorithm
129 static const char *Name;
130
131 /// Specifies whether the history of message updates should be recorded
132 bool recordSentMessages;
133
134 public:
135 /// \name Constructors/destructors
136 //@{
137 /// Default constructor
138 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
139
140 /// Construct from FactorGraph \a fg and PropertySet \a opts
141 BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
142 setProperties( opts );
143 construct();
144 }
145
146 /// Copy constructor
147 BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut),
148 _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages),
149 props(x.props), recordSentMessages(x.recordSentMessages)
150 {
151 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
152 _edge2lut[l->second.first][l->second.second] = l;
153 }
154
155 /// Assignment operator
156 BP& operator=( const BP &x ) {
157 if( this != &x ) {
158 DAIAlgFG::operator=( x );
159 _edges = x._edges;
160 _lut = x._lut;
161 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
162 _edge2lut[l->second.first][l->second.second] = l;
163 _maxdiff = x._maxdiff;
164 _iters = x._iters;
165 _sentMessages = x._sentMessages;
166 props = x.props;
167 recordSentMessages = x.recordSentMessages;
168 }
169 return *this;
170 }
171 //@}
172
173 /// \name General InfAlg interface
174 //@{
175 virtual BP* clone() const { return new BP(*this); }
176 virtual std::string identify() const;
177 virtual Factor belief( const Var &n ) const;
178 virtual Factor belief( const VarSet &ns ) const;
179 virtual Factor beliefV( size_t i ) const;
180 virtual Factor beliefF( size_t I ) const;
181 virtual std::vector<Factor> beliefs() const;
182 virtual Real logZ() const;
183 virtual void init();
184 virtual void init( const VarSet &ns );
185 virtual Real run();
186 virtual Real maxDiff() const { return _maxdiff; }
187 virtual size_t Iterations() const { return _iters; }
188 virtual void setProperties( const PropertySet &opts );
189 virtual PropertySet getProperties() const;
190 virtual std::string printProperties() const;
191 //@}
192
193 /// \name Additional interface specific for BP
194 //@{
195 /// Calculates the joint state of all variables that has maximum probability
196 /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
197 */
198 std::vector<std::size_t> findMaximum() const;
199
200 /// Returns history of which messages have been updated
201 const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
202 return _sentMessages;
203 }
204
205 /// Clears history of which messages have been updated
206 void clearSentMessages() { _sentMessages.clear(); }
207 //@}
208
209 private:
210 /// Returns constant reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
211 const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
212 /// Returns reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
213 Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
214 /// Returns constant reference to updated message from the \a _I 'th neighbor of variable \a i to variable \a i
215 const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; }
216 /// Returns reference to updated message from the \a _I 'th neighbor of variable \a i to variable \a i
217 Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; }
218 /// Returns constant reference to cached index for the edge between variable \a i and its \a _I 'th neighbor
219 const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; }
220 /// Returns reference to cached index for the edge between variable \a i and its \a _I 'th neighbor
221 ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; }
222 /// Returns constant reference to residual for the edge between variable \a i and its \a _I 'th neighbor
223 const Real & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; }
224 /// Returns reference to residual for the edge between variable \a i and its \a _I 'th neighbor
225 Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
226
227 /// Calculate the updated message from the \a _I 'th neighbor of variable \a i to variable \a i
228 void calcNewMessage( size_t i, size_t _I );
229 /// Replace the "old" message from the \a _I 'th neighbor of variable \a i to variable \a i by the "new" (updated) message
230 void updateMessage( size_t i, size_t _I );
231 /// Set the residual (difference between new and old message) for the edge between variable \a i and its \a _I 'th neighbor to \a r
232 void updateResidual( size_t i, size_t _I, Real r );
233 /// Finds the edge which has the maximum residual (difference between new and old message)
234 void findMaxResidual( size_t &i, size_t &_I );
235 /// Calculates unnormalized belief of variable \a i
236 void calcBeliefV( size_t i, Prob &p ) const;
237 /// Calculates unnormalized belief of factor \a I
238 void calcBeliefF( size_t I, Prob &p ) const;
239
240 /// Helper function for constructors
241 void construct();
242 };
243
244
245 } // end of namespace dai
246
247
248 #endif