Miscellaneous improvements to regiongraph, factorgraph and bipgraph, and finished...
[libdai.git] / tests / unit / regiongraph.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/regiongraph.h>
15 #include <vector>
16 #include <strstream>
17
18
19 using namespace dai;
20
21
22 const double tol = 1e-8;
23
24
25 #define BOOST_TEST_MODULE RegionGraphTest
26
27
28 #include <boost/test/unit_test.hpp>
29 #include <boost/test/floating_point_comparison.hpp>
30
31
32 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
33 Var v0( 0, 2 );
34 Var v1( 1, 2 );
35 Var v2( 2, 2 );
36 VarSet v01( v0, v1 );
37 VarSet v02( v0, v2 );
38 VarSet v12( v1, v2 );
39 RegionGraph G;
40 BOOST_CHECK_EQUAL( G.vars(), std::vector<Var>() );
41 BOOST_CHECK_EQUAL( G.factors(), std::vector<Factor>() );
42 BOOST_CHECK( G.bipGraph() == BipartiteGraph() );
43 BOOST_CHECK( G.DAG() == BipartiteGraph() );
44 BOOST_CHECK_EQUAL( G.nrORs(), 0 );
45 BOOST_CHECK_EQUAL( G.nrIRs(), 0 );
46
47 std::vector<Factor> facs;
48 facs.push_back( Factor( v01 ) );
49 facs.push_back( Factor( v02 ) );
50 facs.push_back( Factor( v12 ) );
51 facs.push_back( Factor( v1 ) );
52 std::vector<Var> vars;
53 vars.push_back( v0 );
54 vars.push_back( v1 );
55 vars.push_back( v2 );
56 BipartiteGraph dag( 3, 3 );
57 dag.addEdge( 0, 0 );
58 dag.addEdge( 0, 1 );
59 dag.addEdge( 1, 0 );
60 dag.addEdge( 1, 2 );
61 dag.addEdge( 2, 1 );
62 dag.addEdge( 2, 2 );
63 BipartiteGraph bipgraph( dag );
64 bipgraph.addNode2();
65 bipgraph.addEdge( 1, 3 );
66
67 FactorGraph fg(facs);
68 RegionGraph G1( fg, fg.maximalFactorDomains() );
69 BOOST_CHECK_EQUAL( G1.vars(), vars );
70 BOOST_CHECK_EQUAL( G1.factors(), facs );
71 BOOST_CHECK( G1.bipGraph() == bipgraph );
72 BOOST_CHECK( G1.DAG() == dag );
73 BOOST_CHECK_EQUAL( G1.nrORs(), 3 );
74 BOOST_CHECK_EQUAL( G1.OR(0).c(), 1 );
75 BOOST_CHECK_EQUAL( G1.OR(0), facs[0] * facs[3] );
76 BOOST_CHECK_EQUAL( G1.OR(1).c(), 1 );
77 BOOST_CHECK_EQUAL( G1.OR(1), facs[1] );
78 BOOST_CHECK_EQUAL( G1.OR(2).c(), 1 );
79 BOOST_CHECK_EQUAL( G1.OR(2), facs[2] );
80 BOOST_CHECK_EQUAL( G1.nrIRs(), 3 );
81 BOOST_CHECK_EQUAL( G1.IR(0).c(), -1 );
82 BOOST_CHECK_EQUAL( G1.IR(0), v0 );
83 BOOST_CHECK_EQUAL( G1.IR(1).c(), -1 );
84 BOOST_CHECK_EQUAL( G1.IR(1), v1 );
85 BOOST_CHECK_EQUAL( G1.IR(2).c(), -1 );
86 BOOST_CHECK_EQUAL( G1.IR(2), v2 );
87 BOOST_CHECK_EQUAL( G1.fac2OR(0), 0 );
88 BOOST_CHECK_EQUAL( G1.fac2OR(1), 1 );
89 BOOST_CHECK_EQUAL( G1.fac2OR(2), 2 );
90 BOOST_CHECK_EQUAL( G1.fac2OR(3), 0 );
91 BOOST_CHECK( G1.checkCountingNumbers() );
92
93 std::vector<VarSet> ors;
94 std::vector<Region> irs;
95 ors.push_back( v01 | v2 );
96 irs.push_back( Region( v01, 0.0 ) );
97 irs.push_back( Region( v2, 0.0 ) );
98 typedef std::pair<size_t,size_t> edge;
99 std::vector<edge> edges;
100 edges.push_back( edge( 0, 0 ) );
101 edges.push_back( edge( 0, 1 ) );
102 BipartiteGraph dag2( 1, 2, edges.begin(), edges.end() );
103 RegionGraph G2( fg, ors, irs, edges );
104 BOOST_CHECK_EQUAL( G2.vars(), vars );
105 BOOST_CHECK_EQUAL( G2.factors(), facs );
106 BOOST_CHECK( G2.bipGraph() == bipgraph );
107 BOOST_CHECK( G2.DAG() == dag2 );
108 BOOST_CHECK_EQUAL( G2.nrORs(), 1 );
109 BOOST_CHECK_EQUAL( G2.OR(0).c(), 1 );
110 BOOST_CHECK_EQUAL( G2.OR(0), facs[0] * facs[1] * facs[2] * facs[3] );
111 BOOST_CHECK_EQUAL( G2.nrIRs(), 2 );
112 BOOST_CHECK_EQUAL( G2.IR(0), irs[0] );
113 BOOST_CHECK_EQUAL( G2.IR(1), irs[1] );
114 BOOST_CHECK_EQUAL( G2.fac2OR(0), 0 );
115 BOOST_CHECK_EQUAL( G2.fac2OR(1), 0 );
116 BOOST_CHECK_EQUAL( G2.fac2OR(2), 0 );
117 BOOST_CHECK_EQUAL( G2.fac2OR(3), 0 );
118 BOOST_CHECK( G2.checkCountingNumbers() );
119
120 RegionGraph *G3 = G1.clone();
121 BOOST_CHECK_EQUAL( G3->vars(), G1.vars() );
122 BOOST_CHECK_EQUAL( G3->factors(), G1.factors() );
123 BOOST_CHECK( G3->bipGraph() == bipgraph );
124 BOOST_CHECK( G3->DAG() == dag );
125 BOOST_CHECK_EQUAL( G3->nrORs(), G1.nrORs() );
126 BOOST_CHECK_EQUAL( G3->OR(0), G1.OR(0) );
127 BOOST_CHECK_EQUAL( G3->OR(1), G1.OR(1) );
128 BOOST_CHECK_EQUAL( G3->OR(2), G1.OR(2) );
129 BOOST_CHECK_EQUAL( G3->nrIRs(), G1.nrIRs() );
130 BOOST_CHECK_EQUAL( G3->IR(0), G1.IR(0) );
131 BOOST_CHECK_EQUAL( G3->IR(1), G1.IR(1) );
132 BOOST_CHECK_EQUAL( G3->IR(2), G1.IR(2) );
133 BOOST_CHECK_EQUAL( G3->fac2OR(0), G1.fac2OR(0) );
134 BOOST_CHECK_EQUAL( G3->fac2OR(1), G1.fac2OR(1) );
135 BOOST_CHECK_EQUAL( G3->fac2OR(2), G1.fac2OR(2) );
136 BOOST_CHECK_EQUAL( G3->fac2OR(3), G1.fac2OR(3) );
137 BOOST_CHECK( G3->checkCountingNumbers() );
138 delete G3;
139
140 RegionGraph G4 = G1;
141 BOOST_CHECK_EQUAL( G4.vars(), G1.vars() );
142 BOOST_CHECK_EQUAL( G4.factors(), G1.factors() );
143 BOOST_CHECK( G4.bipGraph() == bipgraph );
144 BOOST_CHECK( G4.DAG() == dag );
145 BOOST_CHECK_EQUAL( G4.nrORs(), G1.nrORs() );
146 BOOST_CHECK_EQUAL( G4.OR(0), G1.OR(0) );
147 BOOST_CHECK_EQUAL( G4.OR(1), G1.OR(1) );
148 BOOST_CHECK_EQUAL( G4.OR(2), G1.OR(2) );
149 BOOST_CHECK_EQUAL( G4.nrIRs(), G1.nrIRs() );
150 BOOST_CHECK_EQUAL( G4.IR(0), G1.IR(0) );
151 BOOST_CHECK_EQUAL( G4.IR(1), G1.IR(1) );
152 BOOST_CHECK_EQUAL( G4.IR(2), G1.IR(2) );
153 BOOST_CHECK_EQUAL( G4.fac2OR(0), G1.fac2OR(0) );
154 BOOST_CHECK_EQUAL( G4.fac2OR(1), G1.fac2OR(1) );
155 BOOST_CHECK_EQUAL( G4.fac2OR(2), G1.fac2OR(2) );
156 BOOST_CHECK_EQUAL( G4.fac2OR(3), G1.fac2OR(3) );
157 BOOST_CHECK( G4.checkCountingNumbers() );
158
159 RegionGraph G5( G1 );
160 BOOST_CHECK_EQUAL( G5.vars(), G1.vars() );
161 BOOST_CHECK_EQUAL( G5.factors(), G1.factors() );
162 BOOST_CHECK( G5.bipGraph() == bipgraph );
163 BOOST_CHECK( G5.DAG() == dag );
164 BOOST_CHECK_EQUAL( G5.nrORs(), G1.nrORs() );
165 BOOST_CHECK_EQUAL( G5.OR(0), G1.OR(0) );
166 BOOST_CHECK_EQUAL( G5.OR(1), G1.OR(1) );
167 BOOST_CHECK_EQUAL( G5.OR(2), G1.OR(2) );
168 BOOST_CHECK_EQUAL( G5.nrIRs(), G1.nrIRs() );
169 BOOST_CHECK_EQUAL( G5.IR(0), G1.IR(0) );
170 BOOST_CHECK_EQUAL( G5.IR(1), G1.IR(1) );
171 BOOST_CHECK_EQUAL( G5.IR(2), G1.IR(2) );
172 BOOST_CHECK_EQUAL( G5.fac2OR(0), G1.fac2OR(0) );
173 BOOST_CHECK_EQUAL( G5.fac2OR(1), G1.fac2OR(1) );
174 BOOST_CHECK_EQUAL( G5.fac2OR(2), G1.fac2OR(2) );
175 BOOST_CHECK_EQUAL( G5.fac2OR(3), G1.fac2OR(3) );
176 BOOST_CHECK( G5.checkCountingNumbers() );
177 }
178
179
180 BOOST_AUTO_TEST_CASE( AccMutTest ) {
181 Var v0( 0, 2 );
182 Var v1( 1, 2 );
183 Var v2( 2, 2 );
184 VarSet v01( v0, v1 );
185 VarSet v02( v0, v2 );
186 VarSet v12( v1, v2 );
187 std::vector<Factor> facs;
188 facs.push_back( Factor( v01 ) );
189 facs.push_back( Factor( v02 ) );
190 facs.push_back( Factor( v12 ) );
191 facs.push_back( Factor( v1 ) );
192 std::vector<Var> vars;
193 vars.push_back( v0 );
194 vars.push_back( v1 );
195 vars.push_back( v2 );
196 BipartiteGraph dag( 3, 3 );
197 dag.addEdge( 0, 0 );
198 dag.addEdge( 0, 1 );
199 dag.addEdge( 1, 0 );
200 dag.addEdge( 1, 2 );
201 dag.addEdge( 2, 1 );
202 dag.addEdge( 2, 2 );
203 BipartiteGraph bipgraph( dag );
204 bipgraph.addNode2();
205 bipgraph.addEdge( 1, 3 );
206
207 FactorGraph fg( facs );
208 RegionGraph G( fg, fg.maximalFactorDomains() );
209 BOOST_CHECK_EQUAL( G.var(0), v0 );
210 BOOST_CHECK_EQUAL( G.var(1), v1 );
211 BOOST_CHECK_EQUAL( G.var(2), v2 );
212 BOOST_CHECK_EQUAL( G.vars(), vars );
213 BOOST_CHECK_EQUAL( G.factor(0), facs[0] );
214 BOOST_CHECK_EQUAL( G.factor(1), facs[1] );
215 BOOST_CHECK_EQUAL( G.factor(2), facs[2] );
216 BOOST_CHECK_EQUAL( G.factor(3), facs[3] );
217 BOOST_CHECK_EQUAL( G.factors(), facs );
218 BOOST_CHECK( G.bipGraph() == bipgraph );
219 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
220 BOOST_CHECK_EQUAL( G.nbV(0,0), 0 );
221 BOOST_CHECK_EQUAL( G.nbV(0,1), 1 );
222 BOOST_CHECK_EQUAL( G.nbV(1).size(), 3 );
223 BOOST_CHECK_EQUAL( G.nbV(1,0), 0 );
224 BOOST_CHECK_EQUAL( G.nbV(1,1), 2 );
225 BOOST_CHECK_EQUAL( G.nbV(1,2), 3 );
226 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
227 BOOST_CHECK_EQUAL( G.nbV(2,0), 1 );
228 BOOST_CHECK_EQUAL( G.nbV(2,1), 2 );
229 BOOST_CHECK_EQUAL( G.nbF(0).size(), 2 );
230 BOOST_CHECK_EQUAL( G.nbF(0,0), 0 );
231 BOOST_CHECK_EQUAL( G.nbF(0,1), 1 );
232 BOOST_CHECK_EQUAL( G.nbF(1).size(), 2 );
233 BOOST_CHECK_EQUAL( G.nbF(1,0), 0 );
234 BOOST_CHECK_EQUAL( G.nbF(1,1), 2 );
235 BOOST_CHECK_EQUAL( G.nbF(2).size(), 2 );
236 BOOST_CHECK_EQUAL( G.nbF(2,0), 1 );
237 BOOST_CHECK_EQUAL( G.nbF(2,1), 2 );
238 BOOST_CHECK_EQUAL( G.nbF(3).size(), 1 );
239 BOOST_CHECK_EQUAL( G.nbF(3,0), 1 );
240 BOOST_CHECK( G.DAG() == dag );
241 BOOST_CHECK_EQUAL( G.nrORs(), 3 );
242 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
243 BOOST_CHECK_EQUAL( G.OR(0).vars(), v01 );
244 BOOST_CHECK_EQUAL( G.OR(0), facs[0] * facs[3] );
245 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
246 BOOST_CHECK_EQUAL( G.OR(1).vars(), v02 );
247 BOOST_CHECK_EQUAL( G.OR(1), facs[1] );
248 BOOST_CHECK_EQUAL( G.OR(2).c(), 1 );
249 BOOST_CHECK_EQUAL( G.OR(2).vars(), v12 );
250 BOOST_CHECK_EQUAL( G.OR(2), facs[2] );
251 BOOST_CHECK_EQUAL( G.nrIRs(), 3 );
252 BOOST_CHECK_EQUAL( G.IR(0).c(), -1 );
253 BOOST_CHECK_EQUAL( G.IR(0), v0 );
254 BOOST_CHECK_EQUAL( G.IR(1).c(), -1 );
255 BOOST_CHECK_EQUAL( G.IR(1), v1 );
256 BOOST_CHECK_EQUAL( G.IR(2).c(), -1 );
257 BOOST_CHECK_EQUAL( G.IR(2), v2 );
258 BOOST_CHECK_EQUAL( G.fac2OR(0), 0 );
259 BOOST_CHECK_EQUAL( G.fac2OR(1), 1 );
260 BOOST_CHECK_EQUAL( G.fac2OR(2), 2 );
261 BOOST_CHECK_EQUAL( G.fac2OR(3), 0 );
262 BOOST_CHECK_EQUAL( G.nbOR(0).size(), 2 );
263 BOOST_CHECK_EQUAL( G.nbOR(0)[0], 0 );
264 BOOST_CHECK_EQUAL( G.nbOR(0)[1], 1 );
265 BOOST_CHECK_EQUAL( G.nbOR(1).size(), 2 );
266 BOOST_CHECK_EQUAL( G.nbOR(1)[0], 0 );
267 BOOST_CHECK_EQUAL( G.nbOR(1)[1], 2 );
268 BOOST_CHECK_EQUAL( G.nbOR(2).size(), 2 );
269 BOOST_CHECK_EQUAL( G.nbOR(2)[0], 1 );
270 BOOST_CHECK_EQUAL( G.nbOR(2)[1], 2 );
271 BOOST_CHECK_EQUAL( G.nbIR(0).size(), 2 );
272 BOOST_CHECK_EQUAL( G.nbIR(0)[0], 0 );
273 BOOST_CHECK_EQUAL( G.nbIR(0)[1], 1 );
274 BOOST_CHECK_EQUAL( G.nbIR(1).size(), 2 );
275 BOOST_CHECK_EQUAL( G.nbIR(1)[0], 0 );
276 BOOST_CHECK_EQUAL( G.nbIR(1)[1], 2 );
277 BOOST_CHECK_EQUAL( G.nbIR(2).size(), 2 );
278 BOOST_CHECK_EQUAL( G.nbIR(2)[0], 1 );
279 BOOST_CHECK_EQUAL( G.nbIR(2)[1], 2 );
280 }
281
282
283 BOOST_AUTO_TEST_CASE( QueriesTest ) {
284 Var v0( 0, 2 );
285 Var v1( 1, 2 );
286 Var v2( 2, 2 );
287 VarSet v01( v0, v1 );
288 VarSet v02( v0, v2 );
289 VarSet v12( v1, v2 );
290 VarSet v012 = v01 | v2;
291
292 RegionGraph G0;
293 BOOST_CHECK_EQUAL( G0.nrVars(), 0 );
294 BOOST_CHECK_EQUAL( G0.nrFactors(), 0 );
295 BOOST_CHECK_EQUAL( G0.nrEdges(), 0 );
296 BOOST_CHECK_THROW( G0.findVar( v0 ), Exception );
297 BOOST_CHECK_THROW( G0.findVars( v01 ), Exception );
298 BOOST_CHECK_THROW( G0.findFactor( v01 ), Exception );
299 #ifdef DAI_DBEUG
300 BOOST_CHECK_THROW( G0.delta( 0 ), Exception );
301 BOOST_CHECK_THROW( G0.Delta( 0 ), Exception );
302 BOOST_CHECK_THROW( G0.delta( v0 ), Exception );
303 BOOST_CHECK_THROW( G0.Delta( v0 ), Exception );
304 BOOST_CHECK_THROW( G0.fac2OR( 0 ), Exception );
305 #endif
306 BOOST_CHECK( G0.isConnected() );
307 BOOST_CHECK( G0.isTree() );
308 BOOST_CHECK( G0.isBinary() );
309 BOOST_CHECK( G0.isPairwise() );
310 BOOST_CHECK( G0.MarkovGraph() == GraphAL() );
311 BOOST_CHECK( G0.bipGraph() == BipartiteGraph() );
312 BOOST_CHECK_EQUAL( G0.maximalFactorDomains().size(), 0 );
313 BOOST_CHECK( G0.DAG() == BipartiteGraph() );
314 BOOST_CHECK( G0.checkCountingNumbers() );
315 BOOST_CHECK_EQUAL( G0.nrORs(), 0 );
316 BOOST_CHECK_EQUAL( G0.nrIRs(), 0 );
317
318 std::vector<Factor> facs;
319 facs.push_back( Factor( v01 ) );
320 facs.push_back( Factor( v12 ) );
321 facs.push_back( Factor( v1 ) );
322 std::vector<Var> vars;
323 vars.push_back( v0 );
324 vars.push_back( v1 );
325 vars.push_back( v2 );
326 GraphAL H(3);
327 H.addEdge( 0, 1 );
328 H.addEdge( 1, 2 );
329 BipartiteGraph K(3, 3);
330 K.addEdge( 0, 0 );
331 K.addEdge( 1, 0 );
332 K.addEdge( 1, 1 );
333 K.addEdge( 2, 1 );
334 K.addEdge( 1, 2 );
335 BipartiteGraph dag(2, 1);
336 dag.addEdge( 0, 0 );
337 dag.addEdge( 1, 0 );
338
339 FactorGraph fg( facs );
340 RegionGraph G1( fg, fg.maximalFactorDomains() );
341 BOOST_CHECK_EQUAL( G1.nrVars(), 3 );
342 BOOST_CHECK_EQUAL( G1.nrFactors(), 3 );
343 BOOST_CHECK_EQUAL( G1.nrEdges(), 5 );
344 BOOST_CHECK_EQUAL( G1.findVar( v0 ), 0 );
345 BOOST_CHECK_EQUAL( G1.findVar( v1 ), 1 );
346 BOOST_CHECK_EQUAL( G1.findVar( v2 ), 2 );
347 BOOST_CHECK_EQUAL( G1.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
348 BOOST_CHECK_EQUAL( G1.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
349 BOOST_CHECK_EQUAL( G1.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
350 BOOST_CHECK_EQUAL( G1.findFactor( v01 ), 0 );
351 BOOST_CHECK_EQUAL( G1.findFactor( v12 ), 1 );
352 BOOST_CHECK_EQUAL( G1.findFactor( v1 ), 2 );
353 BOOST_CHECK_THROW( G1.findFactor( v02 ), Exception );
354 BOOST_CHECK_EQUAL( G1.delta( 0 ), v1 );
355 BOOST_CHECK_EQUAL( G1.delta( 1 ), v02 );
356 BOOST_CHECK_EQUAL( G1.delta( 2 ), v1 );
357 BOOST_CHECK_EQUAL( G1.Delta( 0 ), v01 );
358 BOOST_CHECK_EQUAL( G1.Delta( 1 ), v012 );
359 BOOST_CHECK_EQUAL( G1.Delta( 2 ), v12 );
360 BOOST_CHECK_EQUAL( G1.delta( v0 ), v1 );
361 BOOST_CHECK_EQUAL( G1.delta( v1 ), v02 );
362 BOOST_CHECK_EQUAL( G1.delta( v2 ), v1 );
363 BOOST_CHECK_EQUAL( G1.delta( v01 ), v2 );
364 BOOST_CHECK_EQUAL( G1.delta( v02 ), v1 );
365 BOOST_CHECK_EQUAL( G1.delta( v12 ), v0 );
366 BOOST_CHECK_EQUAL( G1.delta( v012 ), VarSet() );
367 BOOST_CHECK_EQUAL( G1.Delta( v0 ), v01 );
368 BOOST_CHECK_EQUAL( G1.Delta( v1 ), v012 );
369 BOOST_CHECK_EQUAL( G1.Delta( v2 ), v12 );
370 BOOST_CHECK_EQUAL( G1.Delta( v01 ), v012 );
371 BOOST_CHECK_EQUAL( G1.Delta( v02 ), v012 );
372 BOOST_CHECK_EQUAL( G1.Delta( v12 ), v012 );
373 BOOST_CHECK_EQUAL( G1.Delta( v012 ), v012 );
374 BOOST_CHECK( G1.isConnected() );
375 BOOST_CHECK( G1.isTree() );
376 BOOST_CHECK( G1.isBinary() );
377 BOOST_CHECK( G1.isPairwise() );
378 BOOST_CHECK( G1.MarkovGraph() == H );
379 BOOST_CHECK( G1.bipGraph() == K );
380 BOOST_CHECK_EQUAL( G1.maximalFactorDomains().size(), 2 );
381 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[0], v01 );
382 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[1], v12 );
383 BOOST_CHECK( G1.DAG() == dag );
384 BOOST_CHECK( G1.checkCountingNumbers() );
385 BOOST_CHECK_EQUAL( G1.nrORs(), 2 );
386 BOOST_CHECK_EQUAL( G1.OR(0).c(), 1 );
387 BOOST_CHECK_EQUAL( G1.OR(0), facs[0] * facs[2] );
388 BOOST_CHECK_EQUAL( G1.OR(1).c(), 1 );
389 BOOST_CHECK_EQUAL( G1.OR(1), facs[1] );
390 BOOST_CHECK_EQUAL( G1.nrIRs(), 1 );
391 BOOST_CHECK_EQUAL( G1.IR(0).c(), -1 );
392 BOOST_CHECK_EQUAL( G1.IR(0), v1 );
393 BOOST_CHECK_EQUAL( G1.fac2OR(0), 0 );
394 BOOST_CHECK_EQUAL( G1.fac2OR(1), 1 );
395 BOOST_CHECK_EQUAL( G1.fac2OR(2), 0 );
396
397 facs.push_back( Factor( v02 ) );
398 H.addEdge( 0, 2 );
399 K.addNode2();
400 K.addEdge( 0, 3 );
401 K.addEdge( 2, 3 );
402 dag = BipartiteGraph( 3, 3 );
403 dag.addEdge( 0, 0 );
404 dag.addEdge( 0, 1 );
405 dag.addEdge( 1, 1 );
406 dag.addEdge( 1, 2 );
407 dag.addEdge( 2, 0 );
408 dag.addEdge( 2, 2 );
409 fg = FactorGraph( facs );
410 RegionGraph G2( fg, fg.maximalFactorDomains() );
411 BOOST_CHECK_EQUAL( G2.nrVars(), 3 );
412 BOOST_CHECK_EQUAL( G2.nrFactors(), 4 );
413 BOOST_CHECK_EQUAL( G2.nrEdges(), 7 );
414 BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
415 BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
416 BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
417 BOOST_CHECK_EQUAL( G2.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
418 BOOST_CHECK_EQUAL( G2.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
419 BOOST_CHECK_EQUAL( G2.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
420 BOOST_CHECK_EQUAL( G2.findFactor( v01 ), 0 );
421 BOOST_CHECK_EQUAL( G2.findFactor( v12 ), 1 );
422 BOOST_CHECK_EQUAL( G2.findFactor( v1 ), 2 );
423 BOOST_CHECK_EQUAL( G2.findFactor( v02 ), 3 );
424 BOOST_CHECK_EQUAL( G2.delta( 0 ), v12 );
425 BOOST_CHECK_EQUAL( G2.delta( 1 ), v02 );
426 BOOST_CHECK_EQUAL( G2.delta( 2 ), v01 );
427 BOOST_CHECK_EQUAL( G2.Delta( 0 ), v012 );
428 BOOST_CHECK_EQUAL( G2.Delta( 1 ), v012 );
429 BOOST_CHECK_EQUAL( G2.Delta( 2 ), v012 );
430 BOOST_CHECK( G2.isConnected() );
431 BOOST_CHECK( !G2.isTree() );
432 BOOST_CHECK( G2.isBinary() );
433 BOOST_CHECK( G2.isPairwise() );
434 BOOST_CHECK( G2.MarkovGraph() == H );
435 BOOST_CHECK( G2.bipGraph() == K );
436 BOOST_CHECK_EQUAL( G2.maximalFactorDomains().size(), 3 );
437 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[0], v01 );
438 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[1], v12 );
439 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[2], v02 );
440 BOOST_CHECK( G2.DAG() == dag );
441 BOOST_CHECK( G2.checkCountingNumbers() );
442 BOOST_CHECK_EQUAL( G2.nrORs(), 3 );
443 BOOST_CHECK_EQUAL( G2.OR(0).c(), 1 );
444 BOOST_CHECK_EQUAL( G2.OR(0), facs[0] * facs[2] );
445 BOOST_CHECK_EQUAL( G2.OR(1).c(), 1 );
446 BOOST_CHECK_EQUAL( G2.OR(1), facs[1] );
447 BOOST_CHECK_EQUAL( G2.OR(2).c(), 1 );
448 BOOST_CHECK_EQUAL( G2.OR(2), facs[3] );
449 BOOST_CHECK_EQUAL( G2.nrIRs(), 3 );
450 BOOST_CHECK_EQUAL( G2.IR(0).c(), -1.0 );
451 BOOST_CHECK_EQUAL( G2.IR(0), v0 );
452 BOOST_CHECK_EQUAL( G2.IR(1).c(), -1.0 );
453 BOOST_CHECK_EQUAL( G2.IR(1), v1 );
454 BOOST_CHECK_EQUAL( G2.IR(2).c(), -1.0 );
455 BOOST_CHECK_EQUAL( G2.IR(2), v2 );
456 BOOST_CHECK_EQUAL( G2.fac2OR(0), 0 );
457 BOOST_CHECK_EQUAL( G2.fac2OR(1), 1 );
458 BOOST_CHECK_EQUAL( G2.fac2OR(2), 0 );
459 BOOST_CHECK_EQUAL( G2.fac2OR(3), 2 );
460
461 Var v3( 3, 3 );
462 VarSet v03( v0, v3 );
463 VarSet v13( v1, v3 );
464 VarSet v23( v2, v3 );
465 VarSet v013 = v01 | v3;
466 VarSet v023 = v02 | v3;
467 VarSet v123 = v12 | v3;
468 VarSet v0123 = v012 | v3;
469 vars.push_back( v3 );
470 facs.push_back( Factor( v3 ) );
471 H.addNode();
472 K.addNode1();
473 K.addNode2();
474 K.addEdge( 3, 4 );
475 dag.addNode1();
476 fg = FactorGraph( facs );
477 RegionGraph G3( fg, fg.maximalFactorDomains() );
478 BOOST_CHECK_EQUAL( G3.nrVars(), 4 );
479 BOOST_CHECK_EQUAL( G3.nrFactors(), 5 );
480 BOOST_CHECK_EQUAL( G3.nrEdges(), 8 );
481 BOOST_CHECK_EQUAL( G3.findVar( v0 ), 0 );
482 BOOST_CHECK_EQUAL( G3.findVar( v1 ), 1 );
483 BOOST_CHECK_EQUAL( G3.findVar( v2 ), 2 );
484 BOOST_CHECK_EQUAL( G3.findVar( v3 ), 3 );
485 BOOST_CHECK_EQUAL( G3.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
486 BOOST_CHECK_EQUAL( G3.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
487 BOOST_CHECK_EQUAL( G3.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
488 BOOST_CHECK_EQUAL( G3.findFactor( v01 ), 0 );
489 BOOST_CHECK_EQUAL( G3.findFactor( v12 ), 1 );
490 BOOST_CHECK_EQUAL( G3.findFactor( v1 ), 2 );
491 BOOST_CHECK_EQUAL( G3.findFactor( v02 ), 3 );
492 BOOST_CHECK_EQUAL( G3.findFactor( v3 ), 4 );
493 BOOST_CHECK_THROW( G3.findFactor( v23 ), Exception );
494 BOOST_CHECK_EQUAL( G3.delta( 0 ), v12 );
495 BOOST_CHECK_EQUAL( G3.delta( 1 ), v02 );
496 BOOST_CHECK_EQUAL( G3.delta( 2 ), v01 );
497 BOOST_CHECK_EQUAL( G3.delta( 3 ), VarSet() );
498 BOOST_CHECK_EQUAL( G3.Delta( 0 ), v012 );
499 BOOST_CHECK_EQUAL( G3.Delta( 1 ), v012 );
500 BOOST_CHECK_EQUAL( G3.Delta( 2 ), v012 );
501 BOOST_CHECK_EQUAL( G3.Delta( 3 ), v3 );
502 BOOST_CHECK( !G3.isConnected() );
503 BOOST_CHECK( !G3.isTree() );
504 BOOST_CHECK( !G3.isBinary() );
505 BOOST_CHECK( G3.isPairwise() );
506 BOOST_CHECK( G3.MarkovGraph() == H );
507 BOOST_CHECK( G3.bipGraph() == K );
508 BOOST_CHECK_EQUAL( G3.maximalFactorDomains().size(), 4 );
509 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[0], v01 );
510 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[1], v12 );
511 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[2], v02 );
512 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[3], v3 );
513 BOOST_CHECK( G3.DAG() == dag );
514 BOOST_CHECK( G3.checkCountingNumbers() );
515 BOOST_CHECK_EQUAL( G3.nrORs(), 4 );
516 BOOST_CHECK_EQUAL( G3.OR(0).c(), 1 );
517 BOOST_CHECK_EQUAL( G3.OR(0), facs[0] * facs[2] );
518 BOOST_CHECK_EQUAL( G3.OR(1).c(), 1 );
519 BOOST_CHECK_EQUAL( G3.OR(1), facs[1] );
520 BOOST_CHECK_EQUAL( G3.OR(2).c(), 1 );
521 BOOST_CHECK_EQUAL( G3.OR(2), facs[3] );
522 BOOST_CHECK_EQUAL( G3.OR(3).c(), 1 );
523 BOOST_CHECK_EQUAL( G3.OR(3), facs[4] );
524 BOOST_CHECK_EQUAL( G3.nrIRs(), 3 );
525 BOOST_CHECK_EQUAL( G3.IR(0).c(), -1.0 );
526 BOOST_CHECK_EQUAL( G3.IR(0), v0 );
527 BOOST_CHECK_EQUAL( G3.IR(1).c(), -1.0 );
528 BOOST_CHECK_EQUAL( G3.IR(1), v1 );
529 BOOST_CHECK_EQUAL( G3.IR(2).c(), -1.0 );
530 BOOST_CHECK_EQUAL( G3.IR(2), v2 );
531 BOOST_CHECK_EQUAL( G3.fac2OR(0), 0 );
532 BOOST_CHECK_EQUAL( G3.fac2OR(1), 1 );
533 BOOST_CHECK_EQUAL( G3.fac2OR(2), 0 );
534 BOOST_CHECK_EQUAL( G3.fac2OR(3), 2 );
535 BOOST_CHECK_EQUAL( G3.fac2OR(4), 3 );
536
537 facs.push_back( Factor( v123 ) );
538 H.addEdge( 1, 3 );
539 H.addEdge( 2, 3 );
540 K.addNode2();
541 K.addEdge( 1, 5 );
542 K.addEdge( 2, 5 );
543 K.addEdge( 3, 5 );
544 dag = BipartiteGraph( 3, 3 );
545 dag.addEdge( 0, 0 );
546 dag.addEdge( 0, 1 );
547 dag.addEdge( 1, 0 );
548 dag.addEdge( 1, 2 );
549 dag.addEdge( 2, 1 );
550 dag.addEdge( 2, 2 );
551 fg = FactorGraph( facs );
552 RegionGraph G4( fg, fg.maximalFactorDomains() );
553 BOOST_CHECK_EQUAL( G4.nrVars(), 4 );
554 BOOST_CHECK_EQUAL( G4.nrFactors(), 6 );
555 BOOST_CHECK_EQUAL( G4.nrEdges(), 11 );
556 BOOST_CHECK_EQUAL( G4.findVar( v0 ), 0 );
557 BOOST_CHECK_EQUAL( G4.findVar( v1 ), 1 );
558 BOOST_CHECK_EQUAL( G4.findVar( v2 ), 2 );
559 BOOST_CHECK_EQUAL( G4.findVar( v3 ), 3 );
560 BOOST_CHECK_EQUAL( G4.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
561 BOOST_CHECK_EQUAL( G4.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
562 BOOST_CHECK_EQUAL( G4.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
563 BOOST_CHECK_EQUAL( G4.findFactor( v01 ), 0 );
564 BOOST_CHECK_EQUAL( G4.findFactor( v12 ), 1 );
565 BOOST_CHECK_EQUAL( G4.findFactor( v1 ), 2 );
566 BOOST_CHECK_EQUAL( G4.findFactor( v02 ), 3 );
567 BOOST_CHECK_EQUAL( G4.findFactor( v3 ), 4 );
568 BOOST_CHECK_EQUAL( G4.findFactor( v123 ), 5 );
569 BOOST_CHECK_THROW( G4.findFactor( v23 ), Exception );
570 BOOST_CHECK_EQUAL( G4.delta( 0 ), v12 );
571 BOOST_CHECK_EQUAL( G4.delta( 1 ), v023 );
572 BOOST_CHECK_EQUAL( G4.delta( 2 ), v013 );
573 BOOST_CHECK_EQUAL( G4.delta( 3 ), v12 );
574 BOOST_CHECK_EQUAL( G4.Delta( 0 ), v012 );
575 BOOST_CHECK_EQUAL( G4.Delta( 1 ), v0123 );
576 BOOST_CHECK_EQUAL( G4.Delta( 2 ), v0123 );
577 BOOST_CHECK_EQUAL( G4.Delta( 3 ), v123 );
578 BOOST_CHECK( G4.isConnected() );
579 BOOST_CHECK( !G4.isTree() );
580 BOOST_CHECK( !G4.isBinary() );
581 BOOST_CHECK( !G4.isPairwise() );
582 BOOST_CHECK( G4.MarkovGraph() == H );
583 BOOST_CHECK( G4.bipGraph() == K );
584 BOOST_CHECK_EQUAL( G4.maximalFactorDomains().size(), 3 );
585 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[0], v01 );
586 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[1], v02 );
587 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[2], v123 );
588 BOOST_CHECK( G4.DAG() == dag );
589 BOOST_CHECK( G4.checkCountingNumbers() );
590 BOOST_CHECK_EQUAL( G4.nrORs(), 3 );
591 BOOST_CHECK_EQUAL( G4.OR(0).c(), 1 );
592 BOOST_CHECK_EQUAL( G4.OR(0), facs[0] * facs[2] );
593 BOOST_CHECK_EQUAL( G4.OR(1).c(), 1 );
594 BOOST_CHECK_EQUAL( G4.OR(1), facs[3] );
595 BOOST_CHECK_EQUAL( G4.OR(2).c(), 1 );
596 BOOST_CHECK_EQUAL( G4.OR(2), facs[1] * facs[4] * facs[5] );
597 BOOST_CHECK_EQUAL( G4.nrIRs(), 3 );
598 BOOST_CHECK_EQUAL( G4.IR(0).c(), -1.0 );
599 BOOST_CHECK_EQUAL( G4.IR(0), v0 );
600 BOOST_CHECK_EQUAL( G4.IR(1).c(), -1.0 );
601 BOOST_CHECK_EQUAL( G4.IR(1), v1 );
602 BOOST_CHECK_EQUAL( G4.IR(2).c(), -1.0 );
603 BOOST_CHECK_EQUAL( G4.IR(2), v2 );
604 BOOST_CHECK_EQUAL( G4.fac2OR(0), 0 );
605 BOOST_CHECK_EQUAL( G4.fac2OR(1), 2 );
606 BOOST_CHECK_EQUAL( G4.fac2OR(2), 0 );
607 BOOST_CHECK_EQUAL( G4.fac2OR(3), 1 );
608 BOOST_CHECK_EQUAL( G4.fac2OR(4), 2 );
609 BOOST_CHECK_EQUAL( G4.fac2OR(5), 2 );
610 }
611
612
613 BOOST_AUTO_TEST_CASE( BackupRestoreTest ) {
614 Var v0( 0, 2 );
615 Var v1( 1, 2 );
616 Var v2( 2, 2 );
617 VarSet v01( v0, v1 );
618 VarSet v02( v0, v2 );
619 VarSet v12( v1, v2 );
620 VarSet v012 = v01 | v2;
621
622 std::vector<Factor> facs;
623 facs.push_back( Factor( v01 ) );
624 facs.push_back( Factor( v12 ) );
625 facs.push_back( Factor( v1 ) );
626 std::vector<Var> vars;
627 vars.push_back( v0 );
628 vars.push_back( v1 );
629 vars.push_back( v2 );
630
631 FactorGraph fg( facs );
632 RegionGraph G( fg, fg.maximalFactorDomains() );
633 RegionGraph Gorg( G );
634 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
635 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
636 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
637 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
638
639 BOOST_CHECK_THROW( G.setFactor( 0, Factor( v0 ), false ), Exception );
640 G.setFactor( 0, Factor( v01, 2.0 ), false );
641 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
642 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
643 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
644 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
645 BOOST_CHECK_THROW( G.restoreFactor( 0 ), Exception );
646 G.setFactor( 0, Factor( v01, 3.0 ), true );
647 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
648 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
649 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
650 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
651 G.restoreFactor( 0 );
652 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
653 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
654 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
655 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
656 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
657 G.setFactor( 0, Gorg.factor( 0 ), false );
658 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
659 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
660 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
661 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
662 G.backupFactor( 0 );
663 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
664 G.setFactor( 0, Factor( v01, 2.0 ), false );
665 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
666 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
667 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
668 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
669 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
670 G.restoreFactor( 0 );
671 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
672 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
673 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
674 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
675 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
676
677 std::map<size_t, Factor> fs;
678 fs[0] = Factor( v01, 3.0 );
679 fs[2] = Factor( v1, 2.0 );
680 G.setFactors( fs, false );
681 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
682 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
683 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
684 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
685 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
686 G.restoreFactors();
687 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
688 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
689 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
690 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
691 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
692 G = Gorg;
693 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
694 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
695 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
696 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
697 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
698 G.setFactors( fs, true );
699 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
700 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
701 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
702 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
703 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
704 G.restoreFactors();
705 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
706 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
707 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
708 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
709 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
710 std::set<size_t> fsind;
711 fsind.insert( 0 );
712 fsind.insert( 2 );
713 G.backupFactors( fsind );
714 G.setFactors( fs, false );
715 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
716 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
717 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
718 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
719 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
720 G.restoreFactors();
721 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
722 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
723 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
724 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
725 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
726
727 G.backupFactors( v2 );
728 G.setFactor( 1, Factor(v12, 5.0) );
729 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
730 BOOST_CHECK_EQUAL( G.factor(1), Factor(v12, 5.0) );
731 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
732 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
733 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
734 G.restoreFactors( v2 );
735 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
736 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
737 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
738 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
739 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
740
741 G.backupFactors( v1 );
742 fs[1] = Factor( v12, 5.0 );
743 G.setFactors( fs, false );
744 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
745 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
746 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
747 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
748 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
749 G.restoreFactors();
750 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
751 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
752 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
753 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
754 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
755 G.setFactors( fs, true );
756 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
757 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
758 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
759 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
760 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
761 G.restoreFactors( v1 );
762 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
763 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
764 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
765 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
766 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
767 }
768
769
770 BOOST_AUTO_TEST_CASE( TransformationsTest ) {
771 Var v0( 0, 2 );
772 Var v1( 1, 2 );
773 Var v2( 2, 2 );
774 VarSet v01( v0, v1 );
775 VarSet v02( v0, v2 );
776 VarSet v12( v1, v2 );
777 VarSet v012 = v01 | v2;
778
779 std::vector<Factor> facs;
780 facs.push_back( Factor( v01 ).randomize() );
781 facs.push_back( Factor( v12 ).randomize() );
782 facs.push_back( Factor( v1 ).randomize() );
783 std::vector<Var> vars;
784 vars.push_back( v0 );
785 vars.push_back( v1 );
786 vars.push_back( v2 );
787
788 FactorGraph fg( facs );
789 RegionGraph G( fg, fg.maximalFactorDomains() );
790
791 FactorGraph Gsmall = G.maximalFactors();
792 BOOST_CHECK_EQUAL( Gsmall.nrVars(), 3 );
793 BOOST_CHECK_EQUAL( Gsmall.nrFactors(), 2 );
794 BOOST_CHECK_EQUAL( Gsmall.factor( 0 ), G.factor( 0 ) * G.factor( 2 ) );
795 BOOST_CHECK_EQUAL( Gsmall.factor( 1 ), G.factor( 1 ) );
796
797 size_t i = 0;
798 for( size_t x = 0; x < 2; x++ ) {
799 FactorGraph Gcl = G.clamped( i, x );
800 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
801 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
802 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
803 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) * G.factor(2) );
804 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1) );
805 }
806 i = 1;
807 for( size_t x = 0; x < 2; x++ ) {
808 FactorGraph Gcl = G.clamped( i, x );
809 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
810 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
811 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) * G.factor(2).slice(vars[i],x) );
812 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) );
813 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) );
814 }
815 i = 2;
816 for( size_t x = 0; x < 2; x++ ) {
817 FactorGraph Gcl = G.clamped( i, x );
818 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
819 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
820 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
821 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0) );
822 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) * G.factor(2) );
823 }
824 }
825
826
827 BOOST_AUTO_TEST_CASE( OperationsTest ) {
828 Var v0( 0, 2 );
829 Var v1( 1, 2 );
830 Var v2( 2, 2 );
831 VarSet v01( v0, v1 );
832 VarSet v02( v0, v2 );
833 VarSet v12( v1, v2 );
834 VarSet v012 = v01 | v2;
835
836 std::vector<Factor> facs;
837 facs.push_back( Factor( v01 ).randomize() );
838 facs.push_back( Factor( v12 ).randomize() );
839 facs.push_back( Factor( v1 ).randomize() );
840 std::vector<Var> vars;
841 vars.push_back( v0 );
842 vars.push_back( v1 );
843 vars.push_back( v2 );
844
845 FactorGraph fg( facs );
846 RegionGraph G( fg, fg.maximalFactorDomains() );
847
848 // clamp
849 RegionGraph Gcl = G;
850 for( size_t i = 0; i < 3; i++ )
851 for( size_t x = 0; x < 2; x++ ) {
852 Gcl.clamp( i, x, true );
853 Factor delta = createFactorDelta( vars[i], x );
854 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
855 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
856 for( size_t j = 0; j < 3; j++ )
857 if( G.factor(j).vars().contains( vars[i] ) )
858 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
859 else
860 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
861 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
862 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
863
864 Gcl.restoreFactors();
865 for( size_t j = 0; j < 3; j++ )
866 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
867 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
868 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
869 }
870
871 // clampVar
872 for( size_t i = 0; i < 3; i++ )
873 for( size_t x = 0; x < 2; x++ ) {
874 Gcl.clampVar( i, std::vector<size_t>(1, x), true );
875 Factor delta = createFactorDelta( vars[i], x );
876 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
877 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
878 for( size_t j = 0; j < 3; j++ )
879 if( G.factor(j).vars().contains( vars[i] ) )
880 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
881 else
882 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
883 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
884 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
885
886 Gcl.restoreFactors();
887 for( size_t j = 0; j < 3; j++ )
888 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
889 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
890 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
891 }
892 for( size_t i = 0; i < 3; i++ )
893 for( size_t x = 0; x < 2; x++ ) {
894 Gcl.clampVar( i, std::vector<size_t>(), true );
895 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
896 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
897 for( size_t j = 0; j < 3; j++ )
898 if( G.factor(j).vars().contains( vars[i] ) )
899 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * 0.0 );
900 else
901 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
902 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
903 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
904
905 Gcl.restoreFactors();
906 for( size_t j = 0; j < 3; j++ )
907 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
908 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
909 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
910 }
911 std::vector<size_t> both;
912 both.push_back( 0 );
913 both.push_back( 1 );
914 for( size_t i = 0; i < 3; i++ )
915 for( size_t x = 0; x < 2; x++ ) {
916 Gcl.clampVar( i, both, true );
917 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
918 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
919 for( size_t j = 0; j < 3; j++ )
920 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
921 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
922 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
923 Gcl.restoreFactors();
924 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
925 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
926 }
927
928 // clampFactor
929 for( size_t x = 0; x < 4; x++ ) {
930 Gcl.clampFactor( 0, std::vector<size_t>(1,x), true );
931 Factor mask( v01, 0.0 );
932 mask.set( x, 1.0 );
933 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) * mask );
934 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
935 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
936 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
937 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
938 Gcl.restoreFactor( 0 );
939 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
940 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
941 }
942 for( size_t x = 0; x < 4; x++ ) {
943 Gcl.clampFactor( 1, std::vector<size_t>(1,x), true );
944 Factor mask( v12, 0.0 );
945 mask.set( x, 1.0 );
946 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
947 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) * mask );
948 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
949 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
950 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
951 Gcl.restoreFactor( 1 );
952 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
953 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
954 }
955 for( size_t x = 0; x < 2; x++ ) {
956 Gcl.clampFactor( 2, std::vector<size_t>(1,x), true );
957 Factor mask( v1, 0.0 );
958 mask.set( x, 1.0 );
959 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
960 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
961 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) * mask );
962 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
963 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
964 Gcl.restoreFactors();
965 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
966 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
967 }
968
969 // makeCavity
970 RegionGraph Gcav( G );
971 Gcav.makeCavity( 0, true );
972 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
973 BOOST_CHECK_EQUAL( Gcav.factor(1), G.factor(1) );
974 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
975 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
976 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
977 Gcav.restoreFactors();
978 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
979 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
980 Gcav.makeCavity( 1, true );
981 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
982 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
983 BOOST_CHECK_EQUAL( Gcav.factor(2), Factor( v1, 1.0 ) );
984 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
985 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
986 Gcav.restoreFactors();
987 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
988 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
989 Gcav.makeCavity( 2, true );
990 BOOST_CHECK_EQUAL( Gcav.factor(0), G.factor(0) );
991 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
992 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
993 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
994 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
995 Gcav.restoreFactors();
996 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
997 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
998 }
999
1000
1001 BOOST_AUTO_TEST_CASE( IOTest ) {
1002 Var v0( 0, 2 );
1003 Var v1( 1, 2 );
1004 Var v2( 2, 2 );
1005 VarSet v01( v0, v1 );
1006 VarSet v02( v0, v2 );
1007 VarSet v12( v1, v2 );
1008 VarSet v012 = v01 | v2;
1009
1010 std::vector<Factor> facs;
1011 facs.push_back( Factor( v01, 1.0 ) );
1012 facs.push_back( Factor( v12, 2.0 ) );
1013 facs.push_back( Factor( v1, 3.0 ) );
1014 std::vector<Var> vars;
1015 vars.push_back( v0 );
1016 vars.push_back( v1 );
1017 vars.push_back( v2 );
1018
1019 FactorGraph fg( facs );
1020 RegionGraph G( fg, fg.maximalFactorDomains() );
1021 BOOST_CHECK_THROW( G.WriteToFile( "regiongraph_test.fg" ), Exception );
1022 BOOST_CHECK_THROW( G.ReadFromFile( "regiongraph_test.fg" ), Exception );
1023 BOOST_CHECK_THROW( G.printDot( std::cout ), Exception );
1024
1025 std::stringstream ss;
1026 std::string s;
1027 ss << G;
1028 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "digraph RegionGraph {" );
1029 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box];" );
1030 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta0 [label=\"a0: {x0, x1}, c=1\"];" );
1031 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta1 [label=\"a1: {x1, x2}, c=1\"];" );
1032 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=ellipse];" );
1033 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tb0 [label=\"b0: {x1}, c=-1\"];" );
1034 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta0 -> b0;" );
1035 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta1 -> b0;" );
1036 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
1037 }