Extended InfAlg interface with setProperties(), getProperties() and printProperties()
[libdai.git] / include / dai / index.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
8 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
9 * Copyright (C) 2002-2007 Radboud University Nijmegen, The Netherlands
10 */
11
12
13 /// \file
14 /// \brief Defines the IndexFor, multifor, Permute and State classes, which all deal with indexing multi-dimensional arrays
15
16
17 #ifndef __defined_libdai_index_h
18 #define __defined_libdai_index_h
19
20
21 #include <vector>
22 #include <algorithm>
23 #include <map>
24 #include <dai/varset.h>
25
26
27 namespace dai {
28
29
30 /// Tool for looping over the states of several variables.
31 /** The class IndexFor is an important tool for indexing Factor entries.
32 * Its usage can best be explained by an example.
33 * Assume \a indexVars, \a forVars are both VarSet 's.
34 * Then the following code:
35 * \code
36 * IndexFor i( indexVars, forVars );
37 * size_t iter = 0;
38 * for( ; i.valid(); i++, iter++ ) {
39 * cout << "State of forVars: " << forVars.calcStates( iter ) << "; ";
40 * cout << "state of indexVars: " << indexVars.calcStates( long(i) ) << endl;
41 * }
42 * \endcode
43 * loops over all joint states of the variables in \a forVars,
44 * and <tt>(long)i</tt> equals the linear index of the corresponding
45 * state of \a indexVars, where the variables in \a indexVars that are
46 * not in \a forVars assume their zero'th value.
47 * \idea Optimize all indices as follows: keep a cache of all (or only
48 * relatively small) indices that have been computed (use a hash). Then,
49 * instead of computing on the fly, use the precomputed ones. Here the
50 * labels of the variables don't matter, but the ranges of the variables do.
51 */
52 class IndexFor {
53 private:
54 /// The current linear index corresponding to the state of indexVars
55 long _index;
56
57 /// For each variable in forVars, the amount of change in _index
58 std::vector<long> _sum;
59
60 /// For each variable in forVars, the current state
61 std::vector<size_t> _state;
62
63 /// For each variable in forVars, its number of possible values
64 std::vector<size_t> _ranges;
65
66 public:
67 /// Default constructor
68 IndexFor() : _index(-1) {}
69
70 /// Construct IndexFor object from \a indexVars and \a forVars
71 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _state( forVars.size(), 0 ) {
72 long sum = 1;
73
74 _ranges.reserve( forVars.size() );
75 _sum.reserve( forVars.size() );
76
77 VarSet::const_iterator j = forVars.begin();
78 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
79 for( ; j != forVars.end() && *j <= *i; ++j ) {
80 _ranges.push_back( j->states() );
81 _sum.push_back( (*i == *j) ? sum : 0 );
82 }
83 sum *= i->states();
84 }
85 for( ; j != forVars.end(); ++j ) {
86 _ranges.push_back( j->states() );
87 _sum.push_back( 0 );
88 }
89 _index = 0;
90 }
91
92 /// Resets the state
93 IndexFor& reset() {
94 fill( _state.begin(), _state.end(), 0 );
95 _index = 0;
96 return( *this );
97 }
98
99 // OBSOLETE
100 /// Conversion to \c long: returns linear index of the current state of indexVars
101 operator long () const {
102 return( _index );
103 }
104
105 // NEW VERSION
106 /*
107 /// Conversion to \c size_t: returns linear index of the current state of indexVars
108 operator size_t() const {
109 DAI_ASSERT( valid() );
110 return( _index );
111 }
112 */
113
114 /// Increments the current state of \a forVars (prefix)
115 IndexFor& operator++ () {
116 if( _index >= 0 ) {
117 size_t i = 0;
118
119 while( i < _state.size() ) {
120 _index += _sum[i];
121 if( ++_state[i] < _ranges[i] )
122 break;
123 _index -= _sum[i] * _ranges[i];
124 _state[i] = 0;
125 i++;
126 }
127
128 if( i == _state.size() )
129 _index = -1;
130 }
131 return( *this );
132 }
133
134 /// Increments the current state of \a forVars (postfix)
135 void operator++( int ) {
136 operator++();
137 }
138
139 /// Returns \c true if the current state is valid
140 bool valid() const {
141 return( _index >= 0 );
142 }
143 };
144
145
146 /// Tool for calculating permutations of linear indices of multi-dimensional arrays.
147 /** \note This is mainly useful for converting indices into multi-dimensional arrays
148 * corresponding to joint states of variables to and from the canonical ordering used in libDAI.
149 */
150 class Permute {
151 private:
152 /// Stores the number of possible values of all indices
153 std::vector<size_t> _ranges;
154 /// Stores the permutation
155 std::vector<size_t> _sigma;
156
157 public:
158 /// Default constructor
159 Permute() : _ranges(), _sigma() {}
160
161 /// Construct from vector of index ranges and permutation
162 Permute( const std::vector<size_t> &rs, const std::vector<size_t> &sigma ) : _ranges(rs), _sigma(sigma) {
163 DAI_ASSERT( _ranges.size() == _sigma.size() );
164 }
165
166 /// Construct from vector of variables.
167 /** The implied permutation maps the index of each variable in \a vars according to the canonical ordering
168 * (i.e., sorted ascendingly according to their label) to the index it has in \a vars.
169 */
170 Permute( const std::vector<Var> &vars ) : _ranges(vars.size()), _sigma(vars.size()) {
171 for( size_t i = 0; i < vars.size(); ++i )
172 _ranges[i] = vars[i].states();
173 VarSet vs( vars.begin(), vars.end(), vars.size() );
174 VarSet::const_iterator vs_i = vs.begin();
175 for( size_t i = 0; i < vs.size(); ++i, ++vs_i )
176 _sigma[i] = find( vars.begin(), vars.end(), *vs_i ) - vars.begin();
177 }
178
179 /// Calculates a permuted linear index.
180 /** Converts the linear index \a li to a vector index, permutes its
181 * components, and converts it back to a linear index.
182 */
183 size_t convertLinearIndex( size_t li ) const {
184 size_t N = _ranges.size();
185
186 // calculate vector index corresponding to linear index
187 std::vector<size_t> vi;
188 vi.reserve( N );
189 size_t prod = 1;
190 for( size_t k = 0; k < N; k++ ) {
191 vi.push_back( li % _ranges[k] );
192 li /= _ranges[k];
193 prod *= _ranges[k];
194 }
195
196 // convert permuted vector index to corresponding linear index
197 prod = 1;
198 size_t sigma_li = 0;
199 for( size_t k = 0; k < N; k++ ) {
200 sigma_li += vi[_sigma[k]] * prod;
201 prod *= _ranges[_sigma[k]];
202 }
203
204 return sigma_li;
205 }
206
207 // OBSOLETE
208 /// For backwards compatibility (to be removed soon)
209 size_t convert_linear_index( size_t li ) const { return convertLinearIndex(li); }
210
211 /// Returns const reference to the permutation
212 const std::vector<size_t>& sigma() const { return _sigma; }
213
214 /// Returns reference to the permutation
215 std::vector<size_t>& sigma() { return _sigma; }
216
217 /// Returns the result of applying the permutation on \a i
218 size_t operator[]( size_t i ) const {
219 #ifdef DAI_DEBUG
220 return _sigma.at(i);
221 #else
222 return _sigma[i];
223 #endif
224 }
225 };
226
227
228 /// multifor makes it easy to perform a dynamic number of nested \c for loops.
229 /** An example of the usage is as follows:
230 * \code
231 * std::vector<size_t> ranges;
232 * ranges.push_back( 3 );
233 * ranges.push_back( 4 );
234 * ranges.push_back( 5 );
235 * for( multifor s(ranges); s.valid(); ++s )
236 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[2] << ", " << s[1] << ", " << s[0] << endl;
237 * \endcode
238 * which would be equivalent to:
239 * \code
240 * size_t s = 0;
241 * for( size_t s2 = 0; s2 < 5; s2++ )
242 * for( size_t s1 = 0; s1 < 4; s1++ )
243 * for( size_t s0 = 0; s0 < 3; s++, s0++ )
244 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s2 << ", " << s1 << ", " << s0 << endl;
245 * \endcode
246 */
247 class multifor {
248 private:
249 /// Stores the number of possible values of all indices
250 std::vector<size_t> _ranges;
251 /// Stores the current values of all indices
252 std::vector<size_t> _indices;
253 /// Stores the current linear index
254 long _linear_index;
255
256 public:
257 /// Default constructor
258 multifor() : _ranges(), _indices(), _linear_index(0) {}
259
260 /// Initialize from vector of index ranges
261 multifor( const std::vector<size_t> &d ) : _ranges(d), _indices(d.size(),0), _linear_index(0) {}
262
263 /// Returns linear index (i.e., the index in the Cartesian product space)
264 operator size_t() const {
265 DAI_DEBASSERT( valid() );
266 return( _linear_index );
267 }
268
269 /// Returns \a k 'th index
270 size_t operator[]( size_t k ) const {
271 DAI_DEBASSERT( valid() );
272 DAI_DEBASSERT( k < _indices.size() );
273 return _indices[k];
274 }
275
276 /// Increments the current indices (prefix)
277 multifor & operator++() {
278 if( valid() ) {
279 _linear_index++;
280 size_t i;
281 for( i = 0; i != _indices.size(); i++ ) {
282 if( ++(_indices[i]) < _ranges[i] )
283 break;
284 _indices[i] = 0;
285 }
286 if( i == _indices.size() )
287 _linear_index = -1;
288 }
289 return *this;
290 }
291
292 /// Increments the current indices (postfix)
293 void operator++( int ) {
294 operator++();
295 }
296
297 /// Returns \c true if the current indices are valid
298 bool valid() const {
299 return( _linear_index >= 0 );
300 }
301 };
302
303
304 /// Makes it easy to iterate over all possible joint states of variables within a VarSet.
305 /** A joint state of several variables can be represented in two different ways, by a map that maps each variable
306 * to its own state, or by an integer that gives the index of the joint state in the canonical enumeration.
307 *
308 * Both representations are useful, and the main functionality provided by the State class is to simplify iterating
309 * over the various joint states of a VarSet and to provide access to the current state in both representations.
310 *
311 * As an example, consider the following code snippet which iterates over all joint states of variables \a x0 and \a x1:
312 * \code
313 * VarSet vars( x0, x1 );
314 * for( State S(vars); S.valid(); S++ ) {
315 * cout << "Linear state: " << S.get() << ", x0 = " << S(x0) << ", x1 = " << S(x1) << endl;
316 * }
317 * \endcode
318 *
319 * \note The same functionality could be achieved by simply iterating over the linear state and using VarSet::calcStates,
320 * but the State class offers a more efficient implementation.
321 *
322 * \note A State is very similar to a \link multifor \endlink, but tailored for Var 's and VarSet 's.
323 *
324 * \see VarSet::calcState(), VarSet::calcStates()
325 *
326 * \todo Rename VarSet::calcState(), VarSet::calcStates() and make them non-member functions;
327 * make the State class a more prominent part of libDAI (and document it clearly, explaining
328 * the concept of state); add more optimized variants of the State class like IndexFor (e.g. for TFactor<>::slice()).
329 */
330 class State {
331 private:
332 /// Type for representing a joint state of some variables as a map, which maps each variable to its state
333 typedef std::map<Var, size_t> states_type;
334
335 /// Current state (represented linearly)
336 long state;
337
338 /// Current state (represented as a map)
339 states_type states;
340
341 public:
342 /// Default constructor
343 State() : state(0), states() {}
344
345 /// Construct from VarSet \a vs and corresponding linear state \a linearState
346 State( const VarSet &vs, size_t linearState=0 ) : state(linearState), states() {
347 if( linearState == 0 )
348 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
349 states[*v] = 0;
350 else {
351 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
352 states[*v] = linearState % v->states();
353 linearState /= v->states();
354 }
355 DAI_ASSERT( linearState == 0 );
356 }
357 }
358
359 /// Construct from a std::map<Var, size_t>
360 State( const std::map<Var, size_t> &s ) : state(0), states() {
361 insert( s.begin(), s.end() );
362 }
363
364 /// Constant iterator over the values
365 typedef states_type::const_iterator const_iterator;
366
367 /// Returns constant iterator that points to the first item
368 const_iterator begin() const { return states.begin(); }
369
370 /// Returns constant iterator that points beyond the last item
371 const_iterator end() const { return states.end(); }
372
373 /// Return current linear state
374 operator size_t() const {
375 DAI_ASSERT( valid() );
376 return( state );
377 }
378
379 /// Inserts a range of variable-state pairs, changing the current state
380 template<typename InputIterator>
381 void insert( InputIterator b, InputIterator e ) {
382 states.insert( b, e );
383 VarSet vars;
384 for( const_iterator it = begin(); it != end(); it++ )
385 vars |= it->first;
386 state = 0;
387 state = this->operator()( vars );
388 }
389
390 /// Return current state represented as a map
391 const std::map<Var,size_t>& get() const { return states; }
392
393 /// Cast into std::map<Var, size_t>
394 operator const std::map<Var,size_t>& () const { return states; }
395
396 /// Return current state of variable \a v, or 0 if \a v is not in \c *this
397 size_t operator() ( const Var &v ) const {
398 DAI_ASSERT( valid() );
399 states_type::const_iterator entry = states.find( v );
400 if( entry == states.end() )
401 return 0;
402 else
403 return entry->second;
404 }
405
406 /// Return linear state of variables in \a vs, assuming that variables that are not in \c *this are in state 0
407 size_t operator() ( const VarSet &vs ) const {
408 DAI_ASSERT( valid() );
409 size_t vs_state = 0;
410 size_t prod = 1;
411 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
412 states_type::const_iterator entry = states.find( *v );
413 if( entry != states.end() )
414 vs_state += entry->second * prod;
415 prod *= v->states();
416 }
417 return vs_state;
418 }
419
420 /// Increments the current state (prefix)
421 void operator++( ) {
422 if( valid() ) {
423 state++;
424 states_type::iterator entry = states.begin();
425 while( entry != states.end() ) {
426 if( ++(entry->second) < entry->first.states() )
427 break;
428 entry->second = 0;
429 entry++;
430 }
431 if( entry == states.end() )
432 state = -1;
433 }
434 }
435
436 /// Increments the current state (postfix)
437 void operator++( int ) {
438 operator++();
439 }
440
441 /// Returns \c true if the current state is valid
442 bool valid() const {
443 return( state >= 0 );
444 }
445
446 /// Resets the current state (to the joint state represented by linear state 0)
447 void reset() {
448 state = 0;
449 for( states_type::iterator s = states.begin(); s != states.end(); s++ )
450 s->second = 0;
451 }
452 };
453
454
455 } // end of namespace dai
456
457
458 /** \example example_permute.cpp
459 * This example shows how to use the Permute, multifor and State classes.
460 *
461 * \section Output
462 * \verbinclude examples/example_permute.out
463 *
464 * \section Source
465 */
466
467
468 #endif