Improved index.h (merged from SVN head), which yields a 25% speedup.
[libdai.git] / include / dai / index.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
3 Radboud University Nijmegen, The Netherlands
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #ifndef __defined_libdai_index_h
24 #define __defined_libdai_index_h
25
26
27 #include <vector>
28 #include <algorithm>
29 #include <map>
30 #include <cassert>
31 #include <dai/varset.h>
32
33
34 namespace dai {
35
36
37 /// Tool for looping over the states of several variables.
38 /** The class IndexFor is an important tool for indexing of Factors.
39 * Its usage can best be explained by an example.
40 * Assume indexVars, forVars are two VarSets.
41 * Then the following code:
42 * \code
43 * IndexFor i( indexVars, forVars );
44 * for( ; i >= 0; ++i ) {
45 * // use long(i)
46 * }
47 * \endcode
48 * loops over all joint states of the variables in forVars,
49 * and (long)i is equal to the linear index of the corresponding
50 * state of indexVars, where the variables in indexVars that are
51 * not in forVars assume their zero'th value.
52 */
53 class IndexFor {
54 private:
55 /// The current linear index corresponding to the state of indexVars
56 long _index;
57
58 /// For each variable in forVars, the amount of change in _index
59 std::vector<long> _sum;
60
61 /// For each variable in forVars, the current state
62 std::vector<size_t> _count;
63
64 /// For each variable in forVars, its number of possible values
65 std::vector<size_t> _dims;
66
67 public:
68 /// Default constructor
69 IndexFor() {
70 _index = -1;
71 }
72
73 /// Constructor
74 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
75 long sum = 1;
76
77 _dims.reserve( forVars.size() );
78 _sum.reserve( forVars.size() );
79
80 VarSet::const_iterator j = forVars.begin();
81 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
82 for( ; j != forVars.end() && *j <= *i; ++j ) {
83 _dims.push_back( j->states() );
84 _sum.push_back( (*i == *j) ? sum : 0 );
85 }
86 sum *= i->states();
87 }
88 for( ; j != forVars.end(); ++j ) {
89 _dims.push_back( j->states() );
90 _sum.push_back( 0 );
91 }
92 _index = 0;
93 }
94
95 /// Copy constructor
96 IndexFor( const IndexFor & ind ) : _index(ind._index), _sum(ind._sum), _count(ind._count), _dims(ind._dims) {}
97
98 /// Assignment operator
99 IndexFor& operator=( const IndexFor &ind ) {
100 if( this != &ind ) {
101 _index = ind._index;
102 _sum = ind._sum;
103 _count = ind._count;
104 _dims = ind._dims;
105 }
106 return *this;
107 }
108
109 /// Sets the index back to zero
110 IndexFor& clear() {
111 fill( _count.begin(), _count.end(), 0 );
112 _index = 0;
113 return( *this );
114 }
115
116 /// Conversion to long
117 operator long () const {
118 return( _index );
119 }
120
121 /// Pre-increment operator
122 IndexFor& operator++ () {
123 if( _index >= 0 ) {
124 size_t i = 0;
125
126 while( i < _count.size() ) {
127 _index += _sum[i];
128 if( ++_count[i] < _dims[i] )
129 break;
130 _index -= _sum[i] * _dims[i];
131 _count[i] = 0;
132 i++;
133 }
134
135 if( i == _count.size() )
136 _index = -1;
137 }
138 return( *this );
139 }
140 };
141
142
143 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
144 /** An example of the usage is as follows:
145 * \code
146 * std::vector<size_t> dims;
147 * dims.push_back( 3 );
148 * dims.push_back( 4 );
149 * dims.push_back( 5 );
150 * for( MultiFor s(dims); s.valid(); ++s )
151 * cout << "linear index: " << (size_t)s << " corresponds with indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
152 * \endcode
153 * which would be equivalent to:
154 * \code
155 * size_t s = 0;
156 * for( size_t s0 = 0; s0 < 3; s0++ )
157 * for( size_t s1 = 0; s1 < 4; s1++ )
158 * for( size_t s2 = 0; s2 < 5; s++, s2++ )
159 * cout << "linear index: " << (size_t)s << " corresponds with indices " << s0 << ", " << s1 << ", " << s2 << endl;
160 * \endcode
161 */
162 class MultiFor {
163 private:
164 std::vector<size_t> _dims;
165 std::vector<size_t> _states;
166 long _state;
167
168 public:
169 /// Default constructor
170 MultiFor() : _dims(), _states(), _state(0) {}
171
172 /// Initialize from vector of index dimensions
173 MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
174
175 /// Copy constructor
176 MultiFor( const MultiFor &x ) : _dims(x._dims), _states(x._states), _state(x._state) {}
177
178 /// Assignment operator
179 MultiFor& operator=( const MultiFor & x ) {
180 if( this != &x ) {
181 _dims = x._dims;
182 _states = x._states;
183 _state = x._state;
184 }
185 return *this;
186 }
187
188 /// Return linear state
189 operator size_t() const {
190 assert( valid() );
191 return( _state );
192 }
193
194 /// Return k'th index
195 size_t operator[]( size_t k ) const {
196 assert( valid() );
197 assert( k < _states.size() );
198 return _states[k];
199 }
200
201 /// Prefix increment operator
202 MultiFor & operator++() {
203 if( valid() ) {
204 _state++;
205 size_t i;
206 for( i = 0; i != _states.size(); i++ ) {
207 if( ++(_states[i]) < _dims[i] )
208 break;
209 _states[i] = 0;
210 }
211 if( i == _states.size() )
212 _state = -1;
213 }
214 return *this;
215 }
216
217 /// Postfix increment operator
218 void operator++( int ) {
219 operator++();
220 }
221
222 /// Returns true if the current state is valid
223 bool valid() const {
224 return( _state >= 0 );
225 }
226 };
227
228
229 /// Tool for calculating permutations of multiple indices.
230 class Permute {
231 private:
232 std::vector<size_t> _dims;
233 std::vector<size_t> _sigma;
234
235 public:
236 /// Default constructor
237 Permute() : _dims(), _sigma() {}
238
239 /// Initialize from vector of index dimensions and permutation sigma
240 Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
241 assert( _dims.size() == _sigma.size() );
242 }
243
244 /// Copy constructor
245 Permute( const Permute &x ) : _dims(x._dims), _sigma(x._sigma) {}
246
247 /// Assignment operator
248 Permute& operator=( const Permute &x ) {
249 if( this != &x ) {
250 _dims = x._dims;
251 _sigma = x._sigma;
252 }
253 return *this;
254 }
255
256 /// Converts the linear index li to a vector index
257 /// corresponding with the dimensions in _dims,
258 /// permutes it according to sigma,
259 /// and converts it back to a linear index
260 /// according to the permuted dimensions.
261 size_t convert_linear_index( size_t li ) {
262 size_t N = _dims.size();
263
264 // calculate vector index corresponding to linear index
265 std::vector<size_t> vi;
266 vi.reserve( N );
267 size_t prod = 1;
268 for( size_t k = 0; k < N; k++ ) {
269 vi.push_back( li % _dims[k] );
270 li /= _dims[k];
271 prod *= _dims[k];
272 }
273
274 // convert permuted vector index to corresponding linear index
275 prod = 1;
276 size_t sigma_li = 0;
277 for( size_t k = 0; k < N; k++ ) {
278 sigma_li += vi[_sigma[k]] * prod;
279 prod *= _dims[_sigma[k]];
280 }
281
282 return sigma_li;
283 }
284 };
285
286
287 /// Contains the state of variables within a VarSet and useful things to do with this information.
288 /// This is very similar to a MultiFor, but tailored for Vars and Varsets.
289 class State {
290 private:
291 typedef std::map<Var, size_t> states_type;
292
293 long state;
294 states_type states;
295
296 public:
297 /// Default constructor
298 State() : state(0), states() {}
299
300 /// Initialize from VarSet
301 State( const VarSet &vs ) : state(0) {
302 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
303 states[*v] = 0;
304 }
305
306 /// Copy constructor
307 State( const State & x ) : state(x.state), states(x.states) {}
308
309 /// Assignment operator
310 State& operator=( const State &x ) {
311 if( this != &x ) {
312 state = x.state;
313 states = x.states;
314 }
315 return *this;
316 }
317
318 /// Return linear state
319 operator size_t() const {
320 assert( valid() );
321 return( state );
322 }
323
324 /// Return state of variable n,
325 /// or zero if n is not in this State
326 size_t operator() ( const Var &n ) const {
327 assert( valid() );
328 states_type::const_iterator entry = states.find( n );
329 if( entry == states.end() )
330 return 0;
331 else
332 return entry->second;
333 }
334
335 /// Return linear state of variables in varset,
336 /// setting them to zero if they are not in this State
337 size_t operator() ( const VarSet &vs ) const {
338 assert( valid() );
339 size_t vs_state = 0;
340 size_t prod = 1;
341 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
342 states_type::const_iterator entry = states.find( *v );
343 if( entry != states.end() )
344 vs_state += entry->second * prod;
345 prod *= v->states();
346 }
347 return vs_state;
348 }
349
350 /// Postfix increment operator
351 void operator++( int ) {
352 if( valid() ) {
353 state++;
354 states_type::iterator entry = states.begin();
355 while( entry != states.end() ) {
356 if( ++(entry->second) < entry->first.states() )
357 break;
358 entry->second = 0;
359 entry++;
360 }
361 if( entry == states.end() )
362 state = -1;
363 }
364 }
365
366 /// Returns true if the current state is valid
367 bool valid() const {
368 return( state >= 0 );
369 }
370 };
371
372
373 } // end of namespace dai
374
375
376 #endif