Extended SWIG python interface (inspired by Kyle Ellrott): inference is possible...
[libdai.git] / tests / unit / index_test.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/index.h>
10 #include <strstream>
11 #include <map>
12
13
14 using namespace dai;
15
16
17 #define BOOST_TEST_MODULE IndexTest
18
19
20 #include <boost/test/unit_test.hpp>
21
22
23 BOOST_AUTO_TEST_CASE( IndexForTest ) {
24 IndexFor x;
25 BOOST_CHECK( !x.valid() );
26 x.reset();
27 BOOST_CHECK( x.valid() );
28
29 size_t nrVars = 5;
30 std::vector<Var> vars;
31 for( size_t i = 0; i < nrVars; i++ )
32 vars.push_back( Var( i, i+2 ) );
33
34 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
35 VarSet indexVars;
36 VarSet forVars;
37 for( size_t i = 0; i < 5; i++ ) {
38 if( rnd(2) == 0 )
39 indexVars |= vars[i];
40 if( rnd(2) == 0 )
41 forVars |= vars[i];
42 }
43 IndexFor ind( indexVars, forVars );
44 size_t iter = 0;
45 for( ; ind.valid(); ind++, iter++ )
46 BOOST_CHECK_EQUAL( calcLinearState( indexVars, calcState( forVars, iter ) ), (size_t)ind );
47 BOOST_CHECK_EQUAL( iter, forVars.nrStates() );
48 iter = 0;
49 ind.reset();
50 for( ; ind.valid(); ++ind, iter++ )
51 BOOST_CHECK_EQUAL( calcLinearState( indexVars, calcState( forVars, iter ) ), (size_t)ind );
52 BOOST_CHECK_EQUAL( iter, forVars.nrStates() );
53 }
54 }
55
56
57 BOOST_AUTO_TEST_CASE( PermuteTest ) {
58 Permute x;
59
60 Var x0(0, 2);
61 Var x1(1, 3);
62 Var x2(2, 2);
63 std::vector<Var> V;
64 V.push_back( x1 );
65 V.push_back( x2 );
66 V.push_back( x0 );
67 VarSet X( V.begin(), V.end() );
68 Permute sigma(V);
69 BOOST_CHECK_EQUAL( sigma.sigma().size(), 3 );
70 BOOST_CHECK_EQUAL( sigma.sigma()[0], 2 );
71 BOOST_CHECK_EQUAL( sigma.sigma()[1], 0 );
72 BOOST_CHECK_EQUAL( sigma.sigma()[2], 1 );
73 BOOST_CHECK_EQUAL( sigma[0], 2 );
74 BOOST_CHECK_EQUAL( sigma[1], 0 );
75 BOOST_CHECK_EQUAL( sigma[2], 1 );
76 BOOST_CHECK_EQUAL( sigma.ranges().size(), 3 );
77 BOOST_CHECK_EQUAL( sigma.ranges()[0], 3 );
78 BOOST_CHECK_EQUAL( sigma.ranges()[1], 2 );
79 BOOST_CHECK_EQUAL( sigma.ranges()[2], 2 );
80 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 0 ), 0 );
81 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 1 ), 2 );
82 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 2 ), 4 );
83 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 3 ), 6 );
84 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 4 ), 8 );
85 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 5 ), 10 );
86 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 6 ), 1 );
87 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 7 ), 3 );
88 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 8 ), 5 );
89 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 9 ), 7 );
90 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 10 ), 9 );
91 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 11 ), 11 );
92
93 Permute sigmar(V, true);
94 BOOST_CHECK_EQUAL( sigmar.sigma().size(), 3 );
95 BOOST_CHECK_EQUAL( sigmar.sigma()[0], 0 );
96 BOOST_CHECK_EQUAL( sigmar.sigma()[1], 2 );
97 BOOST_CHECK_EQUAL( sigmar.sigma()[2], 1 );
98 BOOST_CHECK_EQUAL( sigmar[0], 0 );
99 BOOST_CHECK_EQUAL( sigmar[1], 2 );
100 BOOST_CHECK_EQUAL( sigmar[2], 1 );
101 BOOST_CHECK_EQUAL( sigmar.ranges().size(), 3 );
102 BOOST_CHECK_EQUAL( sigmar.ranges()[0], 2 );
103 BOOST_CHECK_EQUAL( sigmar.ranges()[1], 2 );
104 BOOST_CHECK_EQUAL( sigmar.ranges()[2], 3 );
105 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 0 ), 0 );
106 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 1 ), 1 );
107 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 2 ), 6 );
108 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 3 ), 7 );
109 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 4 ), 2 );
110 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 5 ), 3 );
111 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 6 ), 8 );
112 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 7 ), 9 );
113 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 8 ), 4 );
114 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 9 ), 5 );
115 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 10 ), 10 );
116 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 11 ), 11 );
117
118 std::vector<size_t> rs, sig;
119 rs.push_back(3);
120 rs.push_back(2);
121 rs.push_back(2);
122 sig.push_back(2);
123 sig.push_back(0);
124 sig.push_back(1);
125 Permute tau( rs, sig );
126 BOOST_CHECK( tau.sigma() == sig );
127 BOOST_CHECK( tau.ranges() == rs );
128 BOOST_CHECK_EQUAL( tau[0], 2 );
129 BOOST_CHECK_EQUAL( tau[1], 0 );
130 BOOST_CHECK_EQUAL( tau[2], 1 );
131 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 0 ), 0 );
132 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 1 ), 2 );
133 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 2 ), 4 );
134 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 3 ), 6 );
135 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 4 ), 8 );
136 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 5 ), 10 );
137 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 6 ), 1 );
138 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 7 ), 3 );
139 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 8 ), 5 );
140 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 9 ), 7 );
141 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 10 ), 9 );
142 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 11 ), 11 );
143
144 Permute tauinv = tau.inverse();
145 BOOST_CHECK_EQUAL( tauinv.sigma().size(), 3 );
146 BOOST_CHECK_EQUAL( tauinv.ranges().size(), 3 );
147 BOOST_CHECK_EQUAL( tauinv[0], 1 );
148 BOOST_CHECK_EQUAL( tauinv[1], 2 );
149 BOOST_CHECK_EQUAL( tauinv[2], 0 );
150 BOOST_CHECK_EQUAL( tauinv.ranges()[0], 2 );
151 BOOST_CHECK_EQUAL( tauinv.ranges()[1], 3 );
152 BOOST_CHECK_EQUAL( tauinv.ranges()[2], 2 );
153 for( size_t i = 0; i < 12; i++ ) {
154 BOOST_CHECK_EQUAL( tau.convertLinearIndex( tauinv.convertLinearIndex( i ) ), i );
155 BOOST_CHECK_EQUAL( tauinv.convertLinearIndex( tau.convertLinearIndex( i ) ), i );
156 }
157 }
158
159
160 BOOST_AUTO_TEST_CASE( multiforTest ) {
161 multifor x;
162 BOOST_CHECK( x.valid() );
163
164 std::vector<size_t> ranges;
165 ranges.push_back( 3 );
166 ranges.push_back( 4 );
167 ranges.push_back( 5 );
168 multifor S(ranges);
169 size_t s = 0;
170 for( size_t s2 = 0; s2 < 5; s2++ )
171 for( size_t s1 = 0; s1 < 4; s1++ )
172 for( size_t s0 = 0; s0 < 3; s0++, s++, S++ ) {
173 BOOST_CHECK( S.valid() );
174 BOOST_CHECK_EQUAL( s, (size_t)S );
175 BOOST_CHECK_EQUAL( S[0], s0 );
176 BOOST_CHECK_EQUAL( S[1], s1 );
177 BOOST_CHECK_EQUAL( S[2], s2 );
178 }
179 BOOST_CHECK( !S.valid() );
180
181 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
182 std::vector<size_t> dims;
183 size_t total = 1;
184 for( size_t i = 0; i < 4; i++ ) {
185 dims.push_back( rnd(3) + 1 );
186 total *= dims.back();
187 }
188 multifor ind( dims );
189 size_t iter = 0;
190 for( ; ind.valid(); ind++, iter++ ) {
191 BOOST_CHECK_EQUAL( (size_t)ind, iter );
192 BOOST_CHECK_EQUAL( ind[0], iter % dims[0] );
193 BOOST_CHECK_EQUAL( ind[1], (iter / dims[0]) % dims[1] );
194 BOOST_CHECK_EQUAL( ind[2], (iter / (dims[0] * dims[1])) % dims[2] );
195 BOOST_CHECK_EQUAL( ind[3], (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
196 }
197 BOOST_CHECK_EQUAL( iter, total );
198 iter = 0;
199 ind.reset();
200 for( ; ind.valid(); ++ind, iter++ ) {
201 BOOST_CHECK_EQUAL( (size_t)ind, iter );
202 BOOST_CHECK_EQUAL( ind[0], iter % dims[0] );
203 BOOST_CHECK_EQUAL( ind[1], (iter / dims[0]) % dims[1] );
204 BOOST_CHECK_EQUAL( ind[2], (iter / (dims[0] * dims[1])) % dims[2] );
205 BOOST_CHECK_EQUAL( ind[3], (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
206 }
207 BOOST_CHECK_EQUAL( iter, total );
208 }
209 }
210
211
212 BOOST_AUTO_TEST_CASE( StateTest ) {
213 State x;
214 BOOST_CHECK( x.valid() );
215
216 Var v0( 0, 3 );
217 Var v1( 1, 4 );
218 Var v2( 3, 5 );
219 VarSet vars;
220 vars |= v2;
221 vars |= v1;
222 vars |= v0;
223 State S( vars );
224 size_t s = 0;
225 for( size_t s2 = 0; s2 < 5; s2++ )
226 for( size_t s1 = 0; s1 < 4; s1++ )
227 for( size_t s0 = 0; s0 < 3; s0++, s++, S++ ) {
228 BOOST_CHECK( S.valid() );
229 BOOST_CHECK_EQUAL( s, (size_t)S );
230 BOOST_CHECK_EQUAL( S(v0), s0 );
231 BOOST_CHECK_EQUAL( S(v1), s1 );
232 BOOST_CHECK_EQUAL( S(v2), s2 );
233 BOOST_CHECK_EQUAL( S( Var( 2, 2 ) ), 0 );
234 }
235 BOOST_CHECK( !S.valid() );
236 S.reset();
237 std::vector<std::pair<Var, size_t> > ps;
238 ps.push_back( std::make_pair( Var( 2, 2 ), 1 ) );
239 ps.push_back( std::make_pair( Var( 4, 2 ), 1 ) );
240 S.insert( ps.begin(), ps.end() );
241 BOOST_CHECK( S.valid() );
242 BOOST_CHECK_EQUAL( (size_t)S, 132 );
243
244 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
245 std::vector<size_t> dims;
246 size_t total = 1;
247 for( size_t i = 0; i < 4; i++ ) {
248 dims.push_back( rnd(3) + 1 );
249 total *= dims.back();
250 }
251 std::vector<Var> vs;
252 for( size_t i = 0; i < 4; i++ )
253 vs.push_back( Var( i, dims[i] ) );
254 State ind( VarSet( vs.begin(), vs.end() ) );
255 size_t iter = 0;
256 for( ; ind.valid(); ind++, iter++ ) {
257 BOOST_CHECK_EQUAL( (size_t)ind, iter );
258 BOOST_CHECK_EQUAL( ind(vs[0]), iter % dims[0] );
259 BOOST_CHECK_EQUAL( ind(vs[1]), (iter / dims[0]) % dims[1] );
260 BOOST_CHECK_EQUAL( ind(vs[2]), (iter / (dims[0] * dims[1])) % dims[2] );
261 BOOST_CHECK_EQUAL( ind(vs[3]), (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
262 BOOST_CHECK_EQUAL( ind(VarSet(vs[0], vs[1])), iter % (dims[0] * dims[1]) );
263 BOOST_CHECK_EQUAL( ind(VarSet(vs[1], vs[2])), (iter / dims[0]) % (dims[1] * dims[2]) );
264 BOOST_CHECK_EQUAL( ind(VarSet(vs[2], vs[3])), (iter / (dims[0] * dims[1])) % (dims[2] * dims[3]) );
265 BOOST_CHECK_EQUAL( ind(VarSet(vs.begin(), vs.end())), iter );
266 State indcopy( VarSet(vs.begin(), vs.end()), (size_t)ind );
267 BOOST_CHECK_EQUAL( ind(vs[0]), indcopy(vs[0]) );
268 BOOST_CHECK_EQUAL( ind(vs[1]), indcopy(vs[1]) );
269 BOOST_CHECK_EQUAL( ind(vs[2]), indcopy(vs[2]) );
270 BOOST_CHECK_EQUAL( ind(vs[3]), indcopy(vs[3]) );
271 State indcopy2( indcopy.get() );
272 BOOST_CHECK_EQUAL( ind(vs[0]), indcopy2(vs[0]) );
273 BOOST_CHECK_EQUAL( ind(vs[1]), indcopy2(vs[1]) );
274 BOOST_CHECK_EQUAL( ind(vs[2]), indcopy2(vs[2]) );
275 BOOST_CHECK_EQUAL( ind(vs[3]), indcopy2(vs[3]) );
276 std::map<Var,size_t> indmap( ind );
277 State indcopy3( indmap );
278 BOOST_CHECK_EQUAL( ind(vs[0]), indcopy3(vs[0]) );
279 BOOST_CHECK_EQUAL( ind(vs[1]), indcopy3(vs[1]) );
280 BOOST_CHECK_EQUAL( ind(vs[2]), indcopy3(vs[2]) );
281 BOOST_CHECK_EQUAL( ind(vs[3]), indcopy3(vs[3]) );
282 }
283 BOOST_CHECK_EQUAL( iter, total );
284 iter = 0;
285 ind.reset();
286 for( ; ind.valid(); ++ind, iter++ ) {
287 BOOST_CHECK_EQUAL( (size_t)ind, iter );
288 BOOST_CHECK_EQUAL( ind(vs[0]), iter % dims[0] );
289 BOOST_CHECK_EQUAL( ind(vs[1]), (iter / dims[0]) % dims[1] );
290 BOOST_CHECK_EQUAL( ind(vs[2]), (iter / (dims[0] * dims[1])) % dims[2] );
291 BOOST_CHECK_EQUAL( ind(vs[3]), (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
292 State::const_iterator ci = ind.begin();
293 BOOST_CHECK_EQUAL( (ci++)->second, iter % dims[0] );
294 BOOST_CHECK_EQUAL( (ci++)->second, (iter / dims[0]) % dims[1] );
295 BOOST_CHECK_EQUAL( (ci++)->second, (iter / (dims[0] * dims[1])) % dims[2] );
296 BOOST_CHECK_EQUAL( (ci++)->second, (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
297 BOOST_CHECK( ci == ind.end() );
298 }
299 BOOST_CHECK_EQUAL( iter, total );
300 State::const_iterator ci = ind.begin();
301 BOOST_CHECK_EQUAL( (ci++)->first, vs[0] );
302 BOOST_CHECK_EQUAL( (ci++)->first, vs[1] );
303 BOOST_CHECK_EQUAL( (ci++)->first, vs[2] );
304 BOOST_CHECK_EQUAL( (ci++)->first, vs[3] );
305 BOOST_CHECK( ci == ind.end() );
306 }
307 }