46b2244c6950c894eca5df804bf917c75c939756
[libdai.git] / include / dai / index.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
7
8 This file is part of libDAI.
9
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25
26 /// \file
27 /// \brief Defines the IndexFor, MultiFor, Permute and State classes
28 /// \todo Improve documentation
29
30
31 #ifndef __defined_libdai_index_h
32 #define __defined_libdai_index_h
33
34
35 #include <vector>
36 #include <algorithm>
37 #include <map>
38 #include <cassert>
39 #include <dai/varset.h>
40
41
42 namespace dai {
43
44
45 /// Tool for looping over the states of several variables.
46 /** The class IndexFor is an important tool for indexing Factor entries.
47 * Its usage can best be explained by an example.
48 * Assume indexVars, forVars are both VarSets.
49 * Then the following code:
50 * \code
51 * IndexFor i( indexVars, forVars );
52 * for( ; i >= 0; ++i ) {
53 * // use long(i)
54 * }
55 * \endcode
56 * loops over all joint states of the variables in forVars,
57 * and (long)i is equal to the linear index of the corresponding
58 * state of indexVars, where the variables in indexVars that are
59 * not in forVars assume their zero'th value.
60 * \idea Optimize all indices as follows: keep a cache of all (or only
61 * relatively small) indices that have been computed (use a hash). Then,
62 * instead of computing on the fly, use the precomputed ones.
63 */
64 class IndexFor {
65 private:
66 /// The current linear index corresponding to the state of indexVars
67 long _index;
68
69 /// For each variable in forVars, the amount of change in _index
70 std::vector<long> _sum;
71
72 /// For each variable in forVars, the current state
73 std::vector<size_t> _count;
74
75 /// For each variable in forVars, its number of possible values
76 std::vector<size_t> _dims;
77
78 public:
79 /// Default constructor
80 IndexFor() {
81 _index = -1;
82 }
83
84 /// Constructor
85 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
86 long sum = 1;
87
88 _dims.reserve( forVars.size() );
89 _sum.reserve( forVars.size() );
90
91 VarSet::const_iterator j = forVars.begin();
92 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
93 for( ; j != forVars.end() && *j <= *i; ++j ) {
94 _dims.push_back( j->states() );
95 _sum.push_back( (*i == *j) ? sum : 0 );
96 }
97 sum *= i->states();
98 }
99 for( ; j != forVars.end(); ++j ) {
100 _dims.push_back( j->states() );
101 _sum.push_back( 0 );
102 }
103 _index = 0;
104 }
105
106 /// Sets the index back to zero
107 IndexFor& clear() {
108 fill( _count.begin(), _count.end(), 0 );
109 _index = 0;
110 return( *this );
111 }
112
113 /// Conversion to long
114 operator long () const {
115 return( _index );
116 }
117
118 /// Pre-increment operator
119 IndexFor& operator++ () {
120 if( _index >= 0 ) {
121 size_t i = 0;
122
123 while( i < _count.size() ) {
124 _index += _sum[i];
125 if( ++_count[i] < _dims[i] )
126 break;
127 _index -= _sum[i] * _dims[i];
128 _count[i] = 0;
129 i++;
130 }
131
132 if( i == _count.size() )
133 _index = -1;
134 }
135 return( *this );
136 }
137 };
138
139
140 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
141 /** An example of the usage is as follows:
142 * \code
143 * std::vector<size_t> dims;
144 * dims.push_back( 3 );
145 * dims.push_back( 4 );
146 * dims.push_back( 5 );
147 * for( MultiFor s(dims); s.valid(); ++s )
148 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
149 * \endcode
150 * which would be equivalent to:
151 * \code
152 * size_t s = 0;
153 * for( size_t s0 = 0; s0 < 3; s0++ )
154 * for( size_t s1 = 0; s1 < 4; s1++ )
155 * for( size_t s2 = 0; s2 < 5; s++, s2++ )
156 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s0 << ", " << s1 << ", " << s2 << endl;
157 * \endcode
158 */
159 class MultiFor {
160 private:
161 std::vector<size_t> _dims;
162 std::vector<size_t> _states;
163 long _state;
164
165 public:
166 /// Default constructor
167 MultiFor() : _dims(), _states(), _state(0) {}
168
169 /// Initialize from vector of index dimensions
170 MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
171
172 /// Return linear state
173 operator size_t() const {
174 assert( valid() );
175 return( _state );
176 }
177
178 /// Return k'th index
179 size_t operator[]( size_t k ) const {
180 assert( valid() );
181 assert( k < _states.size() );
182 return _states[k];
183 }
184
185 /// Prefix increment operator
186 MultiFor & operator++() {
187 if( valid() ) {
188 _state++;
189 size_t i;
190 for( i = 0; i != _states.size(); i++ ) {
191 if( ++(_states[i]) < _dims[i] )
192 break;
193 _states[i] = 0;
194 }
195 if( i == _states.size() )
196 _state = -1;
197 }
198 return *this;
199 }
200
201 /// Postfix increment operator
202 void operator++( int ) {
203 operator++();
204 }
205
206 /// Returns true if the current state is valid
207 bool valid() const {
208 return( _state >= 0 );
209 }
210 };
211
212
213 /// Tool for calculating permutations of multiple indices.
214 class Permute {
215 private:
216 std::vector<size_t> _dims;
217 std::vector<size_t> _sigma;
218
219 public:
220 /// Default constructor
221 Permute() : _dims(), _sigma() {}
222
223 /// Initialize from vector of index dimensions and permutation sigma
224 Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
225 assert( _dims.size() == _sigma.size() );
226 }
227
228 /// Calculates a permuted linear index.
229 /** Converts the linear index li to a vector index
230 * corresponding with the dimensions in _dims, permutes it according to sigma,
231 * and converts it back to a linear index according to the permuted dimensions.
232 */
233 size_t convert_linear_index( size_t li ) {
234 size_t N = _dims.size();
235
236 // calculate vector index corresponding to linear index
237 std::vector<size_t> vi;
238 vi.reserve( N );
239 size_t prod = 1;
240 for( size_t k = 0; k < N; k++ ) {
241 vi.push_back( li % _dims[k] );
242 li /= _dims[k];
243 prod *= _dims[k];
244 }
245
246 // convert permuted vector index to corresponding linear index
247 prod = 1;
248 size_t sigma_li = 0;
249 for( size_t k = 0; k < N; k++ ) {
250 sigma_li += vi[_sigma[k]] * prod;
251 prod *= _dims[_sigma[k]];
252 }
253
254 return sigma_li;
255 }
256 };
257
258
259 /// Contains the joint state of variables within a VarSet and useful things to do with this information.
260 /** This is very similar to a MultiFor, but tailored for Vars and Varsets.
261 */
262 class State {
263 private:
264 typedef std::map<Var, size_t> states_type;
265
266 long state;
267 states_type states;
268
269 public:
270 /// Default constructor
271 State() : state(0), states() {}
272
273 /// Initialize from VarSet
274 State( const VarSet &vs ) : state(0) {
275 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
276 states[*v] = 0;
277 }
278
279 /// Return linear state
280 operator size_t() const {
281 assert( valid() );
282 return( state );
283 }
284
285 /// Return state of variable n, or zero if n is not in this State
286 size_t operator() ( const Var &n ) const {
287 assert( valid() );
288 states_type::const_iterator entry = states.find( n );
289 if( entry == states.end() )
290 return 0;
291 else
292 return entry->second;
293 }
294
295 /// Return linear state of variables in varset, setting them to zero if they are not in this State
296 size_t operator() ( const VarSet &vs ) const {
297 assert( valid() );
298 size_t vs_state = 0;
299 size_t prod = 1;
300 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
301 states_type::const_iterator entry = states.find( *v );
302 if( entry != states.end() )
303 vs_state += entry->second * prod;
304 prod *= v->states();
305 }
306 return vs_state;
307 }
308
309 /// Prefix increment operator
310 void operator++( ) {
311 if( valid() ) {
312 state++;
313 states_type::iterator entry = states.begin();
314 while( entry != states.end() ) {
315 if( ++(entry->second) < entry->first.states() )
316 break;
317 entry->second = 0;
318 entry++;
319 }
320 if( entry == states.end() )
321 state = -1;
322 }
323 }
324
325 /// Postfix increment operator
326 void operator++( int ) {
327 operator++();
328 }
329
330 /// Returns true if the current state is valid
331 bool valid() const {
332 return( state >= 0 );
333 }
334 };
335
336
337 } // end of namespace dai
338
339
340 #endif