Fixed a Windows build issue
[libdai.git] / include / dai / bbp.h
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 /// \file
22 /// \brief Defines class BBP [\ref EaG09]
23 /// \todo Improve documentation
24 /// \todo Clean up
25
26
27 #ifndef ___defined_libdai_bbp_h
28 #define ___defined_libdai_bbp_h
29
30
31 #include <vector>
32 #include <utility>
33
34 #include <dai/prob.h>
35 #include <dai/daialg.h>
36 #include <dai/factorgraph.h>
37 #include <dai/enum.h>
38
39 #include <dai/bp_dual.h>
40
41
42 namespace dai {
43
44
45 std::vector<Prob> get_zero_adj_F(const FactorGraph&);
46 std::vector<Prob> get_zero_adj_V(const FactorGraph&);
47
48
49 /// Implements BBP (Back-belief-propagation) [\ref EaG09]
50 class BBP {
51 protected:
52 // ----------------------------------------------------------------
53 // inputs
54 BP_dual _bp_dual;
55 const FactorGraph* _fg;
56 const InfAlg *_ia;
57
58 // iterations
59 size_t _iters;
60
61 // ----------------------------------------------------------------
62 // Outputs
63 std::vector<Prob> _adj_psi_V, _adj_psi_F;
64 // The following vectors are indexed [i][_I]
65 std::vector<std::vector<Prob> > _adj_n, _adj_m;
66 std::vector<Prob> _adj_b_V, _adj_b_F;
67
68 // Helper quantities computed from the BP messages:
69 // _T[i][_I]
70 std::vector<std::vector<Prob > > _T;
71 // _U[I][_i]
72 std::vector<std::vector<Prob > > _U;
73 // _S[i][_I][_j]
74 std::vector<std::vector<std::vector<Prob > > > _S;
75 // _R[I][_i][_J]
76 std::vector<std::vector<std::vector<Prob > > > _R;
77
78 std::vector<Prob> _adj_b_V_unnorm, _adj_b_F_unnorm;
79 std::vector<Prob> _init_adj_psi_V;
80 std::vector<Prob> _init_adj_psi_F;
81
82 std::vector<std::vector<Prob> > _adj_n_unnorm, _adj_m_unnorm;
83 std::vector<std::vector<Prob> > _new_adj_n, _new_adj_m;
84
85 // ----------------------------------------------------------------
86 // Indexing for performance
87
88 /// Calculates _indices, which is a cache of IndexFor (see bp.cpp)
89 void RegenerateInds();
90
91 typedef std::vector<size_t> _ind_t;
92 std::vector<std::vector<_ind_t> > _indices;
93 const _ind_t& _index(size_t i, size_t _I) const { return _indices[i][_I]; }
94
95 // ----------------------------------------------------------------
96 // Initialization
97
98 /// Calculate T values (see paper)
99 void RegenerateT();
100 /// Calculate U values (see paper)
101 void RegenerateU();
102 /// Calculate S values (see paper)
103 void RegenerateS();
104 /// Calculate R values (see paper)
105 void RegenerateR();
106 /// Calculate _adj_b_V_unnorm and _adj_b_F_unnorm from _adj_b_V and _adj_b_F
107 void RegenerateInputs();
108 /// Initialise members for factor adjoints (call after RegenerateInputs)
109 void RegeneratePsiAdjoints();
110 /// Initialise members for messages adjoints (call after RegenerateInputs)
111 void RegenerateParMessageAdjoints();
112 /** Same as RegenerateMessageAdjoints, but calls sendSeqMsgN rather
113 * than updating _adj_n (and friends) which are unused in sequential algorithm
114 */
115 void RegenerateSeqMessageAdjoints();
116
117 DAI_ACCMUT(Prob & T(size_t i, size_t _I), { return _T[i][_I]; });
118 DAI_ACCMUT(Prob & U(size_t I, size_t _i), { return _U[I][_i]; });
119 DAI_ACCMUT(Prob & S(size_t i, size_t _I, size_t _j), { return _S[i][_I][_j]; });
120 DAI_ACCMUT(Prob & R(size_t I, size_t _i, size_t _J), { return _R[I][_i][_J]; });
121
122 void calcNewN(size_t i, size_t _I);
123 void calcNewM(size_t i, size_t _I);
124 void calcUnnormMsgM(size_t i, size_t _I);
125 void calcUnnormMsgN(size_t i, size_t _I);
126 void upMsgM(size_t i, size_t _I);
127 void upMsgN(size_t i, size_t _I);
128 void doParUpdate();
129 Real getUnMsgMag();
130 void getMsgMags(Real &s, Real &new_s);
131
132 void zero_adj_b_F() {
133 _adj_b_F.clear();
134 _adj_b_F.reserve(_fg->nrFactors());
135 for(size_t I=0; I<_fg->nrFactors(); I++) {
136 _adj_b_F.push_back(Prob(_fg->factor(I).states(),Real(0.0)));
137 }
138 }
139
140 //----------------------------------------------------------------
141 // new interface
142
143 void incrSeqMsgM(size_t i, size_t _I, const Prob& p);
144 void updateSeqMsgM(size_t i, size_t _I);
145 void sendSeqMsgN(size_t i, size_t _I, const Prob &f);
146 void sendSeqMsgM(size_t i, size_t _I);
147 /// used instead of upMsgM / calcNewM, calculates adj_m_unnorm as well
148 void setSeqMsgM(size_t i, size_t _I, const Prob &p);
149
150 Real getMaxMsgM();
151 Real getTotalMsgM();
152 Real getTotalNewMsgM();
153 Real getTotalMsgN();
154
155 void getArgmaxMsgM(size_t &i, size_t &_I, Real &mag);
156
157 public:
158 /// Called by 'init', recalculates intermediate values
159 void Regenerate();
160
161 BBP(const InfAlg *ia, const PropertySet &opts) :
162 _bp_dual(ia), _fg(&(ia->fg())), _ia(ia)
163 {
164 props.set(opts);
165 }
166
167 void init(const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F,
168 const std::vector<Prob> &adj_psi_V, const std::vector<Prob> &adj_psi_F) {
169 _adj_b_V = adj_b_V;
170 _adj_b_F = adj_b_F;
171 _init_adj_psi_V = adj_psi_V;
172 _init_adj_psi_F = adj_psi_F;
173 Regenerate();
174 }
175 void init(const std::vector<Prob> &adj_b_V, const std::vector<Prob> &adj_b_F) {
176 init(adj_b_V, adj_b_F, get_zero_adj_V(*_fg), get_zero_adj_F(*_fg));
177 }
178 void init(const std::vector<Prob> &adj_b_V) {
179 init(adj_b_V, get_zero_adj_F(*_fg));
180 }
181
182 /// run until change is less than given tolerance
183 void run();
184
185 size_t doneIters() { return _iters; }
186
187 DAI_ACCMUT(Prob& adj_psi_V(size_t i), { return _adj_psi_V[i]; });
188 DAI_ACCMUT(Prob& adj_psi_F(size_t I), { return _adj_psi_F[I]; });
189 DAI_ACCMUT(Prob& adj_b_V(size_t i), { return _adj_b_V[i]; });
190 DAI_ACCMUT(Prob& adj_b_F(size_t I), { return _adj_b_F[I]; });
191 protected:
192 DAI_ACCMUT(Prob& adj_n(size_t i, size_t _I), { return _adj_n[i][_I]; });
193 DAI_ACCMUT(Prob& adj_m(size_t i, size_t _I), { return _adj_m[i][_I]; });
194 public:
195
196 /* PROPERTIES(props,BBP) {
197 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
198 size_t verbose;
199 /// tolerance (not used for updates=SEQ_BP_{REV,FWD})
200 double tol;
201 size_t maxiter;
202 /// damping (0 for none)
203 double damping;
204 UpdateType updates;
205 bool clean_updates;
206 }
207 */
208 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
209 ./scripts/regenerate-properties include/dai/bbp.h src/bbp.cpp
210 */
211 struct Properties {
212 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
213 size_t verbose;
214 /// tolerance (not used for updates=SEQ_BP_{REV,FWD})
215 double tol;
216 size_t maxiter;
217 /// damping (0 for none)
218 double damping;
219 UpdateType updates;
220 bool clean_updates;
221
222 /// Set members from PropertySet
223 void set(const PropertySet &opts);
224 /// Get members into PropertySet
225 PropertySet get() const;
226 /// Convert to a string which can be parsed as a PropertySet
227 std::string toString() const;
228 } props;
229 /* }}} END OF GENERATED CODE */
230 };
231
232 /// Cost functions. Not used by BBP class, only used by following functions.
233 DAI_ENUM(bbp_cfn_t,cfn_gibbs_b,cfn_gibbs_b2,cfn_gibbs_exp,cfn_gibbs_b_factor,cfn_gibbs_b2_factor,cfn_gibbs_exp_factor,cfn_var_ent,cfn_factor_ent,cfn_bethe_ent);
234
235 /// Initialise BBP using InfAlg, cost function, and stateP
236 /** Calls bbp.init with adjoints calculated from ia.beliefV and
237 * ia.beliefF. stateP is a Gibbs state and can be NULL, it will be
238 * initialised using a Gibbs run of 2*fg.Iterations() iterations.
239 */
240 void initBBPCostFnAdj(BBP& bbp, const InfAlg& ia, bbp_cfn_t cfn_type, const std::vector<size_t>* stateP);
241
242 /// Answers question: Does given cost function depend on having a Gibbs state?
243 bool needGibbsState(bbp_cfn_t cfn);
244
245 /// Calculate actual value of cost function (cfn_type, stateP)
246 /** This function returns the actual value of the cost function whose
247 * gradient with respect to singleton beliefs is given by
248 * gibbsToB1Adj on the same arguments
249 */
250 Real getCostFn(const InfAlg& fg, bbp_cfn_t cfn_type, const std::vector<size_t> *stateP);
251
252 /// Function to test the validity of adjoints computed by BBP
253 /** given a state for each variable, use numerical derivatives
254 * (multiplying a factor containing a variable by psi_1 adjustments)
255 * to verify accuracy of _adj_psi_V.
256 * 'h' controls size of perturbation.
257 * 'bbpTol' controls tolerance of BBP run.
258 */
259 double numericBBPTest(const InfAlg& bp, const std::vector<size_t> *state, const PropertySet& bbp_props, bbp_cfn_t cfn, double h);
260
261 // ----------------------------------------------------------------
262 // Utility functions, some of which are used elsewhere
263
264 /// Subtract 1 from a size_t, or return 0 if the argument is 0
265 inline size_t oneLess(size_t v) { return v==0?v:v-1; }
266
267 /// function to compute adj_w_unnorm from w, Z_w, adj_w
268 Prob unnormAdjoint(const Prob &w, Real Z_w, const Prob &adj_w);
269
270 /// Runs Gibbs sampling for 'iters' iterations on ia.fg(), and returns state
271 std::vector<size_t> getGibbsState(const InfAlg& ia, size_t iters);
272
273
274 } // end of namespace dai
275
276
277 #endif