Improved TreeEP (added 'maxtime' property)
[libdai.git] / include / dai / gibbs.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
8 */
9
10
11 /// \file
12 /// \brief Defines class Gibbs, which implements Gibbs sampling
13
14
15 #ifndef __defined_libdai_gibbs_h
16 #define __defined_libdai_gibbs_h
17
18
19 #include <dai/daialg.h>
20 #include <dai/factorgraph.h>
21 #include <dai/properties.h>
22
23
24 namespace dai {
25
26
27 /// Approximate inference algorithm "Gibbs sampling"
28 /** \author Frederik Eaton
29 */
30 class Gibbs : public DAIAlgFG {
31 private:
32 /// Type used to store the counts of various states
33 typedef std::vector<size_t> _count_t;
34 /// Type used to store the joint state of all variables
35 typedef std::vector<size_t> _state_t;
36 /// Number of samples counted so far (excluding burn-in periods)
37 size_t _sample_count;
38 /// State counts for each variable
39 std::vector<_count_t> _var_counts;
40 /// State counts for each factor
41 std::vector<_count_t> _factor_counts;
42 /// Number of iterations done (including burn-in periods)
43 size_t _iters;
44 /// Current joint state of all variables
45 _state_t _state;
46 /// Joint state with maximum probability seen so far
47 _state_t _max_state;
48 /// Highest score so far
49 Real _max_score;
50
51 public:
52 /// Parameters for Gibbs
53 struct Properties {
54 /// Maximum number of iterations
55 size_t maxiter;
56
57 /// Maximum time (in seconds)
58 double maxtime;
59
60 /// Number of iterations after which a random restart is made
61 size_t restart;
62
63 /// Number of "burn-in" iterations after each (re)start (for which no statistics are gathered)
64 size_t burnin;
65
66 /// Verbosity (amount of output sent to stderr)
67 size_t verbose;
68 } props;
69
70 /// Name of this inference algorithm
71 static const char *Name;
72
73 public:
74 /// Default constructor
75 Gibbs() : DAIAlgFG(), _sample_count(0), _var_counts(), _factor_counts(), _iters(0), _state(), _max_state(), _max_score(-INFINITY) {}
76
77 /// Construct from FactorGraph \a fg and PropertySet \a opts
78 /** \param fg Factor graph.
79 * \param opts Parameters @see Properties
80 */
81 Gibbs( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), _sample_count(0), _var_counts(), _factor_counts(), _iters(0), _state(), _max_state(), _max_score(-INFINITY) {
82 setProperties( opts );
83 construct();
84 }
85
86
87 /// \name General InfAlg interface
88 //@{
89 virtual Gibbs* clone() const { return new Gibbs(*this); }
90 virtual std::string identify() const { return std::string(Name) + printProperties(); }
91 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
92 virtual Factor belief( const VarSet &vs ) const;
93 virtual Factor beliefV( size_t i ) const;
94 virtual Factor beliefF( size_t I ) const;
95 virtual std::vector<Factor> beliefs() const;
96 virtual Real logZ() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
97 virtual void init();
98 virtual void init( const VarSet &/*ns*/ ) { init(); }
99 virtual Real run();
100 virtual Real maxDiff() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
101 virtual size_t Iterations() const { return _iters; }
102 std::vector<std::size_t> findMaximum() const { return _max_state; }
103 virtual void setProperties( const PropertySet &opts );
104 virtual PropertySet getProperties() const;
105 virtual std::string printProperties() const;
106 //@}
107
108
109 /// \name Additional interface specific for Gibbs
110 //@{
111 /// Draw the current joint state of all variables from a uniform random distribution
112 void randomizeState();
113 /// Return reference to current state of all variables
114 std::vector<size_t>& state() { return _state; }
115 /// Return constant reference to current state of all variables
116 const std::vector<size_t>& state() const { return _state; }
117 //@}
118
119 private:
120 /// Helper function for constructors
121 void construct();
122 /// Updates all counts (_sample_count, _var_counts, _factor_counts) based on current state
123 void updateCounts();
124 /// Calculate conditional distribution of variable \a i, given the current state
125 Prob getVarDist( size_t i );
126 /// Draw state of variable \a i randomly from its conditional distribution and update the current state
127 void resampleVar( size_t i );
128 /// Calculates linear index into factor \a I corresponding to the current state
129 size_t getFactorEntry( size_t I );
130 /// Calculates the differences between linear indices into factor \a I corresponding with a state change of variable \a i
131 size_t getFactorEntryDiff( size_t I, size_t i );
132 };
133
134
135 /// Runs Gibbs sampling for \a iters iterations (of which \a burnin for burn-in) on FactorGraph \a fg, and returns the resulting state
136 /** \relates Gibbs
137 */
138 std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t iters );
139
140
141 } // end of namespace dai
142
143
144 /** \example example_sprinkler_gibbs.cpp
145 * This example shows how to use the Gibbs class.
146 */
147
148
149 #endif