Significant improvement of documentation
[libdai.git] / include / dai / index.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
7
8 This file is part of libDAI.
9
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25
26 /// \file
27 /// \brief Defines the IndexFor, MultiFor, Permute and State classes
28
29
30 #ifndef __defined_libdai_index_h
31 #define __defined_libdai_index_h
32
33
34 #include <vector>
35 #include <algorithm>
36 #include <map>
37 #include <cassert>
38 #include <dai/varset.h>
39
40
41 namespace dai {
42
43
44 /// Tool for looping over the states of several variables.
45 /** The class IndexFor is an important tool for indexing Factor entries.
46 * Its usage can best be explained by an example.
47 * Assume indexVars, forVars are both VarSets.
48 * Then the following code:
49 * \code
50 * IndexFor i( indexVars, forVars );
51 * for( ; i >= 0; ++i ) {
52 * // use long(i)
53 * }
54 * \endcode
55 * loops over all joint states of the variables in forVars,
56 * and (long)i is equal to the linear index of the corresponding
57 * state of indexVars, where the variables in indexVars that are
58 * not in forVars assume their zero'th value.
59 */
60 class IndexFor {
61 private:
62 /// The current linear index corresponding to the state of indexVars
63 long _index;
64
65 /// For each variable in forVars, the amount of change in _index
66 std::vector<long> _sum;
67
68 /// For each variable in forVars, the current state
69 std::vector<size_t> _count;
70
71 /// For each variable in forVars, its number of possible values
72 std::vector<size_t> _dims;
73
74 public:
75 /// Default constructor
76 IndexFor() {
77 _index = -1;
78 }
79
80 /// Constructor
81 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
82 long sum = 1;
83
84 _dims.reserve( forVars.size() );
85 _sum.reserve( forVars.size() );
86
87 VarSet::const_iterator j = forVars.begin();
88 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
89 for( ; j != forVars.end() && *j <= *i; ++j ) {
90 _dims.push_back( j->states() );
91 _sum.push_back( (*i == *j) ? sum : 0 );
92 }
93 sum *= i->states();
94 }
95 for( ; j != forVars.end(); ++j ) {
96 _dims.push_back( j->states() );
97 _sum.push_back( 0 );
98 }
99 _index = 0;
100 }
101
102 /// Copy constructor
103 IndexFor( const IndexFor & ind ) : _index(ind._index), _sum(ind._sum), _count(ind._count), _dims(ind._dims) {}
104
105 /// Assignment operator
106 IndexFor& operator=( const IndexFor &ind ) {
107 if( this != &ind ) {
108 _index = ind._index;
109 _sum = ind._sum;
110 _count = ind._count;
111 _dims = ind._dims;
112 }
113 return *this;
114 }
115
116 /// Sets the index back to zero
117 IndexFor& clear() {
118 fill( _count.begin(), _count.end(), 0 );
119 _index = 0;
120 return( *this );
121 }
122
123 /// Conversion to long
124 operator long () const {
125 return( _index );
126 }
127
128 /// Pre-increment operator
129 IndexFor& operator++ () {
130 if( _index >= 0 ) {
131 size_t i = 0;
132
133 while( i < _count.size() ) {
134 _index += _sum[i];
135 if( ++_count[i] < _dims[i] )
136 break;
137 _index -= _sum[i] * _dims[i];
138 _count[i] = 0;
139 i++;
140 }
141
142 if( i == _count.size() )
143 _index = -1;
144 }
145 return( *this );
146 }
147 };
148
149
150 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
151 /** An example of the usage is as follows:
152 * \code
153 * std::vector<size_t> dims;
154 * dims.push_back( 3 );
155 * dims.push_back( 4 );
156 * dims.push_back( 5 );
157 * for( MultiFor s(dims); s.valid(); ++s )
158 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
159 * \endcode
160 * which would be equivalent to:
161 * \code
162 * size_t s = 0;
163 * for( size_t s0 = 0; s0 < 3; s0++ )
164 * for( size_t s1 = 0; s1 < 4; s1++ )
165 * for( size_t s2 = 0; s2 < 5; s++, s2++ )
166 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s0 << ", " << s1 << ", " << s2 << endl;
167 * \endcode
168 */
169 class MultiFor {
170 private:
171 std::vector<size_t> _dims;
172 std::vector<size_t> _states;
173 long _state;
174
175 public:
176 /// Default constructor
177 MultiFor() : _dims(), _states(), _state(0) {}
178
179 /// Initialize from vector of index dimensions
180 MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
181
182 /// Copy constructor
183 MultiFor( const MultiFor &x ) : _dims(x._dims), _states(x._states), _state(x._state) {}
184
185 /// Assignment operator
186 MultiFor& operator=( const MultiFor & x ) {
187 if( this != &x ) {
188 _dims = x._dims;
189 _states = x._states;
190 _state = x._state;
191 }
192 return *this;
193 }
194
195 /// Return linear state
196 operator size_t() const {
197 assert( valid() );
198 return( _state );
199 }
200
201 /// Return k'th index
202 size_t operator[]( size_t k ) const {
203 assert( valid() );
204 assert( k < _states.size() );
205 return _states[k];
206 }
207
208 /// Prefix increment operator
209 MultiFor & operator++() {
210 if( valid() ) {
211 _state++;
212 size_t i;
213 for( i = 0; i != _states.size(); i++ ) {
214 if( ++(_states[i]) < _dims[i] )
215 break;
216 _states[i] = 0;
217 }
218 if( i == _states.size() )
219 _state = -1;
220 }
221 return *this;
222 }
223
224 /// Postfix increment operator
225 void operator++( int ) {
226 operator++();
227 }
228
229 /// Returns true if the current state is valid
230 bool valid() const {
231 return( _state >= 0 );
232 }
233 };
234
235
236 /// Tool for calculating permutations of multiple indices.
237 class Permute {
238 private:
239 std::vector<size_t> _dims;
240 std::vector<size_t> _sigma;
241
242 public:
243 /// Default constructor
244 Permute() : _dims(), _sigma() {}
245
246 /// Initialize from vector of index dimensions and permutation sigma
247 Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
248 assert( _dims.size() == _sigma.size() );
249 }
250
251 /// Copy constructor
252 Permute( const Permute &x ) : _dims(x._dims), _sigma(x._sigma) {}
253
254 /// Assignment operator
255 Permute& operator=( const Permute &x ) {
256 if( this != &x ) {
257 _dims = x._dims;
258 _sigma = x._sigma;
259 }
260 return *this;
261 }
262
263 /// Converts the linear index li to a vector index
264 /// corresponding with the dimensions in _dims,
265 /// permutes it according to sigma,
266 /// and converts it back to a linear index
267 /// according to the permuted dimensions.
268 size_t convert_linear_index( size_t li ) {
269 size_t N = _dims.size();
270
271 // calculate vector index corresponding to linear index
272 std::vector<size_t> vi;
273 vi.reserve( N );
274 size_t prod = 1;
275 for( size_t k = 0; k < N; k++ ) {
276 vi.push_back( li % _dims[k] );
277 li /= _dims[k];
278 prod *= _dims[k];
279 }
280
281 // convert permuted vector index to corresponding linear index
282 prod = 1;
283 size_t sigma_li = 0;
284 for( size_t k = 0; k < N; k++ ) {
285 sigma_li += vi[_sigma[k]] * prod;
286 prod *= _dims[_sigma[k]];
287 }
288
289 return sigma_li;
290 }
291 };
292
293
294 /// Contains the joint state of variables within a VarSet and useful things to do with this information.
295 /** This is very similar to a MultiFor, but tailored for Vars and Varsets.
296 */
297 class State {
298 private:
299 typedef std::map<Var, size_t> states_type;
300
301 long state;
302 states_type states;
303
304 public:
305 /// Default constructor
306 State() : state(0), states() {}
307
308 /// Initialize from VarSet
309 State( const VarSet &vs ) : state(0) {
310 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
311 states[*v] = 0;
312 }
313
314 /// Copy constructor
315 State( const State & x ) : state(x.state), states(x.states) {}
316
317 /// Assignment operator
318 State& operator=( const State &x ) {
319 if( this != &x ) {
320 state = x.state;
321 states = x.states;
322 }
323 return *this;
324 }
325
326 /// Return linear state
327 operator size_t() const {
328 assert( valid() );
329 return( state );
330 }
331
332 /// Return state of variable n,
333 /// or zero if n is not in this State
334 size_t operator() ( const Var &n ) const {
335 assert( valid() );
336 states_type::const_iterator entry = states.find( n );
337 if( entry == states.end() )
338 return 0;
339 else
340 return entry->second;
341 }
342
343 /// Return linear state of variables in varset,
344 /// setting them to zero if they are not in this State
345 size_t operator() ( const VarSet &vs ) const {
346 assert( valid() );
347 size_t vs_state = 0;
348 size_t prod = 1;
349 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
350 states_type::const_iterator entry = states.find( *v );
351 if( entry != states.end() )
352 vs_state += entry->second * prod;
353 prod *= v->states();
354 }
355 return vs_state;
356 }
357
358 /// Postfix increment operator
359 void operator++( int ) {
360 if( valid() ) {
361 state++;
362 states_type::iterator entry = states.begin();
363 while( entry != states.end() ) {
364 if( ++(entry->second) < entry->first.states() )
365 break;
366 entry->second = 0;
367 entry++;
368 }
369 if( entry == states.end() )
370 state = -1;
371 }
372 }
373
374 /// Returns true if the current state is valid
375 bool valid() const {
376 return( state >= 0 );
377 }
378 };
379
380
381 } // end of namespace dai
382
383
384 #endif