New git HEAD version
[libdai.git] / tests / unit / daialg_test.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/daialg.h>
10 #include <dai/alldai.h>
11 #include <strstream>
12
13
14 using namespace dai;
15
16
17 const double tol = 1e-8;
18
19
20 #define BOOST_TEST_MODULE DAIAlgTest
21
22
23 #include <boost/test/unit_test.hpp>
24
25
26 BOOST_AUTO_TEST_CASE( calcMarginalTest ) {
27 Var v0( 0, 2 );
28 Var v1( 1, 2 );
29 Var v2( 2, 2 );
30 Var v3( 3, 2 );
31 VarSet v01( v0, v1 );
32 VarSet v02( v0, v2 );
33 VarSet v03( v0, v3 );
34 VarSet v12( v1, v2 );
35 VarSet v13( v1, v3 );
36 VarSet v23( v2, v3 );
37 std::vector<Factor> facs;
38 facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
39 facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
40 facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
41 facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
42 facs.push_back( createFactorIsing( v0, -1.0 ) );
43 facs.push_back( createFactorIsing( v1, -1.0 ) );
44 facs.push_back( createFactorIsing( v2, -1.0 ) );
45 facs.push_back( createFactorIsing( v3, 1.0 ) );
46 Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
47 FactorGraph fg( facs );
48 ExactInf ei( fg, PropertySet()("verbose",(size_t)0) );
49 ei.init();
50 ei.run();
51 VarSet vs;
52
53 vs = v0; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
54 vs = v1; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
55 vs = v2; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
56 vs = v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
57 vs = v01; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
58 vs = v02; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
59 vs = v03; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
60 vs = v12; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
61 vs = v13; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
62 vs = v23; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
63 vs = v01 | v2; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
64 vs = v01 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
65 vs = v02 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
66 vs = v12 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
67 vs = v01 | v23; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
68 }
69
70
71 BOOST_AUTO_TEST_CASE( calcPairBeliefsTest ) {
72 Var v0( 0, 2 );
73 Var v1( 1, 2 );
74 Var v2( 2, 2 );
75 Var v3( 3, 2 );
76 VarSet v01( v0, v1 );
77 VarSet v02( v0, v2 );
78 VarSet v03( v0, v3 );
79 VarSet v12( v1, v2 );
80 VarSet v13( v1, v3 );
81 VarSet v23( v2, v3 );
82 std::vector<Factor> facs;
83 facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
84 facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
85 facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
86 facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
87 facs.push_back( createFactorIsing( v0, -1.0 ) );
88 facs.push_back( createFactorIsing( v1, -1.0 ) );
89 facs.push_back( createFactorIsing( v2, -1.0 ) );
90 facs.push_back( createFactorIsing( v3, 1.0 ) );
91 Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
92 FactorGraph fg( facs );
93 ExactInf ei( fg, PropertySet()("verbose",(size_t)0) );
94 ei.init();
95 ei.run();
96 VarSet vs;
97
98 std::vector<Factor> pb = calcPairBeliefs( ei, v01 | v23, false, false );
99 BOOST_CHECK( dist( pb[0], joint.marginal( v01 ), DISTTV ) < tol );
100 BOOST_CHECK( dist( pb[1], joint.marginal( v02 ), DISTTV ) < tol );
101 BOOST_CHECK( dist( pb[2], joint.marginal( v03 ), DISTTV ) < tol );
102 BOOST_CHECK( dist( pb[3], joint.marginal( v12 ), DISTTV ) < tol );
103 BOOST_CHECK( dist( pb[4], joint.marginal( v13 ), DISTTV ) < tol );
104 BOOST_CHECK( dist( pb[5], joint.marginal( v23 ), DISTTV ) < tol );
105
106 pb = calcPairBeliefs( ei, v01 | v23, false, true );
107 BOOST_CHECK( dist( pb[0], joint.marginal( v01 ), DISTTV ) < tol );
108 BOOST_CHECK( dist( pb[1], joint.marginal( v02 ), DISTTV ) < tol );
109 BOOST_CHECK( dist( pb[2], joint.marginal( v03 ), DISTTV ) < tol );
110 BOOST_CHECK( dist( pb[3], joint.marginal( v12 ), DISTTV ) < tol );
111 BOOST_CHECK( dist( pb[4], joint.marginal( v13 ), DISTTV ) < tol );
112 BOOST_CHECK( dist( pb[5], joint.marginal( v23 ), DISTTV ) < tol );
113 }