1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
5 This file is part of libDAI.
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
24 /// \brief Defines class BP
25 /// \todo Improve documentation
28 #ifndef __defined_libdai_bp_h
29 #define __defined_libdai_bp_h
33 #include <dai/daialg.h>
34 #include <dai/factorgraph.h>
35 #include <dai/properties.h>
42 /// Approximate inference algorithm "(Loopy) Belief Propagation"
43 class BP
: public DAIAlgFG
{
45 typedef std::vector
<size_t> ind_t
;
46 typedef std::multimap
<double, std::pair
<std::size_t, std::size_t> > LutType
;
53 std::vector
<std::vector
<EdgeProp
> > _edges
;
54 std::vector
<std::vector
<LutType::iterator
> > _edge2lut
;
56 /// Maximum difference encountered so far
58 /// Number of iterations needed
60 /// The history of message updates (only recorded if recordSentMessages is true)
61 std::vector
<std::pair
<std::size_t, std::size_t> > _sentMessages
;
64 /// Parameters of this inference algorithm
66 /// Enumeration of possible update schedules
67 DAI_ENUM(UpdateType
,SEQFIX
,SEQRND
,SEQMAX
,PARALL
);
69 /// Enumeration of inference variants
70 DAI_ENUM(InfType
,SUMPROD
,MAXPROD
);
75 /// Maximum number of iterations
81 /// Do updates in logarithmic domain?
90 /// Type of inference: sum-product or max-product?
94 /// Name of this inference algorithm
95 static const char *Name
;
97 /// Specifies whether the history of message updates should be recorded
98 bool recordSentMessages
;
101 /// Default constructor
102 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
105 BP( const BP
&x
) : DAIAlgFG(x
), _edges(x
._edges
), _edge2lut(x
._edge2lut
),
106 _lut(x
._lut
), _maxdiff(x
._maxdiff
), _iters(x
._iters
), _sentMessages(x
._sentMessages
),
107 props(x
.props
), recordSentMessages(x
.recordSentMessages
)
109 for( LutType::iterator l
= _lut
.begin(); l
!= _lut
.end(); ++l
)
110 _edge2lut
[l
->second
.first
][l
->second
.second
] = l
;
113 /// Assignment operator
114 BP
& operator=( const BP
&x
) {
116 DAIAlgFG::operator=( x
);
119 for( LutType::iterator l
= _lut
.begin(); l
!= _lut
.end(); ++l
)
120 _edge2lut
[l
->second
.first
][l
->second
.second
] = l
;
121 _maxdiff
= x
._maxdiff
;
123 _sentMessages
= x
._sentMessages
;
125 recordSentMessages
= x
.recordSentMessages
;
130 /// Construct from FactorGraph fg and PropertySet opts
131 BP( const FactorGraph
& fg
, const PropertySet
&opts
) : DAIAlgFG(fg
), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
132 setProperties( opts
);
137 /// @name General InfAlg interface
139 virtual BP
* clone() const { return new BP(*this); }
140 virtual std::string
identify() const;
141 virtual Factor
belief( const Var
&n
) const;
142 virtual Factor
belief( const VarSet
&ns
) const;
143 virtual std::vector
<Factor
> beliefs() const;
144 virtual Real
logZ() const;
146 virtual void init( const VarSet
&ns
);
147 virtual double run();
148 virtual double maxDiff() const { return _maxdiff
; }
149 virtual size_t Iterations() const { return _iters
; }
153 /// @name Additional interface specific for BP
155 Factor
beliefV( size_t i
) const;
156 Factor
beliefF( size_t I
) const;
159 /// Calculates the joint state of all variables that has maximum probability
160 /** Assumes that run() has been called and that props.inference == MAXPROD
162 std::vector
<std::size_t> findMaximum() const;
164 /// Returns history of sent messages
165 const std::vector
<std::pair
<std::size_t, std::size_t> >& getSentMessages() const {
166 return _sentMessages
;
169 /// Clears history of sent messages
170 void clearSentMessages() {
171 _sentMessages
.clear();
175 const Prob
& message(size_t i
, size_t _I
) const { return _edges
[i
][_I
].message
; }
176 Prob
& message(size_t i
, size_t _I
) { return _edges
[i
][_I
].message
; }
177 Prob
& newMessage(size_t i
, size_t _I
) { return _edges
[i
][_I
].newMessage
; }
178 const Prob
& newMessage(size_t i
, size_t _I
) const { return _edges
[i
][_I
].newMessage
; }
179 ind_t
& index(size_t i
, size_t _I
) { return _edges
[i
][_I
].index
; }
180 const ind_t
& index(size_t i
, size_t _I
) const { return _edges
[i
][_I
].index
; }
181 double & residual(size_t i
, size_t _I
) { return _edges
[i
][_I
].residual
; }
182 const double & residual(size_t i
, size_t _I
) const { return _edges
[i
][_I
].residual
; }
184 void calcNewMessage( size_t i
, size_t _I
);
185 void updateMessage( size_t i
, size_t _I
);
186 void updateResidual( size_t i
, size_t _I
, double r
);
187 void findMaxResidual( size_t &i
, size_t &_I
);
188 /// Calculates unnormalized belief of variable
189 void calcBeliefV( size_t i
, Prob
&p
) const;
190 /// Calculates unnormalized belief of factor
191 void calcBeliefF( size_t I
, Prob
&p
) const;
194 /// Set Props according to the PropertySet opts, where the values can be stored as std::strings or as the type of the corresponding Props member
195 void setProperties( const PropertySet
&opts
);
196 PropertySet
getProperties() const;
197 std::string
printProperties() const;
201 } // end of namespace dai