Fixed tabs and trailing whitespaces
[libdai.git] / include / dai / bbp.h
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 /// \file
22 /// \brief Defines class BBP [\ref EaG09]
23 /// \todo Improve documentation
24
25
26 #ifndef ___defined_libdai_bbp_h
27 #define ___defined_libdai_bbp_h
28
29
30 #include <vector>
31 #include <utility>
32
33 #include <dai/prob.h>
34 #include <dai/daialg.h>
35 #include <dai/factorgraph.h>
36 #include <dai/enum.h>
37 #include <dai/bp_dual.h>
38
39
40 namespace dai {
41
42
43 /// Computes the adjoint of the unnormed probability vector from the normalizer and the adjoint of the normalized probability vector @see eqn. (13) in [\ref EaG09]
44 Prob unnormAdjoint( const Prob &w, Real Z_w, const Prob &adj_w );
45
46 /// Runs Gibbs sampling for \a iters iterations on ia.fg(), and returns state
47 std::vector<size_t> getGibbsState( const InfAlg &ia, size_t iters );
48
49
50 /// Implements BBP (Back-Belief-Propagation) [\ref EaG09]
51 class BBP {
52 protected:
53 /// @name Inputs
54 //@{
55 BP_dual _bp_dual;
56 const FactorGraph *_fg;
57 const InfAlg *_ia;
58 //@}
59
60 /// Number of iterations done
61 size_t _iters;
62
63 /// @name Outputs
64 //@{
65 /// Variable factor adjoints
66 std::vector<Prob> _adj_psi_V;
67 /// Factor adjoints
68 std::vector<Prob> _adj_psi_F;
69 /// Variable->factor message adjoints (indexed [i][_I])
70 std::vector<std::vector<Prob> > _adj_n;
71 /// Factor->variable message adjoints (indexed [i][_I])
72 std::vector<std::vector<Prob> > _adj_m;
73 /// Normalized variable belief adjoints
74 std::vector<Prob> _adj_b_V;
75 /// Normalized factor belief adjoints
76 std::vector<Prob> _adj_b_F;
77 //@}
78
79 /// @name Helper quantities computed from the BP messages
80 //@{
81 /// _T[i][_I] (see eqn. (41) in [\ref EaG09])
82 std::vector<std::vector<Prob > > _T;
83 /// _U[I][_i] (see eqn. (42) in [\ref EaG09])
84 std::vector<std::vector<Prob > > _U;
85 /// _S[i][_I][_j] (see eqn. (43) in [\ref EaG09])
86 std::vector<std::vector<std::vector<Prob > > > _S;
87 /// _R[I][_i][_J] (see eqn. (44) in [\ref EaG09])
88 std::vector<std::vector<std::vector<Prob > > > _R;
89 //@}
90
91 /// Unnormalized variable belief adjoints
92 std::vector<Prob> _adj_b_V_unnorm;
93 /// Unnormalized factor belief adjoints
94 std::vector<Prob> _adj_b_F_unnorm;
95
96 /// Initial variable factor adjoints
97 std::vector<Prob> _init_adj_psi_V;
98 /// Initial factor adjoints
99 std::vector<Prob> _init_adj_psi_F;
100
101 /// Unnormalized variable->factor message adjoint (indexed [i][_I])
102 std::vector<std::vector<Prob> > _adj_n_unnorm;
103 /// Unnormalized factor->variable message adjoint (indexed [i][_I])
104 std::vector<std::vector<Prob> > _adj_m_unnorm;
105 /// Updated normalized variable->factor message adjoint (indexed [i][_I])
106 std::vector<std::vector<Prob> > _new_adj_n;
107 /// Updated normalized factor->variable message adjoint (indexed [i][_I])
108 std::vector<std::vector<Prob> > _new_adj_m;
109
110 /// @name Optimized indexing (for performance)
111 //@{
112 /// Calculates _indices, which is a cache of IndexFor @see bp.cpp
113 void RegenerateInds();
114
115 /// Index type
116 typedef std::vector<size_t> _ind_t;
117 /// Cached indices (indexed [i][_I])
118 std::vector<std::vector<_ind_t> > _indices;
119 /// Returns an index from the cache
120 const _ind_t& _index(size_t i, size_t _I) const { return _indices[i][_I]; }
121 //@}
122
123 /// @name Initialization
124 //@{
125 /// Calculate T values; see eqn. (41) in [\ref EaG09]
126 void RegenerateT();
127 /// Calculate U values; see eqn. (42) in [\ref EaG09]
128 void RegenerateU();
129 /// Calculate S values; see eqn. (43) in [\ref EaG09]
130 void RegenerateS();
131 /// Calculate R values; see eqn. (44) in [\ref EaG09]
132 void RegenerateR();
133 /// Calculate _adj_b_V_unnorm and _adj_b_F_unnorm from _adj_b_V and _adj_b_F
134 void RegenerateInputs();
135 /// Initialise members for factor adjoints (call after RegenerateInputs)
136 void RegeneratePsiAdjoints();
137 /// Initialise members for message adjoints (call after RegenerateInputs) for parallel algorithm
138 void RegenerateParMessageAdjoints();
139 /// Initialise members for message adjoints (call after RegenerateInputs) for sequential algorithm
140 /** Same as RegenerateMessageAdjoints, but calls sendSeqMsgN rather
141 * than updating _adj_n (and friends) which are unused in the sequential algorithm.
142 */
143 void RegenerateSeqMessageAdjoints();
144 //@}
145
146 /// Returns T value; see eqn. (41) in [\ref EaG09]
147 DAI_ACCMUT(Prob & T(size_t i, size_t _I), { return _T[i][_I]; });
148 /// Retunrs U value; see eqn. (42) in [\ref EaG09]
149 DAI_ACCMUT(Prob & U(size_t I, size_t _i), { return _U[I][_i]; });
150 /// Returns S value; see eqn. (43) in [\ref EaG09]
151 DAI_ACCMUT(Prob & S(size_t i, size_t _I, size_t _j), { return _S[i][_I][_j]; });
152 /// Returns R value; see eqn. (44) in [\ref EaG09]
153 DAI_ACCMUT(Prob & R(size_t I, size_t _i, size_t _J), { return _R[I][_i][_J]; });
154
155 /// @name Parallel algorithm
156 //@{
157 /// Calculates new variable->factor message adjoint
158 /** Increases variable factor adjoint according to eqn. (27) in [\ref EaG09] and
159 * calculates the new variable->factor message adjoint according to eqn. (29) in [\ref EaG09].
160 */
161 void calcNewN( size_t i, size_t _I );
162 /// Calculates new factor->variable message adjoint
163 /** Increases factor adjoint according to eqn. (28) in [\ref EaG09] and
164 * calculates the new factor->variable message adjoint according to the r.h.s. of eqn. (30) in [\ref EaG09].
165 */
166 void calcNewM( size_t i, size_t _I );
167 /// Calculates unnormalized variable->factor message adjoint from the normalized one
168 void calcUnnormMsgN( size_t i, size_t _I );
169 /// Calculates unnormalized factor->variable message adjoint from the normalized one
170 void calcUnnormMsgM( size_t i, size_t _I );
171 /// Updates (un)normalized variable->factor message adjoints
172 void upMsgN( size_t i, size_t _I );
173 /// Updates (un)normalized factor->variable message adjoints
174 void upMsgM( size_t i, size_t _I );
175 /// Do one parallel update of all message adjoints
176 void doParUpdate();
177 //@}
178
179 /// @name Sequential algorithm
180 //@{
181 /// Helper function for sendSeqMsgM: increases factor->variable message adjoint by p and calculates the corresponding unnormalized adjoint
182 void incrSeqMsgM( size_t i, size_t _I, const Prob& p );
183 // DISABLED BECAUSE IT IS BUGGY:
184 // void updateSeqMsgM( size_t i, size_t _I );
185 /// Sets normalized factor->variable message adjoint and calculates the corresponding unnormalized adjoint
186 void setSeqMsgM( size_t i, size_t _I, const Prob &p );
187 /// Implements routine Send-n in Figure 5 in [\ref EaG09]
188 void sendSeqMsgN( size_t i, size_t _I, const Prob &f );
189 /// Implements routine Send-m in Figure 5 in [\ref EaG09]
190 void sendSeqMsgM( size_t i, size_t _I );
191 //@}
192
193 /// Calculates averaged L-1 norm of unnormalized message adjoints
194 Real getUnMsgMag();
195 /// Calculates averaged L-1 norms of current and new normalized message adjoints
196 void getMsgMags( Real &s, Real &new_s );
197
198 /// Sets all vectors _adj_b_F to zero
199 void zero_adj_b_F() {
200 _adj_b_F.clear();
201 _adj_b_F.reserve( _fg->nrFactors() );
202 for( size_t I = 0; I < _fg->nrFactors(); I++ )
203 _adj_b_F.push_back( Prob( _fg->factor(I).states(), Real( 0.0 ) ) );
204 }
205
206 /// Returns indices and magnitude of the largest normalized factor->variable message adjoint
207 void getArgmaxMsgM( size_t &i, size_t &_I, Real &mag );
208 /// Returns magnitude of the largest (in L1-norm) normalized factor->variable message adjoint
209 Real getMaxMsgM();
210 /// Calculates sum of L1 norms of all normalized factor->variable message adjoints
211 Real getTotalMsgM();
212 /// Calculates sum of L1 norms of all updated normalized factor->variable message adjoints
213 Real getTotalNewMsgM();
214 /// Calculates sum of L1 norms of all normalized variable->factor message adjoints
215 Real getTotalMsgN();
216
217 public:
218 /// Called by \a init, recalculates intermediate values
219 void Regenerate();
220
221 /// Constructor
222 BBP( const InfAlg *ia, const PropertySet &opts ) : _bp_dual(ia), _fg(&(ia->fg())), _ia(ia) {
223 props.set(opts);
224 }
225
226 /// Returns a vector of Probs (filled with zeroes) with state spaces corresponding to the factors in the factor graph fg
227 std::vector<Prob> getZeroAdjF( const FactorGraph &fg );
228 /// Returns a vector of Probs (filled with zeroes) with state spaces corresponding to the variables in the factor graph fg
229 std::vector<Prob> getZeroAdjV( const FactorGraph &fg );
230
231 /// Initializes belief adjoints and initial factor adjoints and regenerates
232 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F, const std::vector<Prob> &adj_psi_V, const std::vector<Prob> &adj_psi_F ) {
233 _adj_b_V = adj_b_V;
234 _adj_b_F = adj_b_F;
235 _init_adj_psi_V = adj_psi_V;
236 _init_adj_psi_F = adj_psi_F;
237 Regenerate();
238 }
239
240 /// Initializes belief adjoints and with zero initial factor adjoints and regenerates
241 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F ) {
242 init( adj_b_V, adj_b_F, getZeroAdjV(*_fg), getZeroAdjF(*_fg) );
243 }
244
245 /// Initializes variable belief adjoints (and sets factor belief adjoints to zero) and with zero initial factor adjoints and regenerates
246 void init( const std::vector<Prob> &adj_b_V ) {
247 init(adj_b_V, getZeroAdjF(*_fg));
248 }
249
250 /// Run until change is less than given tolerance
251 void run();
252
253 /// Return number of iterations done so far
254 size_t doneIters() { return _iters; }
255
256 /// Returns variable factor adjoint
257 DAI_ACCMUT(Prob& adj_psi_V(size_t i), { return _adj_psi_V[i]; });
258 /// Returns factor adjoint
259 DAI_ACCMUT(Prob& adj_psi_F(size_t I), { return _adj_psi_F[I]; });
260 /// Returns variable belief adjoint
261 DAI_ACCMUT(Prob& adj_b_V(size_t i), { return _adj_b_V[i]; });
262 /// Returns factor belief adjoint
263 DAI_ACCMUT(Prob& adj_b_F(size_t I), { return _adj_b_F[I]; });
264
265 protected:
266 /// Returns variable->factor message adjoint
267 DAI_ACCMUT(Prob& adj_n(size_t i, size_t _I), { return _adj_n[i][_I]; });
268 /// Returns factor->variable message adjoint
269 DAI_ACCMUT(Prob& adj_m(size_t i, size_t _I), { return _adj_m[i][_I]; });
270
271 public:
272 /// Parameters of this algorithm
273 /* PROPERTIES(props,BBP) {
274 /// Enumeration of possible update schedules
275 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
276
277 /// Verbosity
278 size_t verbose;
279
280 /// Maximum number of iterations
281 size_t maxiter;
282
283 /// Tolerance (not used for updates = SEQ_BP_REV, SEQ_BP_FWD)
284 double tol;
285
286 /// Damping constant (0 for none); damping = 1 - lambda where lambda is the damping constant used in [\ref EaG09]
287 double damping;
288
289 /// Update schedule
290 UpdateType updates;
291
292 // DISABLED BECAUSE IT IS BUGGY:
293 // bool clean_updates;
294 }
295 */
296 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
297 ./scripts/regenerate-properties include/dai/bbp.h src/bbp.cpp
298 */
299 struct Properties {
300 /// Enumeration of possible update schedules
301 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
302 /// Verbosity
303 size_t verbose;
304 /// Maximum number of iterations
305 size_t maxiter;
306 /// Tolerance (not used for updates = SEQ_BP_REV, SEQ_BP_FWD)
307 double tol;
308 /// Damping constant (0 for none); damping = 1 - lambda where lambda is the damping constant used in [\ref EaG09]
309 double damping;
310 /// Update schedule
311 UpdateType updates;
312
313 /// Set members from PropertySet
314 void set(const PropertySet &opts);
315 /// Get members into PropertySet
316 PropertySet get() const;
317 /// Convert to a string which can be parsed as a PropertySet
318 std::string toString() const;
319 } props;
320 /* }}} END OF GENERATED CODE */
321 };
322
323
324 /// Enumeration of several cost functions that can be used with BBP.
325 DAI_ENUM(bbp_cfn_t,CFN_GIBBS_B,CFN_GIBBS_B2,CFN_GIBBS_EXP,CFN_GIBBS_B_FACTOR,CFN_GIBBS_B2_FACTOR,CFN_GIBBS_EXP_FACTOR,CFN_VAR_ENT,CFN_FACTOR_ENT,CFN_BETHE_ENT);
326
327 /// Initialise BBP using InfAlg, cost function, and stateP
328 /** Calls bbp.init with adjoints calculated from ia.beliefV and
329 * ia.beliefF. stateP is a Gibbs state and can be NULL, it will be
330 * initialised using a Gibbs run of 2*fg.Iterations() iterations.
331 */
332 void initBBPCostFnAdj( BBP &bbp, const InfAlg &ia, bbp_cfn_t cfn_type, const std::vector<size_t> *stateP );
333
334 /// Answers question: does the given cost function depend on having a Gibbs state?
335 bool needGibbsState( bbp_cfn_t cfn );
336
337 /// Calculate actual value of cost function (cfn_type, stateP)
338 /** This function returns the actual value of the cost function whose
339 * gradient with respect to singleton beliefs is given by
340 * gibbsToB1Adj on the same arguments
341 */
342 Real getCostFn( const InfAlg &fg, bbp_cfn_t cfn_type, const std::vector<size_t> *stateP );
343
344 /// Function to test the validity of adjoints computed by BBP given a state for each variable using numerical derivatives.
345 /** Factors containing a variable are multiplied by psi_1 adjustments to verify accuracy of _adj_psi_V.
346 * \param bp BP object.
347 * \param state Global state of all variables.
348 * \param bbp_props BBP Properties.
349 * \param cfn Cost function to be used.
350 * \param h controls size of perturbation.
351 */
352 double numericBBPTest( const InfAlg &bp, const std::vector<size_t> *state, const PropertySet &bbp_props, bbp_cfn_t cfn, double h );
353
354
355 } // end of namespace dai
356
357
358 #endif