Renamed unit tests
[libdai.git] / tests / unit / prob_test.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #include <dai/prob.h>
12 #include <strstream>
13
14
15 using namespace dai;
16
17
18 const double tol = 1e-8;
19
20
21 #define BOOST_TEST_MODULE ProbTest
22
23
24 #include <boost/test/unit_test.hpp>
25 #include <boost/test/floating_point_comparison.hpp>
26
27
28 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
29 // check constructors
30 Prob x1;
31 BOOST_CHECK_EQUAL( x1.size(), 0 );
32 BOOST_CHECK( x1.p() == Prob::container_type() );
33
34 Prob x2( 3 );
35 BOOST_CHECK_EQUAL( x2.size(), 3 );
36 BOOST_CHECK( x2.p() == Prob::container_type( 3, 1.0 / 3.0 ) );
37 BOOST_CHECK_EQUAL( x2[0], 1.0 / 3.0 );
38 BOOST_CHECK_EQUAL( x2[1], 1.0 / 3.0 );
39 BOOST_CHECK_EQUAL( x2[2], 1.0 / 3.0 );
40
41 Prob x3( 4, 1.0 );
42 BOOST_CHECK_EQUAL( x3.size(), 4 );
43 BOOST_CHECK( x3.p() == Prob::container_type( 4, 1.0 ) );
44 BOOST_CHECK_EQUAL( x3[0], 1.0 );
45 BOOST_CHECK_EQUAL( x3[1], 1.0 );
46 BOOST_CHECK_EQUAL( x3[2], 1.0 );
47 BOOST_CHECK_EQUAL( x3[3], 1.0 );
48 x3.set( 0, 0.5 );
49 x3.set( 1, 1.0 );
50 x3.set( 2, 2.0 );
51 x3.set( 3, 4.0 );
52
53 std::vector<Real> v;
54 v.push_back( 0.5 );
55 v.push_back( 1.0 );
56 v.push_back( 2.0 );
57 v.push_back( 4.0 );
58 Prob x4( v.begin(), v.end(), 0 );
59 BOOST_CHECK_EQUAL( x4.size(), 4 );
60 BOOST_CHECK( x4.p() == x3.p() );
61 BOOST_CHECK( x4 == x3 );
62 BOOST_CHECK_EQUAL( x4[0], 0.5 );
63 BOOST_CHECK_EQUAL( x4[1], 1.0 );
64 BOOST_CHECK_EQUAL( x4[2], 2.0 );
65 BOOST_CHECK_EQUAL( x4[3], 4.0 );
66
67 Prob x5( v.begin(), v.end(), v.size() );
68 BOOST_CHECK_EQUAL( x5.size(), 4 );
69 BOOST_CHECK( x5.p() == x3.p() );
70 BOOST_CHECK( x5 == x3 );
71 BOOST_CHECK_EQUAL( x5[0], 0.5 );
72 BOOST_CHECK_EQUAL( x5[1], 1.0 );
73 BOOST_CHECK_EQUAL( x5[2], 2.0 );
74 BOOST_CHECK_EQUAL( x5[3], 4.0 );
75
76 std::vector<int> y( 3, 2 );
77 Prob x6( y );
78 BOOST_CHECK_EQUAL( x6.size(), 3 );
79 BOOST_CHECK_EQUAL( x6[0], 2.0 );
80 BOOST_CHECK_EQUAL( x6[1], 2.0 );
81 BOOST_CHECK_EQUAL( x6[2], 2.0 );
82
83 Prob x7( x6 );
84 BOOST_CHECK( x7 == x6 );
85
86 Prob x8 = x6;
87 BOOST_CHECK( x8 == x6 );
88
89 x7.resize( 5 );
90 BOOST_CHECK_EQUAL( x7.size(), 5 );
91 BOOST_CHECK_EQUAL( x7[0], 2.0 );
92 BOOST_CHECK_EQUAL( x7[1], 2.0 );
93 BOOST_CHECK_EQUAL( x7[2], 2.0 );
94 BOOST_CHECK_EQUAL( x7[3], 0.0 );
95 BOOST_CHECK_EQUAL( x7[4], 0.0 );
96
97 x8.resize( 1 );
98 BOOST_CHECK_EQUAL( x8.size(), 1 );
99 BOOST_CHECK_EQUAL( x8[0], 2.0 );
100 }
101
102
103 #ifndef DAI_SPARSE
104 BOOST_AUTO_TEST_CASE( IteratorTest ) {
105 Prob x( 5, 0.0 );
106 size_t i;
107 for( i = 0; i < x.size(); i++ )
108 x.set( i, i );
109
110 i = 0;
111 for( Prob::const_iterator cit = x.begin(); cit != x.end(); cit++, i++ )
112 BOOST_CHECK_EQUAL( *cit, i );
113
114 i = 0;
115 for( Prob::iterator it = x.begin(); it != x.end(); it++, i++ )
116 *it = 4 - i;
117
118 i = 0;
119 for( Prob::const_iterator it = x.begin(); it != x.end(); it++, i++ )
120 BOOST_CHECK_EQUAL( *it, 4 - i );
121
122 i = 0;
123 for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
124 BOOST_CHECK_EQUAL( *crit, i );
125
126 i = 0;
127 for( Prob::reverse_iterator rit = x.rbegin(); rit != x.rend(); rit++, i++ )
128 *rit = 2 * i;
129
130 i = 0;
131 for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
132 BOOST_CHECK_EQUAL( *crit, 2 * i );
133 }
134 #endif
135
136
137 BOOST_AUTO_TEST_CASE( QueriesTest ) {
138 Prob x( 5, 0.0 );
139 for( size_t i = 0; i < x.size(); i++ )
140 x.set( i, 2.0 - i );
141
142 // test accumulate, min, max, sum, sumAbs, maxAbs
143 BOOST_CHECK_EQUAL( x.sum(), 0.0 );
144 BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_id<Real>() ), 0.0 );
145 BOOST_CHECK_EQUAL( x.accumulateSum( 1.0, fo_id<Real>() ), 1.0 );
146 BOOST_CHECK_EQUAL( x.accumulateSum( -1.0, fo_id<Real>() ), -1.0 );
147 BOOST_CHECK_EQUAL( x.max(), 2.0 );
148 BOOST_CHECK_EQUAL( x.accumulateMax( -INFINITY, fo_id<Real>(), false ), 2.0 );
149 BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_id<Real>(), false ), 3.0 );
150 BOOST_CHECK_EQUAL( x.accumulateMax( -5.0, fo_id<Real>(), false ), 2.0 );
151 BOOST_CHECK_EQUAL( x.min(), -2.0 );
152 BOOST_CHECK_EQUAL( x.accumulateMax( INFINITY, fo_id<Real>(), true ), -2.0 );
153 BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_id<Real>(), true ), -3.0 );
154 BOOST_CHECK_EQUAL( x.accumulateMax( 5.0, fo_id<Real>(), true ), -2.0 );
155 BOOST_CHECK_EQUAL( x.sumAbs(), 6.0 );
156 BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_abs<Real>() ), 6.0 );
157 BOOST_CHECK_EQUAL( x.accumulateSum( 1.0, fo_abs<Real>() ), 7.0 );
158 BOOST_CHECK_EQUAL( x.accumulateSum( -1.0, fo_abs<Real>() ), 7.0 );
159 BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
160 BOOST_CHECK_EQUAL( x.accumulateMax( 0.0, fo_abs<Real>(), false ), 2.0 );
161 BOOST_CHECK_EQUAL( x.accumulateMax( 1.0, fo_abs<Real>(), false ), 2.0 );
162 BOOST_CHECK_EQUAL( x.accumulateMax( -1.0, fo_abs<Real>(), false ), 2.0 );
163 BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_abs<Real>(), false ), 3.0 );
164 BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_abs<Real>(), false ), 3.0 );
165 x.set( 1, 1.0 );
166 BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
167 BOOST_CHECK_EQUAL( x.accumulateMax( 0.0, fo_abs<Real>(), false ), 2.0 );
168 BOOST_CHECK_EQUAL( x.accumulateMax( 1.0, fo_abs<Real>(), false ), 2.0 );
169 BOOST_CHECK_EQUAL( x.accumulateMax( -1.0, fo_abs<Real>(), false ), 2.0 );
170 BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_abs<Real>(), false ), 3.0 );
171 BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_abs<Real>(), false ), 3.0 );
172 for( size_t i = 0; i < x.size(); i++ )
173 x.set( i, i ? (1.0 / i) : 0.0 );
174 BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_inv0<Real>() ), 10.0 );
175 x /= x.sum();
176
177 // test entropy
178 BOOST_CHECK( x.entropy() < Prob(5).entropy() );
179 for( size_t i = 1; i < 100; i++ )
180 BOOST_CHECK_CLOSE( Prob(i).entropy(), std::log((Real)i), tol );
181
182 // test hasNaNs and hasNegatives
183 BOOST_CHECK( !Prob( 3, 0.0 ).hasNaNs() );
184 Real c = 0.0;
185 BOOST_CHECK( Prob( 3, c / c ).hasNaNs() );
186 BOOST_CHECK( !Prob( 3, 0.0 ).hasNegatives() );
187 BOOST_CHECK( !Prob( 3, 1.0 ).hasNegatives() );
188 BOOST_CHECK( Prob( 3, -1.0 ).hasNegatives() );
189 x.set( 0, 0.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 ); x.set( 3, 1.0 ); x.set( 4, 100.0 );
190 BOOST_CHECK( x.hasNegatives() );
191 x.set( 2, -INFINITY );
192 BOOST_CHECK( x.hasNegatives() );
193 x.set( 2, INFINITY );
194 BOOST_CHECK( !x.hasNegatives() );
195 x.set( 2, -1.0 );
196
197 // test argmax
198 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)100.0 ) );
199 x.set( 4, 0.5 );
200 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)3, (Real)1.0 ) );
201 x.set( 3, -2.0 );
202 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)0.5 ) );
203 x.set( 4, -1.0 );
204 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)0, (Real)0.0 ) );
205 x.set( 0, -2.0 );
206 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)1, (Real)0.0 ) );
207 x.set( 1, -3.0 );
208 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)2, (Real)-1.0 ) );
209 x.set( 2, -2.0 );
210 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)-1.0 ) );
211
212 // test draw
213 for( size_t i = 0; i < x.size(); i++ )
214 x.set( i, i ? (1.0 / i) : 0.0 );
215 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
216 BOOST_CHECK( x.draw() < x.size() );
217 BOOST_CHECK( x.draw() != 0 );
218 }
219 x.set( 2, 0.0 );
220 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
221 BOOST_CHECK( x.draw() < x.size() );
222 BOOST_CHECK( x.draw() != 0 );
223 BOOST_CHECK( x.draw() != 2 );
224 }
225 x.set( 4, 0.0 );
226 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
227 BOOST_CHECK( x.draw() < x.size() );
228 BOOST_CHECK( x.draw() != 0 );
229 BOOST_CHECK( x.draw() != 2 );
230 BOOST_CHECK( x.draw() != 4 );
231 }
232 x.set( 1, 0.0 );
233 for( size_t repeat = 0; repeat < 10000; repeat++ )
234 BOOST_CHECK( x.draw() == 3 );
235
236 // test <, ==
237 Prob a(3, 1.0), b(3, 1.0);
238 BOOST_CHECK( !(a < b) );
239 BOOST_CHECK( !(b < a) );
240 BOOST_CHECK( a == b );
241 a.set( 0, 0.0 );
242 BOOST_CHECK( a < b );
243 BOOST_CHECK( !(b < a) );
244 BOOST_CHECK( !(a == b) );
245 b.set( 2, 0.0 );
246 BOOST_CHECK( a < b );
247 BOOST_CHECK( !(b < a) );
248 BOOST_CHECK( !(a == b) );
249 b.set( 0, 0.0 );
250 BOOST_CHECK( !(a < b) );
251 BOOST_CHECK( b < a );
252 BOOST_CHECK( !(a == b) );
253 a.set( 1, 0.0 );
254 BOOST_CHECK( a < b );
255 BOOST_CHECK( !(b < a) );
256 BOOST_CHECK( !(a == b) );
257 b.set( 1, 0.0 );
258 BOOST_CHECK( !(a < b) );
259 BOOST_CHECK( b < a );
260 BOOST_CHECK( !(a == b) );
261 a.set( 2, 0.0 );
262 BOOST_CHECK( !(a < b) );
263 BOOST_CHECK( !(b < a) );
264 BOOST_CHECK( a == b );
265 }
266
267
268 BOOST_AUTO_TEST_CASE( UnaryTransformationsTest ) {
269 Prob x( 3 );
270 x.set( 0, -2.0 );
271 x.set( 1, 0.0 );
272 x.set( 2, 2.0 );
273
274 Prob y = -x;
275 Prob z = x.pwUnaryTr( std::negate<Real>() );
276 BOOST_CHECK_EQUAL( y[0], 2.0 );
277 BOOST_CHECK_EQUAL( y[1], 0.0 );
278 BOOST_CHECK_EQUAL( y[2], -2.0 );
279 BOOST_CHECK( y == z );
280
281 y = x.abs();
282 z = x.pwUnaryTr( fo_abs<Real>() );
283 BOOST_CHECK_EQUAL( y[0], 2.0 );
284 BOOST_CHECK_EQUAL( y[1], 0.0 );
285 BOOST_CHECK_EQUAL( y[2], 2.0 );
286 BOOST_CHECK( y == z );
287
288 y = x.exp();
289 z = x.pwUnaryTr( fo_exp<Real>() );
290 BOOST_CHECK_CLOSE( y[0], std::exp(-2.0), tol );
291 BOOST_CHECK_EQUAL( y[1], 1.0 );
292 BOOST_CHECK_CLOSE( y[2], 1.0 / y[0], tol );
293 BOOST_CHECK( y == z );
294
295 y = x.log(false);
296 z = x.pwUnaryTr( fo_log<Real>() );
297 BOOST_CHECK( isnan( y[0] ) );
298 BOOST_CHECK_EQUAL( y[1], -INFINITY );
299 BOOST_CHECK_CLOSE( y[2], std::log(2.0), tol );
300 BOOST_CHECK( !(y == z) );
301 y.set( 0, 0.0 );
302 z.set( 0, 0.0 );
303 BOOST_CHECK( y == z );
304
305 y = x.log(true);
306 z = x.pwUnaryTr( fo_log0<Real>() );
307 BOOST_CHECK( isnan( y[0] ) );
308 BOOST_CHECK_EQUAL( y[1], 0.0 );
309 BOOST_CHECK_EQUAL( y[2], std::log(2.0) );
310 BOOST_CHECK( !(y == z) );
311 y.set( 0, 0.0 );
312 z.set( 0, 0.0 );
313 BOOST_CHECK( y == z );
314
315 y = x.inverse(false);
316 z = x.pwUnaryTr( fo_inv<Real>() );
317 BOOST_CHECK_EQUAL( y[0], -0.5 );
318 BOOST_CHECK_EQUAL( y[1], INFINITY );
319 BOOST_CHECK_EQUAL( y[2], 0.5 );
320 BOOST_CHECK( y == z );
321
322 y = x.inverse(true);
323 z = x.pwUnaryTr( fo_inv0<Real>() );
324 BOOST_CHECK_EQUAL( y[0], -0.5 );
325 BOOST_CHECK_EQUAL( y[1], 0.0 );
326 BOOST_CHECK_EQUAL( y[2], 0.5 );
327 BOOST_CHECK( y == z );
328
329 x.set( 0, 2.0 );
330 y = x.normalized();
331 BOOST_CHECK_EQUAL( y[0], 0.5 );
332 BOOST_CHECK_EQUAL( y[1], 0.0 );
333 BOOST_CHECK_EQUAL( y[2], 0.5 );
334
335 y = x.normalized( NORMPROB );
336 BOOST_CHECK_EQUAL( y[0], 0.5 );
337 BOOST_CHECK_EQUAL( y[1], 0.0 );
338 BOOST_CHECK_EQUAL( y[2], 0.5 );
339
340 x.set( 0, -2.0 );
341 y = x.normalized( NORMLINF );
342 BOOST_CHECK_EQUAL( y[0], -1.0 );
343 BOOST_CHECK_EQUAL( y[1], 0.0 );
344 BOOST_CHECK_EQUAL( y[2], 1.0 );
345 }
346
347
348 BOOST_AUTO_TEST_CASE( UnaryOperationsTest ) {
349 Prob xorg(3);
350 xorg.set( 0, 2.0 );
351 xorg.set( 1, 0.0 );
352 xorg.set( 2, 1.0 );
353 Prob y(3);
354
355 Prob x = xorg;
356 BOOST_CHECK( x.setUniform() == Prob(3) );
357 BOOST_CHECK( x == Prob(3) );
358
359 y.set( 0, std::exp(2.0) );
360 y.set( 1, 1.0 );
361 y.set( 2, std::exp(1.0) );
362 x = xorg;
363 BOOST_CHECK( x.takeExp() == y );
364 BOOST_CHECK( x == y );
365 x = xorg;
366 BOOST_CHECK( x.pwUnaryOp( fo_exp<Real>() ) == y );
367 BOOST_CHECK( x == y );
368
369 y.set( 0, std::log(2.0) );
370 y.set( 1, -INFINITY );
371 y.set( 2, 0.0 );
372 x = xorg;
373 BOOST_CHECK( x.takeLog() == y );
374 BOOST_CHECK( x == y );
375 x = xorg;
376 BOOST_CHECK( x.takeLog(false) == y );
377 BOOST_CHECK( x == y );
378 x = xorg;
379 BOOST_CHECK( x.pwUnaryOp( fo_log<Real>() ) == y );
380 BOOST_CHECK( x == y );
381
382 y.set( 1, 0.0 );
383 x = xorg;
384 BOOST_CHECK( x.takeLog(true) == y );
385 BOOST_CHECK( x == y );
386 x = xorg;
387 BOOST_CHECK( x.pwUnaryOp( fo_log0<Real>() ) == y );
388 BOOST_CHECK( x == y );
389
390 y.set( 0, 2.0 / 3.0 );
391 y.set( 1, 0.0 / 3.0 );
392 y.set( 2, 1.0 / 3.0 );
393 x = xorg;
394 BOOST_CHECK_EQUAL( x.normalize(), 3.0 );
395 BOOST_CHECK( x == y );
396
397 x = xorg;
398 BOOST_CHECK_EQUAL( x.normalize( NORMPROB ), 3.0 );
399 BOOST_CHECK( x == y );
400
401 y.set( 0, 2.0 / 2.0 );
402 y.set( 1, 0.0 / 2.0 );
403 y.set( 2, 1.0 / 2.0 );
404 x = xorg;
405 BOOST_CHECK_EQUAL( x.normalize( NORMLINF ), 2.0 );
406 BOOST_CHECK( x == y );
407
408 xorg.set( 0, -2.0 );
409 y.set( 0, 2.0 );
410 y.set( 1, 0.0 );
411 y.set( 2, 1.0 );
412 x = xorg;
413 BOOST_CHECK( x.takeAbs() == y );
414 BOOST_CHECK( x == y );
415
416 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
417 x.randomize();
418 for( size_t i = 0; i < x.size(); i++ ) {
419 BOOST_CHECK( x[i] < 1.0 );
420 BOOST_CHECK( x[i] >= 0.0 );
421 }
422 }
423 }
424
425
426 BOOST_AUTO_TEST_CASE( ScalarTransformationsTest ) {
427 Prob x(3);
428 x.set( 0, 2.0 );
429 x.set( 1, 0.0 );
430 x.set( 2, 1.0 );
431 Prob y(3);
432
433 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
434 BOOST_CHECK( (x + 1.0) == y );
435 y.set( 0, 0.0 ); y.set( 1, -2.0 ); y.set( 2, -1.0 );
436 BOOST_CHECK( (x + (-2.0)) == y );
437
438 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
439 BOOST_CHECK( (x - 1.0) == y );
440 y.set( 0, 4.0 ); y.set( 1, 2.0 ); y.set( 2, 3.0 );
441 BOOST_CHECK( (x - (-2.0)) == y );
442
443 BOOST_CHECK( (x * 1.0) == x );
444 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
445 BOOST_CHECK( (x * 2.0) == y );
446 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
447 BOOST_CHECK( (x * -0.5) == y );
448
449 BOOST_CHECK( (x / 1.0) == x );
450 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
451 BOOST_CHECK( (x / 2.0) == y );
452 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
453 BOOST_CHECK( (x / -0.5) == y );
454 BOOST_CHECK( (x / 0.0) == Prob(3, 0.0) );
455
456 BOOST_CHECK( (x ^ 1.0) == x );
457 BOOST_CHECK( (x ^ 0.0) == Prob(3, 1.0) );
458 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
459 BOOST_CHECK( (x ^ 2.0) == y );
460 y.set( 0, 1.0 / std::sqrt(2.0) ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
461 Prob z = (x ^ -0.5);
462 BOOST_CHECK_CLOSE( z[0], y[0], tol );
463 BOOST_CHECK_EQUAL( z[1], y[1] );
464 BOOST_CHECK_CLOSE( z[2], y[2], tol );
465 }
466
467
468 BOOST_AUTO_TEST_CASE( ScalarOperationsTest ) {
469 Prob xorg(3), x(3);
470 xorg.set( 0, 2.0 );
471 xorg.set( 1, 0.0 );
472 xorg.set( 2, 1.0 );
473 Prob y(3);
474
475 x = xorg;
476 BOOST_CHECK( x.fill( 1.0 ) == Prob(3, 1.0) );
477 BOOST_CHECK( x == Prob(3, 1.0) );
478 BOOST_CHECK( x.fill( 2.0 ) == Prob(3, 2.0) );
479 BOOST_CHECK( x == Prob(3, 2.0) );
480 BOOST_CHECK( x.fill( 0.0 ) == Prob(3, 0.0) );
481 BOOST_CHECK( x == Prob(3, 0.0) );
482
483 x = xorg;
484 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
485 BOOST_CHECK( (x += 1.0) == y );
486 BOOST_CHECK( x == y );
487 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
488 BOOST_CHECK( (x += -2.0) == y );
489 BOOST_CHECK( x == y );
490
491 x = xorg;
492 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
493 BOOST_CHECK( (x -= 1.0) == y );
494 BOOST_CHECK( x == y );
495 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
496 BOOST_CHECK( (x -= -2.0) == y );
497 BOOST_CHECK( x == y );
498
499 x = xorg;
500 BOOST_CHECK( (x *= 1.0) == x );
501 BOOST_CHECK( x == x );
502 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
503 BOOST_CHECK( (x *= 2.0) == y );
504 BOOST_CHECK( x == y );
505 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
506 BOOST_CHECK( (x *= -0.25) == y );
507 BOOST_CHECK( x == y );
508
509 x = xorg;
510 BOOST_CHECK( (x /= 1.0) == x );
511 BOOST_CHECK( x == x );
512 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
513 BOOST_CHECK( (x /= 2.0) == y );
514 BOOST_CHECK( x == y );
515 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
516 BOOST_CHECK( (x /= -0.25) == y );
517 BOOST_CHECK( x == y );
518 BOOST_CHECK( (x /= 0.0) == Prob(3, 0.0) );
519 BOOST_CHECK( x == Prob(3, 0.0) );
520
521 x = xorg;
522 BOOST_CHECK( (x ^= 1.0) == x );
523 BOOST_CHECK( x == x );
524 BOOST_CHECK( (x ^= 0.0) == Prob(3, 1.0) );
525 BOOST_CHECK( x == Prob(3, 1.0) );
526 x = xorg;
527 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
528 BOOST_CHECK( (x ^= 2.0) == y );
529 BOOST_CHECK( x == y );
530 y.set( 0, 0.5 ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
531 BOOST_CHECK( (x ^= -0.5) == y );
532 BOOST_CHECK( x == y );
533 }
534
535
536 BOOST_AUTO_TEST_CASE( VectorOperationsTest ) {
537 size_t N = 6;
538 Prob xorg(N), x(N);
539 xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, 1.0 ); xorg.set( 3, 0.0 ); xorg.set( 4, 2.0 ); xorg.set( 5, 3.0 );
540 Prob y(N);
541 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
542 Prob z(N), r(N);
543
544 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
545 x = xorg;
546 r = (x += y);
547 for( size_t i = 0; i < N; i++ )
548 BOOST_CHECK_CLOSE( r[i], z[i], tol );
549 BOOST_CHECK( x == z );
550 x = xorg;
551 BOOST_CHECK( x.pwBinaryOp( y, std::plus<Real>() ) == z );
552 BOOST_CHECK( x == z );
553
554 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
555 x = xorg;
556 r = (x -= y);
557 for( size_t i = 0; i < N; i++ )
558 BOOST_CHECK_CLOSE( r[i], z[i], tol );
559 BOOST_CHECK( x == z );
560 x = xorg;
561 BOOST_CHECK( x.pwBinaryOp( y, std::minus<Real>() ) == z );
562 BOOST_CHECK( x == z );
563
564 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
565 x = xorg;
566 r = (x *= y);
567 for( size_t i = 0; i < N; i++ )
568 BOOST_CHECK_CLOSE( r[i], z[i], tol );
569 BOOST_CHECK( x == z );
570 x = xorg;
571 BOOST_CHECK( x.pwBinaryOp( y, std::multiplies<Real>() ) == z );
572 BOOST_CHECK( x == z );
573
574 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
575 x = xorg;
576 r = (x /= y);
577 for( size_t i = 0; i < N; i++ )
578 BOOST_CHECK_CLOSE( r[i], z[i], tol );
579 BOOST_CHECK( x == z );
580 x = xorg;
581 BOOST_CHECK( x.pwBinaryOp( y, fo_divides0<Real>() ) == z );
582 BOOST_CHECK( x == z );
583
584 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY ); /*z.set( 3, INFINITY );*/ z.set( 4, -1.0 ); z.set( 5, 1.0 );
585 x = xorg;
586 r = (x.divide( y ));
587 BOOST_CHECK_CLOSE( r[0], z[0], tol );
588 BOOST_CHECK_CLOSE( r[1], z[1], tol );
589 BOOST_CHECK_EQUAL( r[2], z[2] );
590 BOOST_CHECK( isnan(r[3]) );
591 BOOST_CHECK_CLOSE( r[4], z[4], tol );
592 BOOST_CHECK_CLOSE( r[5], z[5], tol );
593 x.set( 3, 0.0 ); r.set( 3, 0.0 );
594 BOOST_CHECK( x == r );
595 x = xorg;
596 r = x.pwBinaryOp( y, std::divides<Real>() );
597 BOOST_CHECK_CLOSE( r[0], z[0], tol );
598 BOOST_CHECK_CLOSE( r[1], z[1], tol );
599 BOOST_CHECK_EQUAL( r[2], z[2] );
600 BOOST_CHECK( isnan(r[3]) );
601 BOOST_CHECK_CLOSE( r[4], z[4], tol );
602 BOOST_CHECK_CLOSE( r[5], z[5], tol );
603 x.set( 3, 0.0 ); r.set( 3, 0.0 );
604 BOOST_CHECK( x == r );
605
606 z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
607 x = xorg;
608 r = (x ^= y);
609 BOOST_CHECK_CLOSE( r[0], z[0], tol );
610 BOOST_CHECK_EQUAL( r[1], z[1] );
611 BOOST_CHECK_CLOSE( r[2], z[2], tol );
612 BOOST_CHECK_CLOSE( r[3], z[3], tol );
613 BOOST_CHECK_CLOSE( r[4], z[4], tol );
614 BOOST_CHECK_CLOSE( r[5], z[5], tol );
615 BOOST_CHECK( x == z );
616 x = xorg;
617 BOOST_CHECK( x.pwBinaryOp( y, fo_pow<Real>() ) == z );
618 BOOST_CHECK( x == z );
619 }
620
621
622 BOOST_AUTO_TEST_CASE( VectorTransformationsTest ) {
623 size_t N = 6;
624 Prob x(N);
625 x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, 1.0 ); x.set( 3, 0.0 ); x.set( 4, 2.0 ); x.set( 5, 3.0 );
626 Prob y(N);
627 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
628 Prob z(N), r(N);
629
630 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
631 r = x + y;
632 for( size_t i = 0; i < N; i++ )
633 BOOST_CHECK_CLOSE( r[i], z[i], tol );
634 z = x.pwBinaryTr( y, std::plus<Real>() );
635 BOOST_CHECK( r == z );
636
637 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
638 r = x - y;
639 for( size_t i = 0; i < N; i++ )
640 BOOST_CHECK_CLOSE( r[i], z[i], tol );
641 z = x.pwBinaryTr( y, std::minus<Real>() );
642 BOOST_CHECK( r == z );
643
644 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
645 r = x * y;
646 for( size_t i = 0; i < N; i++ )
647 BOOST_CHECK_CLOSE( r[i], z[i], tol );
648 z = x.pwBinaryTr( y, std::multiplies<Real>() );
649 BOOST_CHECK( r == z );
650
651 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
652 r = x / y;
653 for( size_t i = 0; i < N; i++ )
654 BOOST_CHECK_CLOSE( r[i], z[i], tol );
655 z = x.pwBinaryTr( y, fo_divides0<Real>() );
656 BOOST_CHECK( r == z );
657
658 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY ); /*z.set( 3, INFINITY );*/ z.set( 4, -1.0 ); z.set( 5, 1.0 );
659 r = x.divided_by( y );
660 BOOST_CHECK_CLOSE( r[0], z[0], tol );
661 BOOST_CHECK_CLOSE( r[1], z[1], tol );
662 BOOST_CHECK_EQUAL( r[2], z[2] );
663 BOOST_CHECK( isnan(r[3]) );
664 BOOST_CHECK_CLOSE( r[4], z[4], tol );
665 BOOST_CHECK_CLOSE( r[5], z[5], tol );
666 z = x.pwBinaryTr( y, std::divides<Real>() );
667 BOOST_CHECK_CLOSE( r[0], z[0], tol );
668 BOOST_CHECK_CLOSE( r[1], z[1], tol );
669 BOOST_CHECK_EQUAL( r[2], z[2] );
670 BOOST_CHECK( isnan(r[3]) );
671 BOOST_CHECK_CLOSE( r[4], z[4], tol );
672 BOOST_CHECK_CLOSE( r[5], z[5], tol );
673
674 z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
675 r = x ^ y;
676 BOOST_CHECK_CLOSE( r[0], z[0], tol );
677 BOOST_CHECK_EQUAL( r[1], z[1] );
678 BOOST_CHECK_CLOSE( r[2], z[2], tol );
679 BOOST_CHECK_CLOSE( r[3], z[3], tol );
680 BOOST_CHECK_CLOSE( r[4], z[4], tol );
681 BOOST_CHECK_CLOSE( r[5], z[5], tol );
682 z = x.pwBinaryTr( y, fo_pow<Real>() );
683 BOOST_CHECK( r == z );
684 }
685
686
687 BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
688 Prob x(3), y(3), z(3);
689 x.set( 0, 0.2 );
690 x.set( 1, 0.8 );
691 x.set( 2, 0.0 );
692 y.set( 0, 0.0 );
693 y.set( 1, 0.6 );
694 y.set( 2, 0.4 );
695
696 z = min( x, y );
697 BOOST_CHECK_EQUAL( z[0], 0.0 );
698 BOOST_CHECK_EQUAL( z[1], 0.6 );
699 BOOST_CHECK_EQUAL( z[2], 0.0 );
700 z = max( x, y );
701 BOOST_CHECK_EQUAL( z[0], 0.2 );
702 BOOST_CHECK_EQUAL( z[1], 0.8 );
703 BOOST_CHECK_EQUAL( z[2], 0.4 );
704
705 BOOST_CHECK_EQUAL( dist( x, x, DISTL1 ), 0.0 );
706 BOOST_CHECK_EQUAL( dist( y, y, DISTL1 ), 0.0 );
707 BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), 0.2 + 0.2 + 0.4 );
708 BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), 0.2 + 0.2 + 0.4 );
709 BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
710 BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
711 BOOST_CHECK_EQUAL( dist( x, x, DISTLINF ), 0.0 );
712 BOOST_CHECK_EQUAL( dist( y, y, DISTLINF ), 0.0 );
713 BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), 0.4 );
714 BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), 0.4 );
715 BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), x.innerProduct( y, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
716 BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), y.innerProduct( x, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
717 BOOST_CHECK_EQUAL( dist( x, x, DISTTV ), 0.0 );
718 BOOST_CHECK_EQUAL( dist( y, y, DISTTV ), 0.0 );
719 BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
720 BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
721 BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
722 BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
723 BOOST_CHECK_EQUAL( dist( x, x, DISTKL ), 0.0 );
724 BOOST_CHECK_EQUAL( dist( y, y, DISTKL ), 0.0 );
725 BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), INFINITY );
726 BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), INFINITY );
727 BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
728 BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
729 BOOST_CHECK_EQUAL( dist( x, x, DISTHEL ), 0.0 );
730 BOOST_CHECK_EQUAL( dist( y, y, DISTHEL ), 0.0 );
731 BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
732 BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
733 BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
734 BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
735 x.set( 1, 0.7 ); x.set( 2, 0.1 );
736 y.set( 0, 0.1 ); y.set( 1, 0.5 );
737 BOOST_CHECK_CLOSE( dist( x, y, DISTKL ), 0.2 * std::log(0.2 / 0.1) + 0.7 * std::log(0.7 / 0.5) + 0.1 * std::log(0.1 / 0.4), tol );
738 BOOST_CHECK_CLOSE( dist( y, x, DISTKL ), 0.1 * std::log(0.1 / 0.2) + 0.5 * std::log(0.5 / 0.7) + 0.4 * std::log(0.4 / 0.1), tol );
739 BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
740 BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
741
742 Prob xx(4), yy(4);
743 for( size_t i = 0; i < 3; i++ ) {
744 xx.set( i, x[i] );
745 yy.set( i, y[i] );
746 }
747 std::stringstream ss;
748 ss << xx;
749 std::string s;
750 std::getline( ss, s );
751 #ifdef DAI_SPARSE
752 BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.2, 1:0.7, 2:0.1)") );
753 #else
754 BOOST_CHECK_EQUAL( s, std::string("(0.2, 0.7, 0.1, 0.25)") );
755 #endif
756 std::stringstream ss2;
757 ss2 << yy;
758 std::getline( ss2, s );
759 #ifdef DAI_SPARSE
760 BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.1, 1:0.5, 2:0.4)") );
761 #else
762 BOOST_CHECK_EQUAL( s, std::string("(0.1, 0.5, 0.4, 0.25)") );
763 #endif
764
765 z = min( x, y );
766 BOOST_CHECK_EQUAL( z[0], 0.1 );
767 BOOST_CHECK_EQUAL( z[1], 0.5 );
768 BOOST_CHECK_EQUAL( z[2], 0.1 );
769 z = max( x, y );
770 BOOST_CHECK_EQUAL( z[0], 0.2 );
771 BOOST_CHECK_EQUAL( z[1], 0.7 );
772 BOOST_CHECK_EQUAL( z[2], 0.4 );
773
774 BOOST_CHECK_CLOSE( x.innerProduct( y, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
775 BOOST_CHECK_CLOSE( y.innerProduct( x, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
776 BOOST_CHECK_CLOSE( x.innerProduct( y, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
777 BOOST_CHECK_CLOSE( y.innerProduct( x, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
778 BOOST_CHECK_CLOSE( x.innerProduct( y, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
779 BOOST_CHECK_CLOSE( y.innerProduct( x, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
780 }