Improved coding style of recent changes by Charlie Vaske
[libdai.git] / include / dai / index.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
7
8 This file is part of libDAI.
9
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25
26 /// \file
27 /// \brief Defines the IndexFor, MultiFor, Permute and State classes
28 /// \todo Improve documentation
29
30
31 #ifndef __defined_libdai_index_h
32 #define __defined_libdai_index_h
33
34
35 #include <vector>
36 #include <algorithm>
37 #include <map>
38 #include <cassert>
39 #include <dai/varset.h>
40
41
42 namespace dai {
43
44
45 /// Tool for looping over the states of several variables.
46 /** The class IndexFor is an important tool for indexing Factor entries.
47 * Its usage can best be explained by an example.
48 * Assume indexVars, forVars are both VarSets.
49 * Then the following code:
50 * \code
51 * IndexFor i( indexVars, forVars );
52 * for( ; i >= 0; ++i ) {
53 * // use long(i)
54 * }
55 * \endcode
56 * loops over all joint states of the variables in forVars,
57 * and (long)i is equal to the linear index of the corresponding
58 * state of indexVars, where the variables in indexVars that are
59 * not in forVars assume their zero'th value.
60 * \idea Optimize all indices as follows: keep a cache of all (or only
61 * relatively small) indices that have been computed (use a hash). Then,
62 * instead of computing on the fly, use the precomputed ones.
63 */
64 class IndexFor {
65 private:
66 /// The current linear index corresponding to the state of indexVars
67 long _index;
68
69 /// For each variable in forVars, the amount of change in _index
70 std::vector<long> _sum;
71
72 /// For each variable in forVars, the current state
73 std::vector<size_t> _count;
74
75 /// For each variable in forVars, its number of possible values
76 std::vector<size_t> _dims;
77
78 public:
79 /// Default constructor
80 IndexFor() {
81 _index = -1;
82 }
83
84 /// Constructor
85 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
86 long sum = 1;
87
88 _dims.reserve( forVars.size() );
89 _sum.reserve( forVars.size() );
90
91 VarSet::const_iterator j = forVars.begin();
92 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
93 for( ; j != forVars.end() && *j <= *i; ++j ) {
94 _dims.push_back( j->states() );
95 _sum.push_back( (*i == *j) ? sum : 0 );
96 }
97 sum *= i->states();
98 }
99 for( ; j != forVars.end(); ++j ) {
100 _dims.push_back( j->states() );
101 _sum.push_back( 0 );
102 }
103 _index = 0;
104 }
105
106 /// Sets the index back to zero
107 IndexFor& clear() {
108 fill( _count.begin(), _count.end(), 0 );
109 _index = 0;
110 return( *this );
111 }
112
113 /// Conversion to long
114 operator long () const {
115 return( _index );
116 }
117
118 /// Pre-increment operator
119 IndexFor& operator++ () {
120 if( _index >= 0 ) {
121 size_t i = 0;
122
123 while( i < _count.size() ) {
124 _index += _sum[i];
125 if( ++_count[i] < _dims[i] )
126 break;
127 _index -= _sum[i] * _dims[i];
128 _count[i] = 0;
129 i++;
130 }
131
132 if( i == _count.size() )
133 _index = -1;
134 }
135 return( *this );
136 }
137 };
138
139
140 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
141 /** An example of the usage is as follows:
142 * \code
143 * std::vector<size_t> dims;
144 * dims.push_back( 3 );
145 * dims.push_back( 4 );
146 * dims.push_back( 5 );
147 * for( MultiFor s(dims); s.valid(); ++s )
148 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
149 * \endcode
150 * which would be equivalent to:
151 * \code
152 * size_t s = 0;
153 * for( size_t s0 = 0; s0 < 3; s0++ )
154 * for( size_t s1 = 0; s1 < 4; s1++ )
155 * for( size_t s2 = 0; s2 < 5; s++, s2++ )
156 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s0 << ", " << s1 << ", " << s2 << endl;
157 * \endcode
158 */
159 class MultiFor {
160 private:
161 std::vector<size_t> _dims;
162 std::vector<size_t> _states;
163 long _state;
164
165 public:
166 /// Default constructor
167 MultiFor() : _dims(), _states(), _state(0) {}
168
169 /// Initialize from vector of index dimensions
170 MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
171
172 /// Return linear state
173 operator size_t() const {
174 assert( valid() );
175 return( _state );
176 }
177
178 /// Return k'th index
179 size_t operator[]( size_t k ) const {
180 assert( valid() );
181 assert( k < _states.size() );
182 return _states[k];
183 }
184
185 /// Prefix increment operator
186 MultiFor & operator++() {
187 if( valid() ) {
188 _state++;
189 size_t i;
190 for( i = 0; i != _states.size(); i++ ) {
191 if( ++(_states[i]) < _dims[i] )
192 break;
193 _states[i] = 0;
194 }
195 if( i == _states.size() )
196 _state = -1;
197 }
198 return *this;
199 }
200
201 /// Postfix increment operator
202 void operator++( int ) {
203 operator++();
204 }
205
206 /// Returns true if the current state is valid
207 bool valid() const {
208 return( _state >= 0 );
209 }
210 };
211
212
213 /// Tool for calculating permutations of multiple indices.
214 class Permute {
215 private:
216 std::vector<size_t> _dims;
217 std::vector<size_t> _sigma;
218
219 public:
220 /// Default constructor
221 Permute() : _dims(), _sigma() {}
222
223 /// Construct from vector of index dimensions and permutation sigma
224 Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
225 assert( _dims.size() == _sigma.size() );
226 }
227
228 /// Construct from vector of variables
229 Permute( const std::vector<Var> &vars ) : _dims(vars.size()), _sigma(vars.size()) {
230 VarSet vs( vars.begin(), vars.end(), vars.size() );
231 for( size_t i = 0; i < vars.size(); ++i )
232 _dims[i] = vars[i].states();
233 VarSet::const_iterator set_iter = vs.begin();
234 for( size_t i = 0; i < vs.size(); ++i, ++set_iter )
235 _sigma[i] = find( vars.begin(), vars.end(), *set_iter ) - vars.begin();
236 }
237
238 /// Calculates a permuted linear index.
239 /** Converts the linear index li to a vector index
240 * corresponding with the dimensions in _dims, permutes it according to sigma,
241 * and converts it back to a linear index according to the permuted dimensions.
242 */
243 size_t convert_linear_index( size_t li ) const {
244 size_t N = _dims.size();
245
246 // calculate vector index corresponding to linear index
247 std::vector<size_t> vi;
248 vi.reserve( N );
249 size_t prod = 1;
250 for( size_t k = 0; k < N; k++ ) {
251 vi.push_back( li % _dims[k] );
252 li /= _dims[k];
253 prod *= _dims[k];
254 }
255
256 // convert permuted vector index to corresponding linear index
257 prod = 1;
258 size_t sigma_li = 0;
259 for( size_t k = 0; k < N; k++ ) {
260 sigma_li += vi[_sigma[k]] * prod;
261 prod *= _dims[_sigma[k]];
262 }
263
264 return sigma_li;
265 }
266 };
267
268
269 /// Contains the joint state of variables within a VarSet and useful things to do with this information.
270 /** This is very similar to a MultiFor, but tailored for Vars and Varsets.
271 */
272 class State {
273 private:
274 typedef std::map<Var, size_t> states_type;
275
276 long state;
277 states_type states;
278
279 public:
280 /// Default constructor
281 State() : state(0), states() {}
282
283 /// Initialize from VarSet
284 State( const VarSet &vs ) : state(0) {
285 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
286 states[*v] = 0;
287 }
288
289 /// Return linear state
290 operator size_t() const {
291 assert( valid() );
292 return( state );
293 }
294
295 /// Return state of variable n, or zero if n is not in this State
296 size_t operator() ( const Var &n ) const {
297 assert( valid() );
298 states_type::const_iterator entry = states.find( n );
299 if( entry == states.end() )
300 return 0;
301 else
302 return entry->second;
303 }
304
305 /// Return linear state of variables in varset, setting them to zero if they are not in this State
306 size_t operator() ( const VarSet &vs ) const {
307 assert( valid() );
308 size_t vs_state = 0;
309 size_t prod = 1;
310 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
311 states_type::const_iterator entry = states.find( *v );
312 if( entry != states.end() )
313 vs_state += entry->second * prod;
314 prod *= v->states();
315 }
316 return vs_state;
317 }
318
319 /// Prefix increment operator
320 void operator++( ) {
321 if( valid() ) {
322 state++;
323 states_type::iterator entry = states.begin();
324 while( entry != states.end() ) {
325 if( ++(entry->second) < entry->first.states() )
326 break;
327 entry->second = 0;
328 entry++;
329 }
330 if( entry == states.end() )
331 state = -1;
332 }
333 }
334
335 /// Postfix increment operator
336 void operator++( int ) {
337 operator++();
338 }
339
340 /// Returns true if the current state is valid
341 bool valid() const {
342 return( state >= 0 );
343 }
344 };
345
346
347 } // end of namespace dai
348
349
350 #endif