Started work on factor.h/cpp unit tests
[libdai.git] / tests / unit / factor.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/factor.h>
15 #include <strstream>
16
17
18 using namespace dai;
19
20
21 const double tol = 1e-8;
22
23
24 #define BOOST_TEST_MODULE FactorTest
25
26
27 #include <boost/test/unit_test.hpp>
28 #include <boost/test/floating_point_comparison.hpp>
29
30
31 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
32 // check constructors
33 Factor x1;
34 /* BOOST_CHECK_EQUAL( x1.states(), 0 );
35 BOOST_CHECK( x1.p() == std::vector<Real>() );
36
37 Factor x2( 3 );
38 BOOST_CHECK_EQUAL( x2.states(), 3 );
39 BOOST_CHECK( x2.p() == std::vector<Real>( 3, 1.0 / 3.0 ) );
40 BOOST_CHECK_EQUAL( x2[0], 1.0 / 3.0 );
41 BOOST_CHECK_EQUAL( x2[1], 1.0 / 3.0 );
42 BOOST_CHECK_EQUAL( x2[2], 1.0 / 3.0 );
43
44 Factor x3( 4, 1.0 );
45 BOOST_CHECK_EQUAL( x3.states(), 4 );
46 BOOST_CHECK( x3.p() == std::vector<Real>( 4, 1.0 ) );
47 BOOST_CHECK_EQUAL( x3[0], 1.0 );
48 BOOST_CHECK_EQUAL( x3[1], 1.0 );
49 BOOST_CHECK_EQUAL( x3[2], 1.0 );
50 BOOST_CHECK_EQUAL( x3[3], 1.0 );
51 x3.set( 0, 0.5 );
52 x3.set( 1, 1.0 );
53 x3.set( 2, 2.0 );
54 x3.set( 3, 4.0 );
55
56 Factor x4( x3.begin(), x3.end() );
57 BOOST_CHECK_EQUAL( x4.states(), 4 );
58 BOOST_CHECK( x4.p() == x3.p() );
59 BOOST_CHECK_EQUAL( x4[0], 0.5 );
60 BOOST_CHECK_EQUAL( x4[1], 1.0 );
61 BOOST_CHECK_EQUAL( x4[2], 2.0 );
62 BOOST_CHECK_EQUAL( x4[3], 4.0 );
63
64 x3.p() = std::vector<Real>( 4, 2.5 );
65 Factor x5( x3.begin(), x3.end(), x3.states() );
66 BOOST_CHECK_EQUAL( x5.states(), 4 );
67 BOOST_CHECK( x5.p() == x3.p() );
68 BOOST_CHECK_EQUAL( x5[0], 2.5 );
69 BOOST_CHECK_EQUAL( x5[1], 2.5 );
70 BOOST_CHECK_EQUAL( x5[2], 2.5 );
71 BOOST_CHECK_EQUAL( x5[3], 2.5 );
72
73 std::vector<int> y( 3, 2 );
74 Factor x6( y );
75 BOOST_CHECK_EQUAL( x6.states(), 3 );
76 BOOST_CHECK( x6.p() == std::vector<Real>( 3, 2.0 ) );
77 BOOST_CHECK_EQUAL( x6[0], 2.0 );
78 BOOST_CHECK_EQUAL( x6[1], 2.0 );
79 BOOST_CHECK_EQUAL( x6[2], 2.0 );
80
81 Factor x7( x6 );
82 BOOST_CHECK( x7 == x6 );
83
84 Factor x8 = x6;
85 BOOST_CHECK( x8 == x6 );*/
86 }
87
88 /*
89
90 BOOST_AUTO_TEST_CASE( IteratorTest ) {
91 Prob x( 5, 0.0 );
92 size_t i;
93 for( i = 0; i < x.size(); i++ )
94 x.set( i, i );
95
96 i = 0;
97 for( Prob::const_iterator cit = x.begin(); cit != x.end(); cit++, i++ )
98 BOOST_CHECK_EQUAL( *cit, i );
99
100 i = 0;
101 for( Prob::iterator it = x.begin(); it != x.end(); it++, i++ )
102 *it = 4 - i;
103
104 i = 0;
105 for( Prob::const_iterator it = x.begin(); it != x.end(); it++, i++ )
106 BOOST_CHECK_EQUAL( *it, 4 - i );
107
108 i = 0;
109 for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
110 BOOST_CHECK_EQUAL( *crit, i );
111
112 i = 0;
113 for( Prob::reverse_iterator rit = x.rbegin(); rit != x.rend(); rit++, i++ )
114 *rit = 2 * i;
115
116 i = 0;
117 for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
118 BOOST_CHECK_EQUAL( *crit, 2 * i );
119 }
120
121
122 BOOST_AUTO_TEST_CASE( QueriesTest ) {
123 Prob x( 5, 0.0 );
124 for( size_t i = 0; i < x.size(); i++ )
125 x.set( i, 2.0 - i );
126
127 // test accumulate, min, max, sum, sumAbs, maxAbs
128 BOOST_CHECK_EQUAL( x.sum(), 0.0 );
129 BOOST_CHECK_EQUAL( x.accumulate( 0.0, std::plus<Real>(), fo_id<Real>() ), 0.0 );
130 BOOST_CHECK_EQUAL( x.accumulate( 1.0, std::plus<Real>(), fo_id<Real>() ), 1.0 );
131 BOOST_CHECK_EQUAL( x.accumulate( -1.0, std::plus<Real>(), fo_id<Real>() ), -1.0 );
132 BOOST_CHECK_EQUAL( x.max(), 2.0 );
133 BOOST_CHECK_EQUAL( x.accumulate( -INFINITY, fo_max<Real>(), fo_id<Real>() ), 2.0 );
134 BOOST_CHECK_EQUAL( x.accumulate( 3.0, fo_max<Real>(), fo_id<Real>() ), 3.0 );
135 BOOST_CHECK_EQUAL( x.accumulate( -5.0, fo_max<Real>(), fo_id<Real>() ), 2.0 );
136 BOOST_CHECK_EQUAL( x.min(), -2.0 );
137 BOOST_CHECK_EQUAL( x.accumulate( INFINITY, fo_min<Real>(), fo_id<Real>() ), -2.0 );
138 BOOST_CHECK_EQUAL( x.accumulate( -3.0, fo_min<Real>(), fo_id<Real>() ), -3.0 );
139 BOOST_CHECK_EQUAL( x.accumulate( 5.0, fo_min<Real>(), fo_id<Real>() ), -2.0 );
140 BOOST_CHECK_EQUAL( x.sumAbs(), 6.0 );
141 BOOST_CHECK_EQUAL( x.accumulate( 0.0, std::plus<Real>(), fo_abs<Real>() ), 6.0 );
142 BOOST_CHECK_EQUAL( x.accumulate( 1.0, std::plus<Real>(), fo_abs<Real>() ), 7.0 );
143 BOOST_CHECK_EQUAL( x.accumulate( -1.0, std::plus<Real>(), fo_abs<Real>() ), 7.0 );
144 BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
145 BOOST_CHECK_EQUAL( x.accumulate( 0.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
146 BOOST_CHECK_EQUAL( x.accumulate( 1.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
147 BOOST_CHECK_EQUAL( x.accumulate( -1.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
148 BOOST_CHECK_EQUAL( x.accumulate( 3.0, fo_max<Real>(), fo_abs<Real>() ), 3.0 );
149 BOOST_CHECK_EQUAL( x.accumulate( -3.0, fo_max<Real>(), fo_abs<Real>() ), 3.0 );
150 x.set( 1, 1.0 );
151 BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
152 BOOST_CHECK_EQUAL( x.accumulate( 0.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
153 BOOST_CHECK_EQUAL( x.accumulate( 1.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
154 BOOST_CHECK_EQUAL( x.accumulate( -1.0, fo_max<Real>(), fo_abs<Real>() ), 2.0 );
155 BOOST_CHECK_EQUAL( x.accumulate( 3.0, fo_max<Real>(), fo_abs<Real>() ), 3.0 );
156 BOOST_CHECK_EQUAL( x.accumulate( -3.0, fo_max<Real>(), fo_abs<Real>() ), 3.0 );
157 for( size_t i = 0; i < x.size(); i++ )
158 x.set( i, i ? (1.0 / i) : 0.0 );
159 BOOST_CHECK_EQUAL( x.accumulate( 0.0, std::plus<Real>(), fo_inv0<Real>() ), 10.0 );
160 x /= x.sum();
161
162 // test entropy
163 BOOST_CHECK( x.entropy() < Prob(5).entropy() );
164 for( size_t i = 1; i < 100; i++ )
165 BOOST_CHECK_CLOSE( Prob(i).entropy(), std::log(i), tol );
166
167 // test hasNaNs and hasNegatives
168 BOOST_CHECK( !Prob( 3, 0.0 ).hasNaNs() );
169 Real c = 0.0;
170 BOOST_CHECK( Prob( 3, c / c ).hasNaNs() );
171 BOOST_CHECK( !Prob( 3, 0.0 ).hasNegatives() );
172 BOOST_CHECK( !Prob( 3, 1.0 ).hasNegatives() );
173 BOOST_CHECK( Prob( 3, -1.0 ).hasNegatives() );
174 x.set( 0, 0.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 ); x.set( 3, 1.0 ); x.set( 4, 100.0 );
175 BOOST_CHECK( x.hasNegatives() );
176 x.set( 2, -INFINITY );
177 BOOST_CHECK( x.hasNegatives() );
178 x.set( 2, INFINITY );
179 BOOST_CHECK( !x.hasNegatives() );
180 x.set( 2, -1.0 );
181
182 // test argmax
183 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)100.0 ) );
184 x.set( 4, 0.5 );
185 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)3, (Real)1.0 ) );
186 x.set( 3, -2.0 );
187 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)0.5 ) );
188 x.set( 4, -1.0 );
189 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)0, (Real)0.0 ) );
190 x.set( 0, -2.0 );
191 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)1, (Real)0.0 ) );
192 x.set( 1, -3.0 );
193 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)2, (Real)-1.0 ) );
194 x.set( 2, -2.0 );
195 BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)-1.0 ) );
196
197 // test draw
198 for( size_t i = 0; i < x.size(); i++ )
199 x.set( i, i ? (1.0 / i) : 0.0 );
200 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
201 BOOST_CHECK( x.draw() < x.size() );
202 BOOST_CHECK( x.draw() != 0 );
203 }
204 x.set( 2, 0.0 );
205 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
206 BOOST_CHECK( x.draw() < x.size() );
207 BOOST_CHECK( x.draw() != 0 );
208 BOOST_CHECK( x.draw() != 2 );
209 }
210 x.set( 4, 0.0 );
211 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
212 BOOST_CHECK( x.draw() < x.size() );
213 BOOST_CHECK( x.draw() != 0 );
214 BOOST_CHECK( x.draw() != 2 );
215 BOOST_CHECK( x.draw() != 4 );
216 }
217 x.set( 1, 0.0 );
218 for( size_t repeat = 0; repeat < 10000; repeat++ )
219 BOOST_CHECK( x.draw() == 3 );
220
221 // test <, ==
222 Prob a(3, 1.0), b(3, 1.0);
223 BOOST_CHECK( !(a < b) );
224 BOOST_CHECK( !(b < a) );
225 BOOST_CHECK( a == b );
226 a.set( 0, 0.0 );
227 BOOST_CHECK( a < b );
228 BOOST_CHECK( !(b < a) );
229 BOOST_CHECK( !(a == b) );
230 b.set( 2, 0.0 );
231 BOOST_CHECK( a < b );
232 BOOST_CHECK( !(b < a) );
233 BOOST_CHECK( !(a == b) );
234 b.set( 0, 0.0 );
235 BOOST_CHECK( !(a < b) );
236 BOOST_CHECK( b < a );
237 BOOST_CHECK( !(a == b) );
238 a.set( 1, 0.0 );
239 BOOST_CHECK( a < b );
240 BOOST_CHECK( !(b < a) );
241 BOOST_CHECK( !(a == b) );
242 b.set( 1, 0.0 );
243 BOOST_CHECK( !(a < b) );
244 BOOST_CHECK( b < a );
245 BOOST_CHECK( !(a == b) );
246 a.set( 2, 0.0 );
247 BOOST_CHECK( !(a < b) );
248 BOOST_CHECK( !(b < a) );
249 BOOST_CHECK( a == b );
250 }
251
252
253 BOOST_AUTO_TEST_CASE( UnaryTransformationsTest ) {
254 Prob x( 3 );
255 x.set( 0, -2.0 );
256 x.set( 1, 0.0 );
257 x.set( 2, 2.0 );
258
259 Prob y = -x;
260 Prob z = x.pwUnaryTr( std::negate<Real>() );
261 BOOST_CHECK_EQUAL( y[0], 2.0 );
262 BOOST_CHECK_EQUAL( y[1], 0.0 );
263 BOOST_CHECK_EQUAL( y[2], -2.0 );
264 BOOST_CHECK( y == z );
265
266 y = x.abs();
267 z = x.pwUnaryTr( fo_abs<Real>() );
268 BOOST_CHECK_EQUAL( y[0], 2.0 );
269 BOOST_CHECK_EQUAL( y[1], 0.0 );
270 BOOST_CHECK_EQUAL( y[2], 2.0 );
271 BOOST_CHECK( y == z );
272
273 y = x.exp();
274 z = x.pwUnaryTr( fo_exp<Real>() );
275 BOOST_CHECK_CLOSE( y[0], std::exp(-2.0), tol );
276 BOOST_CHECK_EQUAL( y[1], 1.0 );
277 BOOST_CHECK_CLOSE( y[2], 1.0 / y[0], tol );
278 BOOST_CHECK( y == z );
279
280 y = x.log(false);
281 z = x.pwUnaryTr( fo_log<Real>() );
282 BOOST_CHECK( isnan( y[0] ) );
283 BOOST_CHECK_EQUAL( y[1], -INFINITY );
284 BOOST_CHECK_CLOSE( y[2], std::log(2.0), tol );
285 BOOST_CHECK( !(y == z) );
286 y.set( 0, 0.0 );
287 z.set( 0, 0.0 );
288 BOOST_CHECK( y == z );
289
290 y = x.log(true);
291 z = x.pwUnaryTr( fo_log0<Real>() );
292 BOOST_CHECK( isnan( y[0] ) );
293 BOOST_CHECK_EQUAL( y[1], 0.0 );
294 BOOST_CHECK_EQUAL( y[2], std::log(2.0) );
295 BOOST_CHECK( !(y == z) );
296 y.set( 0, 0.0 );
297 z.set( 0, 0.0 );
298 BOOST_CHECK( y == z );
299
300 y = x.inverse(false);
301 z = x.pwUnaryTr( fo_inv<Real>() );
302 BOOST_CHECK_EQUAL( y[0], -0.5 );
303 BOOST_CHECK_EQUAL( y[1], INFINITY );
304 BOOST_CHECK_EQUAL( y[2], 0.5 );
305 BOOST_CHECK( y == z );
306
307 y = x.inverse(true);
308 z = x.pwUnaryTr( fo_inv0<Real>() );
309 BOOST_CHECK_EQUAL( y[0], -0.5 );
310 BOOST_CHECK_EQUAL( y[1], 0.0 );
311 BOOST_CHECK_EQUAL( y[2], 0.5 );
312 BOOST_CHECK( y == z );
313
314 x.set( 0, 2.0 );
315 y = x.normalized();
316 BOOST_CHECK_EQUAL( y[0], 0.5 );
317 BOOST_CHECK_EQUAL( y[1], 0.0 );
318 BOOST_CHECK_EQUAL( y[2], 0.5 );
319
320 y = x.normalized( Prob::NORMPROB );
321 BOOST_CHECK_EQUAL( y[0], 0.5 );
322 BOOST_CHECK_EQUAL( y[1], 0.0 );
323 BOOST_CHECK_EQUAL( y[2], 0.5 );
324
325 x.set( 0, -2.0 );
326 y = x.normalized( Prob::NORMLINF );
327 BOOST_CHECK_EQUAL( y[0], -1.0 );
328 BOOST_CHECK_EQUAL( y[1], 0.0 );
329 BOOST_CHECK_EQUAL( y[2], 1.0 );
330 }
331
332
333 BOOST_AUTO_TEST_CASE( UnaryOperationsTest ) {
334 Prob xorg(3);
335 xorg.set( 0, 2.0 );
336 xorg.set( 1, 0.0 );
337 xorg.set( 2, 1.0 );
338 Prob y(3);
339
340 Prob x = xorg;
341 BOOST_CHECK( x.setUniform() == Prob(3) );
342 BOOST_CHECK( x == Prob(3) );
343
344 y.set( 0, std::exp(2.0) );
345 y.set( 1, 1.0 );
346 y.set( 2, std::exp(1.0) );
347 x = xorg;
348 BOOST_CHECK( x.takeExp() == y );
349 BOOST_CHECK( x == y );
350 x = xorg;
351 BOOST_CHECK( x.pwUnaryOp( fo_exp<Real>() ) == y );
352 BOOST_CHECK( x == y );
353
354 y.set( 0, std::log(2.0) );
355 y.set( 1, -INFINITY );
356 y.set( 2, 0.0 );
357 x = xorg;
358 BOOST_CHECK( x.takeLog() == y );
359 BOOST_CHECK( x == y );
360 x = xorg;
361 BOOST_CHECK( x.takeLog(false) == y );
362 BOOST_CHECK( x == y );
363 x = xorg;
364 BOOST_CHECK( x.pwUnaryOp( fo_log<Real>() ) == y );
365 BOOST_CHECK( x == y );
366
367 y.set( 1, 0.0 );
368 x = xorg;
369 BOOST_CHECK( x.takeLog(true) == y );
370 BOOST_CHECK( x == y );
371 x = xorg;
372 BOOST_CHECK( x.pwUnaryOp( fo_log0<Real>() ) == y );
373 BOOST_CHECK( x == y );
374
375 y.set( 0, 2.0 / 3.0 );
376 y.set( 1, 0.0 / 3.0 );
377 y.set( 2, 1.0 / 3.0 );
378 x = xorg;
379 BOOST_CHECK_EQUAL( x.normalize(), 3.0 );
380 BOOST_CHECK( x == y );
381
382 x = xorg;
383 BOOST_CHECK_EQUAL( x.normalize( Prob::NORMPROB ), 3.0 );
384 BOOST_CHECK( x == y );
385
386 y.set( 0, 2.0 / 2.0 );
387 y.set( 1, 0.0 / 2.0 );
388 y.set( 2, 1.0 / 2.0 );
389 x = xorg;
390 BOOST_CHECK_EQUAL( x.normalize( Prob::NORMLINF ), 2.0 );
391 BOOST_CHECK( x == y );
392
393 xorg.set( 0, -2.0 );
394 y.set( 0, 2.0 );
395 y.set( 1, 0.0 );
396 y.set( 2, 1.0 );
397 x = xorg;
398 BOOST_CHECK( x.takeAbs() == y );
399 BOOST_CHECK( x == y );
400
401 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
402 x.randomize();
403 for( size_t i = 0; i < x.size(); i++ ) {
404 BOOST_CHECK( x[i] < 1.0 );
405 BOOST_CHECK( x[i] >= 0.0 );
406 }
407 }
408 }
409
410
411 BOOST_AUTO_TEST_CASE( ScalarTransformationsTest ) {
412 Prob x(3);
413 x.set( 0, 2.0 );
414 x.set( 1, 0.0 );
415 x.set( 2, 1.0 );
416 Prob y(3);
417
418 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
419 BOOST_CHECK( (x + 1.0) == y );
420 y.set( 0, 0.0 ); y.set( 1, -2.0 ); y.set( 2, -1.0 );
421 BOOST_CHECK( (x + (-2.0)) == y );
422
423 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
424 BOOST_CHECK( (x - 1.0) == y );
425 y.set( 0, 4.0 ); y.set( 1, 2.0 ); y.set( 2, 3.0 );
426 BOOST_CHECK( (x - (-2.0)) == y );
427
428 BOOST_CHECK( (x * 1.0) == x );
429 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
430 BOOST_CHECK( (x * 2.0) == y );
431 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
432 BOOST_CHECK( (x * -0.5) == y );
433
434 BOOST_CHECK( (x / 1.0) == x );
435 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
436 BOOST_CHECK( (x / 2.0) == y );
437 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
438 BOOST_CHECK( (x / -0.5) == y );
439 BOOST_CHECK( (x / 0.0) == Prob(3, 0.0) );
440
441 BOOST_CHECK( (x ^ 1.0) == x );
442 BOOST_CHECK( (x ^ 0.0) == Prob(3, 1.0) );
443 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
444 BOOST_CHECK( (x ^ 2.0) == y );
445 y.set( 0, 1.0 / std::sqrt(2.0) ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
446 Prob z = (x ^ -0.5);
447 BOOST_CHECK_CLOSE( z[0], y[0], tol );
448 BOOST_CHECK_EQUAL( z[1], y[1] );
449 BOOST_CHECK_CLOSE( z[2], y[2], tol );
450 }
451
452
453 BOOST_AUTO_TEST_CASE( ScalarOperationsTest ) {
454 Prob xorg(3), x(3);
455 xorg.set( 0, 2.0 );
456 xorg.set( 1, 0.0 );
457 xorg.set( 2, 1.0 );
458 Prob y(3);
459
460 x = xorg;
461 BOOST_CHECK( x.fill( 1.0 ) == Prob(3, 1.0) );
462 BOOST_CHECK( x == Prob(3, 1.0) );
463 BOOST_CHECK( x.fill( 2.0 ) == Prob(3, 2.0) );
464 BOOST_CHECK( x == Prob(3, 2.0) );
465 BOOST_CHECK( x.fill( 0.0 ) == Prob(3, 0.0) );
466 BOOST_CHECK( x == Prob(3, 0.0) );
467
468 x = xorg;
469 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
470 BOOST_CHECK( (x += 1.0) == y );
471 BOOST_CHECK( x == y );
472 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
473 BOOST_CHECK( (x += -2.0) == y );
474 BOOST_CHECK( x == y );
475
476 x = xorg;
477 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
478 BOOST_CHECK( (x -= 1.0) == y );
479 BOOST_CHECK( x == y );
480 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
481 BOOST_CHECK( (x -= -2.0) == y );
482 BOOST_CHECK( x == y );
483
484 x = xorg;
485 BOOST_CHECK( (x *= 1.0) == x );
486 BOOST_CHECK( x == x );
487 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
488 BOOST_CHECK( (x *= 2.0) == y );
489 BOOST_CHECK( x == y );
490 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
491 BOOST_CHECK( (x *= -0.25) == y );
492 BOOST_CHECK( x == y );
493
494 x = xorg;
495 BOOST_CHECK( (x /= 1.0) == x );
496 BOOST_CHECK( x == x );
497 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
498 BOOST_CHECK( (x /= 2.0) == y );
499 BOOST_CHECK( x == y );
500 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
501 BOOST_CHECK( (x /= -0.25) == y );
502 BOOST_CHECK( x == y );
503 BOOST_CHECK( (x /= 0.0) == Prob(3, 0.0) );
504 BOOST_CHECK( x == Prob(3, 0.0) );
505
506 x = xorg;
507 BOOST_CHECK( (x ^= 1.0) == x );
508 BOOST_CHECK( x == x );
509 BOOST_CHECK( (x ^= 0.0) == Prob(3, 1.0) );
510 BOOST_CHECK( x == Prob(3, 1.0) );
511 x = xorg;
512 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
513 BOOST_CHECK( (x ^= 2.0) == y );
514 BOOST_CHECK( x == y );
515 y.set( 0, 0.5 ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
516 BOOST_CHECK( (x ^= -0.5) == y );
517 BOOST_CHECK( x == y );
518 }
519
520
521 BOOST_AUTO_TEST_CASE( VectorOperationsTest ) {
522 size_t N = 6;
523 Prob xorg(N), x(N);
524 xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, 1.0 ); xorg.set( 3, 0.0 ); xorg.set( 4, 2.0 ); xorg.set( 5, 3.0 );
525 Prob y(N);
526 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
527 Prob z(N), r(N);
528
529 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
530 x = xorg;
531 r = (x += y);
532 for( size_t i = 0; i < N; i++ )
533 BOOST_CHECK_CLOSE( r[i], z[i], tol );
534 BOOST_CHECK( x == z );
535 x = xorg;
536 BOOST_CHECK( x.pwBinaryOp( y, std::plus<Real>() ) == z );
537 BOOST_CHECK( x == z );
538
539 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
540 x = xorg;
541 r = (x -= y);
542 for( size_t i = 0; i < N; i++ )
543 BOOST_CHECK_CLOSE( r[i], z[i], tol );
544 BOOST_CHECK( x == z );
545 x = xorg;
546 BOOST_CHECK( x.pwBinaryOp( y, std::minus<Real>() ) == z );
547 BOOST_CHECK( x == z );
548
549 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
550 x = xorg;
551 r = (x *= y);
552 for( size_t i = 0; i < N; i++ )
553 BOOST_CHECK_CLOSE( r[i], z[i], tol );
554 BOOST_CHECK( x == z );
555 x = xorg;
556 BOOST_CHECK( x.pwBinaryOp( y, std::multiplies<Real>() ) == z );
557 BOOST_CHECK( x == z );
558
559 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
560 x = xorg;
561 r = (x /= y);
562 for( size_t i = 0; i < N; i++ )
563 BOOST_CHECK_CLOSE( r[i], z[i], tol );
564 BOOST_CHECK( x == z );
565 x = xorg;
566 BOOST_CHECK( x.pwBinaryOp( y, fo_divides0<Real>() ) == z );
567 BOOST_CHECK( x == z );
568
569 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY );
570 // z.set( 3, INFINITY );
571 z.set( 4, -1.0 ); z.set( 5, 1.0 );
572 x = xorg;
573 r = (x.divide( y ));
574 BOOST_CHECK_CLOSE( r[0], z[0], tol );
575 BOOST_CHECK_CLOSE( r[1], z[1], tol );
576 BOOST_CHECK_EQUAL( r[2], z[2] );
577 BOOST_CHECK( isnan(r[3]) );
578 BOOST_CHECK_CLOSE( r[4], z[4], tol );
579 BOOST_CHECK_CLOSE( r[5], z[5], tol );
580 x.set( 3, 0.0 ); r.set( 3, 0.0 );
581 BOOST_CHECK( x == r );
582 x = xorg;
583 r = x.pwBinaryOp( y, std::divides<Real>() );
584 BOOST_CHECK_CLOSE( r[0], z[0], tol );
585 BOOST_CHECK_CLOSE( r[1], z[1], tol );
586 BOOST_CHECK_EQUAL( r[2], z[2] );
587 BOOST_CHECK( isnan(r[3]) );
588 BOOST_CHECK_CLOSE( r[4], z[4], tol );
589 BOOST_CHECK_CLOSE( r[5], z[5], tol );
590 x.set( 3, 0.0 ); r.set( 3, 0.0 );
591 BOOST_CHECK( x == r );
592
593 z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
594 x = xorg;
595 r = (x ^= y);
596 BOOST_CHECK_CLOSE( r[0], z[0], tol );
597 BOOST_CHECK_EQUAL( r[1], z[1] );
598 BOOST_CHECK_CLOSE( r[2], z[2], tol );
599 BOOST_CHECK_CLOSE( r[3], z[3], tol );
600 BOOST_CHECK_CLOSE( r[4], z[4], tol );
601 BOOST_CHECK_CLOSE( r[5], z[5], tol );
602 BOOST_CHECK( x == z );
603 x = xorg;
604 BOOST_CHECK( x.pwBinaryOp( y, fo_pow<Real>() ) == z );
605 BOOST_CHECK( x == z );
606 }
607
608
609 BOOST_AUTO_TEST_CASE( VectorTransformationsTest ) {
610 size_t N = 6;
611 Prob x(N);
612 x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, 1.0 ); x.set( 3, 0.0 ); x.set( 4, 2.0 ); x.set( 5, 3.0 );
613 Prob y(N);
614 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
615 Prob z(N), r(N);
616
617 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
618 r = x + y;
619 for( size_t i = 0; i < N; i++ )
620 BOOST_CHECK_CLOSE( r[i], z[i], tol );
621 z = x.pwBinaryTr( y, std::plus<Real>() );
622 BOOST_CHECK( r == z );
623
624 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
625 r = x - y;
626 for( size_t i = 0; i < N; i++ )
627 BOOST_CHECK_CLOSE( r[i], z[i], tol );
628 z = x.pwBinaryTr( y, std::minus<Real>() );
629 BOOST_CHECK( r == z );
630
631 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
632 r = x * y;
633 for( size_t i = 0; i < N; i++ )
634 BOOST_CHECK_CLOSE( r[i], z[i], tol );
635 z = x.pwBinaryTr( y, std::multiplies<Real>() );
636 BOOST_CHECK( r == z );
637
638 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
639 r = x / y;
640 for( size_t i = 0; i < N; i++ )
641 BOOST_CHECK_CLOSE( r[i], z[i], tol );
642 z = x.pwBinaryTr( y, fo_divides0<Real>() );
643 BOOST_CHECK( r == z );
644
645 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY );
646 // z.set( 3, INFINITY );
647 z.set( 4, -1.0 ); z.set( 5, 1.0 );
648 r = x.divided_by( y );
649 BOOST_CHECK_CLOSE( r[0], z[0], tol );
650 BOOST_CHECK_CLOSE( r[1], z[1], tol );
651 BOOST_CHECK_EQUAL( r[2], z[2] );
652 BOOST_CHECK( isnan(r[3]) );
653 BOOST_CHECK_CLOSE( r[4], z[4], tol );
654 BOOST_CHECK_CLOSE( r[5], z[5], tol );
655 z = x.pwBinaryTr( y, std::divides<Real>() );
656 BOOST_CHECK_CLOSE( r[0], z[0], tol );
657 BOOST_CHECK_CLOSE( r[1], z[1], tol );
658 BOOST_CHECK_EQUAL( r[2], z[2] );
659 BOOST_CHECK( isnan(r[3]) );
660 BOOST_CHECK_CLOSE( r[4], z[4], tol );
661 BOOST_CHECK_CLOSE( r[5], z[5], tol );
662
663 z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
664 r = x ^ y;
665 BOOST_CHECK_CLOSE( r[0], z[0], tol );
666 BOOST_CHECK_EQUAL( r[1], z[1] );
667 BOOST_CHECK_CLOSE( r[2], z[2], tol );
668 BOOST_CHECK_CLOSE( r[3], z[3], tol );
669 BOOST_CHECK_CLOSE( r[4], z[4], tol );
670 BOOST_CHECK_CLOSE( r[5], z[5], tol );
671 z = x.pwBinaryTr( y, fo_pow<Real>() );
672 BOOST_CHECK( r == z );
673 }
674
675
676 BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
677 Prob x(3), y(3), z(3);
678 x.set( 0, 0.2 );
679 x.set( 1, 0.8 );
680 x.set( 2, 0.0 );
681 y.set( 0, 0.0 );
682 y.set( 1, 0.6 );
683 y.set( 2, 0.4 );
684
685 z = min( x, y );
686 BOOST_CHECK_EQUAL( z[0], 0.0 );
687 BOOST_CHECK_EQUAL( z[1], 0.6 );
688 BOOST_CHECK_EQUAL( z[2], 0.0 );
689 z = max( x, y );
690 BOOST_CHECK_EQUAL( z[0], 0.2 );
691 BOOST_CHECK_EQUAL( z[1], 0.8 );
692 BOOST_CHECK_EQUAL( z[2], 0.4 );
693
694 BOOST_CHECK_EQUAL( dist( x, x, Prob::DISTL1 ), 0.0 );
695 BOOST_CHECK_EQUAL( dist( y, y, Prob::DISTL1 ), 0.0 );
696 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTL1 ), 0.2 + 0.2 + 0.4 );
697 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTL1 ), 0.2 + 0.2 + 0.4 );
698 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTL1 ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
699 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTL1 ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
700 BOOST_CHECK_EQUAL( dist( x, x, Prob::DISTLINF ), 0.0 );
701 BOOST_CHECK_EQUAL( dist( y, y, Prob::DISTLINF ), 0.0 );
702 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTLINF ), 0.4 );
703 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTLINF ), 0.4 );
704 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTLINF ), x.innerProduct( y, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
705 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTLINF ), y.innerProduct( x, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
706 BOOST_CHECK_EQUAL( dist( x, x, Prob::DISTTV ), 0.0 );
707 BOOST_CHECK_EQUAL( dist( y, y, Prob::DISTTV ), 0.0 );
708 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
709 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
710 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTTV ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
711 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTTV ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
712 BOOST_CHECK_EQUAL( dist( x, x, Prob::DISTKL ), 0.0 );
713 BOOST_CHECK_EQUAL( dist( y, y, Prob::DISTKL ), 0.0 );
714 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTKL ), INFINITY );
715 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTKL ), INFINITY );
716 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
717 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
718 BOOST_CHECK_EQUAL( dist( x, x, Prob::DISTHEL ), 0.0 );
719 BOOST_CHECK_EQUAL( dist( y, y, Prob::DISTHEL ), 0.0 );
720 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
721 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
722 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTHEL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
723 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTHEL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
724 x.set( 1, 0.7 ); x.set( 2, 0.1 );
725 y.set( 0, 0.1 ); y.set( 1, 0.5 );
726 BOOST_CHECK_CLOSE( dist( x, y, Prob::DISTKL ), 0.2 * std::log(0.2 / 0.1) + 0.7 * std::log(0.7 / 0.5) + 0.1 * std::log(0.1 / 0.4), tol );
727 BOOST_CHECK_CLOSE( dist( y, x, Prob::DISTKL ), 0.1 * std::log(0.1 / 0.2) + 0.5 * std::log(0.5 / 0.7) + 0.4 * std::log(0.4 / 0.1), tol );
728 BOOST_CHECK_EQUAL( dist( x, y, Prob::DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
729 BOOST_CHECK_EQUAL( dist( y, x, Prob::DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
730
731 std::stringstream ss;
732 ss << x;
733 std::string s;
734 std::getline( ss, s );
735 BOOST_CHECK_EQUAL( s, std::string("(0.2, 0.7, 0.1)") );
736 std::stringstream ss2;
737 ss2 << y;
738 std::getline( ss2, s );
739 BOOST_CHECK_EQUAL( s, std::string("(0.1, 0.5, 0.4)") );
740
741 z = min( x, y );
742 BOOST_CHECK_EQUAL( z[0], 0.1 );
743 BOOST_CHECK_EQUAL( z[1], 0.5 );
744 BOOST_CHECK_EQUAL( z[2], 0.1 );
745 z = max( x, y );
746 BOOST_CHECK_EQUAL( z[0], 0.2 );
747 BOOST_CHECK_EQUAL( z[1], 0.7 );
748 BOOST_CHECK_EQUAL( z[2], 0.4 );
749
750 BOOST_CHECK_CLOSE( x.innerProduct( y, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
751 BOOST_CHECK_CLOSE( y.innerProduct( x, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
752 BOOST_CHECK_CLOSE( x.innerProduct( y, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
753 BOOST_CHECK_CLOSE( y.innerProduct( x, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
754 BOOST_CHECK_CLOSE( x.innerProduct( y, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
755 BOOST_CHECK_CLOSE( y.innerProduct( x, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
756 }
757 */