Cleaned up BBP and improved documentation of include/dai/bbp.h
[libdai.git] / include / dai / bbp.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
8 */
9
10
11 /// \file
12 /// \brief Defines class BBP, which implements Back-Belief-Propagation
13
14
15 #ifndef ___defined_libdai_bbp_h
16 #define ___defined_libdai_bbp_h
17
18
19 #include <vector>
20 #include <utility>
21
22 #include <dai/prob.h>
23 #include <dai/daialg.h>
24 #include <dai/factorgraph.h>
25 #include <dai/enum.h>
26 #include <dai/bp_dual.h>
27
28
29 namespace dai {
30
31
32 /// Enumeration of several cost functions that can be used with BBP
33 /** \note This class is meant as a base class for BBPCostFunction, which provides additional functionality.
34 */
35 DAI_ENUM(BBPCostFunctionBase,CFN_GIBBS_B,CFN_GIBBS_B2,CFN_GIBBS_EXP,CFN_GIBBS_B_FACTOR,CFN_GIBBS_B2_FACTOR,CFN_GIBBS_EXP_FACTOR,CFN_VAR_ENT,CFN_FACTOR_ENT,CFN_BETHE_ENT);
36
37
38 /// Predefined cost functions that can be used with BBP
39 class BBPCostFunction : public BBPCostFunctionBase {
40 public:
41 /// Returns whether this cost function depends on having a Gibbs state
42 bool needGibbsState() const;
43
44 /// Evaluates cost function in state \a stateP using the information in inference algorithm \a ia
45 Real evaluate( const InfAlg &ia, const std::vector<size_t> *stateP ) const;
46
47 /// Assignment operator
48 BBPCostFunction& operator=( const BBPCostFunctionBase &x ) {
49 if( this != &x ) {
50 (BBPCostFunctionBase)*this = x;
51 }
52 return *this;
53 }
54 };
55
56
57 /// Implements BBP (Back-Belief-Propagation) [\ref EaG09]
58 /** \author Frederik Eaton
59 */
60 class BBP {
61 private:
62 /// \name Input variables
63 //@{
64 /// Stores a BP_dual helper object
65 BP_dual _bp_dual;
66 /// Pointer to the factor graph
67 const FactorGraph *_fg;
68 /// Pointer to the approximate inference algorithm
69 const InfAlg *_ia;
70 //@}
71
72 /// \name Output variables
73 //@{
74 /// Variable factor adjoints
75 std::vector<Prob> _adj_psi_V;
76 /// Factor adjoints
77 std::vector<Prob> _adj_psi_F;
78 /// Variable->factor message adjoints (indexed [i][_I])
79 std::vector<std::vector<Prob> > _adj_n;
80 /// Factor->variable message adjoints (indexed [i][_I])
81 std::vector<std::vector<Prob> > _adj_m;
82 /// Normalized variable belief adjoints
83 std::vector<Prob> _adj_b_V;
84 /// Normalized factor belief adjoints
85 std::vector<Prob> _adj_b_F;
86 //@}
87
88 /// \name Internal state variables
89 //@{
90 /// Initial variable factor adjoints
91 std::vector<Prob> _init_adj_psi_V;
92 /// Initial factor adjoints
93 std::vector<Prob> _init_adj_psi_F;
94
95 /// Unnormalized variable->factor message adjoint (indexed [i][_I])
96 std::vector<std::vector<Prob> > _adj_n_unnorm;
97 /// Unnormalized factor->variable message adjoint (indexed [i][_I])
98 std::vector<std::vector<Prob> > _adj_m_unnorm;
99 /// Updated normalized variable->factor message adjoint (indexed [i][_I])
100 std::vector<std::vector<Prob> > _new_adj_n;
101 /// Updated normalized factor->variable message adjoint (indexed [i][_I])
102 std::vector<std::vector<Prob> > _new_adj_m;
103 /// Unnormalized variable belief adjoints
104 std::vector<Prob> _adj_b_V_unnorm;
105 /// Unnormalized factor belief adjoints
106 std::vector<Prob> _adj_b_F_unnorm;
107
108 /// _T[i][_I] (see eqn. (41) in [\ref EaG09])
109 std::vector<std::vector<Prob > > _T;
110 /// _U[I][_i] (see eqn. (42) in [\ref EaG09])
111 std::vector<std::vector<Prob > > _U;
112 /// _S[i][_I][_j] (see eqn. (43) in [\ref EaG09])
113 std::vector<std::vector<std::vector<Prob > > > _S;
114 /// _R[I][_i][_J] (see eqn. (44) in [\ref EaG09])
115 std::vector<std::vector<std::vector<Prob > > > _R;
116
117 /// Number of iterations done
118 size_t _iters;
119 //@}
120
121 /// \name Index cache management (for performance)
122 //@{
123 /// Index type
124 typedef std::vector<size_t> _ind_t;
125 /// Cached indices (indexed [i][_I])
126 std::vector<std::vector<_ind_t> > _indices;
127 /// Prepares index cache _indices
128 /** \see bp.cpp
129 */
130 void RegenerateInds();
131 /// Returns an index from the cache
132 const _ind_t& _index(size_t i, size_t _I) const { return _indices[i][_I]; }
133 //@}
134
135 /// \name Initialization helper functions
136 //@{
137 /// Calculate T values; see eqn. (41) in [\ref EaG09]
138 void RegenerateT();
139 /// Calculate U values; see eqn. (42) in [\ref EaG09]
140 void RegenerateU();
141 /// Calculate S values; see eqn. (43) in [\ref EaG09]
142 void RegenerateS();
143 /// Calculate R values; see eqn. (44) in [\ref EaG09]
144 void RegenerateR();
145 /// Calculate _adj_b_V_unnorm and _adj_b_F_unnorm from _adj_b_V and _adj_b_F
146 void RegenerateInputs();
147 /// Initialise members for factor adjoints
148 /** \pre RegenerateInputs() should be called first
149 */
150 void RegeneratePsiAdjoints();
151 /// Initialise members for message adjoints for parallel algorithm
152 /** \pre RegenerateInputs() should be called first
153 */
154 void RegenerateParMessageAdjoints();
155 /// Initialise members for message adjoints for sequential algorithm
156 /** Same as RegenerateMessageAdjoints, but calls sendSeqMsgN rather
157 * than updating _adj_n (and friends) which are unused in the sequential algorithm.
158 * \pre RegenerateInputs() should be called first
159 */
160 void RegenerateSeqMessageAdjoints();
161 /// Called by \a init, recalculates intermediate values
162 void Regenerate();
163 //@}
164
165 /// \name Accessors/mutators
166 //@{
167 /// Returns reference to T value; see eqn. (41) in [\ref EaG09]
168 Prob & T(size_t i, size_t _I) { return _T[i][_I]; }
169 /// Returns constant reference to T value; see eqn. (41) in [\ref EaG09]
170 const Prob & T(size_t i, size_t _I) const { return _T[i][_I]; }
171 /// Returns reference to U value; see eqn. (42) in [\ref EaG09]
172 Prob & U(size_t I, size_t _i) { return _U[I][_i]; }
173 /// Returns constant reference to U value; see eqn. (42) in [\ref EaG09]
174 const Prob & U(size_t I, size_t _i) const { return _U[I][_i]; }
175 /// Returns reference to S value; see eqn. (43) in [\ref EaG09]
176 Prob & S(size_t i, size_t _I, size_t _j) { return _S[i][_I][_j]; }
177 /// Returns constant reference to S value; see eqn. (43) in [\ref EaG09]
178 const Prob & S(size_t i, size_t _I, size_t _j) const { return _S[i][_I][_j]; }
179 /// Returns reference to R value; see eqn. (44) in [\ref EaG09]
180 Prob & R(size_t I, size_t _i, size_t _J) { return _R[I][_i][_J]; }
181 /// Returns constant reference to R value; see eqn. (44) in [\ref EaG09]
182 const Prob & R(size_t I, size_t _i, size_t _J) const { return _R[I][_i][_J]; }
183
184 /// Returns reference to variable->factor message adjoint
185 Prob& adj_n(size_t i, size_t _I) { return _adj_n[i][_I]; }
186 /// Returns constant reference to variable->factor message adjoint
187 const Prob& adj_n(size_t i, size_t _I) const { return _adj_n[i][_I]; }
188 /// Returns reference to factor->variable message adjoint
189 Prob& adj_m(size_t i, size_t _I) { return _adj_m[i][_I]; }
190 /// Returns constant reference to factor->variable message adjoint
191 const Prob& adj_m(size_t i, size_t _I) const { return _adj_m[i][_I]; }
192 //@}
193
194 /// \name Parallel algorithm
195 //@{
196 /// Calculates new variable->factor message adjoint
197 /** Increases variable factor adjoint according to eqn. (27) in [\ref EaG09] and
198 * calculates the new variable->factor message adjoint according to eqn. (29) in [\ref EaG09].
199 */
200 void calcNewN( size_t i, size_t _I );
201 /// Calculates new factor->variable message adjoint
202 /** Increases factor adjoint according to eqn. (28) in [\ref EaG09] and
203 * calculates the new factor->variable message adjoint according to the r.h.s. of eqn. (30) in [\ref EaG09].
204 */
205 void calcNewM( size_t i, size_t _I );
206 /// Calculates unnormalized variable->factor message adjoint from the normalized one
207 void calcUnnormMsgN( size_t i, size_t _I );
208 /// Calculates unnormalized factor->variable message adjoint from the normalized one
209 void calcUnnormMsgM( size_t i, size_t _I );
210 /// Updates (un)normalized variable->factor message adjoints
211 void upMsgN( size_t i, size_t _I );
212 /// Updates (un)normalized factor->variable message adjoints
213 void upMsgM( size_t i, size_t _I );
214 /// Do one parallel update of all message adjoints
215 void doParUpdate();
216 //@}
217
218 /// \name Sequential algorithm
219 //@{
220 /// Helper function for sendSeqMsgM(): increases factor->variable message adjoint by \a p and calculates the corresponding unnormalized adjoint
221 void incrSeqMsgM( size_t i, size_t _I, const Prob& p );
222 // DISABLED BECAUSE IT IS BUGGY:
223 // void updateSeqMsgM( size_t i, size_t _I );
224 /// Sets normalized factor->variable message adjoint and calculates the corresponding unnormalized adjoint
225 void setSeqMsgM( size_t i, size_t _I, const Prob &p );
226 /// Implements routine Send-n in Figure 5 in [\ref EaG09]
227 void sendSeqMsgN( size_t i, size_t _I, const Prob &f );
228 /// Implements routine Send-m in Figure 5 in [\ref EaG09]
229 void sendSeqMsgM( size_t i, size_t _I );
230 //@}
231
232 /// Computes the adjoint of the unnormed probability vector from the normalizer and the adjoint of the normalized probability vector
233 /** \see eqn. (13) in [\ref EaG09]
234 */
235 Prob unnormAdjoint( const Prob &w, Real Z_w, const Prob &adj_w );
236
237 /// Calculates averaged L1 norm of unnormalized message adjoints
238 Real getUnMsgMag();
239 /// Calculates averaged L1 norms of current and new normalized message adjoints
240 void getMsgMags( Real &s, Real &new_s );
241 /// Returns indices and magnitude of the largest normalized factor->variable message adjoint
242 void getArgmaxMsgM( size_t &i, size_t &_I, Real &mag );
243 /// Returns magnitude of the largest (in L1-norm) normalized factor->variable message adjoint
244 Real getMaxMsgM();
245
246 /// Calculates sum of L1 norms of all normalized factor->variable message adjoints
247 Real getTotalMsgM();
248 /// Calculates sum of L1 norms of all updated normalized factor->variable message adjoints
249 Real getTotalNewMsgM();
250 /// Calculates sum of L1 norms of all normalized variable->factor message adjoints
251 Real getTotalMsgN();
252
253 /// Returns a vector of Probs (filled with zeroes) with state spaces corresponding to the factors in the factor graph \a fg
254 std::vector<Prob> getZeroAdjF( const FactorGraph &fg );
255 /// Returns a vector of Probs (filled with zeroes) with state spaces corresponding to the variables in the factor graph \a fg
256 std::vector<Prob> getZeroAdjV( const FactorGraph &fg );
257
258 public:
259 /// \name Constructors/destructors
260 //@{
261 /// Construct from a InfAlg \a ia and a PropertySet \a opts
262 BBP( const InfAlg *ia, const PropertySet &opts ) : _bp_dual(ia), _fg(&(ia->fg())), _ia(ia) {
263 props.set(opts);
264 }
265 //@}
266
267 /// \name Initialization
268 //@{
269 /// Initializes from given belief adjoints \a adj_b_V, \a adj_b_F and initial factor adjoints \a adj_psi_V, \a adj_psi_F
270 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F, const std::vector<Prob> &adj_psi_V, const std::vector<Prob> &adj_psi_F ) {
271 _adj_b_V = adj_b_V;
272 _adj_b_F = adj_b_F;
273 _init_adj_psi_V = adj_psi_V;
274 _init_adj_psi_F = adj_psi_F;
275 Regenerate();
276 }
277
278 /// Initializes from given belief adjoints \a adj_b_V and \a adj_b_F (setting initial factor adjoints to zero)
279 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F ) {
280 init( adj_b_V, adj_b_F, getZeroAdjV(*_fg), getZeroAdjF(*_fg) );
281 }
282
283 /// Initializes variable belief adjoints \a adj_b_V (and sets factor belief adjoints and initial factor adjoints to zero)
284 void init_V( const std::vector<Prob> &adj_b_V ) {
285 init( adj_b_V, getZeroAdjF(*_fg) );
286 }
287
288 /// Initializes factor belief adjoints \a adj_b_F (and sets variable belief adjoints and initial factor adjoints to zero)
289 void init_F( const std::vector<Prob> &adj_b_F ) {
290 init( getZeroAdjV(*_fg), adj_b_F );
291 }
292
293 /// Initializes with adjoints calculated from cost function \a cfn, and state \a stateP
294 /** Uses the internal pointer to an inference algorithm in combination with the cost function and state for initialization.
295 * \param cfn Cost function used for initialization;
296 * \param stateP is a Gibbs state and can be NULL; it will be initialised using a Gibbs run.
297 */
298 void initCostFnAdj( const BBPCostFunction &cfn, const std::vector<size_t> *stateP );
299 //@}
300
301 /// \name BBP Algorithm
302 //@{
303 /// Perform iterative updates until change is less than given tolerance
304 void run();
305 //@}
306
307 /// \name Query results
308 //@{
309 /// Returns reference to variable factor adjoint
310 Prob& adj_psi_V(size_t i) { return _adj_psi_V[i]; }
311 /// Returns constant reference to variable factor adjoint
312 const Prob& adj_psi_V(size_t i) const { return _adj_psi_V[i]; }
313 /// Returns reference to factor adjoint
314 Prob& adj_psi_F(size_t I) { return _adj_psi_F[I]; }
315 /// Returns constant reference to factor adjoint
316 const Prob& adj_psi_F(size_t I) const { return _adj_psi_F[I]; }
317 /// Returns reference to variable belief adjoint
318 Prob& adj_b_V(size_t i) { return _adj_b_V[i]; }
319 /// Returns constant reference to variable belief adjoint
320 const Prob& adj_b_V(size_t i) const { return _adj_b_V[i]; }
321 /// Returns reference to factor belief adjoint
322 Prob& adj_b_F(size_t I) { return _adj_b_F[I]; }
323 /// Returns constant reference to factor belief adjoint
324 const Prob& adj_b_F(size_t I) const { return _adj_b_F[I]; }
325 /// Return number of iterations done so far
326 size_t Iterations() { return _iters; }
327 //@}
328
329 public:
330 /// Parameters of this algorithm
331 /* PROPERTIES(props,BBP) {
332 /// Enumeration of possible update schedules
333 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
334
335 /// Verbosity
336 size_t verbose;
337
338 /// Maximum number of iterations
339 size_t maxiter;
340
341 /// Tolerance (not used for updates = SEQ_BP_REV, SEQ_BP_FWD)
342 Real tol;
343
344 /// Damping constant (0 for none); damping = 1 - lambda where lambda is the damping constant used in [\ref EaG09]
345 Real damping;
346
347 /// Update schedule
348 UpdateType updates;
349
350 // DISABLED BECAUSE IT IS BUGGY:
351 // bool clean_updates;
352 }
353 */
354 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
355 ./scripts/regenerate-properties include/dai/bbp.h src/bbp.cpp
356 */
357 struct Properties {
358 /// Enumeration of possible update schedules
359 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
360 /// Verbosity
361 size_t verbose;
362 /// Maximum number of iterations
363 size_t maxiter;
364 /// Tolerance (not used for updates = SEQ_BP_REV, SEQ_BP_FWD)
365 Real tol;
366 /// Damping constant (0 for none); damping = 1 - lambda where lambda is the damping constant used in [\ref EaG09]
367 Real damping;
368 /// Update schedule
369 UpdateType updates;
370
371 /// Set members from PropertySet
372 void set(const PropertySet &opts);
373 /// Get members into PropertySet
374 PropertySet get() const;
375 /// Convert to a string which can be parsed as a PropertySet
376 std::string toString() const;
377 } props;
378 /* }}} END OF GENERATED CODE */
379 };
380
381
382 /// Function to verify the validity of adjoints computed by BBP using numerical differentiation.
383 /** Factors containing a variable are multiplied by small adjustments to verify accuracy of calculated variable factor adjoints.
384 * \param bp BP object;
385 * \param state Global state of all variables;
386 * \param bbp_props BBP parameters;
387 * \param cfn Cost function to be used;
388 * \param h Size of perturbation.
389 * \relates BBP
390 */
391 Real numericBBPTest( const InfAlg &bp, const std::vector<size_t> *state, const PropertySet &bbp_props, const BBPCostFunction &cfn, Real h );
392
393
394 } // end of namespace dai
395
396
397 #endif