Miscellaneous improvements to regiongraph, factorgraph and bipgraph, and finished...
[libdai.git] / tests / unit / clustergraph.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/clustergraph.h>
15 #include <vector>
16 #include <strstream>
17
18
19 using namespace dai;
20
21
22 const double tol = 1e-8;
23
24
25 #define BOOST_TEST_MODULE ClusterGraphTest
26
27
28 #include <boost/test/unit_test.hpp>
29 #include <boost/test/floating_point_comparison.hpp>
30
31
32 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
33 ClusterGraph G;
34 BOOST_CHECK_EQUAL( G.clusters(), std::vector<VarSet>() );
35 BOOST_CHECK( G.bipGraph() == BipartiteGraph() );
36 BOOST_CHECK_EQUAL( G.nrVars(), 0 );
37 BOOST_CHECK_EQUAL( G.nrClusters(), 0 );
38 #ifdef DAI_DEBUG
39 BOOST_CHECK_THROW( G.var( 0 ), Exception );
40 BOOST_CHECK_THROW( G.cluster( 0 ), Exception );
41 #endif
42 BOOST_CHECK_THROW( G.findVar( Var( 0, 2 ) ), Exception );
43
44 Var v0( 0, 2 );
45 Var v1( 1, 3 );
46 Var v2( 2, 2 );
47 Var v3( 3, 4 );
48 std::vector<Var> vs;
49 vs.push_back( v0 );
50 vs.push_back( v1 );
51 vs.push_back( v2 );
52 vs.push_back( v3 );
53 VarSet v01( v0, v1 );
54 VarSet v02( v0, v2 );
55 VarSet v03( v0, v3 );
56 VarSet v12( v1, v2 );
57 VarSet v13( v1, v3 );
58 VarSet v23( v2, v3 );
59 std::vector<VarSet> cl;
60 cl.push_back( v01 );
61 cl.push_back( v12 );
62 cl.push_back( v23 );
63 cl.push_back( v13 );
64 ClusterGraph G2( cl );
65 BOOST_CHECK_EQUAL( G2.nrVars(), 4 );
66 BOOST_CHECK_EQUAL( G2.nrClusters(), 4 );
67 BOOST_CHECK_EQUAL( G2.vars(), vs );
68 BOOST_CHECK_EQUAL( G2.clusters(), cl );
69 BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
70 BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
71 BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
72 BOOST_CHECK_EQUAL( G2.findVar( v3 ), 3 );
73
74 ClusterGraph Gb( G );
75 BOOST_CHECK( G.bipGraph() == Gb.bipGraph() );
76 BOOST_CHECK( G.vars() == Gb.vars() );
77 BOOST_CHECK( G.clusters() == Gb.clusters() );
78
79 ClusterGraph Gc = G;
80 BOOST_CHECK( G.bipGraph() == Gc.bipGraph() );
81 BOOST_CHECK( G.vars() == Gc.vars() );
82 BOOST_CHECK( G.clusters() == Gc.clusters() );
83
84 ClusterGraph G2b( G2 );
85 BOOST_CHECK( G2.bipGraph() == G2b.bipGraph() );
86 BOOST_CHECK( G2.vars() == G2b.vars() );
87 BOOST_CHECK( G2.clusters() == G2b.clusters() );
88
89 ClusterGraph G2c = G2;
90 BOOST_CHECK( G2.bipGraph() == G2c.bipGraph() );
91 BOOST_CHECK( G2.vars() == G2c.vars() );
92 BOOST_CHECK( G2.clusters() == G2c.clusters() );
93 }
94
95
96 BOOST_AUTO_TEST_CASE( QueriesTest ) {
97 Var v0( 0, 2 );
98 Var v1( 1, 3 );
99 Var v2( 2, 2 );
100 Var v3( 3, 4 );
101 Var v4( 4, 2 );
102 std::vector<Var> vs;
103 vs.push_back( v0 );
104 vs.push_back( v1 );
105 vs.push_back( v2 );
106 vs.push_back( v3 );
107 vs.push_back( v4 );
108 VarSet v01( v0, v1 );
109 VarSet v02( v0, v2 );
110 VarSet v03( v0, v3 );
111 VarSet v04( v0, v4 );
112 VarSet v12( v1, v2 );
113 VarSet v13( v1, v3 );
114 VarSet v14( v1, v4 );
115 VarSet v23( v2, v3 );
116 VarSet v24( v2, v4 );
117 VarSet v34( v3, v4 );
118 VarSet v123 = v12 | v3;
119 std::vector<VarSet> cl;
120 cl.push_back( v01 );
121 cl.push_back( v12 );
122 cl.push_back( v123 );
123 cl.push_back( v34 );
124 cl.push_back( v04 );
125 ClusterGraph G( cl );
126
127 BOOST_CHECK_EQUAL( G.nrVars(), 5 );
128 BOOST_CHECK_EQUAL( G.vars(), vs );
129 BOOST_CHECK_EQUAL( G.var(0), v0 );
130 BOOST_CHECK_EQUAL( G.var(1), v1 );
131 BOOST_CHECK_EQUAL( G.var(2), v2 );
132 BOOST_CHECK_EQUAL( G.var(3), v3 );
133 BOOST_CHECK_EQUAL( G.var(4), v4 );
134 BOOST_CHECK_EQUAL( G.nrClusters(), 5 );
135 BOOST_CHECK_EQUAL( G.clusters(), cl );
136 BOOST_CHECK_EQUAL( G.cluster(0), v01 );
137 BOOST_CHECK_EQUAL( G.cluster(1), v12 );
138 BOOST_CHECK_EQUAL( G.cluster(2), v123 );
139 BOOST_CHECK_EQUAL( G.cluster(3), v34 );
140 BOOST_CHECK_EQUAL( G.cluster(4), v04 );
141 BOOST_CHECK_EQUAL( G.findVar( v0 ), 0 );
142 BOOST_CHECK_EQUAL( G.findVar( v1 ), 1 );
143 BOOST_CHECK_EQUAL( G.findVar( v2 ), 2 );
144 BOOST_CHECK_EQUAL( G.findVar( v3 ), 3 );
145 BOOST_CHECK_EQUAL( G.findVar( v4 ), 4 );
146 BipartiteGraph H( 5, 5 );
147 H.addEdge( 0, 0 );
148 H.addEdge( 1, 0 );
149 H.addEdge( 1, 1 );
150 H.addEdge( 2, 1 );
151 H.addEdge( 1, 2 );
152 H.addEdge( 2, 2 );
153 H.addEdge( 3, 2 );
154 H.addEdge( 3, 3 );
155 H.addEdge( 4, 3 );
156 H.addEdge( 0, 4 );
157 H.addEdge( 4, 4 );
158 BOOST_CHECK( G.bipGraph() == H );
159
160 BOOST_CHECK_EQUAL( G.delta( 0 ), v14 );
161 BOOST_CHECK_EQUAL( G.delta( 1 ), v02 | v3 );
162 BOOST_CHECK_EQUAL( G.delta( 2 ), v13 );
163 BOOST_CHECK_EQUAL( G.delta( 3 ), v12 | v4 );
164 BOOST_CHECK_EQUAL( G.delta( 4 ), v03 );
165 BOOST_CHECK_EQUAL( G.Delta( 0 ), v14 | v0 );
166 BOOST_CHECK_EQUAL( G.Delta( 1 ), v01 | v23 );
167 BOOST_CHECK_EQUAL( G.Delta( 2 ), v13 | v2 );
168 BOOST_CHECK_EQUAL( G.Delta( 3 ), v12 | v34 );
169 BOOST_CHECK_EQUAL( G.Delta( 4 ), v03 | v4 );
170
171 BOOST_CHECK( !G.adj( 0, 0 ) );
172 BOOST_CHECK( G.adj( 0, 1 ) );
173 BOOST_CHECK( !G.adj( 0, 2 ) );
174 BOOST_CHECK( !G.adj( 0, 3 ) );
175 BOOST_CHECK( G.adj( 0, 4 ) );
176 BOOST_CHECK( G.adj( 1, 0 ) );
177 BOOST_CHECK( !G.adj( 1, 1 ) );
178 BOOST_CHECK( G.adj( 1, 2 ) );
179 BOOST_CHECK( G.adj( 1, 3 ) );
180 BOOST_CHECK( !G.adj( 1, 4 ) );
181 BOOST_CHECK( !G.adj( 2, 0 ) );
182 BOOST_CHECK( G.adj( 2, 1 ) );
183 BOOST_CHECK( !G.adj( 2, 2 ) );
184 BOOST_CHECK( G.adj( 2, 3 ) );
185 BOOST_CHECK( !G.adj( 2, 4 ) );
186 BOOST_CHECK( !G.adj( 3, 0 ) );
187 BOOST_CHECK( G.adj( 3, 1 ) );
188 BOOST_CHECK( G.adj( 3, 2 ) );
189 BOOST_CHECK( !G.adj( 3, 3 ) );
190 BOOST_CHECK( G.adj( 3, 4 ) );
191 BOOST_CHECK( G.adj( 4, 0 ) );
192 BOOST_CHECK( !G.adj( 4, 1 ) );
193 BOOST_CHECK( !G.adj( 4, 2 ) );
194 BOOST_CHECK( G.adj( 4, 3 ) );
195 BOOST_CHECK( !G.adj( 4, 4 ) );
196
197 BOOST_CHECK( G.isMaximal( 0 ) );
198 BOOST_CHECK( !G.isMaximal( 1 ) );
199 BOOST_CHECK( G.isMaximal( 2 ) );
200 BOOST_CHECK( G.isMaximal( 3 ) );
201 BOOST_CHECK( G.isMaximal( 4 ) );
202 }
203
204
205 BOOST_AUTO_TEST_CASE( OperationsTest ) {
206 Var v0( 0, 2 );
207 Var v1( 1, 3 );
208 Var v2( 2, 2 );
209 Var v3( 3, 4 );
210 Var v4( 4, 2 );
211 VarSet v01( v0, v1 );
212 VarSet v02( v0, v2 );
213 VarSet v03( v0, v3 );
214 VarSet v04( v0, v4 );
215 VarSet v12( v1, v2 );
216 VarSet v13( v1, v3 );
217 VarSet v14( v1, v4 );
218 VarSet v23( v2, v3 );
219 VarSet v24( v2, v4 );
220 VarSet v34( v3, v4 );
221 VarSet v123 = v12 | v3;
222 std::vector<VarSet> cl;
223 cl.push_back( v01 );
224 cl.push_back( v12 );
225 cl.push_back( v123 );
226 cl.push_back( v34 );
227 cl.push_back( v04 );
228 ClusterGraph G( cl );
229
230 BipartiteGraph H( 5, 5 );
231 H.addEdge( 0, 0 );
232 H.addEdge( 1, 0 );
233 H.addEdge( 1, 1 );
234 H.addEdge( 2, 1 );
235 H.addEdge( 1, 2 );
236 H.addEdge( 2, 2 );
237 H.addEdge( 3, 2 );
238 H.addEdge( 3, 3 );
239 H.addEdge( 4, 3 );
240 H.addEdge( 0, 4 );
241 H.addEdge( 4, 4 );
242 BOOST_CHECK( G.bipGraph() == H );
243
244 G.eraseNonMaximal();
245 BOOST_CHECK_EQUAL( G.nrClusters(), 4 );
246 H.eraseNode2( 1 );
247 BOOST_CHECK( G.bipGraph() == H );
248 G.eraseSubsuming( 4 );
249 BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
250 H.eraseNode2( 2 );
251 H.eraseNode2( 2 );
252 BOOST_CHECK( G.bipGraph() == H );
253 G.insert( v34 );
254 BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
255 G.insert( v123 );
256 BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
257 H.addNode2();
258 H.addEdge( 3, 2 );
259 H.addEdge( 4, 2 );
260 BOOST_CHECK( G.bipGraph() == H );
261 G.insert( v12 );
262 G.insert( v23 );
263 BOOST_CHECK_EQUAL( G.nrClusters(), 5 );
264 H.addNode2();
265 H.addNode2();
266 H.addEdge( 1, 3 );
267 H.addEdge( 2, 3 );
268 H.addEdge( 2, 4 );
269 H.addEdge( 3, 4 );
270 BOOST_CHECK( G.bipGraph() == H );
271 G.eraseNonMaximal();
272 BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
273 H.eraseNode2( 3 );
274 H.eraseNode2( 3 );
275 BOOST_CHECK( G.bipGraph() == H );
276 G.eraseSubsuming( 2 );
277 BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
278 H.eraseNode2( 1 );
279 BOOST_CHECK( G.bipGraph() == H );
280 G.eraseNonMaximal();
281 BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
282 BOOST_CHECK( G.bipGraph() == H );
283 G.eraseSubsuming( 0 );
284 BOOST_CHECK_EQUAL( G.nrClusters(), 1 );
285 H.eraseNode2( 0 );
286 BOOST_CHECK( G.bipGraph() == H );
287 G.eraseSubsuming( 4 );
288 BOOST_CHECK_EQUAL( G.nrClusters(), 0 );
289 H.eraseNode2( 0 );
290 BOOST_CHECK( G.bipGraph() == H );
291 }
292
293
294 BOOST_AUTO_TEST_CASE( VarElimTest ) {
295 Var v0( 0, 2 );
296 Var v1( 1, 3 );
297 Var v2( 2, 2 );
298 Var v3( 3, 4 );
299 Var v4( 4, 2 );
300 VarSet v01( v0, v1 );
301 VarSet v02( v0, v2 );
302 VarSet v03( v0, v3 );
303 VarSet v04( v0, v4 );
304 VarSet v12( v1, v2 );
305 VarSet v13( v1, v3 );
306 VarSet v14( v1, v4 );
307 VarSet v23( v2, v3 );
308 VarSet v24( v2, v4 );
309 VarSet v34( v3, v4 );
310 VarSet v123 = v12 | v3;
311 std::vector<VarSet> cl;
312 cl.push_back( v01 );
313 cl.push_back( v12 );
314 cl.push_back( v123 );
315 cl.push_back( v34 );
316 cl.push_back( v04 );
317 ClusterGraph G( cl );
318 ClusterGraph Gorg = G;
319
320 BipartiteGraph H( 5, 5 );
321 H.addEdge( 0, 0 );
322 H.addEdge( 1, 0 );
323 H.addEdge( 1, 1 );
324 H.addEdge( 2, 1 );
325 H.addEdge( 1, 2 );
326 H.addEdge( 2, 2 );
327 H.addEdge( 3, 2 );
328 H.addEdge( 3, 3 );
329 H.addEdge( 4, 3 );
330 H.addEdge( 0, 4 );
331 H.addEdge( 4, 4 );
332 BOOST_CHECK( G.bipGraph() == H );
333
334 BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 0 ), 1 );
335 BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 1 ), 2 );
336 BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 2 ), 0 );
337 BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 3 ), 2 );
338 BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 4 ), 1 );
339 cl.clear();
340 cl.push_back( v123 );
341 cl.push_back( v01 | v4 );
342 cl.push_back( v13 | v4 );
343 cl.push_back( v34 );
344 cl.push_back( v4 );
345 BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinFill ) ).clusters(), cl );
346
347 G = Gorg;
348 BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 0 ), 2*3 );
349 BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 1 ), 2*2+2*4 );
350 BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 2 ), 0 );
351 BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 3 ), 3*2+2*2 );
352 BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 4 ), 2*4 );
353 cl.clear();
354 cl.push_back( v123 );
355 cl.push_back( v01 | v4 );
356 cl.push_back( v13 | v4 );
357 cl.push_back( v34 );
358 cl.push_back( v4 );
359 BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_WeightedMinFill ) ).clusters(), cl );
360
361 G = Gorg;
362 BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 0 ), 2 );
363 BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 1 ), 3 );
364 BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 2 ), 2 );
365 BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 3 ), 3 );
366 BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 4 ), 2 );
367 cl.clear();
368 cl.push_back( v01 | v4 );
369 cl.push_back( v123 );
370 cl.push_back( v13 | v4 );
371 cl.push_back( v34 );
372 cl.push_back( v4 );
373 BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinNeighbors ) ).clusters(), cl );
374
375 G = Gorg;
376 BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 0 ), 3*2 );
377 BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 1 ), 2*2*4 );
378 BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 2 ), 3*4 );
379 BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 3 ), 3*2*2 );
380 BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 4 ), 2*4 );
381 cl.clear();
382 cl.push_back( v01 | v4 );
383 cl.push_back( v123 );
384 cl.push_back( v13 | v4 );
385 cl.push_back( v14 );
386 cl.push_back( v4 );
387 BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinWeight ) ).clusters(), cl );
388
389 G = Gorg;
390 std::vector<Var> vs;
391 vs.push_back( v4 );
392 vs.push_back( v3 );
393 vs.push_back( v2 );
394 vs.push_back( v1 );
395 vs.push_back( v0 );
396 cl.clear();
397 cl.push_back( v03 | v4 );
398 cl.push_back( v01 | v23 );
399 cl.push_back( v01 | v2 );
400 cl.push_back( v01 );
401 cl.push_back( v0 );
402 BOOST_CHECK_EQUAL( G.VarElim( sequentialVariableElimination( vs ) ).clusters(), cl );
403 }
404
405
406 BOOST_AUTO_TEST_CASE( IOTest ) {
407 Var v0( 0, 2 );
408 Var v1( 1, 3 );
409 Var v2( 2, 2 );
410 Var v3( 3, 4 );
411 VarSet v01( v0, v1 );
412 VarSet v02( v0, v2 );
413 VarSet v03( v0, v3 );
414 VarSet v12( v1, v2 );
415 VarSet v13( v1, v3 );
416 VarSet v23( v2, v3 );
417 std::vector<VarSet> cl;
418 cl.push_back( v01 );
419 cl.push_back( v12 );
420 cl.push_back( v23 );
421 cl.push_back( v13 );
422 ClusterGraph G( cl );
423
424 std::stringstream ss;
425 ss << G;
426 std::string s;
427 getline( ss, s );
428 BOOST_CHECK_EQUAL( s, "({x0, x1}, {x1, x2}, {x2, x3}, {x1, x3})" );
429 }