Cleanup of BBP code
[libdai.git] / include / dai / cbp.h
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 /// \file
22 /// \brief Defines class CBP [\ref EaG09]
23 /// \todo Improve documentation
24 /// \todo Clean up
25
26
27 #ifndef __defined_libdai_cbp_h
28 #define __defined_libdai_cbp_h
29
30
31 #include <fstream>
32
33 #include <boost/shared_ptr.hpp>
34
35 #include <dai/daialg.h>
36
37 #include <dai/cbp.h>
38 #include <dai/bbp.h>
39
40
41 namespace dai {
42
43
44 /// Find a variable to clamp using BBP (goes with maximum adjoint)
45 /// \see BBP
46 std::pair<size_t, size_t> bbpFindClampVar(const InfAlg &in_bp, bool clampingVar,
47 const PropertySet &bbp_props, bbp_cfn_t cfn,
48 Real *maxVarOut);
49
50
51 /// Class for CBP (Clamped Belief Propagation)
52 /** This algorithm uses configurable heuristics to choose a variable
53 * x_i and a state x_i*. Inference is done with x_i "clamped" to x_i*
54 * (e.g. conditional on x_i==x_i*), and also with the negation of this
55 * condition. Clamping is done recursively up to a fixed number of
56 * levels (other stopping criteria are also implemented, see
57 * "recursion" property). The resulting approximate marginals are
58 * combined using logZ estimates.
59 */
60 class CBP : public DAIAlgFG {
61 private:
62 std::vector<Factor> _beliefsV;
63 std::vector<Factor> _beliefsF;
64 double _logZ;
65 double _est_logZ;
66 double _old_est_logZ;
67
68 double _sum_level;
69 size_t _num_leaves;
70
71 boost::shared_ptr<std::ofstream> _clamp_ofstream;
72
73 bbp_cfn_t BBP_cost_function() {
74 return props.bbp_cfn;
75 }
76
77 void printDebugInfo();
78
79 /// Called by 'run', and by itself. Implements the main algorithm.
80 /** Chooses a variable to clamp, recurses, combines the logZ and
81 * beliefs estimates of the children, and returns the improved
82 * estimates in \a lz_out and \a beliefs_out to its parent
83 */
84 void runRecurse(InfAlg *bp,
85 double orig_logZ,
86 std::vector<size_t> clamped_vars_list,
87 size_t &num_leaves,
88 size_t &choose_count,
89 double &sum_level,
90 Real &lz_out,
91 std::vector<Factor>& beliefs_out);
92
93 /// Choose the next variable to clamp
94 /** Choose the next variable to clamp, given a converged InfAlg (\a bp),
95 * and a vector of variables that are already clamped (\a
96 * clamped_vars_list). Returns the chosen variable in \a i, and
97 * the set of states in \a xis. If \a maxVarOut is non-NULL and
98 * props.choose==CHOOSE_BBP then it is used to store the
99 * adjoint of the chosen variable
100 */
101 virtual bool chooseNextClampVar(InfAlg* bp,
102 std::vector<size_t> &clamped_vars_list,
103 size_t &i, std::vector<size_t> &xis, Real *maxVarOut);
104
105 /// Return the InfAlg to use at each step of the recursion.
106 /// \todo At present, only returns a BP instance
107 InfAlg* getInfAlg();
108
109 size_t _iters;
110 double _maxdiff;
111
112 void setBeliefs(const std::vector<Factor>& bs, double logZ) {
113 size_t i=0;
114 _beliefsV.clear(); _beliefsV.reserve(nrVars());
115 _beliefsF.clear(); _beliefsF.reserve(nrFactors());
116 for(i=0; i<nrVars(); i++) _beliefsV.push_back(bs[i]);
117 for(; i<nrVars()+nrFactors(); i++) _beliefsF.push_back(bs[i]);
118 _logZ = logZ;
119 }
120
121 void construct();
122
123 public:
124
125 //----------------------------------------------------------------
126
127 /// construct CBP object from FactorGraph
128 CBP(const FactorGraph &fg, const PropertySet &opts) : DAIAlgFG(fg) {
129 props.set(opts);
130 construct();
131 }
132
133 /// Name of this inference algorithm
134 static const char *Name;
135
136 /// @name General InfAlg interface
137 //@{
138 virtual CBP* clone() const { return new CBP(*this); }
139 virtual CBP* create() const { DAI_THROW(NOT_IMPLEMENTED); }
140 virtual std::string identify() const { return std::string(Name) + props.toString(); }
141 virtual Factor belief (const Var &n) const { return _beliefsV[findVar(n)]; }
142 virtual Factor belief (const VarSet &) const { DAI_THROW(NOT_IMPLEMENTED); }
143 virtual std::vector<Factor> beliefs() const { return concat(_beliefsV, _beliefsF); }
144 virtual Real logZ() const { return _logZ; }
145 virtual void init() {};
146 virtual void init( const VarSet & ) {};
147 virtual double run();
148 virtual double maxDiff() const { return _maxdiff; }
149 virtual size_t Iterations() const { return _iters; }
150 //@}
151
152 Factor beliefV (size_t i) const { return _beliefsV[i]; }
153 Factor beliefF (size_t I) const { return _beliefsF[I]; }
154
155 //----------------------------------------------------------------
156
157 /// Parameters of this inference algorithm
158 /* PROPERTIES(props,CBP) {
159 /// Enumeration of possible update schedules
160 typedef BP::Properties::UpdateType UpdateType;
161 /// Enumeration of possible methods for deciding when to stop recursing
162 DAI_ENUM(RecurseType,REC_FIXED,REC_LOGZ,REC_BDIFF);
163 /// Enumeration of possible heuristics for choosing clamping variable
164 DAI_ENUM(ChooseMethodType,CHOOSE_RANDOM,CHOOSE_MAXENT,CHOOSE_BBP,CHOOSE_BP_L1,CHOOSE_BP_CFN);
165 /// Enumeration of possible clampings: variables or factors
166 DAI_ENUM(ClampType,CLAMP_VAR,CLAMP_FACTOR);
167
168 /// Verbosity
169 size_t verbose = 0;
170
171 /// Tolerance to use in BP
172 double tol;
173 /// Update style for BP
174 UpdateType updates;
175 /// Maximum number of iterations for BP
176 size_t maxiter;
177
178 /// Tolerance to use for controlling recursion depth (\a recurse is REC_LOGZ or REC_BDIFF)
179 double rec_tol;
180 /// Maximum number of levels of recursion (\a recurse is REC_FIXED)
181 size_t max_levels = 10;
182 /// If choose=CHOOSE_BBP and maximum adjoint is less than this value, don't recurse
183 double min_max_adj;
184 /// Heuristic for choosing clamping variable
185 ChooseMethodType choose;
186 /// Method for deciding when to stop recursing
187 RecurseType recursion;
188 /// Whether to clamp variables or factors
189 ClampType clamp;
190 /// Properties to pass to BBP
191 PropertySet bbp_props;
192 /// Cost function to use for BBP
193 bbp_cfn_t bbp_cfn;
194 /// Random seed
195 size_t rand_seed = 0;
196
197 /// If non-empty, write clamping choices to this file
198 std::string clamp_outfile = "";
199 }
200 */
201 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
202 ./scripts/regenerate-properties include/dai/cbp.h src/cbp.cpp
203 */
204 struct Properties {
205 /// Enumeration of possible update schedules
206 typedef BP::Properties::UpdateType UpdateType;
207 /// Enumeration of possible methods for deciding when to stop recursing
208 DAI_ENUM(RecurseType,REC_FIXED,REC_LOGZ,REC_BDIFF);
209 /// Enumeration of possible heuristics for choosing clamping variable
210 DAI_ENUM(ChooseMethodType,CHOOSE_RANDOM,CHOOSE_MAXENT,CHOOSE_BBP,CHOOSE_BP_L1,CHOOSE_BP_CFN);
211 /// Enumeration of possible clampings: variables or factors
212 DAI_ENUM(ClampType,CLAMP_VAR,CLAMP_FACTOR);
213 /// Verbosity
214 size_t verbose;
215 /// Tolerance to use in BP
216 double tol;
217 /// Update style for BP
218 UpdateType updates;
219 /// Maximum number of iterations for BP
220 size_t maxiter;
221 /// Tolerance to use for controlling recursion depth (\a recurse is REC_LOGZ or REC_BDIFF)
222 double rec_tol;
223 /// Maximum number of levels of recursion (\a recurse is REC_FIXED)
224 size_t max_levels;
225 /// If choose=CHOOSE_BBP and maximum adjoint is less than this value, don't recurse
226 double min_max_adj;
227 /// Heuristic for choosing clamping variable
228 ChooseMethodType choose;
229 /// Method for deciding when to stop recursing
230 RecurseType recursion;
231 /// Whether to clamp variables or factors
232 ClampType clamp;
233 /// Properties to pass to BBP
234 PropertySet bbp_props;
235 /// Cost function to use for BBP
236 bbp_cfn_t bbp_cfn;
237 /// Random seed
238 size_t rand_seed;
239 /// If non-empty, write clamping choices to this file
240 std::string clamp_outfile;
241
242 /// Set members from PropertySet
243 void set(const PropertySet &opts);
244 /// Get members into PropertySet
245 PropertySet get() const;
246 /// Convert to a string which can be parsed as a PropertySet
247 std::string toString() const;
248 } props;
249 /* }}} END OF GENERATED CODE */
250
251 /// Returns heuristic used for clamping variable
252 Properties::ChooseMethodType ChooseMethod() { return props.choose; }
253
254 /// Returns method used for deciding when to stop recursing
255 Properties::RecurseType Recursion() { return props.recursion; }
256
257 /// Returns clamping type used
258 Properties::ClampType Clamping() { return props.clamp; }
259 /// Returns maximum number of levels of recursion
260 size_t maxClampLevel() { return props.max_levels; }
261 /// Returns props.min_max_adj @see CBP::Properties::min_max_adj
262 double minMaxAdj() { return props.min_max_adj; }
263 /// Returns tolerance used for controlling recursion depth
264 double recTol() { return props.rec_tol; }
265 };
266
267 /// Given a sorted vector of states \a xis and total state count \a n_states, return a vector of states not in \a xis
268 std::vector<size_t> complement( std::vector<size_t>& xis, size_t n_states );
269
270 /// Computes \f$\frac{\exp(a)}{\exp(a)+\exp(b)}\f$
271 Real unSoftMax( Real a, Real b );
272
273 /// Computes log of sum of exponents, i.e., \f$\log\left(\exp(a) + \exp(b)\right)\f$
274 Real logSumExp( Real a, Real b );
275
276 /// Compute sum of pairwise L-infinity distances of the first \a nv factors in each vector
277 Real dist( const std::vector<Factor>& b1, const std::vector<Factor>& b2, size_t nv );
278
279
280 } // end of namespace dai
281
282
283 #endif