Added TProb<T>::operator==( const TProb<T> & ) and added some unit tests for prob...
[libdai.git] / include / dai / gibbs.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
8 */
9
10
11 /// \file
12 /// \brief Defines class Gibbs, which implements Gibbs sampling
13
14
15 #ifndef __defined_libdai_gibbs_h
16 #define __defined_libdai_gibbs_h
17
18
19 #include <dai/daialg.h>
20 #include <dai/factorgraph.h>
21 #include <dai/properties.h>
22
23
24 namespace dai {
25
26
27 /// Approximate inference algorithm "Gibbs sampling"
28 /** \author Frederik Eaton
29 */
30 class Gibbs : public DAIAlgFG {
31 private:
32 /// Type used to store the counts of various states
33 typedef std::vector<size_t> _count_t;
34 /// Type used to store the joint state of all variables
35 typedef std::vector<size_t> _state_t;
36 /// Number of samples counted so far (excluding burn-in)
37 size_t _sample_count;
38 /// State counts for each variable
39 std::vector<_count_t> _var_counts;
40 /// State counts for each factor
41 std::vector<_count_t> _factor_counts;
42 /// Current joint state of all variables
43 _state_t _state;
44
45 public:
46 /// Parameters for Gibbs
47 struct Properties {
48 /// Total number of iterations
49 size_t iters;
50
51 /// Number of "burn-in" iterations
52 size_t burnin;
53
54 /// Verbosity (amount of output sent to stderr)
55 size_t verbose;
56 } props;
57
58 /// Name of this inference algorithm
59 static const char *Name;
60
61 public:
62 /// Default constructor
63 Gibbs() : DAIAlgFG(), _sample_count(0), _var_counts(), _factor_counts(), _state() {}
64
65 /// Construct from FactorGraph \a fg and PropertySet \a opts
66 /** \param fg Factor graph.
67 * \param opts Parameters @see Properties
68 */
69 Gibbs( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), _sample_count(0), _var_counts(), _factor_counts(), _state() {
70 setProperties( opts );
71 construct();
72 }
73
74
75 /// \name General InfAlg interface
76 //@{
77 virtual Gibbs* clone() const { return new Gibbs(*this); }
78 virtual std::string identify() const { return std::string(Name) + printProperties(); }
79 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
80 virtual Factor belief( const VarSet &vs ) const;
81 virtual Factor beliefV( size_t i ) const;
82 virtual Factor beliefF( size_t I ) const;
83 virtual std::vector<Factor> beliefs() const;
84 virtual Real logZ() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
85 virtual void init();
86 virtual void init( const VarSet &/*ns*/ ) { init(); }
87 virtual Real run();
88 virtual Real maxDiff() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
89 virtual size_t Iterations() const { return props.iters; }
90 virtual void setProperties( const PropertySet &opts );
91 virtual PropertySet getProperties() const;
92 virtual std::string printProperties() const;
93 //@}
94
95
96 /// \name Additional interface specific for Gibbs
97 //@{
98 /// Draw the current joint state of all variables from a uniform random distribution
99 void randomizeState();
100 /// Return reference to current state of all variables
101 std::vector<size_t>& state() { return _state; }
102 /// Return constant reference to current state of all variables
103 const std::vector<size_t>& state() const { return _state; }
104 //@}
105
106 private:
107 /// Helper function for constructors
108 void construct();
109 /// Updates all counts (_sample_count, _var_counts, _factor_counts) based on current state
110 void updateCounts();
111 /// Calculate conditional distribution of variable \a i, given the current state
112 Prob getVarDist( size_t i );
113 /// Draw state of variable \a i randomly from its conditional distribution and update the current state
114 void resampleVar( size_t i );
115 /// Calculates linear index into factor \a I corresponding to the current state
116 size_t getFactorEntry( size_t I );
117 /// Calculates the differences between linear indices into factor \a I corresponding with a state change of variable \a i
118 size_t getFactorEntryDiff( size_t I, size_t i );
119 };
120
121
122 /// Runs Gibbs sampling for \a iters iterations (of which \a burnin for burn-in) on FactorGraph \a fg, and returns the resulting state
123 /** \relates Gibbs
124 */
125 std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t iters );
126
127
128 } // end of namespace dai
129
130
131 /** \example example_sprinkler_gibbs.cpp
132 * This example shows how to use the Gibbs class.
133 */
134
135
136 #endif