Oops, correct previous partial commit.
[libdai.git] / include / dai / index.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
7
8 This file is part of libDAI.
9
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25
26 /// \file
27 /// \brief Defines the IndexFor, MultiFor, Permute and State classes
28 /// \todo Improve documentation
29
30
31 #ifndef __defined_libdai_index_h
32 #define __defined_libdai_index_h
33
34
35 #include <vector>
36 #include <algorithm>
37 #include <map>
38 #include <cassert>
39 #include <dai/varset.h>
40
41
42 namespace dai {
43
44
45 /// Tool for looping over the states of several variables.
46 /** The class IndexFor is an important tool for indexing Factor entries.
47 * Its usage can best be explained by an example.
48 * Assume indexVars, forVars are both VarSets.
49 * Then the following code:
50 * \code
51 * IndexFor i( indexVars, forVars );
52 * for( ; i >= 0; ++i ) {
53 * // use long(i)
54 * }
55 * \endcode
56 * loops over all joint states of the variables in forVars,
57 * and (long)i is equal to the linear index of the corresponding
58 * state of indexVars, where the variables in indexVars that are
59 * not in forVars assume their zero'th value.
60 */
61 class IndexFor {
62 private:
63 /// The current linear index corresponding to the state of indexVars
64 long _index;
65
66 /// For each variable in forVars, the amount of change in _index
67 std::vector<long> _sum;
68
69 /// For each variable in forVars, the current state
70 std::vector<size_t> _count;
71
72 /// For each variable in forVars, its number of possible values
73 std::vector<size_t> _dims;
74
75 public:
76 /// Default constructor
77 IndexFor() {
78 _index = -1;
79 }
80
81 /// Constructor
82 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
83 long sum = 1;
84
85 _dims.reserve( forVars.size() );
86 _sum.reserve( forVars.size() );
87
88 VarSet::const_iterator j = forVars.begin();
89 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
90 for( ; j != forVars.end() && *j <= *i; ++j ) {
91 _dims.push_back( j->states() );
92 _sum.push_back( (*i == *j) ? sum : 0 );
93 }
94 sum *= i->states();
95 }
96 for( ; j != forVars.end(); ++j ) {
97 _dims.push_back( j->states() );
98 _sum.push_back( 0 );
99 }
100 _index = 0;
101 }
102
103 /// Copy constructor
104 IndexFor( const IndexFor & ind ) : _index(ind._index), _sum(ind._sum), _count(ind._count), _dims(ind._dims) {}
105
106 /// Assignment operator
107 IndexFor& operator=( const IndexFor &ind ) {
108 if( this != &ind ) {
109 _index = ind._index;
110 _sum = ind._sum;
111 _count = ind._count;
112 _dims = ind._dims;
113 }
114 return *this;
115 }
116
117 /// Sets the index back to zero
118 IndexFor& clear() {
119 fill( _count.begin(), _count.end(), 0 );
120 _index = 0;
121 return( *this );
122 }
123
124 /// Conversion to long
125 operator long () const {
126 return( _index );
127 }
128
129 /// Pre-increment operator
130 IndexFor& operator++ () {
131 if( _index >= 0 ) {
132 size_t i = 0;
133
134 while( i < _count.size() ) {
135 _index += _sum[i];
136 if( ++_count[i] < _dims[i] )
137 break;
138 _index -= _sum[i] * _dims[i];
139 _count[i] = 0;
140 i++;
141 }
142
143 if( i == _count.size() )
144 _index = -1;
145 }
146 return( *this );
147 }
148 };
149
150
151 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
152 /** An example of the usage is as follows:
153 * \code
154 * std::vector<size_t> dims;
155 * dims.push_back( 3 );
156 * dims.push_back( 4 );
157 * dims.push_back( 5 );
158 * for( MultiFor s(dims); s.valid(); ++s )
159 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
160 * \endcode
161 * which would be equivalent to:
162 * \code
163 * size_t s = 0;
164 * for( size_t s0 = 0; s0 < 3; s0++ )
165 * for( size_t s1 = 0; s1 < 4; s1++ )
166 * for( size_t s2 = 0; s2 < 5; s++, s2++ )
167 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s0 << ", " << s1 << ", " << s2 << endl;
168 * \endcode
169 */
170 class MultiFor {
171 private:
172 std::vector<size_t> _dims;
173 std::vector<size_t> _states;
174 long _state;
175
176 public:
177 /// Default constructor
178 MultiFor() : _dims(), _states(), _state(0) {}
179
180 /// Initialize from vector of index dimensions
181 MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
182
183 /// Copy constructor
184 MultiFor( const MultiFor &x ) : _dims(x._dims), _states(x._states), _state(x._state) {}
185
186 /// Assignment operator
187 MultiFor& operator=( const MultiFor & x ) {
188 if( this != &x ) {
189 _dims = x._dims;
190 _states = x._states;
191 _state = x._state;
192 }
193 return *this;
194 }
195
196 /// Return linear state
197 operator size_t() const {
198 assert( valid() );
199 return( _state );
200 }
201
202 /// Return k'th index
203 size_t operator[]( size_t k ) const {
204 assert( valid() );
205 assert( k < _states.size() );
206 return _states[k];
207 }
208
209 /// Prefix increment operator
210 MultiFor & operator++() {
211 if( valid() ) {
212 _state++;
213 size_t i;
214 for( i = 0; i != _states.size(); i++ ) {
215 if( ++(_states[i]) < _dims[i] )
216 break;
217 _states[i] = 0;
218 }
219 if( i == _states.size() )
220 _state = -1;
221 }
222 return *this;
223 }
224
225 /// Postfix increment operator
226 void operator++( int ) {
227 operator++();
228 }
229
230 /// Returns true if the current state is valid
231 bool valid() const {
232 return( _state >= 0 );
233 }
234 };
235
236
237 /// Tool for calculating permutations of multiple indices.
238 class Permute {
239 private:
240 std::vector<size_t> _dims;
241 std::vector<size_t> _sigma;
242
243 public:
244 /// Default constructor
245 Permute() : _dims(), _sigma() {}
246
247 /// Initialize from vector of index dimensions and permutation sigma
248 Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
249 assert( _dims.size() == _sigma.size() );
250 }
251
252 /// Copy constructor
253 Permute( const Permute &x ) : _dims(x._dims), _sigma(x._sigma) {}
254
255 /// Assignment operator
256 Permute& operator=( const Permute &x ) {
257 if( this != &x ) {
258 _dims = x._dims;
259 _sigma = x._sigma;
260 }
261 return *this;
262 }
263
264 /// Converts the linear index li to a vector index
265 /// corresponding with the dimensions in _dims,
266 /// permutes it according to sigma,
267 /// and converts it back to a linear index
268 /// according to the permuted dimensions.
269 size_t convert_linear_index( size_t li ) {
270 size_t N = _dims.size();
271
272 // calculate vector index corresponding to linear index
273 std::vector<size_t> vi;
274 vi.reserve( N );
275 size_t prod = 1;
276 for( size_t k = 0; k < N; k++ ) {
277 vi.push_back( li % _dims[k] );
278 li /= _dims[k];
279 prod *= _dims[k];
280 }
281
282 // convert permuted vector index to corresponding linear index
283 prod = 1;
284 size_t sigma_li = 0;
285 for( size_t k = 0; k < N; k++ ) {
286 sigma_li += vi[_sigma[k]] * prod;
287 prod *= _dims[_sigma[k]];
288 }
289
290 return sigma_li;
291 }
292 };
293
294
295 /// Contains the joint state of variables within a VarSet and useful things to do with this information.
296 /** This is very similar to a MultiFor, but tailored for Vars and Varsets.
297 */
298 class State {
299 private:
300 typedef std::map<Var, size_t> states_type;
301
302 long state;
303 states_type states;
304
305 public:
306 /// Default constructor
307 State() : state(0), states() {}
308
309 /// Initialize from VarSet
310 State( const VarSet &vs ) : state(0) {
311 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
312 states[*v] = 0;
313 }
314
315 /// Copy constructor
316 State( const State & x ) : state(x.state), states(x.states) {}
317
318 /// Assignment operator
319 State& operator=( const State &x ) {
320 if( this != &x ) {
321 state = x.state;
322 states = x.states;
323 }
324 return *this;
325 }
326
327 /// Return linear state
328 operator size_t() const {
329 assert( valid() );
330 return( state );
331 }
332
333 /// Return state of variable n,
334 /// or zero if n is not in this State
335 size_t operator() ( const Var &n ) const {
336 assert( valid() );
337 states_type::const_iterator entry = states.find( n );
338 if( entry == states.end() )
339 return 0;
340 else
341 return entry->second;
342 }
343
344 /// Return linear state of variables in varset,
345 /// setting them to zero if they are not in this State
346 size_t operator() ( const VarSet &vs ) const {
347 assert( valid() );
348 size_t vs_state = 0;
349 size_t prod = 1;
350 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
351 states_type::const_iterator entry = states.find( *v );
352 if( entry != states.end() )
353 vs_state += entry->second * prod;
354 prod *= v->states();
355 }
356 return vs_state;
357 }
358
359 /// Postfix increment operator
360 void operator++( int ) {
361 if( valid() ) {
362 state++;
363 states_type::iterator entry = states.begin();
364 while( entry != states.end() ) {
365 if( ++(entry->second) < entry->first.states() )
366 break;
367 entry->second = 0;
368 entry++;
369 }
370 if( entry == states.end() )
371 state = -1;
372 }
373 }
374
375 /// Returns true if the current state is valid
376 bool valid() const {
377 return( state >= 0 );
378 }
379 };
380
381
382 } // end of namespace dai
383
384
385 #endif