5edf5d50080de9159d2233243c47075318732066
[libdai.git] / tests / unit / factor_test.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/factor.h>
10 #include <strstream>
11
12
13 using namespace dai;
14
15
16 const Real tol = 1e-8;
17
18
19 #define BOOST_TEST_MODULE FactorTest
20
21
22 #include <boost/test/unit_test.hpp>
23 #include <boost/test/floating_point_comparison.hpp>
24
25
26 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
27 // check constructors
28 Factor x1;
29 BOOST_CHECK_EQUAL( x1.nrStates(), 1 );
30 BOOST_CHECK( x1.p() == Prob( 1, 1.0 ) );
31 BOOST_CHECK( x1.vars() == VarSet() );
32
33 Factor x2( 5.0 );
34 BOOST_CHECK_EQUAL( x2.nrStates(), 1 );
35 BOOST_CHECK( x2.p() == Prob( 1, 5.0 ) );
36 BOOST_CHECK( x2.vars() == VarSet() );
37
38 Var v1( 0, 3 );
39 Factor x3( v1 );
40 BOOST_CHECK_EQUAL( x3.nrStates(), 3 );
41 BOOST_CHECK_SMALL( dist( x3.p(), Prob( 3, 1.0 / 3.0 ), DISTL1 ), tol );
42 BOOST_CHECK( x3.vars() == VarSet( v1 ) );
43 BOOST_CHECK_CLOSE( x3[0], (Real)(1.0 / 3.0), tol );
44 BOOST_CHECK_CLOSE( x3[1], (Real)(1.0 / 3.0), tol );
45 BOOST_CHECK_CLOSE( x3[2], (Real)(1.0 / 3.0), tol );
46
47 Var v2( 1, 2 );
48 Factor x4( VarSet( v1, v2 ) );
49 BOOST_CHECK_EQUAL( x4.nrStates(), 6 );
50 BOOST_CHECK_SMALL( dist( x4.p(), Prob( 6, 1.0 / 6.0 ), DISTL1 ), tol );
51 BOOST_CHECK( x4.vars() == VarSet( v1, v2 ) );
52 for( size_t i = 0; i < 6; i++ )
53 BOOST_CHECK_CLOSE( x4[i], (Real)(1.0 / 6.0), tol );
54
55 Factor x5( VarSet( v1, v2 ), 1.0 );
56 BOOST_CHECK_EQUAL( x5.nrStates(), 6 );
57 BOOST_CHECK( x5.p() == Prob( 6, 1.0 ) );
58 BOOST_CHECK( x5.vars() == VarSet( v1, v2 ) );
59 for( size_t i = 0; i < 6; i++ )
60 BOOST_CHECK_EQUAL( x5[i], (Real)1.0 );
61
62 std::vector<Real> x( 6, 1.0 );
63 for( size_t i = 0; i < 6; i++ )
64 x[i] = 10.0 - i;
65 Factor x6( VarSet( v1, v2 ), x );
66 BOOST_CHECK_EQUAL( x6.nrStates(), 6 );
67 BOOST_CHECK( x6.vars() == VarSet( v1, v2 ) );
68 for( size_t i = 0; i < 6; i++ )
69 BOOST_CHECK_EQUAL( x6[i], x[i] );
70
71 x.resize( 4 );
72 BOOST_CHECK_THROW( Factor x7( VarSet( v1, v2 ), x ), Exception );
73
74 x.resize( 6 );
75 x[4] = 10.0 - 4; x[5] = 10.0 - 5;
76 Factor x8( VarSet( v2, v1 ), &(x[0]) );
77 BOOST_CHECK_EQUAL( x8.nrStates(), 6 );
78 BOOST_CHECK( x8.vars() == VarSet( v1, v2 ) );
79 for( size_t i = 0; i < 6; i++ )
80 BOOST_CHECK_EQUAL( x8[i], x[i] );
81
82 Prob xx( x );
83 Factor x9( VarSet( v2, v1 ), xx );
84 BOOST_CHECK_EQUAL( x9.nrStates(), 6 );
85 BOOST_CHECK( x9.vars() == VarSet( v1, v2 ) );
86 for( size_t i = 0; i < 6; i++ )
87 BOOST_CHECK_EQUAL( x9[i], x[i] );
88
89 xx.resize( 4 );
90 BOOST_CHECK_THROW( Factor x10( VarSet( v2, v1 ), xx ), Exception );
91
92 std::vector<Real> w;
93 w.push_back( 0.1 );
94 w.push_back( 3.5 );
95 w.push_back( 2.8 );
96 w.push_back( 6.3 );
97 w.push_back( 8.4 );
98 w.push_back( 0.0 );
99 w.push_back( 7.4 );
100 w.push_back( 2.4 );
101 w.push_back( 8.9 );
102 w.push_back( 1.3 );
103 w.push_back( 1.6 );
104 w.push_back( 2.6 );
105 Var v4( 4, 3 );
106 Var v8( 8, 2 );
107 Var v7( 7, 2 );
108 std::vector<Var> vars;
109 vars.push_back( v4 );
110 vars.push_back( v8 );
111 vars.push_back( v7 );
112 Factor x11( vars, w );
113 BOOST_CHECK_EQUAL( x11.nrStates(), 12 );
114 BOOST_CHECK( x11.vars() == VarSet( vars.begin(), vars.end() ) );
115 BOOST_CHECK_EQUAL( x11[0], (Real)0.1 );
116 BOOST_CHECK_EQUAL( x11[1], (Real)3.5 );
117 BOOST_CHECK_EQUAL( x11[2], (Real)2.8 );
118 BOOST_CHECK_EQUAL( x11[3], (Real)7.4 );
119 BOOST_CHECK_EQUAL( x11[4], (Real)2.4 );
120 BOOST_CHECK_EQUAL( x11[5], (Real)8.9 );
121 BOOST_CHECK_EQUAL( x11[6], (Real)6.3 );
122 BOOST_CHECK_EQUAL( x11[7], (Real)8.4 );
123 BOOST_CHECK_EQUAL( x11[8], (Real)0.0 );
124 BOOST_CHECK_EQUAL( x11[9], (Real)1.3 );
125 BOOST_CHECK_EQUAL( x11[10], (Real)1.6 );
126 BOOST_CHECK_EQUAL( x11[11], (Real)2.6 );
127
128 Factor x12( x11 );
129 BOOST_CHECK( x12 == x11 );
130
131 Factor x13 = x12;
132 BOOST_CHECK( x13 == x11 );
133 }
134
135
136 BOOST_AUTO_TEST_CASE( QueriesTest ) {
137 Factor x( Var( 5, 5 ), 0.0 );
138 for( size_t i = 0; i < x.nrStates(); i++ )
139 x.set( i, 2.0 - i );
140
141 // test min, max, sum, sumAbs, maxAbs
142 BOOST_CHECK_CLOSE( x.sum(), (Real)0.0, tol );
143 BOOST_CHECK_CLOSE( x.max(), (Real)2.0, tol );
144 BOOST_CHECK_CLOSE( x.min(), (Real)-2.0, tol );
145 BOOST_CHECK_CLOSE( x.sumAbs(), (Real)6.0, tol );
146 BOOST_CHECK_CLOSE( x.maxAbs(), (Real)2.0, tol );
147 x.set( 1, 1.0 );
148 BOOST_CHECK_CLOSE( x.maxAbs(), (Real)2.0, tol );
149 x /= x.sum();
150
151 // test entropy
152 BOOST_CHECK( x.entropy() < Prob(5).entropy() );
153 for( size_t i = 1; i < 100; i++ )
154 BOOST_CHECK_CLOSE( Factor( Var(0,i) ).entropy(), dai::log((Real)i), tol );
155
156 // test hasNaNs and hasNegatives
157 BOOST_CHECK( !Factor( 0.0 ).hasNaNs() );
158 Real c = 0.0;
159 BOOST_CHECK( Factor( c / c ).hasNaNs() );
160 BOOST_CHECK( !Factor( 0.0 ).hasNegatives() );
161 BOOST_CHECK( !Factor( 1.0 ).hasNegatives() );
162 BOOST_CHECK( Factor( -1.0 ).hasNegatives() );
163 x.set( 0, 0.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 ); x.set( 3, 1.0 ); x.set( 4, 100.0 );
164 BOOST_CHECK( x.hasNegatives() );
165 x.set( 2, -INFINITY );
166 BOOST_CHECK( x.hasNegatives() );
167 x.set( 2, INFINITY );
168 BOOST_CHECK( !x.hasNegatives() );
169 x.set( 2, -1.0 );
170
171 // test strength
172 Var x0(0,2);
173 Var x1(1,2);
174 BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, 1.0 ).strength( x0, x1 ), std::tanh( (Real)1.0 ), tol );
175 BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, -1.0 ).strength( x0, x1 ), std::tanh( (Real)1.0 ), tol );
176 BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, 0.5 ).strength( x0, x1 ), std::tanh( (Real)0.5 ), tol );
177
178 // test ==
179 Factor a(Var(0,3)), b(Var(0,3));
180 Factor d(Var(1,3));
181 BOOST_CHECK( !(a == d) );
182 BOOST_CHECK( !(b == d) );
183 BOOST_CHECK( a == b );
184 a.set( 0, 0.0 );
185 BOOST_CHECK( !(a == b) );
186 b.set( 2, 0.0 );
187 BOOST_CHECK( !(a == b) );
188 b.set( 0, 0.0 );
189 BOOST_CHECK( !(a == b) );
190 a.set( 1, 0.0 );
191 BOOST_CHECK( !(a == b) );
192 b.set( 1, 0.0 );
193 BOOST_CHECK( !(a == b) );
194 a.set( 2, 0.0 );
195 BOOST_CHECK( a == b );
196 }
197
198
199 BOOST_AUTO_TEST_CASE( UnaryTransformationsTest ) {
200 Var v( 0, 3 );
201 Factor x( v );
202 x.set( 0, -2.0 );
203 x.set( 1, 0.0 );
204 x.set( 2, 2.0 );
205
206 Factor y = -x;
207 BOOST_CHECK_CLOSE( y[0], (Real)2.0, tol );
208 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
209 BOOST_CHECK_CLOSE( y[2], (Real)-2.0, tol );
210
211 y = x.abs();
212 BOOST_CHECK_CLOSE( y[0], (Real)2.0, tol );
213 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
214 BOOST_CHECK_CLOSE( y[2], (Real)2.0, tol );
215
216 y = x.exp();
217 BOOST_CHECK_CLOSE( y[0], dai::exp((Real)-2.0), tol );
218 BOOST_CHECK_CLOSE( y[1], (Real)1.0, tol );
219 BOOST_CHECK_CLOSE( y[2], (Real)1.0 / y[0], tol );
220
221 y = x.log(false);
222 BOOST_CHECK( dai::isnan( y[0] ) );
223 BOOST_CHECK_EQUAL( y[1], -INFINITY );
224 BOOST_CHECK_CLOSE( y[2], dai::log((Real)2.0), tol );
225
226 y = x.log(true);
227 BOOST_CHECK( dai::isnan( y[0] ) );
228 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
229 BOOST_CHECK_CLOSE( y[2], dai::log((Real)2.0), tol );
230
231 y = x.inverse(false);
232 BOOST_CHECK_CLOSE( y[0], (Real)-0.5, tol );
233 BOOST_CHECK_EQUAL( y[1], INFINITY );
234 BOOST_CHECK_CLOSE( y[2], (Real)0.5, tol );
235
236 y = x.inverse(true);
237 BOOST_CHECK_CLOSE( y[0], (Real)-0.5, tol );
238 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
239 BOOST_CHECK_CLOSE( y[2], (Real)0.5, tol );
240
241 x.set( 0, 2.0 );
242 y = x.normalized();
243 BOOST_CHECK_CLOSE( y[0], (Real)0.5, tol );
244 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
245 BOOST_CHECK_CLOSE( y[2], (Real)0.5, tol );
246
247 y = x.normalized( NORMPROB );
248 BOOST_CHECK_CLOSE( y[0], (Real)0.5, tol );
249 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
250 BOOST_CHECK_CLOSE( y[2], (Real)0.5, tol );
251
252 x.set( 0, -2.0 );
253 y = x.normalized( NORMLINF );
254 BOOST_CHECK_CLOSE( y[0], (Real)-1.0, tol );
255 BOOST_CHECK_CLOSE( y[1], (Real)0.0, tol );
256 BOOST_CHECK_CLOSE( y[2], (Real)1.0, tol );
257 }
258
259
260 BOOST_AUTO_TEST_CASE( UnaryOperationsTest ) {
261 Var v( 0, 3 );
262 Factor xorg( v );
263 xorg.set( 0, 2.0 );
264 xorg.set( 1, 0.0 );
265 xorg.set( 2, 1.0 );
266 Factor y( v );
267
268 Factor x = xorg;
269 BOOST_CHECK( x.setUniform() == Factor( v ) );
270 BOOST_CHECK( x == Factor( v ) );
271
272 y.set( 0, dai::exp(2.0) );
273 y.set( 1, 1.0 );
274 y.set( 2, dai::exp(1.0) );
275 x = xorg;
276 BOOST_CHECK_SMALL( dist( x.takeExp(), y, DISTL1 ), tol );
277 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
278
279 y.set( 0, dai::log(2.0) );
280 y.set( 1, -INFINITY );
281 y.set( 2, 0.0 );
282 x = xorg;
283 Factor z = x.takeLog();
284 BOOST_CHECK_CLOSE( z[0], y[0], tol );
285 //BOOST_CHECK_CLOSE( z[1], y[1], tol );
286 BOOST_CHECK_CLOSE( z[2], y[2], tol );
287 BOOST_CHECK( z.vars() == y.vars() );
288 BOOST_CHECK( x == z );
289 x = xorg;
290 z = x.takeLog(false);
291 BOOST_CHECK_CLOSE( z[0], y[0], tol );
292 //BOOST_CHECK_CLOSE( z[1], y[1], tol );
293 BOOST_CHECK_CLOSE( z[2], y[2], tol );
294 BOOST_CHECK( z.vars() == y.vars() );
295 BOOST_CHECK( x == z );
296
297 y.set( 1, 0.0 );
298 x = xorg;
299 BOOST_CHECK_SMALL( dist( x.takeLog(true), y, DISTL1 ), tol );
300 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
301
302 y.set( 0, 2.0 / 3.0 );
303 y.set( 1, 0.0 / 3.0 );
304 y.set( 2, 1.0 / 3.0 );
305 x = xorg;
306 BOOST_CHECK_CLOSE( x.normalize(), (Real)3.0, tol );
307 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
308
309 x = xorg;
310 BOOST_CHECK_CLOSE( x.normalize( NORMPROB ), (Real)3.0, tol );
311 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
312
313 y.set( 0, 2.0 / 2.0 );
314 y.set( 1, 0.0 / 2.0 );
315 y.set( 2, 1.0 / 2.0 );
316 x = xorg;
317 BOOST_CHECK_CLOSE( x.normalize( NORMLINF ), (Real)2.0, tol );
318 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
319
320 xorg.set( 0, -2.0 );
321 y.set( 0, 2.0 );
322 y.set( 1, 0.0 );
323 y.set( 2, 1.0 );
324 x = xorg;
325 BOOST_CHECK_SMALL( dist( x.takeAbs(), y, DISTL1 ), tol );
326 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
327
328 for( size_t repeat = 0; repeat < 10000; repeat++ ) {
329 x.randomize();
330 for( size_t i = 0; i < x.nrStates(); i++ ) {
331 BOOST_CHECK( x[i] < (Real)1.0 );
332 BOOST_CHECK( x[i] >= (Real)0.0 );
333 }
334 }
335 }
336
337
338 BOOST_AUTO_TEST_CASE( ScalarOperationsTest ) {
339 Var v( 0, 3 );
340 Factor xorg( v ), x( v );
341 xorg.set( 0, 2.0 );
342 xorg.set( 1, 0.0 );
343 xorg.set( 2, 1.0 );
344 Factor y( v );
345
346 x = xorg;
347 BOOST_CHECK( x.fill( 1.0 ) == Factor(v, 1.0) );
348 BOOST_CHECK( x == Factor(v, 1.0) );
349 BOOST_CHECK( x.fill( 2.0 ) == Factor(v, 2.0) );
350 BOOST_CHECK( x == Factor(v, 2.0) );
351 BOOST_CHECK( x.fill( 0.0 ) == Factor(v, 0.0) );
352 BOOST_CHECK( x == Factor(v, 0.0) );
353
354 x = xorg;
355 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
356 BOOST_CHECK_SMALL( dist( (x += 1.0), y, DISTL1 ), tol );
357 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
358 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
359 BOOST_CHECK_SMALL( dist( (x += -2.0), y, DISTL1 ), tol );
360 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
361
362 x = xorg;
363 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
364 BOOST_CHECK_SMALL( dist( (x -= 1.0), y, DISTL1 ), tol );
365 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
366 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
367 BOOST_CHECK_SMALL( dist( (x -= -2.0), y, DISTL1 ), tol );
368 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
369
370 x = xorg;
371 BOOST_CHECK_SMALL( dist( (x *= 1.0), x, DISTL1 ), tol );
372 BOOST_CHECK_SMALL( dist( x, x, DISTL1 ), tol );
373 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
374 BOOST_CHECK_SMALL( dist( (x *= 2.0), y, DISTL1 ), tol );
375 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
376 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
377 BOOST_CHECK_SMALL( dist( (x *= -0.25), y, DISTL1 ), tol );
378 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
379
380 x = xorg;
381 BOOST_CHECK_SMALL( dist( (x /= 1.0), x, DISTL1 ), tol );
382 BOOST_CHECK_SMALL( dist( x, x, DISTL1 ), tol );
383 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
384 BOOST_CHECK_SMALL( dist( (x /= 2.0), y, DISTL1 ), tol );
385 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
386 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
387 BOOST_CHECK_SMALL( dist( (x /= -0.25), y, DISTL1 ), tol );
388 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
389 BOOST_CHECK_SMALL( dist( (x /= 0.0), Factor(v, 0.0), DISTL1 ), tol );
390 BOOST_CHECK_SMALL( dist( x, Factor(v, 0.0), DISTL1 ), tol );
391
392 x = xorg;
393 BOOST_CHECK_SMALL( dist( (x ^= 1.0), x, DISTL1 ), tol );
394 BOOST_CHECK_SMALL( dist( x, x, DISTL1 ), tol );
395 BOOST_CHECK_SMALL( dist( (x ^= 0.0), Factor(v, 1.0), DISTL1 ), tol );
396 BOOST_CHECK_SMALL( dist( x, Factor(v, 1.0), DISTL1 ), tol );
397 x = xorg;
398 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
399 BOOST_CHECK_SMALL( dist( (x ^= 2.0), y, DISTL1 ), tol );
400 BOOST_CHECK_SMALL( dist( x, y, DISTL1 ), tol );
401 y.set( 0, 2.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
402 BOOST_CHECK( (x ^= 0.5) == y );
403 BOOST_CHECK( x == y );
404 }
405
406
407 BOOST_AUTO_TEST_CASE( ScalarTransformationsTest ) {
408 Var v( 0, 3 );
409 Factor x( v );
410 x.set( 0, 2.0 );
411 x.set( 1, 0.0 );
412 x.set( 2, 1.0 );
413 Factor y( v );
414
415 y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
416 BOOST_CHECK_SMALL( dist( (x + 1.0), y, DISTL1 ), tol );
417 y.set( 0, 0.0 ); y.set( 1, -2.0 ); y.set( 2, -1.0 );
418 BOOST_CHECK_SMALL( dist( (x + (-2.0)), y, DISTL1 ), tol );
419
420 y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
421 BOOST_CHECK_SMALL( dist( (x - 1.0), y, DISTL1 ), tol );
422 y.set( 0, 4.0 ); y.set( 1, 2.0 ); y.set( 2, 3.0 );
423 BOOST_CHECK_SMALL( dist( (x - (-2.0)), y, DISTL1 ), tol );
424
425 BOOST_CHECK_SMALL( dist( (x * 1.0), x, DISTL1 ), tol );
426 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
427 BOOST_CHECK_SMALL( dist( (x * 2.0), y, DISTL1 ), tol );
428 y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
429 BOOST_CHECK_SMALL( dist( (x * -0.5), y, DISTL1 ), tol );
430
431 BOOST_CHECK_SMALL( dist( (x / 1.0), x, DISTL1 ), tol );
432 y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
433 BOOST_CHECK_SMALL( dist( (x / 2.0), y, DISTL1 ), tol );
434 y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
435 BOOST_CHECK_SMALL( dist( (x / -0.5), y, DISTL1 ), tol );
436 BOOST_CHECK_SMALL( dist( (x / 0.0), Factor(v, 0.0), DISTL1 ), tol );
437
438 BOOST_CHECK_SMALL( dist( (x ^ 1.0), x, DISTL1 ), tol );
439 BOOST_CHECK_SMALL( dist( (x ^ 0.0), Factor(v, 1.0), DISTL1 ), tol );
440 y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
441 BOOST_CHECK_SMALL( dist( (x ^ 2.0), y, DISTL1 ), tol );
442 y.set( 0, std::sqrt(2.0) ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
443 BOOST_CHECK_SMALL( dist( (x ^ 0.5), y, DISTL1 ), tol );
444 }
445
446
447 BOOST_AUTO_TEST_CASE( SimilarFactorOperationsTest ) {
448 size_t N = 6;
449 Var v( 0, N );
450 Factor xorg( v ), x( v );
451 xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, 1.0 ); xorg.set( 3, 0.0 ); xorg.set( 4, 2.0 ); xorg.set( 5, 3.0 );
452 Factor y( v );
453 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
454 Factor z( v ), r( v );
455
456 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
457 x = xorg;
458 r = (x += y);
459 for( size_t i = 0; i < N; i++ )
460 BOOST_CHECK_CLOSE( r[i], z[i], tol );
461 BOOST_CHECK( x == r );
462 x = xorg;
463 BOOST_CHECK( x.binaryOp( y, std::plus<Real>() ) == r );
464 BOOST_CHECK( x == r );
465
466 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
467 x = xorg;
468 r = (x -= y);
469 for( size_t i = 0; i < N; i++ )
470 BOOST_CHECK_CLOSE( r[i], z[i], tol );
471 BOOST_CHECK( x == r );
472 x = xorg;
473 BOOST_CHECK( x.binaryOp( y, std::minus<Real>() ) == r );
474 BOOST_CHECK( x == r );
475
476 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
477 x = xorg;
478 r = (x *= y);
479 for( size_t i = 0; i < N; i++ )
480 BOOST_CHECK_CLOSE( r[i], z[i], tol );
481 BOOST_CHECK( x == r );
482 x = xorg;
483 BOOST_CHECK( x.binaryOp( y, std::multiplies<Real>() ) == r );
484 BOOST_CHECK( x == r );
485
486 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
487 x = xorg;
488 r = (x /= y);
489 for( size_t i = 0; i < N; i++ )
490 BOOST_CHECK_CLOSE( r[i], z[i], tol );
491 BOOST_CHECK( x == r );
492 x = xorg;
493 BOOST_CHECK( x.binaryOp( y, fo_divides0<Real>() ) == r );
494 BOOST_CHECK( x == r );
495 }
496
497
498 BOOST_AUTO_TEST_CASE( SimilarFactorTransformationsTest ) {
499 size_t N = 6;
500 Var v( 0, N );
501 Factor x( v );
502 x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, 1.0 ); x.set( 3, 0.0 ); x.set( 4, 2.0 ); x.set( 5, 3.0 );
503 Factor y( v );
504 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
505 Factor z( v ), r( v );
506
507 z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
508 r = x + y;
509 for( size_t i = 0; i < N; i++ )
510 BOOST_CHECK_CLOSE( r[i], z[i], tol );
511 z = x.binaryTr( y, std::plus<Real>() );
512 BOOST_CHECK( r == z );
513
514 z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
515 r = x - y;
516 for( size_t i = 0; i < N; i++ )
517 BOOST_CHECK_CLOSE( r[i], z[i], tol );
518 z = x.binaryTr( y, std::minus<Real>() );
519 BOOST_CHECK( r == z );
520
521 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
522 r = x * y;
523 for( size_t i = 0; i < N; i++ )
524 BOOST_CHECK_CLOSE( r[i], z[i], tol );
525 z = x.binaryTr( y, std::multiplies<Real>() );
526 BOOST_CHECK( r == z );
527
528 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
529 r = x / y;
530 for( size_t i = 0; i < N; i++ )
531 BOOST_CHECK_CLOSE( r[i], z[i], tol );
532 z = x.binaryTr( y, fo_divides0<Real>() );
533 BOOST_CHECK( r == z );
534 }
535
536
537 BOOST_AUTO_TEST_CASE( FactorOperationsTest ) {
538 size_t N = 9;
539 Var v1( 1, 3 );
540 Var v2( 2, 3 );
541 Factor xorg( v1 ), x( v1 );
542 xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, -1.0 );
543 Factor y( v2 );
544 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
545 Factor r;
546
547 Factor z( VarSet( v1, v2 ) );
548 z.set( 0, 2.5 ); z.set( 1, 0.5 ); z.set( 2, -0.5 );
549 z.set( 3, 1.0 ); z.set( 4, -1.0 ); z.set( 5, -2.0 );
550 z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
551 x = xorg;
552 r = (x += y);
553 for( size_t i = 0; i < N; i++ )
554 BOOST_CHECK_CLOSE( r[i], z[i], tol );
555 BOOST_CHECK( x == r );
556 x = xorg;
557 BOOST_CHECK( x.binaryOp( y, std::plus<Real>() ) == r );
558 BOOST_CHECK( x == r );
559
560 z.set( 0, 1.5 ); z.set( 1, -0.5 ); z.set( 2, -1.5 );
561 z.set( 3, 3.0 ); z.set( 4, 1.0 ); z.set( 5, 0.0 );
562 z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
563 x = xorg;
564 r = (x -= y);
565 for( size_t i = 0; i < N; i++ )
566 BOOST_CHECK_CLOSE( r[i], z[i], tol );
567 BOOST_CHECK( x == r );
568 x = xorg;
569 BOOST_CHECK( x.binaryOp( y, std::minus<Real>() ) == r );
570 BOOST_CHECK( x == r );
571
572 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, -0.5 );
573 z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
574 z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
575 x = xorg;
576 r = (x *= y);
577 for( size_t i = 0; i < N; i++ )
578 BOOST_CHECK_CLOSE( r[i], z[i], tol );
579 BOOST_CHECK( x == r );
580 x = xorg;
581 BOOST_CHECK( x.binaryOp( y, std::multiplies<Real>() ) == r );
582 BOOST_CHECK( x == r );
583
584 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, -2.0 );
585 z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
586 z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
587 x = xorg;
588 r = (x /= y);
589 for( size_t i = 0; i < N; i++ )
590 BOOST_CHECK_CLOSE( r[i], z[i], tol );
591 BOOST_CHECK( x == r );
592 x = xorg;
593 BOOST_CHECK( x.binaryOp( y, fo_divides0<Real>() ) == r );
594 BOOST_CHECK( x == r );
595 }
596
597
598 BOOST_AUTO_TEST_CASE( FactorTransformationsTest ) {
599 size_t N = 9;
600 Var v1( 1, 3 );
601 Var v2( 2, 3 );
602 Factor x( v1 );
603 x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 );
604 Factor y( v2 );
605 y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
606 Factor r;
607
608 Factor z( VarSet( v1, v2 ) );
609 z.set( 0, 2.5 ); z.set( 1, 0.5 ); z.set( 2, -0.5 );
610 z.set( 3, 1.0 ); z.set( 4, -1.0 ); z.set( 5, -2.0 );
611 z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
612 r = x + y;
613 for( size_t i = 0; i < N; i++ )
614 BOOST_CHECK_CLOSE( r[i], z[i], tol );
615 z = x.binaryTr( y, std::plus<Real>() );
616 BOOST_CHECK( r == z );
617
618 z.set( 0, 1.5 ); z.set( 1, -0.5 ); z.set( 2, -1.5 );
619 z.set( 3, 3.0 ); z.set( 4, 1.0 ); z.set( 5, 0.0 );
620 z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
621 r = x - y;
622 for( size_t i = 0; i < N; i++ )
623 BOOST_CHECK_CLOSE( r[i], z[i], tol );
624 z = x.binaryTr( y, std::minus<Real>() );
625 BOOST_CHECK( r == z );
626
627 z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, -0.5 );
628 z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
629 z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
630 r = x * y;
631 for( size_t i = 0; i < N; i++ )
632 BOOST_CHECK_CLOSE( r[i], z[i], tol );
633 z = x.binaryTr( y, std::multiplies<Real>() );
634 BOOST_CHECK( r == z );
635
636 z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, -2.0 );
637 z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
638 z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
639 r = x / y;
640 for( size_t i = 0; i < N; i++ )
641 BOOST_CHECK_CLOSE( r[i], z[i], tol );
642 z = x.binaryOp( y, fo_divides0<Real>() );
643 BOOST_CHECK( r == z );
644 }
645
646
647 BOOST_AUTO_TEST_CASE( MiscOperationsTest ) {
648 Var v1(1, 2);
649 Var v2(2, 3);
650 Factor x( VarSet( v1, v2 ) );
651 x.randomize();
652
653 // slice
654 Factor y = x.slice( v1, 0 );
655 BOOST_CHECK( y.vars() == VarSet( v2 ) );
656 BOOST_CHECK_EQUAL( y.nrStates(), 3 );
657 BOOST_CHECK_EQUAL( y[0], x[0] );
658 BOOST_CHECK_EQUAL( y[1], x[2] );
659 BOOST_CHECK_EQUAL( y[2], x[4] );
660 y = x.slice( v1, 1 );
661 BOOST_CHECK( y.vars() == VarSet( v2 ) );
662 BOOST_CHECK_EQUAL( y.nrStates(), 3 );
663 BOOST_CHECK_EQUAL( y[0], x[1] );
664 BOOST_CHECK_EQUAL( y[1], x[3] );
665 BOOST_CHECK_EQUAL( y[2], x[5] );
666 y = x.slice( v2, 0 );
667 BOOST_CHECK( y.vars() == VarSet( v1 ) );
668 BOOST_CHECK_EQUAL( y.nrStates(), 2 );
669 BOOST_CHECK_EQUAL( y[0], x[0] );
670 BOOST_CHECK_EQUAL( y[1], x[1] );
671 y = x.slice( v2, 1 );
672 BOOST_CHECK( y.vars() == VarSet( v1 ) );
673 BOOST_CHECK_EQUAL( y.nrStates(), 2 );
674 BOOST_CHECK_EQUAL( y[0], x[2] );
675 BOOST_CHECK_EQUAL( y[1], x[3] );
676 y = x.slice( v2, 2 );
677 BOOST_CHECK( y.vars() == VarSet( v1 ) );
678 BOOST_CHECK_EQUAL( y.nrStates(), 2 );
679 BOOST_CHECK_EQUAL( y[0], x[4] );
680 BOOST_CHECK_EQUAL( y[1], x[5] );
681 for( size_t i = 0; i < x.nrStates(); i++ ) {
682 y = x.slice( VarSet( v1, v2 ), 0 );
683 BOOST_CHECK( y.vars() == VarSet() );
684 BOOST_CHECK_EQUAL( y.nrStates(), 1 );
685 BOOST_CHECK_EQUAL( y[0], x[0] );
686 }
687 y = x.slice( VarSet(), 0 );
688 BOOST_CHECK_EQUAL( y, x );
689
690 // embed
691 Var v3(3, 4);
692 BOOST_CHECK_THROW( x.embed( VarSet( v3 ) ), Exception );
693 BOOST_CHECK_THROW( x.embed( VarSet( v3, v2 ) ), Exception );
694 y = x.embed( VarSet( v3, v2 ) | v1 );
695 for( size_t i = 0; i < y.nrStates(); i++ )
696 BOOST_CHECK_EQUAL( y[i], x[i % 6] );
697 y = x.embed( VarSet( v1, v2 ) );
698 BOOST_CHECK_EQUAL( x, y );
699
700 // marginal
701 y = x.marginal( v1 );
702 BOOST_CHECK( y.vars() == VarSet( v1 ) );
703 BOOST_CHECK_CLOSE( y[0], (x[0] + x[2] + x[4]) / x.sum(), tol );
704 BOOST_CHECK_CLOSE( y[1], (x[1] + x[3] + x[5]) / x.sum(),tol );
705 y = x.marginal( v2 );
706 BOOST_CHECK( y.vars() == VarSet( v2 ) );
707 BOOST_CHECK_CLOSE( y[0], (x[0] + x[1]) / x.sum(), tol );
708 BOOST_CHECK_CLOSE( y[1], (x[2] + x[3]) / x.sum(), tol );
709 BOOST_CHECK_CLOSE( y[2], (x[4] + x[5]) / x.sum(), tol );
710 y = x.marginal( VarSet() );
711 BOOST_CHECK( y.vars() == VarSet() );
712 BOOST_CHECK_CLOSE( y[0], (Real)1.0, tol );
713 y = x.marginal( VarSet( v1, v2 ) );
714 BOOST_CHECK_SMALL( dist( y, x.normalized(), DISTL1 ), tol );
715
716 y = x.marginal( v1, true );
717 BOOST_CHECK( y.vars() == VarSet( v1 ) );
718 BOOST_CHECK_CLOSE( y[0], (x[0] + x[2] + x[4]) / x.sum(), tol );
719 BOOST_CHECK_CLOSE( y[1], (x[1] + x[3] + x[5]) / x.sum(), tol );
720 y = x.marginal( v2, true );
721 BOOST_CHECK( y.vars() == VarSet( v2 ) );
722 BOOST_CHECK_CLOSE( y[0], (x[0] + x[1]) / x.sum(), tol );
723 BOOST_CHECK_CLOSE( y[1], (x[2] + x[3]) / x.sum(), tol );
724 BOOST_CHECK_CLOSE( y[2], (x[4] + x[5]) / x.sum(), tol );
725 y = x.marginal( VarSet(), true );
726 BOOST_CHECK( y.vars() == VarSet() );
727 BOOST_CHECK_CLOSE( y[0], (Real)1.0, tol );
728 y = x.marginal( VarSet( v1, v2 ), true );
729 BOOST_CHECK_SMALL( dist( y, x.normalized(), DISTL1 ), tol );
730
731 y = x.marginal( v1, false );
732 BOOST_CHECK( y.vars() == VarSet( v1 ) );
733 BOOST_CHECK_CLOSE( y[0], x[0] + x[2] + x[4], tol );
734 BOOST_CHECK_CLOSE( y[1], x[1] + x[3] + x[5], tol );
735 y = x.marginal( v2, false );
736 BOOST_CHECK( y.vars() == VarSet( v2 ) );
737 BOOST_CHECK_CLOSE( y[0], x[0] + x[1], tol );
738 BOOST_CHECK_CLOSE( y[1], x[2] + x[3], tol );
739 BOOST_CHECK_CLOSE( y[2], x[4] + x[5], tol );
740 y = x.marginal( VarSet(), false );
741 BOOST_CHECK( y.vars() == VarSet() );
742 BOOST_CHECK_CLOSE( y[0], x.sum(), tol );
743 y = x.marginal( VarSet( v1, v2 ), false );
744 BOOST_CHECK_SMALL( dist( y, x, DISTL1 ), tol );
745
746 // maxMarginal
747 y = x.maxMarginal( v1 );
748 BOOST_CHECK( y.vars() == VarSet( v1 ) );
749 BOOST_CHECK_CLOSE( y[0], x.slice( v1, 0 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()), tol );
750 BOOST_CHECK_CLOSE( y[1], x.slice( v1, 1 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()), tol );
751 y = x.maxMarginal( v2 );
752 BOOST_CHECK( y.vars() == VarSet( v2 ) );
753 BOOST_CHECK_CLOSE( y[0], x.slice( v2, 0 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
754 BOOST_CHECK_CLOSE( y[1], x.slice( v2, 1 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
755 BOOST_CHECK_CLOSE( y[2], x.slice( v2, 2 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
756 y = x.maxMarginal( VarSet() );
757 BOOST_CHECK( y.vars() == VarSet() );
758 BOOST_CHECK_CLOSE( y[0], (Real)1.0, tol );
759 y = x.maxMarginal( VarSet( v1, v2 ) );
760 BOOST_CHECK_SMALL( dist( y, x.normalized(), DISTL1 ), tol );
761
762 y = x.maxMarginal( v1, true );
763 BOOST_CHECK( y.vars() == VarSet( v1 ) );
764 BOOST_CHECK_CLOSE( y[0], x.slice( v1, 0 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()), tol );
765 BOOST_CHECK_CLOSE( y[1], x.slice( v1, 1 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()), tol );
766 y = x.maxMarginal( v2, true );
767 BOOST_CHECK( y.vars() == VarSet( v2 ) );
768 BOOST_CHECK_CLOSE( y[0], x.slice( v2, 0 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
769 BOOST_CHECK_CLOSE( y[1], x.slice( v2, 1 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
770 BOOST_CHECK_CLOSE( y[2], x.slice( v2, 2 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()), tol );
771 y = x.maxMarginal( VarSet(), true );
772 BOOST_CHECK( y.vars() == VarSet() );
773 BOOST_CHECK_CLOSE( y[0], (Real)1.0, tol );
774 y = x.maxMarginal( VarSet( v1, v2 ), true );
775 BOOST_CHECK_SMALL( dist( y, x.normalized(), DISTL1 ), tol );
776
777 y = x.maxMarginal( v1, false );
778 BOOST_CHECK( y.vars() == VarSet( v1 ) );
779 BOOST_CHECK_EQUAL( y[0], x.slice( v1, 0 ).max() );
780 BOOST_CHECK_EQUAL( y[1], x.slice( v1, 1 ).max() );
781 y = x.maxMarginal( v2, false );
782 BOOST_CHECK( y.vars() == VarSet( v2 ) );
783 BOOST_CHECK_EQUAL( y[0], x.slice( v2, 0 ).max() );
784 BOOST_CHECK_EQUAL( y[1], x.slice( v2, 1 ).max() );
785 BOOST_CHECK_EQUAL( y[2], x.slice( v2, 2 ).max() );
786 y = x.maxMarginal( VarSet(), false );
787 BOOST_CHECK( y.vars() == VarSet() );
788 BOOST_CHECK_EQUAL( y[0], x.max() );
789 y = x.maxMarginal( VarSet( v1, v2 ), false );
790 BOOST_CHECK( y == x );
791 }
792
793
794 BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
795 Var v( 0, 3 );
796 Factor x(v), y(v), z(v);
797 x.set( 0, 0.2 );
798 x.set( 1, 0.8 );
799 x.set( 2, 0.0 );
800 y.set( 0, 0.0 );
801 y.set( 1, 0.6 );
802 y.set( 2, 0.4 );
803
804 z = min( x, y );
805 BOOST_CHECK_EQUAL( z[0], (Real)0.0 );
806 BOOST_CHECK_EQUAL( z[1], (Real)0.6 );
807 BOOST_CHECK_EQUAL( z[2], (Real)0.0 );
808 z = max( x, y );
809 BOOST_CHECK_EQUAL( z[0], (Real)0.2 );
810 BOOST_CHECK_EQUAL( z[1], (Real)0.8 );
811 BOOST_CHECK_EQUAL( z[2], (Real)0.4 );
812
813 BOOST_CHECK_CLOSE( dist( x, x, DISTL1 ), (Real)0.0, tol );
814 BOOST_CHECK_CLOSE( dist( y, y, DISTL1 ), (Real)0.0, tol );
815 BOOST_CHECK_CLOSE( dist( x, y, DISTL1 ), (Real)(0.2 + 0.2 + 0.4), tol );
816 BOOST_CHECK_CLOSE( dist( y, x, DISTL1 ), (Real)(0.2 + 0.2 + 0.4), tol );
817 BOOST_CHECK_CLOSE( dist( x, x, DISTLINF ), (Real)0.0, tol );
818 BOOST_CHECK_CLOSE( dist( y, y, DISTLINF ), (Real)0.0, tol );
819 BOOST_CHECK_CLOSE( dist( x, y, DISTLINF ), (Real)0.4, tol );
820 BOOST_CHECK_CLOSE( dist( y, x, DISTLINF ), (Real)0.4, tol );
821 BOOST_CHECK_CLOSE( dist( x, x, DISTTV ), (Real)0.0, tol );
822 BOOST_CHECK_CLOSE( dist( y, y, DISTTV ), (Real)0.0, tol );
823 BOOST_CHECK_CLOSE( dist( x, y, DISTTV ), (Real)(0.5 * (0.2 + 0.2 + 0.4)), tol );
824 BOOST_CHECK_CLOSE( dist( y, x, DISTTV ), (Real)(0.5 * (0.2 + 0.2 + 0.4)), tol );
825 BOOST_CHECK_CLOSE( dist( x, x, DISTKL ), (Real)0.0, tol );
826 BOOST_CHECK_CLOSE( dist( y, y, DISTKL ), (Real)0.0, tol );
827 BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), INFINITY );
828 BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), INFINITY );
829 BOOST_CHECK_CLOSE( dist( x, x, DISTHEL ), (Real)0.0, tol );
830 BOOST_CHECK_CLOSE( dist( y, y, DISTHEL ), (Real)0.0, tol );
831 BOOST_CHECK_CLOSE( dist( x, y, DISTHEL ), (Real)(0.5 * (0.2 + dai::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4)), tol );
832 BOOST_CHECK_CLOSE( dist( y, x, DISTHEL ), (Real)(0.5 * (0.2 + dai::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4)), tol );
833 x.set( 1, 0.7 ); x.set( 2, 0.1 );
834 y.set( 0, 0.1 ); y.set( 1, 0.5 );
835 BOOST_CHECK_CLOSE( dist( x, y, DISTKL ), (Real)(0.2 * dai::log(0.2 / 0.1) + 0.7 * dai::log(0.7 / 0.5) + 0.1 * dai::log(0.1 / 0.4)), tol );
836 BOOST_CHECK_CLOSE( dist( y, x, DISTKL ), (Real)(0.1 * dai::log(0.1 / 0.2) + 0.5 * dai::log(0.5 / 0.7) + 0.4 * dai::log(0.4 / 0.1)), tol );
837
838 std::stringstream ss;
839 ss << x;
840 std::string s;
841 std::getline( ss, s );
842 BOOST_CHECK_EQUAL( s, std::string("({x0}, (0.2, 0.7, 0.1))") );
843 std::stringstream ss2;
844 ss2 << y;
845 std::getline( ss2, s );
846 BOOST_CHECK_EQUAL( s, std::string("({x0}, (0.1, 0.5, 0.4))") );
847
848 z = min( x, y );
849 BOOST_CHECK_EQUAL( z[0], (Real)0.1 );
850 BOOST_CHECK_EQUAL( z[1], (Real)0.5 );
851 BOOST_CHECK_EQUAL( z[2], (Real)0.1 );
852 z = max( x, y );
853 BOOST_CHECK_EQUAL( z[0], (Real)0.2 );
854 BOOST_CHECK_EQUAL( z[1], (Real)0.7 );
855 BOOST_CHECK_EQUAL( z[2], (Real)0.4 );
856
857 for( Real J = -1.0; J <= 1.01; J += 0.1 ) {
858 Factor x = createFactorIsing( Var(0,2), Var(1,2), J ).normalized();
859 BOOST_CHECK_CLOSE( x[0], dai::exp(J) / (4.0 * std::cosh(J)), tol );
860 BOOST_CHECK_CLOSE( x[1], dai::exp(-J) / (4.0 * std::cosh(J)), tol );
861 BOOST_CHECK_CLOSE( x[2], dai::exp(-J) / (4.0 * std::cosh(J)), tol );
862 BOOST_CHECK_CLOSE( x[3], dai::exp(J) / (4.0 * std::cosh(J)), tol );
863 BOOST_CHECK_SMALL( MutualInfo( x ) - (J * std::tanh(J) - dai::log(std::cosh(J))), tol );
864 }
865 Var v1( 1, 3 );
866 Var v2( 2, 4 );
867 BOOST_CHECK_SMALL( MutualInfo( (Factor(v1).randomize() * Factor(v2).randomize()).normalized() ), tol );
868 BOOST_CHECK_THROW( MutualInfo( createFactorIsing( Var(0,2), 1.0 ).normalized() ), Exception );
869 BOOST_CHECK_THROW( createFactorIsing( v1, 0.0 ), Exception );
870 BOOST_CHECK_THROW( createFactorIsing( v1, v2, 0.0 ), Exception );
871 for( Real J = -1.0; J <= 1.01; J += 0.1 ) {
872 Factor x = createFactorIsing( Var(0,2), J ).normalized();
873 BOOST_CHECK_CLOSE( x[0], dai::exp(-J) / (2.0 * std::cosh(J)), tol );
874 BOOST_CHECK_CLOSE( x[1], dai::exp(J) / (2.0 * std::cosh(J)), tol );
875 BOOST_CHECK_SMALL( x.entropy() - (-J * std::tanh(J) + dai::log(2.0 * std::cosh(J))), tol );
876 }
877
878 x = createFactorDelta( v1, 2 );
879 BOOST_CHECK_EQUAL( x[0], (Real)0.0 );
880 BOOST_CHECK_EQUAL( x[1], (Real)0.0 );
881 BOOST_CHECK_EQUAL( x[2], (Real)1.0 );
882 x = createFactorDelta( v1, 1 );
883 BOOST_CHECK_EQUAL( x[0], (Real)0.0 );
884 BOOST_CHECK_EQUAL( x[1], (Real)1.0 );
885 BOOST_CHECK_EQUAL( x[2], (Real)0.0 );
886 x = createFactorDelta( v1, 0 );
887 BOOST_CHECK_EQUAL( x[0], (Real)1.0 );
888 BOOST_CHECK_EQUAL( x[1], (Real)0.0 );
889 BOOST_CHECK_EQUAL( x[2], (Real)0.0 );
890 BOOST_CHECK_THROW( createFactorDelta( v1, 4 ), Exception );
891
892 for( size_t i = 0; i < 12; i++ ) {
893 Factor xx = createFactorDelta( VarSet( v1, v2 ), i );
894 for( size_t j = 0; j < 12; j++ )
895 BOOST_CHECK_EQUAL( xx[j], (Real)(i == j) );
896 }
897 BOOST_CHECK_THROW( createFactorDelta( VarSet( v1, v2 ), 12 ), Exception );
898 }