EM bugfix. Convenience methods in Factor, Permute, Properties, EM.
[libdai.git] / include / dai / index.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
7
8 This file is part of libDAI.
9
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25
26 /// \file
27 /// \brief Defines the IndexFor, MultiFor, Permute and State classes
28 /// \todo Improve documentation
29
30
31 #ifndef __defined_libdai_index_h
32 #define __defined_libdai_index_h
33
34
35 #include <vector>
36 #include <algorithm>
37 #include <map>
38 #include <cassert>
39 #include <dai/varset.h>
40
41
42 namespace dai {
43
44
45 /// Tool for looping over the states of several variables.
46 /** The class IndexFor is an important tool for indexing Factor entries.
47 * Its usage can best be explained by an example.
48 * Assume indexVars, forVars are both VarSets.
49 * Then the following code:
50 * \code
51 * IndexFor i( indexVars, forVars );
52 * for( ; i >= 0; ++i ) {
53 * // use long(i)
54 * }
55 * \endcode
56 * loops over all joint states of the variables in forVars,
57 * and (long)i is equal to the linear index of the corresponding
58 * state of indexVars, where the variables in indexVars that are
59 * not in forVars assume their zero'th value.
60 * \idea Optimize all indices as follows: keep a cache of all (or only
61 * relatively small) indices that have been computed (use a hash). Then,
62 * instead of computing on the fly, use the precomputed ones.
63 */
64 class IndexFor {
65 private:
66 /// The current linear index corresponding to the state of indexVars
67 long _index;
68
69 /// For each variable in forVars, the amount of change in _index
70 std::vector<long> _sum;
71
72 /// For each variable in forVars, the current state
73 std::vector<size_t> _count;
74
75 /// For each variable in forVars, its number of possible values
76 std::vector<size_t> _dims;
77
78 public:
79 /// Default constructor
80 IndexFor() {
81 _index = -1;
82 }
83
84 /// Constructor
85 IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
86 long sum = 1;
87
88 _dims.reserve( forVars.size() );
89 _sum.reserve( forVars.size() );
90
91 VarSet::const_iterator j = forVars.begin();
92 for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
93 for( ; j != forVars.end() && *j <= *i; ++j ) {
94 _dims.push_back( j->states() );
95 _sum.push_back( (*i == *j) ? sum : 0 );
96 }
97 sum *= i->states();
98 }
99 for( ; j != forVars.end(); ++j ) {
100 _dims.push_back( j->states() );
101 _sum.push_back( 0 );
102 }
103 _index = 0;
104 }
105
106 /// Sets the index back to zero
107 IndexFor& clear() {
108 fill( _count.begin(), _count.end(), 0 );
109 _index = 0;
110 return( *this );
111 }
112
113 /// Conversion to long
114 operator long () const {
115 return( _index );
116 }
117
118 /// Pre-increment operator
119 IndexFor& operator++ () {
120 if( _index >= 0 ) {
121 size_t i = 0;
122
123 while( i < _count.size() ) {
124 _index += _sum[i];
125 if( ++_count[i] < _dims[i] )
126 break;
127 _index -= _sum[i] * _dims[i];
128 _count[i] = 0;
129 i++;
130 }
131
132 if( i == _count.size() )
133 _index = -1;
134 }
135 return( *this );
136 }
137 };
138
139
140 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
141 /** An example of the usage is as follows:
142 * \code
143 * std::vector<size_t> dims;
144 * dims.push_back( 3 );
145 * dims.push_back( 4 );
146 * dims.push_back( 5 );
147 * for( MultiFor s(dims); s.valid(); ++s )
148 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
149 * \endcode
150 * which would be equivalent to:
151 * \code
152 * size_t s = 0;
153 * for( size_t s0 = 0; s0 < 3; s0++ )
154 * for( size_t s1 = 0; s1 < 4; s1++ )
155 * for( size_t s2 = 0; s2 < 5; s++, s2++ )
156 * cout << "linear index: " << (size_t)s << " corresponds to indices " << s0 << ", " << s1 << ", " << s2 << endl;
157 * \endcode
158 */
159 class MultiFor {
160 private:
161 std::vector<size_t> _dims;
162 std::vector<size_t> _states;
163 long _state;
164
165 public:
166 /// Default constructor
167 MultiFor() : _dims(), _states(), _state(0) {}
168
169 /// Initialize from vector of index dimensions
170 MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
171
172 /// Return linear state
173 operator size_t() const {
174 assert( valid() );
175 return( _state );
176 }
177
178 /// Return k'th index
179 size_t operator[]( size_t k ) const {
180 assert( valid() );
181 assert( k < _states.size() );
182 return _states[k];
183 }
184
185 /// Prefix increment operator
186 MultiFor & operator++() {
187 if( valid() ) {
188 _state++;
189 size_t i;
190 for( i = 0; i != _states.size(); i++ ) {
191 if( ++(_states[i]) < _dims[i] )
192 break;
193 _states[i] = 0;
194 }
195 if( i == _states.size() )
196 _state = -1;
197 }
198 return *this;
199 }
200
201 /// Postfix increment operator
202 void operator++( int ) {
203 operator++();
204 }
205
206 /// Returns true if the current state is valid
207 bool valid() const {
208 return( _state >= 0 );
209 }
210 };
211
212
213 /// Tool for calculating permutations of multiple indices.
214 class Permute {
215 private:
216 std::vector<size_t> _dims;
217 std::vector<size_t> _sigma;
218
219 public:
220 /// Default constructor
221 Permute() : _dims(), _sigma() {}
222
223 /// Initialize from vector of index dimensions and permutation sigma
224 Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
225 assert( _dims.size() == _sigma.size() );
226 }
227
228 Permute(const std::vector< Var >& vars) : _dims(vars.size()), _sigma(vars.size()) {
229 VarSet vs(vars.begin(), vars.end(), vars.size());
230 for (size_t i = 0; i < vars.size(); ++i) {
231 _dims[i] = vars[i].states();
232 }
233 VarSet::iterator set_iter = vs.begin();
234 for (size_t i = 0; i < vs.size(); ++i, ++set_iter) {
235 std::vector< Var >::const_iterator j;
236 j = find(vars.begin(), vars.end(), *set_iter);
237 _sigma[i] = j - vars.begin();
238 }
239 }
240
241 /// Calculates a permuted linear index.
242 /** Converts the linear index li to a vector index
243 * corresponding with the dimensions in _dims, permutes it according to sigma,
244 * and converts it back to a linear index according to the permuted dimensions.
245 */
246 size_t convert_linear_index( size_t li ) const {
247 size_t N = _dims.size();
248
249 // calculate vector index corresponding to linear index
250 std::vector<size_t> vi;
251 vi.reserve( N );
252 size_t prod = 1;
253 for( size_t k = 0; k < N; k++ ) {
254 vi.push_back( li % _dims[k] );
255 li /= _dims[k];
256 prod *= _dims[k];
257 }
258
259 // convert permuted vector index to corresponding linear index
260 prod = 1;
261 size_t sigma_li = 0;
262 for( size_t k = 0; k < N; k++ ) {
263 sigma_li += vi[_sigma[k]] * prod;
264 prod *= _dims[_sigma[k]];
265 }
266
267 return sigma_li;
268 }
269 };
270
271
272 /// Contains the joint state of variables within a VarSet and useful things to do with this information.
273 /** This is very similar to a MultiFor, but tailored for Vars and Varsets.
274 */
275 class State {
276 private:
277 typedef std::map<Var, size_t> states_type;
278
279 long state;
280 states_type states;
281
282 public:
283 /// Default constructor
284 State() : state(0), states() {}
285
286 /// Initialize from VarSet
287 State( const VarSet &vs ) : state(0) {
288 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
289 states[*v] = 0;
290 }
291
292 /// Return linear state
293 operator size_t() const {
294 assert( valid() );
295 return( state );
296 }
297
298 /// Return state of variable n, or zero if n is not in this State
299 size_t operator() ( const Var &n ) const {
300 assert( valid() );
301 states_type::const_iterator entry = states.find( n );
302 if( entry == states.end() )
303 return 0;
304 else
305 return entry->second;
306 }
307
308 /// Return linear state of variables in varset, setting them to zero if they are not in this State
309 size_t operator() ( const VarSet &vs ) const {
310 assert( valid() );
311 size_t vs_state = 0;
312 size_t prod = 1;
313 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
314 states_type::const_iterator entry = states.find( *v );
315 if( entry != states.end() )
316 vs_state += entry->second * prod;
317 prod *= v->states();
318 }
319 return vs_state;
320 }
321
322 /// Prefix increment operator
323 void operator++( ) {
324 if( valid() ) {
325 state++;
326 states_type::iterator entry = states.begin();
327 while( entry != states.end() ) {
328 if( ++(entry->second) < entry->first.states() )
329 break;
330 entry->second = 0;
331 entry++;
332 }
333 if( entry == states.end() )
334 state = -1;
335 }
336 }
337
338 /// Postfix increment operator
339 void operator++( int ) {
340 operator++();
341 }
342
343 /// Returns true if the current state is valid
344 bool valid() const {
345 return( state >= 0 );
346 }
347 };
348
349
350 } // end of namespace dai
351
352
353 #endif