5724e71617031ce049eb9594b6f8261a16ad0673
[libdai.git] / tests / unit / regiongraph_test.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #include <dai/regiongraph.h>
12 #include <vector>
13 #include <strstream>
14
15
16 using namespace dai;
17
18
19 const double tol = 1e-8;
20
21
22 #define BOOST_TEST_MODULE RegionGraphTest
23
24
25 #include <boost/test/unit_test.hpp>
26 #include <boost/test/floating_point_comparison.hpp>
27
28
29 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
30 Var v0( 0, 2 );
31 Var v1( 1, 2 );
32 Var v2( 2, 2 );
33 VarSet v01( v0, v1 );
34 VarSet v02( v0, v2 );
35 VarSet v12( v1, v2 );
36 RegionGraph G;
37 BOOST_CHECK_EQUAL( G.vars(), std::vector<Var>() );
38 BOOST_CHECK_EQUAL( G.factors(), std::vector<Factor>() );
39 BOOST_CHECK( G.bipGraph() == BipartiteGraph() );
40 BOOST_CHECK( G.DAG() == BipartiteGraph() );
41 BOOST_CHECK_EQUAL( G.nrORs(), 0 );
42 BOOST_CHECK_EQUAL( G.nrIRs(), 0 );
43
44 std::vector<Factor> facs;
45 facs.push_back( Factor( v01 ) );
46 facs.push_back( Factor( v02 ) );
47 facs.push_back( Factor( v12 ) );
48 facs.push_back( Factor( v1 ) );
49 std::vector<Var> vars;
50 vars.push_back( v0 );
51 vars.push_back( v1 );
52 vars.push_back( v2 );
53 BipartiteGraph dag( 3, 3 );
54 dag.addEdge( 0, 0 );
55 dag.addEdge( 0, 1 );
56 dag.addEdge( 1, 0 );
57 dag.addEdge( 1, 2 );
58 dag.addEdge( 2, 1 );
59 dag.addEdge( 2, 2 );
60 BipartiteGraph bipgraph( dag );
61 bipgraph.addNode2();
62 bipgraph.addEdge( 1, 3 );
63
64 FactorGraph fg(facs);
65 RegionGraph G1( fg, fg.maximalFactorDomains() );
66 BOOST_CHECK_EQUAL( G1.vars(), vars );
67 BOOST_CHECK_EQUAL( G1.factors(), facs );
68 BOOST_CHECK( G1.bipGraph() == bipgraph );
69 BOOST_CHECK( G1.DAG() == dag );
70 BOOST_CHECK_EQUAL( G1.nrORs(), 3 );
71 BOOST_CHECK_EQUAL( G1.OR(0).c(), 1 );
72 BOOST_CHECK_EQUAL( G1.OR(0), facs[0] * facs[3] );
73 BOOST_CHECK_EQUAL( G1.OR(1).c(), 1 );
74 BOOST_CHECK_EQUAL( G1.OR(1), facs[1] );
75 BOOST_CHECK_EQUAL( G1.OR(2).c(), 1 );
76 BOOST_CHECK_EQUAL( G1.OR(2), facs[2] );
77 BOOST_CHECK_EQUAL( G1.nrIRs(), 3 );
78 BOOST_CHECK_EQUAL( G1.IR(0).c(), -1 );
79 BOOST_CHECK_EQUAL( G1.IR(0), v0 );
80 BOOST_CHECK_EQUAL( G1.IR(1).c(), -1 );
81 BOOST_CHECK_EQUAL( G1.IR(1), v1 );
82 BOOST_CHECK_EQUAL( G1.IR(2).c(), -1 );
83 BOOST_CHECK_EQUAL( G1.IR(2), v2 );
84 BOOST_CHECK_EQUAL( G1.fac2OR(0), 0 );
85 BOOST_CHECK_EQUAL( G1.fac2OR(1), 1 );
86 BOOST_CHECK_EQUAL( G1.fac2OR(2), 2 );
87 BOOST_CHECK_EQUAL( G1.fac2OR(3), 0 );
88 BOOST_CHECK( G1.checkCountingNumbers() );
89
90 std::vector<VarSet> ors;
91 std::vector<Region> irs;
92 ors.push_back( v01 | v2 );
93 irs.push_back( Region( v01, 0.0 ) );
94 irs.push_back( Region( v2, 0.0 ) );
95 typedef std::pair<size_t,size_t> edge;
96 std::vector<edge> edges;
97 edges.push_back( edge( 0, 0 ) );
98 edges.push_back( edge( 0, 1 ) );
99 BipartiteGraph dag2( 1, 2, edges.begin(), edges.end() );
100 RegionGraph G2( fg, ors, irs, edges );
101 BOOST_CHECK_EQUAL( G2.vars(), vars );
102 BOOST_CHECK_EQUAL( G2.factors(), facs );
103 BOOST_CHECK( G2.bipGraph() == bipgraph );
104 BOOST_CHECK( G2.DAG() == dag2 );
105 BOOST_CHECK_EQUAL( G2.nrORs(), 1 );
106 BOOST_CHECK_EQUAL( G2.OR(0).c(), 1 );
107 BOOST_CHECK_EQUAL( G2.OR(0), facs[0] * facs[1] * facs[2] * facs[3] );
108 BOOST_CHECK_EQUAL( G2.nrIRs(), 2 );
109 BOOST_CHECK_EQUAL( G2.IR(0), irs[0] );
110 BOOST_CHECK_EQUAL( G2.IR(1), irs[1] );
111 BOOST_CHECK_EQUAL( G2.fac2OR(0), 0 );
112 BOOST_CHECK_EQUAL( G2.fac2OR(1), 0 );
113 BOOST_CHECK_EQUAL( G2.fac2OR(2), 0 );
114 BOOST_CHECK_EQUAL( G2.fac2OR(3), 0 );
115 BOOST_CHECK( G2.checkCountingNumbers() );
116
117 RegionGraph *G3 = G1.clone();
118 BOOST_CHECK_EQUAL( G3->vars(), G1.vars() );
119 BOOST_CHECK_EQUAL( G3->factors(), G1.factors() );
120 BOOST_CHECK( G3->bipGraph() == bipgraph );
121 BOOST_CHECK( G3->DAG() == dag );
122 BOOST_CHECK_EQUAL( G3->nrORs(), G1.nrORs() );
123 BOOST_CHECK_EQUAL( G3->OR(0), G1.OR(0) );
124 BOOST_CHECK_EQUAL( G3->OR(1), G1.OR(1) );
125 BOOST_CHECK_EQUAL( G3->OR(2), G1.OR(2) );
126 BOOST_CHECK_EQUAL( G3->nrIRs(), G1.nrIRs() );
127 BOOST_CHECK_EQUAL( G3->IR(0), G1.IR(0) );
128 BOOST_CHECK_EQUAL( G3->IR(1), G1.IR(1) );
129 BOOST_CHECK_EQUAL( G3->IR(2), G1.IR(2) );
130 BOOST_CHECK_EQUAL( G3->fac2OR(0), G1.fac2OR(0) );
131 BOOST_CHECK_EQUAL( G3->fac2OR(1), G1.fac2OR(1) );
132 BOOST_CHECK_EQUAL( G3->fac2OR(2), G1.fac2OR(2) );
133 BOOST_CHECK_EQUAL( G3->fac2OR(3), G1.fac2OR(3) );
134 BOOST_CHECK( G3->checkCountingNumbers() );
135 delete G3;
136
137 RegionGraph G4 = G1;
138 BOOST_CHECK_EQUAL( G4.vars(), G1.vars() );
139 BOOST_CHECK_EQUAL( G4.factors(), G1.factors() );
140 BOOST_CHECK( G4.bipGraph() == bipgraph );
141 BOOST_CHECK( G4.DAG() == dag );
142 BOOST_CHECK_EQUAL( G4.nrORs(), G1.nrORs() );
143 BOOST_CHECK_EQUAL( G4.OR(0), G1.OR(0) );
144 BOOST_CHECK_EQUAL( G4.OR(1), G1.OR(1) );
145 BOOST_CHECK_EQUAL( G4.OR(2), G1.OR(2) );
146 BOOST_CHECK_EQUAL( G4.nrIRs(), G1.nrIRs() );
147 BOOST_CHECK_EQUAL( G4.IR(0), G1.IR(0) );
148 BOOST_CHECK_EQUAL( G4.IR(1), G1.IR(1) );
149 BOOST_CHECK_EQUAL( G4.IR(2), G1.IR(2) );
150 BOOST_CHECK_EQUAL( G4.fac2OR(0), G1.fac2OR(0) );
151 BOOST_CHECK_EQUAL( G4.fac2OR(1), G1.fac2OR(1) );
152 BOOST_CHECK_EQUAL( G4.fac2OR(2), G1.fac2OR(2) );
153 BOOST_CHECK_EQUAL( G4.fac2OR(3), G1.fac2OR(3) );
154 BOOST_CHECK( G4.checkCountingNumbers() );
155
156 RegionGraph G5( G1 );
157 BOOST_CHECK_EQUAL( G5.vars(), G1.vars() );
158 BOOST_CHECK_EQUAL( G5.factors(), G1.factors() );
159 BOOST_CHECK( G5.bipGraph() == bipgraph );
160 BOOST_CHECK( G5.DAG() == dag );
161 BOOST_CHECK_EQUAL( G5.nrORs(), G1.nrORs() );
162 BOOST_CHECK_EQUAL( G5.OR(0), G1.OR(0) );
163 BOOST_CHECK_EQUAL( G5.OR(1), G1.OR(1) );
164 BOOST_CHECK_EQUAL( G5.OR(2), G1.OR(2) );
165 BOOST_CHECK_EQUAL( G5.nrIRs(), G1.nrIRs() );
166 BOOST_CHECK_EQUAL( G5.IR(0), G1.IR(0) );
167 BOOST_CHECK_EQUAL( G5.IR(1), G1.IR(1) );
168 BOOST_CHECK_EQUAL( G5.IR(2), G1.IR(2) );
169 BOOST_CHECK_EQUAL( G5.fac2OR(0), G1.fac2OR(0) );
170 BOOST_CHECK_EQUAL( G5.fac2OR(1), G1.fac2OR(1) );
171 BOOST_CHECK_EQUAL( G5.fac2OR(2), G1.fac2OR(2) );
172 BOOST_CHECK_EQUAL( G5.fac2OR(3), G1.fac2OR(3) );
173 BOOST_CHECK( G5.checkCountingNumbers() );
174 }
175
176
177 BOOST_AUTO_TEST_CASE( AccMutTest ) {
178 Var v0( 0, 2 );
179 Var v1( 1, 2 );
180 Var v2( 2, 2 );
181 VarSet v01( v0, v1 );
182 VarSet v02( v0, v2 );
183 VarSet v12( v1, v2 );
184 std::vector<Factor> facs;
185 facs.push_back( Factor( v01 ) );
186 facs.push_back( Factor( v02 ) );
187 facs.push_back( Factor( v12 ) );
188 facs.push_back( Factor( v1 ) );
189 std::vector<Var> vars;
190 vars.push_back( v0 );
191 vars.push_back( v1 );
192 vars.push_back( v2 );
193 BipartiteGraph dag( 3, 3 );
194 dag.addEdge( 0, 0 );
195 dag.addEdge( 0, 1 );
196 dag.addEdge( 1, 0 );
197 dag.addEdge( 1, 2 );
198 dag.addEdge( 2, 1 );
199 dag.addEdge( 2, 2 );
200 BipartiteGraph bipgraph( dag );
201 bipgraph.addNode2();
202 bipgraph.addEdge( 1, 3 );
203
204 FactorGraph fg( facs );
205 RegionGraph G( fg, fg.maximalFactorDomains() );
206 BOOST_CHECK_EQUAL( G.var(0), v0 );
207 BOOST_CHECK_EQUAL( G.var(1), v1 );
208 BOOST_CHECK_EQUAL( G.var(2), v2 );
209 BOOST_CHECK_EQUAL( G.vars(), vars );
210 BOOST_CHECK_EQUAL( G.factor(0), facs[0] );
211 BOOST_CHECK_EQUAL( G.factor(1), facs[1] );
212 BOOST_CHECK_EQUAL( G.factor(2), facs[2] );
213 BOOST_CHECK_EQUAL( G.factor(3), facs[3] );
214 BOOST_CHECK_EQUAL( G.factors(), facs );
215 BOOST_CHECK( G.bipGraph() == bipgraph );
216 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
217 BOOST_CHECK_EQUAL( G.nbV(0,0), 0 );
218 BOOST_CHECK_EQUAL( G.nbV(0,1), 1 );
219 BOOST_CHECK_EQUAL( G.nbV(1).size(), 3 );
220 BOOST_CHECK_EQUAL( G.nbV(1,0), 0 );
221 BOOST_CHECK_EQUAL( G.nbV(1,1), 2 );
222 BOOST_CHECK_EQUAL( G.nbV(1,2), 3 );
223 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
224 BOOST_CHECK_EQUAL( G.nbV(2,0), 1 );
225 BOOST_CHECK_EQUAL( G.nbV(2,1), 2 );
226 BOOST_CHECK_EQUAL( G.nbF(0).size(), 2 );
227 BOOST_CHECK_EQUAL( G.nbF(0,0), 0 );
228 BOOST_CHECK_EQUAL( G.nbF(0,1), 1 );
229 BOOST_CHECK_EQUAL( G.nbF(1).size(), 2 );
230 BOOST_CHECK_EQUAL( G.nbF(1,0), 0 );
231 BOOST_CHECK_EQUAL( G.nbF(1,1), 2 );
232 BOOST_CHECK_EQUAL( G.nbF(2).size(), 2 );
233 BOOST_CHECK_EQUAL( G.nbF(2,0), 1 );
234 BOOST_CHECK_EQUAL( G.nbF(2,1), 2 );
235 BOOST_CHECK_EQUAL( G.nbF(3).size(), 1 );
236 BOOST_CHECK_EQUAL( G.nbF(3,0), 1 );
237 BOOST_CHECK( G.DAG() == dag );
238 BOOST_CHECK_EQUAL( G.nrORs(), 3 );
239 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
240 BOOST_CHECK_EQUAL( G.OR(0).vars(), v01 );
241 BOOST_CHECK_EQUAL( G.OR(0), facs[0] * facs[3] );
242 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
243 BOOST_CHECK_EQUAL( G.OR(1).vars(), v02 );
244 BOOST_CHECK_EQUAL( G.OR(1), facs[1] );
245 BOOST_CHECK_EQUAL( G.OR(2).c(), 1 );
246 BOOST_CHECK_EQUAL( G.OR(2).vars(), v12 );
247 BOOST_CHECK_EQUAL( G.OR(2), facs[2] );
248 BOOST_CHECK_EQUAL( G.nrIRs(), 3 );
249 BOOST_CHECK_EQUAL( G.IR(0).c(), -1 );
250 BOOST_CHECK_EQUAL( G.IR(0), v0 );
251 BOOST_CHECK_EQUAL( G.IR(1).c(), -1 );
252 BOOST_CHECK_EQUAL( G.IR(1), v1 );
253 BOOST_CHECK_EQUAL( G.IR(2).c(), -1 );
254 BOOST_CHECK_EQUAL( G.IR(2), v2 );
255 BOOST_CHECK_EQUAL( G.fac2OR(0), 0 );
256 BOOST_CHECK_EQUAL( G.fac2OR(1), 1 );
257 BOOST_CHECK_EQUAL( G.fac2OR(2), 2 );
258 BOOST_CHECK_EQUAL( G.fac2OR(3), 0 );
259 BOOST_CHECK_EQUAL( G.nbOR(0).size(), 2 );
260 BOOST_CHECK_EQUAL( G.nbOR(0)[0], 0 );
261 BOOST_CHECK_EQUAL( G.nbOR(0)[1], 1 );
262 BOOST_CHECK_EQUAL( G.nbOR(1).size(), 2 );
263 BOOST_CHECK_EQUAL( G.nbOR(1)[0], 0 );
264 BOOST_CHECK_EQUAL( G.nbOR(1)[1], 2 );
265 BOOST_CHECK_EQUAL( G.nbOR(2).size(), 2 );
266 BOOST_CHECK_EQUAL( G.nbOR(2)[0], 1 );
267 BOOST_CHECK_EQUAL( G.nbOR(2)[1], 2 );
268 BOOST_CHECK_EQUAL( G.nbIR(0).size(), 2 );
269 BOOST_CHECK_EQUAL( G.nbIR(0)[0], 0 );
270 BOOST_CHECK_EQUAL( G.nbIR(0)[1], 1 );
271 BOOST_CHECK_EQUAL( G.nbIR(1).size(), 2 );
272 BOOST_CHECK_EQUAL( G.nbIR(1)[0], 0 );
273 BOOST_CHECK_EQUAL( G.nbIR(1)[1], 2 );
274 BOOST_CHECK_EQUAL( G.nbIR(2).size(), 2 );
275 BOOST_CHECK_EQUAL( G.nbIR(2)[0], 1 );
276 BOOST_CHECK_EQUAL( G.nbIR(2)[1], 2 );
277 }
278
279
280 BOOST_AUTO_TEST_CASE( QueriesTest ) {
281 Var v0( 0, 2 );
282 Var v1( 1, 2 );
283 Var v2( 2, 2 );
284 VarSet v01( v0, v1 );
285 VarSet v02( v0, v2 );
286 VarSet v12( v1, v2 );
287 VarSet v012 = v01 | v2;
288
289 RegionGraph G0;
290 BOOST_CHECK_EQUAL( G0.nrVars(), 0 );
291 BOOST_CHECK_EQUAL( G0.nrFactors(), 0 );
292 BOOST_CHECK_EQUAL( G0.nrEdges(), 0 );
293 BOOST_CHECK_THROW( G0.findVar( v0 ), Exception );
294 BOOST_CHECK_THROW( G0.findVars( v01 ), Exception );
295 BOOST_CHECK_THROW( G0.findFactor( v01 ), Exception );
296 #ifdef DAI_DBEUG
297 BOOST_CHECK_THROW( G0.delta( 0 ), Exception );
298 BOOST_CHECK_THROW( G0.Delta( 0 ), Exception );
299 BOOST_CHECK_THROW( G0.delta( v0 ), Exception );
300 BOOST_CHECK_THROW( G0.Delta( v0 ), Exception );
301 BOOST_CHECK_THROW( G0.fac2OR( 0 ), Exception );
302 #endif
303 BOOST_CHECK( G0.isConnected() );
304 BOOST_CHECK( G0.isTree() );
305 BOOST_CHECK( G0.isBinary() );
306 BOOST_CHECK( G0.isPairwise() );
307 BOOST_CHECK( G0.MarkovGraph() == GraphAL() );
308 BOOST_CHECK( G0.bipGraph() == BipartiteGraph() );
309 BOOST_CHECK_EQUAL( G0.maximalFactorDomains().size(), 0 );
310 BOOST_CHECK( G0.DAG() == BipartiteGraph() );
311 BOOST_CHECK( G0.checkCountingNumbers() );
312 BOOST_CHECK_EQUAL( G0.nrORs(), 0 );
313 BOOST_CHECK_EQUAL( G0.nrIRs(), 0 );
314
315 std::vector<Factor> facs;
316 facs.push_back( Factor( v01 ) );
317 facs.push_back( Factor( v12 ) );
318 facs.push_back( Factor( v1 ) );
319 std::vector<Var> vars;
320 vars.push_back( v0 );
321 vars.push_back( v1 );
322 vars.push_back( v2 );
323 GraphAL H(3);
324 H.addEdge( 0, 1 );
325 H.addEdge( 1, 2 );
326 BipartiteGraph K(3, 3);
327 K.addEdge( 0, 0 );
328 K.addEdge( 1, 0 );
329 K.addEdge( 1, 1 );
330 K.addEdge( 2, 1 );
331 K.addEdge( 1, 2 );
332 BipartiteGraph dag(2, 1);
333 dag.addEdge( 0, 0 );
334 dag.addEdge( 1, 0 );
335
336 FactorGraph fg( facs );
337 RegionGraph G1( fg, fg.maximalFactorDomains() );
338 BOOST_CHECK_EQUAL( G1.nrVars(), 3 );
339 BOOST_CHECK_EQUAL( G1.nrFactors(), 3 );
340 BOOST_CHECK_EQUAL( G1.nrEdges(), 5 );
341 BOOST_CHECK_EQUAL( G1.findVar( v0 ), 0 );
342 BOOST_CHECK_EQUAL( G1.findVar( v1 ), 1 );
343 BOOST_CHECK_EQUAL( G1.findVar( v2 ), 2 );
344 BOOST_CHECK_EQUAL( G1.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
345 BOOST_CHECK_EQUAL( G1.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
346 BOOST_CHECK_EQUAL( G1.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
347 BOOST_CHECK_EQUAL( G1.findFactor( v01 ), 0 );
348 BOOST_CHECK_EQUAL( G1.findFactor( v12 ), 1 );
349 BOOST_CHECK_EQUAL( G1.findFactor( v1 ), 2 );
350 BOOST_CHECK_THROW( G1.findFactor( v02 ), Exception );
351 BOOST_CHECK_EQUAL( G1.delta( 0 ), v1 );
352 BOOST_CHECK_EQUAL( G1.delta( 1 ), v02 );
353 BOOST_CHECK_EQUAL( G1.delta( 2 ), v1 );
354 BOOST_CHECK_EQUAL( G1.Delta( 0 ), v01 );
355 BOOST_CHECK_EQUAL( G1.Delta( 1 ), v012 );
356 BOOST_CHECK_EQUAL( G1.Delta( 2 ), v12 );
357 BOOST_CHECK_EQUAL( G1.delta( v0 ), v1 );
358 BOOST_CHECK_EQUAL( G1.delta( v1 ), v02 );
359 BOOST_CHECK_EQUAL( G1.delta( v2 ), v1 );
360 BOOST_CHECK_EQUAL( G1.delta( v01 ), v2 );
361 BOOST_CHECK_EQUAL( G1.delta( v02 ), v1 );
362 BOOST_CHECK_EQUAL( G1.delta( v12 ), v0 );
363 BOOST_CHECK_EQUAL( G1.delta( v012 ), VarSet() );
364 BOOST_CHECK_EQUAL( G1.Delta( v0 ), v01 );
365 BOOST_CHECK_EQUAL( G1.Delta( v1 ), v012 );
366 BOOST_CHECK_EQUAL( G1.Delta( v2 ), v12 );
367 BOOST_CHECK_EQUAL( G1.Delta( v01 ), v012 );
368 BOOST_CHECK_EQUAL( G1.Delta( v02 ), v012 );
369 BOOST_CHECK_EQUAL( G1.Delta( v12 ), v012 );
370 BOOST_CHECK_EQUAL( G1.Delta( v012 ), v012 );
371 BOOST_CHECK( G1.isConnected() );
372 BOOST_CHECK( G1.isTree() );
373 BOOST_CHECK( G1.isBinary() );
374 BOOST_CHECK( G1.isPairwise() );
375 BOOST_CHECK( G1.MarkovGraph() == H );
376 BOOST_CHECK( G1.bipGraph() == K );
377 BOOST_CHECK_EQUAL( G1.maximalFactorDomains().size(), 2 );
378 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[0], v01 );
379 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[1], v12 );
380 BOOST_CHECK( G1.DAG() == dag );
381 BOOST_CHECK( G1.checkCountingNumbers() );
382 BOOST_CHECK_EQUAL( G1.nrORs(), 2 );
383 BOOST_CHECK_EQUAL( G1.OR(0).c(), 1 );
384 BOOST_CHECK_EQUAL( G1.OR(0), facs[0] * facs[2] );
385 BOOST_CHECK_EQUAL( G1.OR(1).c(), 1 );
386 BOOST_CHECK_EQUAL( G1.OR(1), facs[1] );
387 BOOST_CHECK_EQUAL( G1.nrIRs(), 1 );
388 BOOST_CHECK_EQUAL( G1.IR(0).c(), -1 );
389 BOOST_CHECK_EQUAL( G1.IR(0), v1 );
390 BOOST_CHECK_EQUAL( G1.fac2OR(0), 0 );
391 BOOST_CHECK_EQUAL( G1.fac2OR(1), 1 );
392 BOOST_CHECK_EQUAL( G1.fac2OR(2), 0 );
393
394 facs.push_back( Factor( v02 ) );
395 H.addEdge( 0, 2 );
396 K.addNode2();
397 K.addEdge( 0, 3 );
398 K.addEdge( 2, 3 );
399 dag = BipartiteGraph( 3, 3 );
400 dag.addEdge( 0, 0 );
401 dag.addEdge( 0, 1 );
402 dag.addEdge( 1, 1 );
403 dag.addEdge( 1, 2 );
404 dag.addEdge( 2, 0 );
405 dag.addEdge( 2, 2 );
406 fg = FactorGraph( facs );
407 RegionGraph G2( fg, fg.maximalFactorDomains() );
408 BOOST_CHECK_EQUAL( G2.nrVars(), 3 );
409 BOOST_CHECK_EQUAL( G2.nrFactors(), 4 );
410 BOOST_CHECK_EQUAL( G2.nrEdges(), 7 );
411 BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
412 BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
413 BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
414 BOOST_CHECK_EQUAL( G2.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
415 BOOST_CHECK_EQUAL( G2.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
416 BOOST_CHECK_EQUAL( G2.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
417 BOOST_CHECK_EQUAL( G2.findFactor( v01 ), 0 );
418 BOOST_CHECK_EQUAL( G2.findFactor( v12 ), 1 );
419 BOOST_CHECK_EQUAL( G2.findFactor( v1 ), 2 );
420 BOOST_CHECK_EQUAL( G2.findFactor( v02 ), 3 );
421 BOOST_CHECK_EQUAL( G2.delta( 0 ), v12 );
422 BOOST_CHECK_EQUAL( G2.delta( 1 ), v02 );
423 BOOST_CHECK_EQUAL( G2.delta( 2 ), v01 );
424 BOOST_CHECK_EQUAL( G2.Delta( 0 ), v012 );
425 BOOST_CHECK_EQUAL( G2.Delta( 1 ), v012 );
426 BOOST_CHECK_EQUAL( G2.Delta( 2 ), v012 );
427 BOOST_CHECK( G2.isConnected() );
428 BOOST_CHECK( !G2.isTree() );
429 BOOST_CHECK( G2.isBinary() );
430 BOOST_CHECK( G2.isPairwise() );
431 BOOST_CHECK( G2.MarkovGraph() == H );
432 BOOST_CHECK( G2.bipGraph() == K );
433 BOOST_CHECK_EQUAL( G2.maximalFactorDomains().size(), 3 );
434 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[0], v01 );
435 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[1], v12 );
436 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[2], v02 );
437 BOOST_CHECK( G2.DAG() == dag );
438 BOOST_CHECK( G2.checkCountingNumbers() );
439 BOOST_CHECK_EQUAL( G2.nrORs(), 3 );
440 BOOST_CHECK_EQUAL( G2.OR(0).c(), 1 );
441 BOOST_CHECK_EQUAL( G2.OR(0), facs[0] * facs[2] );
442 BOOST_CHECK_EQUAL( G2.OR(1).c(), 1 );
443 BOOST_CHECK_EQUAL( G2.OR(1), facs[1] );
444 BOOST_CHECK_EQUAL( G2.OR(2).c(), 1 );
445 BOOST_CHECK_EQUAL( G2.OR(2), facs[3] );
446 BOOST_CHECK_EQUAL( G2.nrIRs(), 3 );
447 BOOST_CHECK_EQUAL( G2.IR(0).c(), -1.0 );
448 BOOST_CHECK_EQUAL( G2.IR(0), v0 );
449 BOOST_CHECK_EQUAL( G2.IR(1).c(), -1.0 );
450 BOOST_CHECK_EQUAL( G2.IR(1), v1 );
451 BOOST_CHECK_EQUAL( G2.IR(2).c(), -1.0 );
452 BOOST_CHECK_EQUAL( G2.IR(2), v2 );
453 BOOST_CHECK_EQUAL( G2.fac2OR(0), 0 );
454 BOOST_CHECK_EQUAL( G2.fac2OR(1), 1 );
455 BOOST_CHECK_EQUAL( G2.fac2OR(2), 0 );
456 BOOST_CHECK_EQUAL( G2.fac2OR(3), 2 );
457
458 Var v3( 3, 3 );
459 VarSet v03( v0, v3 );
460 VarSet v13( v1, v3 );
461 VarSet v23( v2, v3 );
462 VarSet v013 = v01 | v3;
463 VarSet v023 = v02 | v3;
464 VarSet v123 = v12 | v3;
465 VarSet v0123 = v012 | v3;
466 vars.push_back( v3 );
467 facs.push_back( Factor( v3 ) );
468 H.addNode();
469 K.addNode1();
470 K.addNode2();
471 K.addEdge( 3, 4 );
472 dag.addNode1();
473 fg = FactorGraph( facs );
474 RegionGraph G3( fg, fg.maximalFactorDomains() );
475 BOOST_CHECK_EQUAL( G3.nrVars(), 4 );
476 BOOST_CHECK_EQUAL( G3.nrFactors(), 5 );
477 BOOST_CHECK_EQUAL( G3.nrEdges(), 8 );
478 BOOST_CHECK_EQUAL( G3.findVar( v0 ), 0 );
479 BOOST_CHECK_EQUAL( G3.findVar( v1 ), 1 );
480 BOOST_CHECK_EQUAL( G3.findVar( v2 ), 2 );
481 BOOST_CHECK_EQUAL( G3.findVar( v3 ), 3 );
482 BOOST_CHECK_EQUAL( G3.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
483 BOOST_CHECK_EQUAL( G3.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
484 BOOST_CHECK_EQUAL( G3.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
485 BOOST_CHECK_EQUAL( G3.findFactor( v01 ), 0 );
486 BOOST_CHECK_EQUAL( G3.findFactor( v12 ), 1 );
487 BOOST_CHECK_EQUAL( G3.findFactor( v1 ), 2 );
488 BOOST_CHECK_EQUAL( G3.findFactor( v02 ), 3 );
489 BOOST_CHECK_EQUAL( G3.findFactor( v3 ), 4 );
490 BOOST_CHECK_THROW( G3.findFactor( v23 ), Exception );
491 BOOST_CHECK_EQUAL( G3.delta( 0 ), v12 );
492 BOOST_CHECK_EQUAL( G3.delta( 1 ), v02 );
493 BOOST_CHECK_EQUAL( G3.delta( 2 ), v01 );
494 BOOST_CHECK_EQUAL( G3.delta( 3 ), VarSet() );
495 BOOST_CHECK_EQUAL( G3.Delta( 0 ), v012 );
496 BOOST_CHECK_EQUAL( G3.Delta( 1 ), v012 );
497 BOOST_CHECK_EQUAL( G3.Delta( 2 ), v012 );
498 BOOST_CHECK_EQUAL( G3.Delta( 3 ), v3 );
499 BOOST_CHECK( !G3.isConnected() );
500 BOOST_CHECK( !G3.isTree() );
501 BOOST_CHECK( !G3.isBinary() );
502 BOOST_CHECK( G3.isPairwise() );
503 BOOST_CHECK( G3.MarkovGraph() == H );
504 BOOST_CHECK( G3.bipGraph() == K );
505 BOOST_CHECK_EQUAL( G3.maximalFactorDomains().size(), 4 );
506 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[0], v01 );
507 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[1], v12 );
508 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[2], v02 );
509 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[3], v3 );
510 BOOST_CHECK( G3.DAG() == dag );
511 BOOST_CHECK( G3.checkCountingNumbers() );
512 BOOST_CHECK_EQUAL( G3.nrORs(), 4 );
513 BOOST_CHECK_EQUAL( G3.OR(0).c(), 1 );
514 BOOST_CHECK_EQUAL( G3.OR(0), facs[0] * facs[2] );
515 BOOST_CHECK_EQUAL( G3.OR(1).c(), 1 );
516 BOOST_CHECK_EQUAL( G3.OR(1), facs[1] );
517 BOOST_CHECK_EQUAL( G3.OR(2).c(), 1 );
518 BOOST_CHECK_EQUAL( G3.OR(2), facs[3] );
519 BOOST_CHECK_EQUAL( G3.OR(3).c(), 1 );
520 BOOST_CHECK_EQUAL( G3.OR(3), facs[4] );
521 BOOST_CHECK_EQUAL( G3.nrIRs(), 3 );
522 BOOST_CHECK_EQUAL( G3.IR(0).c(), -1.0 );
523 BOOST_CHECK_EQUAL( G3.IR(0), v0 );
524 BOOST_CHECK_EQUAL( G3.IR(1).c(), -1.0 );
525 BOOST_CHECK_EQUAL( G3.IR(1), v1 );
526 BOOST_CHECK_EQUAL( G3.IR(2).c(), -1.0 );
527 BOOST_CHECK_EQUAL( G3.IR(2), v2 );
528 BOOST_CHECK_EQUAL( G3.fac2OR(0), 0 );
529 BOOST_CHECK_EQUAL( G3.fac2OR(1), 1 );
530 BOOST_CHECK_EQUAL( G3.fac2OR(2), 0 );
531 BOOST_CHECK_EQUAL( G3.fac2OR(3), 2 );
532 BOOST_CHECK_EQUAL( G3.fac2OR(4), 3 );
533
534 facs.push_back( Factor( v123 ) );
535 H.addEdge( 1, 3 );
536 H.addEdge( 2, 3 );
537 K.addNode2();
538 K.addEdge( 1, 5 );
539 K.addEdge( 2, 5 );
540 K.addEdge( 3, 5 );
541 dag = BipartiteGraph( 3, 3 );
542 dag.addEdge( 0, 0 );
543 dag.addEdge( 0, 1 );
544 dag.addEdge( 1, 0 );
545 dag.addEdge( 1, 2 );
546 dag.addEdge( 2, 1 );
547 dag.addEdge( 2, 2 );
548 fg = FactorGraph( facs );
549 RegionGraph G4( fg, fg.maximalFactorDomains() );
550 BOOST_CHECK_EQUAL( G4.nrVars(), 4 );
551 BOOST_CHECK_EQUAL( G4.nrFactors(), 6 );
552 BOOST_CHECK_EQUAL( G4.nrEdges(), 11 );
553 BOOST_CHECK_EQUAL( G4.findVar( v0 ), 0 );
554 BOOST_CHECK_EQUAL( G4.findVar( v1 ), 1 );
555 BOOST_CHECK_EQUAL( G4.findVar( v2 ), 2 );
556 BOOST_CHECK_EQUAL( G4.findVar( v3 ), 3 );
557 BOOST_CHECK_EQUAL( G4.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
558 BOOST_CHECK_EQUAL( G4.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
559 BOOST_CHECK_EQUAL( G4.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
560 BOOST_CHECK_EQUAL( G4.findFactor( v01 ), 0 );
561 BOOST_CHECK_EQUAL( G4.findFactor( v12 ), 1 );
562 BOOST_CHECK_EQUAL( G4.findFactor( v1 ), 2 );
563 BOOST_CHECK_EQUAL( G4.findFactor( v02 ), 3 );
564 BOOST_CHECK_EQUAL( G4.findFactor( v3 ), 4 );
565 BOOST_CHECK_EQUAL( G4.findFactor( v123 ), 5 );
566 BOOST_CHECK_THROW( G4.findFactor( v23 ), Exception );
567 BOOST_CHECK_EQUAL( G4.delta( 0 ), v12 );
568 BOOST_CHECK_EQUAL( G4.delta( 1 ), v023 );
569 BOOST_CHECK_EQUAL( G4.delta( 2 ), v013 );
570 BOOST_CHECK_EQUAL( G4.delta( 3 ), v12 );
571 BOOST_CHECK_EQUAL( G4.Delta( 0 ), v012 );
572 BOOST_CHECK_EQUAL( G4.Delta( 1 ), v0123 );
573 BOOST_CHECK_EQUAL( G4.Delta( 2 ), v0123 );
574 BOOST_CHECK_EQUAL( G4.Delta( 3 ), v123 );
575 BOOST_CHECK( G4.isConnected() );
576 BOOST_CHECK( !G4.isTree() );
577 BOOST_CHECK( !G4.isBinary() );
578 BOOST_CHECK( !G4.isPairwise() );
579 BOOST_CHECK( G4.MarkovGraph() == H );
580 BOOST_CHECK( G4.bipGraph() == K );
581 BOOST_CHECK_EQUAL( G4.maximalFactorDomains().size(), 3 );
582 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[0], v01 );
583 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[1], v02 );
584 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[2], v123 );
585 BOOST_CHECK( G4.DAG() == dag );
586 BOOST_CHECK( G4.checkCountingNumbers() );
587 BOOST_CHECK_EQUAL( G4.nrORs(), 3 );
588 BOOST_CHECK_EQUAL( G4.OR(0).c(), 1 );
589 BOOST_CHECK_EQUAL( G4.OR(0), facs[0] * facs[2] );
590 BOOST_CHECK_EQUAL( G4.OR(1).c(), 1 );
591 BOOST_CHECK_EQUAL( G4.OR(1), facs[3] );
592 BOOST_CHECK_EQUAL( G4.OR(2).c(), 1 );
593 BOOST_CHECK_EQUAL( G4.OR(2), facs[1] * facs[4] * facs[5] );
594 BOOST_CHECK_EQUAL( G4.nrIRs(), 3 );
595 BOOST_CHECK_EQUAL( G4.IR(0).c(), -1.0 );
596 BOOST_CHECK_EQUAL( G4.IR(0), v0 );
597 BOOST_CHECK_EQUAL( G4.IR(1).c(), -1.0 );
598 BOOST_CHECK_EQUAL( G4.IR(1), v1 );
599 BOOST_CHECK_EQUAL( G4.IR(2).c(), -1.0 );
600 BOOST_CHECK_EQUAL( G4.IR(2), v2 );
601 BOOST_CHECK_EQUAL( G4.fac2OR(0), 0 );
602 BOOST_CHECK_EQUAL( G4.fac2OR(1), 2 );
603 BOOST_CHECK_EQUAL( G4.fac2OR(2), 0 );
604 BOOST_CHECK_EQUAL( G4.fac2OR(3), 1 );
605 BOOST_CHECK_EQUAL( G4.fac2OR(4), 2 );
606 BOOST_CHECK_EQUAL( G4.fac2OR(5), 2 );
607 }
608
609
610 BOOST_AUTO_TEST_CASE( BackupRestoreTest ) {
611 Var v0( 0, 2 );
612 Var v1( 1, 2 );
613 Var v2( 2, 2 );
614 VarSet v01( v0, v1 );
615 VarSet v02( v0, v2 );
616 VarSet v12( v1, v2 );
617 VarSet v012 = v01 | v2;
618
619 std::vector<Factor> facs;
620 facs.push_back( Factor( v01 ) );
621 facs.push_back( Factor( v12 ) );
622 facs.push_back( Factor( v1 ) );
623 std::vector<Var> vars;
624 vars.push_back( v0 );
625 vars.push_back( v1 );
626 vars.push_back( v2 );
627
628 FactorGraph fg( facs );
629 RegionGraph G( fg, fg.maximalFactorDomains() );
630 RegionGraph Gorg( G );
631 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
632 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
633 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
634 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
635
636 BOOST_CHECK_THROW( G.setFactor( 0, Factor( v0 ), false ), Exception );
637 G.setFactor( 0, Factor( v01, 2.0 ), false );
638 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
639 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
640 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
641 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
642 BOOST_CHECK_THROW( G.restoreFactor( 0 ), Exception );
643 G.setFactor( 0, Factor( v01, 3.0 ), true );
644 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
645 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
646 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
647 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
648 G.restoreFactor( 0 );
649 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
650 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
651 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
652 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
653 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
654 G.setFactor( 0, Gorg.factor( 0 ), false );
655 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
656 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
657 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
658 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
659 G.backupFactor( 0 );
660 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
661 G.setFactor( 0, Factor( v01, 2.0 ), false );
662 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
663 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
664 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
665 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
666 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
667 G.restoreFactor( 0 );
668 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
669 BOOST_CHECK_EQUAL( G.OR(0).c(), 1 );
670 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
671 BOOST_CHECK_EQUAL( G.OR(1).c(), 1 );
672 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
673
674 std::map<size_t, Factor> fs;
675 fs[0] = Factor( v01, 3.0 );
676 fs[2] = Factor( v1, 2.0 );
677 G.setFactors( fs, false );
678 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
679 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
680 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
681 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
682 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
683 G.restoreFactors();
684 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
685 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
686 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
687 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
688 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
689 G = Gorg;
690 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
691 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
692 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
693 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
694 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
695 G.setFactors( fs, true );
696 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
697 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
698 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
699 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
700 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
701 G.restoreFactors();
702 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
703 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
704 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
705 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
706 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
707 std::set<size_t> fsind;
708 fsind.insert( 0 );
709 fsind.insert( 2 );
710 G.backupFactors( fsind );
711 G.setFactors( fs, false );
712 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
713 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
714 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
715 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
716 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
717 G.restoreFactors();
718 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
719 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
720 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
721 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
722 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
723
724 G.backupFactors( v2 );
725 G.setFactor( 1, Factor(v12, 5.0) );
726 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
727 BOOST_CHECK_EQUAL( G.factor(1), Factor(v12, 5.0) );
728 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
729 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
730 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
731 G.restoreFactors( v2 );
732 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
733 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
734 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
735 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
736 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
737
738 G.backupFactors( v1 );
739 fs[1] = Factor( v12, 5.0 );
740 G.setFactors( fs, false );
741 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
742 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
743 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
744 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
745 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
746 G.restoreFactors();
747 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
748 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
749 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
750 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
751 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
752 G.setFactors( fs, true );
753 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
754 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
755 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
756 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
757 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
758 G.restoreFactors( v1 );
759 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
760 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
761 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
762 BOOST_CHECK_EQUAL( G.OR(0), G.factor(0) * G.factor(2) );
763 BOOST_CHECK_EQUAL( G.OR(1), G.factor(1) );
764 }
765
766
767 BOOST_AUTO_TEST_CASE( TransformationsTest ) {
768 Var v0( 0, 2 );
769 Var v1( 1, 2 );
770 Var v2( 2, 2 );
771 VarSet v01( v0, v1 );
772 VarSet v02( v0, v2 );
773 VarSet v12( v1, v2 );
774 VarSet v012 = v01 | v2;
775
776 std::vector<Factor> facs;
777 facs.push_back( Factor( v01 ).randomize() );
778 facs.push_back( Factor( v12 ).randomize() );
779 facs.push_back( Factor( v1 ).randomize() );
780 std::vector<Var> vars;
781 vars.push_back( v0 );
782 vars.push_back( v1 );
783 vars.push_back( v2 );
784
785 FactorGraph fg( facs );
786 RegionGraph G( fg, fg.maximalFactorDomains() );
787
788 FactorGraph Gsmall = G.maximalFactors();
789 BOOST_CHECK_EQUAL( Gsmall.nrVars(), 3 );
790 BOOST_CHECK_EQUAL( Gsmall.nrFactors(), 2 );
791 BOOST_CHECK_EQUAL( Gsmall.factor( 0 ), G.factor( 0 ) * G.factor( 2 ) );
792 BOOST_CHECK_EQUAL( Gsmall.factor( 1 ), G.factor( 1 ) );
793
794 size_t i = 0;
795 for( size_t x = 0; x < 2; x++ ) {
796 FactorGraph Gcl = G.clamped( i, x );
797 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
798 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
799 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
800 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) * G.factor(2) );
801 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1) );
802 }
803 i = 1;
804 for( size_t x = 0; x < 2; x++ ) {
805 FactorGraph Gcl = G.clamped( i, x );
806 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
807 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
808 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) * G.factor(2).slice(vars[i],x) );
809 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) );
810 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) );
811 }
812 i = 2;
813 for( size_t x = 0; x < 2; x++ ) {
814 FactorGraph Gcl = G.clamped( i, x );
815 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
816 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
817 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
818 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0) );
819 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) * G.factor(2) );
820 }
821 }
822
823
824 BOOST_AUTO_TEST_CASE( OperationsTest ) {
825 Var v0( 0, 2 );
826 Var v1( 1, 2 );
827 Var v2( 2, 2 );
828 VarSet v01( v0, v1 );
829 VarSet v02( v0, v2 );
830 VarSet v12( v1, v2 );
831 VarSet v012 = v01 | v2;
832
833 std::vector<Factor> facs;
834 facs.push_back( Factor( v01 ).randomize() );
835 facs.push_back( Factor( v12 ).randomize() );
836 facs.push_back( Factor( v1 ).randomize() );
837 std::vector<Var> vars;
838 vars.push_back( v0 );
839 vars.push_back( v1 );
840 vars.push_back( v2 );
841
842 FactorGraph fg( facs );
843 RegionGraph G( fg, fg.maximalFactorDomains() );
844
845 // clamp
846 RegionGraph Gcl = G;
847 for( size_t i = 0; i < 3; i++ )
848 for( size_t x = 0; x < 2; x++ ) {
849 Gcl.clamp( i, x, true );
850 Factor delta = createFactorDelta( vars[i], x );
851 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
852 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
853 for( size_t j = 0; j < 3; j++ )
854 if( G.factor(j).vars().contains( vars[i] ) )
855 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
856 else
857 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
858 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
859 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
860
861 Gcl.restoreFactors();
862 for( size_t j = 0; j < 3; j++ )
863 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
864 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
865 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
866 }
867
868 // clampVar
869 for( size_t i = 0; i < 3; i++ )
870 for( size_t x = 0; x < 2; x++ ) {
871 Gcl.clampVar( i, std::vector<size_t>(1, x), true );
872 Factor delta = createFactorDelta( vars[i], x );
873 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
874 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
875 for( size_t j = 0; j < 3; j++ )
876 if( G.factor(j).vars().contains( vars[i] ) )
877 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
878 else
879 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
880 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
881 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
882
883 Gcl.restoreFactors();
884 for( size_t j = 0; j < 3; j++ )
885 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
886 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
887 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
888 }
889 for( size_t i = 0; i < 3; i++ )
890 for( size_t x = 0; x < 2; x++ ) {
891 Gcl.clampVar( i, std::vector<size_t>(), true );
892 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
893 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
894 for( size_t j = 0; j < 3; j++ )
895 if( G.factor(j).vars().contains( vars[i] ) )
896 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * 0.0 );
897 else
898 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
899 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
900 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
901
902 Gcl.restoreFactors();
903 for( size_t j = 0; j < 3; j++ )
904 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
905 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
906 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
907 }
908 std::vector<size_t> both;
909 both.push_back( 0 );
910 both.push_back( 1 );
911 for( size_t i = 0; i < 3; i++ )
912 for( size_t x = 0; x < 2; x++ ) {
913 Gcl.clampVar( i, both, true );
914 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
915 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
916 for( size_t j = 0; j < 3; j++ )
917 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
918 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
919 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
920 Gcl.restoreFactors();
921 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
922 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
923 }
924
925 // clampFactor
926 for( size_t x = 0; x < 4; x++ ) {
927 Gcl.clampFactor( 0, std::vector<size_t>(1,x), true );
928 Factor mask( v01, 0.0 );
929 mask.set( x, 1.0 );
930 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) * mask );
931 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
932 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
933 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
934 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
935 Gcl.restoreFactor( 0 );
936 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
937 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
938 }
939 for( size_t x = 0; x < 4; x++ ) {
940 Gcl.clampFactor( 1, std::vector<size_t>(1,x), true );
941 Factor mask( v12, 0.0 );
942 mask.set( x, 1.0 );
943 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
944 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) * mask );
945 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
946 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
947 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
948 Gcl.restoreFactor( 1 );
949 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
950 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
951 }
952 for( size_t x = 0; x < 2; x++ ) {
953 Gcl.clampFactor( 2, std::vector<size_t>(1,x), true );
954 Factor mask( v1, 0.0 );
955 mask.set( x, 1.0 );
956 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
957 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
958 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) * mask );
959 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
960 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
961 Gcl.restoreFactors();
962 BOOST_CHECK_EQUAL( Gcl.OR(0), Gcl.factor(0) * Gcl.factor(2) );
963 BOOST_CHECK_EQUAL( Gcl.OR(1), Gcl.factor(1) );
964 }
965
966 // makeCavity
967 RegionGraph Gcav( G );
968 Gcav.makeCavity( 0, true );
969 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
970 BOOST_CHECK_EQUAL( Gcav.factor(1), G.factor(1) );
971 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
972 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
973 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
974 Gcav.restoreFactors();
975 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
976 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
977 Gcav.makeCavity( 1, true );
978 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
979 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
980 BOOST_CHECK_EQUAL( Gcav.factor(2), Factor( v1, 1.0 ) );
981 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
982 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
983 Gcav.restoreFactors();
984 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
985 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
986 Gcav.makeCavity( 2, true );
987 BOOST_CHECK_EQUAL( Gcav.factor(0), G.factor(0) );
988 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
989 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
990 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
991 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
992 Gcav.restoreFactors();
993 BOOST_CHECK_EQUAL( Gcav.OR(0), Gcav.factor(0) * Gcav.factor(2) );
994 BOOST_CHECK_EQUAL( Gcav.OR(1), Gcav.factor(1) );
995 }
996
997
998 BOOST_AUTO_TEST_CASE( IOTest ) {
999 Var v0( 0, 2 );
1000 Var v1( 1, 2 );
1001 Var v2( 2, 2 );
1002 VarSet v01( v0, v1 );
1003 VarSet v02( v0, v2 );
1004 VarSet v12( v1, v2 );
1005 VarSet v012 = v01 | v2;
1006
1007 std::vector<Factor> facs;
1008 facs.push_back( Factor( v01, 1.0 ) );
1009 facs.push_back( Factor( v12, 2.0 ) );
1010 facs.push_back( Factor( v1, 3.0 ) );
1011 std::vector<Var> vars;
1012 vars.push_back( v0 );
1013 vars.push_back( v1 );
1014 vars.push_back( v2 );
1015
1016 FactorGraph fg( facs );
1017 RegionGraph G( fg, fg.maximalFactorDomains() );
1018 BOOST_CHECK_THROW( G.WriteToFile( "regiongraph_test.fg" ), Exception );
1019 BOOST_CHECK_THROW( G.ReadFromFile( "regiongraph_test.fg" ), Exception );
1020 BOOST_CHECK_THROW( G.printDot( std::cout ), Exception );
1021
1022 std::stringstream ss;
1023 std::string s;
1024 ss << G;
1025 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "digraph RegionGraph {" );
1026 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box];" );
1027 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta0 [label=\"a0: {x0, x1}, c=1\"];" );
1028 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta1 [label=\"a1: {x1, x2}, c=1\"];" );
1029 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=ellipse];" );
1030 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tb0 [label=\"b0: {x1}, c=-1\"];" );
1031 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta0 -> b0;" );
1032 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta1 -> b0;" );
1033 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
1034 }