ddfeb4298bb6cb75f1675dfe8b970624cc420532
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
3 This file is part of libDAI.
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
22 /// \brief Defines class BBP [\ref EaG09]
23 /// \todo Improve documentation
27 #ifndef ___defined_libdai_bbp_h
28 #define ___defined_libdai_bbp_h
35 #include <dai/daialg.h>
36 #include <dai/factorgraph.h>
39 #include <dai/bp_dual.h>
45 std::vector
<Prob
> get_zero_adj_F(const FactorGraph
&);
46 std::vector
<Prob
> get_zero_adj_V(const FactorGraph
&);
49 /// Implements BBP (Back-belief-propagation) [\ref EaG09]
52 // ----------------------------------------------------------------
55 const FactorGraph
* _fg
;
61 // ----------------------------------------------------------------
63 std::vector
<Prob
> _adj_psi_V
, _adj_psi_F
;
64 // The following vectors are indexed [i][_I]
65 std::vector
<std::vector
<Prob
> > _adj_n
, _adj_m
;
66 std::vector
<Prob
> _adj_b_V
, _adj_b_F
;
68 // Helper quantities computed from the BP messages:
70 std::vector
<std::vector
<Prob
> > _T
;
72 std::vector
<std::vector
<Prob
> > _U
;
74 std::vector
<std::vector
<std::vector
<Prob
> > > _S
;
76 std::vector
<std::vector
<std::vector
<Prob
> > > _R
;
78 std::vector
<Prob
> _adj_b_V_unnorm
, _adj_b_F_unnorm
;
79 std::vector
<Prob
> _init_adj_psi_V
;
80 std::vector
<Prob
> _init_adj_psi_F
;
82 std::vector
<std::vector
<Prob
> > _adj_n_unnorm
, _adj_m_unnorm
;
83 std::vector
<std::vector
<Prob
> > _new_adj_n
, _new_adj_m
;
85 // ----------------------------------------------------------------
86 // Indexing for performance
88 /// Calculates _indices, which is a cache of IndexFor (see bp.cpp)
89 void RegenerateInds();
91 typedef std::vector
<size_t> _ind_t
;
92 std::vector
<std::vector
<_ind_t
> > _indices
;
93 const _ind_t
& _index(size_t i
, size_t _I
) const { return _indices
[i
][_I
]; }
95 // ----------------------------------------------------------------
98 /// Calculate T values (see paper)
100 /// Calculate U values (see paper)
102 /// Calculate S values (see paper)
104 /// Calculate R values (see paper)
106 /// Calculate _adj_b_V_unnorm and _adj_b_F_unnorm from _adj_b_V and _adj_b_F
107 void RegenerateInputs();
108 /// Initialise members for factor adjoints (call after RegenerateInputs)
109 void RegeneratePsiAdjoints();
110 /// Initialise members for messages adjoints (call after RegenerateInputs)
111 void RegenerateParMessageAdjoints();
112 /** Same as RegenerateMessageAdjoints, but calls sendSeqMsgN rather
113 * than updating _adj_n (and friends) which are unused in sequential algorithm
115 void RegenerateSeqMessageAdjoints();
117 DAI_ACCMUT(Prob
& T(size_t i
, size_t _I
), { return _T
[i
][_I
]; });
118 DAI_ACCMUT(Prob
& U(size_t I
, size_t _i
), { return _U
[I
][_i
]; });
119 DAI_ACCMUT(Prob
& S(size_t i
, size_t _I
, size_t _j
), { return _S
[i
][_I
][_j
]; });
120 DAI_ACCMUT(Prob
& R(size_t I
, size_t _i
, size_t _J
), { return _R
[I
][_i
][_J
]; });
122 void calcNewN(size_t i
, size_t _I
);
123 void calcNewM(size_t i
, size_t _I
);
124 void calcUnnormMsgM(size_t i
, size_t _I
);
125 void calcUnnormMsgN(size_t i
, size_t _I
);
126 void upMsgM(size_t i
, size_t _I
);
127 void upMsgN(size_t i
, size_t _I
);
130 void getMsgMags(Real
&s
, Real
&new_s
);
132 void zero_adj_b_F() {
134 _adj_b_F
.reserve(_fg
->nrFactors());
135 for(size_t I
=0; I
<_fg
->nrFactors(); I
++) {
136 _adj_b_F
.push_back(Prob(_fg
->factor(I
).states(),Real(0.0)));
140 //----------------------------------------------------------------
143 void incrSeqMsgM(size_t i
, size_t _I
, const Prob
& p
);
144 void updateSeqMsgM(size_t i
, size_t _I
);
145 void sendSeqMsgN(size_t i
, size_t _I
, const Prob
&f
);
146 void sendSeqMsgM(size_t i
, size_t _I
);
147 /// used instead of upMsgM / calcNewM, calculates adj_m_unnorm as well
148 void setSeqMsgM(size_t i
, size_t _I
, const Prob
&p
);
152 Real
getTotalNewMsgM();
155 void getArgmaxMsgM(size_t &i
, size_t &_I
, Real
&mag
);
158 /// Called by 'init', recalculates intermediate values
161 BBP(const InfAlg
*ia
, const PropertySet
&opts
) :
162 _bp_dual(ia
), _fg(&(ia
->fg())), _ia(ia
)
167 void init(const std::vector
<Prob
> &adj_b_V
, const std::vector
<Prob
> &adj_b_F
,
168 const std::vector
<Prob
> &adj_psi_V
, const std::vector
<Prob
> &adj_psi_F
) {
171 _init_adj_psi_V
= adj_psi_V
;
172 _init_adj_psi_F
= adj_psi_F
;
175 void init(const std::vector
<Prob
> &adj_b_V
, const std::vector
<Prob
> &adj_b_F
) {
176 init(adj_b_V
, adj_b_F
, get_zero_adj_V(*_fg
), get_zero_adj_F(*_fg
));
178 void init(const std::vector
<Prob
> &adj_b_V
) {
179 init(adj_b_V
, get_zero_adj_F(*_fg
));
182 /// run until change is less than given tolerance
185 size_t doneIters() { return _iters
; }
187 DAI_ACCMUT(Prob
& adj_psi_V(size_t i
), { return _adj_psi_V
[i
]; });
188 DAI_ACCMUT(Prob
& adj_psi_F(size_t I
), { return _adj_psi_F
[I
]; });
189 DAI_ACCMUT(Prob
& adj_b_V(size_t i
), { return _adj_b_V
[i
]; });
190 DAI_ACCMUT(Prob
& adj_b_F(size_t I
), { return _adj_b_F
[I
]; });
192 DAI_ACCMUT(Prob
& adj_n(size_t i
, size_t _I
), { return _adj_n
[i
][_I
]; });
193 DAI_ACCMUT(Prob
& adj_m(size_t i
, size_t _I
), { return _adj_m
[i
][_I
]; });
196 /* PROPERTIES(props,BBP) {
197 DAI_ENUM(UpdateType,SEQ_FIX,SEQ_MAX,SEQ_BP_REV,SEQ_BP_FWD,PAR);
199 /// tolerance (not used for updates=SEQ_BP_{REV,FWD})
202 /// damping (0 for none)
208 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
209 ./scripts/regenerate-properties include/dai/bbp.h src/bbp.cpp
212 DAI_ENUM(UpdateType
,SEQ_FIX
,SEQ_MAX
,SEQ_BP_REV
,SEQ_BP_FWD
,PAR
);
214 /// tolerance (not used for updates=SEQ_BP_{REV,FWD})
217 /// damping (0 for none)
222 /// Set members from PropertySet
223 void set(const PropertySet
&opts
);
224 /// Get members into PropertySet
225 PropertySet
get() const;
226 /// Convert to a string which can be parsed as a PropertySet
227 std::string
toString() const;
229 /* }}} END OF GENERATED CODE */
232 /// Cost functions. Not used by BBP class, only used by following functions.
233 DAI_ENUM(bbp_cfn_t
,cfn_gibbs_b
,cfn_gibbs_b2
,cfn_gibbs_exp
,cfn_gibbs_b_factor
,cfn_gibbs_b2_factor
,cfn_gibbs_exp_factor
,cfn_var_ent
,cfn_factor_ent
,cfn_bethe_ent
);
235 /// Initialise BBP using InfAlg, cost function, and stateP
236 /** Calls bbp.init with adjoints calculated from ia.beliefV and
237 * ia.beliefF. stateP is a Gibbs state and can be NULL, it will be
238 * initialised using a Gibbs run of 2*fg.Iterations() iterations.
240 void initBBPCostFnAdj(BBP
& bbp
, const InfAlg
& ia
, bbp_cfn_t cfn_type
, const std::vector
<size_t>* stateP
);
242 /// Answers question: Does given cost function depend on having a Gibbs state?
243 bool needGibbsState(bbp_cfn_t cfn
);
245 /// Calculate actual value of cost function (cfn_type, stateP)
246 /** This function returns the actual value of the cost function whose
247 * gradient with respect to singleton beliefs is given by
248 * gibbsToB1Adj on the same arguments
250 Real
getCostFn(const InfAlg
& fg
, bbp_cfn_t cfn_type
, const std::vector
<size_t> *stateP
);
252 /// Function to test the validity of adjoints computed by BBP
253 /** given a state for each variable, use numerical derivatives
254 * (multiplying a factor containing a variable by psi_1 adjustments)
255 * to verify accuracy of _adj_psi_V.
256 * 'h' controls size of perturbation.
257 * 'bbpTol' controls tolerance of BBP run.
259 double numericBBPTest(const InfAlg
& bp
, const std::vector
<size_t> *state
, const PropertySet
& bbp_props
, bbp_cfn_t cfn
, double h
);
261 // ----------------------------------------------------------------
262 // Utility functions, some of which are used elsewhere
264 /// Subtract 1 from a size_t, or return 0 if the argument is 0
265 inline size_t oneLess(size_t v
) { return v
==0?v
:v
-1; }
267 /// function to compute adj_w_unnorm from w, Z_w, adj_w
268 Prob
unnormAdjoint(const Prob
&w
, Real Z_w
, const Prob
&adj_w
);
270 /// Runs Gibbs sampling for 'iters' iterations on ia.fg(), and returns state
271 std::vector
<size_t> getGibbsState(const InfAlg
& ia
, size_t iters
);
274 } // end of namespace dai