Improved documentation of include/dai/exactinf.h
[libdai.git] / include / dai / bp.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines class BP
14 /// \todo Improve documentation
15
16
17 #ifndef __defined_libdai_bp_h
18 #define __defined_libdai_bp_h
19
20
21 #include <string>
22 #include <dai/daialg.h>
23 #include <dai/factorgraph.h>
24 #include <dai/properties.h>
25 #include <dai/enum.h>
26
27
28 namespace dai {
29
30
31 /// Approximate inference algorithm "(Loopy) Belief Propagation"
32 class BP : public DAIAlgFG {
33 private:
34 typedef std::vector<size_t> ind_t;
35 typedef std::multimap<Real, std::pair<std::size_t, std::size_t> > LutType;
36 struct EdgeProp {
37 ind_t index;
38 Prob message;
39 Prob newMessage;
40 Real residual;
41 };
42 std::vector<std::vector<EdgeProp> > _edges;
43 std::vector<std::vector<LutType::iterator> > _edge2lut;
44 LutType _lut;
45 /// Maximum difference encountered so far
46 Real _maxdiff;
47 /// Number of iterations needed
48 size_t _iters;
49 /// The history of message updates (only recorded if recordSentMessages is true)
50 std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
51
52 public:
53 /// Parameters of this inference algorithm
54 struct Properties {
55 /// Enumeration of possible update schedules
56 DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL);
57
58 /// Enumeration of inference variants
59 DAI_ENUM(InfType,SUMPROD,MAXPROD);
60
61 /// Verbosity
62 size_t verbose;
63
64 /// Maximum number of iterations
65 size_t maxiter;
66
67 /// Tolerance
68 Real tol;
69
70 /// Do updates in logarithmic domain?
71 bool logdomain;
72
73 /// Damping constant
74 Real damping;
75
76 /// Update schedule
77 UpdateType updates;
78
79 /// Type of inference: sum-product or max-product?
80 InfType inference;
81 } props;
82
83 /// Name of this inference algorithm
84 static const char *Name;
85
86 /// Specifies whether the history of message updates should be recorded
87 bool recordSentMessages;
88
89 public:
90 /// Default constructor
91 BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
92
93 /// Copy constructor
94 BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut),
95 _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages),
96 props(x.props), recordSentMessages(x.recordSentMessages)
97 {
98 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
99 _edge2lut[l->second.first][l->second.second] = l;
100 }
101
102 /// Assignment operator
103 BP& operator=( const BP &x ) {
104 if( this != &x ) {
105 DAIAlgFG::operator=( x );
106 _edges = x._edges;
107 _lut = x._lut;
108 for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
109 _edge2lut[l->second.first][l->second.second] = l;
110 _maxdiff = x._maxdiff;
111 _iters = x._iters;
112 _sentMessages = x._sentMessages;
113 props = x.props;
114 recordSentMessages = x.recordSentMessages;
115 }
116 return *this;
117 }
118
119 /// Construct from FactorGraph fg and PropertySet opts
120 BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
121 setProperties( opts );
122 construct();
123 }
124
125 /// \name General InfAlg interface
126 //@{
127 virtual BP* clone() const { return new BP(*this); }
128 virtual std::string identify() const;
129 virtual Factor belief( const Var &n ) const;
130 virtual Factor belief( const VarSet &ns ) const;
131 virtual Factor beliefV( size_t i ) const;
132 virtual Factor beliefF( size_t I ) const;
133 virtual std::vector<Factor> beliefs() const;
134 virtual Real logZ() const;
135 virtual void init();
136 virtual void init( const VarSet &ns );
137 virtual Real run();
138 virtual Real maxDiff() const { return _maxdiff; }
139 virtual size_t Iterations() const { return _iters; }
140 //@}
141
142 /// \name Additional interface specific for BP
143 //@{
144 /// Calculates the joint state of all variables that has maximum probability
145 /** Assumes that run() has been called and that props.inference == MAXPROD
146 */
147 std::vector<std::size_t> findMaximum() const;
148
149 /// Returns history of sent messages
150 const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
151 return _sentMessages;
152 }
153
154 /// Clears history of sent messages
155 void clearSentMessages() {
156 _sentMessages.clear();
157 }
158 //@}
159
160 /// \name Managing parameters (which are stored in BP::props)
161 //@{
162 /// Set parameters of this inference algorithm.
163 /** The parameters are set according to \a opts.
164 * The values can be stored either as std::string or as the type of the corresponding BP::props member.
165 */
166 void setProperties( const PropertySet &opts );
167 /// Returns parameters of this inference algorithm converted into a PropertySet.
168 PropertySet getProperties() const;
169 /// Returns parameters of this inference algorithm formatted as a string in the format "[key1=val1,key2=val2,...,keyn=valn]".
170 std::string printProperties() const;
171 //@}
172
173 private:
174 const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
175 Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
176 Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; }
177 const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; }
178 ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; }
179 const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; }
180 Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
181 const Real & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; }
182
183 void calcNewMessage( size_t i, size_t _I );
184 void updateMessage( size_t i, size_t _I );
185 void updateResidual( size_t i, size_t _I, Real r );
186 void findMaxResidual( size_t &i, size_t &_I );
187 /// Calculates unnormalized belief of variable
188 void calcBeliefV( size_t i, Prob &p ) const;
189 /// Calculates unnormalized belief of factor
190 void calcBeliefF( size_t I, Prob &p ) const;
191
192 void construct();
193 };
194
195
196 } // end of namespace dai
197
198
199 #endif