Finished a new release: libDAI v0.2.6
[libdai.git] / tests / unit / index_test.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #include <dai/index.h>
12 #include <strstream>
13 #include <map>
14
15
16 using namespace dai;
17
18
19 #define BOOST_TEST_MODULE IndexTest
20
21
22 #include <boost/test/unit_test.hpp>
23
24
25 BOOST_AUTO_TEST_CASE( IndexForTest ) {
26 IndexFor x;
27 BOOST_CHECK( !x.valid() );
28 x.reset();
29 BOOST_CHECK( x.valid() );
30
31 size_t nrVars = 5;
32 std::vector<Var> vars;
33 for( size_t i = 0; i < nrVars; i++ )
34 vars.push_back( Var( i, i+2 ) );
35
36 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
37 VarSet indexVars;
38 VarSet forVars;
39 for( size_t i = 0; i < 5; i++ ) {
40 if( rnd(2) == 0 )
41 indexVars |= vars[i];
42 if( rnd(2) == 0 )
43 forVars |= vars[i];
44 }
45 IndexFor ind( indexVars, forVars );
46 size_t iter = 0;
47 for( ; ind.valid(); ind++, iter++ )
48 BOOST_CHECK_EQUAL( calcLinearState( indexVars, calcState( forVars, iter ) ), (size_t)ind );
49 BOOST_CHECK_EQUAL( iter, forVars.nrStates() );
50 iter = 0;
51 ind.reset();
52 for( ; ind.valid(); ++ind, iter++ )
53 BOOST_CHECK_EQUAL( calcLinearState( indexVars, calcState( forVars, iter ) ), (size_t)ind );
54 BOOST_CHECK_EQUAL( iter, forVars.nrStates() );
55 }
56 }
57
58
59 BOOST_AUTO_TEST_CASE( PermuteTest ) {
60 Permute x;
61
62 Var x0(0, 2);
63 Var x1(1, 3);
64 Var x2(2, 2);
65 std::vector<Var> V;
66 V.push_back( x1 );
67 V.push_back( x2 );
68 V.push_back( x0 );
69 VarSet X( V.begin(), V.end() );
70 Permute sigma(V);
71 BOOST_CHECK_EQUAL( sigma.sigma().size(), 3 );
72 BOOST_CHECK_EQUAL( sigma.sigma()[0], 2 );
73 BOOST_CHECK_EQUAL( sigma.sigma()[1], 0 );
74 BOOST_CHECK_EQUAL( sigma.sigma()[2], 1 );
75 BOOST_CHECK_EQUAL( sigma[0], 2 );
76 BOOST_CHECK_EQUAL( sigma[1], 0 );
77 BOOST_CHECK_EQUAL( sigma[2], 1 );
78 BOOST_CHECK_EQUAL( sigma.ranges().size(), 3 );
79 BOOST_CHECK_EQUAL( sigma.ranges()[0], 3 );
80 BOOST_CHECK_EQUAL( sigma.ranges()[1], 2 );
81 BOOST_CHECK_EQUAL( sigma.ranges()[2], 2 );
82 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 0 ), 0 );
83 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 1 ), 2 );
84 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 2 ), 4 );
85 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 3 ), 6 );
86 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 4 ), 8 );
87 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 5 ), 10 );
88 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 6 ), 1 );
89 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 7 ), 3 );
90 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 8 ), 5 );
91 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 9 ), 7 );
92 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 10 ), 9 );
93 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 11 ), 11 );
94
95 Permute sigmar(V, true);
96 BOOST_CHECK_EQUAL( sigmar.sigma().size(), 3 );
97 BOOST_CHECK_EQUAL( sigmar.sigma()[0], 0 );
98 BOOST_CHECK_EQUAL( sigmar.sigma()[1], 2 );
99 BOOST_CHECK_EQUAL( sigmar.sigma()[2], 1 );
100 BOOST_CHECK_EQUAL( sigmar[0], 0 );
101 BOOST_CHECK_EQUAL( sigmar[1], 2 );
102 BOOST_CHECK_EQUAL( sigmar[2], 1 );
103 BOOST_CHECK_EQUAL( sigmar.ranges().size(), 3 );
104 BOOST_CHECK_EQUAL( sigmar.ranges()[0], 2 );
105 BOOST_CHECK_EQUAL( sigmar.ranges()[1], 2 );
106 BOOST_CHECK_EQUAL( sigmar.ranges()[2], 3 );
107 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 0 ), 0 );
108 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 1 ), 1 );
109 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 2 ), 6 );
110 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 3 ), 7 );
111 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 4 ), 2 );
112 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 5 ), 3 );
113 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 6 ), 8 );
114 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 7 ), 9 );
115 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 8 ), 4 );
116 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 9 ), 5 );
117 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 10 ), 10 );
118 BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 11 ), 11 );
119
120 std::vector<size_t> rs, sig;
121 rs.push_back(3);
122 rs.push_back(2);
123 rs.push_back(2);
124 sig.push_back(2);
125 sig.push_back(0);
126 sig.push_back(1);
127 Permute tau( rs, sig );
128 BOOST_CHECK( tau.sigma() == sig );
129 BOOST_CHECK( tau.ranges() == rs );
130 BOOST_CHECK_EQUAL( tau[0], 2 );
131 BOOST_CHECK_EQUAL( tau[1], 0 );
132 BOOST_CHECK_EQUAL( tau[2], 1 );
133 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 0 ), 0 );
134 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 1 ), 2 );
135 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 2 ), 4 );
136 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 3 ), 6 );
137 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 4 ), 8 );
138 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 5 ), 10 );
139 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 6 ), 1 );
140 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 7 ), 3 );
141 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 8 ), 5 );
142 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 9 ), 7 );
143 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 10 ), 9 );
144 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 11 ), 11 );
145
146 Permute tauinv = tau.inverse();
147 BOOST_CHECK_EQUAL( tauinv.sigma().size(), 3 );
148 BOOST_CHECK_EQUAL( tauinv.ranges().size(), 3 );
149 BOOST_CHECK_EQUAL( tauinv[0], 1 );
150 BOOST_CHECK_EQUAL( tauinv[1], 2 );
151 BOOST_CHECK_EQUAL( tauinv[2], 0 );
152 BOOST_CHECK_EQUAL( tauinv.ranges()[0], 2 );
153 BOOST_CHECK_EQUAL( tauinv.ranges()[1], 3 );
154 BOOST_CHECK_EQUAL( tauinv.ranges()[2], 2 );
155 for( size_t i = 0; i < 12; i++ ) {
156 BOOST_CHECK_EQUAL( tau.convertLinearIndex( tauinv.convertLinearIndex( i ) ), i );
157 BOOST_CHECK_EQUAL( tauinv.convertLinearIndex( tau.convertLinearIndex( i ) ), i );
158 }
159 }
160
161
162 BOOST_AUTO_TEST_CASE( multiforTest ) {
163 multifor x;
164 BOOST_CHECK( x.valid() );
165
166 std::vector<size_t> ranges;
167 ranges.push_back( 3 );
168 ranges.push_back( 4 );
169 ranges.push_back( 5 );
170 multifor S(ranges);
171 size_t s = 0;
172 for( size_t s2 = 0; s2 < 5; s2++ )
173 for( size_t s1 = 0; s1 < 4; s1++ )
174 for( size_t s0 = 0; s0 < 3; s0++, s++, S++ ) {
175 BOOST_CHECK( S.valid() );
176 BOOST_CHECK_EQUAL( s, (size_t)S );
177 BOOST_CHECK_EQUAL( S[0], s0 );
178 BOOST_CHECK_EQUAL( S[1], s1 );
179 BOOST_CHECK_EQUAL( S[2], s2 );
180 }
181 BOOST_CHECK( !S.valid() );
182
183 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
184 std::vector<size_t> dims;
185 size_t total = 1;
186 for( size_t i = 0; i < 4; i++ ) {
187 dims.push_back( rnd(3) + 1 );
188 total *= dims.back();
189 }
190 multifor ind( dims );
191 size_t iter = 0;
192 for( ; ind.valid(); ind++, iter++ ) {
193 BOOST_CHECK_EQUAL( (size_t)ind, iter );
194 BOOST_CHECK_EQUAL( ind[0], iter % dims[0] );
195 BOOST_CHECK_EQUAL( ind[1], (iter / dims[0]) % dims[1] );
196 BOOST_CHECK_EQUAL( ind[2], (iter / (dims[0] * dims[1])) % dims[2] );
197 BOOST_CHECK_EQUAL( ind[3], (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
198 }
199 BOOST_CHECK_EQUAL( iter, total );
200 iter = 0;
201 ind.reset();
202 for( ; ind.valid(); ++ind, iter++ ) {
203 BOOST_CHECK_EQUAL( (size_t)ind, iter );
204 BOOST_CHECK_EQUAL( ind[0], iter % dims[0] );
205 BOOST_CHECK_EQUAL( ind[1], (iter / dims[0]) % dims[1] );
206 BOOST_CHECK_EQUAL( ind[2], (iter / (dims[0] * dims[1])) % dims[2] );
207 BOOST_CHECK_EQUAL( ind[3], (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
208 }
209 BOOST_CHECK_EQUAL( iter, total );
210 }
211 }
212
213
214 BOOST_AUTO_TEST_CASE( StateTest ) {
215 State x;
216 BOOST_CHECK( x.valid() );
217
218 Var v0( 0, 3 );
219 Var v1( 1, 4 );
220 Var v2( 3, 5 );
221 VarSet vars;
222 vars |= v2;
223 vars |= v1;
224 vars |= v0;
225 State S( vars );
226 size_t s = 0;
227 for( size_t s2 = 0; s2 < 5; s2++ )
228 for( size_t s1 = 0; s1 < 4; s1++ )
229 for( size_t s0 = 0; s0 < 3; s0++, s++, S++ ) {
230 BOOST_CHECK( S.valid() );
231 BOOST_CHECK_EQUAL( s, (size_t)S );
232 BOOST_CHECK_EQUAL( S(v0), s0 );
233 BOOST_CHECK_EQUAL( S(v1), s1 );
234 BOOST_CHECK_EQUAL( S(v2), s2 );
235 BOOST_CHECK_EQUAL( S( Var( 2, 2 ) ), 0 );
236 }
237 BOOST_CHECK( !S.valid() );
238 S.reset();
239 std::vector<std::pair<Var, size_t> > ps;
240 ps.push_back( std::make_pair( Var( 2, 2 ), 1 ) );
241 ps.push_back( std::make_pair( Var( 4, 2 ), 1 ) );
242 S.insert( ps.begin(), ps.end() );
243 BOOST_CHECK( S.valid() );
244 BOOST_CHECK_EQUAL( (size_t)S, 132 );
245
246 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
247 std::vector<size_t> dims;
248 size_t total = 1;
249 for( size_t i = 0; i < 4; i++ ) {
250 dims.push_back( rnd(3) + 1 );
251 total *= dims.back();
252 }
253 std::vector<Var> vs;
254 for( size_t i = 0; i < 4; i++ )
255 vs.push_back( Var( i, dims[i] ) );
256 State ind( VarSet( vs.begin(), vs.end() ) );
257 size_t iter = 0;
258 for( ; ind.valid(); ind++, iter++ ) {
259 BOOST_CHECK_EQUAL( (size_t)ind, iter );
260 BOOST_CHECK_EQUAL( ind(vs[0]), iter % dims[0] );
261 BOOST_CHECK_EQUAL( ind(vs[1]), (iter / dims[0]) % dims[1] );
262 BOOST_CHECK_EQUAL( ind(vs[2]), (iter / (dims[0] * dims[1])) % dims[2] );
263 BOOST_CHECK_EQUAL( ind(vs[3]), (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
264 BOOST_CHECK_EQUAL( ind(VarSet(vs[0], vs[1])), iter % (dims[0] * dims[1]) );
265 BOOST_CHECK_EQUAL( ind(VarSet(vs[1], vs[2])), (iter / dims[0]) % (dims[1] * dims[2]) );
266 BOOST_CHECK_EQUAL( ind(VarSet(vs[2], vs[3])), (iter / (dims[0] * dims[1])) % (dims[2] * dims[3]) );
267 BOOST_CHECK_EQUAL( ind(VarSet(vs.begin(), vs.end())), iter );
268 State indcopy( VarSet(vs.begin(), vs.end()), (size_t)ind );
269 BOOST_CHECK_EQUAL( ind(vs[0]), indcopy(vs[0]) );
270 BOOST_CHECK_EQUAL( ind(vs[1]), indcopy(vs[1]) );
271 BOOST_CHECK_EQUAL( ind(vs[2]), indcopy(vs[2]) );
272 BOOST_CHECK_EQUAL( ind(vs[3]), indcopy(vs[3]) );
273 State indcopy2( indcopy.get() );
274 BOOST_CHECK_EQUAL( ind(vs[0]), indcopy2(vs[0]) );
275 BOOST_CHECK_EQUAL( ind(vs[1]), indcopy2(vs[1]) );
276 BOOST_CHECK_EQUAL( ind(vs[2]), indcopy2(vs[2]) );
277 BOOST_CHECK_EQUAL( ind(vs[3]), indcopy2(vs[3]) );
278 std::map<Var,size_t> indmap( ind );
279 State indcopy3( indmap );
280 BOOST_CHECK_EQUAL( ind(vs[0]), indcopy3(vs[0]) );
281 BOOST_CHECK_EQUAL( ind(vs[1]), indcopy3(vs[1]) );
282 BOOST_CHECK_EQUAL( ind(vs[2]), indcopy3(vs[2]) );
283 BOOST_CHECK_EQUAL( ind(vs[3]), indcopy3(vs[3]) );
284 }
285 BOOST_CHECK_EQUAL( iter, total );
286 iter = 0;
287 ind.reset();
288 for( ; ind.valid(); ++ind, iter++ ) {
289 BOOST_CHECK_EQUAL( (size_t)ind, iter );
290 BOOST_CHECK_EQUAL( ind(vs[0]), iter % dims[0] );
291 BOOST_CHECK_EQUAL( ind(vs[1]), (iter / dims[0]) % dims[1] );
292 BOOST_CHECK_EQUAL( ind(vs[2]), (iter / (dims[0] * dims[1])) % dims[2] );
293 BOOST_CHECK_EQUAL( ind(vs[3]), (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
294 State::const_iterator ci = ind.begin();
295 BOOST_CHECK_EQUAL( (ci++)->second, iter % dims[0] );
296 BOOST_CHECK_EQUAL( (ci++)->second, (iter / dims[0]) % dims[1] );
297 BOOST_CHECK_EQUAL( (ci++)->second, (iter / (dims[0] * dims[1])) % dims[2] );
298 BOOST_CHECK_EQUAL( (ci++)->second, (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
299 BOOST_CHECK( ci == ind.end() );
300 }
301 BOOST_CHECK_EQUAL( iter, total );
302 State::const_iterator ci = ind.begin();
303 BOOST_CHECK_EQUAL( (ci++)->first, vs[0] );
304 BOOST_CHECK_EQUAL( (ci++)->first, vs[1] );
305 BOOST_CHECK_EQUAL( (ci++)->first, vs[2] );
306 BOOST_CHECK_EQUAL( (ci++)->first, vs[3] );
307 BOOST_CHECK( ci == ind.end() );
308 }
309 }