Replaced the standard assert() macro by DAI_ASSERT
[libdai.git] / include / dai / index.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
8 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
9 * Copyright (C) 2002-2007 Radboud University Nijmegen, The Netherlands
10 */
11
12
13 /// \file
14 /// \brief Defines the IndexFor, MultiFor, Permute and State classes
15 /// \todo Improve documentation
16
17
18 #ifndef __defined_libdai_index_h
19 #define __defined_libdai_index_h
20
21
22 #include <vector>
23 #include <algorithm>
24 #include <map>
25 #include <dai/varset.h>
26
27
28 namespace dai {
29
30
31 /// Tool for looping over the states of several variables.
32 /** The class IndexFor is an important tool for indexing Factor entries.
33 * Its usage can best be explained by an example.
34 * Assume indexVars, forVars are both VarSets.
35 * Then the following code:
36 * \code
37 * IndexFor i( indexVars, forVars );
38 * for( ; i >= 0; ++i ) {
39 * // use long(i)
40 * }
41 * \endcode
42 * loops over all joint states of the variables in forVars,
43 * and (long)i is equal to the linear index of the corresponding
44 * state of indexVars, where the variables in indexVars that are
45 * not in forVars assume their zero'th value.
46 * \idea Optimize all indices as follows: keep a cache of all (or only
47 * relatively small) indices that have been computed (use a hash). Then,
48 * instead of computing on the fly, use the precomputed ones.
49 */
50 class IndexFor {
51 private:
52 /// The current linear index corresponding to the state of indexVars
53 long _index;
54
55 /// For each variable in forVars, the amount of change in _index
56 std::vector<long> _sum;
57
58 /// For each variable in forVars, the current state
59 std::vector<size_t> _count;
60
61 /// For each variable in forVars, its number of possible values
62 std::vector<size_t> _dims;
63
64 public:
65 /// Default constructor
66 IndexFor() {
67 _index = -1;
68 }
69
70 /// Constructor
71 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
72 long sum = 1;
73
74 _dims.reserve( forVars.size() );
75 _sum.reserve( forVars.size() );
76
77 VarSet::const_iterator j = forVars.begin();
78 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
79 for( ; j != forVars.end() && *j <= *i; ++j ) {
80 _dims.push_back( j->states() );
81 _sum.push_back( (*i == *j) ? sum : 0 );
82 }
83 sum *= i->states();
84 }
85 for( ; j != forVars.end(); ++j ) {
86 _dims.push_back( j->states() );
87 _sum.push_back( 0 );
88 }
89 _index = 0;
90 }
91
92 /// Sets the index back to zero
93 IndexFor& clear() {
94 fill( _count.begin(), _count.end(), 0 );
95 _index = 0;
96 return( *this );
97 }
98
99 /// Conversion to long
100 operator long () const {
101 return( _index );
102 }
103
104 /// Pre-increment operator
105 IndexFor& operator++ () {
106 if( _index >= 0 ) {
107 size_t i = 0;
108
109 while( i < _count.size() ) {
110 _index += _sum[i];
111 if( ++_count[i] < _dims[i] )
112 break;
113 _index -= _sum[i] * _dims[i];
114 _count[i] = 0;
115 i++;
116 }
117
118 if( i == _count.size() )
119 _index = -1;
120 }
121 return( *this );
122 }
123 };
124
125
126 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
127 /** An example of the usage is as follows:
128 * \code
129 * std::vector<size_t> dims;
130 * dims.push_back( 3 );
131 * dims.push_back( 4 );
132 * dims.push_back( 5 );
133 * for( MultiFor s(dims); s.valid(); ++s )
134 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
135 * \endcode
136 * which would be equivalent to:
137 * \code
138 * size_t s = 0;
139 * for( size_t s0 = 0; s0 < 3; s0++ )
140 * for( size_t s1 = 0; s1 < 4; s1++ )
141 * for( size_t s2 = 0; s2 < 5; s++, s2++ )
142 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s0 << ", " << s1 << ", " << s2 << endl;
143 * \endcode
144 */
145 class MultiFor {
146 private:
147 std::vector<size_t> _dims;
148 std::vector<size_t> _states;
149 long _state;
150
151 public:
152 /// Default constructor
153 MultiFor() : _dims(), _states(), _state(0) {}
154
155 /// Initialize from vector of index dimensions
156 MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
157
158 /// Return linear state
159 operator size_t() const {
160 DAI_ASSERT( valid() );
161 return( _state );
162 }
163
164 /// Return k'th index
165 size_t operator[]( size_t k ) const {
166 DAI_ASSERT( valid() );
167 DAI_ASSERT( k < _states.size() );
168 return _states[k];
169 }
170
171 /// Prefix increment operator
172 MultiFor & operator++() {
173 if( valid() ) {
174 _state++;
175 size_t i;
176 for( i = 0; i != _states.size(); i++ ) {
177 if( ++(_states[i]) < _dims[i] )
178 break;
179 _states[i] = 0;
180 }
181 if( i == _states.size() )
182 _state = -1;
183 }
184 return *this;
185 }
186
187 /// Postfix increment operator
188 void operator++( int ) {
189 operator++();
190 }
191
192 /// Returns true if the current state is valid
193 bool valid() const {
194 return( _state >= 0 );
195 }
196 };
197
198
199 /// Tool for calculating permutations of multiple indices.
200 class Permute {
201 private:
202 std::vector<size_t> _dims;
203 std::vector<size_t> _sigma;
204
205 public:
206 /// Default constructor
207 Permute() : _dims(), _sigma() {}
208
209 /// Construct from vector of index dimensions and permutation sigma
210 Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
211 DAI_ASSERT( _dims.size() == _sigma.size() );
212 }
213
214 /// Construct from vector of variables
215 Permute( const std::vector<Var> &vars ) : _dims(vars.size()), _sigma(vars.size()) {
216 VarSet vs( vars.begin(), vars.end(), vars.size() );
217 for( size_t i = 0; i < vars.size(); ++i )
218 _dims[i] = vars[i].states();
219 VarSet::const_iterator set_iter = vs.begin();
220 for( size_t i = 0; i < vs.size(); ++i, ++set_iter )
221 _sigma[i] = find( vars.begin(), vars.end(), *set_iter ) - vars.begin();
222 }
223
224 /// Calculates a permuted linear index.
225 /** Converts the linear index li to a vector index
226 * corresponding with the dimensions in _dims, permutes it according to sigma,
227 * and converts it back to a linear index according to the permuted dimensions.
228 */
229 size_t convert_linear_index( size_t li ) const {
230 size_t N = _dims.size();
231
232 // calculate vector index corresponding to linear index
233 std::vector<size_t> vi;
234 vi.reserve( N );
235 size_t prod = 1;
236 for( size_t k = 0; k < N; k++ ) {
237 vi.push_back( li % _dims[k] );
238 li /= _dims[k];
239 prod *= _dims[k];
240 }
241
242 // convert permuted vector index to corresponding linear index
243 prod = 1;
244 size_t sigma_li = 0;
245 for( size_t k = 0; k < N; k++ ) {
246 sigma_li += vi[_sigma[k]] * prod;
247 prod *= _dims[_sigma[k]];
248 }
249
250 return sigma_li;
251 }
252 };
253
254
255 /// Contains the joint state of variables within a VarSet and useful things to do with this information.
256 /** This is very similar to a MultiFor, but tailored for Vars and Varsets.
257 */
258 class State {
259 private:
260 typedef std::map<Var, size_t> states_type;
261
262 long state;
263 states_type states;
264
265 public:
266 /// Default constructor
267 State() : state(0), states() {}
268
269 /// Initialize from VarSet
270 State( const VarSet &vs ) : state(0) {
271 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
272 states[*v] = 0;
273 }
274
275 /// Return linear state
276 operator size_t() const {
277 DAI_ASSERT( valid() );
278 return( state );
279 }
280
281 /// Return state of variable n, or zero if n is not in this State
282 size_t operator() ( const Var &n ) const {
283 DAI_ASSERT( valid() );
284 states_type::const_iterator entry = states.find( n );
285 if( entry == states.end() )
286 return 0;
287 else
288 return entry->second;
289 }
290
291 /// Return linear state of variables in varset, setting them to zero if they are not in this State
292 size_t operator() ( const VarSet &vs ) const {
293 DAI_ASSERT( valid() );
294 size_t vs_state = 0;
295 size_t prod = 1;
296 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
297 states_type::const_iterator entry = states.find( *v );
298 if( entry != states.end() )
299 vs_state += entry->second * prod;
300 prod *= v->states();
301 }
302 return vs_state;
303 }
304
305 /// Prefix increment operator
306 void operator++( ) {
307 if( valid() ) {
308 state++;
309 states_type::iterator entry = states.begin();
310 while( entry != states.end() ) {
311 if( ++(entry->second) < entry->first.states() )
312 break;
313 entry->second = 0;
314 entry++;
315 }
316 if( entry == states.end() )
317 state = -1;
318 }
319 }
320
321 /// Postfix increment operator
322 void operator++( int ) {
323 operator++();
324 }
325
326 /// Returns true if the current state is valid
327 bool valid() const {
328 return( state >= 0 );
329 }
330 };
331
332
333 } // end of namespace dai
334
335
336 #endif