Updated copyrights
[libdai.git] / include / dai / index.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
7
8 This file is part of libDAI.
9
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25
26 #ifndef __defined_libdai_index_h
27 #define __defined_libdai_index_h
28
29
30 #include <vector>
31 #include <algorithm>
32 #include <map>
33 #include <cassert>
34 #include <dai/varset.h>
35
36
37 namespace dai {
38
39
40 /// Tool for looping over the states of several variables.
41 /** The class IndexFor is an important tool for indexing of Factors.
42 * Its usage can best be explained by an example.
43 * Assume indexVars, forVars are two VarSets.
44 * Then the following code:
45 * \code
46 * IndexFor i( indexVars, forVars );
47 * for( ; i >= 0; ++i ) {
48 * // use long(i)
49 * }
50 * \endcode
51 * loops over all joint states of the variables in forVars,
52 * and (long)i is equal to the linear index of the corresponding
53 * state of indexVars, where the variables in indexVars that are
54 * not in forVars assume their zero'th value.
55 */
56 class IndexFor {
57 private:
58 /// The current linear index corresponding to the state of indexVars
59 long _index;
60
61 /// For each variable in forVars, the amount of change in _index
62 std::vector<long> _sum;
63
64 /// For each variable in forVars, the current state
65 std::vector<size_t> _count;
66
67 /// For each variable in forVars, its number of possible values
68 std::vector<size_t> _dims;
69
70 public:
71 /// Default constructor
72 IndexFor() {
73 _index = -1;
74 }
75
76 /// Constructor
77 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
78 long sum = 1;
79
80 _dims.reserve( forVars.size() );
81 _sum.reserve( forVars.size() );
82
83 VarSet::const_iterator j = forVars.begin();
84 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
85 for( ; j != forVars.end() && *j <= *i; ++j ) {
86 _dims.push_back( j->states() );
87 _sum.push_back( (*i == *j) ? sum : 0 );
88 }
89 sum *= i->states();
90 }
91 for( ; j != forVars.end(); ++j ) {
92 _dims.push_back( j->states() );
93 _sum.push_back( 0 );
94 }
95 _index = 0;
96 }
97
98 /// Copy constructor
99 IndexFor( const IndexFor & ind ) : _index(ind._index), _sum(ind._sum), _count(ind._count), _dims(ind._dims) {}
100
101 /// Assignment operator
102 IndexFor& operator=( const IndexFor &ind ) {
103 if( this != &ind ) {
104 _index = ind._index;
105 _sum = ind._sum;
106 _count = ind._count;
107 _dims = ind._dims;
108 }
109 return *this;
110 }
111
112 /// Sets the index back to zero
113 IndexFor& clear() {
114 fill( _count.begin(), _count.end(), 0 );
115 _index = 0;
116 return( *this );
117 }
118
119 /// Conversion to long
120 operator long () const {
121 return( _index );
122 }
123
124 /// Pre-increment operator
125 IndexFor& operator++ () {
126 if( _index >= 0 ) {
127 size_t i = 0;
128
129 while( i < _count.size() ) {
130 _index += _sum[i];
131 if( ++_count[i] < _dims[i] )
132 break;
133 _index -= _sum[i] * _dims[i];
134 _count[i] = 0;
135 i++;
136 }
137
138 if( i == _count.size() )
139 _index = -1;
140 }
141 return( *this );
142 }
143 };
144
145
146 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
147 /** An example of the usage is as follows:
148 * \code
149 * std::vector<size_t> dims;
150 * dims.push_back( 3 );
151 * dims.push_back( 4 );
152 * dims.push_back( 5 );
153 * for( MultiFor s(dims); s.valid(); ++s )
154 * cout << "linear index: " << (size_t)s << " corresponds with indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
155 * \endcode
156 * which would be equivalent to:
157 * \code
158 * size_t s = 0;
159 * for( size_t s0 = 0; s0 < 3; s0++ )
160 * for( size_t s1 = 0; s1 < 4; s1++ )
161 * for( size_t s2 = 0; s2 < 5; s++, s2++ )
162 * cout << "linear index: " << (size_t)s << " corresponds with indices " << s0 << ", " << s1 << ", " << s2 << endl;
163 * \endcode
164 */
165 class MultiFor {
166 private:
167 std::vector<size_t> _dims;
168 std::vector<size_t> _states;
169 long _state;
170
171 public:
172 /// Default constructor
173 MultiFor() : _dims(), _states(), _state(0) {}
174
175 /// Initialize from vector of index dimensions
176 MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
177
178 /// Copy constructor
179 MultiFor( const MultiFor &x ) : _dims(x._dims), _states(x._states), _state(x._state) {}
180
181 /// Assignment operator
182 MultiFor& operator=( const MultiFor & x ) {
183 if( this != &x ) {
184 _dims = x._dims;
185 _states = x._states;
186 _state = x._state;
187 }
188 return *this;
189 }
190
191 /// Return linear state
192 operator size_t() const {
193 assert( valid() );
194 return( _state );
195 }
196
197 /// Return k'th index
198 size_t operator[]( size_t k ) const {
199 assert( valid() );
200 assert( k < _states.size() );
201 return _states[k];
202 }
203
204 /// Prefix increment operator
205 MultiFor & operator++() {
206 if( valid() ) {
207 _state++;
208 size_t i;
209 for( i = 0; i != _states.size(); i++ ) {
210 if( ++(_states[i]) < _dims[i] )
211 break;
212 _states[i] = 0;
213 }
214 if( i == _states.size() )
215 _state = -1;
216 }
217 return *this;
218 }
219
220 /// Postfix increment operator
221 void operator++( int ) {
222 operator++();
223 }
224
225 /// Returns true if the current state is valid
226 bool valid() const {
227 return( _state >= 0 );
228 }
229 };
230
231
232 /// Tool for calculating permutations of multiple indices.
233 class Permute {
234 private:
235 std::vector<size_t> _dims;
236 std::vector<size_t> _sigma;
237
238 public:
239 /// Default constructor
240 Permute() : _dims(), _sigma() {}
241
242 /// Initialize from vector of index dimensions and permutation sigma
243 Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
244 assert( _dims.size() == _sigma.size() );
245 }
246
247 /// Copy constructor
248 Permute( const Permute &x ) : _dims(x._dims), _sigma(x._sigma) {}
249
250 /// Assignment operator
251 Permute& operator=( const Permute &x ) {
252 if( this != &x ) {
253 _dims = x._dims;
254 _sigma = x._sigma;
255 }
256 return *this;
257 }
258
259 /// Converts the linear index li to a vector index
260 /// corresponding with the dimensions in _dims,
261 /// permutes it according to sigma,
262 /// and converts it back to a linear index
263 /// according to the permuted dimensions.
264 size_t convert_linear_index( size_t li ) {
265 size_t N = _dims.size();
266
267 // calculate vector index corresponding to linear index
268 std::vector<size_t> vi;
269 vi.reserve( N );
270 size_t prod = 1;
271 for( size_t k = 0; k < N; k++ ) {
272 vi.push_back( li % _dims[k] );
273 li /= _dims[k];
274 prod *= _dims[k];
275 }
276
277 // convert permuted vector index to corresponding linear index
278 prod = 1;
279 size_t sigma_li = 0;
280 for( size_t k = 0; k < N; k++ ) {
281 sigma_li += vi[_sigma[k]] * prod;
282 prod *= _dims[_sigma[k]];
283 }
284
285 return sigma_li;
286 }
287 };
288
289
290 /// Contains the state of variables within a VarSet and useful things to do with this information.
291 /// This is very similar to a MultiFor, but tailored for Vars and Varsets.
292 class State {
293 private:
294 typedef std::map<Var, size_t> states_type;
295
296 long state;
297 states_type states;
298
299 public:
300 /// Default constructor
301 State() : state(0), states() {}
302
303 /// Initialize from VarSet
304 State( const VarSet &vs ) : state(0) {
305 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
306 states[*v] = 0;
307 }
308
309 /// Copy constructor
310 State( const State & x ) : state(x.state), states(x.states) {}
311
312 /// Assignment operator
313 State& operator=( const State &x ) {
314 if( this != &x ) {
315 state = x.state;
316 states = x.states;
317 }
318 return *this;
319 }
320
321 /// Return linear state
322 operator size_t() const {
323 assert( valid() );
324 return( state );
325 }
326
327 /// Return state of variable n,
328 /// or zero if n is not in this State
329 size_t operator() ( const Var &n ) const {
330 assert( valid() );
331 states_type::const_iterator entry = states.find( n );
332 if( entry == states.end() )
333 return 0;
334 else
335 return entry->second;
336 }
337
338 /// Return linear state of variables in varset,
339 /// setting them to zero if they are not in this State
340 size_t operator() ( const VarSet &vs ) const {
341 assert( valid() );
342 size_t vs_state = 0;
343 size_t prod = 1;
344 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
345 states_type::const_iterator entry = states.find( *v );
346 if( entry != states.end() )
347 vs_state += entry->second * prod;
348 prod *= v->states();
349 }
350 return vs_state;
351 }
352
353 /// Postfix increment operator
354 void operator++( int ) {
355 if( valid() ) {
356 state++;
357 states_type::iterator entry = states.begin();
358 while( entry != states.end() ) {
359 if( ++(entry->second) < entry->first.states() )
360 break;
361 entry->second = 0;
362 entry++;
363 }
364 if( entry == states.end() )
365 state = -1;
366 }
367 }
368
369 /// Returns true if the current state is valid
370 bool valid() const {
371 return( state >= 0 );
372 }
373 };
374
375
376 } // end of namespace dai
377
378
379 #endif