Changed license from GPL v2+ to FreeBSD (aka BSD 2-clause) license
[libdai.git] / include / dai / gibbs.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 /// \file
10 /// \brief Defines class Gibbs, which implements Gibbs sampling
11
12
13 #ifndef __defined_libdai_gibbs_h
14 #define __defined_libdai_gibbs_h
15
16
17 #include <dai/daialg.h>
18 #include <dai/factorgraph.h>
19 #include <dai/properties.h>
20
21
22 namespace dai {
23
24
25 /// Approximate inference algorithm "Gibbs sampling"
26 /** \author Frederik Eaton
27 */
28 class Gibbs : public DAIAlgFG {
29 private:
30 /// Type used to store the counts of various states
31 typedef std::vector<size_t> _count_t;
32 /// Type used to store the joint state of all variables
33 typedef std::vector<size_t> _state_t;
34 /// Number of samples counted so far (excluding burn-in periods)
35 size_t _sample_count;
36 /// State counts for each variable
37 std::vector<_count_t> _var_counts;
38 /// State counts for each factor
39 std::vector<_count_t> _factor_counts;
40 /// Number of iterations done (including burn-in periods)
41 size_t _iters;
42 /// Current joint state of all variables
43 _state_t _state;
44 /// Joint state with maximum probability seen so far
45 _state_t _max_state;
46 /// Highest score so far
47 Real _max_score;
48
49 public:
50 /// Parameters for Gibbs
51 struct Properties {
52 /// Maximum number of iterations
53 size_t maxiter;
54
55 /// Maximum time (in seconds)
56 double maxtime;
57
58 /// Number of iterations after which a random restart is made
59 size_t restart;
60
61 /// Number of "burn-in" iterations after each (re)start (for which no statistics are gathered)
62 size_t burnin;
63
64 /// Verbosity (amount of output sent to stderr)
65 size_t verbose;
66 } props;
67
68 public:
69 /// Default constructor
70 Gibbs() : DAIAlgFG(), _sample_count(0), _var_counts(), _factor_counts(), _iters(0), _state(), _max_state(), _max_score(-INFINITY) {}
71
72 /// Construct from FactorGraph \a fg and PropertySet \a opts
73 /** \param fg Factor graph.
74 * \param opts Parameters @see Properties
75 */
76 Gibbs( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), _sample_count(0), _var_counts(), _factor_counts(), _iters(0), _state(), _max_state(), _max_score(-INFINITY) {
77 setProperties( opts );
78 construct();
79 }
80
81
82 /// \name General InfAlg interface
83 //@{
84 virtual Gibbs* clone() const { return new Gibbs(*this); }
85 virtual std::string name() const { return "GIBBS"; }
86 virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
87 virtual Factor belief( const VarSet &vs ) const;
88 virtual Factor beliefV( size_t i ) const;
89 virtual Factor beliefF( size_t I ) const;
90 virtual std::vector<Factor> beliefs() const;
91 virtual Real logZ() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
92 std::vector<std::size_t> findMaximum() const { return _max_state; }
93 virtual void init();
94 virtual void init( const VarSet &/*ns*/ ) { init(); }
95 virtual Real run();
96 virtual Real maxDiff() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
97 virtual size_t Iterations() const { return _iters; }
98 virtual void setMaxIter( size_t maxiter ) { props.maxiter = maxiter; }
99 virtual void setProperties( const PropertySet &opts );
100 virtual PropertySet getProperties() const;
101 virtual std::string printProperties() const;
102 //@}
103
104
105 /// \name Additional interface specific for Gibbs
106 //@{
107 /// Draw the current joint state of all variables from a uniform random distribution
108 void randomizeState();
109 /// Return reference to current state of all variables
110 std::vector<size_t>& state() { return _state; }
111 /// Return constant reference to current state of all variables
112 const std::vector<size_t>& state() const { return _state; }
113 //@}
114
115 private:
116 /// Helper function for constructors
117 void construct();
118 /// Updates all counts (_sample_count, _var_counts, _factor_counts) based on current state
119 void updateCounts();
120 /// Calculate conditional distribution of variable \a i, given the current state
121 Prob getVarDist( size_t i );
122 /// Draw state of variable \a i randomly from its conditional distribution and update the current state
123 void resampleVar( size_t i );
124 /// Calculates linear index into factor \a I corresponding to the current state
125 size_t getFactorEntry( size_t I );
126 /// Calculates the differences between linear indices into factor \a I corresponding with a state change of variable \a i
127 size_t getFactorEntryDiff( size_t I, size_t i );
128 };
129
130
131 /// Runs Gibbs sampling for \a maxiter iterations (of which \a burnin for burn-in) on FactorGraph \a fg, and returns the resulting state
132 /** \relates Gibbs
133 */
134 std::vector<size_t> getGibbsState( const FactorGraph &fg, size_t maxiter );
135
136
137 } // end of namespace dai
138
139
140 /** \example example_sprinkler_gibbs.cpp
141 * This example shows how to use the Gibbs class.
142 */
143
144
145 #endif