4ba6498e6de2c2081e993b4e3d5bb2059b04492b
[libdai.git] / include / dai / bbp.h
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 /// \file
22 /// \brief Defines class BBP [\ref EaG09]
23 /// \todo Improve documentation
24 /// \todo Clean up
25
26
27 #ifndef ___defined_libdai_bbp_h
28 #define ___defined_libdai_bbp_h
29
30
31 #include <vector>
32 #include <utility>
33 #include <ext/hash_map>
34
35 #include <dai/prob.h>
36 #include <dai/daialg.h>
37 #include <dai/factorgraph.h>
38 #include <dai/enum.h>
39
40 #include <dai/bp_dual.h>
41
42
43 namespace dai {
44
45
46 std::vector<Prob> get_zero_adj_F(const FactorGraph&);
47 std::vector<Prob> get_zero_adj_V(const FactorGraph&);
48
49
50 /// Implements BBP (Back-belief-propagation) [\ref EaG09]
51 class BBP {
52 protected:
53 // ----------------------------------------------------------------
54 // inputs
55 BP_dual _bp_dual;
56 const FactorGraph* _fg;
57 const InfAlg *_ia;
58
59 // iterations
60 size_t _iters;
61
62 // ----------------------------------------------------------------
63 // Outputs
64 std::vector<Prob> _adj_psi_V, _adj_psi_F;
65 // The following vectors are indexed [i][_I]
66 std::vector<std::vector<Prob> > _adj_n, _adj_m;
67 std::vector<Prob> _adj_b_V, _adj_b_F;
68
69 // Helper quantities computed from the BP messages:
70 // _T[i][_I]
71 std::vector<std::vector<Prob > > _T;
72 // _U[I][_i]
73 std::vector<std::vector<Prob > > _U;
74 // _S[i][_I][_j]
75 std::vector<std::vector<std::vector<Prob > > > _S;
76 // _R[I][_i][_J]
77 std::vector<std::vector<std::vector<Prob > > > _R;
78
79 std::vector<Prob> _adj_b_V_unnorm, _adj_b_F_unnorm;
80 std::vector<Prob> _init_adj_psi_V;
81 std::vector<Prob> _init_adj_psi_F;
82
83 std::vector<std::vector<Prob> > _adj_n_unnorm, _adj_m_unnorm;
84 std::vector<std::vector<Prob> > _new_adj_n, _new_adj_m;
85
86 // ----------------------------------------------------------------
87 // Indexing for performance
88
89 /// Calculates _indices, which is a cache of IndexFor (see bp.cpp)
90 void RegenerateInds();
91
92 typedef std::vector<size_t> _ind_t;
93 std::vector<std::vector<_ind_t> > _indices;
94 const _ind_t& _index(size_t i, size_t _I) const { return _indices[i][_I]; }
95
96 // ----------------------------------------------------------------
97 // Initialization
98
99 /// Calculate T values (see paper)
100 void RegenerateT();
101 /// Calculate U values (see paper)
102 void RegenerateU();
103 /// Calculate S values (see paper)
104 void RegenerateS();
105 /// Calculate R values (see paper)
106 void RegenerateR();
107 /// Calculate _adj_b_V_unnorm and _adj_b_F_unnorm from _adj_b_V and _adj_b_F
108 void RegenerateInputs();
109 /// Initialise members for factor adjoints (call after RegenerateInputs)
110 void RegeneratePsiAdjoints();
111 /// Initialise members for messages adjoints (call after RegenerateInputs)
112 void RegenerateParMessageAdjoints();
113 /** Same as RegenerateMessageAdjoints, but calls sendSeqMsgN rather
114 * than updating _adj_n (and friends) which are unused in sequential algorithm
115 */
116 void RegenerateSeqMessageAdjoints();
117
118 DAI_ACCMUT(Prob & T(size_t i, size_t _I), { return _T[i][_I]; });
119 DAI_ACCMUT(Prob & U(size_t I, size_t _i), { return _U[I][_i]; });
120 DAI_ACCMUT(Prob & S(size_t i, size_t _I, size_t _j), { return _S[i][_I][_j]; });
121 DAI_ACCMUT(Prob & R(size_t I, size_t _i, size_t _J), { return _R[I][_i][_J]; });
122
123 void calcNewN(size_t i, size_t _I);
124 void calcNewM(size_t i, size_t _I);
125 void calcUnnormMsgM(size_t i, size_t _I);
126 void calcUnnormMsgN(size_t i, size_t _I);
127 void upMsgM(size_t i, size_t _I);
128 void upMsgN(size_t i, size_t _I);
129 void doParUpdate();
130 Real getUnMsgMag();
131 void getMsgMags(Real &s, Real &new_s);
132
133 void zero_adj_b_F() {
134 _adj_b_F.clear();
135 _adj_b_F.reserve(_fg->nrFactors());
136 for(size_t I=0; I<_fg->nrFactors(); I++) {
137 _adj_b_F.push_back(Prob(_fg->factor(I).states(),Real(0.0)));
138 }
139 }
140
141 //----------------------------------------------------------------
142 // new interface
143
144 void incrSeqMsgM(size_t i, size_t _I, const Prob& p);
145 void updateSeqMsgM(size_t i, size_t _I);
146 void sendSeqMsgN(size_t i, size_t _I, const Prob &f);
147 void sendSeqMsgM(size_t i, size_t _I);
148 /// used instead of upMsgM / calcNewM, calculates adj_m_unnorm as well
149 void setSeqMsgM(size_t i, size_t _I, const Prob &p);
150
151 Real getMaxMsgM();
152 Real getTotalMsgM();
153 Real getTotalNewMsgM();
154 Real getTotalMsgN();
155
156 void getArgmaxMsgM(size_t &i, size_t &_I, Real &mag);
157
158 public:
159 /// Called by 'init', recalculates intermediate values
160 void Regenerate();
161
162 BBP(const InfAlg *ia, const PropertySet &opts) :
163 _bp_dual(ia), _fg(&(ia->fg())), _ia(ia)
164 {
165 props.set(opts);
166 }
167
168 void init(const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F,
169 const std::vector<Prob> &adj_psi_V, const std::vector<Prob> &adj_psi_F) {
170 _adj_b_V = adj_b_V;
171 _adj_b_F = adj_b_F;
172 _init_adj_psi_V = adj_psi_V;
173 _init_adj_psi_F = adj_psi_F;
174 Regenerate();
175 }
176 void init(const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F) {
177 init(adj_b_V, adj_b_F, get_zero_adj_V(*_fg), get_zero_adj_F(*_fg));
178 }
179 void init(const std::vector<Prob> &adj_b_V) {
180 init(adj_b_V, get_zero_adj_F(*_fg));
181 }
182
183 /// run until change is less than given tolerance
184 void run();
185
186 size_t doneIters() { return _iters; }
187
188 DAI_ACCMUT(Prob& adj_psi_V(size_t i), { return _adj_psi_V[i]; });
189 DAI_ACCMUT(Prob& adj_psi_F(size_t I), { return _adj_psi_F[I]; });
190 DAI_ACCMUT(Prob& adj_b_V(size_t i), { return _adj_b_V[i]; });
191 DAI_ACCMUT(Prob& adj_b_F(size_t I), { return _adj_b_F[I]; });
192 protected:
193 DAI_ACCMUT(Prob& adj_n(size_t i, size_t _I), { return _adj_n[i][_I]; });
194 DAI_ACCMUT(Prob& adj_m(size_t i, size_t _I), { return _adj_m[i][_I]; });
195 public:
196
197 /* PROPERTIES(props,BBP) {
198 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
199 size_t verbose;
200 /// tolerance (not used for updates=SEQ_BP_{REV,FWD})
201 double tol;
202 size_t maxiter;
203 /// damping (0 for none)
204 double damping;
205 UpdateType updates;
206 bool clean_updates;
207 }
208 */
209 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
210 ./scripts/regenerate-properties include/dai/bbp.h src/bbp.cpp
211 */
212 struct Properties {
213 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
214 size_t verbose;
215 /// tolerance (not used for updates=SEQ_BP_{REV,FWD})
216 double tol;
217 size_t maxiter;
218 /// damping (0 for none)
219 double damping;
220 UpdateType updates;
221 bool clean_updates;
222
223 /// Set members from PropertySet
224 void set(const PropertySet &opts);
225 /// Get members into PropertySet
226 PropertySet get() const;
227 /// Convert to a string which can be parsed as a PropertySet
228 std::string toString() const;
229 } props;
230 /* }}} END OF GENERATED CODE */
231 };
232
233 /// Cost functions. Not used by BBP class, only used by following functions.
234 DAI_ENUM(bbp_cfn_t,cfn_gibbs_b,cfn_gibbs_b2,cfn_gibbs_exp,cfn_gibbs_b_factor,cfn_gibbs_b2_factor,cfn_gibbs_exp_factor,cfn_var_ent,cfn_factor_ent,cfn_bethe_ent);
235
236 /// Initialise BBP using InfAlg, cost function, and stateP
237 /** Calls bbp.init with adjoints calculated from ia.beliefV and
238 * ia.beliefF. stateP is a Gibbs state and can be NULL, it will be
239 * initialised using a Gibbs run of 2*fg.Iterations() iterations.
240 */
241 void initBBPCostFnAdj(BBP& bbp, const InfAlg& ia, bbp_cfn_t cfn_type, const std::vector<size_t>* stateP);
242
243 /// Answers question: Does given cost function depend on having a Gibbs state?
244 bool needGibbsState(bbp_cfn_t cfn);
245
246 /// Calculate actual value of cost function (cfn_type, stateP)
247 /** This function returns the actual value of the cost function whose
248 * gradient with respect to singleton beliefs is given by
249 * gibbsToB1Adj on the same arguments
250 */
251 Real getCostFn(const InfAlg& fg, bbp_cfn_t cfn_type, const std::vector<size_t> *stateP);
252
253 /// Function to test the validity of adjoints computed by BBP
254 /** given a state for each variable, use numerical derivatives
255 * (multiplying a factor containing a variable by psi_1 adjustments)
256 * to verify accuracy of _adj_psi_V.
257 * 'h' controls size of perturbation.
258 * 'bbpTol' controls tolerance of BBP run.
259 */
260 double numericBBPTest(const InfAlg& bp, const std::vector<size_t> *state, const PropertySet& bbp_props, bbp_cfn_t cfn, double h);
261
262 // ----------------------------------------------------------------
263 // Utility functions, some of which are used elsewhere
264
265 /// Subtract 1 from a size_t, or return 0 if the argument is 0
266 inline size_t oneLess(size_t v) { return v==0?v:v-1; }
267
268 /// function to compute adj_w_unnorm from w, Z_w, adj_w
269 Prob unnormAdjoint(const Prob &w, Real Z_w, const Prob &adj_w);
270
271 /// Runs Gibbs sampling for 'iters' iterations on ia.fg(), and returns state
272 std::vector<size_t> getGibbsState(const InfAlg& ia, size_t iters);
273
274
275 } // end of namespace dai
276
277
278 #endif