5df21ab08fd6e21ea5772faaa03ff91812c8877e
[libdai.git] / tests / unit / prob.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/prob.h>
15 #include <strstream>
16
17
18 using namespace dai;
19
20
21 const double tol = 1e-8;
22
23
24 #define BOOST_TEST_MODULE ProbTest
25
26
27 #include <boost/test/unit_test.hpp>
28 #include <boost/test/floating_point_comparison.hpp>
29
30
31 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
32 // check constructors
33 Prob x1;
34 BOOST_CHECK_EQUAL( x1.size(), 0 );
35 BOOST_CHECK( x1.p() == Prob::container_type() );
36
37 Prob x2( 3 );
38 BOOST_CHECK_EQUAL( x2.size(), 3 );
39 BOOST_CHECK( x2.p() == Prob::container_type( 3, 1.0 / 3.0 ) );
40 BOOST_CHECK_EQUAL( x2[0], 1.0 / 3.0 );
41 BOOST_CHECK_EQUAL( x2[1], 1.0 / 3.0 );
42 BOOST_CHECK_EQUAL( x2[2], 1.0 / 3.0 );
43
44 Prob x3( 4, 1.0 );
45 BOOST_CHECK_EQUAL( x3.size(), 4 );
46 BOOST_CHECK( x3.p() == Prob::container_type( 4, 1.0 ) );
47 BOOST_CHECK_EQUAL( x3[0], 1.0 );
48 BOOST_CHECK_EQUAL( x3[1], 1.0 );
49 BOOST_CHECK_EQUAL( x3[2], 1.0 );
50 BOOST_CHECK_EQUAL( x3[3], 1.0 );
51 x3.set( 0, 0.5 );
52 x3.set( 1, 1.0 );
53 x3.set( 2, 2.0 );
54 x3.set( 3, 4.0 );
55
56 std::vector<Real> v;
57 v.push_back( 0.5 );
58 v.push_back( 1.0 );
59 v.push_back( 2.0 );
60 v.push_back( 4.0 );
61 Prob x4( v.begin(), v.end(), 0 );
62 BOOST_CHECK_EQUAL( x4.size(), 4 );
63 BOOST_CHECK( x4.p() == x3.p() );
64 BOOST_CHECK( x4 == x3 );
65 BOOST_CHECK_EQUAL( x4[0], 0.5 );
66 BOOST_CHECK_EQUAL( x4[1], 1.0 );
67 BOOST_CHECK_EQUAL( x4[2], 2.0 );
68 BOOST_CHECK_EQUAL( x4[3], 4.0 );
69
70 Prob x5( v.begin(), v.end(), v.size() );
71 BOOST_CHECK_EQUAL( x5.size(), 4 );
72 BOOST_CHECK( x5.p() == x3.p() );
73 BOOST_CHECK( x5 == x3 );
74 BOOST_CHECK_EQUAL( x5[0], 0.5 );
75 BOOST_CHECK_EQUAL( x5[1], 1.0 );
76 BOOST_CHECK_EQUAL( x5[2], 2.0 );
77 BOOST_CHECK_EQUAL( x5[3], 4.0 );
78
79 std::vector<int> y( 3, 2 );
80 Prob x6( y );
81 BOOST_CHECK_EQUAL( x6.size(), 3 );
82 BOOST_CHECK_EQUAL( x6[0], 2.0 );
83 BOOST_CHECK_EQUAL( x6[1], 2.0 );
84 BOOST_CHECK_EQUAL( x6[2], 2.0 );
85
86 Prob x7( x6 );
87 BOOST_CHECK( x7 == x6 );
88
89 Prob x8 = x6;
90 BOOST_CHECK( x8 == x6 );
91 }
92
93
94 #ifndef DAI_SPARSE
95 BOOST_AUTO_TEST_CASE( IteratorTest ) {
96 Prob x( 5, 0.0 );
97 size_t i;
98 for( i = 0; i < x.size(); i++ )
99 x.set( i, i );
100
101 i = 0;
102 for( Prob::const_iterator cit = x.begin(); cit != x.end(); cit++, i++ )
103 BOOST_CHECK_EQUAL( *cit, i );
104
105 i = 0;
106 for( Prob::iterator it = x.begin(); it != x.end(); it++, i++ )
107 *it = 4 - i;
108
109 i = 0;
110 for( Prob::const_iterator it = x.begin(); it != x.end(); it++, i++ )
111 BOOST_CHECK_EQUAL( *it, 4 - i );
112
113 i = 0;
114 for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
115 BOOST_CHECK_EQUAL( *crit, i );
116
117 i = 0;
118 for( Prob::reverse_iterator rit = x.rbegin(); rit != x.rend(); rit++, i++ )
119 *rit = 2 * i;
120
121 i = 0;
122 for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
123 BOOST_CHECK_EQUAL( *crit, 2 * i );
124 }
125 #endif
126
127
128 BOOST_AUTO_TEST_CASE( QueriesTest ) {
129 Prob x( 5, 0.0 );
130 for( size_t i = 0; i < x.size(); i++ )
131 x.set( i, 2.0 - i );
132
133 // test accumulate, min, max, sum, sumAbs, maxAbs
134 BOOST_CHECK_EQUAL( x.sum(), 0.0 );
135 BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_id<Real>() ), 0.0 );
136 BOOST_CHECK_EQUAL( x.accumulateSum( 1.0, fo_id<Real>() ), 1.0 );
137 BOOST_CHECK_EQUAL( x.accumulateSum( -1.0, fo_id<Real>() ), -1.0 );
138 BOOST_CHECK_EQUAL( x.max(), 2.0 );
139 BOOST_CHECK_EQUAL( x.accumulateMax( -INFINITY, fo_id<Real>(), false ), 2.0 );
140 BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_id<Real>(), false ), 3.0 );
141 BOOST_CHECK_EQUAL( x.accumulateMax( -5.0, fo_id<Real>(), false ), 2.0 );
142 BOOST_CHECK_EQUAL( x.min(), -2.0 );
143 BOOST_CHECK_EQUAL( x.accumulateMax( INFINITY, fo_id<Real>(), true ), -2.0 );
144 BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_id<Real>(), true ), -3.0 );
145 BOOST_CHECK_EQUAL( x.accumulateMax( 5.0, fo_id<Real>(), true ), -2.0 );
146 BOOST_CHECK_EQUAL( x.sumAbs(), 6.0 );
147 BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_abs<Real>() ), 6.0 );
148 BOOST_CHECK_EQUAL( x.accumulateSum( 1.0, fo_abs<Real>() ), 7.0 );
149 BOOST_CHECK_EQUAL( x.accumulateSum( -1.0, fo_abs<Real>() ), 7.0 );
150 BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
151 BOOST_CHECK_EQUAL( x.accumulateMax( 0.0, fo_abs<Real>(), false ), 2.0 );
152 BOOST_CHECK_EQUAL( x.accumulateMax( 1.0, fo_abs<Real>(), false ), 2.0 );
153 BOOST_CHECK_EQUAL( x.accumulateMax( -1.0, fo_abs<Real>(), false ), 2.0 );
154 BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_abs<Real>(), false ), 3.0 );
155 BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_abs<Real>(), false ), 3.0 );
156 x.set( 1, 1.0 );
157 BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
158 BOOST_CHECK_EQUAL( x.accumulateMax( 0.0, fo_abs<Real>(), false ), 2.0 );
159 BOOST_CHECK_EQUAL( x.accumulateMax( 1.0, fo_abs<Real>(), false ), 2.0 );
160 BOOST_CHECK_EQUAL( x.accumulateMax( -1.0, fo_abs<Real>(), false ), 2.0 );
161 BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_abs<Real>(), false ), 3.0 );
162 BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_abs<Real>(), false ), 3.0 );
163 for( size_t i = 0; i < x.size(); i++ )
164 x.set( i, i ? (1.0 / i) : 0.0 );
165 BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_inv0<Real>() ), 10.0 );
166 x /= x.sum();
167
168 // test entropy
169 BOOST_CHECK( x.entropy() < Prob(5).entropy() );
170 for( size_t i = 1; i < 100; i++ )
171 BOOST_CHECK_CLOSE( Prob(i).entropy(), std::log(i), tol );
172
173 // test hasNaNs and hasNegatives
174 BOOST_CHECK( !Prob( 3, 0.0 ).hasNaNs() );
175 Real c = 0.0;
176 BOOST_CHECK( Prob( 3, c / c ).hasNaNs() );
177 BOOST_CHECK( !Prob( 3, 0.0 ).hasNegatives() );
178 BOOST_CHECK( !Prob( 3, 1.0 ).hasNegatives() );
179 BOOST_CHECK( Prob( 3, -1.0 ).hasNegatives() );
180 x.set( 0, 0.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 ); x.set( 3, 1.0 ); x.set( 4, 100.0 );
181 BOOST_CHECK( x.hasNegatives() );
182 x.set( 2, -INFINITY );
183 BOOST_CHECK( x.hasNegatives() );
184 x.set( 2, INFINITY );
185 BOOST_CHECK( !x.hasNegatives() );
186 x.set( 2, -1.0 );
187
188 // test argmax
189 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)100.0 ) );
190 x.set( 4, 0.5 );
191 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)3, (Real)1.0 ) );
192 x.set( 3, -2.0 );
193 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)0.5 ) );
194 x.set( 4, -1.0 );
195 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)0, (Real)0.0 ) );
196 x.set( 0, -2.0 );
197 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)1, (Real)0.0 ) );
198 x.set( 1, -3.0 );
199 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)2, (Real)-1.0 ) );
200 x.set( 2, -2.0 );
201 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)-1.0 ) );
202
203 // test draw
204 for( size_t i = 0; i < x.size(); i++ )
205 x.set( i, i ? (1.0 / i) : 0.0 );
206 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
207 BOOST_CHECK( x.draw() < x.size() );
208 BOOST_CHECK( x.draw() != 0 );
209 }
210 x.set( 2, 0.0 );
211 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
212 BOOST_CHECK( x.draw() < x.size() );
213 BOOST_CHECK( x.draw() != 0 );
214 BOOST_CHECK( x.draw() != 2 );
215 }
216 x.set( 4, 0.0 );
217 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
218 BOOST_CHECK( x.draw() < x.size() );
219 BOOST_CHECK( x.draw() != 0 );
220 BOOST_CHECK( x.draw() != 2 );
221 BOOST_CHECK( x.draw() != 4 );
222 }
223 x.set( 1, 0.0 );
224 for( size_t repeat = 0; repeat < 10000; repeat++ )
225 BOOST_CHECK( x.draw() == 3 );
226
227 // test <, ==
228 Prob a(3, 1.0), b(3, 1.0);
229 BOOST_CHECK( !(a < b) );
230 BOOST_CHECK( !(b < a) );
231 BOOST_CHECK( a == b );
232 a.set( 0, 0.0 );
233 BOOST_CHECK( a < b );
234 BOOST_CHECK( !(b < a) );
235 BOOST_CHECK( !(a == b) );
236 b.set( 2, 0.0 );
237 BOOST_CHECK( a < b );
238 BOOST_CHECK( !(b < a) );
239 BOOST_CHECK( !(a == b) );
240 b.set( 0, 0.0 );
241 BOOST_CHECK( !(a < b) );
242 BOOST_CHECK( b < a );
243 BOOST_CHECK( !(a == b) );
244 a.set( 1, 0.0 );
245 BOOST_CHECK( a < b );
246 BOOST_CHECK( !(b < a) );
247 BOOST_CHECK( !(a == b) );
248 b.set( 1, 0.0 );
249 BOOST_CHECK( !(a < b) );
250 BOOST_CHECK( b < a );
251 BOOST_CHECK( !(a == b) );
252 a.set( 2, 0.0 );
253 BOOST_CHECK( !(a < b) );
254 BOOST_CHECK( !(b < a) );
255 BOOST_CHECK( a == b );
256 }
257
258
259 BOOST_AUTO_TEST_CASE( UnaryTransformationsTest ) {
260 Prob x( 3 );
261 x.set( 0, -2.0 );
262 x.set( 1, 0.0 );
263 x.set( 2, 2.0 );
264
265 Prob y = -x;
266 Prob z = x.pwUnaryTr( std::negate<Real>() );
267 BOOST_CHECK_EQUAL( y[0], 2.0 );
268 BOOST_CHECK_EQUAL( y[1], 0.0 );
269 BOOST_CHECK_EQUAL( y[2], -2.0 );
270 BOOST_CHECK( y == z );
271
272 y = x.abs();
273 z = x.pwUnaryTr( fo_abs<Real>() );
274 BOOST_CHECK_EQUAL( y[0], 2.0 );
275 BOOST_CHECK_EQUAL( y[1], 0.0 );
276 BOOST_CHECK_EQUAL( y[2], 2.0 );
277 BOOST_CHECK( y == z );
278
279 y = x.exp();
280 z = x.pwUnaryTr( fo_exp<Real>() );
281 BOOST_CHECK_CLOSE( y[0], std::exp(-2.0), tol );
282 BOOST_CHECK_EQUAL( y[1], 1.0 );
283 BOOST_CHECK_CLOSE( y[2], 1.0 / y[0], tol );
284 BOOST_CHECK( y == z );
285
286 y = x.log(false);
287 z = x.pwUnaryTr( fo_log<Real>() );
288 BOOST_CHECK( isnan( y[0] ) );
289 BOOST_CHECK_EQUAL( y[1], -INFINITY );
290 BOOST_CHECK_CLOSE( y[2], std::log(2.0), tol );
291 BOOST_CHECK( !(y == z) );
292 y.set( 0, 0.0 );
293 z.set( 0, 0.0 );
294 BOOST_CHECK( y == z );
295
296 y = x.log(true);
297 z = x.pwUnaryTr( fo_log0<Real>() );
298 BOOST_CHECK( isnan( y[0] ) );
299 BOOST_CHECK_EQUAL( y[1], 0.0 );
300 BOOST_CHECK_EQUAL( y[2], std::log(2.0) );
301 BOOST_CHECK( !(y == z) );
302 y.set( 0, 0.0 );
303 z.set( 0, 0.0 );
304 BOOST_CHECK( y == z );
305
306 y = x.inverse(false);
307 z = x.pwUnaryTr( fo_inv<Real>() );
308 BOOST_CHECK_EQUAL( y[0], -0.5 );
309 BOOST_CHECK_EQUAL( y[1], INFINITY );
310 BOOST_CHECK_EQUAL( y[2], 0.5 );
311 BOOST_CHECK( y == z );
312
313 y = x.inverse(true);
314 z = x.pwUnaryTr( fo_inv0<Real>() );
315 BOOST_CHECK_EQUAL( y[0], -0.5 );
316 BOOST_CHECK_EQUAL( y[1], 0.0 );
317 BOOST_CHECK_EQUAL( y[2], 0.5 );
318 BOOST_CHECK( y == z );
319
320 x.set( 0, 2.0 );
321 y = x.normalized();
322 BOOST_CHECK_EQUAL( y[0], 0.5 );
323 BOOST_CHECK_EQUAL( y[1], 0.0 );
324 BOOST_CHECK_EQUAL( y[2], 0.5 );
325
326 y = x.normalized( NORMPROB );
327 BOOST_CHECK_EQUAL( y[0], 0.5 );
328 BOOST_CHECK_EQUAL( y[1], 0.0 );
329 BOOST_CHECK_EQUAL( y[2], 0.5 );
330
331 x.set( 0, -2.0 );
332 y = x.normalized( NORMLINF );
333 BOOST_CHECK_EQUAL( y[0], -1.0 );
334 BOOST_CHECK_EQUAL( y[1], 0.0 );
335 BOOST_CHECK_EQUAL( y[2], 1.0 );
336 }
337
338
339 BOOST_AUTO_TEST_CASE( UnaryOperationsTest ) {
340 Prob xorg(3);
341 xorg.set( 0, 2.0 );
342 xorg.set( 1, 0.0 );
343 xorg.set( 2, 1.0 );
344 Prob y(3);
345
346 Prob x = xorg;
347 BOOST_CHECK( x.setUniform() == Prob(3) );
348 BOOST_CHECK( x == Prob(3) );
349
350 y.set( 0, std::exp(2.0) );
351 y.set( 1, 1.0 );
352 y.set( 2, std::exp(1.0) );
353 x = xorg;
354 BOOST_CHECK( x.takeExp() == y );
355 BOOST_CHECK( x == y );
356 x = xorg;
357 BOOST_CHECK( x.pwUnaryOp( fo_exp<Real>() ) == y );
358 BOOST_CHECK( x == y );
359
360 y.set( 0, std::log(2.0) );
361 y.set( 1, -INFINITY );
362 y.set( 2, 0.0 );
363 x = xorg;
364 BOOST_CHECK( x.takeLog() == y );
365 BOOST_CHECK( x == y );
366 x = xorg;
367 BOOST_CHECK( x.takeLog(false) == y );
368 BOOST_CHECK( x == y );
369 x = xorg;
370 BOOST_CHECK( x.pwUnaryOp( fo_log<Real>() ) == y );
371 BOOST_CHECK( x == y );
372
373 y.set( 1, 0.0 );
374 x = xorg;
375 BOOST_CHECK( x.takeLog(true) == y );
376 BOOST_CHECK( x == y );
377 x = xorg;
378 BOOST_CHECK( x.pwUnaryOp( fo_log0<Real>() ) == y );
379 BOOST_CHECK( x == y );
380
381 y.set( 0, 2.0 / 3.0 );
382 y.set( 1, 0.0 / 3.0 );
383 y.set( 2, 1.0 / 3.0 );
384 x = xorg;
385 BOOST_CHECK_EQUAL( x.normalize(), 3.0 );
386 BOOST_CHECK( x == y );
387
388 x = xorg;
389 BOOST_CHECK_EQUAL( x.normalize( NORMPROB ), 3.0 );
390 BOOST_CHECK( x == y );
391
392 y.set( 0, 2.0 / 2.0 );
393 y.set( 1, 0.0 / 2.0 );
394 y.set( 2, 1.0 / 2.0 );
395 x = xorg;
396 BOOST_CHECK_EQUAL( x.normalize( NORMLINF ), 2.0 );
397 BOOST_CHECK( x == y );
398
399 xorg.set( 0, -2.0 );
400 y.set( 0, 2.0 );
401 y.set( 1, 0.0 );
402 y.set( 2, 1.0 );
403 x = xorg;
404 BOOST_CHECK( x.takeAbs() == y );
405 BOOST_CHECK( x == y );
406
407 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
408 x.randomize();
409 for( size_t i = 0; i < x.size(); i++ ) {
410 BOOST_CHECK( x[i] < 1.0 );
411 BOOST_CHECK( x[i] >= 0.0 );
412 }
413 }
414 }
415
416
417 BOOST_AUTO_TEST_CASE( ScalarTransformationsTest ) {
418 Prob x(3);
419 x.set( 0, 2.0 );
420 x.set( 1, 0.0 );
421 x.set( 2, 1.0 );
422 Prob y(3);
423
424 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
425 BOOST_CHECK( (x + 1.0) == y );
426 y.set( 0, 0.0 ); y.set( 1, -2.0 ); y.set( 2, -1.0 );
427 BOOST_CHECK( (x + (-2.0)) == y );
428
429 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
430 BOOST_CHECK( (x - 1.0) == y );
431 y.set( 0, 4.0 ); y.set( 1, 2.0 ); y.set( 2, 3.0 );
432 BOOST_CHECK( (x - (-2.0)) == y );
433
434 BOOST_CHECK( (x * 1.0) == x );
435 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
436 BOOST_CHECK( (x * 2.0) == y );
437 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
438 BOOST_CHECK( (x * -0.5) == y );
439
440 BOOST_CHECK( (x / 1.0) == x );
441 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
442 BOOST_CHECK( (x / 2.0) == y );
443 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
444 BOOST_CHECK( (x / -0.5) == y );
445 BOOST_CHECK( (x / 0.0) == Prob(3, 0.0) );
446
447 BOOST_CHECK( (x ^ 1.0) == x );
448 BOOST_CHECK( (x ^ 0.0) == Prob(3, 1.0) );
449 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
450 BOOST_CHECK( (x ^ 2.0) == y );
451 y.set( 0, 1.0 / std::sqrt(2.0) ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
452 Prob z = (x ^ -0.5);
453 BOOST_CHECK_CLOSE( z[0], y[0], tol );
454 BOOST_CHECK_EQUAL( z[1], y[1] );
455 BOOST_CHECK_CLOSE( z[2], y[2], tol );
456 }
457
458
459 BOOST_AUTO_TEST_CASE( ScalarOperationsTest ) {
460 Prob xorg(3), x(3);
461 xorg.set( 0, 2.0 );
462 xorg.set( 1, 0.0 );
463 xorg.set( 2, 1.0 );
464 Prob y(3);
465
466 x = xorg;
467 BOOST_CHECK( x.fill( 1.0 ) == Prob(3, 1.0) );
468 BOOST_CHECK( x == Prob(3, 1.0) );
469 BOOST_CHECK( x.fill( 2.0 ) == Prob(3, 2.0) );
470 BOOST_CHECK( x == Prob(3, 2.0) );
471 BOOST_CHECK( x.fill( 0.0 ) == Prob(3, 0.0) );
472 BOOST_CHECK( x == Prob(3, 0.0) );
473
474 x = xorg;
475 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
476 BOOST_CHECK( (x += 1.0) == y );
477 BOOST_CHECK( x == y );
478 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
479 BOOST_CHECK( (x += -2.0) == y );
480 BOOST_CHECK( x == y );
481
482 x = xorg;
483 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
484 BOOST_CHECK( (x -= 1.0) == y );
485 BOOST_CHECK( x == y );
486 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
487 BOOST_CHECK( (x -= -2.0) == y );
488 BOOST_CHECK( x == y );
489
490 x = xorg;
491 BOOST_CHECK( (x *= 1.0) == x );
492 BOOST_CHECK( x == x );
493 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
494 BOOST_CHECK( (x *= 2.0) == y );
495 BOOST_CHECK( x == y );
496 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
497 BOOST_CHECK( (x *= -0.25) == y );
498 BOOST_CHECK( x == y );
499
500 x = xorg;
501 BOOST_CHECK( (x /= 1.0) == x );
502 BOOST_CHECK( x == x );
503 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
504 BOOST_CHECK( (x /= 2.0) == y );
505 BOOST_CHECK( x == y );
506 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
507 BOOST_CHECK( (x /= -0.25) == y );
508 BOOST_CHECK( x == y );
509 BOOST_CHECK( (x /= 0.0) == Prob(3, 0.0) );
510 BOOST_CHECK( x == Prob(3, 0.0) );
511
512 x = xorg;
513 BOOST_CHECK( (x ^= 1.0) == x );
514 BOOST_CHECK( x == x );
515 BOOST_CHECK( (x ^= 0.0) == Prob(3, 1.0) );
516 BOOST_CHECK( x == Prob(3, 1.0) );
517 x = xorg;
518 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
519 BOOST_CHECK( (x ^= 2.0) == y );
520 BOOST_CHECK( x == y );
521 y.set( 0, 0.5 ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
522 BOOST_CHECK( (x ^= -0.5) == y );
523 BOOST_CHECK( x == y );
524 }
525
526
527 BOOST_AUTO_TEST_CASE( VectorOperationsTest ) {
528 size_t N = 6;
529 Prob xorg(N), x(N);
530 xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, 1.0 ); xorg.set( 3, 0.0 ); xorg.set( 4, 2.0 ); xorg.set( 5, 3.0 );
531 Prob y(N);
532 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
533 Prob z(N), r(N);
534
535 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
536 x = xorg;
537 r = (x += y);
538 for( size_t i = 0; i < N; i++ )
539 BOOST_CHECK_CLOSE( r[i], z[i], tol );
540 BOOST_CHECK( x == z );
541 x = xorg;
542 BOOST_CHECK( x.pwBinaryOp( y, std::plus<Real>() ) == z );
543 BOOST_CHECK( x == z );
544
545 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
546 x = xorg;
547 r = (x -= y);
548 for( size_t i = 0; i < N; i++ )
549 BOOST_CHECK_CLOSE( r[i], z[i], tol );
550 BOOST_CHECK( x == z );
551 x = xorg;
552 BOOST_CHECK( x.pwBinaryOp( y, std::minus<Real>() ) == z );
553 BOOST_CHECK( x == z );
554
555 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
556 x = xorg;
557 r = (x *= y);
558 for( size_t i = 0; i < N; i++ )
559 BOOST_CHECK_CLOSE( r[i], z[i], tol );
560 BOOST_CHECK( x == z );
561 x = xorg;
562 BOOST_CHECK( x.pwBinaryOp( y, std::multiplies<Real>() ) == z );
563 BOOST_CHECK( x == z );
564
565 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
566 x = xorg;
567 r = (x /= y);
568 for( size_t i = 0; i < N; i++ )
569 BOOST_CHECK_CLOSE( r[i], z[i], tol );
570 BOOST_CHECK( x == z );
571 x = xorg;
572 BOOST_CHECK( x.pwBinaryOp( y, fo_divides0<Real>() ) == z );
573 BOOST_CHECK( x == z );
574
575 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY ); /*z.set( 3, INFINITY );*/ z.set( 4, -1.0 ); z.set( 5, 1.0 );
576 x = xorg;
577 r = (x.divide( y ));
578 BOOST_CHECK_CLOSE( r[0], z[0], tol );
579 BOOST_CHECK_CLOSE( r[1], z[1], tol );
580 BOOST_CHECK_EQUAL( r[2], z[2] );
581 BOOST_CHECK( isnan(r[3]) );
582 BOOST_CHECK_CLOSE( r[4], z[4], tol );
583 BOOST_CHECK_CLOSE( r[5], z[5], tol );
584 x.set( 3, 0.0 ); r.set( 3, 0.0 );
585 BOOST_CHECK( x == r );
586 x = xorg;
587 r = x.pwBinaryOp( y, std::divides<Real>() );
588 BOOST_CHECK_CLOSE( r[0], z[0], tol );
589 BOOST_CHECK_CLOSE( r[1], z[1], tol );
590 BOOST_CHECK_EQUAL( r[2], z[2] );
591 BOOST_CHECK( isnan(r[3]) );
592 BOOST_CHECK_CLOSE( r[4], z[4], tol );
593 BOOST_CHECK_CLOSE( r[5], z[5], tol );
594 x.set( 3, 0.0 ); r.set( 3, 0.0 );
595 BOOST_CHECK( x == r );
596
597 z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
598 x = xorg;
599 r = (x ^= y);
600 BOOST_CHECK_CLOSE( r[0], z[0], tol );
601 BOOST_CHECK_EQUAL( r[1], z[1] );
602 BOOST_CHECK_CLOSE( r[2], z[2], tol );
603 BOOST_CHECK_CLOSE( r[3], z[3], tol );
604 BOOST_CHECK_CLOSE( r[4], z[4], tol );
605 BOOST_CHECK_CLOSE( r[5], z[5], tol );
606 BOOST_CHECK( x == z );
607 x = xorg;
608 BOOST_CHECK( x.pwBinaryOp( y, fo_pow<Real>() ) == z );
609 BOOST_CHECK( x == z );
610 }
611
612
613 BOOST_AUTO_TEST_CASE( VectorTransformationsTest ) {
614 size_t N = 6;
615 Prob x(N);
616 x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, 1.0 ); x.set( 3, 0.0 ); x.set( 4, 2.0 ); x.set( 5, 3.0 );
617 Prob y(N);
618 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
619 Prob z(N), r(N);
620
621 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
622 r = x + y;
623 for( size_t i = 0; i < N; i++ )
624 BOOST_CHECK_CLOSE( r[i], z[i], tol );
625 z = x.pwBinaryTr( y, std::plus<Real>() );
626 BOOST_CHECK( r == z );
627
628 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
629 r = x - y;
630 for( size_t i = 0; i < N; i++ )
631 BOOST_CHECK_CLOSE( r[i], z[i], tol );
632 z = x.pwBinaryTr( y, std::minus<Real>() );
633 BOOST_CHECK( r == z );
634
635 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
636 r = x * y;
637 for( size_t i = 0; i < N; i++ )
638 BOOST_CHECK_CLOSE( r[i], z[i], tol );
639 z = x.pwBinaryTr( y, std::multiplies<Real>() );
640 BOOST_CHECK( r == z );
641
642 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
643 r = x / y;
644 for( size_t i = 0; i < N; i++ )
645 BOOST_CHECK_CLOSE( r[i], z[i], tol );
646 z = x.pwBinaryTr( y, fo_divides0<Real>() );
647 BOOST_CHECK( r == z );
648
649 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY ); /*z.set( 3, INFINITY );*/ z.set( 4, -1.0 ); z.set( 5, 1.0 );
650 r = x.divided_by( y );
651 BOOST_CHECK_CLOSE( r[0], z[0], tol );
652 BOOST_CHECK_CLOSE( r[1], z[1], tol );
653 BOOST_CHECK_EQUAL( r[2], z[2] );
654 BOOST_CHECK( isnan(r[3]) );
655 BOOST_CHECK_CLOSE( r[4], z[4], tol );
656 BOOST_CHECK_CLOSE( r[5], z[5], tol );
657 z = x.pwBinaryTr( y, std::divides<Real>() );
658 BOOST_CHECK_CLOSE( r[0], z[0], tol );
659 BOOST_CHECK_CLOSE( r[1], z[1], tol );
660 BOOST_CHECK_EQUAL( r[2], z[2] );
661 BOOST_CHECK( isnan(r[3]) );
662 BOOST_CHECK_CLOSE( r[4], z[4], tol );
663 BOOST_CHECK_CLOSE( r[5], z[5], tol );
664
665 z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
666 r = x ^ y;
667 BOOST_CHECK_CLOSE( r[0], z[0], tol );
668 BOOST_CHECK_EQUAL( r[1], z[1] );
669 BOOST_CHECK_CLOSE( r[2], z[2], tol );
670 BOOST_CHECK_CLOSE( r[3], z[3], tol );
671 BOOST_CHECK_CLOSE( r[4], z[4], tol );
672 BOOST_CHECK_CLOSE( r[5], z[5], tol );
673 z = x.pwBinaryTr( y, fo_pow<Real>() );
674 BOOST_CHECK( r == z );
675 }
676
677
678 BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
679 Prob x(3), y(3), z(3);
680 x.set( 0, 0.2 );
681 x.set( 1, 0.8 );
682 x.set( 2, 0.0 );
683 y.set( 0, 0.0 );
684 y.set( 1, 0.6 );
685 y.set( 2, 0.4 );
686
687 z = min( x, y );
688 BOOST_CHECK_EQUAL( z[0], 0.0 );
689 BOOST_CHECK_EQUAL( z[1], 0.6 );
690 BOOST_CHECK_EQUAL( z[2], 0.0 );
691 z = max( x, y );
692 BOOST_CHECK_EQUAL( z[0], 0.2 );
693 BOOST_CHECK_EQUAL( z[1], 0.8 );
694 BOOST_CHECK_EQUAL( z[2], 0.4 );
695
696 BOOST_CHECK_EQUAL( dist( x, x, DISTL1 ), 0.0 );
697 BOOST_CHECK_EQUAL( dist( y, y, DISTL1 ), 0.0 );
698 BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), 0.2 + 0.2 + 0.4 );
699 BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), 0.2 + 0.2 + 0.4 );
700 BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
701 BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
702 BOOST_CHECK_EQUAL( dist( x, x, DISTLINF ), 0.0 );
703 BOOST_CHECK_EQUAL( dist( y, y, DISTLINF ), 0.0 );
704 BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), 0.4 );
705 BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), 0.4 );
706 BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), x.innerProduct( y, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
707 BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), y.innerProduct( x, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
708 BOOST_CHECK_EQUAL( dist( x, x, DISTTV ), 0.0 );
709 BOOST_CHECK_EQUAL( dist( y, y, DISTTV ), 0.0 );
710 BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
711 BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
712 BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
713 BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
714 BOOST_CHECK_EQUAL( dist( x, x, DISTKL ), 0.0 );
715 BOOST_CHECK_EQUAL( dist( y, y, DISTKL ), 0.0 );
716 BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), INFINITY );
717 BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), INFINITY );
718 BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
719 BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
720 BOOST_CHECK_EQUAL( dist( x, x, DISTHEL ), 0.0 );
721 BOOST_CHECK_EQUAL( dist( y, y, DISTHEL ), 0.0 );
722 BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
723 BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
724 BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
725 BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
726 x.set( 1, 0.7 ); x.set( 2, 0.1 );
727 y.set( 0, 0.1 ); y.set( 1, 0.5 );
728 BOOST_CHECK_CLOSE( dist( x, y, DISTKL ), 0.2 * std::log(0.2 / 0.1) + 0.7 * std::log(0.7 / 0.5) + 0.1 * std::log(0.1 / 0.4), tol );
729 BOOST_CHECK_CLOSE( dist( y, x, DISTKL ), 0.1 * std::log(0.1 / 0.2) + 0.5 * std::log(0.5 / 0.7) + 0.4 * std::log(0.4 / 0.1), tol );
730 BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
731 BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
732
733 Prob xx(4), yy(4);
734 for( size_t i = 0; i < 3; i++ ) {
735 xx.set( i, x[i] );
736 yy.set( i, y[i] );
737 }
738 std::stringstream ss;
739 ss << xx;
740 std::string s;
741 std::getline( ss, s );
742 #ifdef DAI_SPARSE
743 BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.2, 1:0.7, 2:0.1)") );
744 #else
745 BOOST_CHECK_EQUAL( s, std::string("(0.2, 0.7, 0.1, 0.25)") );
746 #endif
747 std::stringstream ss2;
748 ss2 << yy;
749 std::getline( ss2, s );
750 #ifdef DAI_SPARSE
751 BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.1, 1:0.5, 2:0.4)") );
752 #else
753 BOOST_CHECK_EQUAL( s, std::string("(0.1, 0.5, 0.4, 0.25)") );
754 #endif
755
756 z = min( x, y );
757 BOOST_CHECK_EQUAL( z[0], 0.1 );
758 BOOST_CHECK_EQUAL( z[1], 0.5 );
759 BOOST_CHECK_EQUAL( z[2], 0.1 );
760 z = max( x, y );
761 BOOST_CHECK_EQUAL( z[0], 0.2 );
762 BOOST_CHECK_EQUAL( z[1], 0.7 );
763 BOOST_CHECK_EQUAL( z[2], 0.4 );
764
765 BOOST_CHECK_CLOSE( x.innerProduct( y, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
766 BOOST_CHECK_CLOSE( y.innerProduct( x, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
767 BOOST_CHECK_CLOSE( x.innerProduct( y, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
768 BOOST_CHECK_CLOSE( y.innerProduct( x, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
769 BOOST_CHECK_CLOSE( x.innerProduct( y, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
770 BOOST_CHECK_CLOSE( y.innerProduct( x, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
771 }