Strengthened convergence criteria of various algorithms
[libdai.git] / include / dai / gibbs.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
8 */
9
10
11 /// \file
12 /// \brief Defines class Gibbs, which implements Gibbs sampling
13
14
15 #ifndef __defined_libdai_gibbs_h
16 #define __defined_libdai_gibbs_h
17
18
19 #include <dai/daialg.h>
20 #include <dai/factorgraph.h>
21 #include <dai/properties.h>
22
23
24 namespace dai {
25
26
27 /// Approximate inference algorithm "Gibbs sampling"
28 /** \author Frederik Eaton
29 */
30 class Gibbs : public DAIAlgFG {
31 private:
32 /// Type used to store the counts of various states
33 typedef std::vector<size_t> _count_t;
34 /// Type used to store the joint state of all variables
35 typedef std::vector<size_t> _state_t;
36 /// Number of samples counted so far (excluding burn-in)
37 size_t _sample_count;
38 /// State counts for each variable
39 std::vector<_count_t> _var_counts;
40 /// State counts for each factor
41 std::vector<_count_t> _factor_counts;
42 /// Current joint state of all variables
43 _state_t _state;
44
45 public:
46 /// Parameters for Gibbs
47 struct Properties {
48 /// Total number of iterations
49 size_t iters;
50
51 /// Number of "burn-in" iterations
52 size_t burnin;
53
54 /// Verbosity (amount of output sent to stderr)
55 size_t verbose;
56 } props;
57
58 /// Name of this inference algorithm
59 static const char *Name;
60
61 public:
62 /// Default constructor
63 Gibbs() : DAIAlgFG(), _sample_count(0), _var_counts(), _factor_counts(), _state() {}
64
65 /// Construct from FactorGraph \a fg and PropertySet \a opts
66 /** \param opts Parameters @see Properties
67 */
68 Gibbs( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), _sample_count(0), _var_counts(), _factor_counts(), _state() {
69 setProperties( opts );
70 construct();
71 }
72
73
74 /// \name General InfAlg interface
75 //@{
76 virtual Gibbs* clone() const { return new Gibbs(*this); }
77 virtual std::string identify() const { return std::string(Name) + printProperties(); }
78 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
79 virtual Factor belief( const VarSet &vs ) const;
80 virtual Factor beliefV( size_t i ) const;
81 virtual Factor beliefF( size_t I ) const;
82 virtual std::vector<Factor> beliefs() const;
83 virtual Real logZ() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
84 virtual void init();
85 virtual void init( const VarSet &/*ns*/ ) { init(); }
86 virtual Real run();
87 virtual Real maxDiff() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
88 virtual size_t Iterations() const { return props.iters; }
89 virtual void setProperties( const PropertySet &opts );
90 virtual PropertySet getProperties() const;
91 virtual std::string printProperties() const;
92 //@}
93
94
95 /// \name Additional interface specific for Gibbs
96 //@{
97 /// Draw the current joint state of all variables from a uniform random distribution
98 void randomizeState();
99 /// Return reference to current state of all variables
100 std::vector<size_t>& state() { return _state; }
101 /// Return constant reference to current state of all variables
102 const std::vector<size_t>& state() const { return _state; }
103 //@}
104
105 private:
106 /// Helper function for constructors
107 void construct();
108 /// Updates all counts (_sample_count, _var_counts, _factor_counts) based on current state
109 void updateCounts();
110 /// Calculate conditional distribution of variable \a i, given the current state
111 Prob getVarDist( size_t i );
112 /// Draw state of variable \a i randomly from its conditional distribution and update the current state
113 void resampleVar( size_t i );
114 /// Calculates linear index into factor \a I corresponding to the current state
115 size_t getFactorEntry( size_t I );
116 /// Calculates the differences between linear indices into factor \a I corresponding with a state change of variable \a i
117 size_t getFactorEntryDiff( size_t I, size_t i );
118 };
119
120
121 /// Runs Gibbs sampling for \a iters iterations (of which \a burnin for burn-in) on FactorGraph \a fg, and returns the resulting state
122 /** \relates Gibbs
123 */
124 std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t iters, size_t burnin=0 );
125
126
127 } // end of namespace dai
128
129
130 #endif