Fixed two problems related to g++ 4.0.0 on Darwin 9.8.0
[libdai.git] / tests / unit / factor_test.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #include <dai/factor.h>
12 #include <strstream>
13
14
15 using namespace dai;
16
17
18 const Real tol = 1e-8;
19
20
21 #define BOOST_TEST_MODULE FactorTest
22
23
24 #include <boost/test/unit_test.hpp>
25 #include <boost/test/floating_point_comparison.hpp>
26
27
28 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
29 // check constructors
30 Factor x1;
31 BOOST_CHECK_EQUAL( x1.nrStates(), 1 );
32 BOOST_CHECK( x1.p() == Prob( 1, 1.0 ) );
33 BOOST_CHECK( x1.vars() == VarSet() );
34
35 Factor x2( 5.0 );
36 BOOST_CHECK_EQUAL( x2.nrStates(), 1 );
37 BOOST_CHECK( x2.p() == Prob( 1, 5.0 ) );
38 BOOST_CHECK( x2.vars() == VarSet() );
39
40 Var v1( 0, 3 );
41 Factor x3( v1 );
42 BOOST_CHECK_EQUAL( x3.nrStates(), 3 );
43 BOOST_CHECK_SMALL( dist( x3.p(), Prob( 3, 1.0 / 3.0 ), DISTL1 ), tol );
44 BOOST_CHECK( x3.vars() == VarSet( v1 ) );
45 BOOST_CHECK_CLOSE( x3[0], (Real)(1.0 / 3.0), tol );
46 BOOST_CHECK_CLOSE( x3[1], (Real)(1.0 / 3.0), tol );
47 BOOST_CHECK_CLOSE( x3[2], (Real)(1.0 / 3.0), tol );
48
49 Var v2( 1, 2 );
50 Factor x4( VarSet( v1, v2 ) );
51 BOOST_CHECK_EQUAL( x4.nrStates(), 6 );
52 BOOST_CHECK_SMALL( dist( x4.p(), Prob( 6, 1.0 / 6.0 ), DISTL1 ), tol );
53 BOOST_CHECK( x4.vars() == VarSet( v1, v2 ) );
54 for( size_t i = 0; i < 6; i++ )
55 BOOST_CHECK_CLOSE( x4[i], (Real)(1.0 / 6.0), tol );
56
57 Factor x5( VarSet( v1, v2 ), 1.0 );
58 BOOST_CHECK_EQUAL( x5.nrStates(), 6 );
59 BOOST_CHECK( x5.p() == Prob( 6, 1.0 ) );
60 BOOST_CHECK( x5.vars() == VarSet( v1, v2 ) );
61 for( size_t i = 0; i < 6; i++ )
62 BOOST_CHECK_EQUAL( x5[i], (Real)1.0 );
63
64 std::vector<Real> x( 6, 1.0 );
65 for( size_t i = 0; i < 6; i++ )
66 x[i] = 10.0 - i;
67 Factor x6( VarSet( v1, v2 ), x );
68 BOOST_CHECK_EQUAL( x6.nrStates(), 6 );
69 BOOST_CHECK( x6.vars() == VarSet( v1, v2 ) );
70 for( size_t i = 0; i < 6; i++ )
71 BOOST_CHECK_EQUAL( x6[i], x[i] );
72
73 x.resize( 4 );
74 BOOST_CHECK_THROW( Factor x7( VarSet( v1, v2 ), x ), Exception );
75
76 x.resize( 6 );
77 x[4] = 10.0 - 4; x[5] = 10.0 - 5;
78 Factor x8( VarSet( v2, v1 ), &(x[0]) );
79 BOOST_CHECK_EQUAL( x8.nrStates(), 6 );
80 BOOST_CHECK( x8.vars() == VarSet( v1, v2 ) );
81 for( size_t i = 0; i < 6; i++ )
82 BOOST_CHECK_EQUAL( x8[i], x[i] );
83
84 Prob xx( x );
85 Factor x9( VarSet( v2, v1 ), xx );
86 BOOST_CHECK_EQUAL( x9.nrStates(), 6 );
87 BOOST_CHECK( x9.vars() == VarSet( v1, v2 ) );
88 for( size_t i = 0; i < 6; i++ )
89 BOOST_CHECK_EQUAL( x9[i], x[i] );
90
91 xx.resize( 4 );
92 BOOST_CHECK_THROW( Factor x10( VarSet( v2, v1 ), xx ), Exception );
93
94 std::vector<Real> w;
95 w.push_back( 0.1 );
96 w.push_back( 3.5 );
97 w.push_back( 2.8 );
98 w.push_back( 6.3 );
99 w.push_back( 8.4 );
100 w.push_back( 0.0 );
101 w.push_back( 7.4 );
102 w.push_back( 2.4 );
103 w.push_back( 8.9 );
104 w.push_back( 1.3 );
105 w.push_back( 1.6 );
106 w.push_back( 2.6 );
107 Var v4( 4, 3 );
108 Var v8( 8, 2 );
109 Var v7( 7, 2 );
110 std::vector<Var> vars;
111 vars.push_back( v4 );
112 vars.push_back( v8 );
113 vars.push_back( v7 );
114 Factor x11( vars, w );
115 BOOST_CHECK_EQUAL( x11.nrStates(), 12 );
116 BOOST_CHECK( x11.vars() == VarSet( vars.begin(), vars.end() ) );
117 BOOST_CHECK_EQUAL( x11[0], (Real)0.1 );
118 BOOST_CHECK_EQUAL( x11[1], (Real)3.5 );
119 BOOST_CHECK_EQUAL( x11[2], (Real)2.8 );
120 BOOST_CHECK_EQUAL( x11[3], (Real)7.4 );
121 BOOST_CHECK_EQUAL( x11[4], (Real)2.4 );
122 BOOST_CHECK_EQUAL( x11[5], (Real)8.9 );
123 BOOST_CHECK_EQUAL( x11[6], (Real)6.3 );
124 BOOST_CHECK_EQUAL( x11[7], (Real)8.4 );
125 BOOST_CHECK_EQUAL( x11[8], (Real)0.0 );
126 BOOST_CHECK_EQUAL( x11[9], (Real)1.3 );
127 BOOST_CHECK_EQUAL( x11[10], (Real)1.6 );
128 BOOST_CHECK_EQUAL( x11[11], (Real)2.6 );
129
130 Factor x12( x11 );
131 BOOST_CHECK( x12 == x11 );
132
133 Factor x13 = x12;
134 BOOST_CHECK( x13 == x11 );
135 }
136
137
138 BOOST_AUTO_TEST_CASE( QueriesTest ) {
139 Factor x( Var( 5, 5 ), 0.0 );
140 for( size_t i = 0; i < x.nrStates(); i++ )
141 x.set( i, 2.0 - i );
142
143 // test min, max, sum, sumAbs, maxAbs
144 BOOST_CHECK_CLOSE( x.sum(), (Real)0.0, tol );
145 BOOST_CHECK_CLOSE( x.max(), (Real)2.0, tol );
146 BOOST_CHECK_CLOSE( x.min(), (Real)-2.0, tol );
147 BOOST_CHECK_CLOSE( x.sumAbs(), (Real)6.0, tol );
148 BOOST_CHECK_CLOSE( x.maxAbs(), (Real)2.0, tol );
149 x.set( 1, 1.0 );
150 BOOST_CHECK_CLOSE( x.maxAbs(), (Real)2.0, tol );
151 x /= x.sum();
152
153 // test entropy
154 BOOST_CHECK( x.entropy() < Prob(5).entropy() );
155 for( size_t i = 1; i < 100; i++ )
156 BOOST_CHECK_CLOSE( Factor( Var(0,i) ).entropy(), dai::log((Real)i), tol );
157
158 // test hasNaNs and hasNegatives
159 BOOST_CHECK( !Factor( 0.0 ).hasNaNs() );
160 Real c = 0.0;
161 BOOST_CHECK( Factor( c / c ).hasNaNs() );
162 BOOST_CHECK( !Factor( 0.0 ).hasNegatives() );
163 BOOST_CHECK( !Factor( 1.0 ).hasNegatives() );
164 BOOST_CHECK( Factor( -1.0 ).hasNegatives() );
165 x.set( 0, 0.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 ); x.set( 3, 1.0 ); x.set( 4, 100.0 );
166 BOOST_CHECK( x.hasNegatives() );
167 x.set( 2, -INFINITY );
168 BOOST_CHECK( x.hasNegatives() );
169 x.set( 2, INFINITY );
170 BOOST_CHECK( !x.hasNegatives() );
171 x.set( 2, -1.0 );
172
173 // test strength
174 Var x0(0,2);
175 Var x1(1,2);
176 BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, 1.0 ).strength( x0, x1 ), std::tanh( (Real)1.0 ), tol );
177 BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, -1.0 ).strength( x0, x1 ), std::tanh( (Real)1.0 ), tol );
178 BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, 0.5 ).strength( x0, x1 ), std::tanh( (Real)0.5 ), tol );
179
180 // test ==
181 Factor a(Var(0,3)), b(Var(0,3));
182 Factor d(Var(1,3));
183 BOOST_CHECK( !(a == d) );
184 BOOST_CHECK( !(b == d) );
185 BOOST_CHECK( a == b );
186 a.set( 0, 0.0 );
187 BOOST_CHECK( !(a == b) );
188 b.set( 2, 0.0 );
189 BOOST_CHECK( !(a == b) );
190 b.set( 0, 0.0 );
191 BOOST_CHECK( !(a == b) );
192 a.set( 1, 0.0 );
193 BOOST_CHECK( !(a == b) );
194 b.set( 1, 0.0 );
195 BOOST_CHECK( !(a == b) );
196 a.set( 2, 0.0 );
197 BOOST_CHECK( a == b );
198 }
199
200
201 BOOST_AUTO_TEST_CASE( UnaryTransformationsTest ) {
202 Var v( 0, 3 );
203 Factor x( v );
204 x.set( 0, -2.0 );
205 x.set( 1, 0.0 );
206 x.set( 2, 2.0 );
207
208 Factor y = -x;
209 BOOST_CHECK_CLOSE( y[0], (Real)2.0, tol );
210 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
211 BOOST_CHECK_CLOSE( y[2], (Real)-2.0, tol );
212
213 y = x.abs();
214 BOOST_CHECK_CLOSE( y[0], (Real)2.0, tol );
215 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
216 BOOST_CHECK_CLOSE( y[2], (Real)2.0, tol );
217
218 y = x.exp();
219 BOOST_CHECK_CLOSE( y[0], dai::exp((Real)-2.0), tol );
220 BOOST_CHECK_CLOSE( y[1], (Real)1.0, tol );
221 BOOST_CHECK_CLOSE( y[2], (Real)1.0 / y[0], tol );
222
223 y = x.log(false);
224 BOOST_CHECK( dai::isnan( y[0] ) );
225 BOOST_CHECK_EQUAL( y[1], -INFINITY );
226 BOOST_CHECK_CLOSE( y[2], dai::log((Real)2.0), tol );
227
228 y = x.log(true);
229 BOOST_CHECK( dai::isnan( y[0] ) );
230 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
231 BOOST_CHECK_CLOSE( y[2], dai::log((Real)2.0), tol );
232
233 y = x.inverse(false);
234 BOOST_CHECK_CLOSE( y[0], (Real)-0.5, tol );
235 BOOST_CHECK_EQUAL( y[1], INFINITY );
236 BOOST_CHECK_CLOSE( y[2], (Real)0.5, tol );
237
238 y = x.inverse(true);
239 BOOST_CHECK_CLOSE( y[0], (Real)-0.5, tol );
240 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
241 BOOST_CHECK_CLOSE( y[2], (Real)0.5, tol );
242
243 x.set( 0, 2.0 );
244 y = x.normalized();
245 BOOST_CHECK_CLOSE( y[0], (Real)0.5, tol );
246 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
247 BOOST_CHECK_CLOSE( y[2], (Real)0.5, tol );
248
249 y = x.normalized( NORMPROB );
250 BOOST_CHECK_CLOSE( y[0], (Real)0.5, tol );
251 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
252 BOOST_CHECK_CLOSE( y[2], (Real)0.5, tol );
253
254 x.set( 0, -2.0 );
255 y = x.normalized( NORMLINF );
256 BOOST_CHECK_CLOSE( y[0], (Real)-1.0, tol );
257 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
258 BOOST_CHECK_CLOSE( y[2], (Real)1.0, tol );
259 }
260
261
262 BOOST_AUTO_TEST_CASE( UnaryOperationsTest ) {
263 Var v( 0, 3 );
264 Factor xorg( v );
265 xorg.set( 0, 2.0 );
266 xorg.set( 1, 0.0 );
267 xorg.set( 2, 1.0 );
268 Factor y( v );
269
270 Factor x = xorg;
271 BOOST_CHECK( x.setUniform() == Factor( v ) );
272 BOOST_CHECK( x == Factor( v ) );
273
274 y.set( 0, dai::exp(2.0) );
275 y.set( 1, 1.0 );
276 y.set( 2, dai::exp(1.0) );
277 x = xorg;
278 BOOST_CHECK_SMALL( dist( x.takeExp(), y, DISTL1 ), tol );
279 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
280
281 y.set( 0, dai::log(2.0) );
282 y.set( 1, -INFINITY );
283 y.set( 2, 0.0 );
284 x = xorg;
285 Factor z = x.takeLog();
286 BOOST_CHECK_CLOSE( z[0], y[0], tol );
287 //BOOST_CHECK_CLOSE( z[1], y[1], tol );
288 BOOST_CHECK_CLOSE( z[2], y[2], tol );
289 BOOST_CHECK( z.vars() == y.vars() );
290 BOOST_CHECK( x == z );
291 x = xorg;
292 z = x.takeLog(false);
293 BOOST_CHECK_CLOSE( z[0], y[0], tol );
294 //BOOST_CHECK_CLOSE( z[1], y[1], tol );
295 BOOST_CHECK_CLOSE( z[2], y[2], tol );
296 BOOST_CHECK( z.vars() == y.vars() );
297 BOOST_CHECK( x == z );
298
299 y.set( 1, 0.0 );
300 x = xorg;
301 BOOST_CHECK_SMALL( dist( x.takeLog(true), y, DISTL1 ), tol );
302 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
303
304 y.set( 0, 2.0 / 3.0 );
305 y.set( 1, 0.0 / 3.0 );
306 y.set( 2, 1.0 / 3.0 );
307 x = xorg;
308 BOOST_CHECK_CLOSE( x.normalize(), (Real)3.0, tol );
309 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
310
311 x = xorg;
312 BOOST_CHECK_CLOSE( x.normalize( NORMPROB ), (Real)3.0, tol );
313 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
314
315 y.set( 0, 2.0 / 2.0 );
316 y.set( 1, 0.0 / 2.0 );
317 y.set( 2, 1.0 / 2.0 );
318 x = xorg;
319 BOOST_CHECK_CLOSE( x.normalize( NORMLINF ), (Real)2.0, tol );
320 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
321
322 xorg.set( 0, -2.0 );
323 y.set( 0, 2.0 );
324 y.set( 1, 0.0 );
325 y.set( 2, 1.0 );
326 x = xorg;
327 BOOST_CHECK_SMALL( dist( x.takeAbs(), y, DISTL1 ), tol );
328 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
329
330 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
331 x.randomize();
332 for( size_t i = 0; i < x.nrStates(); i++ ) {
333 BOOST_CHECK( x[i] < (Real)1.0 );
334 BOOST_CHECK( x[i] >= (Real)0.0 );
335 }
336 }
337 }
338
339
340 BOOST_AUTO_TEST_CASE( ScalarOperationsTest ) {
341 Var v( 0, 3 );
342 Factor xorg( v ), x( v );
343 xorg.set( 0, 2.0 );
344 xorg.set( 1, 0.0 );
345 xorg.set( 2, 1.0 );
346 Factor y( v );
347
348 x = xorg;
349 BOOST_CHECK( x.fill( 1.0 ) == Factor(v, 1.0) );
350 BOOST_CHECK( x == Factor(v, 1.0) );
351 BOOST_CHECK( x.fill( 2.0 ) == Factor(v, 2.0) );
352 BOOST_CHECK( x == Factor(v, 2.0) );
353 BOOST_CHECK( x.fill( 0.0 ) == Factor(v, 0.0) );
354 BOOST_CHECK( x == Factor(v, 0.0) );
355
356 x = xorg;
357 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
358 BOOST_CHECK_SMALL( dist( (x += 1.0), y, DISTL1 ), tol );
359 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
360 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
361 BOOST_CHECK_SMALL( dist( (x += -2.0), y, DISTL1 ), tol );
362 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
363
364 x = xorg;
365 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
366 BOOST_CHECK_SMALL( dist( (x -= 1.0), y, DISTL1 ), tol );
367 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
368 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
369 BOOST_CHECK_SMALL( dist( (x -= -2.0), y, DISTL1 ), tol );
370 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
371
372 x = xorg;
373 BOOST_CHECK_SMALL( dist( (x *= 1.0), x, DISTL1 ), tol );
374 BOOST_CHECK_SMALL( dist( x, x, DISTL1 ), tol );
375 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
376 BOOST_CHECK_SMALL( dist( (x *= 2.0), y, DISTL1 ), tol );
377 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
378 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
379 BOOST_CHECK_SMALL( dist( (x *= -0.25), y, DISTL1 ), tol );
380 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
381
382 x = xorg;
383 BOOST_CHECK_SMALL( dist( (x /= 1.0), x, DISTL1 ), tol );
384 BOOST_CHECK_SMALL( dist( x, x, DISTL1 ), tol );
385 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
386 BOOST_CHECK_SMALL( dist( (x /= 2.0), y, DISTL1 ), tol );
387 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
388 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
389 BOOST_CHECK_SMALL( dist( (x /= -0.25), y, DISTL1 ), tol );
390 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
391 BOOST_CHECK_SMALL( dist( (x /= 0.0), Factor(v, 0.0), DISTL1 ), tol );
392 BOOST_CHECK_SMALL( dist( x, Factor(v, 0.0), DISTL1 ), tol );
393
394 x = xorg;
395 BOOST_CHECK_SMALL( dist( (x ^= 1.0), x, DISTL1 ), tol );
396 BOOST_CHECK_SMALL( dist( x, x, DISTL1 ), tol );
397 BOOST_CHECK_SMALL( dist( (x ^= 0.0), Factor(v, 1.0), DISTL1 ), tol );
398 BOOST_CHECK_SMALL( dist( x, Factor(v, 1.0), DISTL1 ), tol );
399 x = xorg;
400 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
401 BOOST_CHECK_SMALL( dist( (x ^= 2.0), y, DISTL1 ), tol );
402 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
403 y.set( 0, 2.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
404 BOOST_CHECK( (x ^= 0.5) == y );
405 BOOST_CHECK( x == y );
406 }
407
408
409 BOOST_AUTO_TEST_CASE( ScalarTransformationsTest ) {
410 Var v( 0, 3 );
411 Factor x( v );
412 x.set( 0, 2.0 );
413 x.set( 1, 0.0 );
414 x.set( 2, 1.0 );
415 Factor y( v );
416
417 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
418 BOOST_CHECK_SMALL( dist( (x + 1.0), y, DISTL1 ), tol );
419 y.set( 0, 0.0 ); y.set( 1, -2.0 ); y.set( 2, -1.0 );
420 BOOST_CHECK_SMALL( dist( (x + (-2.0)), y, DISTL1 ), tol );
421
422 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
423 BOOST_CHECK_SMALL( dist( (x - 1.0), y, DISTL1 ), tol );
424 y.set( 0, 4.0 ); y.set( 1, 2.0 ); y.set( 2, 3.0 );
425 BOOST_CHECK_SMALL( dist( (x - (-2.0)), y, DISTL1 ), tol );
426
427 BOOST_CHECK_SMALL( dist( (x * 1.0), x, DISTL1 ), tol );
428 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
429 BOOST_CHECK_SMALL( dist( (x * 2.0), y, DISTL1 ), tol );
430 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
431 BOOST_CHECK_SMALL( dist( (x * -0.5), y, DISTL1 ), tol );
432
433 BOOST_CHECK_SMALL( dist( (x / 1.0), x, DISTL1 ), tol );
434 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
435 BOOST_CHECK_SMALL( dist( (x / 2.0), y, DISTL1 ), tol );
436 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
437 BOOST_CHECK_SMALL( dist( (x / -0.5), y, DISTL1 ), tol );
438 BOOST_CHECK_SMALL( dist( (x / 0.0), Factor(v, 0.0), DISTL1 ), tol );
439
440 BOOST_CHECK_SMALL( dist( (x ^ 1.0), x, DISTL1 ), tol );
441 BOOST_CHECK_SMALL( dist( (x ^ 0.0), Factor(v, 1.0), DISTL1 ), tol );
442 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
443 BOOST_CHECK_SMALL( dist( (x ^ 2.0), y, DISTL1 ), tol );
444 y.set( 0, std::sqrt(2.0) ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
445 BOOST_CHECK_SMALL( dist( (x ^ 0.5), y, DISTL1 ), tol );
446 }
447
448
449 BOOST_AUTO_TEST_CASE( SimilarFactorOperationsTest ) {
450 size_t N = 6;
451 Var v( 0, N );
452 Factor xorg( v ), x( v );
453 xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, 1.0 ); xorg.set( 3, 0.0 ); xorg.set( 4, 2.0 ); xorg.set( 5, 3.0 );
454 Factor y( v );
455 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
456 Factor z( v ), r( v );
457
458 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
459 x = xorg;
460 r = (x += y);
461 for( size_t i = 0; i < N; i++ )
462 BOOST_CHECK_CLOSE( r[i], z[i], tol );
463 BOOST_CHECK( x == r );
464 x = xorg;
465 BOOST_CHECK( x.binaryOp( y, std::plus<Real>() ) == r );
466 BOOST_CHECK( x == r );
467
468 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
469 x = xorg;
470 r = (x -= y);
471 for( size_t i = 0; i < N; i++ )
472 BOOST_CHECK_CLOSE( r[i], z[i], tol );
473 BOOST_CHECK( x == r );
474 x = xorg;
475 BOOST_CHECK( x.binaryOp( y, std::minus<Real>() ) == r );
476 BOOST_CHECK( x == r );
477
478 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
479 x = xorg;
480 r = (x *= y);
481 for( size_t i = 0; i < N; i++ )
482 BOOST_CHECK_CLOSE( r[i], z[i], tol );
483 BOOST_CHECK( x == r );
484 x = xorg;
485 BOOST_CHECK( x.binaryOp( y, std::multiplies<Real>() ) == r );
486 BOOST_CHECK( x == r );
487
488 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
489 x = xorg;
490 r = (x /= y);
491 for( size_t i = 0; i < N; i++ )
492 BOOST_CHECK_CLOSE( r[i], z[i], tol );
493 BOOST_CHECK( x == r );
494 x = xorg;
495 BOOST_CHECK( x.binaryOp( y, fo_divides0<Real>() ) == r );
496 BOOST_CHECK( x == r );
497 }
498
499
500 BOOST_AUTO_TEST_CASE( SimilarFactorTransformationsTest ) {
501 size_t N = 6;
502 Var v( 0, N );
503 Factor x( v );
504 x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, 1.0 ); x.set( 3, 0.0 ); x.set( 4, 2.0 ); x.set( 5, 3.0 );
505 Factor y( v );
506 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
507 Factor z( v ), r( v );
508
509 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
510 r = x + y;
511 for( size_t i = 0; i < N; i++ )
512 BOOST_CHECK_CLOSE( r[i], z[i], tol );
513 z = x.binaryTr( y, std::plus<Real>() );
514 BOOST_CHECK( r == z );
515
516 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
517 r = x - y;
518 for( size_t i = 0; i < N; i++ )
519 BOOST_CHECK_CLOSE( r[i], z[i], tol );
520 z = x.binaryTr( y, std::minus<Real>() );
521 BOOST_CHECK( r == z );
522
523 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
524 r = x * y;
525 for( size_t i = 0; i < N; i++ )
526 BOOST_CHECK_CLOSE( r[i], z[i], tol );
527 z = x.binaryTr( y, std::multiplies<Real>() );
528 BOOST_CHECK( r == z );
529
530 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
531 r = x / y;
532 for( size_t i = 0; i < N; i++ )
533 BOOST_CHECK_CLOSE( r[i], z[i], tol );
534 z = x.binaryTr( y, fo_divides0<Real>() );
535 BOOST_CHECK( r == z );
536 }
537
538
539 BOOST_AUTO_TEST_CASE( FactorOperationsTest ) {
540 size_t N = 9;
541 Var v1( 1, 3 );
542 Var v2( 2, 3 );
543 Factor xorg( v1 ), x( v1 );
544 xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, -1.0 );
545 Factor y( v2 );
546 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
547 Factor r;
548
549 Factor z( VarSet( v1, v2 ) );
550 z.set( 0, 2.5 ); z.set( 1, 0.5 ); z.set( 2, -0.5 );
551 z.set( 3, 1.0 ); z.set( 4, -1.0 ); z.set( 5, -2.0 );
552 z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
553 x = xorg;
554 r = (x += y);
555 for( size_t i = 0; i < N; i++ )
556 BOOST_CHECK_CLOSE( r[i], z[i], tol );
557 BOOST_CHECK( x == r );
558 x = xorg;
559 BOOST_CHECK( x.binaryOp( y, std::plus<Real>() ) == r );
560 BOOST_CHECK( x == r );
561
562 z.set( 0, 1.5 ); z.set( 1, -0.5 ); z.set( 2, -1.5 );
563 z.set( 3, 3.0 ); z.set( 4, 1.0 ); z.set( 5, 0.0 );
564 z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
565 x = xorg;
566 r = (x -= y);
567 for( size_t i = 0; i < N; i++ )
568 BOOST_CHECK_CLOSE( r[i], z[i], tol );
569 BOOST_CHECK( x == r );
570 x = xorg;
571 BOOST_CHECK( x.binaryOp( y, std::minus<Real>() ) == r );
572 BOOST_CHECK( x == r );
573
574 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, -0.5 );
575 z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
576 z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
577 x = xorg;
578 r = (x *= y);
579 for( size_t i = 0; i < N; i++ )
580 BOOST_CHECK_CLOSE( r[i], z[i], tol );
581 BOOST_CHECK( x == r );
582 x = xorg;
583 BOOST_CHECK( x.binaryOp( y, std::multiplies<Real>() ) == r );
584 BOOST_CHECK( x == r );
585
586 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, -2.0 );
587 z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
588 z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
589 x = xorg;
590 r = (x /= y);
591 for( size_t i = 0; i < N; i++ )
592 BOOST_CHECK_CLOSE( r[i], z[i], tol );
593 BOOST_CHECK( x == r );
594 x = xorg;
595 BOOST_CHECK( x.binaryOp( y, fo_divides0<Real>() ) == r );
596 BOOST_CHECK( x == r );
597 }
598
599
600 BOOST_AUTO_TEST_CASE( FactorTransformationsTest ) {
601 size_t N = 9;
602 Var v1( 1, 3 );
603 Var v2( 2, 3 );
604 Factor x( v1 );
605 x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 );
606 Factor y( v2 );
607 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
608 Factor r;
609
610 Factor z( VarSet( v1, v2 ) );
611 z.set( 0, 2.5 ); z.set( 1, 0.5 ); z.set( 2, -0.5 );
612 z.set( 3, 1.0 ); z.set( 4, -1.0 ); z.set( 5, -2.0 );
613 z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
614 r = x + y;
615 for( size_t i = 0; i < N; i++ )
616 BOOST_CHECK_CLOSE( r[i], z[i], tol );
617 z = x.binaryTr( y, std::plus<Real>() );
618 BOOST_CHECK( r == z );
619
620 z.set( 0, 1.5 ); z.set( 1, -0.5 ); z.set( 2, -1.5 );
621 z.set( 3, 3.0 ); z.set( 4, 1.0 ); z.set( 5, 0.0 );
622 z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
623 r = x - y;
624 for( size_t i = 0; i < N; i++ )
625 BOOST_CHECK_CLOSE( r[i], z[i], tol );
626 z = x.binaryTr( y, std::minus<Real>() );
627 BOOST_CHECK( r == z );
628
629 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, -0.5 );
630 z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
631 z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
632 r = x * y;
633 for( size_t i = 0; i < N; i++ )
634 BOOST_CHECK_CLOSE( r[i], z[i], tol );
635 z = x.binaryTr( y, std::multiplies<Real>() );
636 BOOST_CHECK( r == z );
637
638 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, -2.0 );
639 z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
640 z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
641 r = x / y;
642 for( size_t i = 0; i < N; i++ )
643 BOOST_CHECK_CLOSE( r[i], z[i], tol );
644 z = x.binaryOp( y, fo_divides0<Real>() );
645 BOOST_CHECK( r == z );
646 }
647
648
649 BOOST_AUTO_TEST_CASE( MiscOperationsTest ) {
650 Var v1(1, 2);
651 Var v2(2, 3);
652 Factor x( VarSet( v1, v2 ) );
653 x.randomize();
654
655 // slice
656 Factor y = x.slice( v1, 0 );
657 BOOST_CHECK( y.vars() == VarSet( v2 ) );
658 BOOST_CHECK_EQUAL( y.nrStates(), 3 );
659 BOOST_CHECK_EQUAL( y[0], x[0] );
660 BOOST_CHECK_EQUAL( y[1], x[2] );
661 BOOST_CHECK_EQUAL( y[2], x[4] );
662 y = x.slice( v1, 1 );
663 BOOST_CHECK( y.vars() == VarSet( v2 ) );
664 BOOST_CHECK_EQUAL( y.nrStates(), 3 );
665 BOOST_CHECK_EQUAL( y[0], x[1] );
666 BOOST_CHECK_EQUAL( y[1], x[3] );
667 BOOST_CHECK_EQUAL( y[2], x[5] );
668 y = x.slice( v2, 0 );
669 BOOST_CHECK( y.vars() == VarSet( v1 ) );
670 BOOST_CHECK_EQUAL( y.nrStates(), 2 );
671 BOOST_CHECK_EQUAL( y[0], x[0] );
672 BOOST_CHECK_EQUAL( y[1], x[1] );
673 y = x.slice( v2, 1 );
674 BOOST_CHECK( y.vars() == VarSet( v1 ) );
675 BOOST_CHECK_EQUAL( y.nrStates(), 2 );
676 BOOST_CHECK_EQUAL( y[0], x[2] );
677 BOOST_CHECK_EQUAL( y[1], x[3] );
678 y = x.slice( v2, 2 );
679 BOOST_CHECK( y.vars() == VarSet( v1 ) );
680 BOOST_CHECK_EQUAL( y.nrStates(), 2 );
681 BOOST_CHECK_EQUAL( y[0], x[4] );
682 BOOST_CHECK_EQUAL( y[1], x[5] );
683 for( size_t i = 0; i < x.nrStates(); i++ ) {
684 y = x.slice( VarSet( v1, v2 ), 0 );
685 BOOST_CHECK( y.vars() == VarSet() );
686 BOOST_CHECK_EQUAL( y.nrStates(), 1 );
687 BOOST_CHECK_EQUAL( y[0], x[0] );
688 }
689 y = x.slice( VarSet(), 0 );
690 BOOST_CHECK_EQUAL( y, x );
691
692 // embed
693 Var v3(3, 4);
694 BOOST_CHECK_THROW( x.embed( VarSet( v3 ) ), Exception );
695 BOOST_CHECK_THROW( x.embed( VarSet( v3, v2 ) ), Exception );
696 y = x.embed( VarSet( v3, v2 ) | v1 );
697 for( size_t i = 0; i < y.nrStates(); i++ )
698 BOOST_CHECK_EQUAL( y[i], x[i % 6] );
699 y = x.embed( VarSet( v1, v2 ) );
700 BOOST_CHECK_EQUAL( x, y );
701
702 // marginal
703 y = x.marginal( v1 );
704 BOOST_CHECK( y.vars() == VarSet( v1 ) );
705 BOOST_CHECK_CLOSE( y[0], (x[0] + x[2] + x[4]) / x.sum(), tol );
706 BOOST_CHECK_CLOSE( y[1], (x[1] + x[3] + x[5]) / x.sum(),tol );
707 y = x.marginal( v2 );
708 BOOST_CHECK( y.vars() == VarSet( v2 ) );
709 BOOST_CHECK_CLOSE( y[0], (x[0] + x[1]) / x.sum(), tol );
710 BOOST_CHECK_CLOSE( y[1], (x[2] + x[3]) / x.sum(), tol );
711 BOOST_CHECK_CLOSE( y[2], (x[4] + x[5]) / x.sum(), tol );
712 y = x.marginal( VarSet() );
713 BOOST_CHECK( y.vars() == VarSet() );
714 BOOST_CHECK_CLOSE( y[0], (Real)1.0, tol );
715 y = x.marginal( VarSet( v1, v2 ) );
716 BOOST_CHECK_SMALL( dist( y, x.normalized(), DISTL1 ), tol );
717
718 y = x.marginal( v1, true );
719 BOOST_CHECK( y.vars() == VarSet( v1 ) );
720 BOOST_CHECK_CLOSE( y[0], (x[0] + x[2] + x[4]) / x.sum(), tol );
721 BOOST_CHECK_CLOSE( y[1], (x[1] + x[3] + x[5]) / x.sum(), tol );
722 y = x.marginal( v2, true );
723 BOOST_CHECK( y.vars() == VarSet( v2 ) );
724 BOOST_CHECK_CLOSE( y[0], (x[0] + x[1]) / x.sum(), tol );
725 BOOST_CHECK_CLOSE( y[1], (x[2] + x[3]) / x.sum(), tol );
726 BOOST_CHECK_CLOSE( y[2], (x[4] + x[5]) / x.sum(), tol );
727 y = x.marginal( VarSet(), true );
728 BOOST_CHECK( y.vars() == VarSet() );
729 BOOST_CHECK_CLOSE( y[0], (Real)1.0, tol );
730 y = x.marginal( VarSet( v1, v2 ), true );
731 BOOST_CHECK_SMALL( dist( y, x.normalized(), DISTL1 ), tol );
732
733 y = x.marginal( v1, false );
734 BOOST_CHECK( y.vars() == VarSet( v1 ) );
735 BOOST_CHECK_CLOSE( y[0], x[0] + x[2] + x[4], tol );
736 BOOST_CHECK_CLOSE( y[1], x[1] + x[3] + x[5], tol );
737 y = x.marginal( v2, false );
738 BOOST_CHECK( y.vars() == VarSet( v2 ) );
739 BOOST_CHECK_CLOSE( y[0], x[0] + x[1], tol );
740 BOOST_CHECK_CLOSE( y[1], x[2] + x[3], tol );
741 BOOST_CHECK_CLOSE( y[2], x[4] + x[5], tol );
742 y = x.marginal( VarSet(), false );
743 BOOST_CHECK( y.vars() == VarSet() );
744 BOOST_CHECK_CLOSE( y[0], x.sum(), tol );
745 y = x.marginal( VarSet( v1, v2 ), false );
746 BOOST_CHECK_SMALL( dist( y, x, DISTL1 ), tol );
747
748 // maxMarginal
749 y = x.maxMarginal( v1 );
750 BOOST_CHECK( y.vars() == VarSet( v1 ) );
751 BOOST_CHECK_CLOSE( y[0], x.slice( v1, 0 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()), tol );
752 BOOST_CHECK_CLOSE( y[1], x.slice( v1, 1 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()), tol );
753 y = x.maxMarginal( v2 );
754 BOOST_CHECK( y.vars() == VarSet( v2 ) );
755 BOOST_CHECK_CLOSE( y[0], x.slice( v2, 0 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
756 BOOST_CHECK_CLOSE( y[1], x.slice( v2, 1 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
757 BOOST_CHECK_CLOSE( y[2], x.slice( v2, 2 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
758 y = x.maxMarginal( VarSet() );
759 BOOST_CHECK( y.vars() == VarSet() );
760 BOOST_CHECK_CLOSE( y[0], (Real)1.0, tol );
761 y = x.maxMarginal( VarSet( v1, v2 ) );
762 BOOST_CHECK_SMALL( dist( y, x.normalized(), DISTL1 ), tol );
763
764 y = x.maxMarginal( v1, true );
765 BOOST_CHECK( y.vars() == VarSet( v1 ) );
766 BOOST_CHECK_CLOSE( y[0], x.slice( v1, 0 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()), tol );
767 BOOST_CHECK_CLOSE( y[1], x.slice( v1, 1 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()), tol );
768 y = x.maxMarginal( v2, true );
769 BOOST_CHECK( y.vars() == VarSet( v2 ) );
770 BOOST_CHECK_CLOSE( y[0], x.slice( v2, 0 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
771 BOOST_CHECK_CLOSE( y[1], x.slice( v2, 1 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
772 BOOST_CHECK_CLOSE( y[2], x.slice( v2, 2 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
773 y = x.maxMarginal( VarSet(), true );
774 BOOST_CHECK( y.vars() == VarSet() );
775 BOOST_CHECK_CLOSE( y[0], (Real)1.0, tol );
776 y = x.maxMarginal( VarSet( v1, v2 ), true );
777 BOOST_CHECK_SMALL( dist( y, x.normalized(), DISTL1 ), tol );
778
779 y = x.maxMarginal( v1, false );
780 BOOST_CHECK( y.vars() == VarSet( v1 ) );
781 BOOST_CHECK_EQUAL( y[0], x.slice( v1, 0 ).max() );
782 BOOST_CHECK_EQUAL( y[1], x.slice( v1, 1 ).max() );
783 y = x.maxMarginal( v2, false );
784 BOOST_CHECK( y.vars() == VarSet( v2 ) );
785 BOOST_CHECK_EQUAL( y[0], x.slice( v2, 0 ).max() );
786 BOOST_CHECK_EQUAL( y[1], x.slice( v2, 1 ).max() );
787 BOOST_CHECK_EQUAL( y[2], x.slice( v2, 2 ).max() );
788 y = x.maxMarginal( VarSet(), false );
789 BOOST_CHECK( y.vars() == VarSet() );
790 BOOST_CHECK_EQUAL( y[0], x.max() );
791 y = x.maxMarginal( VarSet( v1, v2 ), false );
792 BOOST_CHECK( y == x );
793 }
794
795
796 BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
797 Var v( 0, 3 );
798 Factor x(v), y(v), z(v);
799 x.set( 0, 0.2 );
800 x.set( 1, 0.8 );
801 x.set( 2, 0.0 );
802 y.set( 0, 0.0 );
803 y.set( 1, 0.6 );
804 y.set( 2, 0.4 );
805
806 z = min( x, y );
807 BOOST_CHECK_EQUAL( z[0], (Real)0.0 );
808 BOOST_CHECK_EQUAL( z[1], (Real)0.6 );
809 BOOST_CHECK_EQUAL( z[2], (Real)0.0 );
810 z = max( x, y );
811 BOOST_CHECK_EQUAL( z[0], (Real)0.2 );
812 BOOST_CHECK_EQUAL( z[1], (Real)0.8 );
813 BOOST_CHECK_EQUAL( z[2], (Real)0.4 );
814
815 BOOST_CHECK_CLOSE( dist( x, x, DISTL1 ), (Real)0.0, tol );
816 BOOST_CHECK_CLOSE( dist( y, y, DISTL1 ), (Real)0.0, tol );
817 BOOST_CHECK_CLOSE( dist( x, y, DISTL1 ), (Real)(0.2 + 0.2 + 0.4), tol );
818 BOOST_CHECK_CLOSE( dist( y, x, DISTL1 ), (Real)(0.2 + 0.2 + 0.4), tol );
819 BOOST_CHECK_CLOSE( dist( x, x, DISTLINF ), (Real)0.0, tol );
820 BOOST_CHECK_CLOSE( dist( y, y, DISTLINF ), (Real)0.0, tol );
821 BOOST_CHECK_CLOSE( dist( x, y, DISTLINF ), (Real)0.4, tol );
822 BOOST_CHECK_CLOSE( dist( y, x, DISTLINF ), (Real)0.4, tol );
823 BOOST_CHECK_CLOSE( dist( x, x, DISTTV ), (Real)0.0, tol );
824 BOOST_CHECK_CLOSE( dist( y, y, DISTTV ), (Real)0.0, tol );
825 BOOST_CHECK_CLOSE( dist( x, y, DISTTV ), (Real)(0.5 * (0.2 + 0.2 + 0.4)), tol );
826 BOOST_CHECK_CLOSE( dist( y, x, DISTTV ), (Real)(0.5 * (0.2 + 0.2 + 0.4)), tol );
827 BOOST_CHECK_CLOSE( dist( x, x, DISTKL ), (Real)0.0, tol );
828 BOOST_CHECK_CLOSE( dist( y, y, DISTKL ), (Real)0.0, tol );
829 BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), INFINITY );
830 BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), INFINITY );
831 BOOST_CHECK_CLOSE( dist( x, x, DISTHEL ), (Real)0.0, tol );
832 BOOST_CHECK_CLOSE( dist( y, y, DISTHEL ), (Real)0.0, tol );
833 BOOST_CHECK_CLOSE( dist( x, y, DISTHEL ), (Real)(0.5 * (0.2 + dai::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4)), tol );
834 BOOST_CHECK_CLOSE( dist( y, x, DISTHEL ), (Real)(0.5 * (0.2 + dai::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4)), tol );
835 x.set( 1, 0.7 ); x.set( 2, 0.1 );
836 y.set( 0, 0.1 ); y.set( 1, 0.5 );
837 BOOST_CHECK_CLOSE( dist( x, y, DISTKL ), (Real)(0.2 * dai::log(0.2 / 0.1) + 0.7 * dai::log(0.7 / 0.5) + 0.1 * dai::log(0.1 / 0.4)), tol );
838 BOOST_CHECK_CLOSE( dist( y, x, DISTKL ), (Real)(0.1 * dai::log(0.1 / 0.2) + 0.5 * dai::log(0.5 / 0.7) + 0.4 * dai::log(0.4 / 0.1)), tol );
839
840 std::stringstream ss;
841 ss << x;
842 std::string s;
843 std::getline( ss, s );
844 BOOST_CHECK_EQUAL( s, std::string("({x0}, (0.2, 0.7, 0.1))") );
845 std::stringstream ss2;
846 ss2 << y;
847 std::getline( ss2, s );
848 BOOST_CHECK_EQUAL( s, std::string("({x0}, (0.1, 0.5, 0.4))") );
849
850 z = min( x, y );
851 BOOST_CHECK_EQUAL( z[0], (Real)0.1 );
852 BOOST_CHECK_EQUAL( z[1], (Real)0.5 );
853 BOOST_CHECK_EQUAL( z[2], (Real)0.1 );
854 z = max( x, y );
855 BOOST_CHECK_EQUAL( z[0], (Real)0.2 );
856 BOOST_CHECK_EQUAL( z[1], (Real)0.7 );
857 BOOST_CHECK_EQUAL( z[2], (Real)0.4 );
858
859 for( Real J = -1.0; J <= 1.01; J += 0.1 ) {
860 Factor x = createFactorIsing( Var(0,2), Var(1,2), J ).normalized();
861 BOOST_CHECK_CLOSE( x[0], dai::exp(J) / (4.0 * std::cosh(J)), tol );
862 BOOST_CHECK_CLOSE( x[1], dai::exp(-J) / (4.0 * std::cosh(J)), tol );
863 BOOST_CHECK_CLOSE( x[2], dai::exp(-J) / (4.0 * std::cosh(J)), tol );
864 BOOST_CHECK_CLOSE( x[3], dai::exp(J) / (4.0 * std::cosh(J)), tol );
865 BOOST_CHECK_SMALL( MutualInfo( x ) - (J * std::tanh(J) - dai::log(std::cosh(J))), tol );
866 }
867 Var v1( 1, 3 );
868 Var v2( 2, 4 );
869 BOOST_CHECK_SMALL( MutualInfo( (Factor(v1).randomize() * Factor(v2).randomize()).normalized() ), tol );
870 BOOST_CHECK_THROW( MutualInfo( createFactorIsing( Var(0,2), 1.0 ).normalized() ), Exception );
871 BOOST_CHECK_THROW( createFactorIsing( v1, 0.0 ), Exception );
872 BOOST_CHECK_THROW( createFactorIsing( v1, v2, 0.0 ), Exception );
873 for( Real J = -1.0; J <= 1.01; J += 0.1 ) {
874 Factor x = createFactorIsing( Var(0,2), J ).normalized();
875 BOOST_CHECK_CLOSE( x[0], dai::exp(-J) / (2.0 * std::cosh(J)), tol );
876 BOOST_CHECK_CLOSE( x[1], dai::exp(J) / (2.0 * std::cosh(J)), tol );
877 BOOST_CHECK_SMALL( x.entropy() - (-J * std::tanh(J) + dai::log(2.0 * std::cosh(J))), tol );
878 }
879
880 x = createFactorDelta( v1, 2 );
881 BOOST_CHECK_EQUAL( x[0], (Real)0.0 );
882 BOOST_CHECK_EQUAL( x[1], (Real)0.0 );
883 BOOST_CHECK_EQUAL( x[2], (Real)1.0 );
884 x = createFactorDelta( v1, 1 );
885 BOOST_CHECK_EQUAL( x[0], (Real)0.0 );
886 BOOST_CHECK_EQUAL( x[1], (Real)1.0 );
887 BOOST_CHECK_EQUAL( x[2], (Real)0.0 );
888 x = createFactorDelta( v1, 0 );
889 BOOST_CHECK_EQUAL( x[0], (Real)1.0 );
890 BOOST_CHECK_EQUAL( x[1], (Real)0.0 );
891 BOOST_CHECK_EQUAL( x[2], (Real)0.0 );
892 BOOST_CHECK_THROW( createFactorDelta( v1, 4 ), Exception );
893 }