Wrote alldai.h/cpp unit tests
[libdai.git] / tests / unit / daialg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/daialg.h>
15 #include <dai/alldai.h>
16 #include <strstream>
17
18
19 using namespace dai;
20
21
22 const double tol = 1e-8;
23
24
25 #define BOOST_TEST_MODULE DAIAlgTest
26
27
28 #include <boost/test/unit_test.hpp>
29
30
31 BOOST_AUTO_TEST_CASE( calcMarginalTest ) {
32 Var v0( 0, 2 );
33 Var v1( 1, 2 );
34 Var v2( 2, 2 );
35 Var v3( 3, 2 );
36 VarSet v01( v0, v1 );
37 VarSet v02( v0, v2 );
38 VarSet v03( v0, v3 );
39 VarSet v12( v1, v2 );
40 VarSet v13( v1, v3 );
41 VarSet v23( v2, v3 );
42 std::vector<Factor> facs;
43 facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
44 facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
45 facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
46 facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
47 facs.push_back( createFactorIsing( v0, -1.0 ) );
48 facs.push_back( createFactorIsing( v1, -1.0 ) );
49 facs.push_back( createFactorIsing( v2, -1.0 ) );
50 facs.push_back( createFactorIsing( v3, 1.0 ) );
51 Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
52 FactorGraph fg( facs );
53 ExactInf ei( fg, PropertySet()("verbose",(size_t)0) );
54 ei.init();
55 ei.run();
56 VarSet vs;
57
58 vs = v0; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
59 vs = v1; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
60 vs = v2; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
61 vs = v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
62 vs = v01; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
63 vs = v02; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
64 vs = v03; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
65 vs = v12; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
66 vs = v13; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
67 vs = v23; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
68 vs = v01 | v2; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
69 vs = v01 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
70 vs = v02 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
71 vs = v12 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
72 vs = v01 | v23; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
73 }
74
75
76 BOOST_AUTO_TEST_CASE( calcPairBeliefsTest ) {
77 Var v0( 0, 2 );
78 Var v1( 1, 2 );
79 Var v2( 2, 2 );
80 Var v3( 3, 2 );
81 VarSet v01( v0, v1 );
82 VarSet v02( v0, v2 );
83 VarSet v03( v0, v3 );
84 VarSet v12( v1, v2 );
85 VarSet v13( v1, v3 );
86 VarSet v23( v2, v3 );
87 std::vector<Factor> facs;
88 facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
89 facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
90 facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
91 facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
92 facs.push_back( createFactorIsing( v0, -1.0 ) );
93 facs.push_back( createFactorIsing( v1, -1.0 ) );
94 facs.push_back( createFactorIsing( v2, -1.0 ) );
95 facs.push_back( createFactorIsing( v3, 1.0 ) );
96 Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
97 FactorGraph fg( facs );
98 ExactInf ei( fg, PropertySet()("verbose",(size_t)0) );
99 ei.init();
100 ei.run();
101 VarSet vs;
102
103 std::vector<Factor> pb = calcPairBeliefs( ei, v01 | v23, false, false );
104 BOOST_CHECK( dist( pb[0], joint.marginal( v01 ), DISTTV ) < tol );
105 BOOST_CHECK( dist( pb[1], joint.marginal( v02 ), DISTTV ) < tol );
106 BOOST_CHECK( dist( pb[2], joint.marginal( v03 ), DISTTV ) < tol );
107 BOOST_CHECK( dist( pb[3], joint.marginal( v12 ), DISTTV ) < tol );
108 BOOST_CHECK( dist( pb[4], joint.marginal( v13 ), DISTTV ) < tol );
109 BOOST_CHECK( dist( pb[5], joint.marginal( v23 ), DISTTV ) < tol );
110
111 pb = calcPairBeliefs( ei, v01 | v23, false, true );
112 BOOST_CHECK( dist( pb[0], joint.marginal( v01 ), DISTTV ) < tol );
113 BOOST_CHECK( dist( pb[1], joint.marginal( v02 ), DISTTV ) < tol );
114 BOOST_CHECK( dist( pb[2], joint.marginal( v03 ), DISTTV ) < tol );
115 BOOST_CHECK( dist( pb[3], joint.marginal( v12 ), DISTTV ) < tol );
116 BOOST_CHECK( dist( pb[4], joint.marginal( v13 ), DISTTV ) < tol );
117 BOOST_CHECK( dist( pb[5], joint.marginal( v23 ), DISTTV ) < tol );
118 }