e301b2465dcb66cc0dfeb281624ca882ff8e3b9b
[libdai.git] / include / dai / index.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
8 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
9 * Copyright (C) 2002-2007 Radboud University Nijmegen, The Netherlands
10 */
11
12
13 /// \file
14 /// \brief Defines the IndexFor, multifor, Permute and State classes, which all deal with indexing multi-dimensional arrays
15
16
17 #ifndef __defined_libdai_index_h
18 #define __defined_libdai_index_h
19
20
21 #include <vector>
22 #include <algorithm>
23 #include <map>
24 #include <dai/varset.h>
25
26
27 namespace dai {
28
29
30 /// Tool for looping over the states of several variables.
31 /** The class IndexFor is an important tool for indexing Factor entries.
32 * Its usage can best be explained by an example.
33 * Assume \a indexVars, \a forVars are both VarSet 's.
34 * Then the following code:
35 * \code
36 * IndexFor i( indexVars, forVars );
37 * size_t iter = 0;
38 * for( ; i.valid(); i++, iter++ ) {
39 * cout << "State of forVars: " << calcState( forVars, iter ) << "; ";
40 * cout << "state of indexVars: " << calcState( indexVars, size_t(i) ) << endl;
41 * }
42 * \endcode
43 * loops over all joint states of the variables in \a forVars,
44 * and <tt>(size_t)i</tt> equals the linear index of the corresponding
45 * state of \a indexVars, where the variables in \a indexVars that are
46 * not in \a forVars assume their zero'th value.
47 * \idea Optimize all indices as follows: keep a cache of all (or only
48 * relatively small) indices that have been computed (use a hash). Then,
49 * instead of computing on the fly, use the precomputed ones. Here the
50 * labels of the variables don't matter, but the ranges of the variables do.
51 */
52 class IndexFor {
53 private:
54 /// The current linear index corresponding to the state of indexVars
55 long _index;
56
57 /// For each variable in forVars, the amount of change in _index
58 std::vector<long> _sum;
59
60 /// For each variable in forVars, the current state
61 std::vector<size_t> _state;
62
63 /// For each variable in forVars, its number of possible values
64 std::vector<size_t> _ranges;
65
66 public:
67 /// Default constructor
68 IndexFor() : _index(-1) {}
69
70 /// Construct IndexFor object from \a indexVars and \a forVars
71 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _state( forVars.size(), 0 ) {
72 long sum = 1;
73
74 _ranges.reserve( forVars.size() );
75 _sum.reserve( forVars.size() );
76
77 VarSet::const_iterator j = forVars.begin();
78 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
79 for( ; j != forVars.end() && *j <= *i; ++j ) {
80 _ranges.push_back( j->states() );
81 _sum.push_back( (*i == *j) ? sum : 0 );
82 }
83 sum *= i->states();
84 }
85 for( ; j != forVars.end(); ++j ) {
86 _ranges.push_back( j->states() );
87 _sum.push_back( 0 );
88 }
89 _index = 0;
90 }
91
92 /// Resets the state
93 IndexFor& reset() {
94 fill( _state.begin(), _state.end(), 0 );
95 _index = 0;
96 return( *this );
97 }
98
99 /// Conversion to \c size_t: returns linear index of the current state of indexVars
100 operator size_t() const {
101 DAI_ASSERT( valid() );
102 return( _index );
103 }
104
105 /// Increments the current state of \a forVars (prefix)
106 IndexFor& operator++ () {
107 if( _index >= 0 ) {
108 size_t i = 0;
109
110 while( i < _state.size() ) {
111 _index += _sum[i];
112 if( ++_state[i] < _ranges[i] )
113 break;
114 _index -= _sum[i] * _ranges[i];
115 _state[i] = 0;
116 i++;
117 }
118
119 if( i == _state.size() )
120 _index = -1;
121 }
122 return( *this );
123 }
124
125 /// Increments the current state of \a forVars (postfix)
126 void operator++( int ) {
127 operator++();
128 }
129
130 /// Returns \c true if the current state is valid
131 bool valid() const {
132 return( _index >= 0 );
133 }
134 };
135
136
137 /// Tool for calculating permutations of linear indices of multi-dimensional arrays.
138 /** \note This is mainly useful for converting indices into multi-dimensional arrays
139 * corresponding to joint states of variables to and from the canonical ordering used in libDAI.
140 */
141 class Permute {
142 private:
143 /// Stores the number of possible values of all indices
144 std::vector<size_t> _ranges;
145 /// Stores the permutation
146 std::vector<size_t> _sigma;
147
148 public:
149 /// Default constructor
150 Permute() : _ranges(), _sigma() {}
151
152 /// Construct from vector of index ranges and permutation
153 Permute( const std::vector<size_t> &rs, const std::vector<size_t> &sigma ) : _ranges(rs), _sigma(sigma) {
154 DAI_ASSERT( _ranges.size() == _sigma.size() );
155 }
156
157 /// Construct from vector of variables.
158 /** The implied permutation maps the index of each variable in \a vars according to the canonical ordering
159 * (i.e., sorted ascendingly according to their label) to the index it has in \a vars.
160 * If \a reverse == \c true, reverses the indexing in \a vars first.
161 */
162 Permute( const std::vector<Var> &vars, bool reverse=false ) : _ranges(), _sigma() {
163 size_t N = vars.size();
164
165 // construct ranges
166 _ranges.reserve( N );
167 for( size_t i = 0; i < N; ++i )
168 if( reverse )
169 _ranges.push_back( vars[N - 1 - i].states() );
170 else
171 _ranges.push_back( vars[i].states() );
172
173 // construct VarSet out of vars
174 VarSet vs( vars.begin(), vars.end(), N );
175 DAI_ASSERT( vs.size() == N );
176
177 // construct sigma
178 _sigma.reserve( N );
179 for( VarSet::const_iterator vs_i = vs.begin(); vs_i != vs.end(); ++vs_i ) {
180 size_t ind = find( vars.begin(), vars.end(), *vs_i ) - vars.begin();
181 if( reverse )
182 _sigma.push_back( N - 1 - ind );
183 else
184 _sigma.push_back( ind );
185 }
186 }
187
188 /// Calculates a permuted linear index.
189 /** Converts the linear index \a li to a vector index, permutes its
190 * components, and converts it back to a linear index.
191 */
192 size_t convertLinearIndex( size_t li ) const {
193 size_t N = _ranges.size();
194
195 // calculate vector index corresponding to linear index
196 std::vector<size_t> vi;
197 vi.reserve( N );
198 size_t prod = 1;
199 for( size_t k = 0; k < N; k++ ) {
200 vi.push_back( li % _ranges[k] );
201 li /= _ranges[k];
202 prod *= _ranges[k];
203 }
204
205 // convert permuted vector index to corresponding linear index
206 prod = 1;
207 size_t sigma_li = 0;
208 for( size_t k = 0; k < N; k++ ) {
209 sigma_li += vi[_sigma[k]] * prod;
210 prod *= _ranges[_sigma[k]];
211 }
212
213 return sigma_li;
214 }
215
216 /// Returns constant reference to the permutation
217 const std::vector<size_t>& sigma() const { return _sigma; }
218
219 /// Returns reference to the permutation
220 std::vector<size_t>& sigma() { return _sigma; }
221
222 /// Returns constant reference to the dimensionality vector
223 const std::vector<size_t>& ranges() { return _ranges; }
224
225 /// Returns the result of applying the permutation on \a i
226 size_t operator[]( size_t i ) const {
227 #ifdef DAI_DEBUG
228 return _sigma.at(i);
229 #else
230 return _sigma[i];
231 #endif
232 }
233
234 /// Returns the inverse permutation
235 Permute inverse() const {
236 size_t N = _ranges.size();
237 std::vector<size_t> invRanges( N, 0 );
238 std::vector<size_t> invSigma( N, 0 );
239 for( size_t i = 0; i < N; i++ ) {
240 invSigma[_sigma[i]] = i;
241 invRanges[i] = _ranges[_sigma[i]];
242 }
243 return Permute( invRanges, invSigma );
244 }
245 };
246
247
248 /// multifor makes it easy to perform a dynamic number of nested \c for loops.
249 /** An example of the usage is as follows:
250 * \code
251 * std::vector<size_t> ranges;
252 * ranges.push_back( 3 );
253 * ranges.push_back( 4 );
254 * ranges.push_back( 5 );
255 * for( multifor s(ranges); s.valid(); ++s )
256 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[2] << ", " << s[1] << ", " << s[0] << endl;
257 * \endcode
258 * which would be equivalent to:
259 * \code
260 * size_t s = 0;
261 * for( size_t s2 = 0; s2 < 5; s2++ )
262 * for( size_t s1 = 0; s1 < 4; s1++ )
263 * for( size_t s0 = 0; s0 < 3; s++, s0++ )
264 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s2 << ", " << s1 << ", " << s0 << endl;
265 * \endcode
266 */
267 class multifor {
268 private:
269 /// Stores the number of possible values of all indices
270 std::vector<size_t> _ranges;
271 /// Stores the current values of all indices
272 std::vector<size_t> _indices;
273 /// Stores the current linear index
274 long _linear_index;
275
276 public:
277 /// Default constructor
278 multifor() : _ranges(), _indices(), _linear_index(0) {}
279
280 /// Initialize from vector of index ranges
281 multifor( const std::vector<size_t> &d ) : _ranges(d), _indices(d.size(),0), _linear_index(0) {}
282
283 /// Returns linear index (i.e., the index in the Cartesian product space)
284 operator size_t() const {
285 DAI_DEBASSERT( valid() );
286 return( _linear_index );
287 }
288
289 /// Returns \a k 'th index
290 size_t operator[]( size_t k ) const {
291 DAI_DEBASSERT( valid() );
292 DAI_DEBASSERT( k < _indices.size() );
293 return _indices[k];
294 }
295
296 /// Increments the current indices (prefix)
297 multifor & operator++() {
298 if( valid() ) {
299 _linear_index++;
300 size_t i;
301 for( i = 0; i != _indices.size(); i++ ) {
302 if( ++(_indices[i]) < _ranges[i] )
303 break;
304 _indices[i] = 0;
305 }
306 if( i == _indices.size() )
307 _linear_index = -1;
308 }
309 return *this;
310 }
311
312 /// Increments the current indices (postfix)
313 void operator++( int ) {
314 operator++();
315 }
316
317 /// Resets the state
318 multifor& reset() {
319 fill( _indices.begin(), _indices.end(), 0 );
320 _linear_index = 0;
321 return( *this );
322 }
323
324 /// Returns \c true if the current indices are valid
325 bool valid() const {
326 return( _linear_index >= 0 );
327 }
328 };
329
330
331 /// Makes it easy to iterate over all possible joint states of variables within a VarSet.
332 /** A joint state of several variables can be represented in two different ways, by a map that maps each variable
333 * to its own state, or by an integer that gives the index of the joint state in the canonical enumeration.
334 *
335 * Both representations are useful, and the main functionality provided by the State class is to simplify iterating
336 * over the various joint states of a VarSet and to provide access to the current state in both representations.
337 *
338 * As an example, consider the following code snippet which iterates over all joint states of variables \a x0 and \a x1:
339 * \code
340 * VarSet vars( x0, x1 );
341 * for( State S(vars); S.valid(); S++ ) {
342 * cout << "Linear state: " << S.get() << ", x0 = " << S(x0) << ", x1 = " << S(x1) << endl;
343 * }
344 * \endcode
345 *
346 * \note The same functionality could be achieved by simply iterating over the linear state and using dai::calcState(),
347 * but the State class offers a more efficient implementation.
348 *
349 * \note A State is very similar to a dai::multifor, but tailored for Var 's and VarSet 's.
350 *
351 * \see dai::calcLinearState(), dai::calcState()
352 *
353 * \idea Make the State class a more prominent part of libDAI
354 * (and document it clearly, explaining the concept of state);
355 * add more optimized variants of the State class like IndexFor
356 * (e.g. for TFactor<>::slice()).
357 */
358 class State {
359 private:
360 /// Type for representing a joint state of some variables as a map, which maps each variable to its state
361 typedef std::map<Var, size_t> states_type;
362
363 /// Current state (represented linearly)
364 long state;
365
366 /// Current state (represented as a map)
367 states_type states;
368
369 public:
370 /// Default constructor
371 State() : state(0), states() {}
372
373 /// Construct from VarSet \a vs and corresponding linear state \a linearState
374 State( const VarSet &vs, size_t linearState=0 ) : state(linearState), states() {
375 if( linearState == 0 )
376 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
377 states[*v] = 0;
378 else {
379 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
380 states[*v] = linearState % v->states();
381 linearState /= v->states();
382 }
383 DAI_ASSERT( linearState == 0 );
384 }
385 }
386
387 /// Construct from a std::map<Var, size_t>
388 State( const std::map<Var, size_t> &s ) : state(0), states() {
389 insert( s.begin(), s.end() );
390 }
391
392 /// Constant iterator over the values
393 typedef states_type::const_iterator const_iterator;
394
395 /// Returns constant iterator that points to the first item
396 const_iterator begin() const { return states.begin(); }
397
398 /// Returns constant iterator that points beyond the last item
399 const_iterator end() const { return states.end(); }
400
401 /// Return current linear state
402 operator size_t() const {
403 DAI_ASSERT( valid() );
404 return( state );
405 }
406
407 /// Inserts a range of variable-state pairs, changing the current state
408 template<typename InputIterator>
409 void insert( InputIterator b, InputIterator e ) {
410 states.insert( b, e );
411 VarSet vars;
412 for( const_iterator it = begin(); it != end(); it++ )
413 vars |= it->first;
414 state = 0;
415 state = this->operator()( vars );
416 }
417
418 /// Return current state represented as a map
419 const std::map<Var,size_t>& get() const { return states; }
420
421 /// Cast into std::map<Var, size_t>
422 operator const std::map<Var,size_t>& () const { return states; }
423
424 /// Return current state of variable \a v, or 0 if \a v is not in \c *this
425 size_t operator() ( const Var &v ) const {
426 DAI_ASSERT( valid() );
427 states_type::const_iterator entry = states.find( v );
428 if( entry == states.end() )
429 return 0;
430 else
431 return entry->second;
432 }
433
434 /// Return linear state of variables in \a vs, assuming that variables that are not in \c *this are in state 0
435 size_t operator() ( const VarSet &vs ) const {
436 DAI_ASSERT( valid() );
437 size_t vs_state = 0;
438 size_t prod = 1;
439 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
440 states_type::const_iterator entry = states.find( *v );
441 if( entry != states.end() )
442 vs_state += entry->second * prod;
443 prod *= v->states();
444 }
445 return vs_state;
446 }
447
448 /// Increments the current state (prefix)
449 void operator++( ) {
450 if( valid() ) {
451 state++;
452 states_type::iterator entry = states.begin();
453 while( entry != states.end() ) {
454 if( ++(entry->second) < entry->first.states() )
455 break;
456 entry->second = 0;
457 entry++;
458 }
459 if( entry == states.end() )
460 state = -1;
461 }
462 }
463
464 /// Increments the current state (postfix)
465 void operator++( int ) {
466 operator++();
467 }
468
469 /// Returns \c true if the current state is valid
470 bool valid() const {
471 return( state >= 0 );
472 }
473
474 /// Resets the current state (to the joint state represented by linear state 0)
475 void reset() {
476 state = 0;
477 for( states_type::iterator s = states.begin(); s != states.end(); s++ )
478 s->second = 0;
479 }
480 };
481
482
483 } // end of namespace dai
484
485
486 /** \example example_permute.cpp
487 * This example shows how to use the Permute, multifor and State classes.
488 *
489 * \section Output
490 * \verbinclude examples/example_permute.out
491 *
492 * \section Source
493 */
494
495
496 #endif