1 /* This file is part of libDAI - http://www.libdai.org/
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
10 /// \brief Defines class BP, which implements (Loopy) Belief Propagation
11 /// \todo Consider using a priority_queue for maximum residual schedule
14 #ifndef __defined_libdai_bp_h
15 #define __defined_libdai_bp_h
18 #include <dai/dai_config.h>
23 #include <dai/daialg.h>
24 #include <dai/factorgraph.h>
25 #include <dai/properties.h>
32 /// Approximate inference algorithm "(Loopy) Belief Propagation"
33 /** The Loopy Belief Propagation algorithm uses message passing
34 * to approximate marginal probability distributions ("beliefs") for variables
35 * and factors (more precisely, for the subset of variables depending on the factor).
36 * There are two variants, the sum-product algorithm (corresponding to
37 * finite temperature) and the max-product algorithm (corresponding to
40 * The messages \f$m_{I\to i}(x_i)\f$ are passed from factors \f$I\f$ to variables \f$i\f$.
41 * In case of the sum-product algorith, the update equation is:
42 * \f[ m_{I\to i}(x_i) \propto \sum_{x_{N_I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
43 * and in case of the max-product algorithm:
44 * \f[ m_{I\to i}(x_i) \propto \max_{x_{N_I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
45 * In order to improve convergence, the updates can be damped. For improved numerical stability,
46 * the updates can be done in the log-domain alternatively.
48 * After convergence, the variable beliefs are calculated by:
49 * \f[ b_i(x_i) \propto \prod_{I\in N_i} m_{I\to i}(x_i)\f]
50 * and the factor beliefs are calculated by:
51 * \f[ b_I(x_I) \propto f_I(x_I) \prod_{j\in N_I} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}(x_j) \f]
52 * The logarithm of the partition sum is calculated by:
53 * \f[ \log Z = \sum_i (1 - |N_i|) \sum_{x_i} b_i(x_i) \log b_i(x_i) - \sum_I \sum_{x_I} b_I(x_I) \log \frac{b_I(x_I)}{f_I(x_I)} \f]
55 * For the max-product algorithm, a heuristic way of finding the MAP state (the
56 * joint configuration of all variables which has maximum probability) is provided
57 * by the findMaximum() method, which can be called after convergence.
59 * \note There are two implementations, an optimized one (the default) which caches IndexFor objects,
60 * and a slower, less complicated one which is easier to maintain/understand. The slower one can be
61 * enabled by defining DAI_BP_FAST as false in the source file.
63 class BP
: public DAIAlgFG
{
65 /// Type used for index cache
66 typedef std::vector
<size_t> ind_t
;
67 /// Type used for storing edge properties
69 /// Index cached for this edge
71 /// Old message living on this edge
73 /// New message living on this edge
75 /// Residual for this edge
78 /// Stores all edge properties
79 std::vector
<std::vector
<EdgeProp
> > _edges
;
80 /// Type of lookup table (only used for maximum-residual BP)
81 typedef std::multimap
<Real
, std::pair
<size_t, size_t> > LutType
;
82 /// Lookup table (only used for maximum-residual BP)
83 std::vector
<std::vector
<LutType::iterator
> > _edge2lut
;
84 /// Lookup table (only used for maximum-residual BP)
86 /// Maximum difference between variable beliefs encountered so far
88 /// Number of iterations needed
90 /// The history of message updates (only recorded if \a recordSentMessages is \c true)
91 std::vector
<std::pair
<size_t, size_t> > _sentMessages
;
92 /// Stores variable beliefs of previous iteration
93 std::vector
<Factor
> _oldBeliefsV
;
94 /// Stores factor beliefs of previous iteration
95 std::vector
<Factor
> _oldBeliefsF
;
96 /// Stores the update schedule
97 std::vector
<Edge
> _updateSeq
;
100 /// Parameters for BP
102 /// Enumeration of possible update schedules
103 /** The following update schedules have been defined:
104 * - PARALL parallel updates
105 * - SEQFIX sequential updates using a fixed sequence
106 * - SEQRND sequential updates using a random sequence
107 * - SEQMAX maximum-residual updates [\ref EMK06]
109 DAI_ENUM(UpdateType
,SEQFIX
,SEQRND
,SEQMAX
,PARALL
);
111 /// Enumeration of inference variants
112 /** There are two inference variants:
113 * - SUMPROD Sum-Product
114 * - MAXPROD Max-Product (equivalent to Min-Sum)
116 DAI_ENUM(InfType
,SUMPROD
,MAXPROD
);
118 /// Verbosity (amount of output sent to stderr)
121 /// Maximum number of iterations
124 /// Maximum time (in seconds)
127 /// Tolerance for convergence test
130 /// Whether updates should be done in logarithmic domain or not
133 /// Damping constant (0.0 means no damping, 1.0 is maximum damping)
136 /// Message update schedule
139 /// Inference variant
143 /// Specifies whether the history of message updates should be recorded
144 bool recordSentMessages
;
147 /// \name Constructors/destructors
149 /// Default constructor
150 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) {}
152 /// Construct from FactorGraph \a fg and PropertySet \a opts
153 /** \param fg Factor graph.
154 * \param opts Parameters @see Properties
156 BP( const FactorGraph
& fg
, const PropertySet
&opts
) : DAIAlgFG(fg
), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) {
157 setProperties( opts
);
162 BP( const BP
&x
) : DAIAlgFG(x
), _edges(x
._edges
), _edge2lut(x
._edge2lut
), _lut(x
._lut
), _maxdiff(x
._maxdiff
), _iters(x
._iters
), _sentMessages(x
._sentMessages
), _oldBeliefsV(x
._oldBeliefsV
), _oldBeliefsF(x
._oldBeliefsF
), _updateSeq(x
._updateSeq
), props(x
.props
), recordSentMessages(x
.recordSentMessages
) {
163 for( LutType::iterator l
= _lut
.begin(); l
!= _lut
.end(); ++l
)
164 _edge2lut
[l
->second
.first
][l
->second
.second
] = l
;
167 /// Assignment operator
168 BP
& operator=( const BP
&x
) {
170 DAIAlgFG::operator=( x
);
173 for( LutType::iterator l
= _lut
.begin(); l
!= _lut
.end(); ++l
)
174 _edge2lut
[l
->second
.first
][l
->second
.second
] = l
;
175 _maxdiff
= x
._maxdiff
;
177 _sentMessages
= x
._sentMessages
;
178 _oldBeliefsV
= x
._oldBeliefsV
;
179 _oldBeliefsF
= x
._oldBeliefsF
;
180 _updateSeq
= x
._updateSeq
;
182 recordSentMessages
= x
.recordSentMessages
;
188 /// \name General InfAlg interface
190 virtual BP
* clone() const { return new BP(*this); }
191 virtual BP
* construct( const FactorGraph
&fg
, const PropertySet
&opts
) const { return new BP( fg
, opts
); }
192 virtual std::string
name() const { return "BP"; }
193 virtual Factor
belief( const Var
&v
) const { return beliefV( findVar( v
) ); }
194 virtual Factor
belief( const VarSet
&vs
) const;
195 virtual Factor
beliefV( size_t i
) const;
196 virtual Factor
beliefF( size_t I
) const;
197 virtual std::vector
<Factor
> beliefs() const;
198 virtual Real
logZ() const;
199 /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
201 std::vector
<size_t> findMaximum() const { return dai::findMaximum( *this ); }
203 virtual void init( const VarSet
&ns
);
205 virtual Real
maxDiff() const { return _maxdiff
; }
206 virtual size_t Iterations() const { return _iters
; }
207 virtual void setMaxIter( size_t maxiter
) { props
.maxiter
= maxiter
; }
208 virtual void setProperties( const PropertySet
&opts
);
209 virtual PropertySet
getProperties() const;
210 virtual std::string
printProperties() const;
213 /// \name Additional interface specific for BP
215 /// Returns history of which messages have been updated
216 const std::vector
<std::pair
<size_t, size_t> >& getSentMessages() const {
217 return _sentMessages
;
220 /// Clears history of which messages have been updated
221 void clearSentMessages() { _sentMessages
.clear(); }
225 /// Returns constant reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
226 const Prob
& message(size_t i
, size_t _I
) const { return _edges
[i
][_I
].message
; }
227 /// Returns reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
228 Prob
& message(size_t i
, size_t _I
) { return _edges
[i
][_I
].message
; }
229 /// Returns constant reference to updated message from the \a _I 'th neighbor of variable \a i to variable \a i
230 const Prob
& newMessage(size_t i
, size_t _I
) const { return _edges
[i
][_I
].newMessage
; }
231 /// Returns reference to updated message from the \a _I 'th neighbor of variable \a i to variable \a i
232 Prob
& newMessage(size_t i
, size_t _I
) { return _edges
[i
][_I
].newMessage
; }
233 /// Returns constant reference to cached index for the edge between variable \a i and its \a _I 'th neighbor
234 const ind_t
& index(size_t i
, size_t _I
) const { return _edges
[i
][_I
].index
; }
235 /// Returns reference to cached index for the edge between variable \a i and its \a _I 'th neighbor
236 ind_t
& index(size_t i
, size_t _I
) { return _edges
[i
][_I
].index
; }
237 /// Returns constant reference to residual for the edge between variable \a i and its \a _I 'th neighbor
238 const Real
& residual(size_t i
, size_t _I
) const { return _edges
[i
][_I
].residual
; }
239 /// Returns reference to residual for the edge between variable \a i and its \a _I 'th neighbor
240 Real
& residual(size_t i
, size_t _I
) { return _edges
[i
][_I
].residual
; }
242 /// Calculate the product of factor \a I and the incoming messages
243 /** If \a without_i == \c true, the message coming from variable \a i is omitted from the product
244 * \note This function is used by calcNewMessage() and calcBeliefF()
246 virtual Prob
calcIncomingMessageProduct( size_t I
, bool without_i
, size_t i
) const;
247 /// Calculate the updated message from the \a _I 'th neighbor of variable \a i to variable \a i
248 virtual void calcNewMessage( size_t i
, size_t _I
);
249 /// Replace the "old" message from the \a _I 'th neighbor of variable \a i to variable \a i by the "new" (updated) message
250 void updateMessage( size_t i
, size_t _I
);
251 /// Set the residual (difference between new and old message) for the edge between variable \a i and its \a _I 'th neighbor to \a r
252 void updateResidual( size_t i
, size_t _I
, Real r
);
253 /// Finds the edge which has the maximum residual (difference between new and old message)
254 void findMaxResidual( size_t &i
, size_t &_I
);
255 /// Calculates unnormalized belief of variable \a i
256 virtual void calcBeliefV( size_t i
, Prob
&p
) const;
257 /// Calculates unnormalized belief of factor \a I
258 virtual void calcBeliefF( size_t I
, Prob
&p
) const {
259 p
= calcIncomingMessageProduct( I
, false, 0 );
262 /// Helper function for constructors
263 virtual void construct();
267 } // end of namespace dai