Updated copyright headers
[libdai.git] / include / dai / index.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
8 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
9 * Copyright (C) 2002-2007 Radboud University Nijmegen, The Netherlands
10 */
11
12
13 /// \file
14 /// \brief Defines the IndexFor, MultiFor, Permute and State classes
15 /// \todo Improve documentation
16
17
18 #ifndef __defined_libdai_index_h
19 #define __defined_libdai_index_h
20
21
22 #include <vector>
23 #include <algorithm>
24 #include <map>
25 #include <cassert>
26 #include <dai/varset.h>
27
28
29 namespace dai {
30
31
32 /// Tool for looping over the states of several variables.
33 /** The class IndexFor is an important tool for indexing Factor entries.
34 * Its usage can best be explained by an example.
35 * Assume indexVars, forVars are both VarSets.
36 * Then the following code:
37 * \code
38 * IndexFor i( indexVars, forVars );
39 * for( ; i >= 0; ++i ) {
40 * // use long(i)
41 * }
42 * \endcode
43 * loops over all joint states of the variables in forVars,
44 * and (long)i is equal to the linear index of the corresponding
45 * state of indexVars, where the variables in indexVars that are
46 * not in forVars assume their zero'th value.
47 * \idea Optimize all indices as follows: keep a cache of all (or only
48 * relatively small) indices that have been computed (use a hash). Then,
49 * instead of computing on the fly, use the precomputed ones.
50 */
51 class IndexFor {
52 private:
53 /// The current linear index corresponding to the state of indexVars
54 long _index;
55
56 /// For each variable in forVars, the amount of change in _index
57 std::vector<long> _sum;
58
59 /// For each variable in forVars, the current state
60 std::vector<size_t> _count;
61
62 /// For each variable in forVars, its number of possible values
63 std::vector<size_t> _dims;
64
65 public:
66 /// Default constructor
67 IndexFor() {
68 _index = -1;
69 }
70
71 /// Constructor
72 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
73 long sum = 1;
74
75 _dims.reserve( forVars.size() );
76 _sum.reserve( forVars.size() );
77
78 VarSet::const_iterator j = forVars.begin();
79 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
80 for( ; j != forVars.end() && *j <= *i; ++j ) {
81 _dims.push_back( j->states() );
82 _sum.push_back( (*i == *j) ? sum : 0 );
83 }
84 sum *= i->states();
85 }
86 for( ; j != forVars.end(); ++j ) {
87 _dims.push_back( j->states() );
88 _sum.push_back( 0 );
89 }
90 _index = 0;
91 }
92
93 /// Sets the index back to zero
94 IndexFor& clear() {
95 fill( _count.begin(), _count.end(), 0 );
96 _index = 0;
97 return( *this );
98 }
99
100 /// Conversion to long
101 operator long () const {
102 return( _index );
103 }
104
105 /// Pre-increment operator
106 IndexFor& operator++ () {
107 if( _index >= 0 ) {
108 size_t i = 0;
109
110 while( i < _count.size() ) {
111 _index += _sum[i];
112 if( ++_count[i] < _dims[i] )
113 break;
114 _index -= _sum[i] * _dims[i];
115 _count[i] = 0;
116 i++;
117 }
118
119 if( i == _count.size() )
120 _index = -1;
121 }
122 return( *this );
123 }
124 };
125
126
127 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
128 /** An example of the usage is as follows:
129 * \code
130 * std::vector<size_t> dims;
131 * dims.push_back( 3 );
132 * dims.push_back( 4 );
133 * dims.push_back( 5 );
134 * for( MultiFor s(dims); s.valid(); ++s )
135 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
136 * \endcode
137 * which would be equivalent to:
138 * \code
139 * size_t s = 0;
140 * for( size_t s0 = 0; s0 < 3; s0++ )
141 * for( size_t s1 = 0; s1 < 4; s1++ )
142 * for( size_t s2 = 0; s2 < 5; s++, s2++ )
143 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s0 << ", " << s1 << ", " << s2 << endl;
144 * \endcode
145 */
146 class MultiFor {
147 private:
148 std::vector<size_t> _dims;
149 std::vector<size_t> _states;
150 long _state;
151
152 public:
153 /// Default constructor
154 MultiFor() : _dims(), _states(), _state(0) {}
155
156 /// Initialize from vector of index dimensions
157 MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
158
159 /// Return linear state
160 operator size_t() const {
161 assert( valid() );
162 return( _state );
163 }
164
165 /// Return k'th index
166 size_t operator[]( size_t k ) const {
167 assert( valid() );
168 assert( k < _states.size() );
169 return _states[k];
170 }
171
172 /// Prefix increment operator
173 MultiFor & operator++() {
174 if( valid() ) {
175 _state++;
176 size_t i;
177 for( i = 0; i != _states.size(); i++ ) {
178 if( ++(_states[i]) < _dims[i] )
179 break;
180 _states[i] = 0;
181 }
182 if( i == _states.size() )
183 _state = -1;
184 }
185 return *this;
186 }
187
188 /// Postfix increment operator
189 void operator++( int ) {
190 operator++();
191 }
192
193 /// Returns true if the current state is valid
194 bool valid() const {
195 return( _state >= 0 );
196 }
197 };
198
199
200 /// Tool for calculating permutations of multiple indices.
201 class Permute {
202 private:
203 std::vector<size_t> _dims;
204 std::vector<size_t> _sigma;
205
206 public:
207 /// Default constructor
208 Permute() : _dims(), _sigma() {}
209
210 /// Construct from vector of index dimensions and permutation sigma
211 Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
212 assert( _dims.size() == _sigma.size() );
213 }
214
215 /// Construct from vector of variables
216 Permute( const std::vector<Var> &vars ) : _dims(vars.size()), _sigma(vars.size()) {
217 VarSet vs( vars.begin(), vars.end(), vars.size() );
218 for( size_t i = 0; i < vars.size(); ++i )
219 _dims[i] = vars[i].states();
220 VarSet::const_iterator set_iter = vs.begin();
221 for( size_t i = 0; i < vs.size(); ++i, ++set_iter )
222 _sigma[i] = find( vars.begin(), vars.end(), *set_iter ) - vars.begin();
223 }
224
225 /// Calculates a permuted linear index.
226 /** Converts the linear index li to a vector index
227 * corresponding with the dimensions in _dims, permutes it according to sigma,
228 * and converts it back to a linear index according to the permuted dimensions.
229 */
230 size_t convert_linear_index( size_t li ) const {
231 size_t N = _dims.size();
232
233 // calculate vector index corresponding to linear index
234 std::vector<size_t> vi;
235 vi.reserve( N );
236 size_t prod = 1;
237 for( size_t k = 0; k < N; k++ ) {
238 vi.push_back( li % _dims[k] );
239 li /= _dims[k];
240 prod *= _dims[k];
241 }
242
243 // convert permuted vector index to corresponding linear index
244 prod = 1;
245 size_t sigma_li = 0;
246 for( size_t k = 0; k < N; k++ ) {
247 sigma_li += vi[_sigma[k]] * prod;
248 prod *= _dims[_sigma[k]];
249 }
250
251 return sigma_li;
252 }
253 };
254
255
256 /// Contains the joint state of variables within a VarSet and useful things to do with this information.
257 /** This is very similar to a MultiFor, but tailored for Vars and Varsets.
258 */
259 class State {
260 private:
261 typedef std::map<Var, size_t> states_type;
262
263 long state;
264 states_type states;
265
266 public:
267 /// Default constructor
268 State() : state(0), states() {}
269
270 /// Initialize from VarSet
271 State( const VarSet &vs ) : state(0) {
272 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
273 states[*v] = 0;
274 }
275
276 /// Return linear state
277 operator size_t() const {
278 assert( valid() );
279 return( state );
280 }
281
282 /// Return state of variable n, or zero if n is not in this State
283 size_t operator() ( const Var &n ) const {
284 assert( valid() );
285 states_type::const_iterator entry = states.find( n );
286 if( entry == states.end() )
287 return 0;
288 else
289 return entry->second;
290 }
291
292 /// Return linear state of variables in varset, setting them to zero if they are not in this State
293 size_t operator() ( const VarSet &vs ) const {
294 assert( valid() );
295 size_t vs_state = 0;
296 size_t prod = 1;
297 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
298 states_type::const_iterator entry = states.find( *v );
299 if( entry != states.end() )
300 vs_state += entry->second * prod;
301 prod *= v->states();
302 }
303 return vs_state;
304 }
305
306 /// Prefix increment operator
307 void operator++( ) {
308 if( valid() ) {
309 state++;
310 states_type::iterator entry = states.begin();
311 while( entry != states.end() ) {
312 if( ++(entry->second) < entry->first.states() )
313 break;
314 entry->second = 0;
315 entry++;
316 }
317 if( entry == states.end() )
318 state = -1;
319 }
320 }
321
322 /// Postfix increment operator
323 void operator++( int ) {
324 operator++();
325 }
326
327 /// Returns true if the current state is valid
328 bool valid() const {
329 return( state >= 0 );
330 }
331 };
332
333
334 } // end of namespace dai
335
336
337 #endif