Miscellaneous improvements to regiongraph, factorgraph and bipgraph, and finished...
[libdai.git] / tests / unit / factorgraph.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #define BOOST_TEST_DYN_LINK
12
13
14 #include <dai/bipgraph.h>
15 #include <dai/factorgraph.h>
16 #include <vector>
17 #include <strstream>
18
19
20 using namespace dai;
21
22
23 const double tol = 1e-8;
24
25
26 #define BOOST_TEST_MODULE FactorGraphTest
27
28
29 #include <boost/test/unit_test.hpp>
30 #include <boost/test/floating_point_comparison.hpp>
31
32
33 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
34 FactorGraph G;
35 BOOST_CHECK_EQUAL( G.vars(), std::vector<Var>() );
36 BOOST_CHECK_EQUAL( G.factors(), std::vector<Factor>() );
37
38 std::vector<Factor> facs;
39 facs.push_back( Factor( VarSet( Var(0, 2), Var(1, 2) ) ) );
40 facs.push_back( Factor( VarSet( Var(0, 2), Var(2, 2) ) ) );
41 facs.push_back( Factor( VarSet( Var(1, 2), Var(2, 2) ) ) );
42 facs.push_back( Factor( VarSet( Var(1, 2) ) ) );
43 std::vector<Var> vars;
44 vars.push_back( Var( 0, 2 ) );
45 vars.push_back( Var( 1, 2 ) );
46 vars.push_back( Var( 2, 2 ) );
47
48 FactorGraph G1( facs );
49 BOOST_CHECK_EQUAL( G1.vars(), vars );
50 BOOST_CHECK_EQUAL( G1.factors(), facs );
51
52 FactorGraph G2( facs.begin(), facs.end(), vars.begin(), vars.end(), facs.size(), vars.size() );
53 BOOST_CHECK_EQUAL( G2.vars(), vars );
54 BOOST_CHECK_EQUAL( G2.factors(), facs );
55
56 FactorGraph *G3 = G2.clone();
57 BOOST_CHECK_EQUAL( G3->vars(), vars );
58 BOOST_CHECK_EQUAL( G3->factors(), facs );
59 delete G3;
60
61 FactorGraph G4 = G2;
62 BOOST_CHECK_EQUAL( G4.vars(), vars );
63 BOOST_CHECK_EQUAL( G4.factors(), facs );
64
65 FactorGraph G5( G2 );
66 BOOST_CHECK_EQUAL( G5.vars(), vars );
67 BOOST_CHECK_EQUAL( G5.factors(), facs );
68 }
69
70
71 BOOST_AUTO_TEST_CASE( AccMutTest ) {
72 std::vector<Factor> facs;
73 facs.push_back( Factor( VarSet( Var(0, 2), Var(1, 2) ) ) );
74 facs.push_back( Factor( VarSet( Var(0, 2), Var(2, 2) ) ) );
75 facs.push_back( Factor( VarSet( Var(1, 2), Var(2, 2) ) ) );
76 facs.push_back( Factor( VarSet( Var(1, 2) ) ) );
77 std::vector<Var> vars;
78 vars.push_back( Var( 0, 2 ) );
79 vars.push_back( Var( 1, 2 ) );
80 vars.push_back( Var( 2, 2 ) );
81
82 FactorGraph G( facs );
83 BOOST_CHECK_EQUAL( G.var(0), Var(0, 2) );
84 BOOST_CHECK_EQUAL( G.var(1), Var(1, 2) );
85 BOOST_CHECK_EQUAL( G.var(2), Var(2, 2) );
86 BOOST_CHECK_EQUAL( G.vars(), vars );
87 BOOST_CHECK_EQUAL( G.factor(0), facs[0] );
88 BOOST_CHECK_EQUAL( G.factor(1), facs[1] );
89 BOOST_CHECK_EQUAL( G.factor(2), facs[2] );
90 BOOST_CHECK_EQUAL( G.factor(3), facs[3] );
91 BOOST_CHECK_EQUAL( G.factors(), facs );
92 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
93 BOOST_CHECK_EQUAL( G.nbV(0,0), 0 );
94 BOOST_CHECK_EQUAL( G.nbV(0,1), 1 );
95 BOOST_CHECK_EQUAL( G.nbV(1).size(), 3 );
96 BOOST_CHECK_EQUAL( G.nbV(1,0), 0 );
97 BOOST_CHECK_EQUAL( G.nbV(1,1), 2 );
98 BOOST_CHECK_EQUAL( G.nbV(1,2), 3 );
99 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
100 BOOST_CHECK_EQUAL( G.nbV(2,0), 1 );
101 BOOST_CHECK_EQUAL( G.nbV(2,1), 2 );
102 BOOST_CHECK_EQUAL( G.nbF(0).size(), 2 );
103 BOOST_CHECK_EQUAL( G.nbF(0,0), 0 );
104 BOOST_CHECK_EQUAL( G.nbF(0,1), 1 );
105 BOOST_CHECK_EQUAL( G.nbF(1).size(), 2 );
106 BOOST_CHECK_EQUAL( G.nbF(1,0), 0 );
107 BOOST_CHECK_EQUAL( G.nbF(1,1), 2 );
108 BOOST_CHECK_EQUAL( G.nbF(2).size(), 2 );
109 BOOST_CHECK_EQUAL( G.nbF(2,0), 1 );
110 BOOST_CHECK_EQUAL( G.nbF(2,1), 2 );
111 BOOST_CHECK_EQUAL( G.nbF(3).size(), 1 );
112 BOOST_CHECK_EQUAL( G.nbF(3,0), 1 );
113 }
114
115
116 BOOST_AUTO_TEST_CASE( QueriesTest ) {
117 Var v0( 0, 2 );
118 Var v1( 1, 2 );
119 Var v2( 2, 2 );
120 VarSet v01( v0, v1 );
121 VarSet v02( v0, v2 );
122 VarSet v12( v1, v2 );
123 VarSet v012 = v01 | v2;
124
125 FactorGraph G0;
126 BOOST_CHECK_EQUAL( G0.nrVars(), 0 );
127 BOOST_CHECK_EQUAL( G0.nrFactors(), 0 );
128 BOOST_CHECK_EQUAL( G0.nrEdges(), 0 );
129 BOOST_CHECK_THROW( G0.findVar( v0 ), Exception );
130 BOOST_CHECK_THROW( G0.findVars( v01 ), Exception );
131 BOOST_CHECK_THROW( G0.findFactor( v01 ), Exception );
132 #ifdef DAI_DBEUG
133 BOOST_CHECK_THROW( G0.delta( 0 ), Exception );
134 BOOST_CHECK_THROW( G0.Delta( 0 ), Exception );
135 BOOST_CHECK_THROW( G0.delta( v0 ), Exception );
136 BOOST_CHECK_THROW( G0.Delta( v0 ), Exception );
137 #endif
138 BOOST_CHECK( G0.isConnected() );
139 BOOST_CHECK( G0.isTree() );
140 BOOST_CHECK( G0.isBinary() );
141 BOOST_CHECK( G0.isPairwise() );
142 BOOST_CHECK( G0.MarkovGraph() == GraphAL() );
143 BOOST_CHECK( G0.bipGraph() == BipartiteGraph() );
144 BOOST_CHECK_EQUAL( G0.maximalFactorDomains().size(), 0 );
145
146 std::vector<Factor> facs;
147 facs.push_back( Factor( v01 ) );
148 facs.push_back( Factor( v12 ) );
149 facs.push_back( Factor( v1 ) );
150 std::vector<Var> vars;
151 vars.push_back( v0 );
152 vars.push_back( v1 );
153 vars.push_back( v2 );
154 GraphAL H(3);
155 H.addEdge( 0, 1 );
156 H.addEdge( 1, 2 );
157 BipartiteGraph K(3, 3);
158 K.addEdge( 0, 0 );
159 K.addEdge( 1, 0 );
160 K.addEdge( 1, 1 );
161 K.addEdge( 2, 1 );
162 K.addEdge( 1, 2 );
163
164 FactorGraph G1( facs );
165 BOOST_CHECK_EQUAL( G1.nrVars(), 3 );
166 BOOST_CHECK_EQUAL( G1.nrFactors(), 3 );
167 BOOST_CHECK_EQUAL( G1.nrEdges(), 5 );
168 BOOST_CHECK_EQUAL( G1.findVar( v0 ), 0 );
169 BOOST_CHECK_EQUAL( G1.findVar( v1 ), 1 );
170 BOOST_CHECK_EQUAL( G1.findVar( v2 ), 2 );
171 BOOST_CHECK_EQUAL( G1.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
172 BOOST_CHECK_EQUAL( G1.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
173 BOOST_CHECK_EQUAL( G1.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
174 BOOST_CHECK_EQUAL( G1.findFactor( v01 ), 0 );
175 BOOST_CHECK_EQUAL( G1.findFactor( v12 ), 1 );
176 BOOST_CHECK_EQUAL( G1.findFactor( v1 ), 2 );
177 BOOST_CHECK_THROW( G1.findFactor( v02 ), Exception );
178 BOOST_CHECK_EQUAL( G1.delta( 0 ), v1 );
179 BOOST_CHECK_EQUAL( G1.delta( 1 ), v02 );
180 BOOST_CHECK_EQUAL( G1.delta( 2 ), v1 );
181 BOOST_CHECK_EQUAL( G1.Delta( 0 ), v01 );
182 BOOST_CHECK_EQUAL( G1.Delta( 1 ), v012 );
183 BOOST_CHECK_EQUAL( G1.Delta( 2 ), v12 );
184 BOOST_CHECK_EQUAL( G1.delta( v0 ), v1 );
185 BOOST_CHECK_EQUAL( G1.delta( v1 ), v02 );
186 BOOST_CHECK_EQUAL( G1.delta( v2 ), v1 );
187 BOOST_CHECK_EQUAL( G1.delta( v01 ), v2 );
188 BOOST_CHECK_EQUAL( G1.delta( v02 ), v1 );
189 BOOST_CHECK_EQUAL( G1.delta( v12 ), v0 );
190 BOOST_CHECK_EQUAL( G1.delta( v012 ), VarSet() );
191 BOOST_CHECK_EQUAL( G1.Delta( v0 ), v01 );
192 BOOST_CHECK_EQUAL( G1.Delta( v1 ), v012 );
193 BOOST_CHECK_EQUAL( G1.Delta( v2 ), v12 );
194 BOOST_CHECK_EQUAL( G1.Delta( v01 ), v012 );
195 BOOST_CHECK_EQUAL( G1.Delta( v02 ), v012 );
196 BOOST_CHECK_EQUAL( G1.Delta( v12 ), v012 );
197 BOOST_CHECK_EQUAL( G1.Delta( v012 ), v012 );
198 BOOST_CHECK( G1.isConnected() );
199 BOOST_CHECK( G1.isTree() );
200 BOOST_CHECK( G1.isBinary() );
201 BOOST_CHECK( G1.isPairwise() );
202 BOOST_CHECK( G1.MarkovGraph() == H );
203 BOOST_CHECK( G1.bipGraph() == K );
204 BOOST_CHECK_EQUAL( G1.maximalFactorDomains().size(), 2 );
205 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[0], v01 );
206 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[1], v12 );
207
208 facs.push_back( Factor( v02 ) );
209 H.addEdge( 0, 2 );
210 K.addNode2();
211 K.addEdge( 0, 3 );
212 K.addEdge( 2, 3 );
213 FactorGraph G2( facs );
214 BOOST_CHECK_EQUAL( G2.nrVars(), 3 );
215 BOOST_CHECK_EQUAL( G2.nrFactors(), 4 );
216 BOOST_CHECK_EQUAL( G2.nrEdges(), 7 );
217 BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
218 BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
219 BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
220 BOOST_CHECK_EQUAL( G2.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
221 BOOST_CHECK_EQUAL( G2.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
222 BOOST_CHECK_EQUAL( G2.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
223 BOOST_CHECK_EQUAL( G2.findFactor( v01 ), 0 );
224 BOOST_CHECK_EQUAL( G2.findFactor( v12 ), 1 );
225 BOOST_CHECK_EQUAL( G2.findFactor( v1 ), 2 );
226 BOOST_CHECK_EQUAL( G2.findFactor( v02 ), 3 );
227 BOOST_CHECK_EQUAL( G2.delta( 0 ), v12 );
228 BOOST_CHECK_EQUAL( G2.delta( 1 ), v02 );
229 BOOST_CHECK_EQUAL( G2.delta( 2 ), v01 );
230 BOOST_CHECK_EQUAL( G2.Delta( 0 ), v012 );
231 BOOST_CHECK_EQUAL( G2.Delta( 1 ), v012 );
232 BOOST_CHECK_EQUAL( G2.Delta( 2 ), v012 );
233 BOOST_CHECK( G2.isConnected() );
234 BOOST_CHECK( !G2.isTree() );
235 BOOST_CHECK( G2.isBinary() );
236 BOOST_CHECK( G2.isPairwise() );
237 BOOST_CHECK( G2.MarkovGraph() == H );
238 BOOST_CHECK( G2.bipGraph() == K );
239 BOOST_CHECK_EQUAL( G2.maximalFactorDomains().size(), 3 );
240 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[0], v01 );
241 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[1], v12 );
242 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[2], v02 );
243
244 Var v3( 3, 3 );
245 VarSet v03( v0, v3 );
246 VarSet v13( v1, v3 );
247 VarSet v23( v2, v3 );
248 VarSet v013 = v01 | v3;
249 VarSet v023 = v02 | v3;
250 VarSet v123 = v12 | v3;
251 VarSet v0123 = v012 | v3;
252 vars.push_back( v3 );
253 facs.push_back( Factor( v3 ) );
254 H.addNode();
255 K.addNode1();
256 K.addNode2();
257 K.addEdge( 3, 4 );
258 FactorGraph G3( facs );
259 BOOST_CHECK_EQUAL( G3.nrVars(), 4 );
260 BOOST_CHECK_EQUAL( G3.nrFactors(), 5 );
261 BOOST_CHECK_EQUAL( G3.nrEdges(), 8 );
262 BOOST_CHECK_EQUAL( G3.findVar( v0 ), 0 );
263 BOOST_CHECK_EQUAL( G3.findVar( v1 ), 1 );
264 BOOST_CHECK_EQUAL( G3.findVar( v2 ), 2 );
265 BOOST_CHECK_EQUAL( G3.findVar( v3 ), 3 );
266 BOOST_CHECK_EQUAL( G3.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
267 BOOST_CHECK_EQUAL( G3.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
268 BOOST_CHECK_EQUAL( G3.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
269 BOOST_CHECK_EQUAL( G3.findFactor( v01 ), 0 );
270 BOOST_CHECK_EQUAL( G3.findFactor( v12 ), 1 );
271 BOOST_CHECK_EQUAL( G3.findFactor( v1 ), 2 );
272 BOOST_CHECK_EQUAL( G3.findFactor( v02 ), 3 );
273 BOOST_CHECK_EQUAL( G3.findFactor( v3 ), 4 );
274 BOOST_CHECK_THROW( G3.findFactor( v23 ), Exception );
275 BOOST_CHECK_EQUAL( G3.delta( 0 ), v12 );
276 BOOST_CHECK_EQUAL( G3.delta( 1 ), v02 );
277 BOOST_CHECK_EQUAL( G3.delta( 2 ), v01 );
278 BOOST_CHECK_EQUAL( G3.delta( 3 ), VarSet() );
279 BOOST_CHECK_EQUAL( G3.Delta( 0 ), v012 );
280 BOOST_CHECK_EQUAL( G3.Delta( 1 ), v012 );
281 BOOST_CHECK_EQUAL( G3.Delta( 2 ), v012 );
282 BOOST_CHECK_EQUAL( G3.Delta( 3 ), v3 );
283 BOOST_CHECK( !G3.isConnected() );
284 BOOST_CHECK( !G3.isTree() );
285 BOOST_CHECK( !G3.isBinary() );
286 BOOST_CHECK( G3.isPairwise() );
287 BOOST_CHECK( G3.MarkovGraph() == H );
288 BOOST_CHECK( G3.bipGraph() == K );
289 BOOST_CHECK_EQUAL( G3.maximalFactorDomains().size(), 4 );
290 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[0], v01 );
291 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[1], v12 );
292 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[2], v02 );
293 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[3], v3 );
294
295 facs.push_back( Factor( v123 ) );
296 H.addEdge( 1, 3 );
297 H.addEdge( 2, 3 );
298 K.addNode2();
299 K.addEdge( 1, 5 );
300 K.addEdge( 2, 5 );
301 K.addEdge( 3, 5 );
302 FactorGraph G4( facs );
303 BOOST_CHECK_EQUAL( G4.nrVars(), 4 );
304 BOOST_CHECK_EQUAL( G4.nrFactors(), 6 );
305 BOOST_CHECK_EQUAL( G4.nrEdges(), 11 );
306 BOOST_CHECK_EQUAL( G4.findVar( v0 ), 0 );
307 BOOST_CHECK_EQUAL( G4.findVar( v1 ), 1 );
308 BOOST_CHECK_EQUAL( G4.findVar( v2 ), 2 );
309 BOOST_CHECK_EQUAL( G4.findVar( v3 ), 3 );
310 BOOST_CHECK_EQUAL( G4.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
311 BOOST_CHECK_EQUAL( G4.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
312 BOOST_CHECK_EQUAL( G4.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
313 BOOST_CHECK_EQUAL( G4.findFactor( v01 ), 0 );
314 BOOST_CHECK_EQUAL( G4.findFactor( v12 ), 1 );
315 BOOST_CHECK_EQUAL( G4.findFactor( v1 ), 2 );
316 BOOST_CHECK_EQUAL( G4.findFactor( v02 ), 3 );
317 BOOST_CHECK_EQUAL( G4.findFactor( v3 ), 4 );
318 BOOST_CHECK_EQUAL( G4.findFactor( v123 ), 5 );
319 BOOST_CHECK_THROW( G4.findFactor( v23 ), Exception );
320 BOOST_CHECK_EQUAL( G4.delta( 0 ), v12 );
321 BOOST_CHECK_EQUAL( G4.delta( 1 ), v023 );
322 BOOST_CHECK_EQUAL( G4.delta( 2 ), v013 );
323 BOOST_CHECK_EQUAL( G4.delta( 3 ), v12 );
324 BOOST_CHECK_EQUAL( G4.Delta( 0 ), v012 );
325 BOOST_CHECK_EQUAL( G4.Delta( 1 ), v0123 );
326 BOOST_CHECK_EQUAL( G4.Delta( 2 ), v0123 );
327 BOOST_CHECK_EQUAL( G4.Delta( 3 ), v123 );
328 BOOST_CHECK( G4.isConnected() );
329 BOOST_CHECK( !G4.isTree() );
330 BOOST_CHECK( !G4.isBinary() );
331 BOOST_CHECK( !G4.isPairwise() );
332 BOOST_CHECK( G4.MarkovGraph() == H );
333 BOOST_CHECK( G4.bipGraph() == K );
334 BOOST_CHECK_EQUAL( G4.maximalFactorDomains().size(), 3 );
335 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[0], v01 );
336 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[1], v02 );
337 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[2], v123 );
338 }
339
340
341 BOOST_AUTO_TEST_CASE( BackupRestoreTest ) {
342 Var v0( 0, 2 );
343 Var v1( 1, 2 );
344 Var v2( 2, 2 );
345 VarSet v01( v0, v1 );
346 VarSet v02( v0, v2 );
347 VarSet v12( v1, v2 );
348 VarSet v012 = v01 | v2;
349
350 std::vector<Factor> facs;
351 facs.push_back( Factor( v01 ) );
352 facs.push_back( Factor( v12 ) );
353 facs.push_back( Factor( v1 ) );
354 std::vector<Var> vars;
355 vars.push_back( v0 );
356 vars.push_back( v1 );
357 vars.push_back( v2 );
358
359 FactorGraph G( facs );
360 FactorGraph Gorg( G );
361
362 BOOST_CHECK_THROW( G.setFactor( 0, Factor( v0 ), false ), Exception );
363 G.setFactor( 0, Factor( v01, 2.0 ), false );
364 BOOST_CHECK_THROW( G.restoreFactor( 0 ), Exception );
365 G.setFactor( 0, Factor( v01, 3.0 ), true );
366 G.restoreFactor( 0 );
367 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
368 G.setFactor( 0, Gorg.factor( 0 ), false );
369 G.backupFactor( 0 );
370 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
371 G.setFactor( 0, Factor( v01, 2.0 ), false );
372 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
373 G.restoreFactor( 0 );
374 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
375
376 std::map<size_t, Factor> fs;
377 fs[0] = Factor( v01, 3.0 );
378 fs[2] = Factor( v1, 2.0 );
379 G.setFactors( fs, false );
380 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
381 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
382 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
383 G.restoreFactors();
384 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
385 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
386 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
387 G = Gorg;
388 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
389 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
390 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
391 G.setFactors( fs, true );
392 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
393 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
394 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
395 G.restoreFactors();
396 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
397 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
398 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
399 std::set<size_t> fsind;
400 fsind.insert( 0 );
401 fsind.insert( 2 );
402 G.backupFactors( fsind );
403 G.setFactors( fs, false );
404 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
405 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
406 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
407 G.restoreFactors();
408 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
409 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
410 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
411
412 G.backupFactors( v2 );
413 G.setFactor( 1, Factor(v12, 5.0) );
414 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
415 BOOST_CHECK_EQUAL( G.factor(1), Factor(v12, 5.0) );
416 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
417 G.restoreFactors( v2 );
418 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
419 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
420 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
421
422 G.backupFactors( v1 );
423 fs[1] = Factor( v12, 5.0 );
424 G.setFactors( fs, false );
425 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
426 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
427 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
428 G.restoreFactors();
429 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
430 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
431 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
432 G.setFactors( fs, true );
433 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
434 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
435 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
436 G.restoreFactors( v1 );
437 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
438 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
439 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
440 }
441
442
443 BOOST_AUTO_TEST_CASE( TransformationsTest ) {
444 Var v0( 0, 2 );
445 Var v1( 1, 2 );
446 Var v2( 2, 2 );
447 VarSet v01( v0, v1 );
448 VarSet v02( v0, v2 );
449 VarSet v12( v1, v2 );
450 VarSet v012 = v01 | v2;
451
452 std::vector<Factor> facs;
453 facs.push_back( Factor( v01 ).randomize() );
454 facs.push_back( Factor( v12 ).randomize() );
455 facs.push_back( Factor( v1 ).randomize() );
456 std::vector<Var> vars;
457 vars.push_back( v0 );
458 vars.push_back( v1 );
459 vars.push_back( v2 );
460
461 FactorGraph G( facs );
462
463 FactorGraph Gsmall = G.maximalFactors();
464 BOOST_CHECK_EQUAL( Gsmall.nrVars(), 3 );
465 BOOST_CHECK_EQUAL( Gsmall.nrFactors(), 2 );
466 BOOST_CHECK_EQUAL( Gsmall.factor( 0 ), G.factor( 0 ) * G.factor( 2 ) );
467 BOOST_CHECK_EQUAL( Gsmall.factor( 1 ), G.factor( 1 ) );
468
469 size_t i = 0;
470 for( size_t x = 0; x < 2; x++ ) {
471 FactorGraph Gcl = G.clamped( i, x );
472 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
473 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
474 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
475 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) * G.factor(2) );
476 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1) );
477 }
478 i = 1;
479 for( size_t x = 0; x < 2; x++ ) {
480 FactorGraph Gcl = G.clamped( i, x );
481 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
482 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
483 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) * G.factor(2).slice(vars[i],x) );
484 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) );
485 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) );
486 }
487 i = 2;
488 for( size_t x = 0; x < 2; x++ ) {
489 FactorGraph Gcl = G.clamped( i, x );
490 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
491 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
492 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
493 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0) );
494 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) * G.factor(2) );
495 }
496 }
497
498
499 BOOST_AUTO_TEST_CASE( OperationsTest ) {
500 Var v0( 0, 2 );
501 Var v1( 1, 2 );
502 Var v2( 2, 2 );
503 VarSet v01( v0, v1 );
504 VarSet v02( v0, v2 );
505 VarSet v12( v1, v2 );
506 VarSet v012 = v01 | v2;
507
508 std::vector<Factor> facs;
509 facs.push_back( Factor( v01 ).randomize() );
510 facs.push_back( Factor( v12 ).randomize() );
511 facs.push_back( Factor( v1 ).randomize() );
512 std::vector<Var> vars;
513 vars.push_back( v0 );
514 vars.push_back( v1 );
515 vars.push_back( v2 );
516
517 FactorGraph G( facs );
518
519 // clamp
520 FactorGraph Gcl = G;
521 for( size_t i = 0; i < 3; i++ )
522 for( size_t x = 0; x < 2; x++ ) {
523 Gcl.clamp( i, x, true );
524 Factor delta = createFactorDelta( vars[i], x );
525 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
526 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
527 for( size_t j = 0; j < 3; j++ )
528 if( G.factor(j).vars().contains( vars[i] ) )
529 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
530 else
531 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
532
533 Gcl.restoreFactors();
534 for( size_t j = 0; j < 3; j++ )
535 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
536 }
537
538 // clampVar
539 for( size_t i = 0; i < 3; i++ )
540 for( size_t x = 0; x < 2; x++ ) {
541 Gcl.clampVar( i, std::vector<size_t>(1, x), true );
542 Factor delta = createFactorDelta( vars[i], x );
543 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
544 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
545 for( size_t j = 0; j < 3; j++ )
546 if( G.factor(j).vars().contains( vars[i] ) )
547 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
548 else
549 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
550
551 Gcl.restoreFactors();
552 for( size_t j = 0; j < 3; j++ )
553 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
554 }
555 for( size_t i = 0; i < 3; i++ )
556 for( size_t x = 0; x < 2; x++ ) {
557 Gcl.clampVar( i, std::vector<size_t>(), true );
558 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
559 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
560 for( size_t j = 0; j < 3; j++ )
561 if( G.factor(j).vars().contains( vars[i] ) )
562 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * 0.0 );
563 else
564 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
565
566 Gcl.restoreFactors();
567 for( size_t j = 0; j < 3; j++ )
568 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
569 }
570 std::vector<size_t> both;
571 both.push_back( 0 );
572 both.push_back( 1 );
573 for( size_t i = 0; i < 3; i++ )
574 for( size_t x = 0; x < 2; x++ ) {
575 Gcl.clampVar( i, both, true );
576 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
577 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
578 for( size_t j = 0; j < 3; j++ )
579 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
580 Gcl.restoreFactors();
581 }
582
583 // clampFactor
584 for( size_t x = 0; x < 4; x++ ) {
585 Gcl.clampFactor( 0, std::vector<size_t>(1,x), true );
586 Factor mask( v01, 0.0 );
587 mask.set( x, 1.0 );
588 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) * mask );
589 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
590 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
591 Gcl.restoreFactor( 0 );
592 }
593 for( size_t x = 0; x < 4; x++ ) {
594 Gcl.clampFactor( 1, std::vector<size_t>(1,x), true );
595 Factor mask( v12, 0.0 );
596 mask.set( x, 1.0 );
597 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
598 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) * mask );
599 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
600 Gcl.restoreFactor( 1 );
601 }
602 for( size_t x = 0; x < 2; x++ ) {
603 Gcl.clampFactor( 2, std::vector<size_t>(1,x), true );
604 Factor mask( v1, 0.0 );
605 mask.set( x, 1.0 );
606 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
607 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
608 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) * mask );
609 Gcl.restoreFactors();
610 }
611
612 // makeCavity
613 FactorGraph Gcav( G );
614 Gcav.makeCavity( 0, true );
615 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
616 BOOST_CHECK_EQUAL( Gcav.factor(1), G.factor(1) );
617 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
618 Gcav.restoreFactors();
619 Gcav.makeCavity( 1, true );
620 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
621 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
622 BOOST_CHECK_EQUAL( Gcav.factor(2), Factor( v1, 1.0 ) );
623 Gcav.restoreFactors();
624 Gcav.makeCavity( 2, true );
625 BOOST_CHECK_EQUAL( Gcav.factor(0), G.factor(0) );
626 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
627 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
628 Gcav.restoreFactors();
629 }
630
631
632 BOOST_AUTO_TEST_CASE( IOTest ) {
633 Var v0( 0, 2 );
634 Var v1( 1, 2 );
635 Var v2( 2, 2 );
636 VarSet v01( v0, v1 );
637 VarSet v02( v0, v2 );
638 VarSet v12( v1, v2 );
639 VarSet v012 = v01 | v2;
640
641 std::vector<Factor> facs;
642 facs.push_back( Factor( v01 ).randomize() );
643 facs.push_back( Factor( v12 ).randomize() );
644 facs.push_back( Factor( v1 ).randomize() );
645 std::vector<Var> vars;
646 vars.push_back( v0 );
647 vars.push_back( v1 );
648 vars.push_back( v2 );
649
650 FactorGraph G( facs );
651
652 G.WriteToFile( "factorgraph_test.fg" );
653 FactorGraph G2;
654 G2.ReadFromFile( "factorgraph_test.fg" );
655
656 BOOST_CHECK( G.vars() == G2.vars() );
657 BOOST_CHECK( G.bipGraph() == G2.bipGraph() );
658 BOOST_CHECK_EQUAL( G.nrFactors(), G2.nrFactors() );
659 for( size_t I = 0; I < G.nrFactors(); I++ ) {
660 BOOST_CHECK( G.factor(I).vars() == G2.factor(I).vars() );
661 for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
662 BOOST_CHECK_CLOSE( G.factor(I)[s], G2.factor(I)[s], tol );
663 }
664
665 std::stringstream ss;
666 std::string s;
667 G.printDot( ss );
668 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph FactorGraph {" );
669 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
670 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv0;" );
671 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1;" );
672 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv2;" );
673 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box,width=0.3,height=0.3,fixedsize=true];" );
674 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf0;" );
675 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf1;" );
676 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf2;" );
677 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv0 -- f0;" );
678 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f0;" );
679 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f1;" );
680 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f2;" );
681 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv2 -- f1;" );
682 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
683
684 G.setFactor( 0, Factor( G.factor(0).vars(), 1.0 ) );
685 G.setFactor( 1, Factor( G.factor(1).vars(), 2.0 ) );
686 G.setFactor( 2, Factor( G.factor(2).vars(), 3.0 ) );
687 ss << G;
688 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3" );
689 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
690 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
691 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 1 " );
692 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2 " );
693 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "4" );
694 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 1" );
695 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 1" );
696 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 1" );
697 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3 1" );
698 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
699 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
700 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 2 " );
701 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2 " );
702 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "4" );
703 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 2" );
704 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 2" );
705 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2" );
706 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3 2" );
707 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
708 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1" );
709 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 " );
710 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 " );
711 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
712 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 3" );
713 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 3" );
714
715 ss << G;
716 FactorGraph G3;
717 ss >> G3;
718 BOOST_CHECK( G.vars() == G3.vars() );
719 BOOST_CHECK( G.bipGraph() == G3.bipGraph() );
720 BOOST_CHECK_EQUAL( G.nrFactors(), G3.nrFactors() );
721 for( size_t I = 0; I < G.nrFactors(); I++ ) {
722 BOOST_CHECK( G.factor(I).vars() == G3.factor(I).vars() );
723 for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
724 BOOST_CHECK_CLOSE( G.factor(I)[s], G3.factor(I)[s], tol );
725 }
726 }