Wrote exceptions, factorgraph unit tests and several other improvements
[libdai.git] / tests / unit / prob.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/prob.h>
15 #include <strstream>
16
17
18 using namespace dai;
19
20
21 const double tol = 1e-8;
22
23
24 #define BOOST_TEST_MODULE ProbTest
25
26
27 #include <boost/test/unit_test.hpp>
28 #include <boost/test/floating_point_comparison.hpp>
29
30
31 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
32 // check constructors
33 Prob x1;
34 BOOST_CHECK_EQUAL( x1.size(), 0 );
35 BOOST_CHECK( x1.p() == Prob::container_type() );
36
37 Prob x2( 3 );
38 BOOST_CHECK_EQUAL( x2.size(), 3 );
39 BOOST_CHECK( x2.p() == Prob::container_type( 3, 1.0 / 3.0 ) );
40 BOOST_CHECK_EQUAL( x2[0], 1.0 / 3.0 );
41 BOOST_CHECK_EQUAL( x2[1], 1.0 / 3.0 );
42 BOOST_CHECK_EQUAL( x2[2], 1.0 / 3.0 );
43
44 Prob x3( 4, 1.0 );
45 BOOST_CHECK_EQUAL( x3.size(), 4 );
46 BOOST_CHECK( x3.p() == Prob::container_type( 4, 1.0 ) );
47 BOOST_CHECK_EQUAL( x3[0], 1.0 );
48 BOOST_CHECK_EQUAL( x3[1], 1.0 );
49 BOOST_CHECK_EQUAL( x3[2], 1.0 );
50 BOOST_CHECK_EQUAL( x3[3], 1.0 );
51 x3.set( 0, 0.5 );
52 x3.set( 1, 1.0 );
53 x3.set( 2, 2.0 );
54 x3.set( 3, 4.0 );
55
56 std::vector<Real> v;
57 v.push_back( 0.5 );
58 v.push_back( 1.0 );
59 v.push_back( 2.0 );
60 v.push_back( 4.0 );
61 Prob x4( v.begin(), v.end(), 0 );
62 BOOST_CHECK_EQUAL( x4.size(), 4 );
63 BOOST_CHECK( x4.p() == x3.p() );
64 BOOST_CHECK( x4 == x3 );
65 BOOST_CHECK_EQUAL( x4[0], 0.5 );
66 BOOST_CHECK_EQUAL( x4[1], 1.0 );
67 BOOST_CHECK_EQUAL( x4[2], 2.0 );
68 BOOST_CHECK_EQUAL( x4[3], 4.0 );
69
70 Prob x5( v.begin(), v.end(), v.size() );
71 BOOST_CHECK_EQUAL( x5.size(), 4 );
72 BOOST_CHECK( x5.p() == x3.p() );
73 BOOST_CHECK( x5 == x3 );
74 BOOST_CHECK_EQUAL( x5[0], 0.5 );
75 BOOST_CHECK_EQUAL( x5[1], 1.0 );
76 BOOST_CHECK_EQUAL( x5[2], 2.0 );
77 BOOST_CHECK_EQUAL( x5[3], 4.0 );
78
79 std::vector<int> y( 3, 2 );
80 Prob x6( y );
81 BOOST_CHECK_EQUAL( x6.size(), 3 );
82 BOOST_CHECK_EQUAL( x6[0], 2.0 );
83 BOOST_CHECK_EQUAL( x6[1], 2.0 );
84 BOOST_CHECK_EQUAL( x6[2], 2.0 );
85
86 Prob x7( x6 );
87 BOOST_CHECK( x7 == x6 );
88
89 Prob x8 = x6;
90 BOOST_CHECK( x8 == x6 );
91
92 x7.resize( 5 );
93 BOOST_CHECK_EQUAL( x7.size(), 5 );
94 BOOST_CHECK_EQUAL( x7[0], 2.0 );
95 BOOST_CHECK_EQUAL( x7[1], 2.0 );
96 BOOST_CHECK_EQUAL( x7[2], 2.0 );
97 BOOST_CHECK_EQUAL( x7[3], 0.0 );
98 BOOST_CHECK_EQUAL( x7[4], 0.0 );
99
100 x8.resize( 1 );
101 BOOST_CHECK_EQUAL( x8.size(), 1 );
102 BOOST_CHECK_EQUAL( x8[0], 2.0 );
103 }
104
105
106 #ifndef DAI_SPARSE
107 BOOST_AUTO_TEST_CASE( IteratorTest ) {
108 Prob x( 5, 0.0 );
109 size_t i;
110 for( i = 0; i < x.size(); i++ )
111 x.set( i, i );
112
113 i = 0;
114 for( Prob::const_iterator cit = x.begin(); cit != x.end(); cit++, i++ )
115 BOOST_CHECK_EQUAL( *cit, i );
116
117 i = 0;
118 for( Prob::iterator it = x.begin(); it != x.end(); it++, i++ )
119 *it = 4 - i;
120
121 i = 0;
122 for( Prob::const_iterator it = x.begin(); it != x.end(); it++, i++ )
123 BOOST_CHECK_EQUAL( *it, 4 - i );
124
125 i = 0;
126 for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
127 BOOST_CHECK_EQUAL( *crit, i );
128
129 i = 0;
130 for( Prob::reverse_iterator rit = x.rbegin(); rit != x.rend(); rit++, i++ )
131 *rit = 2 * i;
132
133 i = 0;
134 for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
135 BOOST_CHECK_EQUAL( *crit, 2 * i );
136 }
137 #endif
138
139
140 BOOST_AUTO_TEST_CASE( QueriesTest ) {
141 Prob x( 5, 0.0 );
142 for( size_t i = 0; i < x.size(); i++ )
143 x.set( i, 2.0 - i );
144
145 // test accumulate, min, max, sum, sumAbs, maxAbs
146 BOOST_CHECK_EQUAL( x.sum(), 0.0 );
147 BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_id<Real>() ), 0.0 );
148 BOOST_CHECK_EQUAL( x.accumulateSum( 1.0, fo_id<Real>() ), 1.0 );
149 BOOST_CHECK_EQUAL( x.accumulateSum( -1.0, fo_id<Real>() ), -1.0 );
150 BOOST_CHECK_EQUAL( x.max(), 2.0 );
151 BOOST_CHECK_EQUAL( x.accumulateMax( -INFINITY, fo_id<Real>(), false ), 2.0 );
152 BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_id<Real>(), false ), 3.0 );
153 BOOST_CHECK_EQUAL( x.accumulateMax( -5.0, fo_id<Real>(), false ), 2.0 );
154 BOOST_CHECK_EQUAL( x.min(), -2.0 );
155 BOOST_CHECK_EQUAL( x.accumulateMax( INFINITY, fo_id<Real>(), true ), -2.0 );
156 BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_id<Real>(), true ), -3.0 );
157 BOOST_CHECK_EQUAL( x.accumulateMax( 5.0, fo_id<Real>(), true ), -2.0 );
158 BOOST_CHECK_EQUAL( x.sumAbs(), 6.0 );
159 BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_abs<Real>() ), 6.0 );
160 BOOST_CHECK_EQUAL( x.accumulateSum( 1.0, fo_abs<Real>() ), 7.0 );
161 BOOST_CHECK_EQUAL( x.accumulateSum( -1.0, fo_abs<Real>() ), 7.0 );
162 BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
163 BOOST_CHECK_EQUAL( x.accumulateMax( 0.0, fo_abs<Real>(), false ), 2.0 );
164 BOOST_CHECK_EQUAL( x.accumulateMax( 1.0, fo_abs<Real>(), false ), 2.0 );
165 BOOST_CHECK_EQUAL( x.accumulateMax( -1.0, fo_abs<Real>(), false ), 2.0 );
166 BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_abs<Real>(), false ), 3.0 );
167 BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_abs<Real>(), false ), 3.0 );
168 x.set( 1, 1.0 );
169 BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
170 BOOST_CHECK_EQUAL( x.accumulateMax( 0.0, fo_abs<Real>(), false ), 2.0 );
171 BOOST_CHECK_EQUAL( x.accumulateMax( 1.0, fo_abs<Real>(), false ), 2.0 );
172 BOOST_CHECK_EQUAL( x.accumulateMax( -1.0, fo_abs<Real>(), false ), 2.0 );
173 BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_abs<Real>(), false ), 3.0 );
174 BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_abs<Real>(), false ), 3.0 );
175 for( size_t i = 0; i < x.size(); i++ )
176 x.set( i, i ? (1.0 / i) : 0.0 );
177 BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_inv0<Real>() ), 10.0 );
178 x /= x.sum();
179
180 // test entropy
181 BOOST_CHECK( x.entropy() < Prob(5).entropy() );
182 for( size_t i = 1; i < 100; i++ )
183 BOOST_CHECK_CLOSE( Prob(i).entropy(), std::log(i), tol );
184
185 // test hasNaNs and hasNegatives
186 BOOST_CHECK( !Prob( 3, 0.0 ).hasNaNs() );
187 Real c = 0.0;
188 BOOST_CHECK( Prob( 3, c / c ).hasNaNs() );
189 BOOST_CHECK( !Prob( 3, 0.0 ).hasNegatives() );
190 BOOST_CHECK( !Prob( 3, 1.0 ).hasNegatives() );
191 BOOST_CHECK( Prob( 3, -1.0 ).hasNegatives() );
192 x.set( 0, 0.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 ); x.set( 3, 1.0 ); x.set( 4, 100.0 );
193 BOOST_CHECK( x.hasNegatives() );
194 x.set( 2, -INFINITY );
195 BOOST_CHECK( x.hasNegatives() );
196 x.set( 2, INFINITY );
197 BOOST_CHECK( !x.hasNegatives() );
198 x.set( 2, -1.0 );
199
200 // test argmax
201 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)100.0 ) );
202 x.set( 4, 0.5 );
203 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)3, (Real)1.0 ) );
204 x.set( 3, -2.0 );
205 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)0.5 ) );
206 x.set( 4, -1.0 );
207 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)0, (Real)0.0 ) );
208 x.set( 0, -2.0 );
209 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)1, (Real)0.0 ) );
210 x.set( 1, -3.0 );
211 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)2, (Real)-1.0 ) );
212 x.set( 2, -2.0 );
213 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)-1.0 ) );
214
215 // test draw
216 for( size_t i = 0; i < x.size(); i++ )
217 x.set( i, i ? (1.0 / i) : 0.0 );
218 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
219 BOOST_CHECK( x.draw() < x.size() );
220 BOOST_CHECK( x.draw() != 0 );
221 }
222 x.set( 2, 0.0 );
223 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
224 BOOST_CHECK( x.draw() < x.size() );
225 BOOST_CHECK( x.draw() != 0 );
226 BOOST_CHECK( x.draw() != 2 );
227 }
228 x.set( 4, 0.0 );
229 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
230 BOOST_CHECK( x.draw() < x.size() );
231 BOOST_CHECK( x.draw() != 0 );
232 BOOST_CHECK( x.draw() != 2 );
233 BOOST_CHECK( x.draw() != 4 );
234 }
235 x.set( 1, 0.0 );
236 for( size_t repeat = 0; repeat < 10000; repeat++ )
237 BOOST_CHECK( x.draw() == 3 );
238
239 // test <, ==
240 Prob a(3, 1.0), b(3, 1.0);
241 BOOST_CHECK( !(a < b) );
242 BOOST_CHECK( !(b < a) );
243 BOOST_CHECK( a == b );
244 a.set( 0, 0.0 );
245 BOOST_CHECK( a < b );
246 BOOST_CHECK( !(b < a) );
247 BOOST_CHECK( !(a == b) );
248 b.set( 2, 0.0 );
249 BOOST_CHECK( a < b );
250 BOOST_CHECK( !(b < a) );
251 BOOST_CHECK( !(a == b) );
252 b.set( 0, 0.0 );
253 BOOST_CHECK( !(a < b) );
254 BOOST_CHECK( b < a );
255 BOOST_CHECK( !(a == b) );
256 a.set( 1, 0.0 );
257 BOOST_CHECK( a < b );
258 BOOST_CHECK( !(b < a) );
259 BOOST_CHECK( !(a == b) );
260 b.set( 1, 0.0 );
261 BOOST_CHECK( !(a < b) );
262 BOOST_CHECK( b < a );
263 BOOST_CHECK( !(a == b) );
264 a.set( 2, 0.0 );
265 BOOST_CHECK( !(a < b) );
266 BOOST_CHECK( !(b < a) );
267 BOOST_CHECK( a == b );
268 }
269
270
271 BOOST_AUTO_TEST_CASE( UnaryTransformationsTest ) {
272 Prob x( 3 );
273 x.set( 0, -2.0 );
274 x.set( 1, 0.0 );
275 x.set( 2, 2.0 );
276
277 Prob y = -x;
278 Prob z = x.pwUnaryTr( std::negate<Real>() );
279 BOOST_CHECK_EQUAL( y[0], 2.0 );
280 BOOST_CHECK_EQUAL( y[1], 0.0 );
281 BOOST_CHECK_EQUAL( y[2], -2.0 );
282 BOOST_CHECK( y == z );
283
284 y = x.abs();
285 z = x.pwUnaryTr( fo_abs<Real>() );
286 BOOST_CHECK_EQUAL( y[0], 2.0 );
287 BOOST_CHECK_EQUAL( y[1], 0.0 );
288 BOOST_CHECK_EQUAL( y[2], 2.0 );
289 BOOST_CHECK( y == z );
290
291 y = x.exp();
292 z = x.pwUnaryTr( fo_exp<Real>() );
293 BOOST_CHECK_CLOSE( y[0], std::exp(-2.0), tol );
294 BOOST_CHECK_EQUAL( y[1], 1.0 );
295 BOOST_CHECK_CLOSE( y[2], 1.0 / y[0], tol );
296 BOOST_CHECK( y == z );
297
298 y = x.log(false);
299 z = x.pwUnaryTr( fo_log<Real>() );
300 BOOST_CHECK( isnan( y[0] ) );
301 BOOST_CHECK_EQUAL( y[1], -INFINITY );
302 BOOST_CHECK_CLOSE( y[2], std::log(2.0), tol );
303 BOOST_CHECK( !(y == z) );
304 y.set( 0, 0.0 );
305 z.set( 0, 0.0 );
306 BOOST_CHECK( y == z );
307
308 y = x.log(true);
309 z = x.pwUnaryTr( fo_log0<Real>() );
310 BOOST_CHECK( isnan( y[0] ) );
311 BOOST_CHECK_EQUAL( y[1], 0.0 );
312 BOOST_CHECK_EQUAL( y[2], std::log(2.0) );
313 BOOST_CHECK( !(y == z) );
314 y.set( 0, 0.0 );
315 z.set( 0, 0.0 );
316 BOOST_CHECK( y == z );
317
318 y = x.inverse(false);
319 z = x.pwUnaryTr( fo_inv<Real>() );
320 BOOST_CHECK_EQUAL( y[0], -0.5 );
321 BOOST_CHECK_EQUAL( y[1], INFINITY );
322 BOOST_CHECK_EQUAL( y[2], 0.5 );
323 BOOST_CHECK( y == z );
324
325 y = x.inverse(true);
326 z = x.pwUnaryTr( fo_inv0<Real>() );
327 BOOST_CHECK_EQUAL( y[0], -0.5 );
328 BOOST_CHECK_EQUAL( y[1], 0.0 );
329 BOOST_CHECK_EQUAL( y[2], 0.5 );
330 BOOST_CHECK( y == z );
331
332 x.set( 0, 2.0 );
333 y = x.normalized();
334 BOOST_CHECK_EQUAL( y[0], 0.5 );
335 BOOST_CHECK_EQUAL( y[1], 0.0 );
336 BOOST_CHECK_EQUAL( y[2], 0.5 );
337
338 y = x.normalized( NORMPROB );
339 BOOST_CHECK_EQUAL( y[0], 0.5 );
340 BOOST_CHECK_EQUAL( y[1], 0.0 );
341 BOOST_CHECK_EQUAL( y[2], 0.5 );
342
343 x.set( 0, -2.0 );
344 y = x.normalized( NORMLINF );
345 BOOST_CHECK_EQUAL( y[0], -1.0 );
346 BOOST_CHECK_EQUAL( y[1], 0.0 );
347 BOOST_CHECK_EQUAL( y[2], 1.0 );
348 }
349
350
351 BOOST_AUTO_TEST_CASE( UnaryOperationsTest ) {
352 Prob xorg(3);
353 xorg.set( 0, 2.0 );
354 xorg.set( 1, 0.0 );
355 xorg.set( 2, 1.0 );
356 Prob y(3);
357
358 Prob x = xorg;
359 BOOST_CHECK( x.setUniform() == Prob(3) );
360 BOOST_CHECK( x == Prob(3) );
361
362 y.set( 0, std::exp(2.0) );
363 y.set( 1, 1.0 );
364 y.set( 2, std::exp(1.0) );
365 x = xorg;
366 BOOST_CHECK( x.takeExp() == y );
367 BOOST_CHECK( x == y );
368 x = xorg;
369 BOOST_CHECK( x.pwUnaryOp( fo_exp<Real>() ) == y );
370 BOOST_CHECK( x == y );
371
372 y.set( 0, std::log(2.0) );
373 y.set( 1, -INFINITY );
374 y.set( 2, 0.0 );
375 x = xorg;
376 BOOST_CHECK( x.takeLog() == y );
377 BOOST_CHECK( x == y );
378 x = xorg;
379 BOOST_CHECK( x.takeLog(false) == y );
380 BOOST_CHECK( x == y );
381 x = xorg;
382 BOOST_CHECK( x.pwUnaryOp( fo_log<Real>() ) == y );
383 BOOST_CHECK( x == y );
384
385 y.set( 1, 0.0 );
386 x = xorg;
387 BOOST_CHECK( x.takeLog(true) == y );
388 BOOST_CHECK( x == y );
389 x = xorg;
390 BOOST_CHECK( x.pwUnaryOp( fo_log0<Real>() ) == y );
391 BOOST_CHECK( x == y );
392
393 y.set( 0, 2.0 / 3.0 );
394 y.set( 1, 0.0 / 3.0 );
395 y.set( 2, 1.0 / 3.0 );
396 x = xorg;
397 BOOST_CHECK_EQUAL( x.normalize(), 3.0 );
398 BOOST_CHECK( x == y );
399
400 x = xorg;
401 BOOST_CHECK_EQUAL( x.normalize( NORMPROB ), 3.0 );
402 BOOST_CHECK( x == y );
403
404 y.set( 0, 2.0 / 2.0 );
405 y.set( 1, 0.0 / 2.0 );
406 y.set( 2, 1.0 / 2.0 );
407 x = xorg;
408 BOOST_CHECK_EQUAL( x.normalize( NORMLINF ), 2.0 );
409 BOOST_CHECK( x == y );
410
411 xorg.set( 0, -2.0 );
412 y.set( 0, 2.0 );
413 y.set( 1, 0.0 );
414 y.set( 2, 1.0 );
415 x = xorg;
416 BOOST_CHECK( x.takeAbs() == y );
417 BOOST_CHECK( x == y );
418
419 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
420 x.randomize();
421 for( size_t i = 0; i < x.size(); i++ ) {
422 BOOST_CHECK( x[i] < 1.0 );
423 BOOST_CHECK( x[i] >= 0.0 );
424 }
425 }
426 }
427
428
429 BOOST_AUTO_TEST_CASE( ScalarTransformationsTest ) {
430 Prob x(3);
431 x.set( 0, 2.0 );
432 x.set( 1, 0.0 );
433 x.set( 2, 1.0 );
434 Prob y(3);
435
436 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
437 BOOST_CHECK( (x + 1.0) == y );
438 y.set( 0, 0.0 ); y.set( 1, -2.0 ); y.set( 2, -1.0 );
439 BOOST_CHECK( (x + (-2.0)) == y );
440
441 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
442 BOOST_CHECK( (x - 1.0) == y );
443 y.set( 0, 4.0 ); y.set( 1, 2.0 ); y.set( 2, 3.0 );
444 BOOST_CHECK( (x - (-2.0)) == y );
445
446 BOOST_CHECK( (x * 1.0) == x );
447 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
448 BOOST_CHECK( (x * 2.0) == y );
449 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
450 BOOST_CHECK( (x * -0.5) == y );
451
452 BOOST_CHECK( (x / 1.0) == x );
453 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
454 BOOST_CHECK( (x / 2.0) == y );
455 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
456 BOOST_CHECK( (x / -0.5) == y );
457 BOOST_CHECK( (x / 0.0) == Prob(3, 0.0) );
458
459 BOOST_CHECK( (x ^ 1.0) == x );
460 BOOST_CHECK( (x ^ 0.0) == Prob(3, 1.0) );
461 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
462 BOOST_CHECK( (x ^ 2.0) == y );
463 y.set( 0, 1.0 / std::sqrt(2.0) ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
464 Prob z = (x ^ -0.5);
465 BOOST_CHECK_CLOSE( z[0], y[0], tol );
466 BOOST_CHECK_EQUAL( z[1], y[1] );
467 BOOST_CHECK_CLOSE( z[2], y[2], tol );
468 }
469
470
471 BOOST_AUTO_TEST_CASE( ScalarOperationsTest ) {
472 Prob xorg(3), x(3);
473 xorg.set( 0, 2.0 );
474 xorg.set( 1, 0.0 );
475 xorg.set( 2, 1.0 );
476 Prob y(3);
477
478 x = xorg;
479 BOOST_CHECK( x.fill( 1.0 ) == Prob(3, 1.0) );
480 BOOST_CHECK( x == Prob(3, 1.0) );
481 BOOST_CHECK( x.fill( 2.0 ) == Prob(3, 2.0) );
482 BOOST_CHECK( x == Prob(3, 2.0) );
483 BOOST_CHECK( x.fill( 0.0 ) == Prob(3, 0.0) );
484 BOOST_CHECK( x == Prob(3, 0.0) );
485
486 x = xorg;
487 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
488 BOOST_CHECK( (x += 1.0) == y );
489 BOOST_CHECK( x == y );
490 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
491 BOOST_CHECK( (x += -2.0) == y );
492 BOOST_CHECK( x == y );
493
494 x = xorg;
495 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
496 BOOST_CHECK( (x -= 1.0) == y );
497 BOOST_CHECK( x == y );
498 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
499 BOOST_CHECK( (x -= -2.0) == y );
500 BOOST_CHECK( x == y );
501
502 x = xorg;
503 BOOST_CHECK( (x *= 1.0) == x );
504 BOOST_CHECK( x == x );
505 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
506 BOOST_CHECK( (x *= 2.0) == y );
507 BOOST_CHECK( x == y );
508 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
509 BOOST_CHECK( (x *= -0.25) == y );
510 BOOST_CHECK( x == y );
511
512 x = xorg;
513 BOOST_CHECK( (x /= 1.0) == x );
514 BOOST_CHECK( x == x );
515 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
516 BOOST_CHECK( (x /= 2.0) == y );
517 BOOST_CHECK( x == y );
518 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
519 BOOST_CHECK( (x /= -0.25) == y );
520 BOOST_CHECK( x == y );
521 BOOST_CHECK( (x /= 0.0) == Prob(3, 0.0) );
522 BOOST_CHECK( x == Prob(3, 0.0) );
523
524 x = xorg;
525 BOOST_CHECK( (x ^= 1.0) == x );
526 BOOST_CHECK( x == x );
527 BOOST_CHECK( (x ^= 0.0) == Prob(3, 1.0) );
528 BOOST_CHECK( x == Prob(3, 1.0) );
529 x = xorg;
530 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
531 BOOST_CHECK( (x ^= 2.0) == y );
532 BOOST_CHECK( x == y );
533 y.set( 0, 0.5 ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
534 BOOST_CHECK( (x ^= -0.5) == y );
535 BOOST_CHECK( x == y );
536 }
537
538
539 BOOST_AUTO_TEST_CASE( VectorOperationsTest ) {
540 size_t N = 6;
541 Prob xorg(N), x(N);
542 xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, 1.0 ); xorg.set( 3, 0.0 ); xorg.set( 4, 2.0 ); xorg.set( 5, 3.0 );
543 Prob y(N);
544 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
545 Prob z(N), r(N);
546
547 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
548 x = xorg;
549 r = (x += y);
550 for( size_t i = 0; i < N; i++ )
551 BOOST_CHECK_CLOSE( r[i], z[i], tol );
552 BOOST_CHECK( x == z );
553 x = xorg;
554 BOOST_CHECK( x.pwBinaryOp( y, std::plus<Real>() ) == z );
555 BOOST_CHECK( x == z );
556
557 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
558 x = xorg;
559 r = (x -= y);
560 for( size_t i = 0; i < N; i++ )
561 BOOST_CHECK_CLOSE( r[i], z[i], tol );
562 BOOST_CHECK( x == z );
563 x = xorg;
564 BOOST_CHECK( x.pwBinaryOp( y, std::minus<Real>() ) == z );
565 BOOST_CHECK( x == z );
566
567 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
568 x = xorg;
569 r = (x *= y);
570 for( size_t i = 0; i < N; i++ )
571 BOOST_CHECK_CLOSE( r[i], z[i], tol );
572 BOOST_CHECK( x == z );
573 x = xorg;
574 BOOST_CHECK( x.pwBinaryOp( y, std::multiplies<Real>() ) == z );
575 BOOST_CHECK( x == z );
576
577 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
578 x = xorg;
579 r = (x /= y);
580 for( size_t i = 0; i < N; i++ )
581 BOOST_CHECK_CLOSE( r[i], z[i], tol );
582 BOOST_CHECK( x == z );
583 x = xorg;
584 BOOST_CHECK( x.pwBinaryOp( y, fo_divides0<Real>() ) == z );
585 BOOST_CHECK( x == z );
586
587 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY ); /*z.set( 3, INFINITY );*/ z.set( 4, -1.0 ); z.set( 5, 1.0 );
588 x = xorg;
589 r = (x.divide( y ));
590 BOOST_CHECK_CLOSE( r[0], z[0], tol );
591 BOOST_CHECK_CLOSE( r[1], z[1], tol );
592 BOOST_CHECK_EQUAL( r[2], z[2] );
593 BOOST_CHECK( isnan(r[3]) );
594 BOOST_CHECK_CLOSE( r[4], z[4], tol );
595 BOOST_CHECK_CLOSE( r[5], z[5], tol );
596 x.set( 3, 0.0 ); r.set( 3, 0.0 );
597 BOOST_CHECK( x == r );
598 x = xorg;
599 r = x.pwBinaryOp( y, std::divides<Real>() );
600 BOOST_CHECK_CLOSE( r[0], z[0], tol );
601 BOOST_CHECK_CLOSE( r[1], z[1], tol );
602 BOOST_CHECK_EQUAL( r[2], z[2] );
603 BOOST_CHECK( isnan(r[3]) );
604 BOOST_CHECK_CLOSE( r[4], z[4], tol );
605 BOOST_CHECK_CLOSE( r[5], z[5], tol );
606 x.set( 3, 0.0 ); r.set( 3, 0.0 );
607 BOOST_CHECK( x == r );
608
609 z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
610 x = xorg;
611 r = (x ^= y);
612 BOOST_CHECK_CLOSE( r[0], z[0], tol );
613 BOOST_CHECK_EQUAL( r[1], z[1] );
614 BOOST_CHECK_CLOSE( r[2], z[2], tol );
615 BOOST_CHECK_CLOSE( r[3], z[3], tol );
616 BOOST_CHECK_CLOSE( r[4], z[4], tol );
617 BOOST_CHECK_CLOSE( r[5], z[5], tol );
618 BOOST_CHECK( x == z );
619 x = xorg;
620 BOOST_CHECK( x.pwBinaryOp( y, fo_pow<Real>() ) == z );
621 BOOST_CHECK( x == z );
622 }
623
624
625 BOOST_AUTO_TEST_CASE( VectorTransformationsTest ) {
626 size_t N = 6;
627 Prob x(N);
628 x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, 1.0 ); x.set( 3, 0.0 ); x.set( 4, 2.0 ); x.set( 5, 3.0 );
629 Prob y(N);
630 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
631 Prob z(N), r(N);
632
633 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
634 r = x + y;
635 for( size_t i = 0; i < N; i++ )
636 BOOST_CHECK_CLOSE( r[i], z[i], tol );
637 z = x.pwBinaryTr( y, std::plus<Real>() );
638 BOOST_CHECK( r == z );
639
640 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
641 r = x - y;
642 for( size_t i = 0; i < N; i++ )
643 BOOST_CHECK_CLOSE( r[i], z[i], tol );
644 z = x.pwBinaryTr( y, std::minus<Real>() );
645 BOOST_CHECK( r == z );
646
647 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
648 r = x * y;
649 for( size_t i = 0; i < N; i++ )
650 BOOST_CHECK_CLOSE( r[i], z[i], tol );
651 z = x.pwBinaryTr( y, std::multiplies<Real>() );
652 BOOST_CHECK( r == z );
653
654 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
655 r = x / y;
656 for( size_t i = 0; i < N; i++ )
657 BOOST_CHECK_CLOSE( r[i], z[i], tol );
658 z = x.pwBinaryTr( y, fo_divides0<Real>() );
659 BOOST_CHECK( r == z );
660
661 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY ); /*z.set( 3, INFINITY );*/ z.set( 4, -1.0 ); z.set( 5, 1.0 );
662 r = x.divided_by( y );
663 BOOST_CHECK_CLOSE( r[0], z[0], tol );
664 BOOST_CHECK_CLOSE( r[1], z[1], tol );
665 BOOST_CHECK_EQUAL( r[2], z[2] );
666 BOOST_CHECK( isnan(r[3]) );
667 BOOST_CHECK_CLOSE( r[4], z[4], tol );
668 BOOST_CHECK_CLOSE( r[5], z[5], tol );
669 z = x.pwBinaryTr( y, std::divides<Real>() );
670 BOOST_CHECK_CLOSE( r[0], z[0], tol );
671 BOOST_CHECK_CLOSE( r[1], z[1], tol );
672 BOOST_CHECK_EQUAL( r[2], z[2] );
673 BOOST_CHECK( isnan(r[3]) );
674 BOOST_CHECK_CLOSE( r[4], z[4], tol );
675 BOOST_CHECK_CLOSE( r[5], z[5], tol );
676
677 z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
678 r = x ^ y;
679 BOOST_CHECK_CLOSE( r[0], z[0], tol );
680 BOOST_CHECK_EQUAL( r[1], z[1] );
681 BOOST_CHECK_CLOSE( r[2], z[2], tol );
682 BOOST_CHECK_CLOSE( r[3], z[3], tol );
683 BOOST_CHECK_CLOSE( r[4], z[4], tol );
684 BOOST_CHECK_CLOSE( r[5], z[5], tol );
685 z = x.pwBinaryTr( y, fo_pow<Real>() );
686 BOOST_CHECK( r == z );
687 }
688
689
690 BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
691 Prob x(3), y(3), z(3);
692 x.set( 0, 0.2 );
693 x.set( 1, 0.8 );
694 x.set( 2, 0.0 );
695 y.set( 0, 0.0 );
696 y.set( 1, 0.6 );
697 y.set( 2, 0.4 );
698
699 z = min( x, y );
700 BOOST_CHECK_EQUAL( z[0], 0.0 );
701 BOOST_CHECK_EQUAL( z[1], 0.6 );
702 BOOST_CHECK_EQUAL( z[2], 0.0 );
703 z = max( x, y );
704 BOOST_CHECK_EQUAL( z[0], 0.2 );
705 BOOST_CHECK_EQUAL( z[1], 0.8 );
706 BOOST_CHECK_EQUAL( z[2], 0.4 );
707
708 BOOST_CHECK_EQUAL( dist( x, x, DISTL1 ), 0.0 );
709 BOOST_CHECK_EQUAL( dist( y, y, DISTL1 ), 0.0 );
710 BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), 0.2 + 0.2 + 0.4 );
711 BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), 0.2 + 0.2 + 0.4 );
712 BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
713 BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
714 BOOST_CHECK_EQUAL( dist( x, x, DISTLINF ), 0.0 );
715 BOOST_CHECK_EQUAL( dist( y, y, DISTLINF ), 0.0 );
716 BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), 0.4 );
717 BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), 0.4 );
718 BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), x.innerProduct( y, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
719 BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), y.innerProduct( x, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
720 BOOST_CHECK_EQUAL( dist( x, x, DISTTV ), 0.0 );
721 BOOST_CHECK_EQUAL( dist( y, y, DISTTV ), 0.0 );
722 BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
723 BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
724 BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
725 BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
726 BOOST_CHECK_EQUAL( dist( x, x, DISTKL ), 0.0 );
727 BOOST_CHECK_EQUAL( dist( y, y, DISTKL ), 0.0 );
728 BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), INFINITY );
729 BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), INFINITY );
730 BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
731 BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
732 BOOST_CHECK_EQUAL( dist( x, x, DISTHEL ), 0.0 );
733 BOOST_CHECK_EQUAL( dist( y, y, DISTHEL ), 0.0 );
734 BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
735 BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
736 BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
737 BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
738 x.set( 1, 0.7 ); x.set( 2, 0.1 );
739 y.set( 0, 0.1 ); y.set( 1, 0.5 );
740 BOOST_CHECK_CLOSE( dist( x, y, DISTKL ), 0.2 * std::log(0.2 / 0.1) + 0.7 * std::log(0.7 / 0.5) + 0.1 * std::log(0.1 / 0.4), tol );
741 BOOST_CHECK_CLOSE( dist( y, x, DISTKL ), 0.1 * std::log(0.1 / 0.2) + 0.5 * std::log(0.5 / 0.7) + 0.4 * std::log(0.4 / 0.1), tol );
742 BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
743 BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
744
745 Prob xx(4), yy(4);
746 for( size_t i = 0; i < 3; i++ ) {
747 xx.set( i, x[i] );
748 yy.set( i, y[i] );
749 }
750 std::stringstream ss;
751 ss << xx;
752 std::string s;
753 std::getline( ss, s );
754 #ifdef DAI_SPARSE
755 BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.2, 1:0.7, 2:0.1)") );
756 #else
757 BOOST_CHECK_EQUAL( s, std::string("(0.2, 0.7, 0.1, 0.25)") );
758 #endif
759 std::stringstream ss2;
760 ss2 << yy;
761 std::getline( ss2, s );
762 #ifdef DAI_SPARSE
763 BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.1, 1:0.5, 2:0.4)") );
764 #else
765 BOOST_CHECK_EQUAL( s, std::string("(0.1, 0.5, 0.4, 0.25)") );
766 #endif
767
768 z = min( x, y );
769 BOOST_CHECK_EQUAL( z[0], 0.1 );
770 BOOST_CHECK_EQUAL( z[1], 0.5 );
771 BOOST_CHECK_EQUAL( z[2], 0.1 );
772 z = max( x, y );
773 BOOST_CHECK_EQUAL( z[0], 0.2 );
774 BOOST_CHECK_EQUAL( z[1], 0.7 );
775 BOOST_CHECK_EQUAL( z[2], 0.4 );
776
777 BOOST_CHECK_CLOSE( x.innerProduct( y, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
778 BOOST_CHECK_CLOSE( y.innerProduct( x, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
779 BOOST_CHECK_CLOSE( x.innerProduct( y, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
780 BOOST_CHECK_CLOSE( y.innerProduct( x, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
781 BOOST_CHECK_CLOSE( x.innerProduct( y, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
782 BOOST_CHECK_CLOSE( y.innerProduct( x, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
783 }