Some improvements to jtree and regiongraph and started work on regiongraph unit tests
[libdai.git] / tests / unit / clustergraph.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/clustergraph.h>
15 #include <vector>
16 #include <strstream>
17
18
19 using namespace dai;
20
21
22 const double tol = 1e-8;
23
24
25 #define BOOST_TEST_MODULE ClusterGraphTest
26
27
28 #include <boost/test/unit_test.hpp>
29 #include <boost/test/floating_point_comparison.hpp>
30
31
32 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
33 ClusterGraph G;
34 BOOST_CHECK_EQUAL( G.clusters(), std::vector<VarSet>() );
35 BOOST_CHECK( G.bipGraph() == BipartiteGraph() );
36 BOOST_CHECK_EQUAL( G.nrVars(), 0 );
37 BOOST_CHECK_EQUAL( G.nrClusters(), 0 );
38 BOOST_CHECK_THROW( G.var( 0 ), Exception );
39 BOOST_CHECK_THROW( G.cluster( 0 ), Exception );
40 BOOST_CHECK_THROW( G.findVar( Var( 0, 2 ) ), Exception );
41
42 Var v0( 0, 2 );
43 Var v1( 1, 3 );
44 Var v2( 2, 2 );
45 Var v3( 3, 4 );
46 std::vector<Var> vs;
47 vs.push_back( v0 );
48 vs.push_back( v1 );
49 vs.push_back( v2 );
50 vs.push_back( v3 );
51 VarSet v01( v0, v1 );
52 VarSet v02( v0, v2 );
53 VarSet v03( v0, v3 );
54 VarSet v12( v1, v2 );
55 VarSet v13( v1, v3 );
56 VarSet v23( v2, v3 );
57 std::vector<VarSet> cl;
58 cl.push_back( v01 );
59 cl.push_back( v12 );
60 cl.push_back( v23 );
61 cl.push_back( v13 );
62 ClusterGraph G2( cl );
63 BOOST_CHECK_EQUAL( G2.nrVars(), 4 );
64 BOOST_CHECK_EQUAL( G2.nrClusters(), 4 );
65 BOOST_CHECK_EQUAL( G2.vars(), vs );
66 BOOST_CHECK_EQUAL( G2.clusters(), cl );
67 BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
68 BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
69 BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
70 BOOST_CHECK_EQUAL( G2.findVar( v3 ), 3 );
71
72 ClusterGraph Gb( G );
73 BOOST_CHECK( G.bipGraph() == Gb.bipGraph() );
74 BOOST_CHECK( G.vars() == Gb.vars() );
75 BOOST_CHECK( G.clusters() == Gb.clusters() );
76
77 ClusterGraph Gc = G;
78 BOOST_CHECK( G.bipGraph() == Gc.bipGraph() );
79 BOOST_CHECK( G.vars() == Gc.vars() );
80 BOOST_CHECK( G.clusters() == Gc.clusters() );
81
82 ClusterGraph G2b( G2 );
83 BOOST_CHECK( G2.bipGraph() == G2b.bipGraph() );
84 BOOST_CHECK( G2.vars() == G2b.vars() );
85 BOOST_CHECK( G2.clusters() == G2b.clusters() );
86
87 ClusterGraph G2c = G2;
88 BOOST_CHECK( G2.bipGraph() == G2c.bipGraph() );
89 BOOST_CHECK( G2.vars() == G2c.vars() );
90 BOOST_CHECK( G2.clusters() == G2c.clusters() );
91 }
92
93
94 BOOST_AUTO_TEST_CASE( QueriesTest ) {
95 Var v0( 0, 2 );
96 Var v1( 1, 3 );
97 Var v2( 2, 2 );
98 Var v3( 3, 4 );
99 Var v4( 4, 2 );
100 std::vector<Var> vs;
101 vs.push_back( v0 );
102 vs.push_back( v1 );
103 vs.push_back( v2 );
104 vs.push_back( v3 );
105 vs.push_back( v4 );
106 VarSet v01( v0, v1 );
107 VarSet v02( v0, v2 );
108 VarSet v03( v0, v3 );
109 VarSet v04( v0, v4 );
110 VarSet v12( v1, v2 );
111 VarSet v13( v1, v3 );
112 VarSet v14( v1, v4 );
113 VarSet v23( v2, v3 );
114 VarSet v24( v2, v4 );
115 VarSet v34( v3, v4 );
116 VarSet v123 = v12 | v3;
117 std::vector<VarSet> cl;
118 cl.push_back( v01 );
119 cl.push_back( v12 );
120 cl.push_back( v123 );
121 cl.push_back( v34 );
122 cl.push_back( v04 );
123 ClusterGraph G( cl );
124
125 BOOST_CHECK_EQUAL( G.nrVars(), 5 );
126 BOOST_CHECK_EQUAL( G.vars(), vs );
127 BOOST_CHECK_EQUAL( G.var(0), v0 );
128 BOOST_CHECK_EQUAL( G.var(1), v1 );
129 BOOST_CHECK_EQUAL( G.var(2), v2 );
130 BOOST_CHECK_EQUAL( G.var(3), v3 );
131 BOOST_CHECK_EQUAL( G.var(4), v4 );
132 BOOST_CHECK_EQUAL( G.nrClusters(), 5 );
133 BOOST_CHECK_EQUAL( G.clusters(), cl );
134 BOOST_CHECK_EQUAL( G.cluster(0), v01 );
135 BOOST_CHECK_EQUAL( G.cluster(1), v12 );
136 BOOST_CHECK_EQUAL( G.cluster(2), v123 );
137 BOOST_CHECK_EQUAL( G.cluster(3), v34 );
138 BOOST_CHECK_EQUAL( G.cluster(4), v04 );
139 BOOST_CHECK_EQUAL( G.findVar( v0 ), 0 );
140 BOOST_CHECK_EQUAL( G.findVar( v1 ), 1 );
141 BOOST_CHECK_EQUAL( G.findVar( v2 ), 2 );
142 BOOST_CHECK_EQUAL( G.findVar( v3 ), 3 );
143 BOOST_CHECK_EQUAL( G.findVar( v4 ), 4 );
144 BipartiteGraph H( 5, 5 );
145 H.addEdge( 0, 0 );
146 H.addEdge( 1, 0 );
147 H.addEdge( 1, 1 );
148 H.addEdge( 2, 1 );
149 H.addEdge( 1, 2 );
150 H.addEdge( 2, 2 );
151 H.addEdge( 3, 2 );
152 H.addEdge( 3, 3 );
153 H.addEdge( 4, 3 );
154 H.addEdge( 0, 4 );
155 H.addEdge( 4, 4 );
156 BOOST_CHECK( G.bipGraph() == H );
157
158 BOOST_CHECK_EQUAL( G.delta( 0 ), v14 );
159 BOOST_CHECK_EQUAL( G.delta( 1 ), v02 | v3 );
160 BOOST_CHECK_EQUAL( G.delta( 2 ), v13 );
161 BOOST_CHECK_EQUAL( G.delta( 3 ), v12 | v4 );
162 BOOST_CHECK_EQUAL( G.delta( 4 ), v03 );
163 BOOST_CHECK_EQUAL( G.Delta( 0 ), v14 | v0 );
164 BOOST_CHECK_EQUAL( G.Delta( 1 ), v01 | v23 );
165 BOOST_CHECK_EQUAL( G.Delta( 2 ), v13 | v2 );
166 BOOST_CHECK_EQUAL( G.Delta( 3 ), v12 | v34 );
167 BOOST_CHECK_EQUAL( G.Delta( 4 ), v03 | v4 );
168
169 BOOST_CHECK( !G.adj( 0, 0 ) );
170 BOOST_CHECK( G.adj( 0, 1 ) );
171 BOOST_CHECK( !G.adj( 0, 2 ) );
172 BOOST_CHECK( !G.adj( 0, 3 ) );
173 BOOST_CHECK( G.adj( 0, 4 ) );
174 BOOST_CHECK( G.adj( 1, 0 ) );
175 BOOST_CHECK( !G.adj( 1, 1 ) );
176 BOOST_CHECK( G.adj( 1, 2 ) );
177 BOOST_CHECK( G.adj( 1, 3 ) );
178 BOOST_CHECK( !G.adj( 1, 4 ) );
179 BOOST_CHECK( !G.adj( 2, 0 ) );
180 BOOST_CHECK( G.adj( 2, 1 ) );
181 BOOST_CHECK( !G.adj( 2, 2 ) );
182 BOOST_CHECK( G.adj( 2, 3 ) );
183 BOOST_CHECK( !G.adj( 2, 4 ) );
184 BOOST_CHECK( !G.adj( 3, 0 ) );
185 BOOST_CHECK( G.adj( 3, 1 ) );
186 BOOST_CHECK( G.adj( 3, 2 ) );
187 BOOST_CHECK( !G.adj( 3, 3 ) );
188 BOOST_CHECK( G.adj( 3, 4 ) );
189 BOOST_CHECK( G.adj( 4, 0 ) );
190 BOOST_CHECK( !G.adj( 4, 1 ) );
191 BOOST_CHECK( !G.adj( 4, 2 ) );
192 BOOST_CHECK( G.adj( 4, 3 ) );
193 BOOST_CHECK( !G.adj( 4, 4 ) );
194
195 BOOST_CHECK( G.isMaximal( 0 ) );
196 BOOST_CHECK( !G.isMaximal( 1 ) );
197 BOOST_CHECK( G.isMaximal( 2 ) );
198 BOOST_CHECK( G.isMaximal( 3 ) );
199 BOOST_CHECK( G.isMaximal( 4 ) );
200 }
201
202
203 BOOST_AUTO_TEST_CASE( OperationsTest ) {
204 Var v0( 0, 2 );
205 Var v1( 1, 3 );
206 Var v2( 2, 2 );
207 Var v3( 3, 4 );
208 Var v4( 4, 2 );
209 VarSet v01( v0, v1 );
210 VarSet v02( v0, v2 );
211 VarSet v03( v0, v3 );
212 VarSet v04( v0, v4 );
213 VarSet v12( v1, v2 );
214 VarSet v13( v1, v3 );
215 VarSet v14( v1, v4 );
216 VarSet v23( v2, v3 );
217 VarSet v24( v2, v4 );
218 VarSet v34( v3, v4 );
219 VarSet v123 = v12 | v3;
220 std::vector<VarSet> cl;
221 cl.push_back( v01 );
222 cl.push_back( v12 );
223 cl.push_back( v123 );
224 cl.push_back( v34 );
225 cl.push_back( v04 );
226 ClusterGraph G( cl );
227
228 BipartiteGraph H( 5, 5 );
229 H.addEdge( 0, 0 );
230 H.addEdge( 1, 0 );
231 H.addEdge( 1, 1 );
232 H.addEdge( 2, 1 );
233 H.addEdge( 1, 2 );
234 H.addEdge( 2, 2 );
235 H.addEdge( 3, 2 );
236 H.addEdge( 3, 3 );
237 H.addEdge( 4, 3 );
238 H.addEdge( 0, 4 );
239 H.addEdge( 4, 4 );
240 BOOST_CHECK( G.bipGraph() == H );
241
242 G.eraseNonMaximal();
243 BOOST_CHECK_EQUAL( G.nrClusters(), 4 );
244 H.eraseNode2( 1 );
245 BOOST_CHECK( G.bipGraph() == H );
246 G.eraseSubsuming( 4 );
247 BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
248 H.eraseNode2( 2 );
249 H.eraseNode2( 2 );
250 BOOST_CHECK( G.bipGraph() == H );
251 G.insert( v34 );
252 BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
253 G.insert( v123 );
254 BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
255 H.addNode2();
256 H.addEdge( 3, 2 );
257 H.addEdge( 4, 2 );
258 BOOST_CHECK( G.bipGraph() == H );
259 G.insert( v12 );
260 G.insert( v23 );
261 BOOST_CHECK_EQUAL( G.nrClusters(), 5 );
262 H.addNode2();
263 H.addNode2();
264 H.addEdge( 1, 3 );
265 H.addEdge( 2, 3 );
266 H.addEdge( 2, 4 );
267 H.addEdge( 3, 4 );
268 BOOST_CHECK( G.bipGraph() == H );
269 G.eraseNonMaximal();
270 BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
271 H.eraseNode2( 3 );
272 H.eraseNode2( 3 );
273 BOOST_CHECK( G.bipGraph() == H );
274 G.eraseSubsuming( 2 );
275 BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
276 H.eraseNode2( 1 );
277 BOOST_CHECK( G.bipGraph() == H );
278 G.eraseNonMaximal();
279 BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
280 BOOST_CHECK( G.bipGraph() == H );
281 G.eraseSubsuming( 0 );
282 BOOST_CHECK_EQUAL( G.nrClusters(), 1 );
283 H.eraseNode2( 0 );
284 BOOST_CHECK( G.bipGraph() == H );
285 G.eraseSubsuming( 4 );
286 BOOST_CHECK_EQUAL( G.nrClusters(), 0 );
287 H.eraseNode2( 0 );
288 BOOST_CHECK( G.bipGraph() == H );
289 }
290
291
292 BOOST_AUTO_TEST_CASE( VarElimTest ) {
293 Var v0( 0, 2 );
294 Var v1( 1, 3 );
295 Var v2( 2, 2 );
296 Var v3( 3, 4 );
297 Var v4( 4, 2 );
298 VarSet v01( v0, v1 );
299 VarSet v02( v0, v2 );
300 VarSet v03( v0, v3 );
301 VarSet v04( v0, v4 );
302 VarSet v12( v1, v2 );
303 VarSet v13( v1, v3 );
304 VarSet v14( v1, v4 );
305 VarSet v23( v2, v3 );
306 VarSet v24( v2, v4 );
307 VarSet v34( v3, v4 );
308 VarSet v123 = v12 | v3;
309 std::vector<VarSet> cl;
310 cl.push_back( v01 );
311 cl.push_back( v12 );
312 cl.push_back( v123 );
313 cl.push_back( v34 );
314 cl.push_back( v04 );
315 ClusterGraph G( cl );
316 ClusterGraph Gorg = G;
317
318 BipartiteGraph H( 5, 5 );
319 H.addEdge( 0, 0 );
320 H.addEdge( 1, 0 );
321 H.addEdge( 1, 1 );
322 H.addEdge( 2, 1 );
323 H.addEdge( 1, 2 );
324 H.addEdge( 2, 2 );
325 H.addEdge( 3, 2 );
326 H.addEdge( 3, 3 );
327 H.addEdge( 4, 3 );
328 H.addEdge( 0, 4 );
329 H.addEdge( 4, 4 );
330 BOOST_CHECK( G.bipGraph() == H );
331
332 BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 0 ), 1 );
333 BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 1 ), 2 );
334 BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 2 ), 0 );
335 BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 3 ), 2 );
336 BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 4 ), 1 );
337 cl.clear();
338 cl.push_back( v123 );
339 cl.push_back( v01 | v4 );
340 cl.push_back( v13 | v4 );
341 cl.push_back( v34 );
342 cl.push_back( v4 );
343 BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinFill ) ).clusters(), cl );
344
345 G = Gorg;
346 BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 0 ), 2*3 );
347 BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 1 ), 2*2+2*4 );
348 BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 2 ), 0 );
349 BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 3 ), 3*2+2*2 );
350 BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 4 ), 2*4 );
351 cl.clear();
352 cl.push_back( v123 );
353 cl.push_back( v01 | v4 );
354 cl.push_back( v13 | v4 );
355 cl.push_back( v34 );
356 cl.push_back( v4 );
357 BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_WeightedMinFill ) ).clusters(), cl );
358
359 G = Gorg;
360 BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 0 ), 2 );
361 BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 1 ), 3 );
362 BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 2 ), 2 );
363 BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 3 ), 3 );
364 BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 4 ), 2 );
365 cl.clear();
366 cl.push_back( v01 | v4 );
367 cl.push_back( v123 );
368 cl.push_back( v13 | v4 );
369 cl.push_back( v34 );
370 cl.push_back( v4 );
371 BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinNeighbors ) ).clusters(), cl );
372
373 G = Gorg;
374 BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 0 ), 3*2 );
375 BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 1 ), 2*2*4 );
376 BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 2 ), 3*4 );
377 BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 3 ), 3*2*2 );
378 BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 4 ), 2*4 );
379 cl.clear();
380 cl.push_back( v01 | v4 );
381 cl.push_back( v123 );
382 cl.push_back( v13 | v4 );
383 cl.push_back( v14 );
384 cl.push_back( v4 );
385 BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinWeight ) ).clusters(), cl );
386
387 G = Gorg;
388 std::vector<Var> vs;
389 vs.push_back( v4 );
390 vs.push_back( v3 );
391 vs.push_back( v2 );
392 vs.push_back( v1 );
393 vs.push_back( v0 );
394 cl.clear();
395 cl.push_back( v03 | v4 );
396 cl.push_back( v01 | v23 );
397 cl.push_back( v01 | v2 );
398 cl.push_back( v01 );
399 cl.push_back( v0 );
400 BOOST_CHECK_EQUAL( G.VarElim( sequentialVariableElimination( vs ) ).clusters(), cl );
401 }
402
403
404 BOOST_AUTO_TEST_CASE( IOTest ) {
405 Var v0( 0, 2 );
406 Var v1( 1, 3 );
407 Var v2( 2, 2 );
408 Var v3( 3, 4 );
409 VarSet v01( v0, v1 );
410 VarSet v02( v0, v2 );
411 VarSet v03( v0, v3 );
412 VarSet v12( v1, v2 );
413 VarSet v13( v1, v3 );
414 VarSet v23( v2, v3 );
415 std::vector<VarSet> cl;
416 cl.push_back( v01 );
417 cl.push_back( v12 );
418 cl.push_back( v23 );
419 cl.push_back( v13 );
420 ClusterGraph G( cl );
421
422 std::stringstream ss;
423 ss << G;
424 std::string s;
425 getline( ss, s );
426 BOOST_CHECK_EQUAL( s, "({x0, x1}, {x1, x2}, {x2, x3}, {x1, x3})" );
427 }