1ef18b1821e6ac153128b42db42375abf43ac087
[libdai.git] / tests / unit / factorgraph_test.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/bipgraph.h>
10 #include <dai/factorgraph.h>
11 #include <vector>
12 #include <strstream>
13
14
15 using namespace dai;
16
17
18 const double tol = 1e-8;
19
20
21 #define BOOST_TEST_MODULE FactorGraphTest
22
23
24 #include <boost/test/unit_test.hpp>
25 #include <boost/test/floating_point_comparison.hpp>
26
27
28 BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
29 FactorGraph G;
30 BOOST_CHECK_EQUAL( G.vars(), std::vector<Var>() );
31 BOOST_CHECK_EQUAL( G.factors(), std::vector<Factor>() );
32
33 std::vector<Factor> facs;
34 facs.push_back( Factor( VarSet( Var(0, 2), Var(1, 2) ) ) );
35 facs.push_back( Factor( VarSet( Var(0, 2), Var(2, 2) ) ) );
36 facs.push_back( Factor( VarSet( Var(1, 2), Var(2, 2) ) ) );
37 facs.push_back( Factor( VarSet( Var(1, 2) ) ) );
38 std::vector<Var> vars;
39 vars.push_back( Var( 0, 2 ) );
40 vars.push_back( Var( 1, 2 ) );
41 vars.push_back( Var( 2, 2 ) );
42
43 FactorGraph G1( facs );
44 BOOST_CHECK_EQUAL( G1.vars(), vars );
45 BOOST_CHECK_EQUAL( G1.factors(), facs );
46
47 FactorGraph G2( facs.begin(), facs.end(), vars.begin(), vars.end(), facs.size(), vars.size() );
48 BOOST_CHECK_EQUAL( G2.vars(), vars );
49 BOOST_CHECK_EQUAL( G2.factors(), facs );
50
51 FactorGraph *G3 = G2.clone();
52 BOOST_CHECK_EQUAL( G3->vars(), vars );
53 BOOST_CHECK_EQUAL( G3->factors(), facs );
54 delete G3;
55
56 FactorGraph G4 = G2;
57 BOOST_CHECK_EQUAL( G4.vars(), vars );
58 BOOST_CHECK_EQUAL( G4.factors(), facs );
59
60 FactorGraph G5( G2 );
61 BOOST_CHECK_EQUAL( G5.vars(), vars );
62 BOOST_CHECK_EQUAL( G5.factors(), facs );
63 }
64
65
66 BOOST_AUTO_TEST_CASE( AccMutTest ) {
67 std::vector<Factor> facs;
68 facs.push_back( Factor( VarSet( Var(0, 2), Var(1, 2) ) ) );
69 facs.push_back( Factor( VarSet( Var(0, 2), Var(2, 2) ) ) );
70 facs.push_back( Factor( VarSet( Var(1, 2), Var(2, 2) ) ) );
71 facs.push_back( Factor( VarSet( Var(1, 2) ) ) );
72 std::vector<Var> vars;
73 vars.push_back( Var( 0, 2 ) );
74 vars.push_back( Var( 1, 2 ) );
75 vars.push_back( Var( 2, 2 ) );
76
77 FactorGraph G( facs );
78 BOOST_CHECK_EQUAL( G.var(0), Var(0, 2) );
79 BOOST_CHECK_EQUAL( G.var(1), Var(1, 2) );
80 BOOST_CHECK_EQUAL( G.var(2), Var(2, 2) );
81 BOOST_CHECK_EQUAL( G.vars(), vars );
82 BOOST_CHECK_EQUAL( G.factor(0), facs[0] );
83 BOOST_CHECK_EQUAL( G.factor(1), facs[1] );
84 BOOST_CHECK_EQUAL( G.factor(2), facs[2] );
85 BOOST_CHECK_EQUAL( G.factor(3), facs[3] );
86 BOOST_CHECK_EQUAL( G.factors(), facs );
87 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
88 BOOST_CHECK_EQUAL( G.nbV(0,0), 0 );
89 BOOST_CHECK_EQUAL( G.nbV(0,1), 1 );
90 BOOST_CHECK_EQUAL( G.nbV(1).size(), 3 );
91 BOOST_CHECK_EQUAL( G.nbV(1,0), 0 );
92 BOOST_CHECK_EQUAL( G.nbV(1,1), 2 );
93 BOOST_CHECK_EQUAL( G.nbV(1,2), 3 );
94 BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
95 BOOST_CHECK_EQUAL( G.nbV(2,0), 1 );
96 BOOST_CHECK_EQUAL( G.nbV(2,1), 2 );
97 BOOST_CHECK_EQUAL( G.nbF(0).size(), 2 );
98 BOOST_CHECK_EQUAL( G.nbF(0,0), 0 );
99 BOOST_CHECK_EQUAL( G.nbF(0,1), 1 );
100 BOOST_CHECK_EQUAL( G.nbF(1).size(), 2 );
101 BOOST_CHECK_EQUAL( G.nbF(1,0), 0 );
102 BOOST_CHECK_EQUAL( G.nbF(1,1), 2 );
103 BOOST_CHECK_EQUAL( G.nbF(2).size(), 2 );
104 BOOST_CHECK_EQUAL( G.nbF(2,0), 1 );
105 BOOST_CHECK_EQUAL( G.nbF(2,1), 2 );
106 BOOST_CHECK_EQUAL( G.nbF(3).size(), 1 );
107 BOOST_CHECK_EQUAL( G.nbF(3,0), 1 );
108 }
109
110
111 BOOST_AUTO_TEST_CASE( QueriesTest ) {
112 Var v0( 0, 2 );
113 Var v1( 1, 2 );
114 Var v2( 2, 2 );
115 VarSet v01( v0, v1 );
116 VarSet v02( v0, v2 );
117 VarSet v12( v1, v2 );
118 VarSet v012 = v01 | v2;
119
120 FactorGraph G0;
121 BOOST_CHECK_EQUAL( G0.nrVars(), 0 );
122 BOOST_CHECK_EQUAL( G0.nrFactors(), 0 );
123 BOOST_CHECK_EQUAL( G0.nrEdges(), 0 );
124 BOOST_CHECK_THROW( G0.findVar( v0 ), Exception );
125 BOOST_CHECK_THROW( G0.findVars( v01 ), Exception );
126 BOOST_CHECK_THROW( G0.findFactor( v01 ), Exception );
127 #ifdef DAI_DBEUG
128 BOOST_CHECK_THROW( G0.delta( 0 ), Exception );
129 BOOST_CHECK_THROW( G0.Delta( 0 ), Exception );
130 BOOST_CHECK_THROW( G0.delta( v0 ), Exception );
131 BOOST_CHECK_THROW( G0.Delta( v0 ), Exception );
132 #endif
133 BOOST_CHECK( G0.isConnected() );
134 BOOST_CHECK( G0.isTree() );
135 BOOST_CHECK( G0.isBinary() );
136 BOOST_CHECK( G0.isPairwise() );
137 BOOST_CHECK( G0.MarkovGraph() == GraphAL() );
138 BOOST_CHECK( G0.bipGraph() == BipartiteGraph() );
139 BOOST_CHECK_EQUAL( G0.maximalFactorDomains().size(), 1 );
140 BOOST_CHECK_CLOSE( G0.logScore( std::vector<size_t>() ), (Real)0.0, tol );
141
142 std::vector<Factor> facs;
143 facs.push_back( Factor( v01 ) );
144 facs.push_back( Factor( v12 ) );
145 facs.push_back( Factor( v1 ) );
146 std::vector<Var> vars;
147 vars.push_back( v0 );
148 vars.push_back( v1 );
149 vars.push_back( v2 );
150 GraphAL H(3);
151 H.addEdge( 0, 1 );
152 H.addEdge( 1, 2 );
153 BipartiteGraph K(3, 3);
154 K.addEdge( 0, 0 );
155 K.addEdge( 1, 0 );
156 K.addEdge( 1, 1 );
157 K.addEdge( 2, 1 );
158 K.addEdge( 1, 2 );
159
160 FactorGraph G1( facs );
161 BOOST_CHECK_EQUAL( G1.nrVars(), 3 );
162 BOOST_CHECK_EQUAL( G1.nrFactors(), 3 );
163 BOOST_CHECK_EQUAL( G1.nrEdges(), 5 );
164 BOOST_CHECK_EQUAL( G1.findVar( v0 ), 0 );
165 BOOST_CHECK_EQUAL( G1.findVar( v1 ), 1 );
166 BOOST_CHECK_EQUAL( G1.findVar( v2 ), 2 );
167 BOOST_CHECK_EQUAL( G1.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
168 BOOST_CHECK_EQUAL( G1.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
169 BOOST_CHECK_EQUAL( G1.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
170 BOOST_CHECK_EQUAL( G1.findFactor( v01 ), 0 );
171 BOOST_CHECK_EQUAL( G1.findFactor( v12 ), 1 );
172 BOOST_CHECK_EQUAL( G1.findFactor( v1 ), 2 );
173 BOOST_CHECK_THROW( G1.findFactor( v02 ), Exception );
174 BOOST_CHECK_EQUAL( G1.delta( 0 ), v1 );
175 BOOST_CHECK_EQUAL( G1.delta( 1 ), v02 );
176 BOOST_CHECK_EQUAL( G1.delta( 2 ), v1 );
177 BOOST_CHECK_EQUAL( G1.Delta( 0 ), v01 );
178 BOOST_CHECK_EQUAL( G1.Delta( 1 ), v012 );
179 BOOST_CHECK_EQUAL( G1.Delta( 2 ), v12 );
180 BOOST_CHECK_EQUAL( G1.delta( v0 ), v1 );
181 BOOST_CHECK_EQUAL( G1.delta( v1 ), v02 );
182 BOOST_CHECK_EQUAL( G1.delta( v2 ), v1 );
183 BOOST_CHECK_EQUAL( G1.delta( v01 ), v2 );
184 BOOST_CHECK_EQUAL( G1.delta( v02 ), v1 );
185 BOOST_CHECK_EQUAL( G1.delta( v12 ), v0 );
186 BOOST_CHECK_EQUAL( G1.delta( v012 ), VarSet() );
187 BOOST_CHECK_EQUAL( G1.Delta( v0 ), v01 );
188 BOOST_CHECK_EQUAL( G1.Delta( v1 ), v012 );
189 BOOST_CHECK_EQUAL( G1.Delta( v2 ), v12 );
190 BOOST_CHECK_EQUAL( G1.Delta( v01 ), v012 );
191 BOOST_CHECK_EQUAL( G1.Delta( v02 ), v012 );
192 BOOST_CHECK_EQUAL( G1.Delta( v12 ), v012 );
193 BOOST_CHECK_EQUAL( G1.Delta( v012 ), v012 );
194 BOOST_CHECK( G1.isConnected() );
195 BOOST_CHECK( G1.isTree() );
196 BOOST_CHECK( G1.isBinary() );
197 BOOST_CHECK( G1.isPairwise() );
198 BOOST_CHECK( G1.MarkovGraph() == H );
199 BOOST_CHECK( G1.bipGraph() == K );
200 BOOST_CHECK( G1.isMaximal( 0 ) );
201 BOOST_CHECK( G1.isMaximal( 1 ) );
202 BOOST_CHECK( !G1.isMaximal( 2 ) );
203 BOOST_CHECK_EQUAL( G1.maximalFactor( 0 ), 0 );
204 BOOST_CHECK_EQUAL( G1.maximalFactor( 1 ), 1 );
205 BOOST_CHECK_EQUAL( G1.maximalFactor( 2 ), 0 );
206 BOOST_CHECK_EQUAL( G1.maximalFactorDomains().size(), 2 );
207 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[0], v01 );
208 BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[1], v12 );
209 BOOST_CHECK_CLOSE( G1.logScore( std::vector<size_t>(3,0) ), -dai::log((Real)32.0), tol );
210
211 facs.push_back( Factor( v02 ) );
212 H.addEdge( 0, 2 );
213 K.addNode2();
214 K.addEdge( 0, 3 );
215 K.addEdge( 2, 3 );
216 FactorGraph G2( facs );
217 BOOST_CHECK_EQUAL( G2.nrVars(), 3 );
218 BOOST_CHECK_EQUAL( G2.nrFactors(), 4 );
219 BOOST_CHECK_EQUAL( G2.nrEdges(), 7 );
220 BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
221 BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
222 BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
223 BOOST_CHECK_EQUAL( G2.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
224 BOOST_CHECK_EQUAL( G2.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
225 BOOST_CHECK_EQUAL( G2.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
226 BOOST_CHECK_EQUAL( G2.findFactor( v01 ), 0 );
227 BOOST_CHECK_EQUAL( G2.findFactor( v12 ), 1 );
228 BOOST_CHECK_EQUAL( G2.findFactor( v1 ), 2 );
229 BOOST_CHECK_EQUAL( G2.findFactor( v02 ), 3 );
230 BOOST_CHECK_EQUAL( G2.delta( 0 ), v12 );
231 BOOST_CHECK_EQUAL( G2.delta( 1 ), v02 );
232 BOOST_CHECK_EQUAL( G2.delta( 2 ), v01 );
233 BOOST_CHECK_EQUAL( G2.Delta( 0 ), v012 );
234 BOOST_CHECK_EQUAL( G2.Delta( 1 ), v012 );
235 BOOST_CHECK_EQUAL( G2.Delta( 2 ), v012 );
236 BOOST_CHECK( G2.isConnected() );
237 BOOST_CHECK( !G2.isTree() );
238 BOOST_CHECK( G2.isBinary() );
239 BOOST_CHECK( G2.isPairwise() );
240 BOOST_CHECK( G2.MarkovGraph() == H );
241 BOOST_CHECK( G2.bipGraph() == K );
242 BOOST_CHECK( G2.isMaximal( 0 ) );
243 BOOST_CHECK( G2.isMaximal( 1 ) );
244 BOOST_CHECK( !G2.isMaximal( 2 ) );
245 BOOST_CHECK( G2.isMaximal( 3 ) );
246 BOOST_CHECK_EQUAL( G2.maximalFactor( 0 ), 0 );
247 BOOST_CHECK_EQUAL( G2.maximalFactor( 1 ), 1 );
248 BOOST_CHECK_EQUAL( G2.maximalFactor( 2 ), 0 );
249 BOOST_CHECK_EQUAL( G2.maximalFactor( 3 ), 3 );
250 BOOST_CHECK_EQUAL( G2.maximalFactorDomains().size(), 3 );
251 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[0], v01 );
252 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[1], v12 );
253 BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[2], v02 );
254 BOOST_CHECK_CLOSE( G2.logScore( std::vector<size_t>(3,0) ), -dai::log((Real)128.0), tol );
255
256 Var v3( 3, 3 );
257 VarSet v03( v0, v3 );
258 VarSet v13( v1, v3 );
259 VarSet v23( v2, v3 );
260 VarSet v013 = v01 | v3;
261 VarSet v023 = v02 | v3;
262 VarSet v123 = v12 | v3;
263 VarSet v0123 = v012 | v3;
264 vars.push_back( v3 );
265 facs.push_back( Factor( v3 ) );
266 H.addNode();
267 K.addNode1();
268 K.addNode2();
269 K.addEdge( 3, 4 );
270 FactorGraph G3( facs );
271 BOOST_CHECK_EQUAL( G3.nrVars(), 4 );
272 BOOST_CHECK_EQUAL( G3.nrFactors(), 5 );
273 BOOST_CHECK_EQUAL( G3.nrEdges(), 8 );
274 BOOST_CHECK_EQUAL( G3.findVar( v0 ), 0 );
275 BOOST_CHECK_EQUAL( G3.findVar( v1 ), 1 );
276 BOOST_CHECK_EQUAL( G3.findVar( v2 ), 2 );
277 BOOST_CHECK_EQUAL( G3.findVar( v3 ), 3 );
278 BOOST_CHECK_EQUAL( G3.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
279 BOOST_CHECK_EQUAL( G3.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
280 BOOST_CHECK_EQUAL( G3.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
281 BOOST_CHECK_EQUAL( G3.findFactor( v01 ), 0 );
282 BOOST_CHECK_EQUAL( G3.findFactor( v12 ), 1 );
283 BOOST_CHECK_EQUAL( G3.findFactor( v1 ), 2 );
284 BOOST_CHECK_EQUAL( G3.findFactor( v02 ), 3 );
285 BOOST_CHECK_EQUAL( G3.findFactor( v3 ), 4 );
286 BOOST_CHECK_THROW( G3.findFactor( v23 ), Exception );
287 BOOST_CHECK_EQUAL( G3.delta( 0 ), v12 );
288 BOOST_CHECK_EQUAL( G3.delta( 1 ), v02 );
289 BOOST_CHECK_EQUAL( G3.delta( 2 ), v01 );
290 BOOST_CHECK_EQUAL( G3.delta( 3 ), VarSet() );
291 BOOST_CHECK_EQUAL( G3.Delta( 0 ), v012 );
292 BOOST_CHECK_EQUAL( G3.Delta( 1 ), v012 );
293 BOOST_CHECK_EQUAL( G3.Delta( 2 ), v012 );
294 BOOST_CHECK_EQUAL( G3.Delta( 3 ), v3 );
295 BOOST_CHECK( !G3.isConnected() );
296 BOOST_CHECK( !G3.isTree() );
297 BOOST_CHECK( !G3.isBinary() );
298 BOOST_CHECK( G3.isPairwise() );
299 BOOST_CHECK( G3.MarkovGraph() == H );
300 BOOST_CHECK( G3.bipGraph() == K );
301 BOOST_CHECK( G3.isMaximal( 0 ) );
302 BOOST_CHECK( G3.isMaximal( 1 ) );
303 BOOST_CHECK( !G3.isMaximal( 2 ) );
304 BOOST_CHECK( G3.isMaximal( 3 ) );
305 BOOST_CHECK( G3.isMaximal( 4 ) );
306 BOOST_CHECK_EQUAL( G3.maximalFactor( 0 ), 0 );
307 BOOST_CHECK_EQUAL( G3.maximalFactor( 1 ), 1 );
308 BOOST_CHECK_EQUAL( G3.maximalFactor( 2 ), 0 );
309 BOOST_CHECK_EQUAL( G3.maximalFactor( 3 ), 3 );
310 BOOST_CHECK_EQUAL( G3.maximalFactor( 4 ), 4 );
311 BOOST_CHECK_EQUAL( G3.maximalFactorDomains().size(), 4 );
312 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[0], v01 );
313 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[1], v12 );
314 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[2], v02 );
315 BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[3], v3 );
316 BOOST_CHECK_CLOSE( G3.logScore( std::vector<size_t>(4,0) ), -dai::log((Real)384.0), tol );
317
318 facs.push_back( Factor( v123 ) );
319 H.addEdge( 1, 3 );
320 H.addEdge( 2, 3 );
321 K.addNode2();
322 K.addEdge( 1, 5 );
323 K.addEdge( 2, 5 );
324 K.addEdge( 3, 5 );
325 FactorGraph G4( facs );
326 BOOST_CHECK_EQUAL( G4.nrVars(), 4 );
327 BOOST_CHECK_EQUAL( G4.nrFactors(), 6 );
328 BOOST_CHECK_EQUAL( G4.nrEdges(), 11 );
329 BOOST_CHECK_EQUAL( G4.findVar( v0 ), 0 );
330 BOOST_CHECK_EQUAL( G4.findVar( v1 ), 1 );
331 BOOST_CHECK_EQUAL( G4.findVar( v2 ), 2 );
332 BOOST_CHECK_EQUAL( G4.findVar( v3 ), 3 );
333 BOOST_CHECK_EQUAL( G4.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
334 BOOST_CHECK_EQUAL( G4.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
335 BOOST_CHECK_EQUAL( G4.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
336 BOOST_CHECK_EQUAL( G4.findFactor( v01 ), 0 );
337 BOOST_CHECK_EQUAL( G4.findFactor( v12 ), 1 );
338 BOOST_CHECK_EQUAL( G4.findFactor( v1 ), 2 );
339 BOOST_CHECK_EQUAL( G4.findFactor( v02 ), 3 );
340 BOOST_CHECK_EQUAL( G4.findFactor( v3 ), 4 );
341 BOOST_CHECK_EQUAL( G4.findFactor( v123 ), 5 );
342 BOOST_CHECK_THROW( G4.findFactor( v23 ), Exception );
343 BOOST_CHECK_EQUAL( G4.delta( 0 ), v12 );
344 BOOST_CHECK_EQUAL( G4.delta( 1 ), v023 );
345 BOOST_CHECK_EQUAL( G4.delta( 2 ), v013 );
346 BOOST_CHECK_EQUAL( G4.delta( 3 ), v12 );
347 BOOST_CHECK_EQUAL( G4.Delta( 0 ), v012 );
348 BOOST_CHECK_EQUAL( G4.Delta( 1 ), v0123 );
349 BOOST_CHECK_EQUAL( G4.Delta( 2 ), v0123 );
350 BOOST_CHECK_EQUAL( G4.Delta( 3 ), v123 );
351 BOOST_CHECK( G4.isConnected() );
352 BOOST_CHECK( !G4.isTree() );
353 BOOST_CHECK( !G4.isBinary() );
354 BOOST_CHECK( !G4.isPairwise() );
355 BOOST_CHECK( G4.MarkovGraph() == H );
356 BOOST_CHECK( G4.bipGraph() == K );
357 BOOST_CHECK( G4.isMaximal( 0 ) );
358 BOOST_CHECK( !G4.isMaximal( 1 ) );
359 BOOST_CHECK( !G4.isMaximal( 2 ) );
360 BOOST_CHECK( G4.isMaximal( 3 ) );
361 BOOST_CHECK( !G4.isMaximal( 4 ) );
362 BOOST_CHECK( G4.isMaximal( 5 ) );
363 BOOST_CHECK_EQUAL( G4.maximalFactor( 0 ), 0 );
364 BOOST_CHECK_EQUAL( G4.maximalFactor( 1 ), 5 );
365 BOOST_CHECK_EQUAL( G4.maximalFactor( 2 ), 0 );
366 BOOST_CHECK_EQUAL( G4.maximalFactor( 3 ), 3 );
367 BOOST_CHECK_EQUAL( G4.maximalFactor( 4 ), 5 );
368 BOOST_CHECK_EQUAL( G4.maximalFactor( 5 ), 5 );
369 BOOST_CHECK_EQUAL( G4.maximalFactorDomains().size(), 3 );
370 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[0], v01 );
371 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[1], v02 );
372 BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[2], v123 );
373 BOOST_CHECK_CLOSE( G4.logScore( std::vector<size_t>(4,0) ), -dai::log((Real)4608.0), tol );
374 }
375
376
377 BOOST_AUTO_TEST_CASE( BackupRestoreTest ) {
378 Var v0( 0, 2 );
379 Var v1( 1, 2 );
380 Var v2( 2, 2 );
381 VarSet v01( v0, v1 );
382 VarSet v02( v0, v2 );
383 VarSet v12( v1, v2 );
384 VarSet v012 = v01 | v2;
385
386 std::vector<Factor> facs;
387 facs.push_back( Factor( v01 ) );
388 facs.push_back( Factor( v12 ) );
389 facs.push_back( Factor( v1 ) );
390 std::vector<Var> vars;
391 vars.push_back( v0 );
392 vars.push_back( v1 );
393 vars.push_back( v2 );
394
395 FactorGraph G( facs );
396 FactorGraph Gorg( G );
397
398 BOOST_CHECK_THROW( G.setFactor( 0, Factor( v0 ), false ), Exception );
399 G.setFactor( 0, Factor( v01, 2.0 ), false );
400 BOOST_CHECK_THROW( G.restoreFactor( 0 ), Exception );
401 G.setFactor( 0, Factor( v01, 3.0 ), true );
402 G.restoreFactor( 0 );
403 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
404 G.setFactor( 0, Gorg.factor( 0 ), false );
405 G.backupFactor( 0 );
406 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
407 G.setFactor( 0, Factor( v01, 2.0 ), false );
408 BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
409 G.restoreFactor( 0 );
410 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
411
412 std::map<size_t, Factor> fs;
413 fs[0] = Factor( v01, 3.0 );
414 fs[2] = Factor( v1, 2.0 );
415 G.setFactors( fs, false );
416 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
417 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
418 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
419 G.restoreFactors();
420 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
421 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
422 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
423 G = Gorg;
424 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
425 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
426 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
427 G.setFactors( fs, true );
428 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
429 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
430 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
431 G.restoreFactors();
432 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
433 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
434 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
435 std::set<size_t> fsind;
436 fsind.insert( 0 );
437 fsind.insert( 2 );
438 G.backupFactors( fsind );
439 G.setFactors( fs, false );
440 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
441 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
442 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
443 G.restoreFactors();
444 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
445 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
446 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
447
448 G.backupFactors( v2 );
449 G.setFactor( 1, Factor(v12, 5.0) );
450 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
451 BOOST_CHECK_EQUAL( G.factor(1), Factor(v12, 5.0) );
452 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
453 G.restoreFactors( v2 );
454 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
455 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
456 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
457
458 G.backupFactors( v1 );
459 fs[1] = Factor( v12, 5.0 );
460 G.setFactors( fs, false );
461 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
462 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
463 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
464 G.restoreFactors();
465 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
466 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
467 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
468 G.setFactors( fs, true );
469 BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
470 BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
471 BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
472 G.restoreFactors( v1 );
473 BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
474 BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
475 BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
476 }
477
478
479 BOOST_AUTO_TEST_CASE( TransformationsTest ) {
480 Var v0( 0, 2 );
481 Var v1( 1, 2 );
482 Var v2( 2, 2 );
483 VarSet v01( v0, v1 );
484 VarSet v02( v0, v2 );
485 VarSet v12( v1, v2 );
486 VarSet v012 = v01 | v2;
487
488 std::vector<Factor> facs;
489 facs.push_back( Factor( v01 ).randomize() );
490 facs.push_back( Factor( v12 ).randomize() );
491 facs.push_back( Factor( v1 ).randomize() );
492 std::vector<Var> vars;
493 vars.push_back( v0 );
494 vars.push_back( v1 );
495 vars.push_back( v2 );
496
497 FactorGraph G( facs );
498
499 FactorGraph Gsmall = G.maximalFactors();
500 BOOST_CHECK_EQUAL( Gsmall.nrVars(), 3 );
501 BOOST_CHECK_EQUAL( Gsmall.nrFactors(), 2 );
502 BOOST_CHECK_EQUAL( Gsmall.factor( 0 ), G.factor( 0 ) * G.factor( 2 ) );
503 BOOST_CHECK_EQUAL( Gsmall.factor( 1 ), G.factor( 1 ) );
504
505 size_t i = 0;
506 for( size_t x = 0; x < 2; x++ ) {
507 FactorGraph Gcl = G.clamped( i, x );
508 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
509 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
510 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
511 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) * G.factor(2) );
512 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1) );
513 }
514 i = 1;
515 for( size_t x = 0; x < 2; x++ ) {
516 FactorGraph Gcl = G.clamped( i, x );
517 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
518 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
519 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) * G.factor(2).slice(vars[i],x) );
520 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) );
521 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) );
522 }
523 i = 2;
524 for( size_t x = 0; x < 2; x++ ) {
525 FactorGraph Gcl = G.clamped( i, x );
526 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
527 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
528 BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
529 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0) );
530 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) * G.factor(2) );
531 }
532 }
533
534
535 BOOST_AUTO_TEST_CASE( OperationsTest ) {
536 Var v0( 0, 2 );
537 Var v1( 1, 2 );
538 Var v2( 2, 2 );
539 VarSet v01( v0, v1 );
540 VarSet v02( v0, v2 );
541 VarSet v12( v1, v2 );
542 VarSet v012 = v01 | v2;
543
544 std::vector<Factor> facs;
545 facs.push_back( Factor( v01 ).randomize() );
546 facs.push_back( Factor( v12 ).randomize() );
547 facs.push_back( Factor( v1 ).randomize() );
548 std::vector<Var> vars;
549 vars.push_back( v0 );
550 vars.push_back( v1 );
551 vars.push_back( v2 );
552
553 FactorGraph G( facs );
554
555 // clamp
556 FactorGraph Gcl = G;
557 for( size_t i = 0; i < 3; i++ )
558 for( size_t x = 0; x < 2; x++ ) {
559 Gcl.clamp( i, x, true );
560 Factor delta = createFactorDelta( vars[i], x );
561 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
562 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
563 for( size_t j = 0; j < 3; j++ )
564 if( G.factor(j).vars().contains( vars[i] ) )
565 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
566 else
567 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
568
569 Gcl.restoreFactors();
570 for( size_t j = 0; j < 3; j++ )
571 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
572 }
573
574 // clampVar
575 for( size_t i = 0; i < 3; i++ )
576 for( size_t x = 0; x < 2; x++ ) {
577 Gcl.clampVar( i, std::vector<size_t>(1, x), true );
578 Factor delta = createFactorDelta( vars[i], x );
579 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
580 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
581 for( size_t j = 0; j < 3; j++ )
582 if( G.factor(j).vars().contains( vars[i] ) )
583 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
584 else
585 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
586
587 Gcl.restoreFactors();
588 for( size_t j = 0; j < 3; j++ )
589 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
590 }
591 for( size_t i = 0; i < 3; i++ )
592 for( size_t x = 0; x < 2; x++ ) {
593 Gcl.clampVar( i, std::vector<size_t>(), true );
594 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
595 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
596 for( size_t j = 0; j < 3; j++ )
597 if( G.factor(j).vars().contains( vars[i] ) )
598 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * 0.0 );
599 else
600 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
601
602 Gcl.restoreFactors();
603 for( size_t j = 0; j < 3; j++ )
604 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
605 }
606 std::vector<size_t> both;
607 both.push_back( 0 );
608 both.push_back( 1 );
609 for( size_t i = 0; i < 3; i++ )
610 for( size_t x = 0; x < 2; x++ ) {
611 Gcl.clampVar( i, both, true );
612 BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
613 BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
614 for( size_t j = 0; j < 3; j++ )
615 BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
616 Gcl.restoreFactors();
617 }
618
619 // clampFactor
620 for( size_t x = 0; x < 4; x++ ) {
621 Gcl.clampFactor( 0, std::vector<size_t>(1,x), true );
622 Factor mask( v01, 0.0 );
623 mask.set( x, 1.0 );
624 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) * mask );
625 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
626 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
627 Gcl.restoreFactor( 0 );
628 }
629 for( size_t x = 0; x < 4; x++ ) {
630 Gcl.clampFactor( 1, std::vector<size_t>(1,x), true );
631 Factor mask( v12, 0.0 );
632 mask.set( x, 1.0 );
633 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
634 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) * mask );
635 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
636 Gcl.restoreFactor( 1 );
637 }
638 for( size_t x = 0; x < 2; x++ ) {
639 Gcl.clampFactor( 2, std::vector<size_t>(1,x), true );
640 Factor mask( v1, 0.0 );
641 mask.set( x, 1.0 );
642 BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
643 BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
644 BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) * mask );
645 Gcl.restoreFactors();
646 }
647
648 // makeCavity
649 FactorGraph Gcav( G );
650 Gcav.makeCavity( 0, true );
651 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
652 BOOST_CHECK_EQUAL( Gcav.factor(1), G.factor(1) );
653 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
654 Gcav.restoreFactors();
655 Gcav.makeCavity( 1, true );
656 BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
657 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
658 BOOST_CHECK_EQUAL( Gcav.factor(2), Factor( v1, 1.0 ) );
659 Gcav.restoreFactors();
660 Gcav.makeCavity( 2, true );
661 BOOST_CHECK_EQUAL( Gcav.factor(0), G.factor(0) );
662 BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
663 BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
664 Gcav.restoreFactors();
665 }
666
667
668 BOOST_AUTO_TEST_CASE( IOTest ) {
669 Var v0( 0, 2 );
670 Var v1( 1, 2 );
671 Var v2( 2, 2 );
672 VarSet v01( v0, v1 );
673 VarSet v02( v0, v2 );
674 VarSet v12( v1, v2 );
675 VarSet v012 = v01 | v2;
676
677 std::vector<Factor> facs;
678 facs.push_back( Factor( v01 ).randomize() );
679 facs.push_back( Factor( v12 ).randomize() );
680 facs.push_back( Factor( v1 ).randomize() );
681 std::vector<Var> vars;
682 vars.push_back( v0 );
683 vars.push_back( v1 );
684 vars.push_back( v2 );
685
686 FactorGraph G( facs );
687
688 G.WriteToFile( "factorgraph_test.fg" );
689 FactorGraph G2;
690 G2.ReadFromFile( "factorgraph_test.fg" );
691
692 BOOST_CHECK( G.vars() == G2.vars() );
693 BOOST_CHECK( G.bipGraph() == G2.bipGraph() );
694 BOOST_CHECK_EQUAL( G.nrFactors(), G2.nrFactors() );
695 for( size_t I = 0; I < G.nrFactors(); I++ ) {
696 BOOST_CHECK( G.factor(I).vars() == G2.factor(I).vars() );
697 for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
698 BOOST_CHECK_CLOSE( G.factor(I)[s], G2.factor(I)[s], tol );
699 }
700
701 std::stringstream ss;
702 std::string s;
703 G.printDot( ss );
704 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph FactorGraph {" );
705 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
706 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv0;" );
707 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1;" );
708 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv2;" );
709 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box,width=0.3,height=0.3,fixedsize=true];" );
710 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf0;" );
711 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf1;" );
712 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf2;" );
713 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv0 -- f0;" );
714 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f0;" );
715 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f1;" );
716 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f2;" );
717 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv2 -- f1;" );
718 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
719
720 G.setFactor( 0, Factor( G.factor(0).vars(), 1.0 ) );
721 G.setFactor( 1, Factor( G.factor(1).vars(), 2.0 ) );
722 G.setFactor( 2, Factor( G.factor(2).vars(), 3.0 ) );
723 ss << G;
724 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3" );
725 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
726 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
727 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 1 " );
728 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2 " );
729 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "4" );
730 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 1" );
731 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 1" );
732 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 1" );
733 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3 1" );
734 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
735 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
736 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 2 " );
737 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2 " );
738 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "4" );
739 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 2" );
740 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 2" );
741 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2" );
742 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3 2" );
743 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
744 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1" );
745 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 " );
746 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 " );
747 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
748 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 3" );
749 std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 3" );
750
751 ss << G;
752 FactorGraph G3;
753 ss >> G3;
754 BOOST_CHECK( G.vars() == G3.vars() );
755 BOOST_CHECK( G.bipGraph() == G3.bipGraph() );
756 BOOST_CHECK_EQUAL( G.nrFactors(), G3.nrFactors() );
757 for( size_t I = 0; I < G.nrFactors(); I++ ) {
758 BOOST_CHECK( G.factor(I).vars() == G3.factor(I).vars() );
759 for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
760 BOOST_CHECK_CLOSE( G.factor(I)[s], G3.factor(I)[s], tol );
761 }
762 }