0769fc3e6bead03a3f2d4a07c9e711620292dac2
[libdai.git] / include / dai / bp.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 /// \file
10 /// \brief Defines class BP, which implements (Loopy) Belief Propagation
11 /// \todo Consider using a priority_queue for maximum residual schedule
12
13
14 #ifndef __defined_libdai_bp_h
15 #define __defined_libdai_bp_h
16
17
18 #include <string>
19 #include <dai/daialg.h>
20 #include <dai/factorgraph.h>
21 #include <dai/properties.h>
22 #include <dai/enum.h>
23
24
25 namespace dai {
26
27
28 /// Approximate inference algorithm "(Loopy) Belief Propagation"
29 /** The Loopy Belief Propagation algorithm uses message passing
30 * to approximate marginal probability distributions ("beliefs") for variables
31 * and factors (more precisely, for the subset of variables depending on the factor).
32 * There are two variants, the sum-product algorithm (corresponding to
33 * finite temperature) and the max-product algorithm (corresponding to
34 * zero temperature).
35 *
36 * The messages \f$m_{I\to i}(x_i)\f$ are passed from factors \f$I\f$ to variables \f$i\f$.
37 * In case of the sum-product algorith, the update equation is:
38 * \f[ m_{I\to i}(x_i) \propto \sum_{x_{N_I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
39 * and in case of the max-product algorithm:
40 * \f[ m_{I\to i}(x_i) \propto \max_{x_{N_I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
41 * In order to improve convergence, the updates can be damped. For improved numerical stability,
42 * the updates can be done in the log-domain alternatively.
43 *
44 * After convergence, the variable beliefs are calculated by:
45 * \f[ b_i(x_i) \propto \prod_{I\in N_i} m_{I\to i}(x_i)\f]
46 * and the factor beliefs are calculated by:
47 * \f[ b_I(x_I) \propto f_I(x_I) \prod_{j\in N_I} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}(x_j) \f]
48 * The logarithm of the partition sum is calculated by:
49 * \f[ \log Z = \sum_i (1 - |N_i|) \sum_{x_i} b_i(x_i) \log b_i(x_i) - \sum_I \sum_{x_I} b_I(x_I) \log \frac{b_I(x_I)}{f_I(x_I)} \f]
50 *
51 * For the max-product algorithm, a heuristic way of finding the MAP state (the
52 * joint configuration of all variables which has maximum probability) is provided
53 * by the findMaximum() method, which can be called after convergence.
54 *
55 * \note There are two implementations, an optimized one (the default) which caches IndexFor objects,
56 * and a slower, less complicated one which is easier to maintain/understand. The slower one can be
57 * enabled by defining DAI_BP_FAST as false in the source file.
58 */
59 class BP : public DAIAlgFG {
60 protected:
61 /// Type used for index cache
62 typedef std::vector<size_t> ind_t;
63 /// Type used for storing edge properties
64 struct EdgeProp {
65 /// Index cached for this edge
66 ind_t index;
67 /// Old message living on this edge
68 Prob message;
69 /// New message living on this edge
70 Prob newMessage;
71 /// Residual for this edge
72 Real residual;
73 };
74 /// Stores all edge properties
75 std::vector<std::vector<EdgeProp> > _edges;
76 /// Type of lookup table (only used for maximum-residual BP)
77 typedef std::multimap<Real, std::pair<std::size_t, std::size_t> > LutType;
78 /// Lookup table (only used for maximum-residual BP)
79 std::vector<std::vector<LutType::iterator> > _edge2lut;
80 /// Lookup table (only used for maximum-residual BP)
81 LutType _lut;
82 /// Maximum difference between variable beliefs encountered so far
83 Real _maxdiff;
84 /// Number of iterations needed
85 size_t _iters;
86 /// The history of message updates (only recorded if \a recordSentMessages is \c true)
87 std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
88 /// Stores variable beliefs of previous iteration
89 std::vector<Factor> _oldBeliefsV;
90 /// Stores factor beliefs of previous iteration
91 std::vector<Factor> _oldBeliefsF;
92 /// Stores the update schedule
93 std::vector<Edge> _updateSeq;
94
95 public:
96 /// Parameters for BP
97 struct Properties {
98 /// Enumeration of possible update schedules
99 /** The following update schedules have been defined:
100 * - PARALL parallel updates
101 * - SEQFIX sequential updates using a fixed sequence
102 * - SEQRND sequential updates using a random sequence
103 * - SEQMAX maximum-residual updates [\ref EMK06]
104 */
105 DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL);
106
107 /// Enumeration of inference variants
108 /** There are two inference variants:
109 * - SUMPROD Sum-Product
110 * - MAXPROD Max-Product (equivalent to Min-Sum)
111 */
112 DAI_ENUM(InfType,SUMPROD,MAXPROD);
113
114 /// Verbosity (amount of output sent to stderr)
115 size_t verbose;
116
117 /// Maximum number of iterations
118 size_t maxiter;
119
120 /// Maximum time (in seconds)
121 double maxtime;
122
123 /// Tolerance for convergence test
124 Real tol;
125
126 /// Whether updates should be done in logarithmic domain or not
127 bool logdomain;
128
129 /// Damping constant (0.0 means no damping, 1.0 is maximum damping)
130 Real damping;
131
132 /// Message update schedule
133 UpdateType updates;
134
135 /// Inference variant
136 InfType inference;
137 } props;
138
139 /// Specifies whether the history of message updates should be recorded
140 bool recordSentMessages;
141
142 public:
143 /// \name Constructors/destructors
144 //@{
145 /// Default constructor
146 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) {}
147
148 /// Construct from FactorGraph \a fg and PropertySet \a opts
149 /** \param fg Factor graph.
150 * \param opts Parameters @see Properties
151 */
152 BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) {
153 setProperties( opts );
154 construct();
155 }
156
157 /// Copy constructor
158 BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut), _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages), _oldBeliefsV(x._oldBeliefsV), _oldBeliefsF(x._oldBeliefsF), _updateSeq(x._updateSeq), props(x.props), recordSentMessages(x.recordSentMessages) {
159 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
160 _edge2lut[l->second.first][l->second.second] = l;
161 }
162
163 /// Assignment operator
164 BP& operator=( const BP &x ) {
165 if( this != &x ) {
166 DAIAlgFG::operator=( x );
167 _edges = x._edges;
168 _lut = x._lut;
169 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
170 _edge2lut[l->second.first][l->second.second] = l;
171 _maxdiff = x._maxdiff;
172 _iters = x._iters;
173 _sentMessages = x._sentMessages;
174 _oldBeliefsV = x._oldBeliefsV;
175 _oldBeliefsF = x._oldBeliefsF;
176 _updateSeq = x._updateSeq;
177 props = x.props;
178 recordSentMessages = x.recordSentMessages;
179 }
180 return *this;
181 }
182 //@}
183
184 /// \name General InfAlg interface
185 //@{
186 virtual BP* clone() const { return new BP(*this); }
187 virtual BP* construct( const FactorGraph &fg, const PropertySet &opts ) const { return new BP( fg, opts ); }
188 virtual std::string name() const { return "BP"; }
189 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
190 virtual Factor belief( const VarSet &vs ) const;
191 virtual Factor beliefV( size_t i ) const;
192 virtual Factor beliefF( size_t I ) const;
193 virtual std::vector<Factor> beliefs() const;
194 virtual Real logZ() const;
195 /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
196 */
197 std::vector<std::size_t> findMaximum() const { return dai::findMaximum( *this ); }
198 virtual void init();
199 virtual void init( const VarSet &ns );
200 virtual Real run();
201 virtual Real maxDiff() const { return _maxdiff; }
202 virtual size_t Iterations() const { return _iters; }
203 virtual void setMaxIter( size_t maxiter ) { props.maxiter = maxiter; }
204 virtual void setProperties( const PropertySet &opts );
205 virtual PropertySet getProperties() const;
206 virtual std::string printProperties() const;
207 //@}
208
209 /// \name Additional interface specific for BP
210 //@{
211 /// Returns history of which messages have been updated
212 const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
213 return _sentMessages;
214 }
215
216 /// Clears history of which messages have been updated
217 void clearSentMessages() { _sentMessages.clear(); }
218 //@}
219
220 protected:
221 /// Returns constant reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
222 const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
223 /// Returns reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
224 Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
225 /// Returns constant reference to updated message from the \a _I 'th neighbor of variable \a i to variable \a i
226 const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; }
227 /// Returns reference to updated message from the \a _I 'th neighbor of variable \a i to variable \a i
228 Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; }
229 /// Returns constant reference to cached index for the edge between variable \a i and its \a _I 'th neighbor
230 const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; }
231 /// Returns reference to cached index for the edge between variable \a i and its \a _I 'th neighbor
232 ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; }
233 /// Returns constant reference to residual for the edge between variable \a i and its \a _I 'th neighbor
234 const Real & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; }
235 /// Returns reference to residual for the edge between variable \a i and its \a _I 'th neighbor
236 Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
237
238 /// Calculate the product of factor \a I and the incoming messages
239 /** If \a without_i == \c true, the message coming from variable \a i is omitted from the product
240 * \note This function is used by calcNewMessage() and calcBeliefF()
241 */
242 virtual Prob calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const;
243 /// Calculate the updated message from the \a _I 'th neighbor of variable \a i to variable \a i
244 virtual void calcNewMessage( size_t i, size_t _I );
245 /// Replace the "old" message from the \a _I 'th neighbor of variable \a i to variable \a i by the "new" (updated) message
246 void updateMessage( size_t i, size_t _I );
247 /// Set the residual (difference between new and old message) for the edge between variable \a i and its \a _I 'th neighbor to \a r
248 void updateResidual( size_t i, size_t _I, Real r );
249 /// Finds the edge which has the maximum residual (difference between new and old message)
250 void findMaxResidual( size_t &i, size_t &_I );
251 /// Calculates unnormalized belief of variable \a i
252 virtual void calcBeliefV( size_t i, Prob &p ) const;
253 /// Calculates unnormalized belief of factor \a I
254 virtual void calcBeliefF( size_t I, Prob &p ) const {
255 p = calcIncomingMessageProduct( I, false, 0 );
256 }
257
258 /// Helper function for constructors
259 virtual void construct();
260 };
261
262
263 } // end of namespace dai
264
265
266 #endif