Cleanup of BBP code
[libdai.git] / include / dai / bbp.h
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 /// \file
22 /// \brief Defines class BBP [\ref EaG09]
23 /// \todo Improve documentation
24 /// \todo Clean up
25 /// \todo Debug clean_updates
26
27
28 #ifndef ___defined_libdai_bbp_h
29 #define ___defined_libdai_bbp_h
30
31
32 #include <vector>
33 #include <utility>
34
35 #include <dai/prob.h>
36 #include <dai/daialg.h>
37 #include <dai/factorgraph.h>
38 #include <dai/enum.h>
39 #include <dai/bp_dual.h>
40
41
42 namespace dai {
43
44
45 /// Returns a vector of Probs (filled with zeroes) with state spaces corresponding to the factors in the factor graph fg
46 std::vector<Prob> get_zero_adj_F( const FactorGraph& fg );
47
48 /// Returns a vector of Probs (filled with zeroes) with state spaces corresponding to the variables in the factor graph fg
49 std::vector<Prob> get_zero_adj_V( const FactorGraph& fg );
50
51
52 /// Implements BBP (Back-Belief-Propagation) [\ref EaG09]
53 class BBP {
54 protected:
55 /// @name Inputs
56 //@{
57 BP_dual _bp_dual;
58 const FactorGraph *_fg;
59 const InfAlg *_ia;
60 //@}
61
62 /// Number of iterations done
63 size_t _iters;
64
65 /// @name Outputs
66 //@{
67 /// Variable factor adjoints
68 std::vector<Prob> _adj_psi_V;
69 /// Factor adjoints
70 std::vector<Prob> _adj_psi_F;
71 /// Variable->factor message adjoints (indexed [i][_I])
72 std::vector<std::vector<Prob> > _adj_n;
73 /// Factor->variable message adjoints (indexed [i][_I])
74 std::vector<std::vector<Prob> > _adj_m;
75 /// Normalized variable belief adjoints
76 std::vector<Prob> _adj_b_V;
77 /// Normalized factor belief adjoints
78 std::vector<Prob> _adj_b_F;
79 //@}
80
81 /// @name Helper quantities computed from the BP messages
82 //@{
83 /// _T[i][_I] (see eqn. (41) in [\ref EaG09])
84 std::vector<std::vector<Prob > > _T;
85 /// _U[I][_i] (see eqn. (42) in [\ref EaG09])
86 std::vector<std::vector<Prob > > _U;
87 /// _S[i][_I][_j] (see eqn. (43) in [\ref EaG09])
88 std::vector<std::vector<std::vector<Prob > > > _S;
89 /// _R[I][_i][_J] (see eqn. (44) in [\ref EaG09])
90 std::vector<std::vector<std::vector<Prob > > > _R;
91 //@}
92
93 /// Unnormalized variable belief adjoints
94 std::vector<Prob> _adj_b_V_unnorm;
95 /// Unnormalized factor belief adjoints
96 std::vector<Prob> _adj_b_F_unnorm;
97
98 /// Initial variable factor adjoints
99 std::vector<Prob> _init_adj_psi_V;
100 /// Initial factor adjoints
101 std::vector<Prob> _init_adj_psi_F;
102
103 /// Unnormalized variable->factor message adjoint (indexed [i][_I])
104 std::vector<std::vector<Prob> > _adj_n_unnorm;
105 /// Unnormalized factor->variable message adjoint (indexed [i][_I])
106 std::vector<std::vector<Prob> > _adj_m_unnorm;
107 /// Updated normalized variable->factor message adjoint (indexed [i][_I])
108 std::vector<std::vector<Prob> > _new_adj_n;
109 /// Updated normalized factor->variable message adjoint (indexed [i][_I])
110 std::vector<std::vector<Prob> > _new_adj_m;
111
112 /// @name Optimized indexing (for performance)
113 //@{
114 /// Calculates _indices, which is a cache of IndexFor @see bp.cpp
115 void RegenerateInds();
116
117 /// Index type
118 typedef std::vector<size_t> _ind_t;
119 /// Cached indices (indexed [i][_I])
120 std::vector<std::vector<_ind_t> > _indices;
121 /// Returns an index from the cache
122 const _ind_t& _index(size_t i, size_t _I) const { return _indices[i][_I]; }
123 //@}
124
125 /// @name Initialization
126 //@{
127 /// Calculate T values; see eqn. (41) in [\ref EaG09]
128 void RegenerateT();
129 /// Calculate U values; see eqn. (42) in [\ref EaG09]
130 void RegenerateU();
131 /// Calculate S values; see eqn. (43) in [\ref EaG09]
132 void RegenerateS();
133 /// Calculate R values; see eqn. (44) in [\ref EaG09]
134 void RegenerateR();
135 /// Calculate _adj_b_V_unnorm and _adj_b_F_unnorm from _adj_b_V and _adj_b_F
136 void RegenerateInputs();
137 /// Initialise members for factor adjoints (call after RegenerateInputs)
138 void RegeneratePsiAdjoints();
139 /// Initialise members for message adjoints (call after RegenerateInputs) for parallel algorithm
140 void RegenerateParMessageAdjoints();
141 /// Initialise members for message adjoints (call after RegenerateInputs) for sequential algorithm
142 /** Same as RegenerateMessageAdjoints, but calls sendSeqMsgN rather
143 * than updating _adj_n (and friends) which are unused in the sequential algorithm.
144 */
145 void RegenerateSeqMessageAdjoints();
146 //@}
147
148 /// Returns T value; see eqn. (41) in [\ref EaG09]
149 DAI_ACCMUT(Prob & T(size_t i, size_t _I), { return _T[i][_I]; });
150 /// Retunrs U value; see eqn. (42) in [\ref EaG09]
151 DAI_ACCMUT(Prob & U(size_t I, size_t _i), { return _U[I][_i]; });
152 /// Returns S value; see eqn. (43) in [\ref EaG09]
153 DAI_ACCMUT(Prob & S(size_t i, size_t _I, size_t _j), { return _S[i][_I][_j]; });
154 /// Returns R value; see eqn. (44) in [\ref EaG09]
155 DAI_ACCMUT(Prob & R(size_t I, size_t _i, size_t _J), { return _R[I][_i][_J]; });
156
157 /// @name Parallel algorithm
158 //@{
159 /// Calculates new variable->factor message adjoint
160 /** Increases variable factor adjoint according to eqn. (27) in [\ref EaG09] and
161 * calculates the new variable->factor message adjoint according to eqn. (29) in [\ref EaG09].
162 */
163 void calcNewN( size_t i, size_t _I );
164 /// Calculates new factor->variable message adjoint
165 /** Increases factor adjoint according to eqn. (28) in [\ref EaG09] and
166 * calculates the new factor->variable message adjoint according to the r.h.s. of eqn. (30) in [\ref EaG09].
167 */
168 void calcNewM( size_t i, size_t _I );
169 /// Calculates unnormalized variable->factor message adjoint from the normalized one
170 void calcUnnormMsgN( size_t i, size_t _I );
171 /// Calculates unnormalized factor->variable message adjoint from the normalized one
172 void calcUnnormMsgM( size_t i, size_t _I );
173 /// Updates (un)normalized variable->factor message adjoints
174 void upMsgN( size_t i, size_t _I );
175 /// Updates (un)normalized factor->variable message adjoints
176 void upMsgM( size_t i, size_t _I );
177 /// Do one parallel update of all message adjoints
178 void doParUpdate();
179 //@}
180
181 /// Calculates averaged L-1 norm of unnormalized message adjoints
182 Real getUnMsgMag();
183 /// Calculates averaged L-1 norms of current and new normalized message adjoints
184 void getMsgMags( Real &s, Real &new_s );
185
186 /// Sets all vectors _adj_b_F to zero
187 void zero_adj_b_F() {
188 _adj_b_F.clear();
189 _adj_b_F.reserve( _fg->nrFactors() );
190 for( size_t I = 0; I < _fg->nrFactors(); I++ )
191 _adj_b_F.push_back( Prob( _fg->factor(I).states(), Real( 0.0 ) ) );
192 }
193
194 /// @name Sequential algorithm
195 //@{
196 /// Helper function for sendSeqMsgM: increases factor->variable message adjoint by p and calculates the corresponding unnormalized adjoint
197 void incrSeqMsgM( size_t i, size_t _I, const Prob& p );
198 void updateSeqMsgM( size_t i, size_t _I );
199 /// Implements routine Send-n in Figure 5 in [\ref EaG09]
200 void sendSeqMsgN( size_t i, size_t _I, const Prob &f );
201 /// Implements routine Send-m in Figure 5 in [\ref EaG09]
202 void sendSeqMsgM( size_t i, size_t _I );
203 /// Sets normalized factor->variable message adjoint and calculates the corresponding unnormalized adjoint
204 void setSeqMsgM( size_t i, size_t _I, const Prob &p );
205 //@}
206
207 /// Returns indices and magnitude of the largest normalized factor->variable message adjoint
208 void getArgmaxMsgM( size_t &i, size_t &_I, Real &mag );
209
210 /// Returns magnitude of the largest (in L1-norm) normalized factor->variable message adjoint
211 Real getMaxMsgM();
212 /// Calculates sum of L1 norms of all normalized factor->variable message adjoints
213 Real getTotalMsgM();
214 /// Calculates sum of L1 norms of all updated normalized factor->variable message adjoints
215 Real getTotalNewMsgM();
216 /// Calculates sum of L1 norms of all normalized variable->factor message adjoints
217 Real getTotalMsgN();
218
219 public:
220 /// Called by \a init, recalculates intermediate values
221 void Regenerate();
222
223 /// Constructor
224 BBP( const InfAlg *ia, const PropertySet &opts ) : _bp_dual(ia), _fg(&(ia->fg())), _ia(ia) {
225 props.set(opts);
226 }
227
228 /// Initializes belief adjoints and initial factor adjoints and regenerates
229 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F, const std::vector<Prob> &adj_psi_V, const std::vector<Prob> &adj_psi_F ) {
230 _adj_b_V = adj_b_V;
231 _adj_b_F = adj_b_F;
232 _init_adj_psi_V = adj_psi_V;
233 _init_adj_psi_F = adj_psi_F;
234 Regenerate();
235 }
236
237 /// Initializes belief adjoints and with zero initial factor adjoints and regenerates
238 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F ) {
239 init( adj_b_V, adj_b_F, get_zero_adj_V(*_fg), get_zero_adj_F(*_fg) );
240 }
241
242 /// Initializes variable belief adjoints (and sets factor belief adjoints to zero) and with zero initial factor adjoints and regenerates
243 void init( const std::vector<Prob> &adj_b_V ) {
244 init(adj_b_V, get_zero_adj_F(*_fg));
245 }
246
247 /// Run until change is less than given tolerance
248 void run();
249
250 /// Return number of iterations done so far
251 size_t doneIters() { return _iters; }
252
253 /// Returns variable factor adjoint
254 DAI_ACCMUT(Prob& adj_psi_V(size_t i), { return _adj_psi_V[i]; });
255 /// Returns factor adjoint
256 DAI_ACCMUT(Prob& adj_psi_F(size_t I), { return _adj_psi_F[I]; });
257 /// Returns variable belief adjoint
258 DAI_ACCMUT(Prob& adj_b_V(size_t i), { return _adj_b_V[i]; });
259 /// Returns factor belief adjoint
260 DAI_ACCMUT(Prob& adj_b_F(size_t I), { return _adj_b_F[I]; });
261
262 protected:
263 /// Returns variable->factor message adjoint
264 DAI_ACCMUT(Prob& adj_n(size_t i, size_t _I), { return _adj_n[i][_I]; });
265 /// Returns factor->variable message adjoint
266 DAI_ACCMUT(Prob& adj_m(size_t i, size_t _I), { return _adj_m[i][_I]; });
267
268 public:
269 /// Parameters of this algorithm
270 /* PROPERTIES(props,BBP) {
271 /// Enumeration of possible update schedules
272 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
273
274 /// Verbosity
275 size_t verbose;
276
277 /// Maximum number of iterations
278 size_t maxiter;
279
280 /// Tolerance (not used for updates = SEQ_BP_REV, SEQ_BP_FWD)
281 double tol;
282
283 /// Damping constant (0 for none); damping = 1 - lambda where lambda is the damping constant used in [\ref EaG09]
284 double damping;
285
286 /// Update schedule
287 UpdateType updates;
288
289 bool clean_updates;
290 }
291 */
292 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
293 ./scripts/regenerate-properties include/dai/bbp.h src/bbp.cpp
294 */
295 struct Properties {
296 /// Enumeration of possible update schedules
297 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
298 /// Verbosity
299 size_t verbose;
300 /// Maximum number of iterations
301 size_t maxiter;
302 /// Tolerance (not used for updates = SEQ_BP_REV, SEQ_BP_FWD)
303 double tol;
304 /// Damping constant (0 for none); damping = 1 - lambda where lambda is the damping constant used in [\ref EaG09]
305 double damping;
306 /// Update schedule
307 UpdateType updates;
308 bool clean_updates;
309
310 /// Set members from PropertySet
311 void set(const PropertySet &opts);
312 /// Get members into PropertySet
313 PropertySet get() const;
314 /// Convert to a string which can be parsed as a PropertySet
315 std::string toString() const;
316 } props;
317 /* }}} END OF GENERATED CODE */
318 };
319
320 /// Cost functions. Not used by BBP class, only used by following functions.
321 DAI_ENUM(bbp_cfn_t,cfn_gibbs_b,cfn_gibbs_b2,cfn_gibbs_exp,cfn_gibbs_b_factor,cfn_gibbs_b2_factor,cfn_gibbs_exp_factor,cfn_var_ent,cfn_factor_ent,cfn_bethe_ent);
322
323 /// Initialise BBP using InfAlg, cost function, and stateP
324 /** Calls bbp.init with adjoints calculated from ia.beliefV and
325 * ia.beliefF. stateP is a Gibbs state and can be NULL, it will be
326 * initialised using a Gibbs run of 2*fg.Iterations() iterations.
327 */
328 void initBBPCostFnAdj(BBP& bbp, const InfAlg& ia, bbp_cfn_t cfn_type, const std::vector<size_t>* stateP);
329
330 /// Answers question: Does given cost function depend on having a Gibbs state?
331 bool needGibbsState(bbp_cfn_t cfn);
332
333 /// Calculate actual value of cost function (cfn_type, stateP)
334 /** This function returns the actual value of the cost function whose
335 * gradient with respect to singleton beliefs is given by
336 * gibbsToB1Adj on the same arguments
337 */
338 Real getCostFn(const InfAlg& fg, bbp_cfn_t cfn_type, const std::vector<size_t> *stateP);
339
340 /// Function to test the validity of adjoints computed by BBP
341 /** given a state for each variable, use numerical derivatives
342 * (multiplying a factor containing a variable by psi_1 adjustments)
343 * to verify accuracy of _adj_psi_V.
344 * 'h' controls size of perturbation.
345 * 'bbpTol' controls tolerance of BBP run.
346 */
347 double numericBBPTest(const InfAlg& bp, const std::vector<size_t> *state, const PropertySet& bbp_props, bbp_cfn_t cfn, double h);
348
349 // ----------------------------------------------------------------
350 // Utility functions, some of which are used elsewhere
351
352 /// Computes the adjoint of the unnormed probability vector from the normalizer and the adjoint of the normalized probability vector @see eqn. (13) in [\ref EaG09]
353 Prob unnormAdjoint( const Prob &w, Real Z_w, const Prob &adj_w );
354
355 /// Runs Gibbs sampling for 'iters' iterations on ia.fg(), and returns state
356 std::vector<size_t> getGibbsState(const InfAlg& ia, size_t iters);
357
358
359 } // end of namespace dai
360
361
362 #endif