e038866cedc4df3777e76b3a3df23cc392b8fdf8
[libdai.git] / tests / unit / prob.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/prob.h>
15 #include <strstream>
16
17
18 using namespace dai;
19
20
21 const double tol = 1e-8;
22
23
24 #define BOOST_TEST_MODULE ProbTest
25
26
27 #include <boost/test/unit_test.hpp>
28 #include <boost/test/floating_point_comparison.hpp>
29
30
31 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
32 // check constructors
33 Prob x1;
34 BOOST_CHECK_EQUAL( x1.size(), 0 );
35 BOOST_CHECK( x1.p() == std::vector<Real>() );
36
37 Prob x2( 3 );
38 BOOST_CHECK_EQUAL( x2.size(), 3 );
39 BOOST_CHECK( x2.p() == std::vector<Real>( 3, 1.0 / 3.0 ) );
40 BOOST_CHECK_EQUAL( x2[0], 1.0 / 3.0 );
41 BOOST_CHECK_EQUAL( x2[1], 1.0 / 3.0 );
42 BOOST_CHECK_EQUAL( x2[2], 1.0 / 3.0 );
43
44 Prob x3( 4, 1.0 );
45 BOOST_CHECK_EQUAL( x3.size(), 4 );
46 BOOST_CHECK( x3.p() == std::vector<Real>( 4, 1.0 ) );
47 BOOST_CHECK_EQUAL( x3[0], 1.0 );
48 BOOST_CHECK_EQUAL( x3[1], 1.0 );
49 BOOST_CHECK_EQUAL( x3[2], 1.0 );
50 BOOST_CHECK_EQUAL( x3[3], 1.0 );
51 x3[0] = 0.5;
52 x3[1] = 1.0;
53 x3[2] = 2.0;
54 x3[3] = 4.0;
55
56 Prob x4( x3.begin(), x3.end() );
57 BOOST_CHECK_EQUAL( x4.size(), 4 );
58 BOOST_CHECK( x4.p() == x3.p() );
59 BOOST_CHECK_EQUAL( x4[0], 0.5 );
60 BOOST_CHECK_EQUAL( x4[1], 1.0 );
61 BOOST_CHECK_EQUAL( x4[2], 2.0 );
62 BOOST_CHECK_EQUAL( x4[3], 4.0 );
63
64 x3.p() = std::vector<Real>( 4, 2.5 );
65 Prob x5( x3.begin(), x3.end(), x3.size() );
66 BOOST_CHECK_EQUAL( x5.size(), 4 );
67 BOOST_CHECK( x5.p() == x3.p() );
68 BOOST_CHECK_EQUAL( x5[0], 2.5 );
69 BOOST_CHECK_EQUAL( x5[1], 2.5 );
70 BOOST_CHECK_EQUAL( x5[2], 2.5 );
71 BOOST_CHECK_EQUAL( x5[3], 2.5 );
72
73 std::vector<int> y( 3, 2 );
74 Prob x6( y );
75 BOOST_CHECK_EQUAL( x6.size(), 3 );
76 BOOST_CHECK( x6.p() == std::vector<Real>( 3, 2.0 ) );
77 BOOST_CHECK_EQUAL( x6[0], 2.0 );
78 BOOST_CHECK_EQUAL( x6[1], 2.0 );
79 BOOST_CHECK_EQUAL( x6[2], 2.0 );
80
81 Prob x7( x6 );
82 BOOST_CHECK( x7 == x6 );
83
84 Prob x8 = x6;
85 BOOST_CHECK( x8 == x6 );
86 }
87
88
89 BOOST_AUTO_TEST_CASE( IteratorTest ) {
90 Prob x( 5, 0.0 );
91 size_t i;
92 for( i = 0; i < x.size(); i++ )
93 x[i] = i;
94
95 i = 0;
96 for( Prob::const_iterator cit = x.begin(); cit != x.end(); cit++, i++ )
97 BOOST_CHECK_EQUAL( *cit, i );
98
99 i = 0;
100 for( Prob::iterator it = x.begin(); it != x.end(); it++, i++ )
101 *it = 4 - i;
102
103 i = 0;
104 for( Prob::const_iterator it = x.begin(); it != x.end(); it++, i++ )
105 BOOST_CHECK_EQUAL( *it, 4 - i );
106
107 i = 0;
108 for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
109 BOOST_CHECK_EQUAL( *crit, i );
110
111 i = 0;
112 for( Prob::reverse_iterator rit = x.rbegin(); rit != x.rend(); rit++, i++ )
113 *rit = 2 * i;
114
115 i = 0;
116 for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
117 BOOST_CHECK_EQUAL( *crit, 2 * i );
118 }
119
120
121 BOOST_AUTO_TEST_CASE( QueriesTest ) {
122 Prob x( 5, 0.0 );
123 for( size_t i = 0; i < x.size(); i++ )
124 x[i] = 2.0 - i;
125
126 // test accumulate, min, max, sum, sumAbs, maxAbs
127 BOOST_CHECK_EQUAL( x.sum(), 0.0 );
128 BOOST_CHECK_EQUAL( x.accumulate( 0.0, std::plus<Real>(), fo_id<Real>() ), 0.0 );
129 BOOST_CHECK_EQUAL( x.accumulate( 1.0, std::plus<Real>(), fo_id<Real>() ), 1.0 );
130 BOOST_CHECK_EQUAL( x.accumulate( -1.0, std::plus<Real>(), fo_id<Real>() ), -1.0 );
131 BOOST_CHECK_EQUAL( x.max(), 2.0 );
132 BOOST_CHECK_EQUAL( x.accumulate( -INFINITY, fo_max<Real>(), fo_id<Real>() ), 2.0 );
133 BOOST_CHECK_EQUAL( x.accumulate( 3.0, fo_max<Real>(), fo_id<Real>() ), 3.0 );
134 BOOST_CHECK_EQUAL( x.accumulate( -5.0, fo_max<Real>(), fo_id<Real>() ), 2.0 );
135 BOOST_CHECK_EQUAL( x.min(), -2.0 );
136 BOOST_CHECK_EQUAL( x.accumulate( INFINITY, fo_min<Real>(), fo_id<Real>() ), -2.0 );
137 BOOST_CHECK_EQUAL( x.accumulate( -3.0, fo_min<Real>(), fo_id<Real>() ), -3.0 );
138 BOOST_CHECK_EQUAL( x.accumulate( 5.0, fo_min<Real>(), fo_id<Real>() ), -2.0 );
139 BOOST_CHECK_EQUAL( x.sumAbs(), 6.0 );
140 BOOST_CHECK_EQUAL( x.accumulate( 0.0, std::plus<Real>(), fo_abs<Real>() ), 6.0 );
141 BOOST_CHECK_EQUAL( x.accumulate( 1.0, std::plus<Real>(), fo_abs<Real>() ), 7.0 );
142 BOOST_CHECK_EQUAL( x.accumulate( -1.0, std::plus<Real>(), fo_abs<Real>() ), 7.0 );
143 BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
144 BOOST_CHECK_EQUAL( x.accumulate( 0.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
145 BOOST_CHECK_EQUAL( x.accumulate( 1.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
146 BOOST_CHECK_EQUAL( x.accumulate( -1.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
147 BOOST_CHECK_EQUAL( x.accumulate( 3.0, fo_max<Real>(), fo_abs<Real>() ), 3.0 );
148 BOOST_CHECK_EQUAL( x.accumulate( -3.0, fo_max<Real>(), fo_abs<Real>() ), 3.0 );
149 x[1] = 1.0;
150 BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
151 BOOST_CHECK_EQUAL( x.accumulate( 0.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
152 BOOST_CHECK_EQUAL( x.accumulate( 1.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
153 BOOST_CHECK_EQUAL( x.accumulate( -1.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
154 BOOST_CHECK_EQUAL( x.accumulate( 3.0, fo_max<Real>(), fo_abs<Real>() ), 3.0 );
155 BOOST_CHECK_EQUAL( x.accumulate( -3.0, fo_max<Real>(), fo_abs<Real>() ), 3.0 );
156 for( size_t i = 0; i < x.size(); i++ )
157 x[i] = i ? (1.0 / i) : 0.0;
158 BOOST_CHECK_EQUAL( x.accumulate( 0.0, std::plus<Real>(), fo_inv0<Real>() ), 10.0 );
159 x /= x.sum();
160
161 // test entropy
162 BOOST_CHECK( x.entropy() < Prob(5).entropy() );
163 for( size_t i = 1; i < 100; i++ )
164 BOOST_CHECK_CLOSE( Prob(i).entropy(), std::log(i), tol );
165
166 // test hasNaNs and hasNegatives
167 BOOST_CHECK( !Prob( 3, 0.0 ).hasNaNs() );
168 Real c = 0.0;
169 BOOST_CHECK( Prob( 3, c / c ).hasNaNs() );
170 BOOST_CHECK( !Prob( 3, 0.0 ).hasNegatives() );
171 BOOST_CHECK( !Prob( 3, 1.0 ).hasNegatives() );
172 BOOST_CHECK( Prob( 3, -1.0 ).hasNegatives() );
173 x[0] = 0.0; x[1] = 0.0; x[2] = -1.0; x[3] = 1.0; x[4] = 100.0;
174 BOOST_CHECK( x.hasNegatives() );
175 x[2] = -INFINITY;
176 BOOST_CHECK( x.hasNegatives() );
177 x[2] = INFINITY;
178 BOOST_CHECK( !x.hasNegatives() );
179 x[2] = -1.0;
180
181 // test argmax
182 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)100.0 ) );
183 x[4] = 0.5;
184 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)3, (Real)1.0 ) );
185 x[3] = -2.0;
186 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)0.5 ) );
187 x[4] = -1.0;
188 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)0, (Real)0.0 ) );
189 x[0] = -2.0;
190 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)1, (Real)0.0 ) );
191 x[1] = -3.0;
192 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)2, (Real)-1.0 ) );
193 x[2] = -2.0;
194 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)-1.0 ) );
195
196 // test draw
197 for( size_t i = 0; i < x.size(); i++ )
198 x[i] = i ? (1.0 / i) : 0.0;
199 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
200 BOOST_CHECK( x.draw() < x.size() );
201 BOOST_CHECK( x.draw() != 0 );
202 }
203 x[2] = 0.0;
204 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
205 BOOST_CHECK( x.draw() < x.size() );
206 BOOST_CHECK( x.draw() != 0 );
207 BOOST_CHECK( x.draw() != 2 );
208 }
209 x[4] = 0.0;
210 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
211 BOOST_CHECK( x.draw() < x.size() );
212 BOOST_CHECK( x.draw() != 0 );
213 BOOST_CHECK( x.draw() != 2 );
214 BOOST_CHECK( x.draw() != 4 );
215 }
216 x[1] = 0.0;
217 for( size_t repeat = 0; repeat < 10000; repeat++ )
218 BOOST_CHECK( x.draw() == 3 );
219
220 // test <, ==
221 Prob a(3, 1.0), b(3, 1.0);
222 BOOST_CHECK( !(a < b) );
223 BOOST_CHECK( !(b < a) );
224 BOOST_CHECK( a == b );
225 a[0] = 0.0;
226 BOOST_CHECK( a < b );
227 BOOST_CHECK( !(b < a) );
228 BOOST_CHECK( !(a == b) );
229 b[2] = 0.0;
230 BOOST_CHECK( a < b );
231 BOOST_CHECK( !(b < a) );
232 BOOST_CHECK( !(a == b) );
233 b[0] = 0.0;
234 BOOST_CHECK( !(a < b) );
235 BOOST_CHECK( b < a );
236 BOOST_CHECK( !(a == b) );
237 a[1] = 0.0;
238 BOOST_CHECK( a < b );
239 BOOST_CHECK( !(b < a) );
240 BOOST_CHECK( !(a == b) );
241 b[1] = 0.0;
242 BOOST_CHECK( !(a < b) );
243 BOOST_CHECK( b < a );
244 BOOST_CHECK( !(a == b) );
245 a[2] = 0.0;
246 BOOST_CHECK( !(a < b) );
247 BOOST_CHECK( !(b < a) );
248 BOOST_CHECK( a == b );
249 }
250
251
252 BOOST_AUTO_TEST_CASE( UnaryTransformationsTest ) {
253 Prob x( 3 );
254 x[0] = -2.0;
255 x[1] = 0.0;
256 x[2] = 2.0;
257
258 Prob y = -x;
259 Prob z = x.pwUnaryTr( std::negate<Real>() );
260 BOOST_CHECK_EQUAL( y[0], 2.0 );
261 BOOST_CHECK_EQUAL( y[1], 0.0 );
262 BOOST_CHECK_EQUAL( y[2], -2.0 );
263 BOOST_CHECK( y == z );
264
265 y = x.abs();
266 z = x.pwUnaryTr( fo_abs<Real>() );
267 BOOST_CHECK_EQUAL( y[0], 2.0 );
268 BOOST_CHECK_EQUAL( y[1], 0.0 );
269 BOOST_CHECK_EQUAL( y[2], 2.0 );
270 BOOST_CHECK( y == z );
271
272 y = x.exp();
273 z = x.pwUnaryTr( fo_exp<Real>() );
274 BOOST_CHECK_CLOSE( y[0], std::exp(-2.0), tol );
275 BOOST_CHECK_EQUAL( y[1], 1.0 );
276 BOOST_CHECK_CLOSE( y[2], 1.0 / y[0], tol );
277 BOOST_CHECK( y == z );
278
279 y = x.log(false);
280 z = x.pwUnaryTr( fo_log<Real>() );
281 BOOST_CHECK( isnan( y[0] ) );
282 BOOST_CHECK_EQUAL( y[1], -INFINITY );
283 BOOST_CHECK_CLOSE( y[2], std::log(2.0), tol );
284 BOOST_CHECK( !(y == z) );
285 y[0] = z[0] = 0.0;
286 BOOST_CHECK( y == z );
287
288 y = x.log(true);
289 z = x.pwUnaryTr( fo_log0<Real>() );
290 BOOST_CHECK( isnan( y[0] ) );
291 BOOST_CHECK_EQUAL( y[1], 0.0 );
292 BOOST_CHECK_EQUAL( y[2], std::log(2.0) );
293 BOOST_CHECK( !(y == z) );
294 y[0] = z[0] = 0.0;
295 BOOST_CHECK( y == z );
296
297 y = x.inverse(false);
298 z = x.pwUnaryTr( fo_inv<Real>() );
299 BOOST_CHECK_EQUAL( y[0], -0.5 );
300 BOOST_CHECK_EQUAL( y[1], INFINITY );
301 BOOST_CHECK_EQUAL( y[2], 0.5 );
302 BOOST_CHECK( y == z );
303
304 y = x.inverse(true);
305 z = x.pwUnaryTr( fo_inv0<Real>() );
306 BOOST_CHECK_EQUAL( y[0], -0.5 );
307 BOOST_CHECK_EQUAL( y[1], 0.0 );
308 BOOST_CHECK_EQUAL( y[2], 0.5 );
309 BOOST_CHECK( y == z );
310
311 x[0] = 2.0;
312 y = x.normalized();
313 BOOST_CHECK_EQUAL( y[0], 0.5 );
314 BOOST_CHECK_EQUAL( y[1], 0.0 );
315 BOOST_CHECK_EQUAL( y[2], 0.5 );
316
317 y = x.normalized( Prob::NORMPROB );
318 BOOST_CHECK_EQUAL( y[0], 0.5 );
319 BOOST_CHECK_EQUAL( y[1], 0.0 );
320 BOOST_CHECK_EQUAL( y[2], 0.5 );
321
322 x[0] = -2.0;
323 y = x.normalized( Prob::NORMLINF );
324 BOOST_CHECK_EQUAL( y[0], -1.0 );
325 BOOST_CHECK_EQUAL( y[1], 0.0 );
326 BOOST_CHECK_EQUAL( y[2], 1.0 );
327 }
328
329
330 BOOST_AUTO_TEST_CASE( UnaryOperationsTest ) {
331 Prob xorg(3);
332 xorg[0] = 2.0;
333 xorg[1] = 0.0;
334 xorg[2] = 1.0;
335 Prob y(3);
336
337 Prob x = xorg;
338 BOOST_CHECK( x.setUniform() == Prob(3) );
339 BOOST_CHECK( x == Prob(3) );
340
341 y[0] = std::exp(2.0);
342 y[1] = 1.0;
343 y[2] = std::exp(1.0);
344 x = xorg;
345 BOOST_CHECK( x.takeExp() == y );
346 BOOST_CHECK( x == y );
347 x = xorg;
348 BOOST_CHECK( x.pwUnaryOp( fo_exp<Real>() ) == y );
349 BOOST_CHECK( x == y );
350
351 y[0] = std::log(2.0);
352 y[1] = -INFINITY;
353 y[2] = 0.0;
354 x = xorg;
355 BOOST_CHECK( x.takeLog() == y );
356 BOOST_CHECK( x == y );
357 x = xorg;
358 BOOST_CHECK( x.takeLog(false) == y );
359 BOOST_CHECK( x == y );
360 x = xorg;
361 BOOST_CHECK( x.pwUnaryOp( fo_log<Real>() ) == y );
362 BOOST_CHECK( x == y );
363
364 y[1] = 0.0;
365 x = xorg;
366 BOOST_CHECK( x.takeLog(true) == y );
367 BOOST_CHECK( x == y );
368 x = xorg;
369 BOOST_CHECK( x.pwUnaryOp( fo_log0<Real>() ) == y );
370 BOOST_CHECK( x == y );
371
372 y[0] = 2.0 / 3.0;
373 y[1] = 0.0 / 3.0;
374 y[2] = 1.0 / 3.0;
375 x = xorg;
376 BOOST_CHECK_EQUAL( x.normalize(), 3.0 );
377 BOOST_CHECK( x == y );
378
379 x = xorg;
380 BOOST_CHECK_EQUAL( x.normalize( Prob::NORMPROB ), 3.0 );
381 BOOST_CHECK( x == y );
382
383 y[0] = 2.0 / 2.0;
384 y[1] = 0.0 / 2.0;
385 y[2] = 1.0 / 2.0;
386 x = xorg;
387 BOOST_CHECK_EQUAL( x.normalize( Prob::NORMLINF ), 2.0 );
388 BOOST_CHECK( x == y );
389
390 xorg[0] = -2.0;
391 y[0] = 2.0;
392 y[1] = 0.0;
393 y[2] = 1.0;
394 x = xorg;
395 BOOST_CHECK( x.takeAbs() == y );
396 BOOST_CHECK( x == y );
397
398 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
399 x.randomize();
400 for( size_t i = 0; i < x.size(); i++ ) {
401 BOOST_CHECK( x[i] < 1.0 );
402 BOOST_CHECK( x[i] >= 0.0 );
403 }
404 }
405 }
406
407
408 BOOST_AUTO_TEST_CASE( ScalarTransformationsTest ) {
409 Prob x(3);
410 x[0] = 2.0;
411 x[1] = 0.0;
412 x[2] = 1.0;
413 Prob y(3);
414
415 y[0] = 3.0; y[1] = 1.0; y[2] = 2.0;
416 BOOST_CHECK( (x + 1.0) == y );
417 y[0] = 0.0; y[1] = -2.0; y[2] = -1.0;
418 BOOST_CHECK( (x + (-2.0)) == y );
419
420 y[0] = 1.0; y[1] = -1.0; y[2] = 0.0;
421 BOOST_CHECK( (x - 1.0) == y );
422 y[0] = 4.0; y[1] = 2.0; y[2] = 3.0;
423 BOOST_CHECK( (x - (-2.0)) == y );
424
425 BOOST_CHECK( (x * 1.0) == x );
426 y[0] = 4.0; y[1] = 0.0; y[2] = 2.0;
427 BOOST_CHECK( (x * 2.0) == y );
428 y[0] = -1.0; y[1] = 0.0; y[2] = -0.5;
429 BOOST_CHECK( (x * -0.5) == y );
430
431 BOOST_CHECK( (x / 1.0) == x );
432 y[0] = 1.0; y[1] = 0.0; y[2] = 0.5;
433 BOOST_CHECK( (x / 2.0) == y );
434 y[0] = -4.0; y[1] = 0.0; y[2] = -2.0;
435 BOOST_CHECK( (x / -0.5) == y );
436 BOOST_CHECK( (x / 0.0) == Prob(3, 0.0) );
437
438 BOOST_CHECK( (x ^ 1.0) == x );
439 BOOST_CHECK( (x ^ 0.0) == Prob(3, 1.0) );
440 y[0] = 4.0; y[1] = 0.0; y[2] = 1.0;
441 BOOST_CHECK( (x ^ 2.0) == y );
442 y[0] = 1.0 / std::sqrt(2.0); y[1] = INFINITY; y[2] = 1.0;
443 Prob z = (x ^ -0.5);
444 BOOST_CHECK_CLOSE( z[0], y[0], tol );
445 BOOST_CHECK_EQUAL( z[1], y[1] );
446 BOOST_CHECK_CLOSE( z[2], y[2], tol );
447 }
448
449
450 BOOST_AUTO_TEST_CASE( ScalarOperationsTest ) {
451 Prob xorg(3), x(3);
452 xorg[0] = 2.0;
453 xorg[1] = 0.0;
454 xorg[2] = 1.0;
455 Prob y(3);
456
457 x = xorg;
458 BOOST_CHECK( x.fill( 1.0 ) == Prob(3, 1.0) );
459 BOOST_CHECK( x == Prob(3, 1.0) );
460 BOOST_CHECK( x.fill( 2.0 ) == Prob(3, 2.0) );
461 BOOST_CHECK( x == Prob(3, 2.0) );
462 BOOST_CHECK( x.fill( 0.0 ) == Prob(3, 0.0) );
463 BOOST_CHECK( x == Prob(3, 0.0) );
464
465 x = xorg;
466 y[0] = 3.0; y[1] = 1.0; y[2] = 2.0;
467 BOOST_CHECK( (x += 1.0) == y );
468 BOOST_CHECK( x == y );
469 y[0] = 1.0; y[1] = -1.0; y[2] = 0.0;
470 BOOST_CHECK( (x += -2.0) == y );
471 BOOST_CHECK( x == y );
472
473 x = xorg;
474 y[0] = 1.0; y[1] = -1.0; y[2] = 0.0;
475 BOOST_CHECK( (x -= 1.0) == y );
476 BOOST_CHECK( x == y );
477 y[0] = 3.0; y[1] = 1.0; y[2] = 2.0;
478 BOOST_CHECK( (x -= -2.0) == y );
479 BOOST_CHECK( x == y );
480
481 x = xorg;
482 BOOST_CHECK( (x *= 1.0) == x );
483 BOOST_CHECK( x == x );
484 y[0] = 4.0; y[1] = 0.0; y[2] = 2.0;
485 BOOST_CHECK( (x *= 2.0) == y );
486 BOOST_CHECK( x == y );
487 y[0] = -1.0; y[1] = 0.0; y[2] = -0.5;
488 BOOST_CHECK( (x *= -0.25) == y );
489 BOOST_CHECK( x == y );
490
491 x = xorg;
492 BOOST_CHECK( (x /= 1.0) == x );
493 BOOST_CHECK( x == x );
494 y[0] = 1.0; y[1] = 0.0; y[2] = 0.5;
495 BOOST_CHECK( (x /= 2.0) == y );
496 BOOST_CHECK( x == y );
497 y[0] = -4.0; y[1] = 0.0; y[2] = -2.0;
498 BOOST_CHECK( (x /= -0.25) == y );
499 BOOST_CHECK( x == y );
500 BOOST_CHECK( (x /= 0.0) == Prob(3, 0.0) );
501 BOOST_CHECK( x == Prob(3, 0.0) );
502
503 x = xorg;
504 BOOST_CHECK( (x ^= 1.0) == x );
505 BOOST_CHECK( x == x );
506 BOOST_CHECK( (x ^= 0.0) == Prob(3, 1.0) );
507 BOOST_CHECK( x == Prob(3, 1.0) );
508 x = xorg;
509 y[0] = 4.0; y[1] = 0.0; y[2] = 1.0;
510 BOOST_CHECK( (x ^= 2.0) == y );
511 BOOST_CHECK( x == y );
512 y[0] = 0.5; y[1] = INFINITY; y[2] = 1.0;
513 BOOST_CHECK( (x ^= -0.5) == y );
514 BOOST_CHECK( x == y );
515 }
516
517
518 BOOST_AUTO_TEST_CASE( VectorOperationsTest ) {
519 size_t N = 6;
520 Prob xorg(N), x(N);
521 xorg[0] = 2.0; xorg[1] = 0.0; xorg[2] = 1.0; xorg[3] = 0.0; xorg[4] = 2.0; xorg[5] = 3.0;
522 Prob y(N);
523 y[0] = 0.5; y[1] = -1.0; y[2] = 0.0; y[3] = 0.0; y[4] = -2.0; y[5] = 3.0;
524 Prob z(N), r(N);
525
526 z[0] = 2.5; z[1] = -1.0; z[2] = 1.0; z[3] = 0.0; z[4] = 0.0; z[5] = 6.0;
527 x = xorg;
528 r = (x += y);
529 for( size_t i = 0; i < N; i++ )
530 BOOST_CHECK_CLOSE( r[i], z[i], tol );
531 BOOST_CHECK( x == z );
532 x = xorg;
533 BOOST_CHECK( x.pwBinaryOp( y, std::plus<Real>() ) == z );
534 BOOST_CHECK( x == z );
535
536 z[0] = 1.5; z[1] = 1.0; z[2] = 1.0; z[3] = 0.0; z[4] = 4.0; z[5] = 0.0;
537 x = xorg;
538 r = (x -= y);
539 for( size_t i = 0; i < N; i++ )
540 BOOST_CHECK_CLOSE( r[i], z[i], tol );
541 BOOST_CHECK( x == z );
542 x = xorg;
543 BOOST_CHECK( x.pwBinaryOp( y, std::minus<Real>() ) == z );
544 BOOST_CHECK( x == z );
545
546 z[0] = 1.0; z[1] = 0.0; z[2] = 0.0; z[3] = 0.0; z[4] = -4.0; z[5] = 9.0;
547 x = xorg;
548 r = (x *= y);
549 for( size_t i = 0; i < N; i++ )
550 BOOST_CHECK_CLOSE( r[i], z[i], tol );
551 BOOST_CHECK( x == z );
552 x = xorg;
553 BOOST_CHECK( x.pwBinaryOp( y, std::multiplies<Real>() ) == z );
554 BOOST_CHECK( x == z );
555
556 z[0] = 4.0; z[1] = 0.0; z[2] = 0.0; z[3] = 0.0; z[4] = -1.0; z[5] = 1.0;
557 x = xorg;
558 r = (x /= y);
559 for( size_t i = 0; i < N; i++ )
560 BOOST_CHECK_CLOSE( r[i], z[i], tol );
561 BOOST_CHECK( x == z );
562 x = xorg;
563 BOOST_CHECK( x.pwBinaryOp( y, fo_divides0<Real>() ) == z );
564 BOOST_CHECK( x == z );
565
566 z[0] = 4.0; z[1] = 0.0; z[2] = INFINITY; /*z[3] = INFINITY;*/ z[4] = -1.0; z[5] = 1.0;
567 x = xorg;
568 r = (x.divide( y ));
569 BOOST_CHECK_CLOSE( r[0], z[0], tol );
570 BOOST_CHECK_CLOSE( r[1], z[1], tol );
571 BOOST_CHECK_EQUAL( r[2], z[2] );
572 BOOST_CHECK( isnan(r[3]) );
573 BOOST_CHECK_CLOSE( r[4], z[4], tol );
574 BOOST_CHECK_CLOSE( r[5], z[5], tol );
575 x[3] = 0.0; r[3] = 0.0;
576 BOOST_CHECK( x == r );
577 x = xorg;
578 r = x.pwBinaryOp( y, std::divides<Real>() );
579 BOOST_CHECK_CLOSE( r[0], z[0], tol );
580 BOOST_CHECK_CLOSE( r[1], z[1], tol );
581 BOOST_CHECK_EQUAL( r[2], z[2] );
582 BOOST_CHECK( isnan(r[3]) );
583 BOOST_CHECK_CLOSE( r[4], z[4], tol );
584 BOOST_CHECK_CLOSE( r[5], z[5], tol );
585 x[3] = 0.0; r[3] = 0.0;
586 BOOST_CHECK( x == r );
587
588 z[0] = std::sqrt(2.0); z[1] = INFINITY; z[2] = 1.0; z[3] = 1.0; z[4] = 0.25; z[5] = 27.0;
589 x = xorg;
590 r = (x ^= y);
591 BOOST_CHECK_CLOSE( r[0], z[0], tol );
592 BOOST_CHECK_EQUAL( r[1], z[1] );
593 BOOST_CHECK_CLOSE( r[2], z[2], tol );
594 BOOST_CHECK_CLOSE( r[3], z[3], tol );
595 BOOST_CHECK_CLOSE( r[4], z[4], tol );
596 BOOST_CHECK_CLOSE( r[5], z[5], tol );
597 BOOST_CHECK( x == z );
598 x = xorg;
599 BOOST_CHECK( x.pwBinaryOp( y, fo_pow<Real>() ) == z );
600 BOOST_CHECK( x == z );
601 }
602
603
604 BOOST_AUTO_TEST_CASE( VectorTransformationsTest ) {
605 size_t N = 6;
606 Prob x(N);
607 x[0] = 2.0; x[1] = 0.0; x[2] = 1.0; x[3] = 0.0; x[4] = 2.0; x[5] = 3.0;
608 Prob y(N);
609 y[0] = 0.5; y[1] = -1.0; y[2] = 0.0; y[3] = 0.0; y[4] = -2.0; y[5] = 3.0;
610 Prob z(N), r(N);
611
612 z[0] = 2.5; z[1] = -1.0; z[2] = 1.0; z[3] = 0.0; z[4] = 0.0; z[5] = 6.0;
613 r = x + y;
614 for( size_t i = 0; i < N; i++ )
615 BOOST_CHECK_CLOSE( r[i], z[i], tol );
616 z = x.pwBinaryTr( y, std::plus<Real>() );
617 BOOST_CHECK( r == z );
618
619 z[0] = 1.5; z[1] = 1.0; z[2] = 1.0; z[3] = 0.0; z[4] = 4.0; z[5] = 0.0;
620 r = x - y;
621 for( size_t i = 0; i < N; i++ )
622 BOOST_CHECK_CLOSE( r[i], z[i], tol );
623 z = x.pwBinaryTr( y, std::minus<Real>() );
624 BOOST_CHECK( r == z );
625
626 z[0] = 1.0; z[1] = 0.0; z[2] = 0.0; z[3] = 0.0; z[4] = -4.0; z[5] = 9.0;
627 r = x * y;
628 for( size_t i = 0; i < N; i++ )
629 BOOST_CHECK_CLOSE( r[i], z[i], tol );
630 z = x.pwBinaryTr( y, std::multiplies<Real>() );
631 BOOST_CHECK( r == z );
632
633 z[0] = 4.0; z[1] = 0.0; z[2] = 0.0; z[3] = 0.0; z[4] = -1.0; z[5] = 1.0;
634 r = x / y;
635 for( size_t i = 0; i < N; i++ )
636 BOOST_CHECK_CLOSE( r[i], z[i], tol );
637 z = x.pwBinaryTr( y, fo_divides0<Real>() );
638 BOOST_CHECK( r == z );
639
640 z[0] = 4.0; z[1] = 0.0; z[2] = INFINITY; /*z[3] = INFINITY;*/ z[4] = -1.0; z[5] = 1.0;
641 r = x.divided_by( y );
642 BOOST_CHECK_CLOSE( r[0], z[0], tol );
643 BOOST_CHECK_CLOSE( r[1], z[1], tol );
644 BOOST_CHECK_EQUAL( r[2], z[2] );
645 BOOST_CHECK( isnan(r[3]) );
646 BOOST_CHECK_CLOSE( r[4], z[4], tol );
647 BOOST_CHECK_CLOSE( r[5], z[5], tol );
648 z = x.pwBinaryTr( y, std::divides<Real>() );
649 BOOST_CHECK_CLOSE( r[0], z[0], tol );
650 BOOST_CHECK_CLOSE( r[1], z[1], tol );
651 BOOST_CHECK_EQUAL( r[2], z[2] );
652 BOOST_CHECK( isnan(r[3]) );
653 BOOST_CHECK_CLOSE( r[4], z[4], tol );
654 BOOST_CHECK_CLOSE( r[5], z[5], tol );
655
656 z[0] = std::sqrt(2.0); z[1] = INFINITY; z[2] = 1.0; z[3] = 1.0; z[4] = 0.25; z[5] = 27.0;
657 r = x ^ y;
658 BOOST_CHECK_CLOSE( r[0], z[0], tol );
659 BOOST_CHECK_EQUAL( r[1], z[1] );
660 BOOST_CHECK_CLOSE( r[2], z[2], tol );
661 BOOST_CHECK_CLOSE( r[3], z[3], tol );
662 BOOST_CHECK_CLOSE( r[4], z[4], tol );
663 BOOST_CHECK_CLOSE( r[5], z[5], tol );
664 z = x.pwBinaryTr( y, fo_pow<Real>() );
665 BOOST_CHECK( r == z );
666 }
667
668
669 BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
670 Prob x(3), y(3), z(3);
671 x[0] = 0.2;
672 x[1] = 0.8;
673 x[2] = 0.0;
674 y[0] = 0.0;
675 y[1] = 0.6;
676 y[2] = 0.4;
677
678 z = min( x, y );
679 BOOST_CHECK_EQUAL( z[0], 0.0 );
680 BOOST_CHECK_EQUAL( z[1], 0.6 );
681 BOOST_CHECK_EQUAL( z[2], 0.0 );
682 z = max( x, y );
683 BOOST_CHECK_EQUAL( z[0], 0.2 );
684 BOOST_CHECK_EQUAL( z[1], 0.8 );
685 BOOST_CHECK_EQUAL( z[2], 0.4 );
686
687 BOOST_CHECK_EQUAL( dist( x, x, Prob::DISTL1 ), 0.0 );
688 BOOST_CHECK_EQUAL( dist( y, y, Prob::DISTL1 ), 0.0 );
689 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTL1 ), 0.2 + 0.2 + 0.4 );
690 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTL1 ), 0.2 + 0.2 + 0.4 );
691 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTL1 ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
692 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTL1 ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
693 BOOST_CHECK_EQUAL( dist( x, x, Prob::DISTLINF ), 0.0 );
694 BOOST_CHECK_EQUAL( dist( y, y, Prob::DISTLINF ), 0.0 );
695 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTLINF ), 0.4 );
696 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTLINF ), 0.4 );
697 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTLINF ), x.innerProduct( y, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
698 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTLINF ), y.innerProduct( x, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
699 BOOST_CHECK_EQUAL( dist( x, x, Prob::DISTTV ), 0.0 );
700 BOOST_CHECK_EQUAL( dist( y, y, Prob::DISTTV ), 0.0 );
701 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
702 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
703 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTTV ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
704 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTTV ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
705 BOOST_CHECK_EQUAL( dist( x, x, Prob::DISTKL ), 0.0 );
706 BOOST_CHECK_EQUAL( dist( y, y, Prob::DISTKL ), 0.0 );
707 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTKL ), INFINITY );
708 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTKL ), INFINITY );
709 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
710 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
711 BOOST_CHECK_EQUAL( dist( x, x, Prob::DISTHEL ), 0.0 );
712 BOOST_CHECK_EQUAL( dist( y, y, Prob::DISTHEL ), 0.0 );
713 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
714 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
715 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTHEL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
716 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTHEL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
717 x[1] = 0.7; x[2] = 0.1;
718 y[0] = 0.1; y[1] = 0.5;
719 BOOST_CHECK_CLOSE( dist( x, y, Prob::DISTKL ), 0.2 * std::log(0.2 / 0.1) + 0.7 * std::log(0.7 / 0.5) + 0.1 * std::log(0.1 / 0.4), tol );
720 BOOST_CHECK_CLOSE( dist( y, x, Prob::DISTKL ), 0.1 * std::log(0.1 / 0.2) + 0.5 * std::log(0.5 / 0.7) + 0.4 * std::log(0.4 / 0.1), tol );
721 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
722 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
723
724 std::stringstream ss;
725 ss << x;
726 std::string s;
727 std::getline( ss, s );
728 BOOST_CHECK_EQUAL( s, std::string("(0.2, 0.7, 0.1)") );
729 std::stringstream ss2;
730 ss2 << y;
731 std::getline( ss2, s );
732 BOOST_CHECK_EQUAL( s, std::string("(0.1, 0.5, 0.4)") );
733
734 z = min( x, y );
735 BOOST_CHECK_EQUAL( z[0], 0.1 );
736 BOOST_CHECK_EQUAL( z[1], 0.5 );
737 BOOST_CHECK_EQUAL( z[2], 0.1 );
738 z = max( x, y );
739 BOOST_CHECK_EQUAL( z[0], 0.2 );
740 BOOST_CHECK_EQUAL( z[1], 0.7 );
741 BOOST_CHECK_EQUAL( z[2], 0.4 );
742
743 BOOST_CHECK_CLOSE( x.innerProduct( y, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
744 BOOST_CHECK_CLOSE( y.innerProduct( x, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
745 BOOST_CHECK_CLOSE( x.innerProduct( y, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
746 BOOST_CHECK_CLOSE( y.innerProduct( x, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
747 BOOST_CHECK_CLOSE( x.innerProduct( y, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
748 BOOST_CHECK_CLOSE( y.innerProduct( x, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
749 }