1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
8 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
9 * Copyright (C) 2002-2007 Radboud University Nijmegen, The Netherlands
14 /// \brief Defines the IndexFor, MultiFor, Permute and State classes
15 /// \todo Improve documentation
18 #ifndef __defined_libdai_index_h
19 #define __defined_libdai_index_h
26 #include <dai/varset.h>
32 /// Tool for looping over the states of several variables.
33 /** The class IndexFor is an important tool for indexing Factor entries.
34 * Its usage can best be explained by an example.
35 * Assume indexVars, forVars are both VarSets.
36 * Then the following code:
38 * IndexFor i( indexVars, forVars );
39 * for( ; i >= 0; ++i ) {
43 * loops over all joint states of the variables in forVars,
44 * and (long)i is equal to the linear index of the corresponding
45 * state of indexVars, where the variables in indexVars that are
46 * not in forVars assume their zero'th value.
47 * \idea Optimize all indices as follows: keep a cache of all (or only
48 * relatively small) indices that have been computed (use a hash). Then,
49 * instead of computing on the fly, use the precomputed ones.
53 /// The current linear index corresponding to the state of indexVars
56 /// For each variable in forVars, the amount of change in _index
57 std::vector
<long> _sum
;
59 /// For each variable in forVars, the current state
60 std::vector
<size_t> _count
;
62 /// For each variable in forVars, its number of possible values
63 std::vector
<size_t> _dims
;
66 /// Default constructor
72 IndexFor( const VarSet
& indexVars
, const VarSet
& forVars
) : _count( forVars
.size(), 0 ) {
75 _dims
.reserve( forVars
.size() );
76 _sum
.reserve( forVars
.size() );
78 VarSet::const_iterator j
= forVars
.begin();
79 for( VarSet::const_iterator i
= indexVars
.begin(); i
!= indexVars
.end(); ++i
) {
80 for( ; j
!= forVars
.end() && *j
<= *i
; ++j
) {
81 _dims
.push_back( j
->states() );
82 _sum
.push_back( (*i
== *j
) ? sum
: 0 );
86 for( ; j
!= forVars
.end(); ++j
) {
87 _dims
.push_back( j
->states() );
93 /// Sets the index back to zero
95 fill( _count
.begin(), _count
.end(), 0 );
100 /// Conversion to long
101 operator long () const {
105 /// Pre-increment operator
106 IndexFor
& operator++ () {
110 while( i
< _count
.size() ) {
112 if( ++_count
[i
] < _dims
[i
] )
114 _index
-= _sum
[i
] * _dims
[i
];
119 if( i
== _count
.size() )
127 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
128 /** An example of the usage is as follows:
130 * std::vector<size_t> dims;
131 * dims.push_back( 3 );
132 * dims.push_back( 4 );
133 * dims.push_back( 5 );
134 * for( MultiFor s(dims); s.valid(); ++s )
135 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
137 * which would be equivalent to:
140 * for( size_t s0 = 0; s0 < 3; s0++ )
141 * for( size_t s1 = 0; s1 < 4; s1++ )
142 * for( size_t s2 = 0; s2 < 5; s++, s2++ )
143 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s0 << ", " << s1 << ", " << s2 << endl;
148 std::vector
<size_t> _dims
;
149 std::vector
<size_t> _states
;
153 /// Default constructor
154 MultiFor() : _dims(), _states(), _state(0) {}
156 /// Initialize from vector of index dimensions
157 MultiFor( const std::vector
<size_t> &d
) : _dims(d
), _states(d
.size(),0), _state(0) {}
159 /// Return linear state
160 operator size_t() const {
165 /// Return k'th index
166 size_t operator[]( size_t k
) const {
168 assert( k
< _states
.size() );
172 /// Prefix increment operator
173 MultiFor
& operator++() {
177 for( i
= 0; i
!= _states
.size(); i
++ ) {
178 if( ++(_states
[i
]) < _dims
[i
] )
182 if( i
== _states
.size() )
188 /// Postfix increment operator
189 void operator++( int ) {
193 /// Returns true if the current state is valid
195 return( _state
>= 0 );
200 /// Tool for calculating permutations of multiple indices.
203 std::vector
<size_t> _dims
;
204 std::vector
<size_t> _sigma
;
207 /// Default constructor
208 Permute() : _dims(), _sigma() {}
210 /// Construct from vector of index dimensions and permutation sigma
211 Permute( const std::vector
<size_t> &d
, const std::vector
<size_t> &sigma
) : _dims(d
), _sigma(sigma
) {
212 assert( _dims
.size() == _sigma
.size() );
215 /// Construct from vector of variables
216 Permute( const std::vector
<Var
> &vars
) : _dims(vars
.size()), _sigma(vars
.size()) {
217 VarSet
vs( vars
.begin(), vars
.end(), vars
.size() );
218 for( size_t i
= 0; i
< vars
.size(); ++i
)
219 _dims
[i
] = vars
[i
].states();
220 VarSet::const_iterator set_iter
= vs
.begin();
221 for( size_t i
= 0; i
< vs
.size(); ++i
, ++set_iter
)
222 _sigma
[i
] = find( vars
.begin(), vars
.end(), *set_iter
) - vars
.begin();
225 /// Calculates a permuted linear index.
226 /** Converts the linear index li to a vector index
227 * corresponding with the dimensions in _dims, permutes it according to sigma,
228 * and converts it back to a linear index according to the permuted dimensions.
230 size_t convert_linear_index( size_t li
) const {
231 size_t N
= _dims
.size();
233 // calculate vector index corresponding to linear index
234 std::vector
<size_t> vi
;
237 for( size_t k
= 0; k
< N
; k
++ ) {
238 vi
.push_back( li
% _dims
[k
] );
243 // convert permuted vector index to corresponding linear index
246 for( size_t k
= 0; k
< N
; k
++ ) {
247 sigma_li
+= vi
[_sigma
[k
]] * prod
;
248 prod
*= _dims
[_sigma
[k
]];
256 /// Contains the joint state of variables within a VarSet and useful things to do with this information.
257 /** This is very similar to a MultiFor, but tailored for Vars and Varsets.
261 typedef std::map
<Var
, size_t> states_type
;
267 /// Default constructor
268 State() : state(0), states() {}
270 /// Initialize from VarSet
271 State( const VarSet
&vs
) : state(0) {
272 for( VarSet::const_iterator v
= vs
.begin(); v
!= vs
.end(); v
++ )
276 /// Return linear state
277 operator size_t() const {
282 /// Return state of variable n, or zero if n is not in this State
283 size_t operator() ( const Var
&n
) const {
285 states_type::const_iterator entry
= states
.find( n
);
286 if( entry
== states
.end() )
289 return entry
->second
;
292 /// Return linear state of variables in varset, setting them to zero if they are not in this State
293 size_t operator() ( const VarSet
&vs
) const {
297 for( VarSet::const_iterator v
= vs
.begin(); v
!= vs
.end(); v
++ ) {
298 states_type::const_iterator entry
= states
.find( *v
);
299 if( entry
!= states
.end() )
300 vs_state
+= entry
->second
* prod
;
306 /// Prefix increment operator
310 states_type::iterator entry
= states
.begin();
311 while( entry
!= states
.end() ) {
312 if( ++(entry
->second
) < entry
->first
.states() )
317 if( entry
== states
.end() )
322 /// Postfix increment operator
323 void operator++( int ) {
327 /// Returns true if the current state is valid
329 return( state
>= 0 );
334 } // end of namespace dai