1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
13 /// \brief Defines class BP
14 /// \todo Improve documentation
17 #ifndef __defined_libdai_bp_h
18 #define __defined_libdai_bp_h
22 #include <dai/daialg.h>
23 #include <dai/factorgraph.h>
24 #include <dai/properties.h>
31 /// Approximate inference algorithm "(Loopy) Belief Propagation"
32 class BP
: public DAIAlgFG
{
34 typedef std::vector
<size_t> ind_t
;
35 typedef std::multimap
<Real
, std::pair
<std::size_t, std::size_t> > LutType
;
42 std::vector
<std::vector
<EdgeProp
> > _edges
;
43 std::vector
<std::vector
<LutType::iterator
> > _edge2lut
;
45 /// Maximum difference encountered so far
47 /// Number of iterations needed
49 /// The history of message updates (only recorded if recordSentMessages is true)
50 std::vector
<std::pair
<std::size_t, std::size_t> > _sentMessages
;
53 /// Parameters of this inference algorithm
55 /// Enumeration of possible update schedules
56 DAI_ENUM(UpdateType
,SEQFIX
,SEQRND
,SEQMAX
,PARALL
);
58 /// Enumeration of inference variants
59 DAI_ENUM(InfType
,SUMPROD
,MAXPROD
);
64 /// Maximum number of iterations
70 /// Do updates in logarithmic domain?
79 /// Type of inference: sum-product or max-product?
83 /// Name of this inference algorithm
84 static const char *Name
;
86 /// Specifies whether the history of message updates should be recorded
87 bool recordSentMessages
;
90 /// Default constructor
91 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
94 BP( const BP
&x
) : DAIAlgFG(x
), _edges(x
._edges
), _edge2lut(x
._edge2lut
),
95 _lut(x
._lut
), _maxdiff(x
._maxdiff
), _iters(x
._iters
), _sentMessages(x
._sentMessages
),
96 props(x
.props
), recordSentMessages(x
.recordSentMessages
)
98 for( LutType::iterator l
= _lut
.begin(); l
!= _lut
.end(); ++l
)
99 _edge2lut
[l
->second
.first
][l
->second
.second
] = l
;
102 /// Assignment operator
103 BP
& operator=( const BP
&x
) {
105 DAIAlgFG::operator=( x
);
108 for( LutType::iterator l
= _lut
.begin(); l
!= _lut
.end(); ++l
)
109 _edge2lut
[l
->second
.first
][l
->second
.second
] = l
;
110 _maxdiff
= x
._maxdiff
;
112 _sentMessages
= x
._sentMessages
;
114 recordSentMessages
= x
.recordSentMessages
;
119 /// Construct from FactorGraph fg and PropertySet opts
120 BP( const FactorGraph
& fg
, const PropertySet
&opts
) : DAIAlgFG(fg
), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
121 setProperties( opts
);
126 /// @name General InfAlg interface
128 virtual BP
* clone() const { return new BP(*this); }
129 virtual std::string
identify() const;
130 virtual Factor
belief( const Var
&n
) const;
131 virtual Factor
belief( const VarSet
&ns
) const;
132 virtual std::vector
<Factor
> beliefs() const;
133 virtual Real
logZ() const;
135 virtual void init( const VarSet
&ns
);
137 virtual Real
maxDiff() const { return _maxdiff
; }
138 virtual size_t Iterations() const { return _iters
; }
142 /// @name Additional interface specific for BP
144 Factor
beliefV( size_t i
) const;
145 Factor
beliefF( size_t I
) const;
148 /// Calculates the joint state of all variables that has maximum probability
149 /** Assumes that run() has been called and that props.inference == MAXPROD
151 std::vector
<std::size_t> findMaximum() const;
153 /// Returns history of sent messages
154 const std::vector
<std::pair
<std::size_t, std::size_t> >& getSentMessages() const {
155 return _sentMessages
;
158 /// Clears history of sent messages
159 void clearSentMessages() {
160 _sentMessages
.clear();
164 const Prob
& message(size_t i
, size_t _I
) const { return _edges
[i
][_I
].message
; }
165 Prob
& message(size_t i
, size_t _I
) { return _edges
[i
][_I
].message
; }
166 Prob
& newMessage(size_t i
, size_t _I
) { return _edges
[i
][_I
].newMessage
; }
167 const Prob
& newMessage(size_t i
, size_t _I
) const { return _edges
[i
][_I
].newMessage
; }
168 ind_t
& index(size_t i
, size_t _I
) { return _edges
[i
][_I
].index
; }
169 const ind_t
& index(size_t i
, size_t _I
) const { return _edges
[i
][_I
].index
; }
170 Real
& residual(size_t i
, size_t _I
) { return _edges
[i
][_I
].residual
; }
171 const Real
& residual(size_t i
, size_t _I
) const { return _edges
[i
][_I
].residual
; }
173 void calcNewMessage( size_t i
, size_t _I
);
174 void updateMessage( size_t i
, size_t _I
);
175 void updateResidual( size_t i
, size_t _I
, Real r
);
176 void findMaxResidual( size_t &i
, size_t &_I
);
177 /// Calculates unnormalized belief of variable
178 void calcBeliefV( size_t i
, Prob
&p
) const;
179 /// Calculates unnormalized belief of factor
180 void calcBeliefF( size_t I
, Prob
&p
) const;
183 /// Set Props according to the PropertySet opts, where the values can be stored as std::strings or as the type of the corresponding Props member
184 void setProperties( const PropertySet
&opts
);
185 PropertySet
getProperties() const;
186 std::string
printProperties() const;
190 } // end of namespace dai