Some improvements to jtree and regiongraph and started work on regiongraph unit tests
[libdai.git] / tests / unit / index.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/index.h>
15 #include <strstream>
16 #include <map>
17
18
19 using namespace dai;
20
21
22 #define BOOST_TEST_MODULE IndexTest
23
24
25 #include <boost/test/unit_test.hpp>
26
27
28 BOOST_AUTO_TEST_CASE( IndexForTest ) {
29 IndexFor x;
30 BOOST_CHECK( !x.valid() );
31 x.reset();
32 BOOST_CHECK( x.valid() );
33
34 size_t nrVars = 5;
35 std::vector<Var> vars;
36 for( size_t i = 0; i < nrVars; i++ )
37 vars.push_back( Var( i, i+2 ) );
38
39 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
40 VarSet indexVars;
41 VarSet forVars;
42 for( size_t i = 0; i < 5; i++ ) {
43 if( rnd(2) == 0 )
44 indexVars |= vars[i];
45 if( rnd(2) == 0 )
46 forVars |= vars[i];
47 }
48 IndexFor ind( indexVars, forVars );
49 size_t iter = 0;
50 for( ; ind.valid(); ind++, iter++ )
51 BOOST_CHECK_EQUAL( calcLinearState( indexVars, calcState( forVars, iter ) ), (size_t)ind );
52 BOOST_CHECK_EQUAL( iter, forVars.nrStates() );
53 iter = 0;
54 ind.reset();
55 for( ; ind.valid(); ++ind, iter++ )
56 BOOST_CHECK_EQUAL( calcLinearState( indexVars, calcState( forVars, iter ) ), (size_t)ind );
57 BOOST_CHECK_EQUAL( iter, forVars.nrStates() );
58 }
59 }
60
61
62 BOOST_AUTO_TEST_CASE( PermuteTest ) {
63 Permute x;
64
65 Var x0(0, 2);
66 Var x1(1, 3);
67 Var x2(2, 2);
68 std::vector<Var> V;
69 V.push_back( x1 );
70 V.push_back( x2 );
71 V.push_back( x0 );
72 VarSet X( V.begin(), V.end() );
73 Permute sigma(V);
74 BOOST_CHECK_EQUAL( sigma.sigma()[0], 2 );
75 BOOST_CHECK_EQUAL( sigma.sigma()[1], 0 );
76 BOOST_CHECK_EQUAL( sigma.sigma()[2], 1 );
77 BOOST_CHECK_EQUAL( sigma[0], 2 );
78 BOOST_CHECK_EQUAL( sigma[1], 0 );
79 BOOST_CHECK_EQUAL( sigma[2], 1 );
80 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 0 ), 0 );
81 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 1 ), 2 );
82 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 2 ), 4 );
83 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 3 ), 6 );
84 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 4 ), 8 );
85 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 5 ), 10 );
86 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 6 ), 1 );
87 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 7 ), 3 );
88 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 8 ), 5 );
89 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 9 ), 7 );
90 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 10 ), 9 );
91 BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 11 ), 11 );
92
93 std::vector<size_t> rs, sig;
94 rs.push_back(2);
95 rs.push_back(3);
96 rs.push_back(2);
97 sig.push_back(2);
98 sig.push_back(0);
99 sig.push_back(1);
100 Permute tau( rs, sig );
101 BOOST_CHECK( tau.sigma() == sig );
102 BOOST_CHECK_EQUAL( tau[0], 2 );
103 BOOST_CHECK_EQUAL( tau[1], 0 );
104 BOOST_CHECK_EQUAL( tau[2], 1 );
105 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 0 ), 0 );
106 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 1 ), 2 );
107 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 2 ), 4 );
108 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 3 ), 6 );
109 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 4 ), 8 );
110 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 5 ), 10 );
111 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 6 ), 1 );
112 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 7 ), 3 );
113 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 8 ), 5 );
114 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 9 ), 7 );
115 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 10 ), 9 );
116 BOOST_CHECK_EQUAL( tau.convertLinearIndex( 11 ), 11 );
117 }
118
119
120 BOOST_AUTO_TEST_CASE( multiforTest ) {
121 multifor x;
122 BOOST_CHECK( x.valid() );
123
124 std::vector<size_t> ranges;
125 ranges.push_back( 3 );
126 ranges.push_back( 4 );
127 ranges.push_back( 5 );
128 multifor S(ranges);
129 size_t s = 0;
130 for( size_t s2 = 0; s2 < 5; s2++ )
131 for( size_t s1 = 0; s1 < 4; s1++ )
132 for( size_t s0 = 0; s0 < 3; s0++, s++, S++ ) {
133 BOOST_CHECK( S.valid() );
134 BOOST_CHECK_EQUAL( s, (size_t)S );
135 BOOST_CHECK_EQUAL( S[0], s0 );
136 BOOST_CHECK_EQUAL( S[1], s1 );
137 BOOST_CHECK_EQUAL( S[2], s2 );
138 }
139 BOOST_CHECK( !S.valid() );
140
141 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
142 std::vector<size_t> dims;
143 size_t total = 1;
144 for( size_t i = 0; i < 4; i++ ) {
145 dims.push_back( rnd(3) + 1 );
146 total *= dims.back();
147 }
148 multifor ind( dims );
149 size_t iter = 0;
150 for( ; ind.valid(); ind++, iter++ ) {
151 BOOST_CHECK_EQUAL( (size_t)ind, iter );
152 BOOST_CHECK_EQUAL( ind[0], iter % dims[0] );
153 BOOST_CHECK_EQUAL( ind[1], (iter / dims[0]) % dims[1] );
154 BOOST_CHECK_EQUAL( ind[2], (iter / (dims[0] * dims[1])) % dims[2] );
155 BOOST_CHECK_EQUAL( ind[3], (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
156 }
157 BOOST_CHECK_EQUAL( iter, total );
158 iter = 0;
159 ind.reset();
160 for( ; ind.valid(); ++ind, iter++ ) {
161 BOOST_CHECK_EQUAL( (size_t)ind, iter );
162 BOOST_CHECK_EQUAL( ind[0], iter % dims[0] );
163 BOOST_CHECK_EQUAL( ind[1], (iter / dims[0]) % dims[1] );
164 BOOST_CHECK_EQUAL( ind[2], (iter / (dims[0] * dims[1])) % dims[2] );
165 BOOST_CHECK_EQUAL( ind[3], (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
166 }
167 BOOST_CHECK_EQUAL( iter, total );
168 }
169 }
170
171
172 BOOST_AUTO_TEST_CASE( StateTest ) {
173 State x;
174 BOOST_CHECK( x.valid() );
175
176 Var v0( 0, 3 );
177 Var v1( 1, 4 );
178 Var v2( 3, 5 );
179 VarSet vars;
180 vars |= v2;
181 vars |= v1;
182 vars |= v0;
183 State S( vars );
184 size_t s = 0;
185 for( size_t s2 = 0; s2 < 5; s2++ )
186 for( size_t s1 = 0; s1 < 4; s1++ )
187 for( size_t s0 = 0; s0 < 3; s0++, s++, S++ ) {
188 BOOST_CHECK( S.valid() );
189 BOOST_CHECK_EQUAL( s, (size_t)S );
190 BOOST_CHECK_EQUAL( S(v0), s0 );
191 BOOST_CHECK_EQUAL( S(v1), s1 );
192 BOOST_CHECK_EQUAL( S(v2), s2 );
193 BOOST_CHECK_EQUAL( S( Var( 2, 2 ) ), 0 );
194 }
195 BOOST_CHECK( !S.valid() );
196 S.reset();
197 std::vector<std::pair<Var, size_t> > ps;
198 ps.push_back( std::make_pair( Var( 2, 2 ), 1 ) );
199 ps.push_back( std::make_pair( Var( 4, 2 ), 1 ) );
200 S.insert( ps.begin(), ps.end() );
201 BOOST_CHECK( S.valid() );
202 BOOST_CHECK_EQUAL( (size_t)S, 132 );
203
204 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
205 std::vector<size_t> dims;
206 size_t total = 1;
207 for( size_t i = 0; i < 4; i++ ) {
208 dims.push_back( rnd(3) + 1 );
209 total *= dims.back();
210 }
211 std::vector<Var> vs;
212 for( size_t i = 0; i < 4; i++ )
213 vs.push_back( Var( i, dims[i] ) );
214 State ind( VarSet( vs.begin(), vs.end() ) );
215 size_t iter = 0;
216 for( ; ind.valid(); ind++, iter++ ) {
217 BOOST_CHECK_EQUAL( (size_t)ind, iter );
218 BOOST_CHECK_EQUAL( ind(vs[0]), iter % dims[0] );
219 BOOST_CHECK_EQUAL( ind(vs[1]), (iter / dims[0]) % dims[1] );
220 BOOST_CHECK_EQUAL( ind(vs[2]), (iter / (dims[0] * dims[1])) % dims[2] );
221 BOOST_CHECK_EQUAL( ind(vs[3]), (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
222 BOOST_CHECK_EQUAL( ind(VarSet(vs[0], vs[1])), iter % (dims[0] * dims[1]) );
223 BOOST_CHECK_EQUAL( ind(VarSet(vs[1], vs[2])), (iter / dims[0]) % (dims[1] * dims[2]) );
224 BOOST_CHECK_EQUAL( ind(VarSet(vs[2], vs[3])), (iter / (dims[0] * dims[1])) % (dims[2] * dims[3]) );
225 BOOST_CHECK_EQUAL( ind(VarSet(vs.begin(), vs.end())), iter );
226 State indcopy( VarSet(vs.begin(), vs.end()), (size_t)ind );
227 BOOST_CHECK_EQUAL( ind(vs[0]), indcopy(vs[0]) );
228 BOOST_CHECK_EQUAL( ind(vs[1]), indcopy(vs[1]) );
229 BOOST_CHECK_EQUAL( ind(vs[2]), indcopy(vs[2]) );
230 BOOST_CHECK_EQUAL( ind(vs[3]), indcopy(vs[3]) );
231 State indcopy2( indcopy.get() );
232 BOOST_CHECK_EQUAL( ind(vs[0]), indcopy2(vs[0]) );
233 BOOST_CHECK_EQUAL( ind(vs[1]), indcopy2(vs[1]) );
234 BOOST_CHECK_EQUAL( ind(vs[2]), indcopy2(vs[2]) );
235 BOOST_CHECK_EQUAL( ind(vs[3]), indcopy2(vs[3]) );
236 std::map<Var,size_t> indmap( ind );
237 State indcopy3( indmap );
238 BOOST_CHECK_EQUAL( ind(vs[0]), indcopy3(vs[0]) );
239 BOOST_CHECK_EQUAL( ind(vs[1]), indcopy3(vs[1]) );
240 BOOST_CHECK_EQUAL( ind(vs[2]), indcopy3(vs[2]) );
241 BOOST_CHECK_EQUAL( ind(vs[3]), indcopy3(vs[3]) );
242 }
243 BOOST_CHECK_EQUAL( iter, total );
244 iter = 0;
245 ind.reset();
246 for( ; ind.valid(); ++ind, iter++ ) {
247 BOOST_CHECK_EQUAL( (size_t)ind, iter );
248 BOOST_CHECK_EQUAL( ind(vs[0]), iter % dims[0] );
249 BOOST_CHECK_EQUAL( ind(vs[1]), (iter / dims[0]) % dims[1] );
250 BOOST_CHECK_EQUAL( ind(vs[2]), (iter / (dims[0] * dims[1])) % dims[2] );
251 BOOST_CHECK_EQUAL( ind(vs[3]), (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
252 State::const_iterator ci = ind.begin();
253 BOOST_CHECK_EQUAL( (ci++)->second, iter % dims[0] );
254 BOOST_CHECK_EQUAL( (ci++)->second, (iter / dims[0]) % dims[1] );
255 BOOST_CHECK_EQUAL( (ci++)->second, (iter / (dims[0] * dims[1])) % dims[2] );
256 BOOST_CHECK_EQUAL( (ci++)->second, (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
257 BOOST_CHECK( ci == ind.end() );
258 }
259 BOOST_CHECK_EQUAL( iter, total );
260 State::const_iterator ci = ind.begin();
261 BOOST_CHECK_EQUAL( (ci++)->first, vs[0] );
262 BOOST_CHECK_EQUAL( (ci++)->first, vs[1] );
263 BOOST_CHECK_EQUAL( (ci++)->first, vs[2] );
264 BOOST_CHECK_EQUAL( (ci++)->first, vs[3] );
265 BOOST_CHECK( ci == ind.end() );
266 }
267 }