Fixed tabs and trailing whitespaces
[libdai.git] / include / dai / bp.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 /// \file
24 /// \brief Defines class BP
25 /// \todo Improve documentation
26
27
28 #ifndef __defined_libdai_bp_h
29 #define __defined_libdai_bp_h
30
31
32 #include <string>
33 #include <dai/daialg.h>
34 #include <dai/factorgraph.h>
35 #include <dai/properties.h>
36 #include <dai/enum.h>
37
38
39 namespace dai {
40
41
42 /// Approximate inference algorithm "(Loopy) Belief Propagation"
43 class BP : public DAIAlgFG {
44 private:
45 typedef std::vector<size_t> ind_t;
46 typedef std::multimap<double, std::pair<std::size_t, std::size_t> > LutType;
47 struct EdgeProp {
48 ind_t index;
49 Prob message;
50 Prob newMessage;
51 double residual;
52 };
53 std::vector<std::vector<EdgeProp> > _edges;
54 std::vector<std::vector<LutType::iterator> > _edge2lut;
55 LutType _lut;
56 /// Maximum difference encountered so far
57 double _maxdiff;
58 /// Number of iterations needed
59 size_t _iters;
60 /// The history of message updates (only recorded if recordSentMessages is true)
61 std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
62
63 public:
64 /// Parameters of this inference algorithm
65 struct Properties {
66 /// Enumeration of possible update schedules
67 DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL);
68
69 /// Enumeration of inference variants
70 DAI_ENUM(InfType,SUMPROD,MAXPROD);
71
72 /// Verbosity
73 size_t verbose;
74
75 /// Maximum number of iterations
76 size_t maxiter;
77
78 /// Tolerance
79 double tol;
80
81 /// Do updates in logarithmic domain?
82 bool logdomain;
83
84 /// Damping constant
85 double damping;
86
87 /// Update schedule
88 UpdateType updates;
89
90 /// Type of inference: sum-product or max-product?
91 InfType inference;
92 } props;
93
94 /// Name of this inference algorithm
95 static const char *Name;
96
97 /// Specifies whether the history of message updates should be recorded
98 bool recordSentMessages;
99
100 public:
101 /// Default constructor
102 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
103
104 /// Copy constructor
105 BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut),
106 _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages),
107 props(x.props), recordSentMessages(x.recordSentMessages)
108 {
109 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
110 _edge2lut[l->second.first][l->second.second] = l;
111 }
112
113 /// Assignment operator
114 BP& operator=( const BP &x ) {
115 if( this != &x ) {
116 DAIAlgFG::operator=( x );
117 _edges = x._edges;
118 _lut = x._lut;
119 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
120 _edge2lut[l->second.first][l->second.second] = l;
121 _maxdiff = x._maxdiff;
122 _iters = x._iters;
123 _sentMessages = x._sentMessages;
124 props = x.props;
125 recordSentMessages = x.recordSentMessages;
126 }
127 return *this;
128 }
129
130 /// Construct from FactorGraph fg and PropertySet opts
131 BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
132 setProperties( opts );
133 construct();
134 }
135
136
137 /// @name General InfAlg interface
138 //@{
139 virtual BP* clone() const { return new BP(*this); }
140 virtual std::string identify() const;
141 virtual Factor belief( const Var &n ) const;
142 virtual Factor belief( const VarSet &ns ) const;
143 virtual std::vector<Factor> beliefs() const;
144 virtual Real logZ() const;
145 virtual void init();
146 virtual void init( const VarSet &ns );
147 virtual double run();
148 virtual double maxDiff() const { return _maxdiff; }
149 virtual size_t Iterations() const { return _iters; }
150 //@}
151
152
153 /// @name Additional interface specific for BP
154 //@{
155 Factor beliefV( size_t i ) const;
156 Factor beliefF( size_t I ) const;
157 //@}
158
159 /// Calculates the joint state of all variables that has maximum probability
160 /** Assumes that run() has been called and that props.inference == MAXPROD
161 */
162 std::vector<std::size_t> findMaximum() const;
163
164 /// Returns history of sent messages
165 const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
166 return _sentMessages;
167 }
168
169 /// Clears history of sent messages
170 void clearSentMessages() {
171 _sentMessages.clear();
172 }
173
174 private:
175 const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
176 Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
177 Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; }
178 const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; }
179 ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; }
180 const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; }
181 double & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
182 const double & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; }
183
184 void calcNewMessage( size_t i, size_t _I );
185 void updateMessage( size_t i, size_t _I );
186 void updateResidual( size_t i, size_t _I, double r );
187 void findMaxResidual( size_t &i, size_t &_I );
188 /// Calculates unnormalized belief of variable
189 void calcBeliefV( size_t i, Prob &p ) const;
190 /// Calculates unnormalized belief of factor
191 void calcBeliefF( size_t I, Prob &p ) const;
192
193 void construct();
194 /// Set Props according to the PropertySet opts, where the values can be stored as std::strings or as the type of the corresponding Props member
195 void setProperties( const PropertySet &opts );
196 PropertySet getProperties() const;
197 std::string printProperties() const;
198 };
199
200
201 } // end of namespace dai
202
203
204 #endif