New git HEAD version
[libdai.git] / include / dai / index.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 /// \file
10 /// \brief Defines the IndexFor, multifor, Permute and State classes, which all deal with indexing multi-dimensional arrays
11
12
13 #ifndef __defined_libdai_index_h
14 #define __defined_libdai_index_h
15
16
17 #include <vector>
18 #include <algorithm>
19 #include <map>
20 #include <dai/varset.h>
21
22
23 namespace dai {
24
25
26 /// Tool for looping over the states of several variables.
27 /** The class IndexFor is an important tool for indexing Factor entries.
28 * Its usage can best be explained by an example.
29 * Assume \a indexVars, \a forVars are both VarSet 's.
30 * Then the following code:
31 * \code
32 * IndexFor i( indexVars, forVars );
33 * size_t iter = 0;
34 * for( ; i.valid(); i++, iter++ ) {
35 * cout << "State of forVars: " << calcState( forVars, iter ) << "; ";
36 * cout << "state of indexVars: " << calcState( indexVars, size_t(i) ) << endl;
37 * }
38 * \endcode
39 * loops over all joint states of the variables in \a forVars,
40 * and <tt>(size_t)i</tt> equals the linear index of the corresponding
41 * state of \a indexVars, where the variables in \a indexVars that are
42 * not in \a forVars assume their zero'th value.
43 * \idea Optimize all indices as follows: keep a cache of all (or only
44 * relatively small) indices that have been computed (use a hash). Then,
45 * instead of computing on the fly, use the precomputed ones. Here the
46 * labels of the variables don't matter, but the ranges of the variables do.
47 */
48 class IndexFor {
49 private:
50 /// The current linear index corresponding to the state of indexVars
51 long _index;
52
53 /// For each variable in forVars, the amount of change in _index
54 std::vector<long> _sum;
55
56 /// For each variable in forVars, the current state
57 std::vector<size_t> _state;
58
59 /// For each variable in forVars, its number of possible values
60 std::vector<size_t> _ranges;
61
62 public:
63 /// Default constructor
64 IndexFor() : _index(-1) {}
65
66 /// Construct IndexFor object from \a indexVars and \a forVars
67 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _state( forVars.size(), 0 ) {
68 long sum = 1;
69
70 _ranges.reserve( forVars.size() );
71 _sum.reserve( forVars.size() );
72
73 VarSet::const_iterator j = forVars.begin();
74 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
75 for( ; j != forVars.end() && *j <= *i; ++j ) {
76 _ranges.push_back( j->states() );
77 _sum.push_back( (*i == *j) ? sum : 0 );
78 }
79 sum *= i->states();
80 }
81 for( ; j != forVars.end(); ++j ) {
82 _ranges.push_back( j->states() );
83 _sum.push_back( 0 );
84 }
85 _index = 0;
86 }
87
88 /// Resets the state
89 IndexFor& reset() {
90 fill( _state.begin(), _state.end(), 0 );
91 _index = 0;
92 return( *this );
93 }
94
95 /// Conversion to \c size_t: returns linear index of the current state of indexVars
96 operator size_t() const {
97 DAI_ASSERT( valid() );
98 return( _index );
99 }
100
101 /// Increments the current state of \a forVars (prefix)
102 IndexFor& operator++ () {
103 if( _index >= 0 ) {
104 size_t i = 0;
105
106 while( i < _state.size() ) {
107 _index += _sum[i];
108 if( ++_state[i] < _ranges[i] )
109 break;
110 _index -= _sum[i] * _ranges[i];
111 _state[i] = 0;
112 i++;
113 }
114
115 if( i == _state.size() )
116 _index = -1;
117 }
118 return( *this );
119 }
120
121 /// Increments the current state of \a forVars (postfix)
122 void operator++( int ) {
123 operator++();
124 }
125
126 /// Returns \c true if the current state is valid
127 bool valid() const {
128 return( _index >= 0 );
129 }
130 };
131
132
133 /// Tool for calculating permutations of linear indices of multi-dimensional arrays.
134 /** \note This is mainly useful for converting indices into multi-dimensional arrays
135 * corresponding to joint states of variables to and from the canonical ordering used in libDAI.
136 */
137 class Permute {
138 private:
139 /// Stores the number of possible values of all indices
140 std::vector<size_t> _ranges;
141 /// Stores the permutation
142 std::vector<size_t> _sigma;
143
144 public:
145 /// Default constructor
146 Permute() : _ranges(), _sigma() {}
147
148 /// Construct from vector of index ranges and permutation
149 Permute( const std::vector<size_t> &rs, const std::vector<size_t> &sigma ) : _ranges(rs), _sigma(sigma) {
150 DAI_ASSERT( _ranges.size() == _sigma.size() );
151 }
152
153 /// Construct from vector of variables.
154 /** The implied permutation maps the index of each variable in \a vars according to the canonical ordering
155 * (i.e., sorted ascendingly according to their label) to the index it has in \a vars.
156 * If \a reverse == \c true, reverses the indexing in \a vars first.
157 */
158 Permute( const std::vector<Var> &vars, bool reverse=false ) : _ranges(), _sigma() {
159 size_t N = vars.size();
160
161 // construct ranges
162 _ranges.reserve( N );
163 for( size_t i = 0; i < N; ++i )
164 if( reverse )
165 _ranges.push_back( vars[N - 1 - i].states() );
166 else
167 _ranges.push_back( vars[i].states() );
168
169 // construct VarSet out of vars
170 VarSet vs( vars.begin(), vars.end(), N );
171 DAI_ASSERT( vs.size() == N );
172
173 // construct sigma
174 _sigma.reserve( N );
175 for( VarSet::const_iterator vs_i = vs.begin(); vs_i != vs.end(); ++vs_i ) {
176 size_t ind = find( vars.begin(), vars.end(), *vs_i ) - vars.begin();
177 if( reverse )
178 _sigma.push_back( N - 1 - ind );
179 else
180 _sigma.push_back( ind );
181 }
182 }
183
184 /// Calculates a permuted linear index.
185 /** Converts the linear index \a li to a vector index, permutes its
186 * components, and converts it back to a linear index.
187 */
188 size_t convertLinearIndex( size_t li ) const {
189 size_t N = _ranges.size();
190
191 // calculate vector index corresponding to linear index
192 std::vector<size_t> vi;
193 vi.reserve( N );
194 size_t prod = 1;
195 for( size_t k = 0; k < N; k++ ) {
196 vi.push_back( li % _ranges[k] );
197 li /= _ranges[k];
198 prod *= _ranges[k];
199 }
200
201 // convert permuted vector index to corresponding linear index
202 prod = 1;
203 size_t sigma_li = 0;
204 for( size_t k = 0; k < N; k++ ) {
205 sigma_li += vi[_sigma[k]] * prod;
206 prod *= _ranges[_sigma[k]];
207 }
208
209 return sigma_li;
210 }
211
212 /// Returns constant reference to the permutation
213 const std::vector<size_t>& sigma() const { return _sigma; }
214
215 /// Returns reference to the permutation
216 std::vector<size_t>& sigma() { return _sigma; }
217
218 /// Returns constant reference to the dimensionality vector
219 const std::vector<size_t>& ranges() { return _ranges; }
220
221 /// Returns the result of applying the permutation on \a i
222 size_t operator[]( size_t i ) const {
223 #ifdef DAI_DEBUG
224 return _sigma.at(i);
225 #else
226 return _sigma[i];
227 #endif
228 }
229
230 /// Returns the inverse permutation
231 Permute inverse() const {
232 size_t N = _ranges.size();
233 std::vector<size_t> invRanges( N, 0 );
234 std::vector<size_t> invSigma( N, 0 );
235 for( size_t i = 0; i < N; i++ ) {
236 invSigma[_sigma[i]] = i;
237 invRanges[i] = _ranges[_sigma[i]];
238 }
239 return Permute( invRanges, invSigma );
240 }
241 };
242
243
244 /// multifor makes it easy to perform a dynamic number of nested \c for loops.
245 /** An example of the usage is as follows:
246 * \code
247 * std::vector<size_t> ranges;
248 * ranges.push_back( 3 );
249 * ranges.push_back( 4 );
250 * ranges.push_back( 5 );
251 * for( multifor s(ranges); s.valid(); ++s )
252 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[2] << ", " << s[1] << ", " << s[0] << endl;
253 * \endcode
254 * which would be equivalent to:
255 * \code
256 * size_t s = 0;
257 * for( size_t s2 = 0; s2 < 5; s2++ )
258 * for( size_t s1 = 0; s1 < 4; s1++ )
259 * for( size_t s0 = 0; s0 < 3; s++, s0++ )
260 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s2 << ", " << s1 << ", " << s0 << endl;
261 * \endcode
262 */
263 class multifor {
264 private:
265 /// Stores the number of possible values of all indices
266 std::vector<size_t> _ranges;
267 /// Stores the current values of all indices
268 std::vector<size_t> _indices;
269 /// Stores the current linear index
270 long _linear_index;
271
272 public:
273 /// Default constructor
274 multifor() : _ranges(), _indices(), _linear_index(0) {}
275
276 /// Initialize from vector of index ranges
277 multifor( const std::vector<size_t> &d ) : _ranges(d), _indices(d.size(),0), _linear_index(0) {}
278
279 /// Returns linear index (i.e., the index in the Cartesian product space)
280 operator size_t() const {
281 DAI_DEBASSERT( valid() );
282 return( _linear_index );
283 }
284
285 /// Returns \a k 'th index
286 size_t operator[]( size_t k ) const {
287 DAI_DEBASSERT( valid() );
288 DAI_DEBASSERT( k < _indices.size() );
289 return _indices[k];
290 }
291
292 /// Increments the current indices (prefix)
293 multifor & operator++() {
294 if( valid() ) {
295 _linear_index++;
296 size_t i;
297 for( i = 0; i != _indices.size(); i++ ) {
298 if( ++(_indices[i]) < _ranges[i] )
299 break;
300 _indices[i] = 0;
301 }
302 if( i == _indices.size() )
303 _linear_index = -1;
304 }
305 return *this;
306 }
307
308 /// Increments the current indices (postfix)
309 void operator++( int ) {
310 operator++();
311 }
312
313 /// Resets the state
314 multifor& reset() {
315 fill( _indices.begin(), _indices.end(), 0 );
316 _linear_index = 0;
317 return( *this );
318 }
319
320 /// Returns \c true if the current indices are valid
321 bool valid() const {
322 return( _linear_index >= 0 );
323 }
324 };
325
326
327 /// Makes it easy to iterate over all possible joint states of variables within a VarSet.
328 /** A joint state of several variables can be represented in two different ways, by a map that maps each variable
329 * to its own state, or by an integer that gives the index of the joint state in the canonical enumeration.
330 *
331 * Both representations are useful, and the main functionality provided by the State class is to simplify iterating
332 * over the various joint states of a VarSet and to provide access to the current state in both representations.
333 *
334 * As an example, consider the following code snippet which iterates over all joint states of variables \a x0 and \a x1:
335 * \code
336 * VarSet vars( x0, x1 );
337 * for( State S(vars); S.valid(); S++ ) {
338 * cout << "Linear state: " << S.get() << ", x0 = " << S(x0) << ", x1 = " << S(x1) << endl;
339 * }
340 * \endcode
341 *
342 * \note The same functionality could be achieved by simply iterating over the linear state and using dai::calcState(),
343 * but the State class offers a more efficient implementation.
344 *
345 * \note A State is very similar to a dai::multifor, but tailored for Var 's and VarSet 's.
346 *
347 * \see dai::calcLinearState(), dai::calcState()
348 *
349 * \idea Make the State class a more prominent part of libDAI (and document it clearly, explaining the concept of state);
350 * add more optimized variants of the State class like IndexFor (e.g. for TFactor<>::slice()).
351 */
352 class State {
353 private:
354 /// Type for representing a joint state of some variables as a map, which maps each variable to its state
355 typedef std::map<Var, size_t> states_type;
356
357 /// Current state (represented linearly)
358 BigInt state;
359
360 /// Current state (represented as a map)
361 states_type states;
362
363 public:
364 /// Default constructor
365 State() : state(0), states() {}
366
367 /// Construct from VarSet \a vs and corresponding linear state \a linearState
368 State( const VarSet &vs, BigInt linearState=0 ) : state(linearState), states() {
369 if( linearState == 0 )
370 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
371 states[*v] = 0;
372 else {
373 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
374 states[*v] = BigInt_size_t( linearState % (BigInt)v->states() );
375 linearState /= (BigInt)v->states();
376 }
377 DAI_ASSERT( linearState == 0 );
378 }
379 }
380
381 /// Construct from a std::map<Var, size_t>
382 State( const std::map<Var, size_t> &s ) : state(0), states() {
383 insert( s.begin(), s.end() );
384 }
385
386 /// Constant iterator over the values
387 typedef states_type::const_iterator const_iterator;
388
389 /// Returns constant iterator that points to the first item
390 const_iterator begin() const { return states.begin(); }
391
392 /// Returns constant iterator that points beyond the last item
393 const_iterator end() const { return states.end(); }
394
395 /// Return current linear state
396 operator size_t() const {
397 DAI_ASSERT( valid() );
398 return( BigInt_size_t( state ) );
399 }
400
401 /// Inserts a range of variable-state pairs, changing the current state
402 template<typename InputIterator>
403 void insert( InputIterator b, InputIterator e ) {
404 states.insert( b, e );
405 VarSet vars;
406 for( const_iterator it = begin(); it != end(); it++ )
407 vars |= it->first;
408 state = 0;
409 state = this->operator()( vars );
410 }
411
412 /// Return current state represented as a map
413 const std::map<Var,size_t>& get() const { return states; }
414
415 /// Cast into std::map<Var, size_t>
416 operator const std::map<Var,size_t>& () const { return states; }
417
418 /// Return current state of variable \a v, or 0 if \a v is not in \c *this
419 size_t operator() ( const Var &v ) const {
420 states_type::const_iterator entry = states.find( v );
421 if( entry == states.end() )
422 return 0;
423 else
424 return entry->second;
425 }
426
427 /// Return linear state of variables in \a vs, assuming that variables that are not in \c *this are in state 0
428 BigInt operator() ( const VarSet &vs ) const {
429 BigInt vs_state = 0;
430 BigInt prod = 1;
431 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
432 states_type::const_iterator entry = states.find( *v );
433 if( entry != states.end() )
434 vs_state += (BigInt)entry->second * prod;
435 prod *= (BigInt)v->states();
436 }
437 return vs_state;
438 }
439
440 /// Increments the current state (prefix)
441 void operator++( ) {
442 if( valid() ) {
443 state++;
444 states_type::iterator entry = states.begin();
445 while( entry != states.end() ) {
446 if( ++(entry->second) < entry->first.states() )
447 break;
448 entry->second = 0;
449 entry++;
450 }
451 if( entry == states.end() )
452 state = -1;
453 }
454 }
455
456 /// Increments the current state (postfix)
457 void operator++( int ) {
458 operator++();
459 }
460
461 /// Returns \c true if the current state is valid
462 bool valid() const {
463 return( state >= 0 );
464 }
465
466 /// Resets the current state (to the joint state represented by linear state 0)
467 void reset() {
468 state = 0;
469 for( states_type::iterator s = states.begin(); s != states.end(); s++ )
470 s->second = 0;
471 }
472 };
473
474
475 } // end of namespace dai
476
477
478 /** \example example_permute.cpp
479 * This example shows how to use the Permute, multifor and State classes.
480 *
481 * \section Output
482 * \verbinclude examples/example_permute.out
483 *
484 * \section Source
485 */
486
487
488 #endif