Replaced doubles by Reals, fixed two bugs
[libdai.git] / include / dai / bp.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines class BP
14 /// \todo Improve documentation
15
16
17 #ifndef __defined_libdai_bp_h
18 #define __defined_libdai_bp_h
19
20
21 #include <string>
22 #include <dai/daialg.h>
23 #include <dai/factorgraph.h>
24 #include <dai/properties.h>
25 #include <dai/enum.h>
26
27
28 namespace dai {
29
30
31 /// Approximate inference algorithm "(Loopy) Belief Propagation"
32 class BP : public DAIAlgFG {
33 private:
34 typedef std::vector<size_t> ind_t;
35 typedef std::multimap<Real, std::pair<std::size_t, std::size_t> > LutType;
36 struct EdgeProp {
37 ind_t index;
38 Prob message;
39 Prob newMessage;
40 Real residual;
41 };
42 std::vector<std::vector<EdgeProp> > _edges;
43 std::vector<std::vector<LutType::iterator> > _edge2lut;
44 LutType _lut;
45 /// Maximum difference encountered so far
46 Real _maxdiff;
47 /// Number of iterations needed
48 size_t _iters;
49 /// The history of message updates (only recorded if recordSentMessages is true)
50 std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
51
52 public:
53 /// Parameters of this inference algorithm
54 struct Properties {
55 /// Enumeration of possible update schedules
56 DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL);
57
58 /// Enumeration of inference variants
59 DAI_ENUM(InfType,SUMPROD,MAXPROD);
60
61 /// Verbosity
62 size_t verbose;
63
64 /// Maximum number of iterations
65 size_t maxiter;
66
67 /// Tolerance
68 Real tol;
69
70 /// Do updates in logarithmic domain?
71 bool logdomain;
72
73 /// Damping constant
74 Real damping;
75
76 /// Update schedule
77 UpdateType updates;
78
79 /// Type of inference: sum-product or max-product?
80 InfType inference;
81 } props;
82
83 /// Name of this inference algorithm
84 static const char *Name;
85
86 /// Specifies whether the history of message updates should be recorded
87 bool recordSentMessages;
88
89 public:
90 /// Default constructor
91 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
92
93 /// Copy constructor
94 BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut),
95 _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages),
96 props(x.props), recordSentMessages(x.recordSentMessages)
97 {
98 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
99 _edge2lut[l->second.first][l->second.second] = l;
100 }
101
102 /// Assignment operator
103 BP& operator=( const BP &x ) {
104 if( this != &x ) {
105 DAIAlgFG::operator=( x );
106 _edges = x._edges;
107 _lut = x._lut;
108 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
109 _edge2lut[l->second.first][l->second.second] = l;
110 _maxdiff = x._maxdiff;
111 _iters = x._iters;
112 _sentMessages = x._sentMessages;
113 props = x.props;
114 recordSentMessages = x.recordSentMessages;
115 }
116 return *this;
117 }
118
119 /// Construct from FactorGraph fg and PropertySet opts
120 BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
121 setProperties( opts );
122 construct();
123 }
124
125
126 /// @name General InfAlg interface
127 //@{
128 virtual BP* clone() const { return new BP(*this); }
129 virtual std::string identify() const;
130 virtual Factor belief( const Var &n ) const;
131 virtual Factor belief( const VarSet &ns ) const;
132 virtual std::vector<Factor> beliefs() const;
133 virtual Real logZ() const;
134 virtual void init();
135 virtual void init( const VarSet &ns );
136 virtual Real run();
137 virtual Real maxDiff() const { return _maxdiff; }
138 virtual size_t Iterations() const { return _iters; }
139 //@}
140
141
142 /// @name Additional interface specific for BP
143 //@{
144 Factor beliefV( size_t i ) const;
145 Factor beliefF( size_t I ) const;
146 //@}
147
148 /// Calculates the joint state of all variables that has maximum probability
149 /** Assumes that run() has been called and that props.inference == MAXPROD
150 */
151 std::vector<std::size_t> findMaximum() const;
152
153 /// Returns history of sent messages
154 const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
155 return _sentMessages;
156 }
157
158 /// Clears history of sent messages
159 void clearSentMessages() {
160 _sentMessages.clear();
161 }
162
163 private:
164 const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
165 Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
166 Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; }
167 const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; }
168 ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; }
169 const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; }
170 Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
171 const Real & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; }
172
173 void calcNewMessage( size_t i, size_t _I );
174 void updateMessage( size_t i, size_t _I );
175 void updateResidual( size_t i, size_t _I, Real r );
176 void findMaxResidual( size_t &i, size_t &_I );
177 /// Calculates unnormalized belief of variable
178 void calcBeliefV( size_t i, Prob &p ) const;
179 /// Calculates unnormalized belief of factor
180 void calcBeliefF( size_t I, Prob &p ) const;
181
182 void construct();
183 /// Set Props according to the PropertySet opts, where the values can be stored as std::strings or as the type of the corresponding Props member
184 void setProperties( const PropertySet &opts );
185 PropertySet getProperties() const;
186 std::string printProperties() const;
187 };
188
189
190 } // end of namespace dai
191
192
193 #endif