Extended SWIG python interface (inspired by Kyle Ellrott): inference is possible...
[libdai.git] / tests / unit / regiongraph_test.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/regiongraph.h>
10 #include <vector>
11 #include <strstream>
12
13
14 using namespace dai;
15
16
17 const double tol = 1e-8;
18
19
20 #define BOOST_TEST_MODULE RegionGraphTest
21
22
23 #include <boost/test/unit_test.hpp>
24 #include <boost/test/floating_point_comparison.hpp>
25
26
27 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
28 Var v0( 0, 2 );
29 Var v1( 1, 2 );
30 Var v2( 2, 2 );
31 VarSet v01( v0, v1 );
32 VarSet v02( v0, v2 );
33 VarSet v12( v1, v2 );
34 RegionGraph G;
35 BOOST_CHECK_EQUAL( G.vars(), std::vector<Var>() );
36 BOOST_CHECK_EQUAL( G.factors(), std::vector<Factor>() );
37 BOOST_CHECK( G.bipGraph() == BipartiteGraph() );
38 BOOST_CHECK( G.DAG() == BipartiteGraph() );
39 BOOST_CHECK_EQUAL( G.nrORs(), 0 );
40 BOOST_CHECK_EQUAL( G.nrIRs(), 0 );
41
42 std::vector<Factor> facs;
43 facs.push_back( Factor( v01 ) );
44 facs.push_back( Factor( v02 ) );
45 facs.push_back( Factor( v12 ) );
46 facs.push_back( Factor( v1 ) );
47 std::vector<Var> vars;
48 vars.push_back( v0 );
49 vars.push_back( v1 );
50 vars.push_back( v2 );
51 BipartiteGraph dag( 3, 3 );
52 dag.addEdge( 0, 0 );
53 dag.addEdge( 0, 1 );
54 dag.addEdge( 1, 0 );
55 dag.addEdge( 1, 2 );
56 dag.addEdge( 2, 1 );
57 dag.addEdge( 2, 2 );
58 BipartiteGraph bipgraph( dag );
59 bipgraph.addNode2();
60 bipgraph.addEdge( 1, 3 );
61
62 FactorGraph fg(facs);
63 RegionGraph G1( fg, fg.maximalFactorDomains() );
64 BOOST_CHECK_EQUAL( G1.vars(), vars );
65 BOOST_CHECK_EQUAL( G1.factors(), facs );
66 BOOST_CHECK( G1.bipGraph() == bipgraph );
67 BOOST_CHECK( G1.DAG() == dag );
68 BOOST_CHECK_EQUAL( G1.nrORs(), 3 );
69 BOOST_CHECK_EQUAL( G1.OR(0).c(), 1 );
70 BOOST_CHECK_EQUAL( G1.OR(0), facs[0] * facs[3] );
71 BOOST_CHECK_EQUAL( G1.OR(1).c(), 1 );
72 BOOST_CHECK_EQUAL( G1.OR(1), facs[1] );
73 BOOST_CHECK_EQUAL( G1.OR(2).c(), 1 );
74 BOOST_CHECK_EQUAL( G1.OR(2), facs[2] );
75 BOOST_CHECK_EQUAL( G1.nrIRs(), 3 );
76 BOOST_CHECK_EQUAL( G1.IR(0).c(), -1 );
77 BOOST_CHECK_EQUAL( G1.IR(0), v0 );
78 BOOST_CHECK_EQUAL( G1.IR(1).c(), -1 );
79 BOOST_CHECK_EQUAL( G1.IR(1), v1 );
80 BOOST_CHECK_EQUAL( G1.IR(2).c(), -1 );
81 BOOST_CHECK_EQUAL( G1.IR(2), v2 );
82 BOOST_CHECK_EQUAL( G1.fac2OR(0), 0 );
83 BOOST_CHECK_EQUAL( G1.fac2OR(1), 1 );
84 BOOST_CHECK_EQUAL( G1.fac2OR(2), 2 );
85 BOOST_CHECK_EQUAL( G1.fac2OR(3), 0 );
86 BOOST_CHECK( G1.checkCountingNumbers() );
87
88 std::vector<VarSet> ors;
89 std::vector<Region> irs;
90 ors.push_back( v01 | v2 );
91 irs.push_back( Region( v01, 0.0 ) );
92 irs.push_back( Region( v2, 0.0 ) );
93 typedef std::pair<size_t,size_t> edge;
94 std::vector<edge> edges;
95 edges.push_back( edge( 0, 0 ) );
96 edges.push_back( edge( 0, 1 ) );
97 BipartiteGraph dag2( 1, 2, edges.begin(), edges.end() );
98 RegionGraph G2( fg, ors, irs, edges );
99 BOOST_CHECK_EQUAL( G2.vars(), vars );
100 BOOST_CHECK_EQUAL( G2.factors(), facs );
101 BOOST_CHECK( G2.bipGraph() == bipgraph );
102 BOOST_CHECK( G2.DAG() == dag2 );
103 BOOST_CHECK_EQUAL( G2.nrORs(), 1 );
104 BOOST_CHECK_EQUAL( G2.OR(0).c(), 1 );
105 BOOST_CHECK_EQUAL( G2.OR(0), facs[0] * facs[1] * facs[2] * facs[3] );
106 BOOST_CHECK_EQUAL( G2.nrIRs(), 2 );
107 BOOST_CHECK_EQUAL( G2.IR(0), irs[0] );
108 BOOST_CHECK_EQUAL( G2.IR(1), irs[1] );
109 BOOST_CHECK_EQUAL( G2.fac2OR(0), 0 );
110 BOOST_CHECK_EQUAL( G2.fac2OR(1), 0 );
111 BOOST_CHECK_EQUAL( G2.fac2OR(2), 0 );
112 BOOST_CHECK_EQUAL( G2.fac2OR(3), 0 );
113 BOOST_CHECK( G2.checkCountingNumbers() );
114
115 RegionGraph *G3 = G1.clone();
116 BOOST_CHECK_EQUAL( G3->vars(), G1.vars() );
117 BOOST_CHECK_EQUAL( G3->factors(), G1.factors() );
118 BOOST_CHECK( G3->bipGraph() == bipgraph );
119 BOOST_CHECK( G3->DAG() == dag );
120 BOOST_CHECK_EQUAL( G3->nrORs(), G1.nrORs() );
121 BOOST_CHECK_EQUAL( G3->OR(0), G1.OR(0) );
122 BOOST_CHECK_EQUAL( G3->OR(1), G1.OR(1) );
123 BOOST_CHECK_EQUAL( G3->OR(2), G1.OR(2) );
124 BOOST_CHECK_EQUAL( G3->nrIRs(), G1.nrIRs() );
125 BOOST_CHECK_EQUAL( G3->IR(0), G1.IR(0) );
126 BOOST_CHECK_EQUAL( G3->IR(1), G1.IR(1) );
127 BOOST_CHECK_EQUAL( G3->IR(2), G1.IR(2) );
128 BOOST_CHECK_EQUAL( G3->fac2OR(0), G1.fac2OR(0) );
129 BOOST_CHECK_EQUAL( G3->fac2OR(1), G1.fac2OR(1) );
130 BOOST_CHECK_EQUAL( G3->fac2OR(2), G1.fac2OR(2) );
131 BOOST_CHECK_EQUAL( G3->fac2OR(3), G1.fac2OR(3) );
132 BOOST_CHECK( G3->checkCountingNumbers() );
133 delete G3;
134
135 RegionGraph G4 = G1;
136 BOOST_CHECK_EQUAL( G4.vars(), G1.vars() );
137 BOOST_CHECK_EQUAL( G4.factors(), G1.factors() );
138 BOOST_CHECK( G4.bipGraph() == bipgraph );
139 BOOST_CHECK( G4.DAG() == dag );
140 BOOST_CHECK_EQUAL( G4.nrORs(), G1.nrORs() );
141 BOOST_CHECK_EQUAL( G4.OR(0), G1.OR(0) );
142 BOOST_CHECK_EQUAL( G4.OR(1), G1.OR(1) );
143 BOOST_CHECK_EQUAL( G4.OR(2), G1.OR(2) );
144 BOOST_CHECK_EQUAL( G4.nrIRs(), G1.nrIRs() );
145 BOOST_CHECK_EQUAL( G4.IR(0), G1.IR(0) );
146 BOOST_CHECK_EQUAL( G4.IR(1), G1.IR(1) );
147 BOOST_CHECK_EQUAL( G4.IR(2), G1.IR(2) );
148 BOOST_CHECK_EQUAL( G4.fac2OR(0), G1.fac2OR(0) );
149 BOOST_CHECK_EQUAL( G4.fac2OR(1), G1.fac2OR(1) );
150 BOOST_CHECK_EQUAL( G4.fac2OR(2), G1.fac2OR(2) );
151 BOOST_CHECK_EQUAL( G4.fac2OR(3), G1.fac2OR(3) );
152 BOOST_CHECK( G4.checkCountingNumbers() );
153
154 RegionGraph G5( G1 );
155 BOOST_CHECK_EQUAL( G5.vars(), G1.vars() );
156 BOOST_CHECK_EQUAL( G5.factors(), G1.factors() );
157 BOOST_CHECK( G5.bipGraph() == bipgraph );
158 BOOST_CHECK( G5.DAG() == dag );
159 BOOST_CHECK_EQUAL( G5.nrORs(), G1.nrORs() );
160 BOOST_CHECK_EQUAL( G5.OR(0), G1.OR(0) );
161 BOOST_CHECK_EQUAL( G5.OR(1), G1.OR(1) );
162 BOOST_CHECK_EQUAL( G5.OR(2), G1.OR(2) );
163 BOOST_CHECK_EQUAL( G5.nrIRs(), G1.nrIRs() );
164 BOOST_CHECK_EQUAL( G5.IR(0), G1.IR(0) );
165 BOOST_CHECK_EQUAL( G5.IR(1), G1.IR(1) );
166 BOOST_CHECK_EQUAL( G5.IR(2), G1.IR(2) );
167 BOOST_CHECK_EQUAL( G5.fac2OR(0), G1.fac2OR(0) );
168 BOOST_CHECK_EQUAL( G5.fac2OR(1), G1.fac2OR(1) );
169 BOOST_CHECK_EQUAL( G5.fac2OR(2), G1.fac2OR(2) );
170 BOOST_CHECK_EQUAL( G5.fac2OR(3), G1.fac2OR(3) );
171 BOOST_CHECK( G5.checkCountingNumbers() );
172 }
173
174
175 BOOST_AUTO_TEST_CASE( AccMutTest ) {
176 Var v0( 0, 2 );
177 Var v1( 1, 2 );
178 Var v2( 2, 2 );
179 VarSet v01( v0, v1 );
180 VarSet v02( v0, v2 );
181 VarSet v12( v1, v2 );
182 std::vector<Factor> facs;
183 facs.push_back( Factor( v01 ) );
184 facs.push_back( Factor( v02 ) );
185 facs.push_back( Factor( v12 ) );
186 facs.push_back( Factor( v1 ) );
187 std::vector<Var> vars;
188 vars.push_back( v0 );
189 vars.push_back( v1 );
190 vars.push_back( v2 );
191 BipartiteGraph dag( 3, 3 );
192 dag.addEdge( 0, 0 );
193 dag.addEdge( 0, 1 );
194 dag.addEdge( 1, 0 );
195 dag.addEdge( 1, 2 );
196 dag.addEdge( 2, 1 );
197 dag.addEdge( 2, 2 );
198 BipartiteGraph bipgraph( dag );
199 bipgraph.addNode2();
200 bipgraph.addEdge( 1, 3 );
201
202 FactorGraph fg( facs );
203 RegionGraph G( fg, fg.maximalFactorDomains() );
204 BOOST_CHECK_EQUAL( G.var(0), v0 );
205 BOOST_CHECK_EQUAL( G.var(1), v1 );
206 BOOST_CHECK_EQUAL( G.var(2), v2 );
207 BOOST_CHECK_EQUAL( G.vars(), vars );
208 BOOST_CHECK_EQUAL( G.factor(0), facs[0] );
209 BOOST_CHECK_EQUAL( G.factor(1), facs[1] );
210 BOOST_CHECK_EQUAL( G.factor(2), facs[2] );
211 BOOST_CHECK_EQUAL( G.factor(3), facs[3] );
212 BOOST_CHECK_EQUAL( G.factors(), facs );
213 BOOST_CHECK( G.bipGraph() == bipgraph );
214 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
215 BOOST_CHECK_EQUAL( G.nbV(0,0), 0 );
216 BOOST_CHECK_EQUAL( G.nbV(0,1), 1 );
217 BOOST_CHECK_EQUAL( G.nbV(1).size(), 3 );
218 BOOST_CHECK_EQUAL( G.nbV(1,0), 0 );
219 BOOST_CHECK_EQUAL( G.nbV(1,1), 2 );
220 BOOST_CHECK_EQUAL( G.nbV(1,2), 3 );
221 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
222 BOOST_CHECK_EQUAL( G.nbV(2,0), 1 );
223 BOOST_CHECK_EQUAL( G.nbV(2,1), 2 );
224 BOOST_CHECK_EQUAL( G.nbF(0).size(), 2 );
225 BOOST_CHECK_EQUAL( G.nbF(0,0), 0 );
226 BOOST_CHECK_EQUAL( G.nbF(0,1), 1 );
227 BOOST_CHECK_EQUAL( G.nbF(1).size(), 2 );
228 BOOST_CHECK_EQUAL( G.nbF(1,0), 0 );
229 BOOST_CHECK_EQUAL( G.nbF(1,1), 2 );
230 BOOST_CHECK_EQUAL( G.nbF(2).size(), 2 );
231 BOOST_CHECK_EQUAL( G.nbF(2,0), 1 );
232 BOOST_CHECK_EQUAL( G.nbF(2,1), 2 );
233 BOOST_CHECK_EQUAL( G.nbF(3).size(), 1 );
234 BOOST_CHECK_EQUAL( G.nbF(3,0), 1 );
235 BOOST_CHECK( G.DAG() == dag );
236 BOOST_CHECK_EQUAL( G.nrORs(), 3 );
237 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
238 BOOST_CHECK_EQUAL( G.OR(0).vars(), v01 );
239 BOOST_CHECK_EQUAL( G.OR(0), facs[0] * facs[3] );
240 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
241 BOOST_CHECK_EQUAL( G.OR(1).vars(), v02 );
242 BOOST_CHECK_EQUAL( G.OR(1), facs[1] );
243 BOOST_CHECK_EQUAL( G.OR(2).c(), 1 );
244 BOOST_CHECK_EQUAL( G.OR(2).vars(), v12 );
245 BOOST_CHECK_EQUAL( G.OR(2), facs[2] );
246 BOOST_CHECK_EQUAL( G.nrIRs(), 3 );
247 BOOST_CHECK_EQUAL( G.IR(0).c(), -1 );
248 BOOST_CHECK_EQUAL( G.IR(0), v0 );
249 BOOST_CHECK_EQUAL( G.IR(1).c(), -1 );
250 BOOST_CHECK_EQUAL( G.IR(1), v1 );
251 BOOST_CHECK_EQUAL( G.IR(2).c(), -1 );
252 BOOST_CHECK_EQUAL( G.IR(2), v2 );
253 BOOST_CHECK_EQUAL( G.fac2OR(0), 0 );
254 BOOST_CHECK_EQUAL( G.fac2OR(1), 1 );
255 BOOST_CHECK_EQUAL( G.fac2OR(2), 2 );
256 BOOST_CHECK_EQUAL( G.fac2OR(3), 0 );
257 BOOST_CHECK_EQUAL( G.nbOR(0).size(), 2 );
258 BOOST_CHECK_EQUAL( G.nbOR(0)[0], 0 );
259 BOOST_CHECK_EQUAL( G.nbOR(0)[1], 1 );
260 BOOST_CHECK_EQUAL( G.nbOR(1).size(), 2 );
261 BOOST_CHECK_EQUAL( G.nbOR(1)[0], 0 );
262 BOOST_CHECK_EQUAL( G.nbOR(1)[1], 2 );
263 BOOST_CHECK_EQUAL( G.nbOR(2).size(), 2 );
264 BOOST_CHECK_EQUAL( G.nbOR(2)[0], 1 );
265 BOOST_CHECK_EQUAL( G.nbOR(2)[1], 2 );
266 BOOST_CHECK_EQUAL( G.nbIR(0).size(), 2 );
267 BOOST_CHECK_EQUAL( G.nbIR(0)[0], 0 );
268 BOOST_CHECK_EQUAL( G.nbIR(0)[1], 1 );
269 BOOST_CHECK_EQUAL( G.nbIR(1).size(), 2 );
270 BOOST_CHECK_EQUAL( G.nbIR(1)[0], 0 );
271 BOOST_CHECK_EQUAL( G.nbIR(1)[1], 2 );
272 BOOST_CHECK_EQUAL( G.nbIR(2).size(), 2 );
273 BOOST_CHECK_EQUAL( G.nbIR(2)[0], 1 );
274 BOOST_CHECK_EQUAL( G.nbIR(2)[1], 2 );
275 }
276
277
278 BOOST_AUTO_TEST_CASE( QueriesTest ) {
279 Var v0( 0, 2 );
280 Var v1( 1, 2 );
281 Var v2( 2, 2 );
282 VarSet v01( v0, v1 );
283 VarSet v02( v0, v2 );
284 VarSet v12( v1, v2 );
285 VarSet v012 = v01 | v2;
286
287 RegionGraph G0;
288 BOOST_CHECK_EQUAL( G0.nrVars(), 0 );
289 BOOST_CHECK_EQUAL( G0.nrFactors(), 0 );
290 BOOST_CHECK_EQUAL( G0.nrEdges(), 0 );
291 BOOST_CHECK_THROW( G0.findVar( v0 ), Exception );
292 BOOST_CHECK_THROW( G0.findVars( v01 ), Exception );
293 BOOST_CHECK_THROW( G0.findFactor( v01 ), Exception );
294 #ifdef DAI_DBEUG
295 BOOST_CHECK_THROW( G0.delta( 0 ), Exception );
296 BOOST_CHECK_THROW( G0.Delta( 0 ), Exception );
297 BOOST_CHECK_THROW( G0.delta( v0 ), Exception );
298 BOOST_CHECK_THROW( G0.Delta( v0 ), Exception );
299 BOOST_CHECK_THROW( G0.fac2OR( 0 ), Exception );
300 #endif
301 BOOST_CHECK( G0.isConnected() );
302 BOOST_CHECK( G0.isTree() );
303 BOOST_CHECK( G0.isBinary() );
304 BOOST_CHECK( G0.isPairwise() );
305 BOOST_CHECK( G0.MarkovGraph() == GraphAL() );
306 BOOST_CHECK( G0.bipGraph() == BipartiteGraph() );
307 BOOST_CHECK_EQUAL( G0.maximalFactorDomains().size(), 1 );
308 BOOST_CHECK( G0.DAG() == BipartiteGraph() );
309 BOOST_CHECK( G0.checkCountingNumbers() );
310 BOOST_CHECK_EQUAL( G0.nrORs(), 0 );
311 BOOST_CHECK_EQUAL( G0.nrIRs(), 0 );
312
313 std::vector<Factor> facs;
314 facs.push_back( Factor( v01 ) );
315 facs.push_back( Factor( v12 ) );
316 facs.push_back( Factor( v1 ) );
317 std::vector<Var> vars;
318 vars.push_back( v0 );
319 vars.push_back( v1 );
320 vars.push_back( v2 );
321 GraphAL H(3);
322 H.addEdge( 0, 1 );
323 H.addEdge( 1, 2 );
324 BipartiteGraph K(3, 3);
325 K.addEdge( 0, 0 );
326 K.addEdge( 1, 0 );
327 K.addEdge( 1, 1 );
328 K.addEdge( 2, 1 );
329 K.addEdge( 1, 2 );
330 BipartiteGraph dag(2, 1);
331 dag.addEdge( 0, 0 );
332 dag.addEdge( 1, 0 );
333
334 FactorGraph fg( facs );
335 RegionGraph G1( fg, fg.maximalFactorDomains() );
336 BOOST_CHECK_EQUAL( G1.nrVars(), 3 );
337 BOOST_CHECK_EQUAL( G1.nrFactors(), 3 );
338 BOOST_CHECK_EQUAL( G1.nrEdges(), 5 );
339 BOOST_CHECK_EQUAL( G1.findVar( v0 ), 0 );
340 BOOST_CHECK_EQUAL( G1.findVar( v1 ), 1 );
341 BOOST_CHECK_EQUAL( G1.findVar( v2 ), 2 );
342 BOOST_CHECK_EQUAL( G1.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
343 BOOST_CHECK_EQUAL( G1.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
344 BOOST_CHECK_EQUAL( G1.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
345 BOOST_CHECK_EQUAL( G1.findFactor( v01 ), 0 );
346 BOOST_CHECK_EQUAL( G1.findFactor( v12 ), 1 );
347 BOOST_CHECK_EQUAL( G1.findFactor( v1 ), 2 );
348 BOOST_CHECK_THROW( G1.findFactor( v02 ), Exception );
349 BOOST_CHECK_EQUAL( G1.delta( 0 ), v1 );
350 BOOST_CHECK_EQUAL( G1.delta( 1 ), v02 );
351 BOOST_CHECK_EQUAL( G1.delta( 2 ), v1 );
352 BOOST_CHECK_EQUAL( G1.Delta( 0 ), v01 );
353 BOOST_CHECK_EQUAL( G1.Delta( 1 ), v012 );
354 BOOST_CHECK_EQUAL( G1.Delta( 2 ), v12 );
355 BOOST_CHECK_EQUAL( G1.delta( v0 ), v1 );
356 BOOST_CHECK_EQUAL( G1.delta( v1 ), v02 );
357 BOOST_CHECK_EQUAL( G1.delta( v2 ), v1 );
358 BOOST_CHECK_EQUAL( G1.delta( v01 ), v2 );
359 BOOST_CHECK_EQUAL( G1.delta( v02 ), v1 );
360 BOOST_CHECK_EQUAL( G1.delta( v12 ), v0 );
361 BOOST_CHECK_EQUAL( G1.delta( v012 ), VarSet() );
362 BOOST_CHECK_EQUAL( G1.Delta( v0 ), v01 );
363 BOOST_CHECK_EQUAL( G1.Delta( v1 ), v012 );
364 BOOST_CHECK_EQUAL( G1.Delta( v2 ), v12 );
365 BOOST_CHECK_EQUAL( G1.Delta( v01 ), v012 );
366 BOOST_CHECK_EQUAL( G1.Delta( v02 ), v012 );
367 BOOST_CHECK_EQUAL( G1.Delta( v12 ), v012 );
368 BOOST_CHECK_EQUAL( G1.Delta( v012 ), v012 );
369 BOOST_CHECK( G1.isConnected() );
370 BOOST_CHECK( G1.isTree() );
371 BOOST_CHECK( G1.isBinary() );
372 BOOST_CHECK( G1.isPairwise() );
373 BOOST_CHECK( G1.MarkovGraph() == H );
374 BOOST_CHECK( G1.bipGraph() == K );
375 BOOST_CHECK_EQUAL( G1.maximalFactorDomains().size(), 2 );
376 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[0], v01 );
377 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[1], v12 );
378 BOOST_CHECK( G1.DAG() == dag );
379 BOOST_CHECK( G1.checkCountingNumbers() );
380 BOOST_CHECK_EQUAL( G1.nrORs(), 2 );
381 BOOST_CHECK_EQUAL( G1.OR(0).c(), 1 );
382 BOOST_CHECK_EQUAL( G1.OR(0), facs[0] * facs[2] );
383 BOOST_CHECK_EQUAL( G1.OR(1).c(), 1 );
384 BOOST_CHECK_EQUAL( G1.OR(1), facs[1] );
385 BOOST_CHECK_EQUAL( G1.nrIRs(), 1 );
386 BOOST_CHECK_EQUAL( G1.IR(0).c(), -1 );
387 BOOST_CHECK_EQUAL( G1.IR(0), v1 );
388 BOOST_CHECK_EQUAL( G1.fac2OR(0), 0 );
389 BOOST_CHECK_EQUAL( G1.fac2OR(1), 1 );
390 BOOST_CHECK_EQUAL( G1.fac2OR(2), 0 );
391
392 facs.push_back( Factor( v02 ) );
393 H.addEdge( 0, 2 );
394 K.addNode2();
395 K.addEdge( 0, 3 );
396 K.addEdge( 2, 3 );
397 dag = BipartiteGraph( 3, 3 );
398 dag.addEdge( 0, 0 );
399 dag.addEdge( 0, 1 );
400 dag.addEdge( 1, 1 );
401 dag.addEdge( 1, 2 );
402 dag.addEdge( 2, 0 );
403 dag.addEdge( 2, 2 );
404 fg = FactorGraph( facs );
405 RegionGraph G2( fg, fg.maximalFactorDomains() );
406 BOOST_CHECK_EQUAL( G2.nrVars(), 3 );
407 BOOST_CHECK_EQUAL( G2.nrFactors(), 4 );
408 BOOST_CHECK_EQUAL( G2.nrEdges(), 7 );
409 BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
410 BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
411 BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
412 BOOST_CHECK_EQUAL( G2.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
413 BOOST_CHECK_EQUAL( G2.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
414 BOOST_CHECK_EQUAL( G2.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
415 BOOST_CHECK_EQUAL( G2.findFactor( v01 ), 0 );
416 BOOST_CHECK_EQUAL( G2.findFactor( v12 ), 1 );
417 BOOST_CHECK_EQUAL( G2.findFactor( v1 ), 2 );
418 BOOST_CHECK_EQUAL( G2.findFactor( v02 ), 3 );
419 BOOST_CHECK_EQUAL( G2.delta( 0 ), v12 );
420 BOOST_CHECK_EQUAL( G2.delta( 1 ), v02 );
421 BOOST_CHECK_EQUAL( G2.delta( 2 ), v01 );
422 BOOST_CHECK_EQUAL( G2.Delta( 0 ), v012 );
423 BOOST_CHECK_EQUAL( G2.Delta( 1 ), v012 );
424 BOOST_CHECK_EQUAL( G2.Delta( 2 ), v012 );
425 BOOST_CHECK( G2.isConnected() );
426 BOOST_CHECK( !G2.isTree() );
427 BOOST_CHECK( G2.isBinary() );
428 BOOST_CHECK( G2.isPairwise() );
429 BOOST_CHECK( G2.MarkovGraph() == H );
430 BOOST_CHECK( G2.bipGraph() == K );
431 BOOST_CHECK_EQUAL( G2.maximalFactorDomains().size(), 3 );
432 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[0], v01 );
433 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[1], v12 );
434 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[2], v02 );
435 BOOST_CHECK( G2.DAG() == dag );
436 BOOST_CHECK( G2.checkCountingNumbers() );
437 BOOST_CHECK_EQUAL( G2.nrORs(), 3 );
438 BOOST_CHECK_EQUAL( G2.OR(0).c(), 1 );
439 BOOST_CHECK_EQUAL( G2.OR(0), facs[0] * facs[2] );
440 BOOST_CHECK_EQUAL( G2.OR(1).c(), 1 );
441 BOOST_CHECK_EQUAL( G2.OR(1), facs[1] );
442 BOOST_CHECK_EQUAL( G2.OR(2).c(), 1 );
443 BOOST_CHECK_EQUAL( G2.OR(2), facs[3] );
444 BOOST_CHECK_EQUAL( G2.nrIRs(), 3 );
445 BOOST_CHECK_EQUAL( G2.IR(0).c(), -1.0 );
446 BOOST_CHECK_EQUAL( G2.IR(0), v0 );
447 BOOST_CHECK_EQUAL( G2.IR(1).c(), -1.0 );
448 BOOST_CHECK_EQUAL( G2.IR(1), v1 );
449 BOOST_CHECK_EQUAL( G2.IR(2).c(), -1.0 );
450 BOOST_CHECK_EQUAL( G2.IR(2), v2 );
451 BOOST_CHECK_EQUAL( G2.fac2OR(0), 0 );
452 BOOST_CHECK_EQUAL( G2.fac2OR(1), 1 );
453 BOOST_CHECK_EQUAL( G2.fac2OR(2), 0 );
454 BOOST_CHECK_EQUAL( G2.fac2OR(3), 2 );
455
456 Var v3( 3, 3 );
457 VarSet v03( v0, v3 );
458 VarSet v13( v1, v3 );
459 VarSet v23( v2, v3 );
460 VarSet v013 = v01 | v3;
461 VarSet v023 = v02 | v3;
462 VarSet v123 = v12 | v3;
463 VarSet v0123 = v012 | v3;
464 vars.push_back( v3 );
465 facs.push_back( Factor( v3 ) );
466 H.addNode();
467 K.addNode1();
468 K.addNode2();
469 K.addEdge( 3, 4 );
470 dag.addNode1();
471 fg = FactorGraph( facs );
472 RegionGraph G3( fg, fg.maximalFactorDomains() );
473 BOOST_CHECK_EQUAL( G3.nrVars(), 4 );
474 BOOST_CHECK_EQUAL( G3.nrFactors(), 5 );
475 BOOST_CHECK_EQUAL( G3.nrEdges(), 8 );
476 BOOST_CHECK_EQUAL( G3.findVar( v0 ), 0 );
477 BOOST_CHECK_EQUAL( G3.findVar( v1 ), 1 );
478 BOOST_CHECK_EQUAL( G3.findVar( v2 ), 2 );
479 BOOST_CHECK_EQUAL( G3.findVar( v3 ), 3 );
480 BOOST_CHECK_EQUAL( G3.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
481 BOOST_CHECK_EQUAL( G3.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
482 BOOST_CHECK_EQUAL( G3.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
483 BOOST_CHECK_EQUAL( G3.findFactor( v01 ), 0 );
484 BOOST_CHECK_EQUAL( G3.findFactor( v12 ), 1 );
485 BOOST_CHECK_EQUAL( G3.findFactor( v1 ), 2 );
486 BOOST_CHECK_EQUAL( G3.findFactor( v02 ), 3 );
487 BOOST_CHECK_EQUAL( G3.findFactor( v3 ), 4 );
488 BOOST_CHECK_THROW( G3.findFactor( v23 ), Exception );
489 BOOST_CHECK_EQUAL( G3.delta( 0 ), v12 );
490 BOOST_CHECK_EQUAL( G3.delta( 1 ), v02 );
491 BOOST_CHECK_EQUAL( G3.delta( 2 ), v01 );
492 BOOST_CHECK_EQUAL( G3.delta( 3 ), VarSet() );
493 BOOST_CHECK_EQUAL( G3.Delta( 0 ), v012 );
494 BOOST_CHECK_EQUAL( G3.Delta( 1 ), v012 );
495 BOOST_CHECK_EQUAL( G3.Delta( 2 ), v012 );
496 BOOST_CHECK_EQUAL( G3.Delta( 3 ), v3 );
497 BOOST_CHECK( !G3.isConnected() );
498 BOOST_CHECK( !G3.isTree() );
499 BOOST_CHECK( !G3.isBinary() );
500 BOOST_CHECK( G3.isPairwise() );
501 BOOST_CHECK( G3.MarkovGraph() == H );
502 BOOST_CHECK( G3.bipGraph() == K );
503 BOOST_CHECK_EQUAL( G3.maximalFactorDomains().size(), 4 );
504 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[0], v01 );
505 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[1], v12 );
506 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[2], v02 );
507 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[3], v3 );
508 BOOST_CHECK( G3.DAG() == dag );
509 BOOST_CHECK( G3.checkCountingNumbers() );
510 BOOST_CHECK_EQUAL( G3.nrORs(), 4 );
511 BOOST_CHECK_EQUAL( G3.OR(0).c(), 1 );
512 BOOST_CHECK_EQUAL( G3.OR(0), facs[0] * facs[2] );
513 BOOST_CHECK_EQUAL( G3.OR(1).c(), 1 );
514 BOOST_CHECK_EQUAL( G3.OR(1), facs[1] );
515 BOOST_CHECK_EQUAL( G3.OR(2).c(), 1 );
516 BOOST_CHECK_EQUAL( G3.OR(2), facs[3] );
517 BOOST_CHECK_EQUAL( G3.OR(3).c(), 1 );
518 BOOST_CHECK_EQUAL( G3.OR(3), facs[4] );
519 BOOST_CHECK_EQUAL( G3.nrIRs(), 3 );
520 BOOST_CHECK_EQUAL( G3.IR(0).c(), -1.0 );
521 BOOST_CHECK_EQUAL( G3.IR(0), v0 );
522 BOOST_CHECK_EQUAL( G3.IR(1).c(), -1.0 );
523 BOOST_CHECK_EQUAL( G3.IR(1), v1 );
524 BOOST_CHECK_EQUAL( G3.IR(2).c(), -1.0 );
525 BOOST_CHECK_EQUAL( G3.IR(2), v2 );
526 BOOST_CHECK_EQUAL( G3.fac2OR(0), 0 );
527 BOOST_CHECK_EQUAL( G3.fac2OR(1), 1 );
528 BOOST_CHECK_EQUAL( G3.fac2OR(2), 0 );
529 BOOST_CHECK_EQUAL( G3.fac2OR(3), 2 );
530 BOOST_CHECK_EQUAL( G3.fac2OR(4), 3 );
531
532 facs.push_back( Factor( v123 ) );
533 H.addEdge( 1, 3 );
534 H.addEdge( 2, 3 );
535 K.addNode2();
536 K.addEdge( 1, 5 );
537 K.addEdge( 2, 5 );
538 K.addEdge( 3, 5 );
539 dag = BipartiteGraph( 3, 3 );
540 dag.addEdge( 0, 0 );
541 dag.addEdge( 0, 1 );
542 dag.addEdge( 1, 0 );
543 dag.addEdge( 1, 2 );
544 dag.addEdge( 2, 1 );
545 dag.addEdge( 2, 2 );
546 fg = FactorGraph( facs );
547 RegionGraph G4( fg, fg.maximalFactorDomains() );
548 BOOST_CHECK_EQUAL( G4.nrVars(), 4 );
549 BOOST_CHECK_EQUAL( G4.nrFactors(), 6 );
550 BOOST_CHECK_EQUAL( G4.nrEdges(), 11 );
551 BOOST_CHECK_EQUAL( G4.findVar( v0 ), 0 );
552 BOOST_CHECK_EQUAL( G4.findVar( v1 ), 1 );
553 BOOST_CHECK_EQUAL( G4.findVar( v2 ), 2 );
554 BOOST_CHECK_EQUAL( G4.findVar( v3 ), 3 );
555 BOOST_CHECK_EQUAL( G4.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
556 BOOST_CHECK_EQUAL( G4.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
557 BOOST_CHECK_EQUAL( G4.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
558 BOOST_CHECK_EQUAL( G4.findFactor( v01 ), 0 );
559 BOOST_CHECK_EQUAL( G4.findFactor( v12 ), 1 );
560 BOOST_CHECK_EQUAL( G4.findFactor( v1 ), 2 );
561 BOOST_CHECK_EQUAL( G4.findFactor( v02 ), 3 );
562 BOOST_CHECK_EQUAL( G4.findFactor( v3 ), 4 );
563 BOOST_CHECK_EQUAL( G4.findFactor( v123 ), 5 );
564 BOOST_CHECK_THROW( G4.findFactor( v23 ), Exception );
565 BOOST_CHECK_EQUAL( G4.delta( 0 ), v12 );
566 BOOST_CHECK_EQUAL( G4.delta( 1 ), v023 );
567 BOOST_CHECK_EQUAL( G4.delta( 2 ), v013 );
568 BOOST_CHECK_EQUAL( G4.delta( 3 ), v12 );
569 BOOST_CHECK_EQUAL( G4.Delta( 0 ), v012 );
570 BOOST_CHECK_EQUAL( G4.Delta( 1 ), v0123 );
571 BOOST_CHECK_EQUAL( G4.Delta( 2 ), v0123 );
572 BOOST_CHECK_EQUAL( G4.Delta( 3 ), v123 );
573 BOOST_CHECK( G4.isConnected() );
574 BOOST_CHECK( !G4.isTree() );
575 BOOST_CHECK( !G4.isBinary() );
576 BOOST_CHECK( !G4.isPairwise() );
577 BOOST_CHECK( G4.MarkovGraph() == H );
578 BOOST_CHECK( G4.bipGraph() == K );
579 BOOST_CHECK_EQUAL( G4.maximalFactorDomains().size(), 3 );
580 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[0], v01 );
581 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[1], v02 );
582 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[2], v123 );
583 BOOST_CHECK( G4.DAG() == dag );
584 BOOST_CHECK( G4.checkCountingNumbers() );
585 BOOST_CHECK_EQUAL( G4.nrORs(), 3 );
586 BOOST_CHECK_EQUAL( G4.OR(0).c(), 1 );
587 BOOST_CHECK_EQUAL( G4.OR(0), facs[0] * facs[2] );
588 BOOST_CHECK_EQUAL( G4.OR(1).c(), 1 );
589 BOOST_CHECK_EQUAL( G4.OR(1), facs[3] );
590 BOOST_CHECK_EQUAL( G4.OR(2).c(), 1 );
591 BOOST_CHECK_EQUAL( G4.OR(2), facs[1] * facs[4] * facs[5] );
592 BOOST_CHECK_EQUAL( G4.nrIRs(), 3 );
593 BOOST_CHECK_EQUAL( G4.IR(0).c(), -1.0 );
594 BOOST_CHECK_EQUAL( G4.IR(0), v0 );
595 BOOST_CHECK_EQUAL( G4.IR(1).c(), -1.0 );
596 BOOST_CHECK_EQUAL( G4.IR(1), v1 );
597 BOOST_CHECK_EQUAL( G4.IR(2).c(), -1.0 );
598 BOOST_CHECK_EQUAL( G4.IR(2), v2 );
599 BOOST_CHECK_EQUAL( G4.fac2OR(0), 0 );
600 BOOST_CHECK_EQUAL( G4.fac2OR(1), 2 );
601 BOOST_CHECK_EQUAL( G4.fac2OR(2), 0 );
602 BOOST_CHECK_EQUAL( G4.fac2OR(3), 1 );
603 BOOST_CHECK_EQUAL( G4.fac2OR(4), 2 );
604 BOOST_CHECK_EQUAL( G4.fac2OR(5), 2 );
605 }
606
607
608 BOOST_AUTO_TEST_CASE( BackupRestoreTest ) {
609 Var v0( 0, 2 );
610 Var v1( 1, 2 );
611 Var v2( 2, 2 );
612 VarSet v01( v0, v1 );
613 VarSet v02( v0, v2 );
614 VarSet v12( v1, v2 );
615 VarSet v012 = v01 | v2;
616
617 std::vector<Factor> facs;
618 facs.push_back( Factor( v01 ) );
619 facs.push_back( Factor( v12 ) );
620 facs.push_back( Factor( v1 ) );
621 std::vector<Var> vars;
622 vars.push_back( v0 );
623 vars.push_back( v1 );
624 vars.push_back( v2 );
625
626 FactorGraph fg( facs );
627 RegionGraph G( fg, fg.maximalFactorDomains() );
628 RegionGraph Gorg( G );
629 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
630 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
631 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
632 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
633
634 BOOST_CHECK_THROW( G.setFactor( 0, Factor( v0 ), false ), Exception );
635 G.setFactor( 0, Factor( v01, 2.0 ), false );
636 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
637 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
638 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
639 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
640 BOOST_CHECK_THROW( G.restoreFactor( 0 ), Exception );
641 G.setFactor( 0, Factor( v01, 3.0 ), true );
642 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
643 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
644 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
645 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
646 G.restoreFactor( 0 );
647 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
648 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
649 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
650 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
651 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
652 G.setFactor( 0, Gorg.factor( 0 ), false );
653 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
654 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
655 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
656 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
657 G.backupFactor( 0 );
658 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
659 G.setFactor( 0, Factor( v01, 2.0 ), false );
660 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
661 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
662 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
663 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
664 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
665 G.restoreFactor( 0 );
666 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
667 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
668 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
669 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
670 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
671
672 std::map<size_t, Factor> fs;
673 fs[0] = Factor( v01, 3.0 );
674 fs[2] = Factor( v1, 2.0 );
675 G.setFactors( fs, false );
676 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
677 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
678 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
679 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
680 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
681 G.restoreFactors();
682 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
683 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
684 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
685 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
686 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
687 G = Gorg;
688 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
689 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
690 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
691 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
692 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
693 G.setFactors( fs, true );
694 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
695 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
696 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
697 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
698 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
699 G.restoreFactors();
700 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
701 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
702 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
703 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
704 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
705 std::set<size_t> fsind;
706 fsind.insert( 0 );
707 fsind.insert( 2 );
708 G.backupFactors( fsind );
709 G.setFactors( fs, false );
710 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
711 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
712 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
713 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
714 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
715 G.restoreFactors();
716 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
717 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
718 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
719 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
720 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
721
722 G.backupFactors( v2 );
723 G.setFactor( 1, Factor(v12, 5.0) );
724 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
725 BOOST_CHECK_EQUAL( G.factor(1), Factor(v12, 5.0) );
726 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
727 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
728 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
729 G.restoreFactors( v2 );
730 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
731 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
732 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
733 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
734 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
735
736 G.backupFactors( v1 );
737 fs[1] = Factor( v12, 5.0 );
738 G.setFactors( fs, false );
739 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
740 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
741 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
742 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
743 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
744 G.restoreFactors();
745 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
746 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
747 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
748 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
749 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
750 G.setFactors( fs, true );
751 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
752 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
753 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
754 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
755 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
756 G.restoreFactors( v1 );
757 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
758 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
759 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
760 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
761 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
762 }
763
764
765 BOOST_AUTO_TEST_CASE( TransformationsTest ) {
766 Var v0( 0, 2 );
767 Var v1( 1, 2 );
768 Var v2( 2, 2 );
769 VarSet v01( v0, v1 );
770 VarSet v02( v0, v2 );
771 VarSet v12( v1, v2 );
772 VarSet v012 = v01 | v2;
773
774 std::vector<Factor> facs;
775 facs.push_back( Factor( v01 ).randomize() );
776 facs.push_back( Factor( v12 ).randomize() );
777 facs.push_back( Factor( v1 ).randomize() );
778 std::vector<Var> vars;
779 vars.push_back( v0 );
780 vars.push_back( v1 );
781 vars.push_back( v2 );
782
783 FactorGraph fg( facs );
784 RegionGraph G( fg, fg.maximalFactorDomains() );
785
786 FactorGraph Gsmall = G.maximalFactors();
787 BOOST_CHECK_EQUAL( Gsmall.nrVars(), 3 );
788 BOOST_CHECK_EQUAL( Gsmall.nrFactors(), 2 );
789 BOOST_CHECK_EQUAL( Gsmall.factor( 0 ), G.factor( 0 ) * G.factor( 2 ) );
790 BOOST_CHECK_EQUAL( Gsmall.factor( 1 ), G.factor( 1 ) );
791
792 size_t i = 0;
793 for( size_t x = 0; x < 2; x++ ) {
794 FactorGraph Gcl = G.clamped( i, x );
795 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
796 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
797 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
798 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) * G.factor(2) );
799 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1) );
800 }
801 i = 1;
802 for( size_t x = 0; x < 2; x++ ) {
803 FactorGraph Gcl = G.clamped( i, x );
804 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
805 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
806 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) * G.factor(2).slice(vars[i],x) );
807 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) );
808 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) );
809 }
810 i = 2;
811 for( size_t x = 0; x < 2; x++ ) {
812 FactorGraph Gcl = G.clamped( i, x );
813 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
814 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
815 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
816 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0) );
817 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) * G.factor(2) );
818 }
819 }
820
821
822 BOOST_AUTO_TEST_CASE( OperationsTest ) {
823 Var v0( 0, 2 );
824 Var v1( 1, 2 );
825 Var v2( 2, 2 );
826 VarSet v01( v0, v1 );
827 VarSet v02( v0, v2 );
828 VarSet v12( v1, v2 );
829 VarSet v012 = v01 | v2;
830
831 std::vector<Factor> facs;
832 facs.push_back( Factor( v01 ).randomize() );
833 facs.push_back( Factor( v12 ).randomize() );
834 facs.push_back( Factor( v1 ).randomize() );
835 std::vector<Var> vars;
836 vars.push_back( v0 );
837 vars.push_back( v1 );
838 vars.push_back( v2 );
839
840 FactorGraph fg( facs );
841 RegionGraph G( fg, fg.maximalFactorDomains() );
842
843 // clamp
844 RegionGraph Gcl = G;
845 for( size_t i = 0; i < 3; i++ )
846 for( size_t x = 0; x < 2; x++ ) {
847 Gcl.clamp( i, x, true );
848 Factor delta = createFactorDelta( vars[i], x );
849 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
850 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
851 for( size_t j = 0; j < 3; j++ )
852 if( G.factor(j).vars().contains( vars[i] ) )
853 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
854 else
855 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
856 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
857 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
858
859 Gcl.restoreFactors();
860 for( size_t j = 0; j < 3; j++ )
861 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
862 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
863 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
864 }
865
866 // clampVar
867 for( size_t i = 0; i < 3; i++ )
868 for( size_t x = 0; x < 2; x++ ) {
869 Gcl.clampVar( i, std::vector<size_t>(1, x), true );
870 Factor delta = createFactorDelta( vars[i], x );
871 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
872 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
873 for( size_t j = 0; j < 3; j++ )
874 if( G.factor(j).vars().contains( vars[i] ) )
875 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
876 else
877 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
878 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
879 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
880
881 Gcl.restoreFactors();
882 for( size_t j = 0; j < 3; j++ )
883 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
884 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
885 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
886 }
887 for( size_t i = 0; i < 3; i++ )
888 for( size_t x = 0; x < 2; x++ ) {
889 Gcl.clampVar( i, std::vector<size_t>(), true );
890 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
891 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
892 for( size_t j = 0; j < 3; j++ )
893 if( G.factor(j).vars().contains( vars[i] ) )
894 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * 0.0 );
895 else
896 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
897 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
898 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
899
900 Gcl.restoreFactors();
901 for( size_t j = 0; j < 3; j++ )
902 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
903 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
904 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
905 }
906 std::vector<size_t> both;
907 both.push_back( 0 );
908 both.push_back( 1 );
909 for( size_t i = 0; i < 3; i++ )
910 for( size_t x = 0; x < 2; x++ ) {
911 Gcl.clampVar( i, both, true );
912 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
913 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
914 for( size_t j = 0; j < 3; j++ )
915 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
916 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
917 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
918 Gcl.restoreFactors();
919 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
920 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
921 }
922
923 // clampFactor
924 for( size_t x = 0; x < 4; x++ ) {
925 Gcl.clampFactor( 0, std::vector<size_t>(1,x), true );
926 Factor mask( v01, 0.0 );
927 mask.set( x, 1.0 );
928 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) * mask );
929 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
930 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
931 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
932 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
933 Gcl.restoreFactor( 0 );
934 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
935 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
936 }
937 for( size_t x = 0; x < 4; x++ ) {
938 Gcl.clampFactor( 1, std::vector<size_t>(1,x), true );
939 Factor mask( v12, 0.0 );
940 mask.set( x, 1.0 );
941 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
942 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) * mask );
943 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
944 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
945 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
946 Gcl.restoreFactor( 1 );
947 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
948 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
949 }
950 for( size_t x = 0; x < 2; x++ ) {
951 Gcl.clampFactor( 2, std::vector<size_t>(1,x), true );
952 Factor mask( v1, 0.0 );
953 mask.set( x, 1.0 );
954 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
955 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
956 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) * mask );
957 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
958 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
959 Gcl.restoreFactors();
960 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
961 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
962 }
963
964 // makeCavity
965 RegionGraph Gcav( G );
966 Gcav.makeCavity( 0, true );
967 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
968 BOOST_CHECK_EQUAL( Gcav.factor(1), G.factor(1) );
969 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
970 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
971 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
972 Gcav.restoreFactors();
973 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
974 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
975 Gcav.makeCavity( 1, true );
976 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
977 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
978 BOOST_CHECK_EQUAL( Gcav.factor(2), Factor( v1, 1.0 ) );
979 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
980 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
981 Gcav.restoreFactors();
982 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
983 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
984 Gcav.makeCavity( 2, true );
985 BOOST_CHECK_EQUAL( Gcav.factor(0), G.factor(0) );
986 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
987 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
988 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
989 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
990 Gcav.restoreFactors();
991 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
992 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
993 }
994
995
996 BOOST_AUTO_TEST_CASE( IOTest ) {
997 Var v0( 0, 2 );
998 Var v1( 1, 2 );
999 Var v2( 2, 2 );
1000 VarSet v01( v0, v1 );
1001 VarSet v02( v0, v2 );
1002 VarSet v12( v1, v2 );
1003 VarSet v012 = v01 | v2;
1004
1005 std::vector<Factor> facs;
1006 facs.push_back( Factor( v01, 1.0 ) );
1007 facs.push_back( Factor( v12, 2.0 ) );
1008 facs.push_back( Factor( v1, 3.0 ) );
1009 std::vector<Var> vars;
1010 vars.push_back( v0 );
1011 vars.push_back( v1 );
1012 vars.push_back( v2 );
1013
1014 FactorGraph fg( facs );
1015 RegionGraph G( fg, fg.maximalFactorDomains() );
1016 BOOST_CHECK_THROW( G.WriteToFile( "regiongraph_test.fg" ), Exception );
1017 BOOST_CHECK_THROW( G.ReadFromFile( "regiongraph_test.fg" ), Exception );
1018 BOOST_CHECK_THROW( G.printDot( std::cout ), Exception );
1019
1020 std::stringstream ss;
1021 std::string s;
1022 ss << G;
1023 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "digraph RegionGraph {" );
1024 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box];" );
1025 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta0 [label=\"a0: {x0, x1}, c=1\"];" );
1026 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta1 [label=\"a1: {x1, x2}, c=1\"];" );
1027 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=ellipse];" );
1028 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tb0 [label=\"b0: {x1}, c=-1\"];" );
1029 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta0 -> b0;" );
1030 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta1 -> b0;" );
1031 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
1032
1033 BOOST_CHECK_EQUAL( G.toString(), "digraph RegionGraph {\nnode[shape=box];\n\ta0 [label=\"a0: {x0, x1}, c=1\"];\n\ta1 [label=\"a1: {x1, x2}, c=1\"];\nnode[shape=ellipse];\n\tb0 [label=\"b0: {x1}, c=-1\"];\n\ta0 -> b0;\n\ta1 -> b0;\n}\n" );
1034 }