Fixed example_imagesegmentation by adding InfAlg::setMaxIter(size_t)
[libdai.git] / include / dai / gibbs.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
8 */
9
10
11 /// \file
12 /// \brief Defines class Gibbs, which implements Gibbs sampling
13
14
15 #ifndef __defined_libdai_gibbs_h
16 #define __defined_libdai_gibbs_h
17
18
19 #include <dai/daialg.h>
20 #include <dai/factorgraph.h>
21 #include <dai/properties.h>
22
23
24 namespace dai {
25
26
27 /// Approximate inference algorithm "Gibbs sampling"
28 /** \author Frederik Eaton
29 */
30 class Gibbs : public DAIAlgFG {
31 private:
32 /// Type used to store the counts of various states
33 typedef std::vector<size_t> _count_t;
34 /// Type used to store the joint state of all variables
35 typedef std::vector<size_t> _state_t;
36 /// Number of samples counted so far (excluding burn-in periods)
37 size_t _sample_count;
38 /// State counts for each variable
39 std::vector<_count_t> _var_counts;
40 /// State counts for each factor
41 std::vector<_count_t> _factor_counts;
42 /// Number of iterations done (including burn-in periods)
43 size_t _iters;
44 /// Current joint state of all variables
45 _state_t _state;
46 /// Joint state with maximum probability seen so far
47 _state_t _max_state;
48 /// Highest score so far
49 Real _max_score;
50
51 public:
52 /// Parameters for Gibbs
53 struct Properties {
54 /// Maximum number of iterations
55 size_t maxiter;
56
57 /// Maximum time (in seconds)
58 double maxtime;
59
60 /// Number of iterations after which a random restart is made
61 size_t restart;
62
63 /// Number of "burn-in" iterations after each (re)start (for which no statistics are gathered)
64 size_t burnin;
65
66 /// Verbosity (amount of output sent to stderr)
67 size_t verbose;
68 } props;
69
70 /// Name of this inference algorithm
71 static const char *Name;
72
73 public:
74 /// Default constructor
75 Gibbs() : DAIAlgFG(), _sample_count(0), _var_counts(), _factor_counts(), _iters(0), _state(), _max_state(), _max_score(-INFINITY) {}
76
77 /// Construct from FactorGraph \a fg and PropertySet \a opts
78 /** \param fg Factor graph.
79 * \param opts Parameters @see Properties
80 */
81 Gibbs( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), _sample_count(0), _var_counts(), _factor_counts(), _iters(0), _state(), _max_state(), _max_score(-INFINITY) {
82 setProperties( opts );
83 construct();
84 }
85
86
87 /// \name General InfAlg interface
88 //@{
89 virtual Gibbs* clone() const { return new Gibbs(*this); }
90 virtual std::string identify() const { return std::string(Name) + printProperties(); }
91 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
92 virtual Factor belief( const VarSet &vs ) const;
93 virtual Factor beliefV( size_t i ) const;
94 virtual Factor beliefF( size_t I ) const;
95 virtual std::vector<Factor> beliefs() const;
96 virtual Real logZ() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
97 std::vector<std::size_t> findMaximum() const { return _max_state; }
98 virtual void init();
99 virtual void init( const VarSet &/*ns*/ ) { init(); }
100 virtual Real run();
101 virtual Real maxDiff() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
102 virtual size_t Iterations() const { return _iters; }
103 virtual void setMaxIter( size_t maxiter ) { props.maxiter = maxiter; }
104 virtual void setProperties( const PropertySet &opts );
105 virtual PropertySet getProperties() const;
106 virtual std::string printProperties() const;
107 //@}
108
109
110 /// \name Additional interface specific for Gibbs
111 //@{
112 /// Draw the current joint state of all variables from a uniform random distribution
113 void randomizeState();
114 /// Return reference to current state of all variables
115 std::vector<size_t>& state() { return _state; }
116 /// Return constant reference to current state of all variables
117 const std::vector<size_t>& state() const { return _state; }
118 //@}
119
120 private:
121 /// Helper function for constructors
122 void construct();
123 /// Updates all counts (_sample_count, _var_counts, _factor_counts) based on current state
124 void updateCounts();
125 /// Calculate conditional distribution of variable \a i, given the current state
126 Prob getVarDist( size_t i );
127 /// Draw state of variable \a i randomly from its conditional distribution and update the current state
128 void resampleVar( size_t i );
129 /// Calculates linear index into factor \a I corresponding to the current state
130 size_t getFactorEntry( size_t I );
131 /// Calculates the differences between linear indices into factor \a I corresponding with a state change of variable \a i
132 size_t getFactorEntryDiff( size_t I, size_t i );
133 };
134
135
136 /// Runs Gibbs sampling for \a iters iterations (of which \a burnin for burn-in) on FactorGraph \a fg, and returns the resulting state
137 /** \relates Gibbs
138 */
139 std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t iters );
140
141
142 } // end of namespace dai
143
144
145 /** \example example_sprinkler_gibbs.cpp
146 * This example shows how to use the Gibbs class.
147 */
148
149
150 #endif