Some improvements to jtree and regiongraph and started work on regiongraph unit tests
[libdai.git] / tests / unit / regiongraph.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/regiongraph.h>
15 #include <vector>
16 #include <strstream>
17
18
19 using namespace dai;
20
21
22 const double tol = 1e-8;
23
24
25 #define BOOST_TEST_MODULE RegionGraphTest
26
27
28 #include <boost/test/unit_test.hpp>
29 #include <boost/test/floating_point_comparison.hpp>
30
31
32 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
33 RegionGraph G;
34 BOOST_CHECK_EQUAL( G.vars(), std::vector<Var>() );
35 BOOST_CHECK_EQUAL( G.factors(), std::vector<Factor>() );
36
37 std::vector<Factor> facs;
38 facs.push_back( Factor( VarSet( Var(0, 2), Var(1, 2) ) ) );
39 facs.push_back( Factor( VarSet( Var(0, 2), Var(2, 2) ) ) );
40 facs.push_back( Factor( VarSet( Var(1, 2), Var(2, 2) ) ) );
41 facs.push_back( Factor( VarSet( Var(1, 2) ) ) );
42 std::vector<Var> vars;
43 vars.push_back( Var( 0, 2 ) );
44 vars.push_back( Var( 1, 2 ) );
45 vars.push_back( Var( 2, 2 ) );
46
47 FactorGraph fg(facs);
48 RegionGraph G1( fg, fg.maximalFactorDomains() );
49 BOOST_CHECK_EQUAL( G1.vars(), vars );
50 BOOST_CHECK_EQUAL( G1.factors(), facs );
51
52 RegionGraph *G2 = G1.clone();
53 BOOST_CHECK_EQUAL( G2->vars(), vars );
54 BOOST_CHECK_EQUAL( G2->factors(), facs );
55 delete G2;
56
57 RegionGraph G3 = G1;
58 BOOST_CHECK_EQUAL( G3.vars(), vars );
59 BOOST_CHECK_EQUAL( G3.factors(), facs );
60
61 RegionGraph G4( G1 );
62 BOOST_CHECK_EQUAL( G4.vars(), vars );
63 BOOST_CHECK_EQUAL( G4.factors(), facs );
64 }
65
66
67 BOOST_AUTO_TEST_CASE( AccMutIterTest ) {
68 std::vector<Factor> facs;
69 facs.push_back( Factor( VarSet( Var(0, 2), Var(1, 2) ) ) );
70 facs.push_back( Factor( VarSet( Var(0, 2), Var(2, 2) ) ) );
71 facs.push_back( Factor( VarSet( Var(1, 2), Var(2, 2) ) ) );
72 facs.push_back( Factor( VarSet( Var(1, 2) ) ) );
73 std::vector<Var> vars;
74 vars.push_back( Var( 0, 2 ) );
75 vars.push_back( Var( 1, 2 ) );
76 vars.push_back( Var( 2, 2 ) );
77
78 FactorGraph fg( facs );
79 RegionGraph G( fg, fg.maximalFactorDomains() );
80 BOOST_CHECK_EQUAL( G.var(0), Var(0, 2) );
81 BOOST_CHECK_EQUAL( G.var(1), Var(1, 2) );
82 BOOST_CHECK_EQUAL( G.var(2), Var(2, 2) );
83 BOOST_CHECK_EQUAL( G.vars(), vars );
84 BOOST_CHECK_EQUAL( G.factor(0), facs[0] );
85 BOOST_CHECK_EQUAL( G.factor(1), facs[1] );
86 BOOST_CHECK_EQUAL( G.factor(2), facs[2] );
87 BOOST_CHECK_EQUAL( G.factor(3), facs[3] );
88 BOOST_CHECK_EQUAL( G.factors(), facs );
89 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
90 BOOST_CHECK_EQUAL( G.nbV(0,0), 0 );
91 BOOST_CHECK_EQUAL( G.nbV(0,1), 1 );
92 BOOST_CHECK_EQUAL( G.nbV(1).size(), 3 );
93 BOOST_CHECK_EQUAL( G.nbV(1,0), 0 );
94 BOOST_CHECK_EQUAL( G.nbV(1,1), 2 );
95 BOOST_CHECK_EQUAL( G.nbV(1,2), 3 );
96 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
97 BOOST_CHECK_EQUAL( G.nbV(2,0), 1 );
98 BOOST_CHECK_EQUAL( G.nbV(2,1), 2 );
99 BOOST_CHECK_EQUAL( G.nbF(0).size(), 2 );
100 BOOST_CHECK_EQUAL( G.nbF(0,0), 0 );
101 BOOST_CHECK_EQUAL( G.nbF(0,1), 1 );
102 BOOST_CHECK_EQUAL( G.nbF(1).size(), 2 );
103 BOOST_CHECK_EQUAL( G.nbF(1,0), 0 );
104 BOOST_CHECK_EQUAL( G.nbF(1,1), 2 );
105 BOOST_CHECK_EQUAL( G.nbF(2).size(), 2 );
106 BOOST_CHECK_EQUAL( G.nbF(2,0), 1 );
107 BOOST_CHECK_EQUAL( G.nbF(2,1), 2 );
108 BOOST_CHECK_EQUAL( G.nbF(3).size(), 1 );
109 BOOST_CHECK_EQUAL( G.nbF(3,0), 1 );
110
111 RegionGraph::const_iterator cit = G.begin();
112 RegionGraph::iterator it = G.begin();
113 for( size_t I = 0; I < G.nrFactors(); I++, cit++, it++ ) {
114 BOOST_CHECK_EQUAL( *cit, G.factor(I) );
115 BOOST_CHECK_EQUAL( *it, G.factor(I) );
116 }
117 BOOST_CHECK( cit == G.end() );
118 BOOST_CHECK( it == G.end() );
119 }
120
121
122 BOOST_AUTO_TEST_CASE( QueriesTest ) {
123 Var v0( 0, 2 );
124 Var v1( 1, 2 );
125 Var v2( 2, 2 );
126 VarSet v01( v0, v1 );
127 VarSet v02( v0, v2 );
128 VarSet v12( v1, v2 );
129 VarSet v012 = v01 | v2;
130
131 RegionGraph G0;
132 BOOST_CHECK_EQUAL( G0.nrVars(), 0 );
133 BOOST_CHECK_EQUAL( G0.nrFactors(), 0 );
134 BOOST_CHECK_EQUAL( G0.nrEdges(), 0 );
135 BOOST_CHECK_THROW( G0.findVar( v0 ), Exception );
136 BOOST_CHECK_THROW( G0.findVars( v01 ), Exception );
137 BOOST_CHECK_THROW( G0.findFactor( v01 ), Exception );
138 #ifdef DAI_DBEUG
139 BOOST_CHECK_THROW( G0.delta( 0 ), Exception );
140 BOOST_CHECK_THROW( G0.Delta( 0 ), Exception );
141 BOOST_CHECK_THROW( G0.delta( v0 ), Exception );
142 BOOST_CHECK_THROW( G0.Delta( v0 ), Exception );
143 #endif
144 BOOST_CHECK( G0.isConnected() );
145 BOOST_CHECK( G0.isTree() );
146 BOOST_CHECK( G0.isBinary() );
147 BOOST_CHECK( G0.isPairwise() );
148 BOOST_CHECK( G0.MarkovGraph() == GraphAL() );
149 BOOST_CHECK( G0.bipGraph() == BipartiteGraph() );
150 BOOST_CHECK_EQUAL( G0.maximalFactorDomains().size(), 0 );
151
152 std::vector<Factor> facs;
153 facs.push_back( Factor( v01 ) );
154 facs.push_back( Factor( v12 ) );
155 facs.push_back( Factor( v1 ) );
156 std::vector<Var> vars;
157 vars.push_back( v0 );
158 vars.push_back( v1 );
159 vars.push_back( v2 );
160 GraphAL H(3);
161 H.addEdge( 0, 1 );
162 H.addEdge( 1, 2 );
163 BipartiteGraph K(3, 3);
164 K.addEdge( 0, 0 );
165 K.addEdge( 1, 0 );
166 K.addEdge( 1, 1 );
167 K.addEdge( 2, 1 );
168 K.addEdge( 1, 2 );
169
170 FactorGraph fg( facs );
171 RegionGraph G1( fg, fg.maximalFactorDomains() );
172 BOOST_CHECK_EQUAL( G1.nrVars(), 3 );
173 BOOST_CHECK_EQUAL( G1.nrFactors(), 3 );
174 BOOST_CHECK_EQUAL( G1.nrEdges(), 5 );
175 BOOST_CHECK_EQUAL( G1.findVar( v0 ), 0 );
176 BOOST_CHECK_EQUAL( G1.findVar( v1 ), 1 );
177 BOOST_CHECK_EQUAL( G1.findVar( v2 ), 2 );
178 BOOST_CHECK_EQUAL( G1.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
179 BOOST_CHECK_EQUAL( G1.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
180 BOOST_CHECK_EQUAL( G1.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
181 BOOST_CHECK_EQUAL( G1.findFactor( v01 ), 0 );
182 BOOST_CHECK_EQUAL( G1.findFactor( v12 ), 1 );
183 BOOST_CHECK_EQUAL( G1.findFactor( v1 ), 2 );
184 BOOST_CHECK_THROW( G1.findFactor( v02 ), Exception );
185 BOOST_CHECK_EQUAL( G1.delta( 0 ), v1 );
186 BOOST_CHECK_EQUAL( G1.delta( 1 ), v02 );
187 BOOST_CHECK_EQUAL( G1.delta( 2 ), v1 );
188 BOOST_CHECK_EQUAL( G1.Delta( 0 ), v01 );
189 BOOST_CHECK_EQUAL( G1.Delta( 1 ), v012 );
190 BOOST_CHECK_EQUAL( G1.Delta( 2 ), v12 );
191 BOOST_CHECK_EQUAL( G1.delta( v0 ), v1 );
192 BOOST_CHECK_EQUAL( G1.delta( v1 ), v02 );
193 BOOST_CHECK_EQUAL( G1.delta( v2 ), v1 );
194 BOOST_CHECK_EQUAL( G1.delta( v01 ), v2 );
195 BOOST_CHECK_EQUAL( G1.delta( v02 ), v1 );
196 BOOST_CHECK_EQUAL( G1.delta( v12 ), v0 );
197 BOOST_CHECK_EQUAL( G1.delta( v012 ), VarSet() );
198 BOOST_CHECK_EQUAL( G1.Delta( v0 ), v01 );
199 BOOST_CHECK_EQUAL( G1.Delta( v1 ), v012 );
200 BOOST_CHECK_EQUAL( G1.Delta( v2 ), v12 );
201 BOOST_CHECK_EQUAL( G1.Delta( v01 ), v012 );
202 BOOST_CHECK_EQUAL( G1.Delta( v02 ), v012 );
203 BOOST_CHECK_EQUAL( G1.Delta( v12 ), v012 );
204 BOOST_CHECK_EQUAL( G1.Delta( v012 ), v012 );
205 BOOST_CHECK( G1.isConnected() );
206 BOOST_CHECK( G1.isTree() );
207 BOOST_CHECK( G1.isBinary() );
208 BOOST_CHECK( G1.isPairwise() );
209 BOOST_CHECK( G1.MarkovGraph() == H );
210 BOOST_CHECK( G1.bipGraph() == K );
211 BOOST_CHECK_EQUAL( G1.maximalFactorDomains().size(), 2 );
212 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[0], v01 );
213 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[1], v12 );
214
215 facs.push_back( Factor( v02 ) );
216 H.addEdge( 0, 2 );
217 K.addNode2();
218 K.addEdge( 0, 3 );
219 K.addEdge( 2, 3 );
220 fg = FactorGraph( facs );
221 RegionGraph G2( fg, fg.maximalFactorDomains() );
222 BOOST_CHECK_EQUAL( G2.nrVars(), 3 );
223 BOOST_CHECK_EQUAL( G2.nrFactors(), 4 );
224 BOOST_CHECK_EQUAL( G2.nrEdges(), 7 );
225 BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
226 BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
227 BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
228 BOOST_CHECK_EQUAL( G2.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
229 BOOST_CHECK_EQUAL( G2.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
230 BOOST_CHECK_EQUAL( G2.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
231 BOOST_CHECK_EQUAL( G2.findFactor( v01 ), 0 );
232 BOOST_CHECK_EQUAL( G2.findFactor( v12 ), 1 );
233 BOOST_CHECK_EQUAL( G2.findFactor( v1 ), 2 );
234 BOOST_CHECK_EQUAL( G2.findFactor( v02 ), 3 );
235 BOOST_CHECK_EQUAL( G2.delta( 0 ), v12 );
236 BOOST_CHECK_EQUAL( G2.delta( 1 ), v02 );
237 BOOST_CHECK_EQUAL( G2.delta( 2 ), v01 );
238 BOOST_CHECK_EQUAL( G2.Delta( 0 ), v012 );
239 BOOST_CHECK_EQUAL( G2.Delta( 1 ), v012 );
240 BOOST_CHECK_EQUAL( G2.Delta( 2 ), v012 );
241 BOOST_CHECK( G2.isConnected() );
242 BOOST_CHECK( !G2.isTree() );
243 BOOST_CHECK( G2.isBinary() );
244 BOOST_CHECK( G2.isPairwise() );
245 BOOST_CHECK( G2.MarkovGraph() == H );
246 BOOST_CHECK( G2.bipGraph() == K );
247 BOOST_CHECK_EQUAL( G2.maximalFactorDomains().size(), 3 );
248 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[0], v01 );
249 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[1], v12 );
250 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[2], v02 );
251
252 Var v3( 3, 3 );
253 VarSet v03( v0, v3 );
254 VarSet v13( v1, v3 );
255 VarSet v23( v2, v3 );
256 VarSet v013 = v01 | v3;
257 VarSet v023 = v02 | v3;
258 VarSet v123 = v12 | v3;
259 VarSet v0123 = v012 | v3;
260 vars.push_back( v3 );
261 facs.push_back( Factor( v3 ) );
262 H.addNode();
263 K.addNode1();
264 K.addNode2();
265 K.addEdge( 3, 4 );
266 fg = FactorGraph( facs );
267 RegionGraph G3( fg, fg.maximalFactorDomains() );
268 BOOST_CHECK_EQUAL( G3.nrVars(), 4 );
269 BOOST_CHECK_EQUAL( G3.nrFactors(), 5 );
270 BOOST_CHECK_EQUAL( G3.nrEdges(), 8 );
271 BOOST_CHECK_EQUAL( G3.findVar( v0 ), 0 );
272 BOOST_CHECK_EQUAL( G3.findVar( v1 ), 1 );
273 BOOST_CHECK_EQUAL( G3.findVar( v2 ), 2 );
274 BOOST_CHECK_EQUAL( G3.findVar( v3 ), 3 );
275 BOOST_CHECK_EQUAL( G3.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
276 BOOST_CHECK_EQUAL( G3.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
277 BOOST_CHECK_EQUAL( G3.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
278 BOOST_CHECK_EQUAL( G3.findFactor( v01 ), 0 );
279 BOOST_CHECK_EQUAL( G3.findFactor( v12 ), 1 );
280 BOOST_CHECK_EQUAL( G3.findFactor( v1 ), 2 );
281 BOOST_CHECK_EQUAL( G3.findFactor( v02 ), 3 );
282 BOOST_CHECK_EQUAL( G3.findFactor( v3 ), 4 );
283 BOOST_CHECK_THROW( G3.findFactor( v23 ), Exception );
284 BOOST_CHECK_EQUAL( G3.delta( 0 ), v12 );
285 BOOST_CHECK_EQUAL( G3.delta( 1 ), v02 );
286 BOOST_CHECK_EQUAL( G3.delta( 2 ), v01 );
287 BOOST_CHECK_EQUAL( G3.delta( 3 ), VarSet() );
288 BOOST_CHECK_EQUAL( G3.Delta( 0 ), v012 );
289 BOOST_CHECK_EQUAL( G3.Delta( 1 ), v012 );
290 BOOST_CHECK_EQUAL( G3.Delta( 2 ), v012 );
291 BOOST_CHECK_EQUAL( G3.Delta( 3 ), v3 );
292 BOOST_CHECK( !G3.isConnected() );
293 BOOST_CHECK( !G3.isTree() );
294 BOOST_CHECK( !G3.isBinary() );
295 BOOST_CHECK( G3.isPairwise() );
296 BOOST_CHECK( G3.MarkovGraph() == H );
297 BOOST_CHECK( G3.bipGraph() == K );
298 BOOST_CHECK_EQUAL( G3.maximalFactorDomains().size(), 4 );
299 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[0], v01 );
300 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[1], v12 );
301 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[2], v02 );
302 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[3], v3 );
303
304 facs.push_back( Factor( v123 ) );
305 H.addEdge( 1, 3 );
306 H.addEdge( 2, 3 );
307 K.addNode2();
308 K.addEdge( 1, 5 );
309 K.addEdge( 2, 5 );
310 K.addEdge( 3, 5 );
311 fg = FactorGraph( facs );
312 RegionGraph G4( fg, fg.maximalFactorDomains() );
313 BOOST_CHECK_EQUAL( G4.nrVars(), 4 );
314 BOOST_CHECK_EQUAL( G4.nrFactors(), 6 );
315 BOOST_CHECK_EQUAL( G4.nrEdges(), 11 );
316 BOOST_CHECK_EQUAL( G4.findVar( v0 ), 0 );
317 BOOST_CHECK_EQUAL( G4.findVar( v1 ), 1 );
318 BOOST_CHECK_EQUAL( G4.findVar( v2 ), 2 );
319 BOOST_CHECK_EQUAL( G4.findVar( v3 ), 3 );
320 BOOST_CHECK_EQUAL( G4.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
321 BOOST_CHECK_EQUAL( G4.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
322 BOOST_CHECK_EQUAL( G4.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
323 BOOST_CHECK_EQUAL( G4.findFactor( v01 ), 0 );
324 BOOST_CHECK_EQUAL( G4.findFactor( v12 ), 1 );
325 BOOST_CHECK_EQUAL( G4.findFactor( v1 ), 2 );
326 BOOST_CHECK_EQUAL( G4.findFactor( v02 ), 3 );
327 BOOST_CHECK_EQUAL( G4.findFactor( v3 ), 4 );
328 BOOST_CHECK_EQUAL( G4.findFactor( v123 ), 5 );
329 BOOST_CHECK_THROW( G4.findFactor( v23 ), Exception );
330 BOOST_CHECK_EQUAL( G4.delta( 0 ), v12 );
331 BOOST_CHECK_EQUAL( G4.delta( 1 ), v023 );
332 BOOST_CHECK_EQUAL( G4.delta( 2 ), v013 );
333 BOOST_CHECK_EQUAL( G4.delta( 3 ), v12 );
334 BOOST_CHECK_EQUAL( G4.Delta( 0 ), v012 );
335 BOOST_CHECK_EQUAL( G4.Delta( 1 ), v0123 );
336 BOOST_CHECK_EQUAL( G4.Delta( 2 ), v0123 );
337 BOOST_CHECK_EQUAL( G4.Delta( 3 ), v123 );
338 BOOST_CHECK( G4.isConnected() );
339 BOOST_CHECK( !G4.isTree() );
340 BOOST_CHECK( !G4.isBinary() );
341 BOOST_CHECK( !G4.isPairwise() );
342 BOOST_CHECK( G4.MarkovGraph() == H );
343 BOOST_CHECK( G4.bipGraph() == K );
344 BOOST_CHECK_EQUAL( G4.maximalFactorDomains().size(), 3 );
345 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[0], v01 );
346 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[1], v02 );
347 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[2], v123 );
348 }
349
350
351 BOOST_AUTO_TEST_CASE( BackupRestoreTest ) {
352 Var v0( 0, 2 );
353 Var v1( 1, 2 );
354 Var v2( 2, 2 );
355 VarSet v01( v0, v1 );
356 VarSet v02( v0, v2 );
357 VarSet v12( v1, v2 );
358 VarSet v012 = v01 | v2;
359
360 std::vector<Factor> facs;
361 facs.push_back( Factor( v01 ) );
362 facs.push_back( Factor( v12 ) );
363 facs.push_back( Factor( v1 ) );
364 std::vector<Var> vars;
365 vars.push_back( v0 );
366 vars.push_back( v1 );
367 vars.push_back( v2 );
368
369 FactorGraph fg( facs );
370 RegionGraph G( fg, fg.maximalFactorDomains() );
371 RegionGraph Gorg( G );
372
373 BOOST_CHECK_THROW( G.setFactor( 0, Factor( v0 ), false ), Exception );
374 G.setFactor( 0, Factor( v01, 2.0 ), false );
375 BOOST_CHECK_THROW( G.restoreFactor( 0 ), Exception );
376 G.setFactor( 0, Factor( v01, 3.0 ), true );
377 G.restoreFactor( 0 );
378 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
379 G.setFactor( 0, Gorg.factor( 0 ), false );
380 G.backupFactor( 0 );
381 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
382 G.setFactor( 0, Factor( v01, 2.0 ), false );
383 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
384 G.restoreFactor( 0 );
385 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
386
387 std::map<size_t, Factor> fs;
388 fs[0] = Factor( v01, 3.0 );
389 fs[2] = Factor( v1, 2.0 );
390 G.setFactors( fs, false );
391 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
392 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
393 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
394 G.restoreFactors();
395 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
396 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
397 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
398 G = Gorg;
399 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
400 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
401 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
402 G.setFactors( fs, true );
403 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
404 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
405 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
406 G.restoreFactors();
407 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
408 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
409 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
410 std::set<size_t> fsind;
411 fsind.insert( 0 );
412 fsind.insert( 2 );
413 G.backupFactors( fsind );
414 G.setFactors( fs, false );
415 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
416 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
417 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
418 G.restoreFactors();
419 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
420 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
421 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
422
423 G.backupFactors( v2 );
424 G.setFactor( 1, Factor(v12, 5.0) );
425 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
426 BOOST_CHECK_EQUAL( G.factor(1), Factor(v12, 5.0) );
427 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
428 G.restoreFactors( v2 );
429 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
430 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
431 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
432
433 G.backupFactors( v1 );
434 fs[1] = Factor( v12, 5.0 );
435 G.setFactors( fs, false );
436 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
437 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
438 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
439 G.restoreFactors();
440 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
441 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
442 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
443 G.setFactors( fs, true );
444 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
445 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
446 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
447 G.restoreFactors( v1 );
448 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
449 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
450 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
451 }
452
453
454 BOOST_AUTO_TEST_CASE( TransformationsTest ) {
455 Var v0( 0, 2 );
456 Var v1( 1, 2 );
457 Var v2( 2, 2 );
458 VarSet v01( v0, v1 );
459 VarSet v02( v0, v2 );
460 VarSet v12( v1, v2 );
461 VarSet v012 = v01 | v2;
462
463 std::vector<Factor> facs;
464 facs.push_back( Factor( v01 ).randomize() );
465 facs.push_back( Factor( v12 ).randomize() );
466 facs.push_back( Factor( v1 ).randomize() );
467 std::vector<Var> vars;
468 vars.push_back( v0 );
469 vars.push_back( v1 );
470 vars.push_back( v2 );
471
472 FactorGraph fg( facs );
473 RegionGraph G( fg, fg.maximalFactorDomains() );
474
475 FactorGraph Gsmall = G.maximalFactors();
476 BOOST_CHECK_EQUAL( Gsmall.nrVars(), 3 );
477 BOOST_CHECK_EQUAL( Gsmall.nrFactors(), 2 );
478 BOOST_CHECK_EQUAL( Gsmall.factor( 0 ), G.factor( 0 ) * G.factor( 2 ) );
479 BOOST_CHECK_EQUAL( Gsmall.factor( 1 ), G.factor( 1 ) );
480
481 size_t i = 0;
482 for( size_t x = 0; x < 2; x++ ) {
483 FactorGraph Gcl = G.clamped( i, x );
484 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
485 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
486 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
487 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) * G.factor(2) );
488 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1) );
489 }
490 i = 1;
491 for( size_t x = 0; x < 2; x++ ) {
492 FactorGraph Gcl = G.clamped( i, x );
493 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
494 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
495 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) * G.factor(2).slice(vars[i],x) );
496 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) );
497 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) );
498 }
499 i = 2;
500 for( size_t x = 0; x < 2; x++ ) {
501 FactorGraph Gcl = G.clamped( i, x );
502 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
503 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
504 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
505 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0) );
506 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) * G.factor(2) );
507 }
508 }
509
510
511 BOOST_AUTO_TEST_CASE( OperationsTest ) {
512 Var v0( 0, 2 );
513 Var v1( 1, 2 );
514 Var v2( 2, 2 );
515 VarSet v01( v0, v1 );
516 VarSet v02( v0, v2 );
517 VarSet v12( v1, v2 );
518 VarSet v012 = v01 | v2;
519
520 std::vector<Factor> facs;
521 facs.push_back( Factor( v01 ).randomize() );
522 facs.push_back( Factor( v12 ).randomize() );
523 facs.push_back( Factor( v1 ).randomize() );
524 std::vector<Var> vars;
525 vars.push_back( v0 );
526 vars.push_back( v1 );
527 vars.push_back( v2 );
528
529 FactorGraph fg( facs );
530 RegionGraph G( fg, fg.maximalFactorDomains() );
531
532 // clamp
533 RegionGraph Gcl = G;
534 for( size_t i = 0; i < 3; i++ )
535 for( size_t x = 0; x < 2; x++ ) {
536 Gcl.clamp( i, x, true );
537 Factor delta = createFactorDelta( vars[i], x );
538 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
539 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
540 for( size_t j = 0; j < 3; j++ )
541 if( G.factor(j).vars().contains( vars[i] ) )
542 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
543 else
544 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
545
546 Gcl.restoreFactors();
547 for( size_t j = 0; j < 3; j++ )
548 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
549 }
550
551 // clampVar
552 for( size_t i = 0; i < 3; i++ )
553 for( size_t x = 0; x < 2; x++ ) {
554 Gcl.clampVar( i, std::vector<size_t>(1, x), true );
555 Factor delta = createFactorDelta( vars[i], x );
556 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
557 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
558 for( size_t j = 0; j < 3; j++ )
559 if( G.factor(j).vars().contains( vars[i] ) )
560 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
561 else
562 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
563
564 Gcl.restoreFactors();
565 for( size_t j = 0; j < 3; j++ )
566 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
567 }
568 for( size_t i = 0; i < 3; i++ )
569 for( size_t x = 0; x < 2; x++ ) {
570 Gcl.clampVar( i, std::vector<size_t>(), true );
571 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
572 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
573 for( size_t j = 0; j < 3; j++ )
574 if( G.factor(j).vars().contains( vars[i] ) )
575 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * 0.0 );
576 else
577 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
578
579 Gcl.restoreFactors();
580 for( size_t j = 0; j < 3; j++ )
581 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
582 }
583 std::vector<size_t> both;
584 both.push_back( 0 );
585 both.push_back( 1 );
586 for( size_t i = 0; i < 3; i++ )
587 for( size_t x = 0; x < 2; x++ ) {
588 Gcl.clampVar( i, both, true );
589 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
590 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
591 for( size_t j = 0; j < 3; j++ )
592 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
593 Gcl.restoreFactors();
594 }
595
596 // clampFactor
597 for( size_t x = 0; x < 4; x++ ) {
598 Gcl.clampFactor( 0, std::vector<size_t>(1,x), true );
599 Factor mask( v01, 0.0 );
600 mask.set( x, 1.0 );
601 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) * mask );
602 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
603 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
604 Gcl.restoreFactor( 0 );
605 }
606 for( size_t x = 0; x < 4; x++ ) {
607 Gcl.clampFactor( 1, std::vector<size_t>(1,x), true );
608 Factor mask( v12, 0.0 );
609 mask.set( x, 1.0 );
610 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
611 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) * mask );
612 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
613 Gcl.restoreFactor( 1 );
614 }
615 for( size_t x = 0; x < 2; x++ ) {
616 Gcl.clampFactor( 2, std::vector<size_t>(1,x), true );
617 Factor mask( v1, 0.0 );
618 mask.set( x, 1.0 );
619 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
620 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
621 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) * mask );
622 Gcl.restoreFactors();
623 }
624
625 // makeCavity
626 RegionGraph Gcav( G );
627 Gcav.makeCavity( 0, true );
628 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
629 BOOST_CHECK_EQUAL( Gcav.factor(1), G.factor(1) );
630 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
631 Gcav.restoreFactors();
632 Gcav.makeCavity( 1, true );
633 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
634 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
635 BOOST_CHECK_EQUAL( Gcav.factor(2), Factor( v1, 1.0 ) );
636 Gcav.restoreFactors();
637 Gcav.makeCavity( 2, true );
638 BOOST_CHECK_EQUAL( Gcav.factor(0), G.factor(0) );
639 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
640 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
641 Gcav.restoreFactors();
642 }
643
644
645 BOOST_AUTO_TEST_CASE( IOTest ) {
646 Var v0( 0, 2 );
647 Var v1( 1, 2 );
648 Var v2( 2, 2 );
649 VarSet v01( v0, v1 );
650 VarSet v02( v0, v2 );
651 VarSet v12( v1, v2 );
652 VarSet v012 = v01 | v2;
653
654 std::vector<Factor> facs;
655 facs.push_back( Factor( v01 ).randomize() );
656 facs.push_back( Factor( v12 ).randomize() );
657 facs.push_back( Factor( v1 ).randomize() );
658 std::vector<Var> vars;
659 vars.push_back( v0 );
660 vars.push_back( v1 );
661 vars.push_back( v2 );
662
663 FactorGraph fg( facs );
664 RegionGraph G( fg, fg.maximalFactorDomains() );
665
666 G.WriteToFile( "regiongraph_test.fg" );
667 RegionGraph G2;
668 G2.ReadFromFile( "regiongraph_test.fg" );
669
670 BOOST_CHECK( G.vars() == G2.vars() );
671 BOOST_CHECK( G.bipGraph() == G2.bipGraph() );
672 BOOST_CHECK_EQUAL( G.nrFactors(), G2.nrFactors() );
673 for( size_t I = 0; I < G.nrFactors(); I++ ) {
674 BOOST_CHECK( G.factor(I).vars() == G2.factor(I).vars() );
675 for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
676 BOOST_CHECK_CLOSE( G.factor(I)[s], G2.factor(I)[s], tol );
677 }
678
679 std::stringstream ss;
680 std::string s;
681 G.printDot( ss );
682 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph G {" );
683 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
684 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv0;" );
685 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1;" );
686 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv2;" );
687 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box,width=0.3,height=0.3,fixedsize=true];" );
688 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf0;" );
689 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf1;" );
690 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf2;" );
691 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv0 -- f0;" );
692 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f0;" );
693 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f1;" );
694 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f2;" );
695 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv2 -- f1;" );
696 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
697
698 /* G.factor(0).fill(1.0);
699 G.factor(1).fill(2.0);
700 G.factor(2).fill(3.0);
701 ss << G;
702 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3" );
703 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
704 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
705 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 1 " );
706 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2 " );
707 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "4" );
708 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 1" );
709 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 1" );
710 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 1" );
711 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3 1" );
712 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
713 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
714 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 2 " );
715 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2 " );
716 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "4" );
717 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 2" );
718 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 2" );
719 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2" );
720 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3 2" );
721 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
722 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1" );
723 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 " );
724 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 " );
725 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
726 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 3" );
727 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 3" );
728 */
729 /* ss << G;
730 RegionGraph G3;
731 ss >> G3;
732 BOOST_CHECK( G.vars() == G3.vars() );
733 BOOST_CHECK( G.bipGraph() == G3.bipGraph() );
734 BOOST_CHECK_EQUAL( G.nrFactors(), G3.nrFactors() );
735 for( size_t I = 0; I < G.nrFactors(); I++ ) {
736 BOOST_CHECK( G.factor(I).vars() == G3.factor(I).vars() );
737 for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
738 BOOST_CHECK_CLOSE( G.factor(I)[s], G3.factor(I)[s], tol );
739 }*/
740 }