1b060a8e31fa6223dc07c72392ec17e4927d0bc8
[libdai.git] / tests / unit / daialg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/daialg.h>
15 #include <dai/alldai.h>
16 #include <strstream>
17
18
19 using namespace dai;
20
21
22 const double tol = 1e-8;
23
24
25 #define BOOST_TEST_MODULE DAIAlgTest
26
27
28 #include <boost/test/unit_test.hpp>
29
30
31 BOOST_AUTO_TEST_CASE( calcMarginalTest ) {
32 Var v0( 0, 2 );
33 Var v1( 1, 2 );
34 Var v2( 2, 2 );
35 Var v3( 3, 2 );
36 VarSet v01( v0, v1 );
37 VarSet v02( v0, v2 );
38 VarSet v03( v0, v3 );
39 VarSet v12( v1, v2 );
40 VarSet v13( v1, v3 );
41 VarSet v23( v2, v3 );
42 std::vector<Factor> facs;
43 facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
44 facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
45 facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
46 facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
47 facs.push_back( createFactorIsing( v0, -1.0 ) );
48 facs.push_back( createFactorIsing( v1, -1.0 ) );
49 facs.push_back( createFactorIsing( v2, -1.0 ) );
50 facs.push_back( createFactorIsing( v3, 1.0 ) );
51 Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
52 FactorGraph fg( facs );
53 ExactInf ei( fg, PropertySet()("verbose",(size_t)0) );
54 VarSet vs;
55
56 vs = v0; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
57 vs = v1; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
58 vs = v2; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
59 vs = v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
60 vs = v01; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
61 vs = v02; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
62 vs = v03; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
63 vs = v12; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
64 vs = v13; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
65 vs = v23; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
66 vs = v01 | v2; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
67 vs = v01 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
68 vs = v02 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
69 vs = v12 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
70 vs = v01 | v23; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
71 }
72
73
74 BOOST_AUTO_TEST_CASE( calcPairBeliefsTest ) {
75 Var v0( 0, 2 );
76 Var v1( 1, 2 );
77 Var v2( 2, 2 );
78 Var v3( 3, 2 );
79 VarSet v01( v0, v1 );
80 VarSet v02( v0, v2 );
81 VarSet v03( v0, v3 );
82 VarSet v12( v1, v2 );
83 VarSet v13( v1, v3 );
84 VarSet v23( v2, v3 );
85 std::vector<Factor> facs;
86 facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
87 facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
88 facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
89 facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
90 facs.push_back( createFactorIsing( v0, -1.0 ) );
91 facs.push_back( createFactorIsing( v1, -1.0 ) );
92 facs.push_back( createFactorIsing( v2, -1.0 ) );
93 facs.push_back( createFactorIsing( v3, 1.0 ) );
94 Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
95 FactorGraph fg( facs );
96 ExactInf ei( fg, PropertySet()("verbose",(size_t)0) );
97 VarSet vs;
98
99 std::vector<Factor> pb = calcPairBeliefs( ei, v01 | v23, false, false );
100 BOOST_CHECK( dist( pb[0], joint.marginal( v01 ), DISTTV ) < tol );
101 BOOST_CHECK( dist( pb[1], joint.marginal( v02 ), DISTTV ) < tol );
102 BOOST_CHECK( dist( pb[2], joint.marginal( v03 ), DISTTV ) < tol );
103 BOOST_CHECK( dist( pb[3], joint.marginal( v12 ), DISTTV ) < tol );
104 BOOST_CHECK( dist( pb[4], joint.marginal( v13 ), DISTTV ) < tol );
105 BOOST_CHECK( dist( pb[5], joint.marginal( v23 ), DISTTV ) < tol );
106
107 pb = calcPairBeliefs( ei, v01 | v23, false, true );
108 BOOST_CHECK( dist( pb[0], joint.marginal( v01 ), DISTTV ) < tol );
109 BOOST_CHECK( dist( pb[1], joint.marginal( v02 ), DISTTV ) < tol );
110 BOOST_CHECK( dist( pb[2], joint.marginal( v03 ), DISTTV ) < tol );
111 BOOST_CHECK( dist( pb[3], joint.marginal( v12 ), DISTTV ) < tol );
112 BOOST_CHECK( dist( pb[4], joint.marginal( v13 ), DISTTV ) < tol );
113 BOOST_CHECK( dist( pb[5], joint.marginal( v23 ), DISTTV ) < tol );
114 }