1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2008 Frederik Eaton [frederik at ofb dot net]
16 #include <dai/gibbs.h>
18 #include <dai/properties.h>
27 const char *Gibbs::Name
= "GIBBS";
30 void Gibbs::setProperties( const PropertySet
&opts
) {
31 DAI_ASSERT( opts
.hasKey("iters") );
32 props
.iters
= opts
.getStringAs
<size_t>("iters");
34 if( opts
.hasKey("burnin") )
35 props
.burnin
= opts
.getStringAs
<size_t>("burnin");
39 if( opts
.hasKey("verbose") )
40 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
46 PropertySet
Gibbs::getProperties() const {
48 opts
.Set( "iters", props
.iters
);
49 opts
.Set( "burnin", props
.burnin
);
50 opts
.Set( "verbose", props
.verbose
);
55 string
Gibbs::printProperties() const {
56 stringstream
s( stringstream::out
);
58 s
<< "iters=" << props
.iters
<< ",";
59 s
<< "burnin=" << props
.burnin
<< ",";
60 s
<< "verbose=" << props
.verbose
<< "]";
65 void Gibbs::construct() {
67 _var_counts
.reserve( nrVars() );
68 for( size_t i
= 0; i
< nrVars(); i
++ )
69 _var_counts
.push_back( _count_t( var(i
).states(), 0 ) );
71 _factor_counts
.clear();
72 _factor_counts
.reserve( nrFactors() );
73 for( size_t I
= 0; I
< nrFactors(); I
++ )
74 _factor_counts
.push_back( _count_t( factor(I
).states(), 0 ) );
79 _state
.resize( nrVars(), 0 );
83 void Gibbs::updateCounts() {
85 if( _sample_count
> props
.burnin
) {
86 for( size_t i
= 0; i
< nrVars(); i
++ )
87 _var_counts
[i
][_state
[i
]]++;
88 for( size_t I
= 0; I
< nrFactors(); I
++ )
89 _factor_counts
[I
][getFactorEntry(I
)]++;
94 inline size_t Gibbs::getFactorEntry( size_t I
) {
96 for( int _j
= nbF(I
).size() - 1; _j
>= 0; _j
-- ) {
97 // note that iterating over nbF(I) yields the same ordering
98 // of variables as iterating over factor(I).vars()
99 size_t j
= nbF(I
)[_j
];
100 f_entry
*= var(j
).states();
101 f_entry
+= _state
[j
];
107 inline size_t Gibbs::getFactorEntryDiff( size_t I
, size_t i
) {
109 for( size_t _j
= 0; _j
< nbF(I
).size(); _j
++ ) {
110 // note that iterating over nbF(I) yields the same ordering
111 // of variables as iterating over factor(I).vars()
112 size_t j
= nbF(I
)[_j
];
116 skip
*= var(j
).states();
122 Prob
Gibbs::getVarDist( size_t i
) {
123 DAI_ASSERT( i
< nrVars() );
124 size_t i_states
= var(i
).states();
125 Prob
i_given_MB( i_states
, 1.0 );
127 // use Markov blanket of var(i) to calculate distribution
128 foreach( const Neighbor
&I
, nbV(i
) ) {
129 const Factor
&f_I
= factor(I
);
130 size_t I_skip
= getFactorEntryDiff( I
, i
);
131 size_t I_entry
= getFactorEntry(I
) - (_state
[i
] * I_skip
);
132 for( size_t st_i
= 0; st_i
< i_states
; st_i
++ ) {
133 i_given_MB
[st_i
] *= f_I
[I_entry
];
138 if( i_given_MB
.sum() == 0.0 )
139 // If no state of i is allowed, use uniform distribution
140 // FIXME is that indeed the right thing to do?
141 i_given_MB
= Prob( i_states
);
143 i_given_MB
.normalize();
148 inline void Gibbs::resampleVar( size_t i
) {
149 _state
[i
] = getVarDist(i
).draw();
153 void Gibbs::randomizeState() {
154 for( size_t i
= 0; i
< nrVars(); i
++ )
155 _state
[i
] = rnd( var(i
).states() );
160 for( size_t i
= 0; i
< nrVars(); i
++ )
161 fill( _var_counts
[i
].begin(), _var_counts
[i
].end(), 0 );
162 for( size_t I
= 0; I
< nrFactors(); I
++ )
163 fill( _factor_counts
[I
].begin(), _factor_counts
[I
].end(), 0 );
169 if( props
.verbose
>= 1 )
170 cerr
<< "Starting " << identify() << "...";
171 if( props
.verbose
>= 3 )
178 for( size_t iter
= 0; iter
< props
.iters
; iter
++ ) {
179 for( size_t j
= 0; j
< nrVars(); j
++ )
184 if( props
.verbose
>= 3 ) {
185 for( size_t i
= 0; i
< nrVars(); i
++ ) {
186 cerr
<< "belief for variable " << var(i
) << ": " << beliefV(i
) << endl
;
187 cerr
<< "counts for variable " << var(i
) << ": " << Prob( _var_counts
[i
] ) << endl
;
191 if( props
.verbose
>= 3 )
192 cerr
<< Name
<< "::run: ran " << props
.iters
<< " passes (" << toc() - tic
<< " clocks)." << endl
;
198 Factor
Gibbs::beliefV( size_t i
) const {
199 return Factor( var(i
), _var_counts
[i
] ).normalized();
203 Factor
Gibbs::beliefF( size_t I
) const {
204 return Factor( factor(I
).vars(), _factor_counts
[I
] ).normalized();
208 vector
<Factor
> Gibbs::beliefs() const {
209 vector
<Factor
> result
;
210 for( size_t i
= 0; i
< nrVars(); ++i
)
211 result
.push_back( beliefV(i
) );
212 for( size_t I
= 0; I
< nrFactors(); ++I
)
213 result
.push_back( beliefF(I
) );
218 Factor
Gibbs::belief( const VarSet
&ns
) const {
221 else if( ns
.size() == 1 )
222 return beliefV( findVar( *(ns
.begin()) ) );
225 for( I
= 0; I
< nrFactors(); I
++ )
226 if( factor(I
).vars() >> ns
)
228 if( I
== nrFactors() )
229 DAI_THROW(BELIEF_NOT_AVAILABLE
);
230 return beliefF(I
).marginal(ns
);
235 std::vector
<size_t> getGibbsState( const FactorGraph
&fg
, size_t iters
, size_t burnin
) {
236 PropertySet gibbsProps
;
237 gibbsProps
.Set("iters", iters
);
238 gibbsProps
.Set("burnin", burnin
);
239 gibbsProps
.Set("verbose", size_t(0));
240 Gibbs
gibbs( fg
, gibbsProps
);
242 return gibbs
.state();
246 } // end of namespace dai