Cleaned up BBP and improved documentation of include/dai/bbp.h
[libdai.git] / include / dai / cbp.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
8 */
9
10
11 /// \file
12 /// \brief Defines class CBP [\ref EaG09]
13 /// \author Frederik Eaton
14 /// \todo Improve documentation
15
16
17 #ifndef __defined_libdai_cbp_h
18 #define __defined_libdai_cbp_h
19
20
21 #include <fstream>
22 #include <boost/shared_ptr.hpp>
23
24 #include <dai/daialg.h>
25 #include <dai/bbp.h>
26
27
28 namespace dai {
29
30
31 /// Find a variable to clamp using BBP (goes with maximum adjoint)
32 /// \see BBP
33 std::pair<size_t, size_t> bbpFindClampVar( const InfAlg &in_bp, bool clampingVar, const PropertySet &bbp_props, const BBPCostFunction &cfn, Real *maxVarOut );
34
35
36 /// Class for CBP (Clamped Belief Propagation)
37 /** This algorithm uses configurable heuristics to choose a variable
38 * x_i and a state x_i*. Inference is done with x_i "clamped" to x_i*
39 * (i.e., conditional on x_i == x_i*), and also with the negation of this
40 * condition. Clamping is done recursively up to a fixed number of
41 * levels (other stopping criteria are also implemented, see
42 * \a recursion property). The resulting approximate marginals are
43 * combined using logZ estimates.
44 *
45 * \author Frederik Eaton
46 */
47 class CBP : public DAIAlgFG {
48 private:
49 /// Variable beliefs
50 std::vector<Factor> _beliefsV;
51 /// Factor beliefs
52 std::vector<Factor> _beliefsF;
53 /// Log-partition sum
54 Real _logZ;
55
56 /// Counts number of clampings at each leaf node
57 Real _sum_level;
58
59 /// Number of leaves of recursion tree
60 size_t _num_leaves;
61
62 /// Output stream where information about the clampings is written
63 boost::shared_ptr<std::ofstream> _clamp_ofstream;
64
65 /// Returns BBP cost function used
66 BBPCostFunction BBP_cost_function() { return props.bbp_cfn; }
67
68 /// Prints beliefs, variables and partition sum, in case of a debugging build
69 void printDebugInfo();
70
71 /// Called by 'run', and by itself. Implements the main algorithm.
72 /** Chooses a variable to clamp, recurses, combines the logZ and
73 * beliefs estimates of the children, and returns the improved
74 * estimates in \a lz_out and \a beliefs_out to its parent
75 */
76 void runRecurse( InfAlg *bp, Real orig_logZ, std::vector<size_t> clamped_vars_list, size_t &num_leaves,
77 size_t &choose_count, Real &sum_level, Real &lz_out, std::vector<Factor> &beliefs_out );
78
79 /// Choose the next variable to clamp
80 /** Choose the next variable to clamp, given a converged InfAlg (\a bp),
81 * and a vector of variables that are already clamped (\a
82 * clamped_vars_list). Returns the chosen variable in \a i, and
83 * the set of states in \a xis. If \a maxVarOut is non-NULL and
84 * props.choose==CHOOSE_BBP then it is used to store the
85 * adjoint of the chosen variable
86 */
87 virtual bool chooseNextClampVar( InfAlg* bp, std::vector<size_t> &clamped_vars_list, size_t &i, std::vector<size_t> &xis, Real *maxVarOut );
88
89 /// Return the InfAlg to use at each step of the recursion.
90 /// \todo At present, only returns a BP instance
91 InfAlg* getInfAlg();
92
93 /// Numer of iterations needed
94 size_t _iters;
95 /// Maximum difference encountered so far
96 Real _maxdiff;
97
98 /// Sets variable beliefs, factor beliefs and logZ
99 /** \param bs should be a concatenation of the variable beliefs followed by the factor beliefs
100 */
101 void setBeliefs( const std::vector<Factor> &bs, Real logZ );
102
103 /// Constructor helper function
104 void construct();
105
106 public:
107 /// Construct CBP object from FactorGraph fg and PropertySet opts
108 CBP( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg) {
109 props.set( opts );
110 construct();
111 }
112
113 /// Name of this inference algorithm
114 static const char *Name;
115
116 /// \name General InfAlg interface
117 //@{
118 virtual CBP* clone() const { return new CBP(*this); }
119 virtual std::string identify() const { return std::string(Name) + props.toString(); }
120 virtual Factor belief (const Var &n) const { return _beliefsV[findVar(n)]; }
121 virtual Factor belief (const VarSet &) const { DAI_THROW(NOT_IMPLEMENTED); }
122 virtual Factor beliefV( size_t i ) const { return _beliefsV[i]; }
123 virtual Factor beliefF( size_t I ) const { return _beliefsF[I]; }
124 virtual std::vector<Factor> beliefs() const { return concat(_beliefsV, _beliefsF); }
125 virtual Real logZ() const { return _logZ; }
126 virtual void init() {};
127 virtual void init( const VarSet & ) {};
128 virtual Real run();
129 virtual Real maxDiff() const { return _maxdiff; }
130 virtual size_t Iterations() const { return _iters; }
131 virtual void setProperties( const PropertySet &opts ) { props.set( opts ); }
132 virtual PropertySet getProperties() const { return props.get(); }
133 virtual std::string printProperties() const { return props.toString(); }
134 //@}
135
136 //----------------------------------------------------------------
137
138 /// Parameters of this inference algorithm
139 /* PROPERTIES(props,CBP) {
140 /// Enumeration of possible update schedules
141 typedef BP::Properties::UpdateType UpdateType;
142 /// Enumeration of possible methods for deciding when to stop recursing
143 DAI_ENUM(RecurseType,REC_FIXED,REC_LOGZ,REC_BDIFF);
144 /// Enumeration of possible heuristics for choosing clamping variable
145 DAI_ENUM(ChooseMethodType,CHOOSE_RANDOM,CHOOSE_MAXENT,CHOOSE_BBP,CHOOSE_BP_L1,CHOOSE_BP_CFN);
146 /// Enumeration of possible clampings: variables or factors
147 DAI_ENUM(ClampType,CLAMP_VAR,CLAMP_FACTOR);
148
149 /// Verbosity
150 size_t verbose = 0;
151
152 /// Tolerance to use in BP
153 Real tol;
154 /// Update style for BP
155 UpdateType updates;
156 /// Maximum number of iterations for BP
157 size_t maxiter;
158
159 /// Tolerance to use for controlling recursion depth (\a recurse is REC_LOGZ or REC_BDIFF)
160 Real rec_tol;
161 /// Maximum number of levels of recursion (\a recurse is REC_FIXED)
162 size_t max_levels = 10;
163 /// If choose==CHOOSE_BBP and maximum adjoint is less than this value, don't recurse
164 Real min_max_adj;
165 /// Heuristic for choosing clamping variable
166 ChooseMethodType choose;
167 /// Method for deciding when to stop recursing
168 RecurseType recursion;
169 /// Whether to clamp variables or factors
170 ClampType clamp;
171 /// Properties to pass to BBP
172 PropertySet bbp_props;
173 /// Cost function to use for BBP
174 BBPCostFunction bbp_cfn;
175 /// Random seed
176 size_t rand_seed = 0;
177
178 /// If non-empty, write clamping choices to this file
179 std::string clamp_outfile = "";
180 }
181 */
182 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
183 ./scripts/regenerate-properties include/dai/cbp.h src/cbp.cpp
184 */
185 struct Properties {
186 /// Enumeration of possible update schedules
187 typedef BP::Properties::UpdateType UpdateType;
188 /// Enumeration of possible methods for deciding when to stop recursing
189 DAI_ENUM(RecurseType,REC_FIXED,REC_LOGZ,REC_BDIFF);
190 /// Enumeration of possible heuristics for choosing clamping variable
191 DAI_ENUM(ChooseMethodType,CHOOSE_RANDOM,CHOOSE_MAXENT,CHOOSE_BBP,CHOOSE_BP_L1,CHOOSE_BP_CFN);
192 /// Enumeration of possible clampings: variables or factors
193 DAI_ENUM(ClampType,CLAMP_VAR,CLAMP_FACTOR);
194 /// Verbosity
195 size_t verbose;
196 /// Tolerance to use in BP
197 Real tol;
198 /// Update style for BP
199 UpdateType updates;
200 /// Maximum number of iterations for BP
201 size_t maxiter;
202 /// Tolerance to use for controlling recursion depth (\a recurse is REC_LOGZ or REC_BDIFF)
203 Real rec_tol;
204 /// Maximum number of levels of recursion (\a recurse is REC_FIXED)
205 size_t max_levels;
206 /// If choose==CHOOSE_BBP and maximum adjoint is less than this value, don't recurse
207 Real min_max_adj;
208 /// Heuristic for choosing clamping variable
209 ChooseMethodType choose;
210 /// Method for deciding when to stop recursing
211 RecurseType recursion;
212 /// Whether to clamp variables or factors
213 ClampType clamp;
214 /// Properties to pass to BBP
215 PropertySet bbp_props;
216 /// Cost function to use for BBP
217 BBPCostFunction bbp_cfn;
218 /// Random seed
219 size_t rand_seed;
220 /// If non-empty, write clamping choices to this file
221 std::string clamp_outfile;
222
223 /// Set members from PropertySet
224 void set(const PropertySet &opts);
225 /// Get members into PropertySet
226 PropertySet get() const;
227 /// Convert to a string which can be parsed as a PropertySet
228 std::string toString() const;
229 } props;
230 /* }}} END OF GENERATED CODE */
231
232 /// Returns heuristic used for clamping variable
233 Properties::ChooseMethodType ChooseMethod() { return props.choose; }
234 /// Returns method used for deciding when to stop recursing
235 Properties::RecurseType Recursion() { return props.recursion; }
236 /// Returns clamping type used
237 Properties::ClampType Clamping() { return props.clamp; }
238 /// Returns maximum number of levels of recursion
239 size_t maxClampLevel() { return props.max_levels; }
240 /// Returns props.min_max_adj @see CBP::Properties::min_max_adj
241 Real minMaxAdj() { return props.min_max_adj; }
242 /// Returns tolerance used for controlling recursion depth
243 Real recTol() { return props.rec_tol; }
244 };
245
246
247 /// Given a sorted vector of states \a xis and total state count \a n_states, return a vector of states not in \a xis
248 std::vector<size_t> complement( std::vector<size_t>& xis, size_t n_states );
249
250 /// Computes \f$\frac{\exp(a)}{\exp(a)+\exp(b)}\f$
251 Real unSoftMax( Real a, Real b );
252
253 /// Computes log of sum of exponents, i.e., \f$\log\left(\exp(a) + \exp(b)\right)\f$
254 Real logSumExp( Real a, Real b );
255
256 /// Compute sum of pairwise L-infinity distances of the first \a nv factors in each vector
257 Real dist( const std::vector<Factor>& b1, const std::vector<Factor>& b2, size_t nv );
258
259
260 } // end of namespace dai
261
262
263 #endif