1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
8 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
9 * Copyright (C) 2002-2007 Radboud University Nijmegen, The Netherlands
14 /// \brief Defines the IndexFor, multifor, Permute and State classes, which all deal with indexing multi-dimensional arrays
17 #ifndef __defined_libdai_index_h
18 #define __defined_libdai_index_h
24 #include <dai/varset.h>
30 /// Tool for looping over the states of several variables.
31 /** The class IndexFor is an important tool for indexing Factor entries.
32 * Its usage can best be explained by an example.
33 * Assume \a indexVars, \a forVars are both VarSet 's.
34 * Then the following code:
36 * IndexFor i( indexVars, forVars );
38 * for( ; i.valid(); i++, iter++ ) {
39 * cout << "State of forVars: " << calcState( forVars, iter ) << "; ";
40 * cout << "state of indexVars: " << calcState( indexVars, long(i) ) << endl;
43 * loops over all joint states of the variables in \a forVars,
44 * and <tt>(long)i</tt> equals the linear index of the corresponding
45 * state of \a indexVars, where the variables in \a indexVars that are
46 * not in \a forVars assume their zero'th value.
47 * \idea Optimize all indices as follows: keep a cache of all (or only
48 * relatively small) indices that have been computed (use a hash). Then,
49 * instead of computing on the fly, use the precomputed ones. Here the
50 * labels of the variables don't matter, but the ranges of the variables do.
54 /// The current linear index corresponding to the state of indexVars
57 /// For each variable in forVars, the amount of change in _index
58 std::vector
<long> _sum
;
60 /// For each variable in forVars, the current state
61 std::vector
<size_t> _state
;
63 /// For each variable in forVars, its number of possible values
64 std::vector
<size_t> _ranges
;
67 /// Default constructor
68 IndexFor() : _index(-1) {}
70 /// Construct IndexFor object from \a indexVars and \a forVars
71 IndexFor( const VarSet
& indexVars
, const VarSet
& forVars
) : _state( forVars
.size(), 0 ) {
74 _ranges
.reserve( forVars
.size() );
75 _sum
.reserve( forVars
.size() );
77 VarSet::const_iterator j
= forVars
.begin();
78 for( VarSet::const_iterator i
= indexVars
.begin(); i
!= indexVars
.end(); ++i
) {
79 for( ; j
!= forVars
.end() && *j
<= *i
; ++j
) {
80 _ranges
.push_back( j
->states() );
81 _sum
.push_back( (*i
== *j
) ? sum
: 0 );
85 for( ; j
!= forVars
.end(); ++j
) {
86 _ranges
.push_back( j
->states() );
94 fill( _state
.begin(), _state
.end(), 0 );
100 /// Conversion to \c long: returns linear index of the current state of indexVars
101 /** \deprecated Will be replaced by an operator size_t()
103 operator long () const {
109 /// Conversion to \c size_t: returns linear index of the current state of indexVars
110 operator size_t() const {
111 DAI_ASSERT( valid() );
116 /// Increments the current state of \a forVars (prefix)
117 IndexFor
& operator++ () {
121 while( i
< _state
.size() ) {
123 if( ++_state
[i
] < _ranges
[i
] )
125 _index
-= _sum
[i
] * _ranges
[i
];
130 if( i
== _state
.size() )
136 /// Increments the current state of \a forVars (postfix)
137 void operator++( int ) {
141 /// Returns \c true if the current state is valid
143 return( _index
>= 0 );
148 /// Tool for calculating permutations of linear indices of multi-dimensional arrays.
149 /** \note This is mainly useful for converting indices into multi-dimensional arrays
150 * corresponding to joint states of variables to and from the canonical ordering used in libDAI.
154 /// Stores the number of possible values of all indices
155 std::vector
<size_t> _ranges
;
156 /// Stores the permutation
157 std::vector
<size_t> _sigma
;
160 /// Default constructor
161 Permute() : _ranges(), _sigma() {}
163 /// Construct from vector of index ranges and permutation
164 Permute( const std::vector
<size_t> &rs
, const std::vector
<size_t> &sigma
) : _ranges(rs
), _sigma(sigma
) {
165 DAI_ASSERT( _ranges
.size() == _sigma
.size() );
168 /// Construct from vector of variables.
169 /** The implied permutation maps the index of each variable in \a vars according to the canonical ordering
170 * (i.e., sorted ascendingly according to their label) to the index it has in \a vars.
172 Permute( const std::vector
<Var
> &vars
) : _ranges(vars
.size()), _sigma(vars
.size()) {
173 for( size_t i
= 0; i
< vars
.size(); ++i
)
174 _ranges
[i
] = vars
[i
].states();
175 VarSet
vs( vars
.begin(), vars
.end(), vars
.size() );
176 VarSet::const_iterator vs_i
= vs
.begin();
177 for( size_t i
= 0; i
< vs
.size(); ++i
, ++vs_i
)
178 _sigma
[i
] = find( vars
.begin(), vars
.end(), *vs_i
) - vars
.begin();
181 /// Calculates a permuted linear index.
182 /** Converts the linear index \a li to a vector index, permutes its
183 * components, and converts it back to a linear index.
185 size_t convertLinearIndex( size_t li
) const {
186 size_t N
= _ranges
.size();
188 // calculate vector index corresponding to linear index
189 std::vector
<size_t> vi
;
192 for( size_t k
= 0; k
< N
; k
++ ) {
193 vi
.push_back( li
% _ranges
[k
] );
198 // convert permuted vector index to corresponding linear index
201 for( size_t k
= 0; k
< N
; k
++ ) {
202 sigma_li
+= vi
[_sigma
[k
]] * prod
;
203 prod
*= _ranges
[_sigma
[k
]];
210 /// Calculates a permuted linear index
211 /** \deprecated Renamed into dai::Permute::convertLinearIndex()
213 size_t convert_linear_index( size_t li
) const { return convertLinearIndex(li
); }
215 /// Returns const reference to the permutation
216 const std::vector
<size_t>& sigma() const { return _sigma
; }
218 /// Returns reference to the permutation
219 std::vector
<size_t>& sigma() { return _sigma
; }
221 /// Returns the result of applying the permutation on \a i
222 size_t operator[]( size_t i
) const {
232 /// multifor makes it easy to perform a dynamic number of nested \c for loops.
233 /** An example of the usage is as follows:
235 * std::vector<size_t> ranges;
236 * ranges.push_back( 3 );
237 * ranges.push_back( 4 );
238 * ranges.push_back( 5 );
239 * for( multifor s(ranges); s.valid(); ++s )
240 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[2] << ", " << s[1] << ", " << s[0] << endl;
242 * which would be equivalent to:
245 * for( size_t s2 = 0; s2 < 5; s2++ )
246 * for( size_t s1 = 0; s1 < 4; s1++ )
247 * for( size_t s0 = 0; s0 < 3; s++, s0++ )
248 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s2 << ", " << s1 << ", " << s0 << endl;
253 /// Stores the number of possible values of all indices
254 std::vector
<size_t> _ranges
;
255 /// Stores the current values of all indices
256 std::vector
<size_t> _indices
;
257 /// Stores the current linear index
261 /// Default constructor
262 multifor() : _ranges(), _indices(), _linear_index(0) {}
264 /// Initialize from vector of index ranges
265 multifor( const std::vector
<size_t> &d
) : _ranges(d
), _indices(d
.size(),0), _linear_index(0) {}
267 /// Returns linear index (i.e., the index in the Cartesian product space)
268 operator size_t() const {
269 DAI_DEBASSERT( valid() );
270 return( _linear_index
);
273 /// Returns \a k 'th index
274 size_t operator[]( size_t k
) const {
275 DAI_DEBASSERT( valid() );
276 DAI_DEBASSERT( k
< _indices
.size() );
280 /// Increments the current indices (prefix)
281 multifor
& operator++() {
285 for( i
= 0; i
!= _indices
.size(); i
++ ) {
286 if( ++(_indices
[i
]) < _ranges
[i
] )
290 if( i
== _indices
.size() )
296 /// Increments the current indices (postfix)
297 void operator++( int ) {
301 /// Returns \c true if the current indices are valid
303 return( _linear_index
>= 0 );
308 /// Makes it easy to iterate over all possible joint states of variables within a VarSet.
309 /** A joint state of several variables can be represented in two different ways, by a map that maps each variable
310 * to its own state, or by an integer that gives the index of the joint state in the canonical enumeration.
312 * Both representations are useful, and the main functionality provided by the State class is to simplify iterating
313 * over the various joint states of a VarSet and to provide access to the current state in both representations.
315 * As an example, consider the following code snippet which iterates over all joint states of variables \a x0 and \a x1:
317 * VarSet vars( x0, x1 );
318 * for( State S(vars); S.valid(); S++ ) {
319 * cout << "Linear state: " << S.get() << ", x0 = " << S(x0) << ", x1 = " << S(x1) << endl;
323 * \note The same functionality could be achieved by simply iterating over the linear state and using dai::calcState(),
324 * but the State class offers a more efficient implementation.
326 * \note A State is very similar to a \link multifor \endlink, but tailored for Var 's and VarSet 's.
328 * \see dai::calcLinearState(), dai::calcState()
330 * \idea Make the State class a more prominent part of libDAI
331 * (and document it clearly, explaining the concept of state);
332 * add more optimized variants of the State class like IndexFor
333 * (e.g. for TFactor<>::slice()).
337 /// Type for representing a joint state of some variables as a map, which maps each variable to its state
338 typedef std::map
<Var
, size_t> states_type
;
340 /// Current state (represented linearly)
343 /// Current state (represented as a map)
347 /// Default constructor
348 State() : state(0), states() {}
350 /// Construct from VarSet \a vs and corresponding linear state \a linearState
351 State( const VarSet
&vs
, size_t linearState
=0 ) : state(linearState
), states() {
352 if( linearState
== 0 )
353 for( VarSet::const_iterator v
= vs
.begin(); v
!= vs
.end(); v
++ )
356 for( VarSet::const_iterator v
= vs
.begin(); v
!= vs
.end(); v
++ ) {
357 states
[*v
] = linearState
% v
->states();
358 linearState
/= v
->states();
360 DAI_ASSERT( linearState
== 0 );
364 /// Construct from a std::map<Var, size_t>
365 State( const std::map
<Var
, size_t> &s
) : state(0), states() {
366 insert( s
.begin(), s
.end() );
369 /// Constant iterator over the values
370 typedef states_type::const_iterator const_iterator
;
372 /// Returns constant iterator that points to the first item
373 const_iterator
begin() const { return states
.begin(); }
375 /// Returns constant iterator that points beyond the last item
376 const_iterator
end() const { return states
.end(); }
378 /// Return current linear state
379 operator size_t() const {
380 DAI_ASSERT( valid() );
384 /// Inserts a range of variable-state pairs, changing the current state
385 template<typename InputIterator
>
386 void insert( InputIterator b
, InputIterator e
) {
387 states
.insert( b
, e
);
389 for( const_iterator it
= begin(); it
!= end(); it
++ )
392 state
= this->operator()( vars
);
395 /// Return current state represented as a map
396 const std::map
<Var
,size_t>& get() const { return states
; }
398 /// Cast into std::map<Var, size_t>
399 operator const std::map
<Var
,size_t>& () const { return states
; }
401 /// Return current state of variable \a v, or 0 if \a v is not in \c *this
402 size_t operator() ( const Var
&v
) const {
403 DAI_ASSERT( valid() );
404 states_type::const_iterator entry
= states
.find( v
);
405 if( entry
== states
.end() )
408 return entry
->second
;
411 /// Return linear state of variables in \a vs, assuming that variables that are not in \c *this are in state 0
412 size_t operator() ( const VarSet
&vs
) const {
413 DAI_ASSERT( valid() );
416 for( VarSet::const_iterator v
= vs
.begin(); v
!= vs
.end(); v
++ ) {
417 states_type::const_iterator entry
= states
.find( *v
);
418 if( entry
!= states
.end() )
419 vs_state
+= entry
->second
* prod
;
425 /// Increments the current state (prefix)
429 states_type::iterator entry
= states
.begin();
430 while( entry
!= states
.end() ) {
431 if( ++(entry
->second
) < entry
->first
.states() )
436 if( entry
== states
.end() )
441 /// Increments the current state (postfix)
442 void operator++( int ) {
446 /// Returns \c true if the current state is valid
448 return( state
>= 0 );
451 /// Resets the current state (to the joint state represented by linear state 0)
454 for( states_type::iterator s
= states
.begin(); s
!= states
.end(); s
++ )
460 } // end of namespace dai
463 /** \example example_permute.cpp
464 * This example shows how to use the Permute, multifor and State classes.
467 * \verbinclude examples/example_permute.out