Improved BP (added 'maxtime' property)
[libdai.git] / include / dai / bp.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines class BP, which implements (Loopy) Belief Propagation
14 /// \todo Consider using a priority_queue for maximum residual schedule
15
16
17 #ifndef __defined_libdai_bp_h
18 #define __defined_libdai_bp_h
19
20
21 #include <string>
22 #include <dai/daialg.h>
23 #include <dai/factorgraph.h>
24 #include <dai/properties.h>
25 #include <dai/enum.h>
26
27
28 namespace dai {
29
30
31 /// Approximate inference algorithm "(Loopy) Belief Propagation"
32 /** The Loopy Belief Propagation algorithm uses message passing
33 * to approximate marginal probability distributions ("beliefs") for variables
34 * and factors (more precisely, for the subset of variables depending on the factor).
35 * There are two variants, the sum-product algorithm (corresponding to
36 * finite temperature) and the max-product algorithm (corresponding to
37 * zero temperature).
38 *
39 * The messages \f$m_{I\to i}(x_i)\f$ are passed from factors \f$I\f$ to variables \f$i\f$.
40 * In case of the sum-product algorith, the update equation is:
41 * \f[ m_{I\to i}(x_i) \propto \sum_{x_{N_I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
42 * and in case of the max-product algorithm:
43 * \f[ m_{I\to i}(x_i) \propto \max_{x_{N_I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
44 * In order to improve convergence, the updates can be damped. For improved numerical stability,
45 * the updates can be done in the log-domain alternatively.
46 *
47 * After convergence, the variable beliefs are calculated by:
48 * \f[ b_i(x_i) \propto \prod_{I\in N_i} m_{I\to i}(x_i)\f]
49 * and the factor beliefs are calculated by:
50 * \f[ b_I(x_I) \propto f_I(x_I) \prod_{j\in N_I} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}(x_j) \f]
51 * The logarithm of the partition sum is calculated by:
52 * \f[ \log Z = \sum_i (1 - |N_i|) \sum_{x_i} b_i(x_i) \log b_i(x_i) - \sum_I \sum_{x_I} b_I(x_I) \log \frac{b_I(x_I)}{f_I(x_I)} \f]
53 *
54 * For the max-product algorithm, a heuristic way of finding the MAP state (the
55 * joint configuration of all variables which has maximum probability) is provided
56 * by the findMaximum() method, which can be called after convergence.
57 *
58 * \note There are two implementations, an optimized one (the default) which caches IndexFor objects,
59 * and a slower, less complicated one which is easier to maintain/understand. The slower one can be
60 * enabled by defining DAI_BP_FAST as false in the source file.
61 */
62 class BP : public DAIAlgFG {
63 protected:
64 /// Type used for index cache
65 typedef std::vector<size_t> ind_t;
66 /// Type used for storing edge properties
67 struct EdgeProp {
68 /// Index cached for this edge
69 ind_t index;
70 /// Old message living on this edge
71 Prob message;
72 /// New message living on this edge
73 Prob newMessage;
74 /// Residual for this edge
75 Real residual;
76 };
77 /// Stores all edge properties
78 std::vector<std::vector<EdgeProp> > _edges;
79 /// Type of lookup table (only used for maximum-residual BP)
80 typedef std::multimap<Real, std::pair<std::size_t, std::size_t> > LutType;
81 /// Lookup table (only used for maximum-residual BP)
82 std::vector<std::vector<LutType::iterator> > _edge2lut;
83 /// Lookup table (only used for maximum-residual BP)
84 LutType _lut;
85 /// Maximum difference between variable beliefs encountered so far
86 Real _maxdiff;
87 /// Number of iterations needed
88 size_t _iters;
89 /// The history of message updates (only recorded if \a recordSentMessages is \c true)
90 std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
91
92 public:
93 /// Parameters for BP
94 struct Properties {
95 /// Enumeration of possible update schedules
96 /** The following update schedules have been defined:
97 * - PARALL parallel updates
98 * - SEQFIX sequential updates using a fixed sequence
99 * - SEQRND sequential updates using a random sequence
100 * - SEQMAX maximum-residual updates [\ref EMK06]
101 */
102 DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL);
103
104 /// Enumeration of inference variants
105 /** There are two inference variants:
106 * - SUMPROD Sum-Product
107 * - MAXPROD Max-Product (equivalent to Min-Sum)
108 */
109 DAI_ENUM(InfType,SUMPROD,MAXPROD);
110
111 /// Verbosity (amount of output sent to stderr)
112 size_t verbose;
113
114 /// Maximum number of iterations
115 size_t maxiter;
116
117 /// Maximum time (in seconds)
118 double maxtime;
119
120 /// Tolerance for convergence test
121 Real tol;
122
123 /// Whether updates should be done in logarithmic domain or not
124 bool logdomain;
125
126 /// Damping constant (0.0 means no damping, 1.0 is maximum damping)
127 Real damping;
128
129 /// Message update schedule
130 UpdateType updates;
131
132 /// Inference variant
133 InfType inference;
134 } props;
135
136 /// Name of this inference algorithm
137 static const char *Name;
138
139 /// Specifies whether the history of message updates should be recorded
140 bool recordSentMessages;
141
142 /// Stores variable beliefs of previous iteration
143 std::vector<Factor> oldBeliefsV;
144
145 /// Stores factor beliefs of previous iteration
146 std::vector<Factor> oldBeliefsF;
147
148 /// Stores the update schedule
149 std::vector<Edge> updateSeq;
150
151 public:
152 /// \name Constructors/destructors
153 //@{
154 /// Default constructor
155 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
156
157 /// Construct from FactorGraph \a fg and PropertySet \a opts
158 /** \param fg Factor graph.
159 * \param opts Parameters @see Properties
160 */
161 BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
162 setProperties( opts );
163 construct();
164 }
165
166 /// Copy constructor
167 BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut),
168 _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages),
169 props(x.props), recordSentMessages(x.recordSentMessages)
170 {
171 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
172 _edge2lut[l->second.first][l->second.second] = l;
173 }
174
175 /// Assignment operator
176 BP& operator=( const BP &x ) {
177 if( this != &x ) {
178 DAIAlgFG::operator=( x );
179 _edges = x._edges;
180 _lut = x._lut;
181 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
182 _edge2lut[l->second.first][l->second.second] = l;
183 _maxdiff = x._maxdiff;
184 _iters = x._iters;
185 _sentMessages = x._sentMessages;
186 props = x.props;
187 recordSentMessages = x.recordSentMessages;
188 }
189 return *this;
190 }
191 //@}
192
193 /// \name General InfAlg interface
194 //@{
195 virtual BP* clone() const { return new BP(*this); }
196 virtual std::string identify() const;
197 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
198 virtual Factor belief( const VarSet &vs ) const;
199 virtual Factor beliefV( size_t i ) const;
200 virtual Factor beliefF( size_t I ) const;
201 virtual std::vector<Factor> beliefs() const;
202 virtual Real logZ() const;
203 virtual void init();
204 virtual void init( const VarSet &ns );
205 virtual Real run();
206 virtual Real maxDiff() const { return _maxdiff; }
207 virtual size_t Iterations() const { return _iters; }
208 virtual void setProperties( const PropertySet &opts );
209 virtual PropertySet getProperties() const;
210 virtual std::string printProperties() const;
211 //@}
212
213 /// \name Additional interface specific for BP
214 //@{
215 /// Calculates the joint state of all variables that has maximum probability
216 /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
217 */
218 std::vector<std::size_t> findMaximum() const;
219
220 /// Returns history of which messages have been updated
221 const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
222 return _sentMessages;
223 }
224
225 /// Clears history of which messages have been updated
226 void clearSentMessages() { _sentMessages.clear(); }
227 //@}
228
229 protected:
230 /// Returns constant reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
231 const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
232 /// Returns reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
233 Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
234 /// Returns constant reference to updated message from the \a _I 'th neighbor of variable \a i to variable \a i
235 const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; }
236 /// Returns reference to updated message from the \a _I 'th neighbor of variable \a i to variable \a i
237 Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; }
238 /// Returns constant reference to cached index for the edge between variable \a i and its \a _I 'th neighbor
239 const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; }
240 /// Returns reference to cached index for the edge between variable \a i and its \a _I 'th neighbor
241 ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; }
242 /// Returns constant reference to residual for the edge between variable \a i and its \a _I 'th neighbor
243 const Real & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; }
244 /// Returns reference to residual for the edge between variable \a i and its \a _I 'th neighbor
245 Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
246
247 /// Calculate the product of factor \a I and the incoming messages
248 /** If \a without_i == \c true, the message coming from variable \a i is omitted from the product
249 * \note This function is used by calcNewMessage() and calcBeliefF()
250 */
251 virtual Prob calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const;
252 /// Calculate the updated message from the \a _I 'th neighbor of variable \a i to variable \a i
253 virtual void calcNewMessage( size_t i, size_t _I );
254 /// Replace the "old" message from the \a _I 'th neighbor of variable \a i to variable \a i by the "new" (updated) message
255 void updateMessage( size_t i, size_t _I );
256 /// Set the residual (difference between new and old message) for the edge between variable \a i and its \a _I 'th neighbor to \a r
257 void updateResidual( size_t i, size_t _I, Real r );
258 /// Finds the edge which has the maximum residual (difference between new and old message)
259 void findMaxResidual( size_t &i, size_t &_I );
260 /// Calculates unnormalized belief of variable \a i
261 virtual void calcBeliefV( size_t i, Prob &p ) const;
262 /// Calculates unnormalized belief of factor \a I
263 virtual void calcBeliefF( size_t I, Prob &p ) const {
264 p = calcIncomingMessageProduct( I, false, 0 );
265 }
266
267 /// Helper function for constructors
268 virtual void construct();
269 };
270
271
272 } // end of namespace dai
273
274
275 #endif