2b41568e37e4a89e453c32e8c8d1ae0d7988735a
[libdai.git] / include / dai / bbp.h
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 /// \file
22 /// \brief Defines class BBP [\ref EaG09]
23 /// \todo Improve documentation
24 /// \todo Debug clean_updates
25
26
27 #ifndef ___defined_libdai_bbp_h
28 #define ___defined_libdai_bbp_h
29
30
31 #include <vector>
32 #include <utility>
33
34 #include <dai/prob.h>
35 #include <dai/daialg.h>
36 #include <dai/factorgraph.h>
37 #include <dai/enum.h>
38 #include <dai/bp_dual.h>
39
40
41 namespace dai {
42
43
44 /// Computes the adjoint of the unnormed probability vector from the normalizer and the adjoint of the normalized probability vector @see eqn. (13) in [\ref EaG09]
45 Prob unnormAdjoint( const Prob &w, Real Z_w, const Prob &adj_w );
46
47 /// Runs Gibbs sampling for \a iters iterations on ia.fg(), and returns state
48 std::vector<size_t> getGibbsState( const InfAlg &ia, size_t iters );
49
50
51 /// Implements BBP (Back-Belief-Propagation) [\ref EaG09]
52 class BBP {
53 protected:
54 /// @name Inputs
55 //@{
56 BP_dual _bp_dual;
57 const FactorGraph *_fg;
58 const InfAlg *_ia;
59 //@}
60
61 /// Number of iterations done
62 size_t _iters;
63
64 /// @name Outputs
65 //@{
66 /// Variable factor adjoints
67 std::vector<Prob> _adj_psi_V;
68 /// Factor adjoints
69 std::vector<Prob> _adj_psi_F;
70 /// Variable->factor message adjoints (indexed [i][_I])
71 std::vector<std::vector<Prob> > _adj_n;
72 /// Factor->variable message adjoints (indexed [i][_I])
73 std::vector<std::vector<Prob> > _adj_m;
74 /// Normalized variable belief adjoints
75 std::vector<Prob> _adj_b_V;
76 /// Normalized factor belief adjoints
77 std::vector<Prob> _adj_b_F;
78 //@}
79
80 /// @name Helper quantities computed from the BP messages
81 //@{
82 /// _T[i][_I] (see eqn. (41) in [\ref EaG09])
83 std::vector<std::vector<Prob > > _T;
84 /// _U[I][_i] (see eqn. (42) in [\ref EaG09])
85 std::vector<std::vector<Prob > > _U;
86 /// _S[i][_I][_j] (see eqn. (43) in [\ref EaG09])
87 std::vector<std::vector<std::vector<Prob > > > _S;
88 /// _R[I][_i][_J] (see eqn. (44) in [\ref EaG09])
89 std::vector<std::vector<std::vector<Prob > > > _R;
90 //@}
91
92 /// Unnormalized variable belief adjoints
93 std::vector<Prob> _adj_b_V_unnorm;
94 /// Unnormalized factor belief adjoints
95 std::vector<Prob> _adj_b_F_unnorm;
96
97 /// Initial variable factor adjoints
98 std::vector<Prob> _init_adj_psi_V;
99 /// Initial factor adjoints
100 std::vector<Prob> _init_adj_psi_F;
101
102 /// Unnormalized variable->factor message adjoint (indexed [i][_I])
103 std::vector<std::vector<Prob> > _adj_n_unnorm;
104 /// Unnormalized factor->variable message adjoint (indexed [i][_I])
105 std::vector<std::vector<Prob> > _adj_m_unnorm;
106 /// Updated normalized variable->factor message adjoint (indexed [i][_I])
107 std::vector<std::vector<Prob> > _new_adj_n;
108 /// Updated normalized factor->variable message adjoint (indexed [i][_I])
109 std::vector<std::vector<Prob> > _new_adj_m;
110
111 /// @name Optimized indexing (for performance)
112 //@{
113 /// Calculates _indices, which is a cache of IndexFor @see bp.cpp
114 void RegenerateInds();
115
116 /// Index type
117 typedef std::vector<size_t> _ind_t;
118 /// Cached indices (indexed [i][_I])
119 std::vector<std::vector<_ind_t> > _indices;
120 /// Returns an index from the cache
121 const _ind_t& _index(size_t i, size_t _I) const { return _indices[i][_I]; }
122 //@}
123
124 /// @name Initialization
125 //@{
126 /// Calculate T values; see eqn. (41) in [\ref EaG09]
127 void RegenerateT();
128 /// Calculate U values; see eqn. (42) in [\ref EaG09]
129 void RegenerateU();
130 /// Calculate S values; see eqn. (43) in [\ref EaG09]
131 void RegenerateS();
132 /// Calculate R values; see eqn. (44) in [\ref EaG09]
133 void RegenerateR();
134 /// Calculate _adj_b_V_unnorm and _adj_b_F_unnorm from _adj_b_V and _adj_b_F
135 void RegenerateInputs();
136 /// Initialise members for factor adjoints (call after RegenerateInputs)
137 void RegeneratePsiAdjoints();
138 /// Initialise members for message adjoints (call after RegenerateInputs) for parallel algorithm
139 void RegenerateParMessageAdjoints();
140 /// Initialise members for message adjoints (call after RegenerateInputs) for sequential algorithm
141 /** Same as RegenerateMessageAdjoints, but calls sendSeqMsgN rather
142 * than updating _adj_n (and friends) which are unused in the sequential algorithm.
143 */
144 void RegenerateSeqMessageAdjoints();
145 //@}
146
147 /// Returns T value; see eqn. (41) in [\ref EaG09]
148 DAI_ACCMUT(Prob & T(size_t i, size_t _I), { return _T[i][_I]; });
149 /// Retunrs U value; see eqn. (42) in [\ref EaG09]
150 DAI_ACCMUT(Prob & U(size_t I, size_t _i), { return _U[I][_i]; });
151 /// Returns S value; see eqn. (43) in [\ref EaG09]
152 DAI_ACCMUT(Prob & S(size_t i, size_t _I, size_t _j), { return _S[i][_I][_j]; });
153 /// Returns R value; see eqn. (44) in [\ref EaG09]
154 DAI_ACCMUT(Prob & R(size_t I, size_t _i, size_t _J), { return _R[I][_i][_J]; });
155
156 /// @name Parallel algorithm
157 //@{
158 /// Calculates new variable->factor message adjoint
159 /** Increases variable factor adjoint according to eqn. (27) in [\ref EaG09] and
160 * calculates the new variable->factor message adjoint according to eqn. (29) in [\ref EaG09].
161 */
162 void calcNewN( size_t i, size_t _I );
163 /// Calculates new factor->variable message adjoint
164 /** Increases factor adjoint according to eqn. (28) in [\ref EaG09] and
165 * calculates the new factor->variable message adjoint according to the r.h.s. of eqn. (30) in [\ref EaG09].
166 */
167 void calcNewM( size_t i, size_t _I );
168 /// Calculates unnormalized variable->factor message adjoint from the normalized one
169 void calcUnnormMsgN( size_t i, size_t _I );
170 /// Calculates unnormalized factor->variable message adjoint from the normalized one
171 void calcUnnormMsgM( size_t i, size_t _I );
172 /// Updates (un)normalized variable->factor message adjoints
173 void upMsgN( size_t i, size_t _I );
174 /// Updates (un)normalized factor->variable message adjoints
175 void upMsgM( size_t i, size_t _I );
176 /// Do one parallel update of all message adjoints
177 void doParUpdate();
178 //@}
179
180 /// @name Sequential algorithm
181 //@{
182 /// Helper function for sendSeqMsgM: increases factor->variable message adjoint by p and calculates the corresponding unnormalized adjoint
183 void incrSeqMsgM( size_t i, size_t _I, const Prob& p );
184 void updateSeqMsgM( size_t i, size_t _I );
185 /// Sets normalized factor->variable message adjoint and calculates the corresponding unnormalized adjoint
186 void setSeqMsgM( size_t i, size_t _I, const Prob &p );
187 /// Implements routine Send-n in Figure 5 in [\ref EaG09]
188 void sendSeqMsgN( size_t i, size_t _I, const Prob &f );
189 /// Implements routine Send-m in Figure 5 in [\ref EaG09]
190 void sendSeqMsgM( size_t i, size_t _I );
191 //@}
192
193 /// Calculates averaged L-1 norm of unnormalized message adjoints
194 Real getUnMsgMag();
195 /// Calculates averaged L-1 norms of current and new normalized message adjoints
196 void getMsgMags( Real &s, Real &new_s );
197
198 /// Sets all vectors _adj_b_F to zero
199 void zero_adj_b_F() {
200 _adj_b_F.clear();
201 _adj_b_F.reserve( _fg->nrFactors() );
202 for( size_t I = 0; I < _fg->nrFactors(); I++ )
203 _adj_b_F.push_back( Prob( _fg->factor(I).states(), Real( 0.0 ) ) );
204 }
205
206 /// Returns indices and magnitude of the largest normalized factor->variable message adjoint
207 void getArgmaxMsgM( size_t &i, size_t &_I, Real &mag );
208 /// Returns magnitude of the largest (in L1-norm) normalized factor->variable message adjoint
209 Real getMaxMsgM();
210 /// Calculates sum of L1 norms of all normalized factor->variable message adjoints
211 Real getTotalMsgM();
212 /// Calculates sum of L1 norms of all updated normalized factor->variable message adjoints
213 Real getTotalNewMsgM();
214 /// Calculates sum of L1 norms of all normalized variable->factor message adjoints
215 Real getTotalMsgN();
216
217 public:
218 /// Called by \a init, recalculates intermediate values
219 void Regenerate();
220
221 /// Constructor
222 BBP( const InfAlg *ia, const PropertySet &opts ) : _bp_dual(ia), _fg(&(ia->fg())), _ia(ia) {
223 props.set(opts);
224 }
225
226 /// Returns a vector of Probs (filled with zeroes) with state spaces corresponding to the factors in the factor graph fg
227 std::vector<Prob> getZeroAdjF( const FactorGraph &fg );
228 /// Returns a vector of Probs (filled with zeroes) with state spaces corresponding to the variables in the factor graph fg
229 std::vector<Prob> getZeroAdjV( const FactorGraph &fg );
230
231 /// Initializes belief adjoints and initial factor adjoints and regenerates
232 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F, const std::vector<Prob> &adj_psi_V, const std::vector<Prob> &adj_psi_F ) {
233 _adj_b_V = adj_b_V;
234 _adj_b_F = adj_b_F;
235 _init_adj_psi_V = adj_psi_V;
236 _init_adj_psi_F = adj_psi_F;
237 Regenerate();
238 }
239
240 /// Initializes belief adjoints and with zero initial factor adjoints and regenerates
241 void init( const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F ) {
242 init( adj_b_V, adj_b_F, getZeroAdjV(*_fg), getZeroAdjF(*_fg) );
243 }
244
245 /// Initializes variable belief adjoints (and sets factor belief adjoints to zero) and with zero initial factor adjoints and regenerates
246 void init( const std::vector<Prob> &adj_b_V ) {
247 init(adj_b_V, getZeroAdjF(*_fg));
248 }
249
250 /// Run until change is less than given tolerance
251 void run();
252
253 /// Return number of iterations done so far
254 size_t doneIters() { return _iters; }
255
256 /// Returns variable factor adjoint
257 DAI_ACCMUT(Prob& adj_psi_V(size_t i), { return _adj_psi_V[i]; });
258 /// Returns factor adjoint
259 DAI_ACCMUT(Prob& adj_psi_F(size_t I), { return _adj_psi_F[I]; });
260 /// Returns variable belief adjoint
261 DAI_ACCMUT(Prob& adj_b_V(size_t i), { return _adj_b_V[i]; });
262 /// Returns factor belief adjoint
263 DAI_ACCMUT(Prob& adj_b_F(size_t I), { return _adj_b_F[I]; });
264
265 protected:
266 /// Returns variable->factor message adjoint
267 DAI_ACCMUT(Prob& adj_n(size_t i, size_t _I), { return _adj_n[i][_I]; });
268 /// Returns factor->variable message adjoint
269 DAI_ACCMUT(Prob& adj_m(size_t i, size_t _I), { return _adj_m[i][_I]; });
270
271 public:
272 /// Parameters of this algorithm
273 /* PROPERTIES(props,BBP) {
274 /// Enumeration of possible update schedules
275 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
276
277 /// Verbosity
278 size_t verbose;
279
280 /// Maximum number of iterations
281 size_t maxiter;
282
283 /// Tolerance (not used for updates = SEQ_BP_REV, SEQ_BP_FWD)
284 double tol;
285
286 /// Damping constant (0 for none); damping = 1 - lambda where lambda is the damping constant used in [\ref EaG09]
287 double damping;
288
289 /// Update schedule
290 UpdateType updates;
291
292 bool clean_updates;
293 }
294 */
295 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
296 ./scripts/regenerate-properties include/dai/bbp.h src/bbp.cpp
297 */
298 struct Properties {
299 /// Enumeration of possible update schedules
300 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
301 /// Verbosity
302 size_t verbose;
303 /// Maximum number of iterations
304 size_t maxiter;
305 /// Tolerance (not used for updates = SEQ_BP_REV, SEQ_BP_FWD)
306 double tol;
307 /// Damping constant (0 for none); damping = 1 - lambda where lambda is the damping constant used in [\ref EaG09]
308 double damping;
309 /// Update schedule
310 UpdateType updates;
311 bool clean_updates;
312
313 /// Set members from PropertySet
314 void set(const PropertySet &opts);
315 /// Get members into PropertySet
316 PropertySet get() const;
317 /// Convert to a string which can be parsed as a PropertySet
318 std::string toString() const;
319 } props;
320 /* }}} END OF GENERATED CODE */
321 };
322
323
324 /// Enumeration of several cost functions that can be used with BBP.
325 DAI_ENUM(bbp_cfn_t,CFN_GIBBS_B,CFN_GIBBS_B2,CFN_GIBBS_EXP,CFN_GIBBS_B_FACTOR,CFN_GIBBS_B2_FACTOR,CFN_GIBBS_EXP_FACTOR,CFN_VAR_ENT,CFN_FACTOR_ENT,CFN_BETHE_ENT);
326
327 /// Initialise BBP using InfAlg, cost function, and stateP
328 /** Calls bbp.init with adjoints calculated from ia.beliefV and
329 * ia.beliefF. stateP is a Gibbs state and can be NULL, it will be
330 * initialised using a Gibbs run of 2*fg.Iterations() iterations.
331 */
332 void initBBPCostFnAdj( BBP &bbp, const InfAlg &ia, bbp_cfn_t cfn_type, const std::vector<size_t> *stateP );
333
334 /// Answers question: does the given cost function depend on having a Gibbs state?
335 bool needGibbsState( bbp_cfn_t cfn );
336
337 /// Calculate actual value of cost function (cfn_type, stateP)
338 /** This function returns the actual value of the cost function whose
339 * gradient with respect to singleton beliefs is given by
340 * gibbsToB1Adj on the same arguments
341 */
342 Real getCostFn( const InfAlg &fg, bbp_cfn_t cfn_type, const std::vector<size_t> *stateP );
343
344 /// Function to test the validity of adjoints computed by BBP given a state for each variable using numerical derivatives.
345 /** Factors containing a variable are multiplied by psi_1 adjustments to verify accuracy of _adj_psi_V.
346 * \param h controls size of perturbation.
347 */
348 double numericBBPTest( const InfAlg &bp, const std::vector<size_t> *state, const PropertySet &bbp_props, bbp_cfn_t cfn, double h );
349
350
351 } // end of namespace dai
352
353
354 #endif