Finished integrating Gibbs sampler by Frederik Eaton into libDAI
[libdai.git] / include / dai / index.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
7
8 This file is part of libDAI.
9
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25
26 /// \file
27 /// \brief Defines the IndexFor, MultiFor, Permute and State classes
28 /// \todo Improve documentation
29
30
31 #ifndef __defined_libdai_index_h
32 #define __defined_libdai_index_h
33
34
35 #include <vector>
36 #include <algorithm>
37 #include <map>
38 #include <cassert>
39 #include <dai/varset.h>
40
41
42 namespace dai {
43
44
45 /// Tool for looping over the states of several variables.
46 /** The class IndexFor is an important tool for indexing Factor entries.
47 * Its usage can best be explained by an example.
48 * Assume indexVars, forVars are both VarSets.
49 * Then the following code:
50 * \code
51 * IndexFor i( indexVars, forVars );
52 * for( ; i >= 0; ++i ) {
53 * // use long(i)
54 * }
55 * \endcode
56 * loops over all joint states of the variables in forVars,
57 * and (long)i is equal to the linear index of the corresponding
58 * state of indexVars, where the variables in indexVars that are
59 * not in forVars assume their zero'th value.
60 * \idea Optimize all indices as follows: keep a cache of all (or only
61 * relatively small) indices that have been computed (use a hash). Then,
62 * instead of computing on the fly, use the precomputed ones.
63 */
64 class IndexFor {
65 private:
66 /// The current linear index corresponding to the state of indexVars
67 long _index;
68
69 /// For each variable in forVars, the amount of change in _index
70 std::vector<long> _sum;
71
72 /// For each variable in forVars, the current state
73 std::vector<size_t> _count;
74
75 /// For each variable in forVars, its number of possible values
76 std::vector<size_t> _dims;
77
78 public:
79 /// Default constructor
80 IndexFor() {
81 _index = -1;
82 }
83
84 /// Constructor
85 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
86 long sum = 1;
87
88 _dims.reserve( forVars.size() );
89 _sum.reserve( forVars.size() );
90
91 VarSet::const_iterator j = forVars.begin();
92 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
93 for( ; j != forVars.end() && *j <= *i; ++j ) {
94 _dims.push_back( j->states() );
95 _sum.push_back( (*i == *j) ? sum : 0 );
96 }
97 sum *= i->states();
98 }
99 for( ; j != forVars.end(); ++j ) {
100 _dims.push_back( j->states() );
101 _sum.push_back( 0 );
102 }
103 _index = 0;
104 }
105
106 /// Copy constructor
107 IndexFor( const IndexFor & ind ) : _index(ind._index), _sum(ind._sum), _count(ind._count), _dims(ind._dims) {}
108
109 /// Assignment operator
110 IndexFor& operator=( const IndexFor &ind ) {
111 if( this != &ind ) {
112 _index = ind._index;
113 _sum = ind._sum;
114 _count = ind._count;
115 _dims = ind._dims;
116 }
117 return *this;
118 }
119
120 /// Sets the index back to zero
121 IndexFor& clear() {
122 fill( _count.begin(), _count.end(), 0 );
123 _index = 0;
124 return( *this );
125 }
126
127 /// Conversion to long
128 operator long () const {
129 return( _index );
130 }
131
132 /// Pre-increment operator
133 IndexFor& operator++ () {
134 if( _index >= 0 ) {
135 size_t i = 0;
136
137 while( i < _count.size() ) {
138 _index += _sum[i];
139 if( ++_count[i] < _dims[i] )
140 break;
141 _index -= _sum[i] * _dims[i];
142 _count[i] = 0;
143 i++;
144 }
145
146 if( i == _count.size() )
147 _index = -1;
148 }
149 return( *this );
150 }
151 };
152
153
154 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
155 /** An example of the usage is as follows:
156 * \code
157 * std::vector<size_t> dims;
158 * dims.push_back( 3 );
159 * dims.push_back( 4 );
160 * dims.push_back( 5 );
161 * for( MultiFor s(dims); s.valid(); ++s )
162 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
163 * \endcode
164 * which would be equivalent to:
165 * \code
166 * size_t s = 0;
167 * for( size_t s0 = 0; s0 < 3; s0++ )
168 * for( size_t s1 = 0; s1 < 4; s1++ )
169 * for( size_t s2 = 0; s2 < 5; s++, s2++ )
170 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s0 << ", " << s1 << ", " << s2 << endl;
171 * \endcode
172 */
173 class MultiFor {
174 private:
175 std::vector<size_t> _dims;
176 std::vector<size_t> _states;
177 long _state;
178
179 public:
180 /// Default constructor
181 MultiFor() : _dims(), _states(), _state(0) {}
182
183 /// Initialize from vector of index dimensions
184 MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
185
186 /// Copy constructor
187 MultiFor( const MultiFor &x ) : _dims(x._dims), _states(x._states), _state(x._state) {}
188
189 /// Assignment operator
190 MultiFor& operator=( const MultiFor & x ) {
191 if( this != &x ) {
192 _dims = x._dims;
193 _states = x._states;
194 _state = x._state;
195 }
196 return *this;
197 }
198
199 /// Return linear state
200 operator size_t() const {
201 assert( valid() );
202 return( _state );
203 }
204
205 /// Return k'th index
206 size_t operator[]( size_t k ) const {
207 assert( valid() );
208 assert( k < _states.size() );
209 return _states[k];
210 }
211
212 /// Prefix increment operator
213 MultiFor & operator++() {
214 if( valid() ) {
215 _state++;
216 size_t i;
217 for( i = 0; i != _states.size(); i++ ) {
218 if( ++(_states[i]) < _dims[i] )
219 break;
220 _states[i] = 0;
221 }
222 if( i == _states.size() )
223 _state = -1;
224 }
225 return *this;
226 }
227
228 /// Postfix increment operator
229 void operator++( int ) {
230 operator++();
231 }
232
233 /// Returns true if the current state is valid
234 bool valid() const {
235 return( _state >= 0 );
236 }
237 };
238
239
240 /// Tool for calculating permutations of multiple indices.
241 class Permute {
242 private:
243 std::vector<size_t> _dims;
244 std::vector<size_t> _sigma;
245
246 public:
247 /// Default constructor
248 Permute() : _dims(), _sigma() {}
249
250 /// Initialize from vector of index dimensions and permutation sigma
251 Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
252 assert( _dims.size() == _sigma.size() );
253 }
254
255 /// Copy constructor
256 Permute( const Permute &x ) : _dims(x._dims), _sigma(x._sigma) {}
257
258 /// Assignment operator
259 Permute& operator=( const Permute &x ) {
260 if( this != &x ) {
261 _dims = x._dims;
262 _sigma = x._sigma;
263 }
264 return *this;
265 }
266
267 /// Calculates a permuted linear index.
268 /** Converts the linear index li to a vector index
269 * corresponding with the dimensions in _dims, permutes it according to sigma,
270 * and converts it back to a linear index according to the permuted dimensions.
271 */
272 size_t convert_linear_index( size_t li ) {
273 size_t N = _dims.size();
274
275 // calculate vector index corresponding to linear index
276 std::vector<size_t> vi;
277 vi.reserve( N );
278 size_t prod = 1;
279 for( size_t k = 0; k < N; k++ ) {
280 vi.push_back( li % _dims[k] );
281 li /= _dims[k];
282 prod *= _dims[k];
283 }
284
285 // convert permuted vector index to corresponding linear index
286 prod = 1;
287 size_t sigma_li = 0;
288 for( size_t k = 0; k < N; k++ ) {
289 sigma_li += vi[_sigma[k]] * prod;
290 prod *= _dims[_sigma[k]];
291 }
292
293 return sigma_li;
294 }
295 };
296
297
298 /// Contains the joint state of variables within a VarSet and useful things to do with this information.
299 /** This is very similar to a MultiFor, but tailored for Vars and Varsets.
300 */
301 class State {
302 private:
303 typedef std::map<Var, size_t> states_type;
304
305 long state;
306 states_type states;
307
308 public:
309 /// Default constructor
310 State() : state(0), states() {}
311
312 /// Initialize from VarSet
313 State( const VarSet &vs ) : state(0) {
314 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
315 states[*v] = 0;
316 }
317
318 /// Copy constructor
319 State( const State & x ) : state(x.state), states(x.states) {}
320
321 /// Assignment operator
322 State& operator=( const State &x ) {
323 if( this != &x ) {
324 state = x.state;
325 states = x.states;
326 }
327 return *this;
328 }
329
330 /// Return linear state
331 operator size_t() const {
332 assert( valid() );
333 return( state );
334 }
335
336 /// Return state of variable n, or zero if n is not in this State
337 size_t operator() ( const Var &n ) const {
338 assert( valid() );
339 states_type::const_iterator entry = states.find( n );
340 if( entry == states.end() )
341 return 0;
342 else
343 return entry->second;
344 }
345
346 /// Return linear state of variables in varset, setting them to zero if they are not in this State
347 size_t operator() ( const VarSet &vs ) const {
348 assert( valid() );
349 size_t vs_state = 0;
350 size_t prod = 1;
351 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
352 states_type::const_iterator entry = states.find( *v );
353 if( entry != states.end() )
354 vs_state += entry->second * prod;
355 prod *= v->states();
356 }
357 return vs_state;
358 }
359
360 /// Postfix increment operator
361 void operator++( int ) {
362 if( valid() ) {
363 state++;
364 states_type::iterator entry = states.begin();
365 while( entry != states.end() ) {
366 if( ++(entry->second) < entry->first.states() )
367 break;
368 entry->second = 0;
369 entry++;
370 }
371 if( entry == states.end() )
372 state = -1;
373 }
374 }
375
376 /// Returns true if the current state is valid
377 bool valid() const {
378 return( state >= 0 );
379 }
380 };
381
382
383 } // end of namespace dai
384
385
386 #endif