matlabs : matlab/dai$(ME) matlab/dai_readfg$(ME) matlab/dai_writefg$(ME) matlab/dai_potstrength$(ME)
-unittests : tests/unit/var$(EE) tests/unit/smallset$(EE) tests/unit/varset$(EE) tests/unit/graph$(EE) tests/unit/bipgraph$(EE) tests/unit/weightedgraph$(EE) tests/unit/enum$(EE) tests/unit/enum$(EE) tests/unit/util$(EE) tests/unit/exceptions$(EE) tests/unit/properties$(EE) tests/unit/index$(EE) tests/unit/prob$(EE) tests/unit/factor$(EE) tests/unit/factorgraph$(EE) tests/unit/clustergraph$(EE) tests/unit/regiongraph$(EE) tests/unit/daialg$(EE) tests/unit/alldai$(EE)
+unittests : tests/unit/var_test$(EE) tests/unit/smallset_test$(EE) tests/unit/varset_test$(EE) tests/unit/graph_test$(EE) tests/unit/bipgraph_test$(EE) tests/unit/weightedgraph_test$(EE) tests/unit/enum_test$(EE) tests/unit/enum_test$(EE) tests/unit/util_test$(EE) tests/unit/exceptions_test$(EE) tests/unit/properties_test$(EE) tests/unit/index_test$(EE) tests/unit/prob_test$(EE) tests/unit/factor_test$(EE) tests/unit/factorgraph_test$(EE) tests/unit/clustergraph_test$(EE) tests/unit/regiongraph_test$(EE) tests/unit/daialg_test$(EE) tests/unit/alldai_test$(EE)
echo Running unit tests...
- tests/unit/var$(EE)
- tests/unit/smallset$(EE)
- tests/unit/varset$(EE)
- tests/unit/graph$(EE)
- tests/unit/bipgraph$(EE)
- tests/unit/weightedgraph$(EE)
- tests/unit/enum$(EE)
- tests/unit/util$(EE)
- tests/unit/exceptions$(EE)
- tests/unit/properties$(EE)
- tests/unit/index$(EE)
- tests/unit/prob$(EE)
- tests/unit/factor$(EE)
- tests/unit/factorgraph$(EE)
- tests/unit/clustergraph$(EE)
- tests/unit/regiongraph$(EE)
- tests/unit/daialg$(EE)
- tests/unit/alldai$(EE)
+ tests/unit/var_test$(EE)
+ tests/unit/smallset_test$(EE)
+ tests/unit/varset_test$(EE)
+ tests/unit/graph_test$(EE)
+ tests/unit/bipgraph_test$(EE)
+ tests/unit/weightedgraph_test$(EE)
+ tests/unit/enum_test$(EE)
+ tests/unit/util_test$(EE)
+ tests/unit/exceptions_test$(EE)
+ tests/unit/properties_test$(EE)
+ tests/unit/index_test$(EE)
+ tests/unit/prob_test$(EE)
+ tests/unit/factor_test$(EE)
+ tests/unit/factorgraph_test$(EE)
+ tests/unit/clustergraph_test$(EE)
+ tests/unit/regiongraph_test$(EE)
+ tests/unit/daialg_test$(EE)
+ tests/unit/alldai_test$(EE)
tests : tests/testdai$(EE) tests/testem/testem$(EE) tests/testbbp$(EE) $(unittests)
-rm matlab/*$(ME)
-rm examples/example$(EE) examples/example_bipgraph$(EE) examples/example_varset$(EE) examples/example_permute$(EE) examples/example_sprinkler$(EE) examples/example_sprinkler_gibbs$(EE) examples/example_sprinkler_em$(EE)
-rm tests/testdai$(EE) tests/testem/testem$(EE) tests/testbbp$(EE)
- -rm tests/unit/var$(EE) tests/unit/smallset$(EE) tests/unit/varset$(EE) tests/unit/graph$(EE) tests/unit/bipgraph$(EE) tests/unit/weightedgraph$(EE) tests/unit/enum$(EE) tests/unit/util$(EE) tests/unit/exceptions$(EE) tests/unit/properties$(EE) tests/unit/index$(EE) tests/unit/prob$(EE) tests/unit/factor$(EE) tests/unit/factorgraph$(EE) tests/unit/clustergraph$(EE) tests/unit/regiongraph$(EE) tests/unit/daialg$(EE) tests/unit/alldai$(EE)
+ -rm tests/unit/var_test$(EE) tests/unit/smallset_test$(EE) tests/unit/varset_test$(EE) tests/unit/graph_test$(EE) tests/unit/bipgraph_test$(EE) tests/unit/weightedgraph_test$(EE) tests/unit/enum_test$(EE) tests/unit/util_test$(EE) tests/unit/exceptions_test$(EE) tests/unit/properties_test$(EE) tests/unit/index_test$(EE) tests/unit/prob_test$(EE) tests/unit/factor_test$(EE) tests/unit/factorgraph_test$(EE) tests/unit/clustergraph_test$(EE) tests/unit/regiongraph_test$(EE) tests/unit/daialg_test$(EE) tests/unit/alldai_test$(EE)
-rm factorgraph_test.fg alldai_test.aliases
-rm utils/fg2dot$(EE) utils/createfg$(EE) utils/fginfo$(EE)
-rm -R doc
-del utils\*$(EE).manifest
-del utils\*.pdb
-del utils\*.ilk
- -del tests\unit\*.ilk
- -del tests\unit\*.pdb
- -del tests\unit\*$(EE)
- -del tests\unit\*$(EE).manifest
+ -del tests\unit\*_test.ilk
+ -del tests\unit\*_test.pdb
+ -del tests\unit\*_test$(EE)
+ -del tests\unit\*_test$(EE).manifest
-del factorgraph_test.fg
-del alldai_test.aliases
-del $(LIB)\libdai$(LE)
# Flags to add in non-debugging mode (if DEBUG=false)
CCNODEBUGFLAGS=/Ox
# Standard include directories
-CCINC=-Iinclude -IE:\boost_1_42_0
+CCINC=-Iinclude -IE:\windows\boost_1_42_0
# LINKER
# Standard libraries to include
LIBS=/link $(LIB)/libdai$(LE)
# For linking with BOOST libraries
-BOOSTLIBS_PO=/LIBPATH:E:\boost_1_42_0\stage\lib
-BOOSTLIBS_UTF=/LIBPATH:E:\boost_1_42_0\stage\lib
+BOOSTLIBS_PO=/LIBPATH:E:\windows\boost_1_42_0\stage\lib
+BOOSTLIBS_UTF=/LIBPATH:E:\windows\boost_1_42_0\stage\lib
# Additional library search paths for linker
# (For some reason, we have to add the VC library path, although it is in the environment)
CCLIB=/LIBPATH:"C:\Program Files\Microsoft Visual Studio 9.0\VC\ATLMFC\LIB" /LIBPATH:"C:\Program Files\Microsoft Visual Studio 9.0\VC\LIB" /LIBPATH:"C:\Program Files\Microsoft SDKs\Windows\v6.0A\lib"
# Location of Python header files
INCLUDE_PYTHON=C:\python2.5
# Location of Boost C++ library header files
-INCLUDE_BOOST=E:\boost_1_42_0
+INCLUDE_BOOST=E:\windows\boost_1_42_0
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/alldai.h>
-#include <strstream>
-#include <fstream>
-
-
-using namespace dai;
-
-
-const double tol = 1e-8;
-
-
-#define BOOST_TEST_MODULE DAIAlgTest
-
-
-#include <boost/test/unit_test.hpp>
-
-
-BOOST_AUTO_TEST_CASE( newInfAlgTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 2 );
- Var v2( 2, 2 );
- Var v3( 3, 2 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v03( v0, v3 );
- VarSet v12( v1, v2 );
- VarSet v13( v1, v3 );
- VarSet v23( v2, v3 );
- std::vector<Factor> facs;
- facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
- facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
- facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
- facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
- facs.push_back( createFactorIsing( v0, -1.0 ) );
- facs.push_back( createFactorIsing( v1, -1.0 ) );
- facs.push_back( createFactorIsing( v2, -1.0 ) );
- facs.push_back( createFactorIsing( v3, 1.0 ) );
- Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
- FactorGraph fg( facs );
- VarSet vs = v01;
-
- InfAlg* ia = newInfAlg( "EXACT", fg, PropertySet()("verbose",(size_t)0) );
- ia->init();
- ia->run();
- BOOST_CHECK( dist( ia->belief( v01 ), joint.marginal( vs ), DISTTV ) < tol );
- BOOST_CHECK( dist( calcMarginal( *ia, vs, true ), joint.marginal( vs ), DISTTV ) < tol );
- delete ia;
-
- ia = newInfAlgFromString( "EXACT[verbose=0]", fg );
- ia->init();
- ia->run();
- BOOST_CHECK( dist( ia->belief( v01 ), joint.marginal( vs ), DISTTV ) < tol );
- BOOST_CHECK( dist( calcMarginal( *ia, vs, true ), joint.marginal( vs ), DISTTV ) < tol );
- delete ia;
-
- std::map<std::string, std::string> aliases;
- aliases["alias"] = "EXACT[verbose=1]";
- ia = newInfAlgFromString( "alias[verbose=0]", fg, aliases );
- ia->init();
- ia->run();
- BOOST_CHECK( dist( ia->belief( v01 ), joint.marginal( vs ), DISTTV ) < tol );
- BOOST_CHECK( dist( calcMarginal( *ia, vs, true ), joint.marginal( vs ), DISTTV ) < tol );
- delete ia;
-}
-
-
-BOOST_AUTO_TEST_CASE( parseTest ) {
- std::pair<std::string, PropertySet> nameProps = parseNameProperties( "name" );
- BOOST_CHECK_EQUAL( nameProps.first, "name" );
- BOOST_CHECK_EQUAL( nameProps.second.size(), 0 );
-
- nameProps = parseNameProperties( "name[]" );
- BOOST_CHECK_EQUAL( nameProps.first, "name" );
- BOOST_CHECK_EQUAL( nameProps.second.size(), 0 );
-
- nameProps = parseNameProperties( "name[key1=value,key2=0.5]" );
- BOOST_CHECK_EQUAL( nameProps.first, "name" );
- BOOST_CHECK_EQUAL( nameProps.second.size(), 2 );
- BOOST_CHECK( nameProps.second.hasKey( "key1" ) );
- BOOST_CHECK( nameProps.second.hasKey( "key2" ) );
- BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key1"), "value" );
- BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key2"), "0.5" );
-
- std::map<std::string, std::string> aliases;
- aliases["alias"] = "name[key1=other]";
-
- nameProps = parseNameProperties( "alias", aliases );
- BOOST_CHECK_EQUAL( nameProps.first, "name" );
- BOOST_CHECK_EQUAL( nameProps.second.size(), 1 );
- BOOST_CHECK( nameProps.second.hasKey( "key1" ) );
- BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key1"), "other" );
-
- nameProps = parseNameProperties( "alias[]", aliases );
- BOOST_CHECK_EQUAL( nameProps.first, "name" );
- BOOST_CHECK_EQUAL( nameProps.second.size(), 1 );
- BOOST_CHECK( nameProps.second.hasKey( "key1" ) );
- BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key1"), "other" );
-
- nameProps = parseNameProperties( "alias[key1=value,key2=0.5]", aliases );
- BOOST_CHECK_EQUAL( nameProps.first, "name" );
- BOOST_CHECK_EQUAL( nameProps.second.size(), 2 );
- BOOST_CHECK( nameProps.second.hasKey( "key1" ) );
- BOOST_CHECK( nameProps.second.hasKey( "key2" ) );
- BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key1"), "value" );
- BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key2"), "0.5" );
-}
-
-
-BOOST_AUTO_TEST_CASE( readAliasFileTest ) {
- std::ofstream outfile;
- std::string filename( "alldai_test.aliases" );
- outfile.open( filename.c_str() );
- if( outfile.is_open() ) {
- outfile << "alias:\tname[key1=other]" << std::endl;
- outfile.close();
- } else
- DAI_THROWE(CANNOT_WRITE_FILE,"Cannot write to file " + std::string(filename));
-
- std::map<std::string, std::string> aliases;
- aliases["alias"] = "name[key1=other]";
-
- std::map<std::string, std::string> aliases2 = readAliasesFile( filename );
- BOOST_CHECK( aliases == aliases2 );
-}
--- /dev/null
+/* This file is part of libDAI - http://www.libdai.org/
+ *
+ * libDAI is licensed under the terms of the GNU General Public License version
+ * 2, or (at your option) any later version. libDAI is distributed without any
+ * warranty. See the file COPYING for more details.
+ *
+ * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/alldai.h>
+#include <strstream>
+#include <fstream>
+
+
+using namespace dai;
+
+
+const double tol = 1e-8;
+
+
+#define BOOST_TEST_MODULE DAIAlgTest
+
+
+#include <boost/test/unit_test.hpp>
+
+
+BOOST_AUTO_TEST_CASE( newInfAlgTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 2 );
+ Var v2( 2, 2 );
+ Var v3( 3, 2 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v03( v0, v3 );
+ VarSet v12( v1, v2 );
+ VarSet v13( v1, v3 );
+ VarSet v23( v2, v3 );
+ std::vector<Factor> facs;
+ facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
+ facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
+ facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
+ facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
+ facs.push_back( createFactorIsing( v0, -1.0 ) );
+ facs.push_back( createFactorIsing( v1, -1.0 ) );
+ facs.push_back( createFactorIsing( v2, -1.0 ) );
+ facs.push_back( createFactorIsing( v3, 1.0 ) );
+ Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
+ FactorGraph fg( facs );
+ VarSet vs = v01;
+
+ InfAlg* ia = newInfAlg( "EXACT", fg, PropertySet()("verbose",(size_t)0) );
+ ia->init();
+ ia->run();
+ BOOST_CHECK( dist( ia->belief( v01 ), joint.marginal( vs ), DISTTV ) < tol );
+ BOOST_CHECK( dist( calcMarginal( *ia, vs, true ), joint.marginal( vs ), DISTTV ) < tol );
+ delete ia;
+
+ ia = newInfAlgFromString( "EXACT[verbose=0]", fg );
+ ia->init();
+ ia->run();
+ BOOST_CHECK( dist( ia->belief( v01 ), joint.marginal( vs ), DISTTV ) < tol );
+ BOOST_CHECK( dist( calcMarginal( *ia, vs, true ), joint.marginal( vs ), DISTTV ) < tol );
+ delete ia;
+
+ std::map<std::string, std::string> aliases;
+ aliases["alias"] = "EXACT[verbose=1]";
+ ia = newInfAlgFromString( "alias[verbose=0]", fg, aliases );
+ ia->init();
+ ia->run();
+ BOOST_CHECK( dist( ia->belief( v01 ), joint.marginal( vs ), DISTTV ) < tol );
+ BOOST_CHECK( dist( calcMarginal( *ia, vs, true ), joint.marginal( vs ), DISTTV ) < tol );
+ delete ia;
+}
+
+
+BOOST_AUTO_TEST_CASE( parseTest ) {
+ std::pair<std::string, PropertySet> nameProps = parseNameProperties( "name" );
+ BOOST_CHECK_EQUAL( nameProps.first, "name" );
+ BOOST_CHECK_EQUAL( nameProps.second.size(), 0 );
+
+ nameProps = parseNameProperties( "name[]" );
+ BOOST_CHECK_EQUAL( nameProps.first, "name" );
+ BOOST_CHECK_EQUAL( nameProps.second.size(), 0 );
+
+ nameProps = parseNameProperties( "name[key1=value,key2=0.5]" );
+ BOOST_CHECK_EQUAL( nameProps.first, "name" );
+ BOOST_CHECK_EQUAL( nameProps.second.size(), 2 );
+ BOOST_CHECK( nameProps.second.hasKey( "key1" ) );
+ BOOST_CHECK( nameProps.second.hasKey( "key2" ) );
+ BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key1"), "value" );
+ BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key2"), "0.5" );
+
+ std::map<std::string, std::string> aliases;
+ aliases["alias"] = "name[key1=other]";
+
+ nameProps = parseNameProperties( "alias", aliases );
+ BOOST_CHECK_EQUAL( nameProps.first, "name" );
+ BOOST_CHECK_EQUAL( nameProps.second.size(), 1 );
+ BOOST_CHECK( nameProps.second.hasKey( "key1" ) );
+ BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key1"), "other" );
+
+ nameProps = parseNameProperties( "alias[]", aliases );
+ BOOST_CHECK_EQUAL( nameProps.first, "name" );
+ BOOST_CHECK_EQUAL( nameProps.second.size(), 1 );
+ BOOST_CHECK( nameProps.second.hasKey( "key1" ) );
+ BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key1"), "other" );
+
+ nameProps = parseNameProperties( "alias[key1=value,key2=0.5]", aliases );
+ BOOST_CHECK_EQUAL( nameProps.first, "name" );
+ BOOST_CHECK_EQUAL( nameProps.second.size(), 2 );
+ BOOST_CHECK( nameProps.second.hasKey( "key1" ) );
+ BOOST_CHECK( nameProps.second.hasKey( "key2" ) );
+ BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key1"), "value" );
+ BOOST_CHECK_EQUAL( nameProps.second.getAs<std::string>("key2"), "0.5" );
+}
+
+
+BOOST_AUTO_TEST_CASE( readAliasFileTest ) {
+ std::ofstream outfile;
+ std::string filename( "alldai_test.aliases" );
+ outfile.open( filename.c_str() );
+ if( outfile.is_open() ) {
+ outfile << "alias:\tname[key1=other]" << std::endl;
+ outfile.close();
+ } else
+ DAI_THROWE(CANNOT_WRITE_FILE,"Cannot write to file " + std::string(filename));
+
+ std::map<std::string, std::string> aliases;
+ aliases["alias"] = "name[key1=other]";
+
+ std::map<std::string, std::string> aliases2 = readAliasesFile( filename );
+ BOOST_CHECK( aliases == aliases2 );
+}
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/bipgraph.h>
-#include <vector>
-#include <strstream>
-
-
-using namespace dai;
-
-
-#define BOOST_TEST_MODULE BipartiteGraphTest
-
-
-#include <boost/test/unit_test.hpp>
-
-
-BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
- // check constructors
- typedef BipartiteGraph::Edge Edge;
-
- BipartiteGraph G;
- BOOST_CHECK_EQUAL( G.nrNodes1(), 0 );
- BOOST_CHECK_EQUAL( G.nrNodes2(), 0 );
- BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- G.checkConsistency();
-
- BipartiteGraph G1( 2, 3 );
- BOOST_CHECK_EQUAL( G1.nrNodes1(), 2 );
- BOOST_CHECK_EQUAL( G1.nrNodes2(), 3 );
- BOOST_CHECK_EQUAL( G1.nrEdges(), 0 );
- BOOST_CHECK( !G1.isConnected() );
- BOOST_CHECK( !G1.isTree() );
- BOOST_CHECK( !(G1 == G) );
- G1.checkConsistency();
-
- std::vector<Edge> edges;
- edges.push_back( Edge(0, 0) );
- edges.push_back( Edge(0, 1) );
- edges.push_back( Edge(1, 1) );
- edges.push_back( Edge(1, 2) );
- edges.push_back( Edge(1, 2) );
- BipartiteGraph G2( 2, 3, edges.begin(), edges.end() );
- BOOST_CHECK_EQUAL( G2.nrNodes1(), 2 );
- BOOST_CHECK_EQUAL( G2.nrNodes2(), 3 );
- BOOST_CHECK_EQUAL( G2.nrEdges(), 4 );
- BOOST_CHECK( G2.isConnected() );
- BOOST_CHECK( G2.isTree() );
- BOOST_CHECK( !(G2 == G) );
- BOOST_CHECK( !(G2 == G1) );
- G2.checkConsistency();
-
- edges.push_back( Edge(1, 0) );
- BipartiteGraph G3( 2, 3, edges.begin(), edges.end() );
- BOOST_CHECK_EQUAL( G3.nrNodes1(), 2 );
- BOOST_CHECK_EQUAL( G3.nrNodes2(), 3 );
- BOOST_CHECK_EQUAL( G3.nrEdges(), 5 );
- BOOST_CHECK( G3.isConnected() );
- BOOST_CHECK( !G3.isTree() );
- BOOST_CHECK( !(G3 == G) );
- BOOST_CHECK( !(G3 == G1) );
- BOOST_CHECK( !(G3 == G2) );
- G3.checkConsistency();
-
- BipartiteGraph G4( 3, 3, edges.begin(), edges.end() );
- BOOST_CHECK_EQUAL( G4.nrNodes1(), 3 );
- BOOST_CHECK_EQUAL( G4.nrNodes2(), 3 );
- BOOST_CHECK_EQUAL( G4.nrEdges(), 5 );
- BOOST_CHECK( !G4.isConnected() );
- BOOST_CHECK( !G4.isTree() );
- BOOST_CHECK( !(G4 == G) );
- BOOST_CHECK( !(G4 == G1) );
- BOOST_CHECK( !(G4 == G2) );
- BOOST_CHECK( !(G4 == G3) );
- G4.checkConsistency();
-
- G.construct( 3, 3, edges.begin(), edges.end() );
- BOOST_CHECK_EQUAL( G.nrNodes1(), 3 );
- BOOST_CHECK_EQUAL( G.nrNodes2(), 3 );
- BOOST_CHECK_EQUAL( G.nrEdges(), 5 );
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- BOOST_CHECK( !(G == G1) );
- BOOST_CHECK( !(G == G2) );
- BOOST_CHECK( !(G == G3) );
- BOOST_CHECK( G == G4 );
- G.checkConsistency();
-
- BipartiteGraph G5( G4 );
- BOOST_CHECK( G5 == G4 );
-
- BipartiteGraph G6 = G4;
- BOOST_CHECK( G6 == G4 );
-}
-
-
-BOOST_AUTO_TEST_CASE( NeighborTest ) {
- // check nb() accessor / mutator
- typedef BipartiteGraph::Edge Edge;
- std::vector<Edge> edges;
- edges.push_back( Edge(0, 0) );
- edges.push_back( Edge(0, 1) );
- edges.push_back( Edge(1, 1) );
- edges.push_back( Edge(1, 2) );
- BipartiteGraph G( 2, 3, edges.begin(), edges.end() );
- BOOST_CHECK_EQUAL( G.nb1(0).size(), 2 );
- BOOST_CHECK_EQUAL( G.nb1(1).size(), 2 );
- BOOST_CHECK_EQUAL( G.nb2(0).size(), 1 );
- BOOST_CHECK_EQUAL( G.nb2(1).size(), 2 );
- BOOST_CHECK_EQUAL( G.nb2(2).size(), 1 );
- BOOST_CHECK_EQUAL( G.nb1(0,0).iter, 0 );
- BOOST_CHECK_EQUAL( G.nb1(0,0).node, 0 );
- BOOST_CHECK_EQUAL( G.nb1(0,0).dual, 0 );
- BOOST_CHECK_EQUAL( G.nb1(0,1).iter, 1 );
- BOOST_CHECK_EQUAL( G.nb1(0,1).node, 1 );
- BOOST_CHECK_EQUAL( G.nb1(0,1).dual, 0 );
- BOOST_CHECK_EQUAL( G.nb1(1,0).iter, 0 );
- BOOST_CHECK_EQUAL( G.nb1(1,0).node, 1 );
- BOOST_CHECK_EQUAL( G.nb1(1,0).dual, 1 );
- BOOST_CHECK_EQUAL( G.nb1(1,1).iter, 1 );
- BOOST_CHECK_EQUAL( G.nb1(1,1).node, 2 );
- BOOST_CHECK_EQUAL( G.nb1(1,1).dual, 0 );
- BOOST_CHECK_EQUAL( G.nb2(0,0).iter, 0 );
- BOOST_CHECK_EQUAL( G.nb2(0,0).node, 0 );
- BOOST_CHECK_EQUAL( G.nb2(0,0).dual, 0 );
- BOOST_CHECK_EQUAL( G.nb2(1,0).iter, 0 );
- BOOST_CHECK_EQUAL( G.nb2(1,0).node, 0 );
- BOOST_CHECK_EQUAL( G.nb2(1,0).dual, 1 );
- BOOST_CHECK_EQUAL( G.nb2(1,1).iter, 1 );
- BOOST_CHECK_EQUAL( G.nb2(1,1).node, 1 );
- BOOST_CHECK_EQUAL( G.nb2(1,1).dual, 0 );
- BOOST_CHECK_EQUAL( G.nb2(2,0).iter, 0 );
- BOOST_CHECK_EQUAL( G.nb2(2,0).node, 1 );
- BOOST_CHECK_EQUAL( G.nb2(2,0).dual, 1 );
-}
-
-
-BOOST_AUTO_TEST_CASE( AddEraseTest ) {
- // check addition and erasure of nodes and edges
- typedef BipartiteGraph::Edge Edge;
- std::vector<Edge> edges;
- edges.push_back( Edge( 0, 0 ) );
- edges.push_back( Edge( 0, 1 ) );
- edges.push_back( Edge( 1, 1 ) );
- BipartiteGraph G( 2, 2, edges.begin(), edges.end() );
- G.checkConsistency();
- BOOST_CHECK_EQUAL( G.nrNodes1(), 2 );
- BOOST_CHECK_EQUAL( G.nrNodes2(), 2 );
- BOOST_CHECK_EQUAL( G.nrEdges(), 3 );
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- BOOST_CHECK_EQUAL( G.addNode1(), 2 );
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- BOOST_CHECK_EQUAL( G.addNode2(), 2 );
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- BOOST_CHECK_EQUAL( G.addNode1(), 3 );
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- G.checkConsistency();
- std::vector<size_t> nbs;
- nbs.push_back( 2 );
- nbs.push_back( 0 );
- BOOST_CHECK_EQUAL( G.addNode1( nbs.begin(), nbs.end() ), 4 );
- G.checkConsistency();
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- BOOST_CHECK_EQUAL( G.addNode2( nbs.begin(), nbs.end() ), 3 );
- G.checkConsistency();
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- G.addEdge( 3, 3 );
- G.checkConsistency();
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- G.addEdge( 1, 3 );
- G.checkConsistency();
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- BOOST_CHECK_EQUAL( G.nrNodes1(), 5 );
- BOOST_CHECK_EQUAL( G.nrNodes2(), 4 );
- BOOST_CHECK_EQUAL( G.nrEdges(), 9 );
- G.eraseEdge( 0, 3 );
- G.checkConsistency();
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- G.eraseEdge( 4, 2 );
- G.checkConsistency();
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- G.eraseNode2( 2 );
- G.checkConsistency();
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- G.eraseNode1( 0 );
- G.checkConsistency();
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- G.addEdge( 1, 0 );
- G.checkConsistency();
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- G.eraseNode1( 2 );
- G.checkConsistency();
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- G.eraseNode2( 2 );
- G.checkConsistency();
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- G.addEdge( 1, 1 );
- G.checkConsistency();
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- G.eraseNode2( 1 );
- G.checkConsistency();
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- G.eraseNode1( 1 );
- G.checkConsistency();
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- G.addEdge( 0, 0 );
- G.checkConsistency();
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- G.eraseNode2( 0 );
- G.checkConsistency();
- BOOST_CHECK( !G.isConnected() );
- BOOST_CHECK( !G.isTree() );
- G.eraseNode1( 0 );
- G.checkConsistency();
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- G.eraseNode1( 0 );
- G.checkConsistency();
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- BOOST_CHECK_EQUAL( G.nrNodes1(), 0 );
- BOOST_CHECK_EQUAL( G.nrNodes2(), 0 );
- BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
-}
-
-
-BOOST_AUTO_TEST_CASE( RandomAddEraseTest ) {
- // check adding and erasing nodes and edges randomly
- BipartiteGraph G;
- for( size_t maxN1 = 2; maxN1 < 10; maxN1++ )
- for( size_t maxN2 = 2; maxN2 < 10; maxN2++ )
- for( size_t repeats = 0; repeats < 100000; repeats++ ) {
- size_t action = rnd( 5 );
- size_t N1 = G.nrNodes1();
- size_t N2 = G.nrNodes2();
- size_t M = G.nrEdges();
- size_t maxM = N1 * N2;
- if( action == 0 ) {
- // add node
- if( rnd( 2 ) == 0 ) {
- // add node of type 1
- if( N1 < maxN1 )
- G.addNode1();
- } else {
- // add node of type 2
- if( N2 < maxN2 )
- G.addNode2();
- }
- } else if( action == 1 ) {
- // erase node
- if( rnd( 2 ) == 0 ) {
- // erase node of type 1
- if( N1 > 0 )
- G.eraseNode1( rnd( N1 ) );
- } else {
- // erase node of type 2
- if( N2 > 0 )
- G.eraseNode2( rnd( N2 ) );
- }
- } else if( action == 2 || action == 3 ) {
- // add edge
- if( N1 >= 1 && N2 >= 1 && M < maxM ) {
- size_t n1 = 0;
- size_t n2 = 0;
- if( rnd( 2 ) == 0 ) {
- do {
- n1 = rnd( N1 );
- } while( G.nb1(n1).size() >= N2 );
- do {
- n2 = rnd( N2 );
- } while( G.hasEdge( n1, n2 ) );
- } else {
- do {
- n2 = rnd( N2 );
- } while( G.nb2(n2).size() >= N1 );
- do {
- n1 = rnd( N1 );
- } while( G.hasEdge( n1, n2 ) );
- }
- G.addEdge( n1, n2 );
- }
- } else if( action == 4 ) {
- // erase edge
- if( M > 0 ) {
- size_t n1 = 0;
- size_t n2 = 0;
- if( rnd( 2 ) == 0 ) {
- do {
- n1 = rnd( N1 );
- } while( G.nb1(n1).size() == 0 );
- do {
- n2 = rnd( N2 );
- } while( !G.hasEdge( n1, n2 ) );
- } else {
- do {
- n2 = rnd( N2 );
- } while( G.nb2(n2).size() == 0 );
- do {
- n1 = rnd( N1 );
- } while( !G.hasEdge( n1, n2 ) );
- }
- G.eraseEdge( n1, n2 );
- }
- }
- G.checkConsistency();
- }
-}
-
-
-BOOST_AUTO_TEST_CASE( QueriesTest ) {
- // check queries which have not been tested in another test case
- typedef BipartiteGraph::Edge Edge;
- std::vector<Edge> edges;
- edges.push_back( Edge( 0, 1 ) );
- edges.push_back( Edge( 1, 1 ) );
- edges.push_back( Edge( 0, 0 ) );
- BipartiteGraph G( 3, 2, edges.begin(), edges.end() );
- G.checkConsistency();
- SmallSet<size_t> v;
- SmallSet<size_t> v0( 0 );
- SmallSet<size_t> v1( 1 );
- SmallSet<size_t> v01( 0, 1 );
- BOOST_CHECK_EQUAL( G.delta1( 0, true ), v01 );
- BOOST_CHECK_EQUAL( G.delta1( 1, true ), v01 );
- BOOST_CHECK_EQUAL( G.delta1( 2, true ), v );
- BOOST_CHECK_EQUAL( G.delta2( 0, true ), v01 );
- BOOST_CHECK_EQUAL( G.delta2( 1, true ), v01 );
- BOOST_CHECK_EQUAL( G.delta1( 0, false ), v1 );
- BOOST_CHECK_EQUAL( G.delta1( 1, false ), v0 );
- BOOST_CHECK_EQUAL( G.delta1( 2, false ), v );
- BOOST_CHECK_EQUAL( G.delta2( 0, false ), v1 );
- BOOST_CHECK_EQUAL( G.delta2( 1, false ), v0 );
- BOOST_CHECK( G.hasEdge( 0, 0 ) );
- BOOST_CHECK( G.hasEdge( 0, 1 ) );
- BOOST_CHECK( G.hasEdge( 1, 1 ) );
- BOOST_CHECK( !G.hasEdge( 1, 0 ) );
- BOOST_CHECK( !G.hasEdge( 2, 0 ) );
- BOOST_CHECK( !G.hasEdge( 2, 1 ) );
- BOOST_CHECK_EQUAL( G.findNb1( 0, 0 ), 1 );
- BOOST_CHECK_EQUAL( G.findNb1( 0, 1 ), 0 );
- BOOST_CHECK_EQUAL( G.findNb1( 1, 1 ), 0 );
- BOOST_CHECK_EQUAL( G.findNb2( 0, 0 ), 0 );
- BOOST_CHECK_EQUAL( G.findNb2( 0, 1 ), 0 );
- BOOST_CHECK_EQUAL( G.findNb2( 1, 1 ), 1 );
-}
-
-
-BOOST_AUTO_TEST_CASE( StreamTest ) {
- // check printDot
- typedef BipartiteGraph::Edge Edge;
- std::vector<Edge> edges;
- edges.push_back( Edge(0, 0) );
- edges.push_back( Edge(0, 1) );
- edges.push_back( Edge(1, 1) );
- edges.push_back( Edge(1, 2) );
- BipartiteGraph G( 2, 3, edges.begin(), edges.end() );
-
- std::stringstream ss;
- std::string s;
-
- G.printDot( ss );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph BipartiteGraph {" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box,width=0.3,height=0.3,fixedsize=true];" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty2;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- y0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- y1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- y1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- y2;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
-
- ss << G;
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph BipartiteGraph {" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box,width=0.3,height=0.3,fixedsize=true];" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty2;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- y0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- y1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- y1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- y2;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
-}
--- /dev/null
+/* This file is part of libDAI - http://www.libdai.org/
+ *
+ * libDAI is licensed under the terms of the GNU General Public License version
+ * 2, or (at your option) any later version. libDAI is distributed without any
+ * warranty. See the file COPYING for more details.
+ *
+ * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/bipgraph.h>
+#include <vector>
+#include <strstream>
+
+
+using namespace dai;
+
+
+#define BOOST_TEST_MODULE BipartiteGraphTest
+
+
+#include <boost/test/unit_test.hpp>
+
+
+BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
+ // check constructors
+ typedef BipartiteGraph::Edge Edge;
+
+ BipartiteGraph G;
+ BOOST_CHECK_EQUAL( G.nrNodes1(), 0 );
+ BOOST_CHECK_EQUAL( G.nrNodes2(), 0 );
+ BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ G.checkConsistency();
+
+ BipartiteGraph G1( 2, 3 );
+ BOOST_CHECK_EQUAL( G1.nrNodes1(), 2 );
+ BOOST_CHECK_EQUAL( G1.nrNodes2(), 3 );
+ BOOST_CHECK_EQUAL( G1.nrEdges(), 0 );
+ BOOST_CHECK( !G1.isConnected() );
+ BOOST_CHECK( !G1.isTree() );
+ BOOST_CHECK( !(G1 == G) );
+ G1.checkConsistency();
+
+ std::vector<Edge> edges;
+ edges.push_back( Edge(0, 0) );
+ edges.push_back( Edge(0, 1) );
+ edges.push_back( Edge(1, 1) );
+ edges.push_back( Edge(1, 2) );
+ edges.push_back( Edge(1, 2) );
+ BipartiteGraph G2( 2, 3, edges.begin(), edges.end() );
+ BOOST_CHECK_EQUAL( G2.nrNodes1(), 2 );
+ BOOST_CHECK_EQUAL( G2.nrNodes2(), 3 );
+ BOOST_CHECK_EQUAL( G2.nrEdges(), 4 );
+ BOOST_CHECK( G2.isConnected() );
+ BOOST_CHECK( G2.isTree() );
+ BOOST_CHECK( !(G2 == G) );
+ BOOST_CHECK( !(G2 == G1) );
+ G2.checkConsistency();
+
+ edges.push_back( Edge(1, 0) );
+ BipartiteGraph G3( 2, 3, edges.begin(), edges.end() );
+ BOOST_CHECK_EQUAL( G3.nrNodes1(), 2 );
+ BOOST_CHECK_EQUAL( G3.nrNodes2(), 3 );
+ BOOST_CHECK_EQUAL( G3.nrEdges(), 5 );
+ BOOST_CHECK( G3.isConnected() );
+ BOOST_CHECK( !G3.isTree() );
+ BOOST_CHECK( !(G3 == G) );
+ BOOST_CHECK( !(G3 == G1) );
+ BOOST_CHECK( !(G3 == G2) );
+ G3.checkConsistency();
+
+ BipartiteGraph G4( 3, 3, edges.begin(), edges.end() );
+ BOOST_CHECK_EQUAL( G4.nrNodes1(), 3 );
+ BOOST_CHECK_EQUAL( G4.nrNodes2(), 3 );
+ BOOST_CHECK_EQUAL( G4.nrEdges(), 5 );
+ BOOST_CHECK( !G4.isConnected() );
+ BOOST_CHECK( !G4.isTree() );
+ BOOST_CHECK( !(G4 == G) );
+ BOOST_CHECK( !(G4 == G1) );
+ BOOST_CHECK( !(G4 == G2) );
+ BOOST_CHECK( !(G4 == G3) );
+ G4.checkConsistency();
+
+ G.construct( 3, 3, edges.begin(), edges.end() );
+ BOOST_CHECK_EQUAL( G.nrNodes1(), 3 );
+ BOOST_CHECK_EQUAL( G.nrNodes2(), 3 );
+ BOOST_CHECK_EQUAL( G.nrEdges(), 5 );
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ BOOST_CHECK( !(G == G1) );
+ BOOST_CHECK( !(G == G2) );
+ BOOST_CHECK( !(G == G3) );
+ BOOST_CHECK( G == G4 );
+ G.checkConsistency();
+
+ BipartiteGraph G5( G4 );
+ BOOST_CHECK( G5 == G4 );
+
+ BipartiteGraph G6 = G4;
+ BOOST_CHECK( G6 == G4 );
+}
+
+
+BOOST_AUTO_TEST_CASE( NeighborTest ) {
+ // check nb() accessor / mutator
+ typedef BipartiteGraph::Edge Edge;
+ std::vector<Edge> edges;
+ edges.push_back( Edge(0, 0) );
+ edges.push_back( Edge(0, 1) );
+ edges.push_back( Edge(1, 1) );
+ edges.push_back( Edge(1, 2) );
+ BipartiteGraph G( 2, 3, edges.begin(), edges.end() );
+ BOOST_CHECK_EQUAL( G.nb1(0).size(), 2 );
+ BOOST_CHECK_EQUAL( G.nb1(1).size(), 2 );
+ BOOST_CHECK_EQUAL( G.nb2(0).size(), 1 );
+ BOOST_CHECK_EQUAL( G.nb2(1).size(), 2 );
+ BOOST_CHECK_EQUAL( G.nb2(2).size(), 1 );
+ BOOST_CHECK_EQUAL( G.nb1(0,0).iter, 0 );
+ BOOST_CHECK_EQUAL( G.nb1(0,0).node, 0 );
+ BOOST_CHECK_EQUAL( G.nb1(0,0).dual, 0 );
+ BOOST_CHECK_EQUAL( G.nb1(0,1).iter, 1 );
+ BOOST_CHECK_EQUAL( G.nb1(0,1).node, 1 );
+ BOOST_CHECK_EQUAL( G.nb1(0,1).dual, 0 );
+ BOOST_CHECK_EQUAL( G.nb1(1,0).iter, 0 );
+ BOOST_CHECK_EQUAL( G.nb1(1,0).node, 1 );
+ BOOST_CHECK_EQUAL( G.nb1(1,0).dual, 1 );
+ BOOST_CHECK_EQUAL( G.nb1(1,1).iter, 1 );
+ BOOST_CHECK_EQUAL( G.nb1(1,1).node, 2 );
+ BOOST_CHECK_EQUAL( G.nb1(1,1).dual, 0 );
+ BOOST_CHECK_EQUAL( G.nb2(0,0).iter, 0 );
+ BOOST_CHECK_EQUAL( G.nb2(0,0).node, 0 );
+ BOOST_CHECK_EQUAL( G.nb2(0,0).dual, 0 );
+ BOOST_CHECK_EQUAL( G.nb2(1,0).iter, 0 );
+ BOOST_CHECK_EQUAL( G.nb2(1,0).node, 0 );
+ BOOST_CHECK_EQUAL( G.nb2(1,0).dual, 1 );
+ BOOST_CHECK_EQUAL( G.nb2(1,1).iter, 1 );
+ BOOST_CHECK_EQUAL( G.nb2(1,1).node, 1 );
+ BOOST_CHECK_EQUAL( G.nb2(1,1).dual, 0 );
+ BOOST_CHECK_EQUAL( G.nb2(2,0).iter, 0 );
+ BOOST_CHECK_EQUAL( G.nb2(2,0).node, 1 );
+ BOOST_CHECK_EQUAL( G.nb2(2,0).dual, 1 );
+}
+
+
+BOOST_AUTO_TEST_CASE( AddEraseTest ) {
+ // check addition and erasure of nodes and edges
+ typedef BipartiteGraph::Edge Edge;
+ std::vector<Edge> edges;
+ edges.push_back( Edge( 0, 0 ) );
+ edges.push_back( Edge( 0, 1 ) );
+ edges.push_back( Edge( 1, 1 ) );
+ BipartiteGraph G( 2, 2, edges.begin(), edges.end() );
+ G.checkConsistency();
+ BOOST_CHECK_EQUAL( G.nrNodes1(), 2 );
+ BOOST_CHECK_EQUAL( G.nrNodes2(), 2 );
+ BOOST_CHECK_EQUAL( G.nrEdges(), 3 );
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ BOOST_CHECK_EQUAL( G.addNode1(), 2 );
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ BOOST_CHECK_EQUAL( G.addNode2(), 2 );
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ BOOST_CHECK_EQUAL( G.addNode1(), 3 );
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ G.checkConsistency();
+ std::vector<size_t> nbs;
+ nbs.push_back( 2 );
+ nbs.push_back( 0 );
+ BOOST_CHECK_EQUAL( G.addNode1( nbs.begin(), nbs.end() ), 4 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ BOOST_CHECK_EQUAL( G.addNode2( nbs.begin(), nbs.end() ), 3 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ G.addEdge( 3, 3 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ G.addEdge( 1, 3 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ BOOST_CHECK_EQUAL( G.nrNodes1(), 5 );
+ BOOST_CHECK_EQUAL( G.nrNodes2(), 4 );
+ BOOST_CHECK_EQUAL( G.nrEdges(), 9 );
+ G.eraseEdge( 0, 3 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ G.eraseEdge( 4, 2 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ G.eraseNode2( 2 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ G.eraseNode1( 0 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ G.addEdge( 1, 0 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ G.eraseNode1( 2 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ G.eraseNode2( 2 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ G.addEdge( 1, 1 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ G.eraseNode2( 1 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ G.eraseNode1( 1 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ G.addEdge( 0, 0 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ G.eraseNode2( 0 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isConnected() );
+ BOOST_CHECK( !G.isTree() );
+ G.eraseNode1( 0 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ G.eraseNode1( 0 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ BOOST_CHECK_EQUAL( G.nrNodes1(), 0 );
+ BOOST_CHECK_EQUAL( G.nrNodes2(), 0 );
+ BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
+}
+
+
+BOOST_AUTO_TEST_CASE( RandomAddEraseTest ) {
+ // check adding and erasing nodes and edges randomly
+ BipartiteGraph G;
+ for( size_t maxN1 = 2; maxN1 < 10; maxN1++ )
+ for( size_t maxN2 = 2; maxN2 < 10; maxN2++ )
+ for( size_t repeats = 0; repeats < 100000; repeats++ ) {
+ size_t action = rnd( 5 );
+ size_t N1 = G.nrNodes1();
+ size_t N2 = G.nrNodes2();
+ size_t M = G.nrEdges();
+ size_t maxM = N1 * N2;
+ if( action == 0 ) {
+ // add node
+ if( rnd( 2 ) == 0 ) {
+ // add node of type 1
+ if( N1 < maxN1 )
+ G.addNode1();
+ } else {
+ // add node of type 2
+ if( N2 < maxN2 )
+ G.addNode2();
+ }
+ } else if( action == 1 ) {
+ // erase node
+ if( rnd( 2 ) == 0 ) {
+ // erase node of type 1
+ if( N1 > 0 )
+ G.eraseNode1( rnd( N1 ) );
+ } else {
+ // erase node of type 2
+ if( N2 > 0 )
+ G.eraseNode2( rnd( N2 ) );
+ }
+ } else if( action == 2 || action == 3 ) {
+ // add edge
+ if( N1 >= 1 && N2 >= 1 && M < maxM ) {
+ size_t n1 = 0;
+ size_t n2 = 0;
+ if( rnd( 2 ) == 0 ) {
+ do {
+ n1 = rnd( N1 );
+ } while( G.nb1(n1).size() >= N2 );
+ do {
+ n2 = rnd( N2 );
+ } while( G.hasEdge( n1, n2 ) );
+ } else {
+ do {
+ n2 = rnd( N2 );
+ } while( G.nb2(n2).size() >= N1 );
+ do {
+ n1 = rnd( N1 );
+ } while( G.hasEdge( n1, n2 ) );
+ }
+ G.addEdge( n1, n2 );
+ }
+ } else if( action == 4 ) {
+ // erase edge
+ if( M > 0 ) {
+ size_t n1 = 0;
+ size_t n2 = 0;
+ if( rnd( 2 ) == 0 ) {
+ do {
+ n1 = rnd( N1 );
+ } while( G.nb1(n1).size() == 0 );
+ do {
+ n2 = rnd( N2 );
+ } while( !G.hasEdge( n1, n2 ) );
+ } else {
+ do {
+ n2 = rnd( N2 );
+ } while( G.nb2(n2).size() == 0 );
+ do {
+ n1 = rnd( N1 );
+ } while( !G.hasEdge( n1, n2 ) );
+ }
+ G.eraseEdge( n1, n2 );
+ }
+ }
+ G.checkConsistency();
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE( QueriesTest ) {
+ // check queries which have not been tested in another test case
+ typedef BipartiteGraph::Edge Edge;
+ std::vector<Edge> edges;
+ edges.push_back( Edge( 0, 1 ) );
+ edges.push_back( Edge( 1, 1 ) );
+ edges.push_back( Edge( 0, 0 ) );
+ BipartiteGraph G( 3, 2, edges.begin(), edges.end() );
+ G.checkConsistency();
+ SmallSet<size_t> v;
+ SmallSet<size_t> v0( 0 );
+ SmallSet<size_t> v1( 1 );
+ SmallSet<size_t> v01( 0, 1 );
+ BOOST_CHECK_EQUAL( G.delta1( 0, true ), v01 );
+ BOOST_CHECK_EQUAL( G.delta1( 1, true ), v01 );
+ BOOST_CHECK_EQUAL( G.delta1( 2, true ), v );
+ BOOST_CHECK_EQUAL( G.delta2( 0, true ), v01 );
+ BOOST_CHECK_EQUAL( G.delta2( 1, true ), v01 );
+ BOOST_CHECK_EQUAL( G.delta1( 0, false ), v1 );
+ BOOST_CHECK_EQUAL( G.delta1( 1, false ), v0 );
+ BOOST_CHECK_EQUAL( G.delta1( 2, false ), v );
+ BOOST_CHECK_EQUAL( G.delta2( 0, false ), v1 );
+ BOOST_CHECK_EQUAL( G.delta2( 1, false ), v0 );
+ BOOST_CHECK( G.hasEdge( 0, 0 ) );
+ BOOST_CHECK( G.hasEdge( 0, 1 ) );
+ BOOST_CHECK( G.hasEdge( 1, 1 ) );
+ BOOST_CHECK( !G.hasEdge( 1, 0 ) );
+ BOOST_CHECK( !G.hasEdge( 2, 0 ) );
+ BOOST_CHECK( !G.hasEdge( 2, 1 ) );
+ BOOST_CHECK_EQUAL( G.findNb1( 0, 0 ), 1 );
+ BOOST_CHECK_EQUAL( G.findNb1( 0, 1 ), 0 );
+ BOOST_CHECK_EQUAL( G.findNb1( 1, 1 ), 0 );
+ BOOST_CHECK_EQUAL( G.findNb2( 0, 0 ), 0 );
+ BOOST_CHECK_EQUAL( G.findNb2( 0, 1 ), 0 );
+ BOOST_CHECK_EQUAL( G.findNb2( 1, 1 ), 1 );
+}
+
+
+BOOST_AUTO_TEST_CASE( StreamTest ) {
+ // check printDot
+ typedef BipartiteGraph::Edge Edge;
+ std::vector<Edge> edges;
+ edges.push_back( Edge(0, 0) );
+ edges.push_back( Edge(0, 1) );
+ edges.push_back( Edge(1, 1) );
+ edges.push_back( Edge(1, 2) );
+ BipartiteGraph G( 2, 3, edges.begin(), edges.end() );
+
+ std::stringstream ss;
+ std::string s;
+
+ G.printDot( ss );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph BipartiteGraph {" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box,width=0.3,height=0.3,fixedsize=true];" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty2;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- y0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- y1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- y1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- y2;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
+
+ ss << G;
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph BipartiteGraph {" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box,width=0.3,height=0.3,fixedsize=true];" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ty2;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- y0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- y1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- y1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- y2;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
+}
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/clustergraph.h>
-#include <vector>
-#include <strstream>
-
-
-using namespace dai;
-
-
-const double tol = 1e-8;
-
-
-#define BOOST_TEST_MODULE ClusterGraphTest
-
-
-#include <boost/test/unit_test.hpp>
-#include <boost/test/floating_point_comparison.hpp>
-
-
-BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
- ClusterGraph G;
- BOOST_CHECK_EQUAL( G.clusters(), std::vector<VarSet>() );
- BOOST_CHECK( G.bipGraph() == BipartiteGraph() );
- BOOST_CHECK_EQUAL( G.nrVars(), 0 );
- BOOST_CHECK_EQUAL( G.nrClusters(), 0 );
-#ifdef DAI_DEBUG
- BOOST_CHECK_THROW( G.var( 0 ), Exception );
- BOOST_CHECK_THROW( G.cluster( 0 ), Exception );
-#endif
- BOOST_CHECK_THROW( G.findVar( Var( 0, 2 ) ), Exception );
-
- Var v0( 0, 2 );
- Var v1( 1, 3 );
- Var v2( 2, 2 );
- Var v3( 3, 4 );
- std::vector<Var> vs;
- vs.push_back( v0 );
- vs.push_back( v1 );
- vs.push_back( v2 );
- vs.push_back( v3 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v03( v0, v3 );
- VarSet v12( v1, v2 );
- VarSet v13( v1, v3 );
- VarSet v23( v2, v3 );
- std::vector<VarSet> cl;
- cl.push_back( v01 );
- cl.push_back( v12 );
- cl.push_back( v23 );
- cl.push_back( v13 );
- ClusterGraph G2( cl );
- BOOST_CHECK_EQUAL( G2.nrVars(), 4 );
- BOOST_CHECK_EQUAL( G2.nrClusters(), 4 );
- BOOST_CHECK_EQUAL( G2.vars(), vs );
- BOOST_CHECK_EQUAL( G2.clusters(), cl );
- BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
- BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
- BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
- BOOST_CHECK_EQUAL( G2.findVar( v3 ), 3 );
-
- ClusterGraph Gb( G );
- BOOST_CHECK( G.bipGraph() == Gb.bipGraph() );
- BOOST_CHECK( G.vars() == Gb.vars() );
- BOOST_CHECK( G.clusters() == Gb.clusters() );
-
- ClusterGraph Gc = G;
- BOOST_CHECK( G.bipGraph() == Gc.bipGraph() );
- BOOST_CHECK( G.vars() == Gc.vars() );
- BOOST_CHECK( G.clusters() == Gc.clusters() );
-
- ClusterGraph G2b( G2 );
- BOOST_CHECK( G2.bipGraph() == G2b.bipGraph() );
- BOOST_CHECK( G2.vars() == G2b.vars() );
- BOOST_CHECK( G2.clusters() == G2b.clusters() );
-
- ClusterGraph G2c = G2;
- BOOST_CHECK( G2.bipGraph() == G2c.bipGraph() );
- BOOST_CHECK( G2.vars() == G2c.vars() );
- BOOST_CHECK( G2.clusters() == G2c.clusters() );
-}
-
-
-BOOST_AUTO_TEST_CASE( QueriesTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 3 );
- Var v2( 2, 2 );
- Var v3( 3, 4 );
- Var v4( 4, 2 );
- std::vector<Var> vs;
- vs.push_back( v0 );
- vs.push_back( v1 );
- vs.push_back( v2 );
- vs.push_back( v3 );
- vs.push_back( v4 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v03( v0, v3 );
- VarSet v04( v0, v4 );
- VarSet v12( v1, v2 );
- VarSet v13( v1, v3 );
- VarSet v14( v1, v4 );
- VarSet v23( v2, v3 );
- VarSet v24( v2, v4 );
- VarSet v34( v3, v4 );
- VarSet v123 = v12 | v3;
- std::vector<VarSet> cl;
- cl.push_back( v01 );
- cl.push_back( v12 );
- cl.push_back( v123 );
- cl.push_back( v34 );
- cl.push_back( v04 );
- ClusterGraph G( cl );
-
- BOOST_CHECK_EQUAL( G.nrVars(), 5 );
- BOOST_CHECK_EQUAL( G.vars(), vs );
- BOOST_CHECK_EQUAL( G.var(0), v0 );
- BOOST_CHECK_EQUAL( G.var(1), v1 );
- BOOST_CHECK_EQUAL( G.var(2), v2 );
- BOOST_CHECK_EQUAL( G.var(3), v3 );
- BOOST_CHECK_EQUAL( G.var(4), v4 );
- BOOST_CHECK_EQUAL( G.nrClusters(), 5 );
- BOOST_CHECK_EQUAL( G.clusters(), cl );
- BOOST_CHECK_EQUAL( G.cluster(0), v01 );
- BOOST_CHECK_EQUAL( G.cluster(1), v12 );
- BOOST_CHECK_EQUAL( G.cluster(2), v123 );
- BOOST_CHECK_EQUAL( G.cluster(3), v34 );
- BOOST_CHECK_EQUAL( G.cluster(4), v04 );
- BOOST_CHECK_EQUAL( G.findVar( v0 ), 0 );
- BOOST_CHECK_EQUAL( G.findVar( v1 ), 1 );
- BOOST_CHECK_EQUAL( G.findVar( v2 ), 2 );
- BOOST_CHECK_EQUAL( G.findVar( v3 ), 3 );
- BOOST_CHECK_EQUAL( G.findVar( v4 ), 4 );
- BipartiteGraph H( 5, 5 );
- H.addEdge( 0, 0 );
- H.addEdge( 1, 0 );
- H.addEdge( 1, 1 );
- H.addEdge( 2, 1 );
- H.addEdge( 1, 2 );
- H.addEdge( 2, 2 );
- H.addEdge( 3, 2 );
- H.addEdge( 3, 3 );
- H.addEdge( 4, 3 );
- H.addEdge( 0, 4 );
- H.addEdge( 4, 4 );
- BOOST_CHECK( G.bipGraph() == H );
-
- BOOST_CHECK_EQUAL( G.delta( 0 ), v14 );
- BOOST_CHECK_EQUAL( G.delta( 1 ), v02 | v3 );
- BOOST_CHECK_EQUAL( G.delta( 2 ), v13 );
- BOOST_CHECK_EQUAL( G.delta( 3 ), v12 | v4 );
- BOOST_CHECK_EQUAL( G.delta( 4 ), v03 );
- BOOST_CHECK_EQUAL( G.Delta( 0 ), v14 | v0 );
- BOOST_CHECK_EQUAL( G.Delta( 1 ), v01 | v23 );
- BOOST_CHECK_EQUAL( G.Delta( 2 ), v13 | v2 );
- BOOST_CHECK_EQUAL( G.Delta( 3 ), v12 | v34 );
- BOOST_CHECK_EQUAL( G.Delta( 4 ), v03 | v4 );
-
- BOOST_CHECK( !G.adj( 0, 0 ) );
- BOOST_CHECK( G.adj( 0, 1 ) );
- BOOST_CHECK( !G.adj( 0, 2 ) );
- BOOST_CHECK( !G.adj( 0, 3 ) );
- BOOST_CHECK( G.adj( 0, 4 ) );
- BOOST_CHECK( G.adj( 1, 0 ) );
- BOOST_CHECK( !G.adj( 1, 1 ) );
- BOOST_CHECK( G.adj( 1, 2 ) );
- BOOST_CHECK( G.adj( 1, 3 ) );
- BOOST_CHECK( !G.adj( 1, 4 ) );
- BOOST_CHECK( !G.adj( 2, 0 ) );
- BOOST_CHECK( G.adj( 2, 1 ) );
- BOOST_CHECK( !G.adj( 2, 2 ) );
- BOOST_CHECK( G.adj( 2, 3 ) );
- BOOST_CHECK( !G.adj( 2, 4 ) );
- BOOST_CHECK( !G.adj( 3, 0 ) );
- BOOST_CHECK( G.adj( 3, 1 ) );
- BOOST_CHECK( G.adj( 3, 2 ) );
- BOOST_CHECK( !G.adj( 3, 3 ) );
- BOOST_CHECK( G.adj( 3, 4 ) );
- BOOST_CHECK( G.adj( 4, 0 ) );
- BOOST_CHECK( !G.adj( 4, 1 ) );
- BOOST_CHECK( !G.adj( 4, 2 ) );
- BOOST_CHECK( G.adj( 4, 3 ) );
- BOOST_CHECK( !G.adj( 4, 4 ) );
-
- BOOST_CHECK( G.isMaximal( 0 ) );
- BOOST_CHECK( !G.isMaximal( 1 ) );
- BOOST_CHECK( G.isMaximal( 2 ) );
- BOOST_CHECK( G.isMaximal( 3 ) );
- BOOST_CHECK( G.isMaximal( 4 ) );
-}
-
-
-BOOST_AUTO_TEST_CASE( OperationsTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 3 );
- Var v2( 2, 2 );
- Var v3( 3, 4 );
- Var v4( 4, 2 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v03( v0, v3 );
- VarSet v04( v0, v4 );
- VarSet v12( v1, v2 );
- VarSet v13( v1, v3 );
- VarSet v14( v1, v4 );
- VarSet v23( v2, v3 );
- VarSet v24( v2, v4 );
- VarSet v34( v3, v4 );
- VarSet v123 = v12 | v3;
- std::vector<VarSet> cl;
- cl.push_back( v01 );
- cl.push_back( v12 );
- cl.push_back( v123 );
- cl.push_back( v34 );
- cl.push_back( v04 );
- ClusterGraph G( cl );
-
- BipartiteGraph H( 5, 5 );
- H.addEdge( 0, 0 );
- H.addEdge( 1, 0 );
- H.addEdge( 1, 1 );
- H.addEdge( 2, 1 );
- H.addEdge( 1, 2 );
- H.addEdge( 2, 2 );
- H.addEdge( 3, 2 );
- H.addEdge( 3, 3 );
- H.addEdge( 4, 3 );
- H.addEdge( 0, 4 );
- H.addEdge( 4, 4 );
- BOOST_CHECK( G.bipGraph() == H );
-
- G.eraseNonMaximal();
- BOOST_CHECK_EQUAL( G.nrClusters(), 4 );
- H.eraseNode2( 1 );
- BOOST_CHECK( G.bipGraph() == H );
- G.eraseSubsuming( 4 );
- BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
- H.eraseNode2( 2 );
- H.eraseNode2( 2 );
- BOOST_CHECK( G.bipGraph() == H );
- G.insert( v34 );
- BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
- G.insert( v123 );
- BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
- H.addNode2();
- H.addEdge( 3, 2 );
- H.addEdge( 4, 2 );
- BOOST_CHECK( G.bipGraph() == H );
- G.insert( v12 );
- G.insert( v23 );
- BOOST_CHECK_EQUAL( G.nrClusters(), 5 );
- H.addNode2();
- H.addNode2();
- H.addEdge( 1, 3 );
- H.addEdge( 2, 3 );
- H.addEdge( 2, 4 );
- H.addEdge( 3, 4 );
- BOOST_CHECK( G.bipGraph() == H );
- G.eraseNonMaximal();
- BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
- H.eraseNode2( 3 );
- H.eraseNode2( 3 );
- BOOST_CHECK( G.bipGraph() == H );
- G.eraseSubsuming( 2 );
- BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
- H.eraseNode2( 1 );
- BOOST_CHECK( G.bipGraph() == H );
- G.eraseNonMaximal();
- BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
- BOOST_CHECK( G.bipGraph() == H );
- G.eraseSubsuming( 0 );
- BOOST_CHECK_EQUAL( G.nrClusters(), 1 );
- H.eraseNode2( 0 );
- BOOST_CHECK( G.bipGraph() == H );
- G.eraseSubsuming( 4 );
- BOOST_CHECK_EQUAL( G.nrClusters(), 0 );
- H.eraseNode2( 0 );
- BOOST_CHECK( G.bipGraph() == H );
-}
-
-
-BOOST_AUTO_TEST_CASE( VarElimTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 3 );
- Var v2( 2, 2 );
- Var v3( 3, 4 );
- Var v4( 4, 2 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v03( v0, v3 );
- VarSet v04( v0, v4 );
- VarSet v12( v1, v2 );
- VarSet v13( v1, v3 );
- VarSet v14( v1, v4 );
- VarSet v23( v2, v3 );
- VarSet v24( v2, v4 );
- VarSet v34( v3, v4 );
- VarSet v123 = v12 | v3;
- std::vector<VarSet> cl;
- cl.push_back( v01 );
- cl.push_back( v12 );
- cl.push_back( v123 );
- cl.push_back( v34 );
- cl.push_back( v04 );
- ClusterGraph G( cl );
- ClusterGraph Gorg = G;
-
- BipartiteGraph H( 5, 5 );
- H.addEdge( 0, 0 );
- H.addEdge( 1, 0 );
- H.addEdge( 1, 1 );
- H.addEdge( 2, 1 );
- H.addEdge( 1, 2 );
- H.addEdge( 2, 2 );
- H.addEdge( 3, 2 );
- H.addEdge( 3, 3 );
- H.addEdge( 4, 3 );
- H.addEdge( 0, 4 );
- H.addEdge( 4, 4 );
- BOOST_CHECK( G.bipGraph() == H );
-
- BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 0 ), 1 );
- BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 1 ), 2 );
- BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 2 ), 0 );
- BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 3 ), 2 );
- BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 4 ), 1 );
- cl.clear();
- cl.push_back( v123 );
- cl.push_back( v01 | v4 );
- cl.push_back( v13 | v4 );
- cl.push_back( v34 );
- cl.push_back( v4 );
- BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinFill ) ).clusters(), cl );
-
- G = Gorg;
- BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 0 ), 2*3 );
- BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 1 ), 2*2+2*4 );
- BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 2 ), 0 );
- BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 3 ), 3*2+2*2 );
- BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 4 ), 2*4 );
- cl.clear();
- cl.push_back( v123 );
- cl.push_back( v01 | v4 );
- cl.push_back( v13 | v4 );
- cl.push_back( v34 );
- cl.push_back( v4 );
- BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_WeightedMinFill ) ).clusters(), cl );
-
- G = Gorg;
- BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 0 ), 2 );
- BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 1 ), 3 );
- BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 2 ), 2 );
- BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 3 ), 3 );
- BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 4 ), 2 );
- cl.clear();
- cl.push_back( v01 | v4 );
- cl.push_back( v123 );
- cl.push_back( v13 | v4 );
- cl.push_back( v34 );
- cl.push_back( v4 );
- BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinNeighbors ) ).clusters(), cl );
-
- G = Gorg;
- BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 0 ), 3*2 );
- BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 1 ), 2*2*4 );
- BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 2 ), 3*4 );
- BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 3 ), 3*2*2 );
- BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 4 ), 2*4 );
- cl.clear();
- cl.push_back( v01 | v4 );
- cl.push_back( v123 );
- cl.push_back( v13 | v4 );
- cl.push_back( v14 );
- cl.push_back( v4 );
- BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinWeight ) ).clusters(), cl );
-
- G = Gorg;
- std::vector<Var> vs;
- vs.push_back( v4 );
- vs.push_back( v3 );
- vs.push_back( v2 );
- vs.push_back( v1 );
- vs.push_back( v0 );
- cl.clear();
- cl.push_back( v03 | v4 );
- cl.push_back( v01 | v23 );
- cl.push_back( v01 | v2 );
- cl.push_back( v01 );
- cl.push_back( v0 );
- BOOST_CHECK_EQUAL( G.VarElim( sequentialVariableElimination( vs ) ).clusters(), cl );
-}
-
-
-BOOST_AUTO_TEST_CASE( IOTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 3 );
- Var v2( 2, 2 );
- Var v3( 3, 4 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v03( v0, v3 );
- VarSet v12( v1, v2 );
- VarSet v13( v1, v3 );
- VarSet v23( v2, v3 );
- std::vector<VarSet> cl;
- cl.push_back( v01 );
- cl.push_back( v12 );
- cl.push_back( v23 );
- cl.push_back( v13 );
- ClusterGraph G( cl );
-
- std::stringstream ss;
- ss << G;
- std::string s;
- getline( ss, s );
- BOOST_CHECK_EQUAL( s, "({x0, x1}, {x1, x2}, {x2, x3}, {x1, x3})" );
-}
--- /dev/null
+/* This file is part of libDAI - http://www.libdai.org/
+ *
+ * libDAI is licensed under the terms of the GNU General Public License version
+ * 2, or (at your option) any later version. libDAI is distributed without any
+ * warranty. See the file COPYING for more details.
+ *
+ * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/clustergraph.h>
+#include <vector>
+#include <strstream>
+
+
+using namespace dai;
+
+
+const double tol = 1e-8;
+
+
+#define BOOST_TEST_MODULE ClusterGraphTest
+
+
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+
+
+BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
+ ClusterGraph G;
+ BOOST_CHECK_EQUAL( G.clusters(), std::vector<VarSet>() );
+ BOOST_CHECK( G.bipGraph() == BipartiteGraph() );
+ BOOST_CHECK_EQUAL( G.nrVars(), 0 );
+ BOOST_CHECK_EQUAL( G.nrClusters(), 0 );
+#ifdef DAI_DEBUG
+ BOOST_CHECK_THROW( G.var( 0 ), Exception );
+ BOOST_CHECK_THROW( G.cluster( 0 ), Exception );
+#endif
+ BOOST_CHECK_THROW( G.findVar( Var( 0, 2 ) ), Exception );
+
+ Var v0( 0, 2 );
+ Var v1( 1, 3 );
+ Var v2( 2, 2 );
+ Var v3( 3, 4 );
+ std::vector<Var> vs;
+ vs.push_back( v0 );
+ vs.push_back( v1 );
+ vs.push_back( v2 );
+ vs.push_back( v3 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v03( v0, v3 );
+ VarSet v12( v1, v2 );
+ VarSet v13( v1, v3 );
+ VarSet v23( v2, v3 );
+ std::vector<VarSet> cl;
+ cl.push_back( v01 );
+ cl.push_back( v12 );
+ cl.push_back( v23 );
+ cl.push_back( v13 );
+ ClusterGraph G2( cl );
+ BOOST_CHECK_EQUAL( G2.nrVars(), 4 );
+ BOOST_CHECK_EQUAL( G2.nrClusters(), 4 );
+ BOOST_CHECK_EQUAL( G2.vars(), vs );
+ BOOST_CHECK_EQUAL( G2.clusters(), cl );
+ BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
+ BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
+ BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
+ BOOST_CHECK_EQUAL( G2.findVar( v3 ), 3 );
+
+ ClusterGraph Gb( G );
+ BOOST_CHECK( G.bipGraph() == Gb.bipGraph() );
+ BOOST_CHECK( G.vars() == Gb.vars() );
+ BOOST_CHECK( G.clusters() == Gb.clusters() );
+
+ ClusterGraph Gc = G;
+ BOOST_CHECK( G.bipGraph() == Gc.bipGraph() );
+ BOOST_CHECK( G.vars() == Gc.vars() );
+ BOOST_CHECK( G.clusters() == Gc.clusters() );
+
+ ClusterGraph G2b( G2 );
+ BOOST_CHECK( G2.bipGraph() == G2b.bipGraph() );
+ BOOST_CHECK( G2.vars() == G2b.vars() );
+ BOOST_CHECK( G2.clusters() == G2b.clusters() );
+
+ ClusterGraph G2c = G2;
+ BOOST_CHECK( G2.bipGraph() == G2c.bipGraph() );
+ BOOST_CHECK( G2.vars() == G2c.vars() );
+ BOOST_CHECK( G2.clusters() == G2c.clusters() );
+}
+
+
+BOOST_AUTO_TEST_CASE( QueriesTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 3 );
+ Var v2( 2, 2 );
+ Var v3( 3, 4 );
+ Var v4( 4, 2 );
+ std::vector<Var> vs;
+ vs.push_back( v0 );
+ vs.push_back( v1 );
+ vs.push_back( v2 );
+ vs.push_back( v3 );
+ vs.push_back( v4 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v03( v0, v3 );
+ VarSet v04( v0, v4 );
+ VarSet v12( v1, v2 );
+ VarSet v13( v1, v3 );
+ VarSet v14( v1, v4 );
+ VarSet v23( v2, v3 );
+ VarSet v24( v2, v4 );
+ VarSet v34( v3, v4 );
+ VarSet v123 = v12 | v3;
+ std::vector<VarSet> cl;
+ cl.push_back( v01 );
+ cl.push_back( v12 );
+ cl.push_back( v123 );
+ cl.push_back( v34 );
+ cl.push_back( v04 );
+ ClusterGraph G( cl );
+
+ BOOST_CHECK_EQUAL( G.nrVars(), 5 );
+ BOOST_CHECK_EQUAL( G.vars(), vs );
+ BOOST_CHECK_EQUAL( G.var(0), v0 );
+ BOOST_CHECK_EQUAL( G.var(1), v1 );
+ BOOST_CHECK_EQUAL( G.var(2), v2 );
+ BOOST_CHECK_EQUAL( G.var(3), v3 );
+ BOOST_CHECK_EQUAL( G.var(4), v4 );
+ BOOST_CHECK_EQUAL( G.nrClusters(), 5 );
+ BOOST_CHECK_EQUAL( G.clusters(), cl );
+ BOOST_CHECK_EQUAL( G.cluster(0), v01 );
+ BOOST_CHECK_EQUAL( G.cluster(1), v12 );
+ BOOST_CHECK_EQUAL( G.cluster(2), v123 );
+ BOOST_CHECK_EQUAL( G.cluster(3), v34 );
+ BOOST_CHECK_EQUAL( G.cluster(4), v04 );
+ BOOST_CHECK_EQUAL( G.findVar( v0 ), 0 );
+ BOOST_CHECK_EQUAL( G.findVar( v1 ), 1 );
+ BOOST_CHECK_EQUAL( G.findVar( v2 ), 2 );
+ BOOST_CHECK_EQUAL( G.findVar( v3 ), 3 );
+ BOOST_CHECK_EQUAL( G.findVar( v4 ), 4 );
+ BipartiteGraph H( 5, 5 );
+ H.addEdge( 0, 0 );
+ H.addEdge( 1, 0 );
+ H.addEdge( 1, 1 );
+ H.addEdge( 2, 1 );
+ H.addEdge( 1, 2 );
+ H.addEdge( 2, 2 );
+ H.addEdge( 3, 2 );
+ H.addEdge( 3, 3 );
+ H.addEdge( 4, 3 );
+ H.addEdge( 0, 4 );
+ H.addEdge( 4, 4 );
+ BOOST_CHECK( G.bipGraph() == H );
+
+ BOOST_CHECK_EQUAL( G.delta( 0 ), v14 );
+ BOOST_CHECK_EQUAL( G.delta( 1 ), v02 | v3 );
+ BOOST_CHECK_EQUAL( G.delta( 2 ), v13 );
+ BOOST_CHECK_EQUAL( G.delta( 3 ), v12 | v4 );
+ BOOST_CHECK_EQUAL( G.delta( 4 ), v03 );
+ BOOST_CHECK_EQUAL( G.Delta( 0 ), v14 | v0 );
+ BOOST_CHECK_EQUAL( G.Delta( 1 ), v01 | v23 );
+ BOOST_CHECK_EQUAL( G.Delta( 2 ), v13 | v2 );
+ BOOST_CHECK_EQUAL( G.Delta( 3 ), v12 | v34 );
+ BOOST_CHECK_EQUAL( G.Delta( 4 ), v03 | v4 );
+
+ BOOST_CHECK( !G.adj( 0, 0 ) );
+ BOOST_CHECK( G.adj( 0, 1 ) );
+ BOOST_CHECK( !G.adj( 0, 2 ) );
+ BOOST_CHECK( !G.adj( 0, 3 ) );
+ BOOST_CHECK( G.adj( 0, 4 ) );
+ BOOST_CHECK( G.adj( 1, 0 ) );
+ BOOST_CHECK( !G.adj( 1, 1 ) );
+ BOOST_CHECK( G.adj( 1, 2 ) );
+ BOOST_CHECK( G.adj( 1, 3 ) );
+ BOOST_CHECK( !G.adj( 1, 4 ) );
+ BOOST_CHECK( !G.adj( 2, 0 ) );
+ BOOST_CHECK( G.adj( 2, 1 ) );
+ BOOST_CHECK( !G.adj( 2, 2 ) );
+ BOOST_CHECK( G.adj( 2, 3 ) );
+ BOOST_CHECK( !G.adj( 2, 4 ) );
+ BOOST_CHECK( !G.adj( 3, 0 ) );
+ BOOST_CHECK( G.adj( 3, 1 ) );
+ BOOST_CHECK( G.adj( 3, 2 ) );
+ BOOST_CHECK( !G.adj( 3, 3 ) );
+ BOOST_CHECK( G.adj( 3, 4 ) );
+ BOOST_CHECK( G.adj( 4, 0 ) );
+ BOOST_CHECK( !G.adj( 4, 1 ) );
+ BOOST_CHECK( !G.adj( 4, 2 ) );
+ BOOST_CHECK( G.adj( 4, 3 ) );
+ BOOST_CHECK( !G.adj( 4, 4 ) );
+
+ BOOST_CHECK( G.isMaximal( 0 ) );
+ BOOST_CHECK( !G.isMaximal( 1 ) );
+ BOOST_CHECK( G.isMaximal( 2 ) );
+ BOOST_CHECK( G.isMaximal( 3 ) );
+ BOOST_CHECK( G.isMaximal( 4 ) );
+}
+
+
+BOOST_AUTO_TEST_CASE( OperationsTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 3 );
+ Var v2( 2, 2 );
+ Var v3( 3, 4 );
+ Var v4( 4, 2 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v03( v0, v3 );
+ VarSet v04( v0, v4 );
+ VarSet v12( v1, v2 );
+ VarSet v13( v1, v3 );
+ VarSet v14( v1, v4 );
+ VarSet v23( v2, v3 );
+ VarSet v24( v2, v4 );
+ VarSet v34( v3, v4 );
+ VarSet v123 = v12 | v3;
+ std::vector<VarSet> cl;
+ cl.push_back( v01 );
+ cl.push_back( v12 );
+ cl.push_back( v123 );
+ cl.push_back( v34 );
+ cl.push_back( v04 );
+ ClusterGraph G( cl );
+
+ BipartiteGraph H( 5, 5 );
+ H.addEdge( 0, 0 );
+ H.addEdge( 1, 0 );
+ H.addEdge( 1, 1 );
+ H.addEdge( 2, 1 );
+ H.addEdge( 1, 2 );
+ H.addEdge( 2, 2 );
+ H.addEdge( 3, 2 );
+ H.addEdge( 3, 3 );
+ H.addEdge( 4, 3 );
+ H.addEdge( 0, 4 );
+ H.addEdge( 4, 4 );
+ BOOST_CHECK( G.bipGraph() == H );
+
+ G.eraseNonMaximal();
+ BOOST_CHECK_EQUAL( G.nrClusters(), 4 );
+ H.eraseNode2( 1 );
+ BOOST_CHECK( G.bipGraph() == H );
+ G.eraseSubsuming( 4 );
+ BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
+ H.eraseNode2( 2 );
+ H.eraseNode2( 2 );
+ BOOST_CHECK( G.bipGraph() == H );
+ G.insert( v34 );
+ BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
+ G.insert( v123 );
+ BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
+ H.addNode2();
+ H.addEdge( 3, 2 );
+ H.addEdge( 4, 2 );
+ BOOST_CHECK( G.bipGraph() == H );
+ G.insert( v12 );
+ G.insert( v23 );
+ BOOST_CHECK_EQUAL( G.nrClusters(), 5 );
+ H.addNode2();
+ H.addNode2();
+ H.addEdge( 1, 3 );
+ H.addEdge( 2, 3 );
+ H.addEdge( 2, 4 );
+ H.addEdge( 3, 4 );
+ BOOST_CHECK( G.bipGraph() == H );
+ G.eraseNonMaximal();
+ BOOST_CHECK_EQUAL( G.nrClusters(), 3 );
+ H.eraseNode2( 3 );
+ H.eraseNode2( 3 );
+ BOOST_CHECK( G.bipGraph() == H );
+ G.eraseSubsuming( 2 );
+ BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
+ H.eraseNode2( 1 );
+ BOOST_CHECK( G.bipGraph() == H );
+ G.eraseNonMaximal();
+ BOOST_CHECK_EQUAL( G.nrClusters(), 2 );
+ BOOST_CHECK( G.bipGraph() == H );
+ G.eraseSubsuming( 0 );
+ BOOST_CHECK_EQUAL( G.nrClusters(), 1 );
+ H.eraseNode2( 0 );
+ BOOST_CHECK( G.bipGraph() == H );
+ G.eraseSubsuming( 4 );
+ BOOST_CHECK_EQUAL( G.nrClusters(), 0 );
+ H.eraseNode2( 0 );
+ BOOST_CHECK( G.bipGraph() == H );
+}
+
+
+BOOST_AUTO_TEST_CASE( VarElimTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 3 );
+ Var v2( 2, 2 );
+ Var v3( 3, 4 );
+ Var v4( 4, 2 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v03( v0, v3 );
+ VarSet v04( v0, v4 );
+ VarSet v12( v1, v2 );
+ VarSet v13( v1, v3 );
+ VarSet v14( v1, v4 );
+ VarSet v23( v2, v3 );
+ VarSet v24( v2, v4 );
+ VarSet v34( v3, v4 );
+ VarSet v123 = v12 | v3;
+ std::vector<VarSet> cl;
+ cl.push_back( v01 );
+ cl.push_back( v12 );
+ cl.push_back( v123 );
+ cl.push_back( v34 );
+ cl.push_back( v04 );
+ ClusterGraph G( cl );
+ ClusterGraph Gorg = G;
+
+ BipartiteGraph H( 5, 5 );
+ H.addEdge( 0, 0 );
+ H.addEdge( 1, 0 );
+ H.addEdge( 1, 1 );
+ H.addEdge( 2, 1 );
+ H.addEdge( 1, 2 );
+ H.addEdge( 2, 2 );
+ H.addEdge( 3, 2 );
+ H.addEdge( 3, 3 );
+ H.addEdge( 4, 3 );
+ H.addEdge( 0, 4 );
+ H.addEdge( 4, 4 );
+ BOOST_CHECK( G.bipGraph() == H );
+
+ BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 0 ), 1 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 1 ), 2 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 2 ), 0 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 3 ), 2 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinFill( G, 4 ), 1 );
+ cl.clear();
+ cl.push_back( v123 );
+ cl.push_back( v01 | v4 );
+ cl.push_back( v13 | v4 );
+ cl.push_back( v34 );
+ cl.push_back( v4 );
+ BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinFill ) ).clusters(), cl );
+
+ G = Gorg;
+ BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 0 ), 2*3 );
+ BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 1 ), 2*2+2*4 );
+ BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 2 ), 0 );
+ BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 3 ), 3*2+2*2 );
+ BOOST_CHECK_EQUAL( eliminationCost_WeightedMinFill( G, 4 ), 2*4 );
+ cl.clear();
+ cl.push_back( v123 );
+ cl.push_back( v01 | v4 );
+ cl.push_back( v13 | v4 );
+ cl.push_back( v34 );
+ cl.push_back( v4 );
+ BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_WeightedMinFill ) ).clusters(), cl );
+
+ G = Gorg;
+ BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 0 ), 2 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 1 ), 3 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 2 ), 2 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 3 ), 3 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinNeighbors( G, 4 ), 2 );
+ cl.clear();
+ cl.push_back( v01 | v4 );
+ cl.push_back( v123 );
+ cl.push_back( v13 | v4 );
+ cl.push_back( v34 );
+ cl.push_back( v4 );
+ BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinNeighbors ) ).clusters(), cl );
+
+ G = Gorg;
+ BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 0 ), 3*2 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 1 ), 2*2*4 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 2 ), 3*4 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 3 ), 3*2*2 );
+ BOOST_CHECK_EQUAL( eliminationCost_MinWeight( G, 4 ), 2*4 );
+ cl.clear();
+ cl.push_back( v01 | v4 );
+ cl.push_back( v123 );
+ cl.push_back( v13 | v4 );
+ cl.push_back( v14 );
+ cl.push_back( v4 );
+ BOOST_CHECK_EQUAL( G.VarElim( greedyVariableElimination( eliminationCost_MinWeight ) ).clusters(), cl );
+
+ G = Gorg;
+ std::vector<Var> vs;
+ vs.push_back( v4 );
+ vs.push_back( v3 );
+ vs.push_back( v2 );
+ vs.push_back( v1 );
+ vs.push_back( v0 );
+ cl.clear();
+ cl.push_back( v03 | v4 );
+ cl.push_back( v01 | v23 );
+ cl.push_back( v01 | v2 );
+ cl.push_back( v01 );
+ cl.push_back( v0 );
+ BOOST_CHECK_EQUAL( G.VarElim( sequentialVariableElimination( vs ) ).clusters(), cl );
+}
+
+
+BOOST_AUTO_TEST_CASE( IOTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 3 );
+ Var v2( 2, 2 );
+ Var v3( 3, 4 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v03( v0, v3 );
+ VarSet v12( v1, v2 );
+ VarSet v13( v1, v3 );
+ VarSet v23( v2, v3 );
+ std::vector<VarSet> cl;
+ cl.push_back( v01 );
+ cl.push_back( v12 );
+ cl.push_back( v23 );
+ cl.push_back( v13 );
+ ClusterGraph G( cl );
+
+ std::stringstream ss;
+ ss << G;
+ std::string s;
+ getline( ss, s );
+ BOOST_CHECK_EQUAL( s, "({x0, x1}, {x1, x2}, {x2, x3}, {x1, x3})" );
+}
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/daialg.h>
-#include <dai/alldai.h>
-#include <strstream>
-
-
-using namespace dai;
-
-
-const double tol = 1e-8;
-
-
-#define BOOST_TEST_MODULE DAIAlgTest
-
-
-#include <boost/test/unit_test.hpp>
-
-
-BOOST_AUTO_TEST_CASE( calcMarginalTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 2 );
- Var v2( 2, 2 );
- Var v3( 3, 2 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v03( v0, v3 );
- VarSet v12( v1, v2 );
- VarSet v13( v1, v3 );
- VarSet v23( v2, v3 );
- std::vector<Factor> facs;
- facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
- facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
- facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
- facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
- facs.push_back( createFactorIsing( v0, -1.0 ) );
- facs.push_back( createFactorIsing( v1, -1.0 ) );
- facs.push_back( createFactorIsing( v2, -1.0 ) );
- facs.push_back( createFactorIsing( v3, 1.0 ) );
- Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
- FactorGraph fg( facs );
- ExactInf ei( fg, PropertySet()("verbose",(size_t)0) );
- ei.init();
- ei.run();
- VarSet vs;
-
- vs = v0; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v1; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v2; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v01; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v02; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v03; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v12; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v13; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v23; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v01 | v2; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v01 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v02 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v12 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
- vs = v01 | v23; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
-}
-
-
-BOOST_AUTO_TEST_CASE( calcPairBeliefsTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 2 );
- Var v2( 2, 2 );
- Var v3( 3, 2 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v03( v0, v3 );
- VarSet v12( v1, v2 );
- VarSet v13( v1, v3 );
- VarSet v23( v2, v3 );
- std::vector<Factor> facs;
- facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
- facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
- facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
- facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
- facs.push_back( createFactorIsing( v0, -1.0 ) );
- facs.push_back( createFactorIsing( v1, -1.0 ) );
- facs.push_back( createFactorIsing( v2, -1.0 ) );
- facs.push_back( createFactorIsing( v3, 1.0 ) );
- Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
- FactorGraph fg( facs );
- ExactInf ei( fg, PropertySet()("verbose",(size_t)0) );
- ei.init();
- ei.run();
- VarSet vs;
-
- std::vector<Factor> pb = calcPairBeliefs( ei, v01 | v23, false, false );
- BOOST_CHECK( dist( pb[0], joint.marginal( v01 ), DISTTV ) < tol );
- BOOST_CHECK( dist( pb[1], joint.marginal( v02 ), DISTTV ) < tol );
- BOOST_CHECK( dist( pb[2], joint.marginal( v03 ), DISTTV ) < tol );
- BOOST_CHECK( dist( pb[3], joint.marginal( v12 ), DISTTV ) < tol );
- BOOST_CHECK( dist( pb[4], joint.marginal( v13 ), DISTTV ) < tol );
- BOOST_CHECK( dist( pb[5], joint.marginal( v23 ), DISTTV ) < tol );
-
- pb = calcPairBeliefs( ei, v01 | v23, false, true );
- BOOST_CHECK( dist( pb[0], joint.marginal( v01 ), DISTTV ) < tol );
- BOOST_CHECK( dist( pb[1], joint.marginal( v02 ), DISTTV ) < tol );
- BOOST_CHECK( dist( pb[2], joint.marginal( v03 ), DISTTV ) < tol );
- BOOST_CHECK( dist( pb[3], joint.marginal( v12 ), DISTTV ) < tol );
- BOOST_CHECK( dist( pb[4], joint.marginal( v13 ), DISTTV ) < tol );
- BOOST_CHECK( dist( pb[5], joint.marginal( v23 ), DISTTV ) < tol );
-}
--- /dev/null
+/* This file is part of libDAI - http://www.libdai.org/
+ *
+ * libDAI is licensed under the terms of the GNU General Public License version
+ * 2, or (at your option) any later version. libDAI is distributed without any
+ * warranty. See the file COPYING for more details.
+ *
+ * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/daialg.h>
+#include <dai/alldai.h>
+#include <strstream>
+
+
+using namespace dai;
+
+
+const double tol = 1e-8;
+
+
+#define BOOST_TEST_MODULE DAIAlgTest
+
+
+#include <boost/test/unit_test.hpp>
+
+
+BOOST_AUTO_TEST_CASE( calcMarginalTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 2 );
+ Var v2( 2, 2 );
+ Var v3( 3, 2 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v03( v0, v3 );
+ VarSet v12( v1, v2 );
+ VarSet v13( v1, v3 );
+ VarSet v23( v2, v3 );
+ std::vector<Factor> facs;
+ facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
+ facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
+ facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
+ facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
+ facs.push_back( createFactorIsing( v0, -1.0 ) );
+ facs.push_back( createFactorIsing( v1, -1.0 ) );
+ facs.push_back( createFactorIsing( v2, -1.0 ) );
+ facs.push_back( createFactorIsing( v3, 1.0 ) );
+ Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
+ FactorGraph fg( facs );
+ ExactInf ei( fg, PropertySet()("verbose",(size_t)0) );
+ ei.init();
+ ei.run();
+ VarSet vs;
+
+ vs = v0; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v1; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v2; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v01; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v02; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v03; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v12; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v13; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v23; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v01 | v2; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v01 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v02 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v12 | v3; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+ vs = v01 | v23; BOOST_CHECK( dist( calcMarginal( ei, vs, false ), joint.marginal( vs ), DISTTV ) < tol );
+}
+
+
+BOOST_AUTO_TEST_CASE( calcPairBeliefsTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 2 );
+ Var v2( 2, 2 );
+ Var v3( 3, 2 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v03( v0, v3 );
+ VarSet v12( v1, v2 );
+ VarSet v13( v1, v3 );
+ VarSet v23( v2, v3 );
+ std::vector<Factor> facs;
+ facs.push_back( createFactorIsing( v0, v1, 1.0 ) );
+ facs.push_back( createFactorIsing( v1, v2, 1.0 ) );
+ facs.push_back( createFactorIsing( v2, v3, 1.0 ) );
+ facs.push_back( createFactorIsing( v3, v0, 1.0 ) );
+ facs.push_back( createFactorIsing( v0, -1.0 ) );
+ facs.push_back( createFactorIsing( v1, -1.0 ) );
+ facs.push_back( createFactorIsing( v2, -1.0 ) );
+ facs.push_back( createFactorIsing( v3, 1.0 ) );
+ Factor joint = facs[0] * facs[1] * facs[2] * facs[3] * facs[4] * facs[5] * facs[6] * facs[7];
+ FactorGraph fg( facs );
+ ExactInf ei( fg, PropertySet()("verbose",(size_t)0) );
+ ei.init();
+ ei.run();
+ VarSet vs;
+
+ std::vector<Factor> pb = calcPairBeliefs( ei, v01 | v23, false, false );
+ BOOST_CHECK( dist( pb[0], joint.marginal( v01 ), DISTTV ) < tol );
+ BOOST_CHECK( dist( pb[1], joint.marginal( v02 ), DISTTV ) < tol );
+ BOOST_CHECK( dist( pb[2], joint.marginal( v03 ), DISTTV ) < tol );
+ BOOST_CHECK( dist( pb[3], joint.marginal( v12 ), DISTTV ) < tol );
+ BOOST_CHECK( dist( pb[4], joint.marginal( v13 ), DISTTV ) < tol );
+ BOOST_CHECK( dist( pb[5], joint.marginal( v23 ), DISTTV ) < tol );
+
+ pb = calcPairBeliefs( ei, v01 | v23, false, true );
+ BOOST_CHECK( dist( pb[0], joint.marginal( v01 ), DISTTV ) < tol );
+ BOOST_CHECK( dist( pb[1], joint.marginal( v02 ), DISTTV ) < tol );
+ BOOST_CHECK( dist( pb[2], joint.marginal( v03 ), DISTTV ) < tol );
+ BOOST_CHECK( dist( pb[3], joint.marginal( v12 ), DISTTV ) < tol );
+ BOOST_CHECK( dist( pb[4], joint.marginal( v13 ), DISTTV ) < tol );
+ BOOST_CHECK( dist( pb[5], joint.marginal( v23 ), DISTTV ) < tol );
+}
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/enum.h>
-#include <strstream>
-#include <iostream>
-#include <string>
-
-
-using namespace dai;
-
-
-#define BOOST_TEST_MODULE EnumTest
-
-
-#include <boost/test/unit_test.hpp>
-
-
-DAI_ENUM(colors,RED,GREEN,BLUE);
-
-
-BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
- colors c;
- BOOST_CHECK( c == colors::RED );
- BOOST_CHECK_EQUAL( static_cast<size_t>(c), 0 );
- BOOST_CHECK_EQUAL( strcmp( static_cast<char const *>(c), "RED" ), 0 );
-
- colors d(colors::GREEN);
- BOOST_CHECK( d == colors::GREEN );
- BOOST_CHECK_EQUAL( static_cast<size_t>(d), 1 );
- BOOST_CHECK_EQUAL( strcmp( static_cast<char const *>(d), "GREEN" ), 0 );
-
- colors e("BLUE");
- BOOST_CHECK( e == colors::BLUE );
- BOOST_CHECK_EQUAL( static_cast<size_t>(e), 2 );
- BOOST_CHECK_EQUAL( strcmp( static_cast<char const *>(e), "BLUE" ), 0 );
-
- BOOST_CHECK_THROW( colors f("BLUEISH"), Exception );
-
- colors f = e;
- colors g(f);
- BOOST_CHECK( f == colors::BLUE );
- BOOST_CHECK( g == colors::BLUE );
- BOOST_CHECK( static_cast<colors::value>(f) == static_cast<colors::value>(e) );
- BOOST_CHECK( static_cast<colors::value>(e) == static_cast<colors::value>(g) );
- BOOST_CHECK( static_cast<colors::value>(f) == static_cast<colors::value>(g) );
-}
-
-
-BOOST_AUTO_TEST_CASE( StreamTest ) {
- std::stringstream ss1, ss2, ss3, ss4, ss5, ss6, ss7;
- std::string s;
-
- ss1 << colors(colors::RED);
- ss1 >> s;
- BOOST_CHECK_EQUAL( s, "RED" );
-
- ss2 << colors(colors::GREEN);
- std::getline( ss2, s );
- BOOST_CHECK_EQUAL( s, "GREEN" );
-
- ss3 << colors(colors::BLUE);
- ss3 >> s;
- BOOST_CHECK_EQUAL( s, "BLUE" );
-
- colors c;
- ss4 << colors(colors::RED);
- ss4 >> c;
- BOOST_CHECK_EQUAL( c, colors::RED );
- ss5 << colors(colors::GREEN);
- ss5 >> c;
- BOOST_CHECK_EQUAL( c, colors::GREEN );
- ss6 << colors(colors::BLUE);
- ss6 >> c;
- BOOST_CHECK_EQUAL( c, colors::BLUE );
-
- ss7 << "BLUEISH";
- BOOST_CHECK_THROW( ss7 >> c, Exception );
-}
--- /dev/null
+/* This file is part of libDAI - http://www.libdai.org/
+ *
+ * libDAI is licensed under the terms of the GNU General Public License version
+ * 2, or (at your option) any later version. libDAI is distributed without any
+ * warranty. See the file COPYING for more details.
+ *
+ * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/enum.h>
+#include <strstream>
+#include <iostream>
+#include <string>
+
+
+using namespace dai;
+
+
+#define BOOST_TEST_MODULE EnumTest
+
+
+#include <boost/test/unit_test.hpp>
+
+
+DAI_ENUM(colors,RED,GREEN,BLUE);
+
+
+BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
+ colors c;
+ BOOST_CHECK( c == colors::RED );
+ BOOST_CHECK_EQUAL( static_cast<size_t>(c), 0 );
+ BOOST_CHECK_EQUAL( strcmp( static_cast<char const *>(c), "RED" ), 0 );
+
+ colors d(colors::GREEN);
+ BOOST_CHECK( d == colors::GREEN );
+ BOOST_CHECK_EQUAL( static_cast<size_t>(d), 1 );
+ BOOST_CHECK_EQUAL( strcmp( static_cast<char const *>(d), "GREEN" ), 0 );
+
+ colors e("BLUE");
+ BOOST_CHECK( e == colors::BLUE );
+ BOOST_CHECK_EQUAL( static_cast<size_t>(e), 2 );
+ BOOST_CHECK_EQUAL( strcmp( static_cast<char const *>(e), "BLUE" ), 0 );
+
+ BOOST_CHECK_THROW( colors f("BLUEISH"), Exception );
+
+ colors f = e;
+ colors g(f);
+ BOOST_CHECK( f == colors::BLUE );
+ BOOST_CHECK( g == colors::BLUE );
+ BOOST_CHECK( static_cast<colors::value>(f) == static_cast<colors::value>(e) );
+ BOOST_CHECK( static_cast<colors::value>(e) == static_cast<colors::value>(g) );
+ BOOST_CHECK( static_cast<colors::value>(f) == static_cast<colors::value>(g) );
+}
+
+
+BOOST_AUTO_TEST_CASE( StreamTest ) {
+ std::stringstream ss1, ss2, ss3, ss4, ss5, ss6, ss7;
+ std::string s;
+
+ ss1 << colors(colors::RED);
+ ss1 >> s;
+ BOOST_CHECK_EQUAL( s, "RED" );
+
+ ss2 << colors(colors::GREEN);
+ std::getline( ss2, s );
+ BOOST_CHECK_EQUAL( s, "GREEN" );
+
+ ss3 << colors(colors::BLUE);
+ ss3 >> s;
+ BOOST_CHECK_EQUAL( s, "BLUE" );
+
+ colors c;
+ ss4 << colors(colors::RED);
+ ss4 >> c;
+ BOOST_CHECK_EQUAL( c, colors::RED );
+ ss5 << colors(colors::GREEN);
+ ss5 >> c;
+ BOOST_CHECK_EQUAL( c, colors::GREEN );
+ ss6 << colors(colors::BLUE);
+ ss6 >> c;
+ BOOST_CHECK_EQUAL( c, colors::BLUE );
+
+ ss7 << "BLUEISH";
+ BOOST_CHECK_THROW( ss7 >> c, Exception );
+}
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/exceptions.h>
-#include <strstream>
-
-
-using namespace dai;
-
-
-#define BOOST_TEST_MODULE ExceptionsTest
-
-
-#include <boost/test/unit_test.hpp>
-
-
-BOOST_AUTO_TEST_CASE( ExceptionsTest ) {
- BOOST_CHECK_THROW( DAI_THROW(NOT_IMPLEMENTED), Exception );
- BOOST_CHECK_THROW( DAI_THROW(NOT_IMPLEMENTED), std::runtime_error );
- BOOST_CHECK_THROW( DAI_THROWE(NOT_IMPLEMENTED,"Detailed error message"), Exception );
- BOOST_CHECK_THROW( DAI_THROWE(NOT_IMPLEMENTED,"Detailed error messgae"), std::runtime_error );
- BOOST_CHECK_THROW( DAI_ASSERT( 0 ), Exception );
- BOOST_CHECK_THROW( DAI_ASSERT( 0 == 1 ), std::runtime_error );
-
- try {
- DAI_THROW(NOT_IMPLEMENTED);
- } catch( Exception& e ) {
- BOOST_CHECK_EQUAL( e.code(), Exception::NOT_IMPLEMENTED );
- BOOST_CHECK_EQUAL( e.message(e.code()), std::string("Feature not implemented") );
- }
-
- try {
- DAI_THROWE(NOT_IMPLEMENTED,"Detailed error message");
- } catch( Exception& e ) {
- BOOST_CHECK_EQUAL( e.code(), Exception::NOT_IMPLEMENTED );
- BOOST_CHECK_EQUAL( e.message(e.code()), std::string("Feature not implemented") );
- }
-
- try {
- DAI_THROW(NOT_IMPLEMENTED);
- } catch( std::runtime_error& e ) {
- BOOST_CHECK_EQUAL( e.what(), std::string("Feature not implemented [tests/unit/exceptions.cpp, line 47]") );
- }
-
- try {
- DAI_THROWE(NOT_IMPLEMENTED,"Detailed error message");
- } catch( std::runtime_error& e ) {
- BOOST_CHECK_EQUAL( e.what(), std::string("Feature not implemented [tests/unit/exceptions.cpp, line 53]") );
- }
-}
--- /dev/null
+/* This file is part of libDAI - http://www.libdai.org/
+ *
+ * libDAI is licensed under the terms of the GNU General Public License version
+ * 2, or (at your option) any later version. libDAI is distributed without any
+ * warranty. See the file COPYING for more details.
+ *
+ * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/exceptions.h>
+#include <strstream>
+
+
+using namespace dai;
+
+
+#define BOOST_TEST_MODULE ExceptionsTest
+
+
+#include <boost/test/unit_test.hpp>
+
+
+BOOST_AUTO_TEST_CASE( ExceptionsTest ) {
+ BOOST_CHECK_THROW( DAI_THROW(NOT_IMPLEMENTED), Exception );
+ BOOST_CHECK_THROW( DAI_THROW(NOT_IMPLEMENTED), std::runtime_error );
+ BOOST_CHECK_THROW( DAI_THROWE(NOT_IMPLEMENTED,"Detailed error message"), Exception );
+ BOOST_CHECK_THROW( DAI_THROWE(NOT_IMPLEMENTED,"Detailed error messgae"), std::runtime_error );
+ BOOST_CHECK_THROW( DAI_ASSERT( 0 ), Exception );
+ BOOST_CHECK_THROW( DAI_ASSERT( 0 == 1 ), std::runtime_error );
+
+ try {
+ DAI_THROW(NOT_IMPLEMENTED);
+ } catch( Exception& e ) {
+ BOOST_CHECK_EQUAL( e.code(), Exception::NOT_IMPLEMENTED );
+ BOOST_CHECK_EQUAL( e.message(e.code()), std::string("Feature not implemented") );
+ }
+
+ try {
+ DAI_THROWE(NOT_IMPLEMENTED,"Detailed error message");
+ } catch( Exception& e ) {
+ BOOST_CHECK_EQUAL( e.code(), Exception::NOT_IMPLEMENTED );
+ BOOST_CHECK_EQUAL( e.message(e.code()), std::string("Feature not implemented") );
+ }
+
+ try {
+ DAI_THROW(NOT_IMPLEMENTED);
+ } catch( std::runtime_error& e ) {
+ BOOST_CHECK_EQUAL( e.what(), std::string("Feature not implemented [tests/unit/exceptions_test.cpp, line 47]") );
+ }
+
+ try {
+ DAI_THROWE(NOT_IMPLEMENTED,"Detailed error message");
+ } catch( std::runtime_error& e ) {
+ BOOST_CHECK_EQUAL( e.what(), std::string("Feature not implemented [tests/unit/exceptions_test.cpp, line 53]") );
+ }
+}
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/factor.h>
-#include <strstream>
-
-
-using namespace dai;
-
-
-const double tol = 1e-8;
-
-
-#define BOOST_TEST_MODULE FactorTest
-
-
-#include <boost/test/unit_test.hpp>
-#include <boost/test/floating_point_comparison.hpp>
-
-
-BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
- // check constructors
- Factor x1;
- BOOST_CHECK_EQUAL( x1.nrStates(), 1 );
- BOOST_CHECK( x1.p() == Prob( 1, 1.0 ) );
- BOOST_CHECK( x1.vars() == VarSet() );
-
- Factor x2( 5.0 );
- BOOST_CHECK_EQUAL( x2.nrStates(), 1 );
- BOOST_CHECK( x2.p() == Prob( 1, 5.0 ) );
- BOOST_CHECK( x2.vars() == VarSet() );
-
- Var v1( 0, 3 );
- Factor x3( v1 );
- BOOST_CHECK_EQUAL( x3.nrStates(), 3 );
- BOOST_CHECK( x3.p() == Prob( 3, 1.0 / 3.0 ) );
- BOOST_CHECK( x3.vars() == VarSet( v1 ) );
- BOOST_CHECK_EQUAL( x3[0], 1.0 / 3.0 );
- BOOST_CHECK_EQUAL( x3[1], 1.0 / 3.0 );
- BOOST_CHECK_EQUAL( x3[2], 1.0 / 3.0 );
-
- Var v2( 1, 2 );
- Factor x4( VarSet( v1, v2 ) );
- BOOST_CHECK_EQUAL( x4.nrStates(), 6 );
- BOOST_CHECK( x4.p() == Prob( 6, 1.0 / 6.0 ) );
- BOOST_CHECK( x4.vars() == VarSet( v1, v2 ) );
- for( size_t i = 0; i < 6; i++ )
- BOOST_CHECK_EQUAL( x4[i], 1.0 / 6.0 );
-
- Factor x5( VarSet( v1, v2 ), 1.0 );
- BOOST_CHECK_EQUAL( x5.nrStates(), 6 );
- BOOST_CHECK( x5.p() == Prob( 6, 1.0 ) );
- BOOST_CHECK( x5.vars() == VarSet( v1, v2 ) );
- for( size_t i = 0; i < 6; i++ )
- BOOST_CHECK_EQUAL( x5[i], 1.0 );
-
- std::vector<Real> x( 6, 1.0 );
- for( size_t i = 0; i < 6; i++ )
- x[i] = 10.0 - i;
- Factor x6( VarSet( v1, v2 ), x );
- BOOST_CHECK_EQUAL( x6.nrStates(), 6 );
- BOOST_CHECK( x6.vars() == VarSet( v1, v2 ) );
- for( size_t i = 0; i < 6; i++ )
- BOOST_CHECK_EQUAL( x6[i], x[i] );
-
- x.resize( 4 );
- BOOST_CHECK_THROW( Factor x7( VarSet( v1, v2 ), x ), Exception );
-
- x.resize( 6 );
- x[4] = 10.0 - 4; x[5] = 10.0 - 5;
- Factor x8( VarSet( v2, v1 ), &(x[0]) );
- BOOST_CHECK_EQUAL( x8.nrStates(), 6 );
- BOOST_CHECK( x8.vars() == VarSet( v1, v2 ) );
- for( size_t i = 0; i < 6; i++ )
- BOOST_CHECK_EQUAL( x8[i], x[i] );
-
- Prob xx( x );
- Factor x9( VarSet( v2, v1 ), xx );
- BOOST_CHECK_EQUAL( x9.nrStates(), 6 );
- BOOST_CHECK( x9.vars() == VarSet( v1, v2 ) );
- for( size_t i = 0; i < 6; i++ )
- BOOST_CHECK_EQUAL( x9[i], x[i] );
-
- xx.resize( 4 );
- BOOST_CHECK_THROW( Factor x10( VarSet( v2, v1 ), xx ), Exception );
-
- std::vector<Real> w;
- w.push_back( 0.1 );
- w.push_back( 3.5 );
- w.push_back( 2.8 );
- w.push_back( 6.3 );
- w.push_back( 8.4 );
- w.push_back( 0.0 );
- w.push_back( 7.4 );
- w.push_back( 2.4 );
- w.push_back( 8.9 );
- w.push_back( 1.3 );
- w.push_back( 1.6 );
- w.push_back( 2.6 );
- Var v4( 4, 3 );
- Var v8( 8, 2 );
- Var v7( 7, 2 );
- std::vector<Var> vars;
- vars.push_back( v4 );
- vars.push_back( v8 );
- vars.push_back( v7 );
- Factor x11( vars, w );
- BOOST_CHECK_EQUAL( x11.nrStates(), 12 );
- BOOST_CHECK( x11.vars() == VarSet( vars.begin(), vars.end() ) );
- BOOST_CHECK_EQUAL( x11[0], 0.1 );
- BOOST_CHECK_EQUAL( x11[1], 3.5 );
- BOOST_CHECK_EQUAL( x11[2], 2.8 );
- BOOST_CHECK_EQUAL( x11[3], 7.4 );
- BOOST_CHECK_EQUAL( x11[4], 2.4 );
- BOOST_CHECK_EQUAL( x11[5], 8.9 );
- BOOST_CHECK_EQUAL( x11[6], 6.3 );
- BOOST_CHECK_EQUAL( x11[7], 8.4 );
- BOOST_CHECK_EQUAL( x11[8], 0.0 );
- BOOST_CHECK_EQUAL( x11[9], 1.3 );
- BOOST_CHECK_EQUAL( x11[10], 1.6 );
- BOOST_CHECK_EQUAL( x11[11], 2.6 );
-
- Factor x12( x11 );
- BOOST_CHECK( x12 == x11 );
-
- Factor x13 = x12;
- BOOST_CHECK( x13 == x11 );
-}
-
-
-BOOST_AUTO_TEST_CASE( QueriesTest ) {
- Factor x( Var( 5, 5 ), 0.0 );
- for( size_t i = 0; i < x.nrStates(); i++ )
- x.set( i, 2.0 - i );
-
- // test min, max, sum, sumAbs, maxAbs
- BOOST_CHECK_EQUAL( x.sum(), 0.0 );
- BOOST_CHECK_EQUAL( x.max(), 2.0 );
- BOOST_CHECK_EQUAL( x.min(), -2.0 );
- BOOST_CHECK_EQUAL( x.sumAbs(), 6.0 );
- BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
- x.set( 1, 1.0 );
- BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
- x /= x.sum();
-
- // test entropy
- BOOST_CHECK( x.entropy() < Prob(5).entropy() );
- for( size_t i = 1; i < 100; i++ )
- BOOST_CHECK_CLOSE( Factor( Var(0,i) ).entropy(), std::log((Real)i), tol );
-
- // test hasNaNs and hasNegatives
- BOOST_CHECK( !Factor( 0.0 ).hasNaNs() );
- Real c = 0.0;
- BOOST_CHECK( Factor( c / c ).hasNaNs() );
- BOOST_CHECK( !Factor( 0.0 ).hasNegatives() );
- BOOST_CHECK( !Factor( 1.0 ).hasNegatives() );
- BOOST_CHECK( Factor( -1.0 ).hasNegatives() );
- x.set( 0, 0.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 ); x.set( 3, 1.0 ); x.set( 4, 100.0 );
- BOOST_CHECK( x.hasNegatives() );
- x.set( 2, -INFINITY );
- BOOST_CHECK( x.hasNegatives() );
- x.set( 2, INFINITY );
- BOOST_CHECK( !x.hasNegatives() );
- x.set( 2, -1.0 );
-
- // test strength
- Var x0(0,2);
- Var x1(1,2);
- BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, 1.0 ).strength( x0, x1 ), std::tanh( 1.0 ), tol );
- BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, -1.0 ).strength( x0, x1 ), std::tanh( 1.0 ), tol );
- BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, 0.5 ).strength( x0, x1 ), std::tanh( 0.5 ), tol );
-
- // test ==
- Factor a(Var(0,3)), b(Var(0,3));
- Factor d(Var(1,3));
- BOOST_CHECK( !(a == d) );
- BOOST_CHECK( !(b == d) );
- BOOST_CHECK( a == b );
- a.set( 0, 0.0 );
- BOOST_CHECK( !(a == b) );
- b.set( 2, 0.0 );
- BOOST_CHECK( !(a == b) );
- b.set( 0, 0.0 );
- BOOST_CHECK( !(a == b) );
- a.set( 1, 0.0 );
- BOOST_CHECK( !(a == b) );
- b.set( 1, 0.0 );
- BOOST_CHECK( !(a == b) );
- a.set( 2, 0.0 );
- BOOST_CHECK( a == b );
-}
-
-
-BOOST_AUTO_TEST_CASE( UnaryTransformationsTest ) {
- Var v( 0, 3 );
- Factor x( v );
- x.set( 0, -2.0 );
- x.set( 1, 0.0 );
- x.set( 2, 2.0 );
-
- Factor y = -x;
- BOOST_CHECK_EQUAL( y[0], 2.0 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], -2.0 );
-
- y = x.abs();
- BOOST_CHECK_EQUAL( y[0], 2.0 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], 2.0 );
-
- y = x.exp();
- BOOST_CHECK_CLOSE( y[0], std::exp(-2.0), tol );
- BOOST_CHECK_EQUAL( y[1], 1.0 );
- BOOST_CHECK_CLOSE( y[2], 1.0 / y[0], tol );
-
- y = x.log(false);
- BOOST_CHECK( isnan( y[0] ) );
- BOOST_CHECK_EQUAL( y[1], -INFINITY );
- BOOST_CHECK_CLOSE( y[2], std::log(2.0), tol );
-
- y = x.log(true);
- BOOST_CHECK( isnan( y[0] ) );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], std::log(2.0) );
-
- y = x.inverse(false);
- BOOST_CHECK_EQUAL( y[0], -0.5 );
- BOOST_CHECK_EQUAL( y[1], INFINITY );
- BOOST_CHECK_EQUAL( y[2], 0.5 );
-
- y = x.inverse(true);
- BOOST_CHECK_EQUAL( y[0], -0.5 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], 0.5 );
-
- x.set( 0, 2.0 );
- y = x.normalized();
- BOOST_CHECK_EQUAL( y[0], 0.5 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], 0.5 );
-
- y = x.normalized( NORMPROB );
- BOOST_CHECK_EQUAL( y[0], 0.5 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], 0.5 );
-
- x.set( 0, -2.0 );
- y = x.normalized( NORMLINF );
- BOOST_CHECK_EQUAL( y[0], -1.0 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], 1.0 );
-}
-
-
-BOOST_AUTO_TEST_CASE( UnaryOperationsTest ) {
- Var v( 0, 3 );
- Factor xorg( v );
- xorg.set( 0, 2.0 );
- xorg.set( 1, 0.0 );
- xorg.set( 2, 1.0 );
- Factor y( v );
-
- Factor x = xorg;
- BOOST_CHECK( x.setUniform() == Factor( v ) );
- BOOST_CHECK( x == Factor( v ) );
-
- y.set( 0, std::exp(2.0) );
- y.set( 1, 1.0 );
- y.set( 2, std::exp(1.0) );
- x = xorg;
- BOOST_CHECK( x.takeExp() == y );
- BOOST_CHECK( x == y );
-
- y.set( 0, std::log(2.0) );
- y.set( 1, -INFINITY );
- y.set( 2, 0.0 );
- x = xorg;
- BOOST_CHECK( x.takeLog() == y );
- BOOST_CHECK( x == y );
- x = xorg;
- BOOST_CHECK( x.takeLog(false) == y );
- BOOST_CHECK( x == y );
-
- y.set( 1, 0.0 );
- x = xorg;
- BOOST_CHECK( x.takeLog(true) == y );
- BOOST_CHECK( x == y );
-
- y.set( 0, 2.0 / 3.0 );
- y.set( 1, 0.0 / 3.0 );
- y.set( 2, 1.0 / 3.0 );
- x = xorg;
- BOOST_CHECK_EQUAL( x.normalize(), 3.0 );
- BOOST_CHECK( x == y );
-
- x = xorg;
- BOOST_CHECK_EQUAL( x.normalize( NORMPROB ), 3.0 );
- BOOST_CHECK( x == y );
-
- y.set( 0, 2.0 / 2.0 );
- y.set( 1, 0.0 / 2.0 );
- y.set( 2, 1.0 / 2.0 );
- x = xorg;
- BOOST_CHECK_EQUAL( x.normalize( NORMLINF ), 2.0 );
- BOOST_CHECK( x == y );
-
- xorg.set( 0, -2.0 );
- y.set( 0, 2.0 );
- y.set( 1, 0.0 );
- y.set( 2, 1.0 );
- x = xorg;
- BOOST_CHECK( x.takeAbs() == y );
- BOOST_CHECK( x == y );
-
- for( size_t repeat = 0; repeat < 10000; repeat++ ) {
- x.randomize();
- for( size_t i = 0; i < x.nrStates(); i++ ) {
- BOOST_CHECK( x[i] < 1.0 );
- BOOST_CHECK( x[i] >= 0.0 );
- }
- }
-}
-
-
-BOOST_AUTO_TEST_CASE( ScalarOperationsTest ) {
- Var v( 0, 3 );
- Factor xorg( v ), x( v );
- xorg.set( 0, 2.0 );
- xorg.set( 1, 0.0 );
- xorg.set( 2, 1.0 );
- Factor y( v );
-
- x = xorg;
- BOOST_CHECK( x.fill( 1.0 ) == Factor(v, 1.0) );
- BOOST_CHECK( x == Factor(v, 1.0) );
- BOOST_CHECK( x.fill( 2.0 ) == Factor(v, 2.0) );
- BOOST_CHECK( x == Factor(v, 2.0) );
- BOOST_CHECK( x.fill( 0.0 ) == Factor(v, 0.0) );
- BOOST_CHECK( x == Factor(v, 0.0) );
-
- x = xorg;
- y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
- BOOST_CHECK( (x += 1.0) == y );
- BOOST_CHECK( x == y );
- y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
- BOOST_CHECK( (x += -2.0) == y );
- BOOST_CHECK( x == y );
-
- x = xorg;
- y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
- BOOST_CHECK( (x -= 1.0) == y );
- BOOST_CHECK( x == y );
- y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
- BOOST_CHECK( (x -= -2.0) == y );
- BOOST_CHECK( x == y );
-
- x = xorg;
- BOOST_CHECK( (x *= 1.0) == x );
- BOOST_CHECK( x == x );
- y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
- BOOST_CHECK( (x *= 2.0) == y );
- BOOST_CHECK( x == y );
- y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
- BOOST_CHECK( (x *= -0.25) == y );
- BOOST_CHECK( x == y );
-
- x = xorg;
- BOOST_CHECK( (x /= 1.0) == x );
- BOOST_CHECK( x == x );
- y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
- BOOST_CHECK( (x /= 2.0) == y );
- BOOST_CHECK( x == y );
- y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
- BOOST_CHECK( (x /= -0.25) == y );
- BOOST_CHECK( x == y );
- BOOST_CHECK( (x /= 0.0) == Factor(v, 0.0) );
- BOOST_CHECK( x == Factor(v, 0.0) );
-
- x = xorg;
- BOOST_CHECK( (x ^= 1.0) == x );
- BOOST_CHECK( x == x );
- BOOST_CHECK( (x ^= 0.0) == Factor(v, 1.0) );
- BOOST_CHECK( x == Factor(v, 1.0) );
- x = xorg;
- y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
- BOOST_CHECK( (x ^= 2.0) == y );
- BOOST_CHECK( x == y );
- y.set( 0, 0.5 ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
- BOOST_CHECK( (x ^= -0.5) == y );
- BOOST_CHECK( x == y );
-}
-
-
-BOOST_AUTO_TEST_CASE( ScalarTransformationsTest ) {
- Var v( 0, 3 );
- Factor x( v );
- x.set( 0, 2.0 );
- x.set( 1, 0.0 );
- x.set( 2, 1.0 );
- Factor y( v );
-
- y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
- BOOST_CHECK( (x + 1.0) == y );
- y.set( 0, 0.0 ); y.set( 1, -2.0 ); y.set( 2, -1.0 );
- BOOST_CHECK( (x + (-2.0)) == y );
-
- y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
- BOOST_CHECK( (x - 1.0) == y );
- y.set( 0, 4.0 ); y.set( 1, 2.0 ); y.set( 2, 3.0 );
- BOOST_CHECK( (x - (-2.0)) == y );
-
- BOOST_CHECK( (x * 1.0) == x );
- y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
- BOOST_CHECK( (x * 2.0) == y );
- y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
- BOOST_CHECK( (x * -0.5) == y );
-
- BOOST_CHECK( (x / 1.0) == x );
- y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
- BOOST_CHECK( (x / 2.0) == y );
- y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
- BOOST_CHECK( (x / -0.5) == y );
- BOOST_CHECK( (x / 0.0) == Factor(v, 0.0) );
-
- BOOST_CHECK( (x ^ 1.0) == x );
- BOOST_CHECK( (x ^ 0.0) == Factor(v, 1.0) );
- y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
- BOOST_CHECK( (x ^ 2.0) == y );
- y.set( 0, 1.0 / std::sqrt(2.0) ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
- Factor z = (x ^ -0.5);
- BOOST_CHECK_CLOSE( z[0], y[0], tol );
- BOOST_CHECK_EQUAL( z[1], y[1] );
- BOOST_CHECK_CLOSE( z[2], y[2], tol );
-}
-
-
-BOOST_AUTO_TEST_CASE( SimilarFactorOperationsTest ) {
- size_t N = 6;
- Var v( 0, N );
- Factor xorg( v ), x( v );
- xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, 1.0 ); xorg.set( 3, 0.0 ); xorg.set( 4, 2.0 ); xorg.set( 5, 3.0 );
- Factor y( v );
- y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
- Factor z( v ), r( v );
-
- z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
- x = xorg;
- r = (x += y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.binaryOp( y, std::plus<Real>() ) == z );
- BOOST_CHECK( x == z );
-
- z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
- x = xorg;
- r = (x -= y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.binaryOp( y, std::minus<Real>() ) == z );
- BOOST_CHECK( x == z );
-
- z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
- x = xorg;
- r = (x *= y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.binaryOp( y, std::multiplies<Real>() ) == z );
- BOOST_CHECK( x == z );
-
- z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
- x = xorg;
- r = (x /= y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.binaryOp( y, fo_divides0<Real>() ) == z );
- BOOST_CHECK( x == z );
-}
-
-
-BOOST_AUTO_TEST_CASE( SimilarFactorTransformationsTest ) {
- size_t N = 6;
- Var v( 0, N );
- Factor x( v );
- x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, 1.0 ); x.set( 3, 0.0 ); x.set( 4, 2.0 ); x.set( 5, 3.0 );
- Factor y( v );
- y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
- Factor z( v ), r( v );
-
- z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
- r = x + y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- z = x.binaryTr( y, std::plus<Real>() );
- BOOST_CHECK( r == z );
-
- z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
- r = x - y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- z = x.binaryTr( y, std::minus<Real>() );
- BOOST_CHECK( r == z );
-
- z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
- r = x * y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- z = x.binaryTr( y, std::multiplies<Real>() );
- BOOST_CHECK( r == z );
-
- z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
- r = x / y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- z = x.binaryTr( y, fo_divides0<Real>() );
- BOOST_CHECK( r == z );
-}
-
-
-BOOST_AUTO_TEST_CASE( FactorOperationsTest ) {
- size_t N = 9;
- Var v1( 1, 3 );
- Var v2( 2, 3 );
- Factor xorg( v1 ), x( v1 );
- xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, -1.0 );
- Factor y( v2 );
- y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
- Factor r;
-
- Factor z( VarSet( v1, v2 ) );
- z.set( 0, 2.5 ); z.set( 1, 0.5 ); z.set( 2, -0.5 );
- z.set( 3, 1.0 ); z.set( 4, -1.0 ); z.set( 5, -2.0 );
- z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
- x = xorg;
- r = (x += y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.binaryOp( y, std::plus<Real>() ) == z );
- BOOST_CHECK( x == z );
-
- z.set( 0, 1.5 ); z.set( 1, -0.5 ); z.set( 2, -1.5 );
- z.set( 3, 3.0 ); z.set( 4, 1.0 ); z.set( 5, 0.0 );
- z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
- x = xorg;
- r = (x -= y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.binaryOp( y, std::minus<Real>() ) == z );
- BOOST_CHECK( x == z );
-
- z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, -0.5 );
- z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
- z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
- x = xorg;
- r = (x *= y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.binaryOp( y, std::multiplies<Real>() ) == z );
- BOOST_CHECK( x == z );
-
- z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, -2.0 );
- z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
- z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
- x = xorg;
- r = (x /= y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.binaryOp( y, fo_divides0<Real>() ) == z );
- BOOST_CHECK( x == z );
-}
-
-
-BOOST_AUTO_TEST_CASE( FactorTransformationsTest ) {
- size_t N = 9;
- Var v1( 1, 3 );
- Var v2( 2, 3 );
- Factor x( v1 );
- x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 );
- Factor y( v2 );
- y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
- Factor r;
-
- Factor z( VarSet( v1, v2 ) );
- z.set( 0, 2.5 ); z.set( 1, 0.5 ); z.set( 2, -0.5 );
- z.set( 3, 1.0 ); z.set( 4, -1.0 ); z.set( 5, -2.0 );
- z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
- r = x + y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( r == z );
- z = x.binaryTr( y, std::plus<Real>() );
- BOOST_CHECK( r == z );
-
- z.set( 0, 1.5 ); z.set( 1, -0.5 ); z.set( 2, -1.5 );
- z.set( 3, 3.0 ); z.set( 4, 1.0 ); z.set( 5, 0.0 );
- z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
- r = x - y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( r == z );
- z = x.binaryTr( y, std::minus<Real>() );
- BOOST_CHECK( r == z );
-
- z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, -0.5 );
- z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
- z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
- r = x * y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( r == z );
- z = x.binaryTr( y, std::multiplies<Real>() );
- BOOST_CHECK( r == z );
-
- z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, -2.0 );
- z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
- z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
- r = x / y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( r == z );
- z = x.binaryOp( y, fo_divides0<Real>() );
- BOOST_CHECK( r == z );
-}
-
-
-BOOST_AUTO_TEST_CASE( MiscOperationsTest ) {
- Var v1(1, 2);
- Var v2(2, 3);
- Factor x( VarSet( v1, v2 ) );
- x.randomize();
-
- // slice
- Factor y = x.slice( v1, 0 );
- BOOST_CHECK( y.vars() == VarSet( v2 ) );
- BOOST_CHECK_EQUAL( y.nrStates(), 3 );
- BOOST_CHECK_EQUAL( y[0], x[0] );
- BOOST_CHECK_EQUAL( y[1], x[2] );
- BOOST_CHECK_EQUAL( y[2], x[4] );
- y = x.slice( v1, 1 );
- BOOST_CHECK( y.vars() == VarSet( v2 ) );
- BOOST_CHECK_EQUAL( y.nrStates(), 3 );
- BOOST_CHECK_EQUAL( y[0], x[1] );
- BOOST_CHECK_EQUAL( y[1], x[3] );
- BOOST_CHECK_EQUAL( y[2], x[5] );
- y = x.slice( v2, 0 );
- BOOST_CHECK( y.vars() == VarSet( v1 ) );
- BOOST_CHECK_EQUAL( y.nrStates(), 2 );
- BOOST_CHECK_EQUAL( y[0], x[0] );
- BOOST_CHECK_EQUAL( y[1], x[1] );
- y = x.slice( v2, 1 );
- BOOST_CHECK( y.vars() == VarSet( v1 ) );
- BOOST_CHECK_EQUAL( y.nrStates(), 2 );
- BOOST_CHECK_EQUAL( y[0], x[2] );
- BOOST_CHECK_EQUAL( y[1], x[3] );
- y = x.slice( v2, 2 );
- BOOST_CHECK( y.vars() == VarSet( v1 ) );
- BOOST_CHECK_EQUAL( y.nrStates(), 2 );
- BOOST_CHECK_EQUAL( y[0], x[4] );
- BOOST_CHECK_EQUAL( y[1], x[5] );
- for( size_t i = 0; i < x.nrStates(); i++ ) {
- y = x.slice( VarSet( v1, v2 ), 0 );
- BOOST_CHECK( y.vars() == VarSet() );
- BOOST_CHECK_EQUAL( y.nrStates(), 1 );
- BOOST_CHECK_EQUAL( y[0], x[0] );
- }
- y = x.slice( VarSet(), 0 );
- BOOST_CHECK_EQUAL( y, x );
-
- // embed
- Var v3(3, 4);
- BOOST_CHECK_THROW( x.embed( VarSet( v3 ) ), Exception );
- BOOST_CHECK_THROW( x.embed( VarSet( v3, v2 ) ), Exception );
- y = x.embed( VarSet( v3, v2 ) | v1 );
- for( size_t i = 0; i < y.nrStates(); i++ )
- BOOST_CHECK_EQUAL( y[i], x[i % 6] );
- y = x.embed( VarSet( v1, v2 ) );
- BOOST_CHECK_EQUAL( x, y );
-
- // marginal
- y = x.marginal( v1 );
- BOOST_CHECK( y.vars() == VarSet( v1 ) );
- BOOST_CHECK_EQUAL( y[0], (x[0] + x[2] + x[4]) / x.sum() );
- BOOST_CHECK_EQUAL( y[1], (x[1] + x[3] + x[5]) / x.sum() );
- y = x.marginal( v2 );
- BOOST_CHECK( y.vars() == VarSet( v2 ) );
- BOOST_CHECK_CLOSE( y[0], (x[0] + x[1]) / x.sum(), tol );
- BOOST_CHECK_CLOSE( y[1], (x[2] + x[3]) / x.sum(), tol );
- BOOST_CHECK_CLOSE( y[2], (x[4] + x[5]) / x.sum(), tol );
- y = x.marginal( VarSet() );
- BOOST_CHECK( y.vars() == VarSet() );
- BOOST_CHECK_EQUAL( y[0], 1.0 );
- y = x.marginal( VarSet( v1, v2 ) );
- BOOST_CHECK( y == x.normalized() );
-
- y = x.marginal( v1, true );
- BOOST_CHECK( y.vars() == VarSet( v1 ) );
- BOOST_CHECK_EQUAL( y[0], (x[0] + x[2] + x[4]) / x.sum() );
- BOOST_CHECK_EQUAL( y[1], (x[1] + x[3] + x[5]) / x.sum() );
- y = x.marginal( v2, true );
- BOOST_CHECK( y.vars() == VarSet( v2 ) );
- BOOST_CHECK_CLOSE( y[0], (x[0] + x[1]) / x.sum(), tol );
- BOOST_CHECK_CLOSE( y[1], (x[2] + x[3]) / x.sum(), tol );
- BOOST_CHECK_CLOSE( y[2], (x[4] + x[5]) / x.sum(), tol );
- y = x.marginal( VarSet(), true );
- BOOST_CHECK( y.vars() == VarSet() );
- BOOST_CHECK_EQUAL( y[0], 1.0 );
- y = x.marginal( VarSet( v1, v2 ), true );
- BOOST_CHECK( y == x.normalized() );
-
- y = x.marginal( v1, false );
- BOOST_CHECK( y.vars() == VarSet( v1 ) );
- BOOST_CHECK_EQUAL( y[0], x[0] + x[2] + x[4] );
- BOOST_CHECK_EQUAL( y[1], x[1] + x[3] + x[5] );
- y = x.marginal( v2, false );
- BOOST_CHECK( y.vars() == VarSet( v2 ) );
- BOOST_CHECK_EQUAL( y[0], x[0] + x[1] );
- BOOST_CHECK_EQUAL( y[1], x[2] + x[3] );
- BOOST_CHECK_EQUAL( y[2], x[4] + x[5] );
- y = x.marginal( VarSet(), false );
- BOOST_CHECK( y.vars() == VarSet() );
- BOOST_CHECK_EQUAL( y[0], x.sum() );
- y = x.marginal( VarSet( v1, v2 ), false );
- BOOST_CHECK( y == x );
-
- // maxMarginal
- y = x.maxMarginal( v1 );
- BOOST_CHECK( y.vars() == VarSet( v1 ) );
- BOOST_CHECK_EQUAL( y[0], x.slice( v1, 0 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()) );
- BOOST_CHECK_EQUAL( y[1], x.slice( v1, 1 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()) );
- y = x.maxMarginal( v2 );
- BOOST_CHECK( y.vars() == VarSet( v2 ) );
- BOOST_CHECK_EQUAL( y[0], x.slice( v2, 0 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
- BOOST_CHECK_EQUAL( y[1], x.slice( v2, 1 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
- BOOST_CHECK_EQUAL( y[2], x.slice( v2, 2 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
- y = x.maxMarginal( VarSet() );
- BOOST_CHECK( y.vars() == VarSet() );
- BOOST_CHECK_EQUAL( y[0], 1.0 );
- y = x.maxMarginal( VarSet( v1, v2 ) );
- BOOST_CHECK( y == x.normalized() );
-
- y = x.maxMarginal( v1, true );
- BOOST_CHECK( y.vars() == VarSet( v1 ) );
- BOOST_CHECK_EQUAL( y[0], x.slice( v1, 0 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()) );
- BOOST_CHECK_EQUAL( y[1], x.slice( v1, 1 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()) );
- y = x.maxMarginal( v2, true );
- BOOST_CHECK( y.vars() == VarSet( v2 ) );
- BOOST_CHECK_EQUAL( y[0], x.slice( v2, 0 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
- BOOST_CHECK_EQUAL( y[1], x.slice( v2, 1 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
- BOOST_CHECK_EQUAL( y[2], x.slice( v2, 2 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
- y = x.maxMarginal( VarSet(), true );
- BOOST_CHECK( y.vars() == VarSet() );
- BOOST_CHECK_EQUAL( y[0], 1.0 );
- y = x.maxMarginal( VarSet( v1, v2 ), true );
- BOOST_CHECK( y == x.normalized() );
-
- y = x.maxMarginal( v1, false );
- BOOST_CHECK( y.vars() == VarSet( v1 ) );
- BOOST_CHECK_EQUAL( y[0], x.slice( v1, 0 ).max() );
- BOOST_CHECK_EQUAL( y[1], x.slice( v1, 1 ).max() );
- y = x.maxMarginal( v2, false );
- BOOST_CHECK( y.vars() == VarSet( v2 ) );
- BOOST_CHECK_EQUAL( y[0], x.slice( v2, 0 ).max() );
- BOOST_CHECK_EQUAL( y[1], x.slice( v2, 1 ).max() );
- BOOST_CHECK_EQUAL( y[2], x.slice( v2, 2 ).max() );
- y = x.maxMarginal( VarSet(), false );
- BOOST_CHECK( y.vars() == VarSet() );
- BOOST_CHECK_EQUAL( y[0], x.max() );
- y = x.maxMarginal( VarSet( v1, v2 ), false );
- BOOST_CHECK( y == x );
-}
-
-
-BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
- Var v( 0, 3 );
- Factor x(v), y(v), z(v);
- x.set( 0, 0.2 );
- x.set( 1, 0.8 );
- x.set( 2, 0.0 );
- y.set( 0, 0.0 );
- y.set( 1, 0.6 );
- y.set( 2, 0.4 );
-
- z = min( x, y );
- BOOST_CHECK_EQUAL( z[0], 0.0 );
- BOOST_CHECK_EQUAL( z[1], 0.6 );
- BOOST_CHECK_EQUAL( z[2], 0.0 );
- z = max( x, y );
- BOOST_CHECK_EQUAL( z[0], 0.2 );
- BOOST_CHECK_EQUAL( z[1], 0.8 );
- BOOST_CHECK_EQUAL( z[2], 0.4 );
-
- BOOST_CHECK_EQUAL( dist( x, x, DISTL1 ), 0.0 );
- BOOST_CHECK_EQUAL( dist( y, y, DISTL1 ), 0.0 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), 0.2 + 0.2 + 0.4 );
- BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), 0.2 + 0.2 + 0.4 );
- BOOST_CHECK_EQUAL( dist( x, x, DISTLINF ), 0.0 );
- BOOST_CHECK_EQUAL( dist( y, y, DISTLINF ), 0.0 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), 0.4 );
- BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), 0.4 );
- BOOST_CHECK_EQUAL( dist( x, x, DISTTV ), 0.0 );
- BOOST_CHECK_EQUAL( dist( y, y, DISTTV ), 0.0 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
- BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
- BOOST_CHECK_EQUAL( dist( x, x, DISTKL ), 0.0 );
- BOOST_CHECK_EQUAL( dist( y, y, DISTKL ), 0.0 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), INFINITY );
- BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), INFINITY );
- BOOST_CHECK_EQUAL( dist( x, x, DISTHEL ), 0.0 );
- BOOST_CHECK_EQUAL( dist( y, y, DISTHEL ), 0.0 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
- BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
- x.set( 1, 0.7 ); x.set( 2, 0.1 );
- y.set( 0, 0.1 ); y.set( 1, 0.5 );
- BOOST_CHECK_CLOSE( dist( x, y, DISTKL ), 0.2 * std::log(0.2 / 0.1) + 0.7 * std::log(0.7 / 0.5) + 0.1 * std::log(0.1 / 0.4), tol );
- BOOST_CHECK_CLOSE( dist( y, x, DISTKL ), 0.1 * std::log(0.1 / 0.2) + 0.5 * std::log(0.5 / 0.7) + 0.4 * std::log(0.4 / 0.1), tol );
-
- std::stringstream ss;
- ss << x;
- std::string s;
- std::getline( ss, s );
- BOOST_CHECK_EQUAL( s, std::string("({x0}, (0.2, 0.7, 0.1))") );
- std::stringstream ss2;
- ss2 << y;
- std::getline( ss2, s );
- BOOST_CHECK_EQUAL( s, std::string("({x0}, (0.1, 0.5, 0.4))") );
-
- z = min( x, y );
- BOOST_CHECK_EQUAL( z[0], 0.1 );
- BOOST_CHECK_EQUAL( z[1], 0.5 );
- BOOST_CHECK_EQUAL( z[2], 0.1 );
- z = max( x, y );
- BOOST_CHECK_EQUAL( z[0], 0.2 );
- BOOST_CHECK_EQUAL( z[1], 0.7 );
- BOOST_CHECK_EQUAL( z[2], 0.4 );
-
- for( double J = -1.0; J <= 1.01; J += 0.1 ) {
- Factor x = createFactorIsing( Var(0,2), Var(1,2), J ).normalized();
- BOOST_CHECK_CLOSE( x[0], std::exp(J) / (4.0 * std::cosh(J)), tol );
- BOOST_CHECK_CLOSE( x[1], std::exp(-J) / (4.0 * std::cosh(J)), tol );
- BOOST_CHECK_CLOSE( x[2], std::exp(-J) / (4.0 * std::cosh(J)), tol );
- BOOST_CHECK_CLOSE( x[3], std::exp(J) / (4.0 * std::cosh(J)), tol );
- BOOST_CHECK_SMALL( MutualInfo( x ) - (J * std::tanh(J) - std::log(std::cosh(J))), tol );
- }
- Var v1( 1, 3 );
- Var v2( 2, 4 );
- BOOST_CHECK_SMALL( MutualInfo( (Factor(v1).randomize() * Factor(v2).randomize()).normalized() ), tol );
- BOOST_CHECK_THROW( MutualInfo( createFactorIsing( Var(0,2), 1.0 ).normalized() ), Exception );
- BOOST_CHECK_THROW( createFactorIsing( v1, 0.0 ), Exception );
- BOOST_CHECK_THROW( createFactorIsing( v1, v2, 0.0 ), Exception );
- for( double J = -1.0; J <= 1.01; J += 0.1 ) {
- Factor x = createFactorIsing( Var(0,2), J ).normalized();
- BOOST_CHECK_CLOSE( x[0], std::exp(-J) / (2.0 * std::cosh(J)), tol );
- BOOST_CHECK_CLOSE( x[1], std::exp(J) / (2.0 * std::cosh(J)), tol );
- BOOST_CHECK_SMALL( x.entropy() - (-J * std::tanh(J) + std::log(2.0 * std::cosh(J))), tol );
- }
-
- x = createFactorDelta( v1, 2 );
- BOOST_CHECK_EQUAL( x[0], 0.0 );
- BOOST_CHECK_EQUAL( x[1], 0.0 );
- BOOST_CHECK_EQUAL( x[2], 1.0 );
- x = createFactorDelta( v1, 1 );
- BOOST_CHECK_EQUAL( x[0], 0.0 );
- BOOST_CHECK_EQUAL( x[1], 1.0 );
- BOOST_CHECK_EQUAL( x[2], 0.0 );
- x = createFactorDelta( v1, 0 );
- BOOST_CHECK_EQUAL( x[0], 1.0 );
- BOOST_CHECK_EQUAL( x[1], 0.0 );
- BOOST_CHECK_EQUAL( x[2], 0.0 );
- BOOST_CHECK_THROW( createFactorDelta( v1, 4 ), Exception );
-}
--- /dev/null
+/* This file is part of libDAI - http://www.libdai.org/
+ *
+ * libDAI is licensed under the terms of the GNU General Public License version
+ * 2, or (at your option) any later version. libDAI is distributed without any
+ * warranty. See the file COPYING for more details.
+ *
+ * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/factor.h>
+#include <strstream>
+
+
+using namespace dai;
+
+
+const double tol = 1e-8;
+
+
+#define BOOST_TEST_MODULE FactorTest
+
+
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+
+
+BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
+ // check constructors
+ Factor x1;
+ BOOST_CHECK_EQUAL( x1.nrStates(), 1 );
+ BOOST_CHECK( x1.p() == Prob( 1, 1.0 ) );
+ BOOST_CHECK( x1.vars() == VarSet() );
+
+ Factor x2( 5.0 );
+ BOOST_CHECK_EQUAL( x2.nrStates(), 1 );
+ BOOST_CHECK( x2.p() == Prob( 1, 5.0 ) );
+ BOOST_CHECK( x2.vars() == VarSet() );
+
+ Var v1( 0, 3 );
+ Factor x3( v1 );
+ BOOST_CHECK_EQUAL( x3.nrStates(), 3 );
+ BOOST_CHECK( x3.p() == Prob( 3, 1.0 / 3.0 ) );
+ BOOST_CHECK( x3.vars() == VarSet( v1 ) );
+ BOOST_CHECK_EQUAL( x3[0], 1.0 / 3.0 );
+ BOOST_CHECK_EQUAL( x3[1], 1.0 / 3.0 );
+ BOOST_CHECK_EQUAL( x3[2], 1.0 / 3.0 );
+
+ Var v2( 1, 2 );
+ Factor x4( VarSet( v1, v2 ) );
+ BOOST_CHECK_EQUAL( x4.nrStates(), 6 );
+ BOOST_CHECK( x4.p() == Prob( 6, 1.0 / 6.0 ) );
+ BOOST_CHECK( x4.vars() == VarSet( v1, v2 ) );
+ for( size_t i = 0; i < 6; i++ )
+ BOOST_CHECK_EQUAL( x4[i], 1.0 / 6.0 );
+
+ Factor x5( VarSet( v1, v2 ), 1.0 );
+ BOOST_CHECK_EQUAL( x5.nrStates(), 6 );
+ BOOST_CHECK( x5.p() == Prob( 6, 1.0 ) );
+ BOOST_CHECK( x5.vars() == VarSet( v1, v2 ) );
+ for( size_t i = 0; i < 6; i++ )
+ BOOST_CHECK_EQUAL( x5[i], 1.0 );
+
+ std::vector<Real> x( 6, 1.0 );
+ for( size_t i = 0; i < 6; i++ )
+ x[i] = 10.0 - i;
+ Factor x6( VarSet( v1, v2 ), x );
+ BOOST_CHECK_EQUAL( x6.nrStates(), 6 );
+ BOOST_CHECK( x6.vars() == VarSet( v1, v2 ) );
+ for( size_t i = 0; i < 6; i++ )
+ BOOST_CHECK_EQUAL( x6[i], x[i] );
+
+ x.resize( 4 );
+ BOOST_CHECK_THROW( Factor x7( VarSet( v1, v2 ), x ), Exception );
+
+ x.resize( 6 );
+ x[4] = 10.0 - 4; x[5] = 10.0 - 5;
+ Factor x8( VarSet( v2, v1 ), &(x[0]) );
+ BOOST_CHECK_EQUAL( x8.nrStates(), 6 );
+ BOOST_CHECK( x8.vars() == VarSet( v1, v2 ) );
+ for( size_t i = 0; i < 6; i++ )
+ BOOST_CHECK_EQUAL( x8[i], x[i] );
+
+ Prob xx( x );
+ Factor x9( VarSet( v2, v1 ), xx );
+ BOOST_CHECK_EQUAL( x9.nrStates(), 6 );
+ BOOST_CHECK( x9.vars() == VarSet( v1, v2 ) );
+ for( size_t i = 0; i < 6; i++ )
+ BOOST_CHECK_EQUAL( x9[i], x[i] );
+
+ xx.resize( 4 );
+ BOOST_CHECK_THROW( Factor x10( VarSet( v2, v1 ), xx ), Exception );
+
+ std::vector<Real> w;
+ w.push_back( 0.1 );
+ w.push_back( 3.5 );
+ w.push_back( 2.8 );
+ w.push_back( 6.3 );
+ w.push_back( 8.4 );
+ w.push_back( 0.0 );
+ w.push_back( 7.4 );
+ w.push_back( 2.4 );
+ w.push_back( 8.9 );
+ w.push_back( 1.3 );
+ w.push_back( 1.6 );
+ w.push_back( 2.6 );
+ Var v4( 4, 3 );
+ Var v8( 8, 2 );
+ Var v7( 7, 2 );
+ std::vector<Var> vars;
+ vars.push_back( v4 );
+ vars.push_back( v8 );
+ vars.push_back( v7 );
+ Factor x11( vars, w );
+ BOOST_CHECK_EQUAL( x11.nrStates(), 12 );
+ BOOST_CHECK( x11.vars() == VarSet( vars.begin(), vars.end() ) );
+ BOOST_CHECK_EQUAL( x11[0], 0.1 );
+ BOOST_CHECK_EQUAL( x11[1], 3.5 );
+ BOOST_CHECK_EQUAL( x11[2], 2.8 );
+ BOOST_CHECK_EQUAL( x11[3], 7.4 );
+ BOOST_CHECK_EQUAL( x11[4], 2.4 );
+ BOOST_CHECK_EQUAL( x11[5], 8.9 );
+ BOOST_CHECK_EQUAL( x11[6], 6.3 );
+ BOOST_CHECK_EQUAL( x11[7], 8.4 );
+ BOOST_CHECK_EQUAL( x11[8], 0.0 );
+ BOOST_CHECK_EQUAL( x11[9], 1.3 );
+ BOOST_CHECK_EQUAL( x11[10], 1.6 );
+ BOOST_CHECK_EQUAL( x11[11], 2.6 );
+
+ Factor x12( x11 );
+ BOOST_CHECK( x12 == x11 );
+
+ Factor x13 = x12;
+ BOOST_CHECK( x13 == x11 );
+}
+
+
+BOOST_AUTO_TEST_CASE( QueriesTest ) {
+ Factor x( Var( 5, 5 ), 0.0 );
+ for( size_t i = 0; i < x.nrStates(); i++ )
+ x.set( i, 2.0 - i );
+
+ // test min, max, sum, sumAbs, maxAbs
+ BOOST_CHECK_EQUAL( x.sum(), 0.0 );
+ BOOST_CHECK_EQUAL( x.max(), 2.0 );
+ BOOST_CHECK_EQUAL( x.min(), -2.0 );
+ BOOST_CHECK_EQUAL( x.sumAbs(), 6.0 );
+ BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
+ x.set( 1, 1.0 );
+ BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
+ x /= x.sum();
+
+ // test entropy
+ BOOST_CHECK( x.entropy() < Prob(5).entropy() );
+ for( size_t i = 1; i < 100; i++ )
+ BOOST_CHECK_CLOSE( Factor( Var(0,i) ).entropy(), std::log((Real)i), tol );
+
+ // test hasNaNs and hasNegatives
+ BOOST_CHECK( !Factor( 0.0 ).hasNaNs() );
+ Real c = 0.0;
+ BOOST_CHECK( Factor( c / c ).hasNaNs() );
+ BOOST_CHECK( !Factor( 0.0 ).hasNegatives() );
+ BOOST_CHECK( !Factor( 1.0 ).hasNegatives() );
+ BOOST_CHECK( Factor( -1.0 ).hasNegatives() );
+ x.set( 0, 0.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 ); x.set( 3, 1.0 ); x.set( 4, 100.0 );
+ BOOST_CHECK( x.hasNegatives() );
+ x.set( 2, -INFINITY );
+ BOOST_CHECK( x.hasNegatives() );
+ x.set( 2, INFINITY );
+ BOOST_CHECK( !x.hasNegatives() );
+ x.set( 2, -1.0 );
+
+ // test strength
+ Var x0(0,2);
+ Var x1(1,2);
+ BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, 1.0 ).strength( x0, x1 ), std::tanh( 1.0 ), tol );
+ BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, -1.0 ).strength( x0, x1 ), std::tanh( 1.0 ), tol );
+ BOOST_CHECK_CLOSE( createFactorIsing( x0, x1, 0.5 ).strength( x0, x1 ), std::tanh( 0.5 ), tol );
+
+ // test ==
+ Factor a(Var(0,3)), b(Var(0,3));
+ Factor d(Var(1,3));
+ BOOST_CHECK( !(a == d) );
+ BOOST_CHECK( !(b == d) );
+ BOOST_CHECK( a == b );
+ a.set( 0, 0.0 );
+ BOOST_CHECK( !(a == b) );
+ b.set( 2, 0.0 );
+ BOOST_CHECK( !(a == b) );
+ b.set( 0, 0.0 );
+ BOOST_CHECK( !(a == b) );
+ a.set( 1, 0.0 );
+ BOOST_CHECK( !(a == b) );
+ b.set( 1, 0.0 );
+ BOOST_CHECK( !(a == b) );
+ a.set( 2, 0.0 );
+ BOOST_CHECK( a == b );
+}
+
+
+BOOST_AUTO_TEST_CASE( UnaryTransformationsTest ) {
+ Var v( 0, 3 );
+ Factor x( v );
+ x.set( 0, -2.0 );
+ x.set( 1, 0.0 );
+ x.set( 2, 2.0 );
+
+ Factor y = -x;
+ BOOST_CHECK_EQUAL( y[0], 2.0 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], -2.0 );
+
+ y = x.abs();
+ BOOST_CHECK_EQUAL( y[0], 2.0 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], 2.0 );
+
+ y = x.exp();
+ BOOST_CHECK_CLOSE( y[0], std::exp(-2.0), tol );
+ BOOST_CHECK_EQUAL( y[1], 1.0 );
+ BOOST_CHECK_CLOSE( y[2], 1.0 / y[0], tol );
+
+ y = x.log(false);
+ BOOST_CHECK( isnan( y[0] ) );
+ BOOST_CHECK_EQUAL( y[1], -INFINITY );
+ BOOST_CHECK_CLOSE( y[2], std::log(2.0), tol );
+
+ y = x.log(true);
+ BOOST_CHECK( isnan( y[0] ) );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], std::log(2.0) );
+
+ y = x.inverse(false);
+ BOOST_CHECK_EQUAL( y[0], -0.5 );
+ BOOST_CHECK_EQUAL( y[1], INFINITY );
+ BOOST_CHECK_EQUAL( y[2], 0.5 );
+
+ y = x.inverse(true);
+ BOOST_CHECK_EQUAL( y[0], -0.5 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], 0.5 );
+
+ x.set( 0, 2.0 );
+ y = x.normalized();
+ BOOST_CHECK_EQUAL( y[0], 0.5 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], 0.5 );
+
+ y = x.normalized( NORMPROB );
+ BOOST_CHECK_EQUAL( y[0], 0.5 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], 0.5 );
+
+ x.set( 0, -2.0 );
+ y = x.normalized( NORMLINF );
+ BOOST_CHECK_EQUAL( y[0], -1.0 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], 1.0 );
+}
+
+
+BOOST_AUTO_TEST_CASE( UnaryOperationsTest ) {
+ Var v( 0, 3 );
+ Factor xorg( v );
+ xorg.set( 0, 2.0 );
+ xorg.set( 1, 0.0 );
+ xorg.set( 2, 1.0 );
+ Factor y( v );
+
+ Factor x = xorg;
+ BOOST_CHECK( x.setUniform() == Factor( v ) );
+ BOOST_CHECK( x == Factor( v ) );
+
+ y.set( 0, std::exp(2.0) );
+ y.set( 1, 1.0 );
+ y.set( 2, std::exp(1.0) );
+ x = xorg;
+ BOOST_CHECK( x.takeExp() == y );
+ BOOST_CHECK( x == y );
+
+ y.set( 0, std::log(2.0) );
+ y.set( 1, -INFINITY );
+ y.set( 2, 0.0 );
+ x = xorg;
+ BOOST_CHECK( x.takeLog() == y );
+ BOOST_CHECK( x == y );
+ x = xorg;
+ BOOST_CHECK( x.takeLog(false) == y );
+ BOOST_CHECK( x == y );
+
+ y.set( 1, 0.0 );
+ x = xorg;
+ BOOST_CHECK( x.takeLog(true) == y );
+ BOOST_CHECK( x == y );
+
+ y.set( 0, 2.0 / 3.0 );
+ y.set( 1, 0.0 / 3.0 );
+ y.set( 2, 1.0 / 3.0 );
+ x = xorg;
+ BOOST_CHECK_EQUAL( x.normalize(), 3.0 );
+ BOOST_CHECK( x == y );
+
+ x = xorg;
+ BOOST_CHECK_EQUAL( x.normalize( NORMPROB ), 3.0 );
+ BOOST_CHECK( x == y );
+
+ y.set( 0, 2.0 / 2.0 );
+ y.set( 1, 0.0 / 2.0 );
+ y.set( 2, 1.0 / 2.0 );
+ x = xorg;
+ BOOST_CHECK_EQUAL( x.normalize( NORMLINF ), 2.0 );
+ BOOST_CHECK( x == y );
+
+ xorg.set( 0, -2.0 );
+ y.set( 0, 2.0 );
+ y.set( 1, 0.0 );
+ y.set( 2, 1.0 );
+ x = xorg;
+ BOOST_CHECK( x.takeAbs() == y );
+ BOOST_CHECK( x == y );
+
+ for( size_t repeat = 0; repeat < 10000; repeat++ ) {
+ x.randomize();
+ for( size_t i = 0; i < x.nrStates(); i++ ) {
+ BOOST_CHECK( x[i] < 1.0 );
+ BOOST_CHECK( x[i] >= 0.0 );
+ }
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE( ScalarOperationsTest ) {
+ Var v( 0, 3 );
+ Factor xorg( v ), x( v );
+ xorg.set( 0, 2.0 );
+ xorg.set( 1, 0.0 );
+ xorg.set( 2, 1.0 );
+ Factor y( v );
+
+ x = xorg;
+ BOOST_CHECK( x.fill( 1.0 ) == Factor(v, 1.0) );
+ BOOST_CHECK( x == Factor(v, 1.0) );
+ BOOST_CHECK( x.fill( 2.0 ) == Factor(v, 2.0) );
+ BOOST_CHECK( x == Factor(v, 2.0) );
+ BOOST_CHECK( x.fill( 0.0 ) == Factor(v, 0.0) );
+ BOOST_CHECK( x == Factor(v, 0.0) );
+
+ x = xorg;
+ y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
+ BOOST_CHECK( (x += 1.0) == y );
+ BOOST_CHECK( x == y );
+ y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
+ BOOST_CHECK( (x += -2.0) == y );
+ BOOST_CHECK( x == y );
+
+ x = xorg;
+ y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
+ BOOST_CHECK( (x -= 1.0) == y );
+ BOOST_CHECK( x == y );
+ y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
+ BOOST_CHECK( (x -= -2.0) == y );
+ BOOST_CHECK( x == y );
+
+ x = xorg;
+ BOOST_CHECK( (x *= 1.0) == x );
+ BOOST_CHECK( x == x );
+ y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
+ BOOST_CHECK( (x *= 2.0) == y );
+ BOOST_CHECK( x == y );
+ y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
+ BOOST_CHECK( (x *= -0.25) == y );
+ BOOST_CHECK( x == y );
+
+ x = xorg;
+ BOOST_CHECK( (x /= 1.0) == x );
+ BOOST_CHECK( x == x );
+ y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
+ BOOST_CHECK( (x /= 2.0) == y );
+ BOOST_CHECK( x == y );
+ y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
+ BOOST_CHECK( (x /= -0.25) == y );
+ BOOST_CHECK( x == y );
+ BOOST_CHECK( (x /= 0.0) == Factor(v, 0.0) );
+ BOOST_CHECK( x == Factor(v, 0.0) );
+
+ x = xorg;
+ BOOST_CHECK( (x ^= 1.0) == x );
+ BOOST_CHECK( x == x );
+ BOOST_CHECK( (x ^= 0.0) == Factor(v, 1.0) );
+ BOOST_CHECK( x == Factor(v, 1.0) );
+ x = xorg;
+ y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
+ BOOST_CHECK( (x ^= 2.0) == y );
+ BOOST_CHECK( x == y );
+ y.set( 0, 0.5 ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
+ BOOST_CHECK( (x ^= -0.5) == y );
+ BOOST_CHECK( x == y );
+}
+
+
+BOOST_AUTO_TEST_CASE( ScalarTransformationsTest ) {
+ Var v( 0, 3 );
+ Factor x( v );
+ x.set( 0, 2.0 );
+ x.set( 1, 0.0 );
+ x.set( 2, 1.0 );
+ Factor y( v );
+
+ y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
+ BOOST_CHECK( (x + 1.0) == y );
+ y.set( 0, 0.0 ); y.set( 1, -2.0 ); y.set( 2, -1.0 );
+ BOOST_CHECK( (x + (-2.0)) == y );
+
+ y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
+ BOOST_CHECK( (x - 1.0) == y );
+ y.set( 0, 4.0 ); y.set( 1, 2.0 ); y.set( 2, 3.0 );
+ BOOST_CHECK( (x - (-2.0)) == y );
+
+ BOOST_CHECK( (x * 1.0) == x );
+ y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
+ BOOST_CHECK( (x * 2.0) == y );
+ y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
+ BOOST_CHECK( (x * -0.5) == y );
+
+ BOOST_CHECK( (x / 1.0) == x );
+ y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
+ BOOST_CHECK( (x / 2.0) == y );
+ y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
+ BOOST_CHECK( (x / -0.5) == y );
+ BOOST_CHECK( (x / 0.0) == Factor(v, 0.0) );
+
+ BOOST_CHECK( (x ^ 1.0) == x );
+ BOOST_CHECK( (x ^ 0.0) == Factor(v, 1.0) );
+ y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
+ BOOST_CHECK( (x ^ 2.0) == y );
+ y.set( 0, 1.0 / std::sqrt(2.0) ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
+ Factor z = (x ^ -0.5);
+ BOOST_CHECK_CLOSE( z[0], y[0], tol );
+ BOOST_CHECK_EQUAL( z[1], y[1] );
+ BOOST_CHECK_CLOSE( z[2], y[2], tol );
+}
+
+
+BOOST_AUTO_TEST_CASE( SimilarFactorOperationsTest ) {
+ size_t N = 6;
+ Var v( 0, N );
+ Factor xorg( v ), x( v );
+ xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, 1.0 ); xorg.set( 3, 0.0 ); xorg.set( 4, 2.0 ); xorg.set( 5, 3.0 );
+ Factor y( v );
+ y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
+ Factor z( v ), r( v );
+
+ z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
+ x = xorg;
+ r = (x += y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.binaryOp( y, std::plus<Real>() ) == z );
+ BOOST_CHECK( x == z );
+
+ z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
+ x = xorg;
+ r = (x -= y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.binaryOp( y, std::minus<Real>() ) == z );
+ BOOST_CHECK( x == z );
+
+ z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
+ x = xorg;
+ r = (x *= y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.binaryOp( y, std::multiplies<Real>() ) == z );
+ BOOST_CHECK( x == z );
+
+ z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
+ x = xorg;
+ r = (x /= y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.binaryOp( y, fo_divides0<Real>() ) == z );
+ BOOST_CHECK( x == z );
+}
+
+
+BOOST_AUTO_TEST_CASE( SimilarFactorTransformationsTest ) {
+ size_t N = 6;
+ Var v( 0, N );
+ Factor x( v );
+ x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, 1.0 ); x.set( 3, 0.0 ); x.set( 4, 2.0 ); x.set( 5, 3.0 );
+ Factor y( v );
+ y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
+ Factor z( v ), r( v );
+
+ z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
+ r = x + y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ z = x.binaryTr( y, std::plus<Real>() );
+ BOOST_CHECK( r == z );
+
+ z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
+ r = x - y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ z = x.binaryTr( y, std::minus<Real>() );
+ BOOST_CHECK( r == z );
+
+ z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
+ r = x * y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ z = x.binaryTr( y, std::multiplies<Real>() );
+ BOOST_CHECK( r == z );
+
+ z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
+ r = x / y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ z = x.binaryTr( y, fo_divides0<Real>() );
+ BOOST_CHECK( r == z );
+}
+
+
+BOOST_AUTO_TEST_CASE( FactorOperationsTest ) {
+ size_t N = 9;
+ Var v1( 1, 3 );
+ Var v2( 2, 3 );
+ Factor xorg( v1 ), x( v1 );
+ xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, -1.0 );
+ Factor y( v2 );
+ y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
+ Factor r;
+
+ Factor z( VarSet( v1, v2 ) );
+ z.set( 0, 2.5 ); z.set( 1, 0.5 ); z.set( 2, -0.5 );
+ z.set( 3, 1.0 ); z.set( 4, -1.0 ); z.set( 5, -2.0 );
+ z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
+ x = xorg;
+ r = (x += y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.binaryOp( y, std::plus<Real>() ) == z );
+ BOOST_CHECK( x == z );
+
+ z.set( 0, 1.5 ); z.set( 1, -0.5 ); z.set( 2, -1.5 );
+ z.set( 3, 3.0 ); z.set( 4, 1.0 ); z.set( 5, 0.0 );
+ z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
+ x = xorg;
+ r = (x -= y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.binaryOp( y, std::minus<Real>() ) == z );
+ BOOST_CHECK( x == z );
+
+ z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, -0.5 );
+ z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
+ z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
+ x = xorg;
+ r = (x *= y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.binaryOp( y, std::multiplies<Real>() ) == z );
+ BOOST_CHECK( x == z );
+
+ z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, -2.0 );
+ z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
+ z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
+ x = xorg;
+ r = (x /= y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.binaryOp( y, fo_divides0<Real>() ) == z );
+ BOOST_CHECK( x == z );
+}
+
+
+BOOST_AUTO_TEST_CASE( FactorTransformationsTest ) {
+ size_t N = 9;
+ Var v1( 1, 3 );
+ Var v2( 2, 3 );
+ Factor x( v1 );
+ x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 );
+ Factor y( v2 );
+ y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
+ Factor r;
+
+ Factor z( VarSet( v1, v2 ) );
+ z.set( 0, 2.5 ); z.set( 1, 0.5 ); z.set( 2, -0.5 );
+ z.set( 3, 1.0 ); z.set( 4, -1.0 ); z.set( 5, -2.0 );
+ z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
+ r = x + y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( r == z );
+ z = x.binaryTr( y, std::plus<Real>() );
+ BOOST_CHECK( r == z );
+
+ z.set( 0, 1.5 ); z.set( 1, -0.5 ); z.set( 2, -1.5 );
+ z.set( 3, 3.0 ); z.set( 4, 1.0 ); z.set( 5, 0.0 );
+ z.set( 6, 2.0 ); z.set( 7, 0.0 ); z.set( 8, -1.0 );
+ r = x - y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( r == z );
+ z = x.binaryTr( y, std::minus<Real>() );
+ BOOST_CHECK( r == z );
+
+ z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, -0.5 );
+ z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
+ z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
+ r = x * y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( r == z );
+ z = x.binaryTr( y, std::multiplies<Real>() );
+ BOOST_CHECK( r == z );
+
+ z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, -2.0 );
+ z.set( 3, -2.0 ); z.set( 4, 0.0 ); z.set( 5, 1.0 );
+ z.set( 6, 0.0 ); z.set( 7, 0.0 ); z.set( 8, 0.0 );
+ r = x / y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( r == z );
+ z = x.binaryOp( y, fo_divides0<Real>() );
+ BOOST_CHECK( r == z );
+}
+
+
+BOOST_AUTO_TEST_CASE( MiscOperationsTest ) {
+ Var v1(1, 2);
+ Var v2(2, 3);
+ Factor x( VarSet( v1, v2 ) );
+ x.randomize();
+
+ // slice
+ Factor y = x.slice( v1, 0 );
+ BOOST_CHECK( y.vars() == VarSet( v2 ) );
+ BOOST_CHECK_EQUAL( y.nrStates(), 3 );
+ BOOST_CHECK_EQUAL( y[0], x[0] );
+ BOOST_CHECK_EQUAL( y[1], x[2] );
+ BOOST_CHECK_EQUAL( y[2], x[4] );
+ y = x.slice( v1, 1 );
+ BOOST_CHECK( y.vars() == VarSet( v2 ) );
+ BOOST_CHECK_EQUAL( y.nrStates(), 3 );
+ BOOST_CHECK_EQUAL( y[0], x[1] );
+ BOOST_CHECK_EQUAL( y[1], x[3] );
+ BOOST_CHECK_EQUAL( y[2], x[5] );
+ y = x.slice( v2, 0 );
+ BOOST_CHECK( y.vars() == VarSet( v1 ) );
+ BOOST_CHECK_EQUAL( y.nrStates(), 2 );
+ BOOST_CHECK_EQUAL( y[0], x[0] );
+ BOOST_CHECK_EQUAL( y[1], x[1] );
+ y = x.slice( v2, 1 );
+ BOOST_CHECK( y.vars() == VarSet( v1 ) );
+ BOOST_CHECK_EQUAL( y.nrStates(), 2 );
+ BOOST_CHECK_EQUAL( y[0], x[2] );
+ BOOST_CHECK_EQUAL( y[1], x[3] );
+ y = x.slice( v2, 2 );
+ BOOST_CHECK( y.vars() == VarSet( v1 ) );
+ BOOST_CHECK_EQUAL( y.nrStates(), 2 );
+ BOOST_CHECK_EQUAL( y[0], x[4] );
+ BOOST_CHECK_EQUAL( y[1], x[5] );
+ for( size_t i = 0; i < x.nrStates(); i++ ) {
+ y = x.slice( VarSet( v1, v2 ), 0 );
+ BOOST_CHECK( y.vars() == VarSet() );
+ BOOST_CHECK_EQUAL( y.nrStates(), 1 );
+ BOOST_CHECK_EQUAL( y[0], x[0] );
+ }
+ y = x.slice( VarSet(), 0 );
+ BOOST_CHECK_EQUAL( y, x );
+
+ // embed
+ Var v3(3, 4);
+ BOOST_CHECK_THROW( x.embed( VarSet( v3 ) ), Exception );
+ BOOST_CHECK_THROW( x.embed( VarSet( v3, v2 ) ), Exception );
+ y = x.embed( VarSet( v3, v2 ) | v1 );
+ for( size_t i = 0; i < y.nrStates(); i++ )
+ BOOST_CHECK_EQUAL( y[i], x[i % 6] );
+ y = x.embed( VarSet( v1, v2 ) );
+ BOOST_CHECK_EQUAL( x, y );
+
+ // marginal
+ y = x.marginal( v1 );
+ BOOST_CHECK( y.vars() == VarSet( v1 ) );
+ BOOST_CHECK_EQUAL( y[0], (x[0] + x[2] + x[4]) / x.sum() );
+ BOOST_CHECK_EQUAL( y[1], (x[1] + x[3] + x[5]) / x.sum() );
+ y = x.marginal( v2 );
+ BOOST_CHECK( y.vars() == VarSet( v2 ) );
+ BOOST_CHECK_CLOSE( y[0], (x[0] + x[1]) / x.sum(), tol );
+ BOOST_CHECK_CLOSE( y[1], (x[2] + x[3]) / x.sum(), tol );
+ BOOST_CHECK_CLOSE( y[2], (x[4] + x[5]) / x.sum(), tol );
+ y = x.marginal( VarSet() );
+ BOOST_CHECK( y.vars() == VarSet() );
+ BOOST_CHECK_EQUAL( y[0], 1.0 );
+ y = x.marginal( VarSet( v1, v2 ) );
+ BOOST_CHECK( y == x.normalized() );
+
+ y = x.marginal( v1, true );
+ BOOST_CHECK( y.vars() == VarSet( v1 ) );
+ BOOST_CHECK_EQUAL( y[0], (x[0] + x[2] + x[4]) / x.sum() );
+ BOOST_CHECK_EQUAL( y[1], (x[1] + x[3] + x[5]) / x.sum() );
+ y = x.marginal( v2, true );
+ BOOST_CHECK( y.vars() == VarSet( v2 ) );
+ BOOST_CHECK_CLOSE( y[0], (x[0] + x[1]) / x.sum(), tol );
+ BOOST_CHECK_CLOSE( y[1], (x[2] + x[3]) / x.sum(), tol );
+ BOOST_CHECK_CLOSE( y[2], (x[4] + x[5]) / x.sum(), tol );
+ y = x.marginal( VarSet(), true );
+ BOOST_CHECK( y.vars() == VarSet() );
+ BOOST_CHECK_EQUAL( y[0], 1.0 );
+ y = x.marginal( VarSet( v1, v2 ), true );
+ BOOST_CHECK( y == x.normalized() );
+
+ y = x.marginal( v1, false );
+ BOOST_CHECK( y.vars() == VarSet( v1 ) );
+ BOOST_CHECK_EQUAL( y[0], x[0] + x[2] + x[4] );
+ BOOST_CHECK_EQUAL( y[1], x[1] + x[3] + x[5] );
+ y = x.marginal( v2, false );
+ BOOST_CHECK( y.vars() == VarSet( v2 ) );
+ BOOST_CHECK_EQUAL( y[0], x[0] + x[1] );
+ BOOST_CHECK_EQUAL( y[1], x[2] + x[3] );
+ BOOST_CHECK_EQUAL( y[2], x[4] + x[5] );
+ y = x.marginal( VarSet(), false );
+ BOOST_CHECK( y.vars() == VarSet() );
+ BOOST_CHECK_EQUAL( y[0], x.sum() );
+ y = x.marginal( VarSet( v1, v2 ), false );
+ BOOST_CHECK( y == x );
+
+ // maxMarginal
+ y = x.maxMarginal( v1 );
+ BOOST_CHECK( y.vars() == VarSet( v1 ) );
+ BOOST_CHECK_EQUAL( y[0], x.slice( v1, 0 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()) );
+ BOOST_CHECK_EQUAL( y[1], x.slice( v1, 1 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()) );
+ y = x.maxMarginal( v2 );
+ BOOST_CHECK( y.vars() == VarSet( v2 ) );
+ BOOST_CHECK_EQUAL( y[0], x.slice( v2, 0 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
+ BOOST_CHECK_EQUAL( y[1], x.slice( v2, 1 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
+ BOOST_CHECK_EQUAL( y[2], x.slice( v2, 2 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
+ y = x.maxMarginal( VarSet() );
+ BOOST_CHECK( y.vars() == VarSet() );
+ BOOST_CHECK_EQUAL( y[0], 1.0 );
+ y = x.maxMarginal( VarSet( v1, v2 ) );
+ BOOST_CHECK( y == x.normalized() );
+
+ y = x.maxMarginal( v1, true );
+ BOOST_CHECK( y.vars() == VarSet( v1 ) );
+ BOOST_CHECK_EQUAL( y[0], x.slice( v1, 0 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()) );
+ BOOST_CHECK_EQUAL( y[1], x.slice( v1, 1 ).max() / (x.slice( v1, 0 ).max() + x.slice( v1, 1 ).max()) );
+ y = x.maxMarginal( v2, true );
+ BOOST_CHECK( y.vars() == VarSet( v2 ) );
+ BOOST_CHECK_EQUAL( y[0], x.slice( v2, 0 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
+ BOOST_CHECK_EQUAL( y[1], x.slice( v2, 1 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
+ BOOST_CHECK_EQUAL( y[2], x.slice( v2, 2 ).max() / (x.slice( v2, 0 ).max() + x.slice( v2, 1 ).max() + x.slice( v2, 2 ).max()) );
+ y = x.maxMarginal( VarSet(), true );
+ BOOST_CHECK( y.vars() == VarSet() );
+ BOOST_CHECK_EQUAL( y[0], 1.0 );
+ y = x.maxMarginal( VarSet( v1, v2 ), true );
+ BOOST_CHECK( y == x.normalized() );
+
+ y = x.maxMarginal( v1, false );
+ BOOST_CHECK( y.vars() == VarSet( v1 ) );
+ BOOST_CHECK_EQUAL( y[0], x.slice( v1, 0 ).max() );
+ BOOST_CHECK_EQUAL( y[1], x.slice( v1, 1 ).max() );
+ y = x.maxMarginal( v2, false );
+ BOOST_CHECK( y.vars() == VarSet( v2 ) );
+ BOOST_CHECK_EQUAL( y[0], x.slice( v2, 0 ).max() );
+ BOOST_CHECK_EQUAL( y[1], x.slice( v2, 1 ).max() );
+ BOOST_CHECK_EQUAL( y[2], x.slice( v2, 2 ).max() );
+ y = x.maxMarginal( VarSet(), false );
+ BOOST_CHECK( y.vars() == VarSet() );
+ BOOST_CHECK_EQUAL( y[0], x.max() );
+ y = x.maxMarginal( VarSet( v1, v2 ), false );
+ BOOST_CHECK( y == x );
+}
+
+
+BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
+ Var v( 0, 3 );
+ Factor x(v), y(v), z(v);
+ x.set( 0, 0.2 );
+ x.set( 1, 0.8 );
+ x.set( 2, 0.0 );
+ y.set( 0, 0.0 );
+ y.set( 1, 0.6 );
+ y.set( 2, 0.4 );
+
+ z = min( x, y );
+ BOOST_CHECK_EQUAL( z[0], 0.0 );
+ BOOST_CHECK_EQUAL( z[1], 0.6 );
+ BOOST_CHECK_EQUAL( z[2], 0.0 );
+ z = max( x, y );
+ BOOST_CHECK_EQUAL( z[0], 0.2 );
+ BOOST_CHECK_EQUAL( z[1], 0.8 );
+ BOOST_CHECK_EQUAL( z[2], 0.4 );
+
+ BOOST_CHECK_EQUAL( dist( x, x, DISTL1 ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( y, y, DISTL1 ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), 0.2 + 0.2 + 0.4 );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), 0.2 + 0.2 + 0.4 );
+ BOOST_CHECK_EQUAL( dist( x, x, DISTLINF ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( y, y, DISTLINF ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), 0.4 );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), 0.4 );
+ BOOST_CHECK_EQUAL( dist( x, x, DISTTV ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( y, y, DISTTV ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
+ BOOST_CHECK_EQUAL( dist( x, x, DISTKL ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( y, y, DISTKL ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), INFINITY );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), INFINITY );
+ BOOST_CHECK_EQUAL( dist( x, x, DISTHEL ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( y, y, DISTHEL ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
+ x.set( 1, 0.7 ); x.set( 2, 0.1 );
+ y.set( 0, 0.1 ); y.set( 1, 0.5 );
+ BOOST_CHECK_CLOSE( dist( x, y, DISTKL ), 0.2 * std::log(0.2 / 0.1) + 0.7 * std::log(0.7 / 0.5) + 0.1 * std::log(0.1 / 0.4), tol );
+ BOOST_CHECK_CLOSE( dist( y, x, DISTKL ), 0.1 * std::log(0.1 / 0.2) + 0.5 * std::log(0.5 / 0.7) + 0.4 * std::log(0.4 / 0.1), tol );
+
+ std::stringstream ss;
+ ss << x;
+ std::string s;
+ std::getline( ss, s );
+ BOOST_CHECK_EQUAL( s, std::string("({x0}, (0.2, 0.7, 0.1))") );
+ std::stringstream ss2;
+ ss2 << y;
+ std::getline( ss2, s );
+ BOOST_CHECK_EQUAL( s, std::string("({x0}, (0.1, 0.5, 0.4))") );
+
+ z = min( x, y );
+ BOOST_CHECK_EQUAL( z[0], 0.1 );
+ BOOST_CHECK_EQUAL( z[1], 0.5 );
+ BOOST_CHECK_EQUAL( z[2], 0.1 );
+ z = max( x, y );
+ BOOST_CHECK_EQUAL( z[0], 0.2 );
+ BOOST_CHECK_EQUAL( z[1], 0.7 );
+ BOOST_CHECK_EQUAL( z[2], 0.4 );
+
+ for( double J = -1.0; J <= 1.01; J += 0.1 ) {
+ Factor x = createFactorIsing( Var(0,2), Var(1,2), J ).normalized();
+ BOOST_CHECK_CLOSE( x[0], std::exp(J) / (4.0 * std::cosh(J)), tol );
+ BOOST_CHECK_CLOSE( x[1], std::exp(-J) / (4.0 * std::cosh(J)), tol );
+ BOOST_CHECK_CLOSE( x[2], std::exp(-J) / (4.0 * std::cosh(J)), tol );
+ BOOST_CHECK_CLOSE( x[3], std::exp(J) / (4.0 * std::cosh(J)), tol );
+ BOOST_CHECK_SMALL( MutualInfo( x ) - (J * std::tanh(J) - std::log(std::cosh(J))), tol );
+ }
+ Var v1( 1, 3 );
+ Var v2( 2, 4 );
+ BOOST_CHECK_SMALL( MutualInfo( (Factor(v1).randomize() * Factor(v2).randomize()).normalized() ), tol );
+ BOOST_CHECK_THROW( MutualInfo( createFactorIsing( Var(0,2), 1.0 ).normalized() ), Exception );
+ BOOST_CHECK_THROW( createFactorIsing( v1, 0.0 ), Exception );
+ BOOST_CHECK_THROW( createFactorIsing( v1, v2, 0.0 ), Exception );
+ for( double J = -1.0; J <= 1.01; J += 0.1 ) {
+ Factor x = createFactorIsing( Var(0,2), J ).normalized();
+ BOOST_CHECK_CLOSE( x[0], std::exp(-J) / (2.0 * std::cosh(J)), tol );
+ BOOST_CHECK_CLOSE( x[1], std::exp(J) / (2.0 * std::cosh(J)), tol );
+ BOOST_CHECK_SMALL( x.entropy() - (-J * std::tanh(J) + std::log(2.0 * std::cosh(J))), tol );
+ }
+
+ x = createFactorDelta( v1, 2 );
+ BOOST_CHECK_EQUAL( x[0], 0.0 );
+ BOOST_CHECK_EQUAL( x[1], 0.0 );
+ BOOST_CHECK_EQUAL( x[2], 1.0 );
+ x = createFactorDelta( v1, 1 );
+ BOOST_CHECK_EQUAL( x[0], 0.0 );
+ BOOST_CHECK_EQUAL( x[1], 1.0 );
+ BOOST_CHECK_EQUAL( x[2], 0.0 );
+ x = createFactorDelta( v1, 0 );
+ BOOST_CHECK_EQUAL( x[0], 1.0 );
+ BOOST_CHECK_EQUAL( x[1], 0.0 );
+ BOOST_CHECK_EQUAL( x[2], 0.0 );
+ BOOST_CHECK_THROW( createFactorDelta( v1, 4 ), Exception );
+}
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/bipgraph.h>
-#include <dai/factorgraph.h>
-#include <vector>
-#include <strstream>
-
-
-using namespace dai;
-
-
-const double tol = 1e-8;
-
-
-#define BOOST_TEST_MODULE FactorGraphTest
-
-
-#include <boost/test/unit_test.hpp>
-#include <boost/test/floating_point_comparison.hpp>
-
-
-BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
- FactorGraph G;
- BOOST_CHECK_EQUAL( G.vars(), std::vector<Var>() );
- BOOST_CHECK_EQUAL( G.factors(), std::vector<Factor>() );
-
- std::vector<Factor> facs;
- facs.push_back( Factor( VarSet( Var(0, 2), Var(1, 2) ) ) );
- facs.push_back( Factor( VarSet( Var(0, 2), Var(2, 2) ) ) );
- facs.push_back( Factor( VarSet( Var(1, 2), Var(2, 2) ) ) );
- facs.push_back( Factor( VarSet( Var(1, 2) ) ) );
- std::vector<Var> vars;
- vars.push_back( Var( 0, 2 ) );
- vars.push_back( Var( 1, 2 ) );
- vars.push_back( Var( 2, 2 ) );
-
- FactorGraph G1( facs );
- BOOST_CHECK_EQUAL( G1.vars(), vars );
- BOOST_CHECK_EQUAL( G1.factors(), facs );
-
- FactorGraph G2( facs.begin(), facs.end(), vars.begin(), vars.end(), facs.size(), vars.size() );
- BOOST_CHECK_EQUAL( G2.vars(), vars );
- BOOST_CHECK_EQUAL( G2.factors(), facs );
-
- FactorGraph *G3 = G2.clone();
- BOOST_CHECK_EQUAL( G3->vars(), vars );
- BOOST_CHECK_EQUAL( G3->factors(), facs );
- delete G3;
-
- FactorGraph G4 = G2;
- BOOST_CHECK_EQUAL( G4.vars(), vars );
- BOOST_CHECK_EQUAL( G4.factors(), facs );
-
- FactorGraph G5( G2 );
- BOOST_CHECK_EQUAL( G5.vars(), vars );
- BOOST_CHECK_EQUAL( G5.factors(), facs );
-}
-
-
-BOOST_AUTO_TEST_CASE( AccMutTest ) {
- std::vector<Factor> facs;
- facs.push_back( Factor( VarSet( Var(0, 2), Var(1, 2) ) ) );
- facs.push_back( Factor( VarSet( Var(0, 2), Var(2, 2) ) ) );
- facs.push_back( Factor( VarSet( Var(1, 2), Var(2, 2) ) ) );
- facs.push_back( Factor( VarSet( Var(1, 2) ) ) );
- std::vector<Var> vars;
- vars.push_back( Var( 0, 2 ) );
- vars.push_back( Var( 1, 2 ) );
- vars.push_back( Var( 2, 2 ) );
-
- FactorGraph G( facs );
- BOOST_CHECK_EQUAL( G.var(0), Var(0, 2) );
- BOOST_CHECK_EQUAL( G.var(1), Var(1, 2) );
- BOOST_CHECK_EQUAL( G.var(2), Var(2, 2) );
- BOOST_CHECK_EQUAL( G.vars(), vars );
- BOOST_CHECK_EQUAL( G.factor(0), facs[0] );
- BOOST_CHECK_EQUAL( G.factor(1), facs[1] );
- BOOST_CHECK_EQUAL( G.factor(2), facs[2] );
- BOOST_CHECK_EQUAL( G.factor(3), facs[3] );
- BOOST_CHECK_EQUAL( G.factors(), facs );
- BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
- BOOST_CHECK_EQUAL( G.nbV(0,0), 0 );
- BOOST_CHECK_EQUAL( G.nbV(0,1), 1 );
- BOOST_CHECK_EQUAL( G.nbV(1).size(), 3 );
- BOOST_CHECK_EQUAL( G.nbV(1,0), 0 );
- BOOST_CHECK_EQUAL( G.nbV(1,1), 2 );
- BOOST_CHECK_EQUAL( G.nbV(1,2), 3 );
- BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
- BOOST_CHECK_EQUAL( G.nbV(2,0), 1 );
- BOOST_CHECK_EQUAL( G.nbV(2,1), 2 );
- BOOST_CHECK_EQUAL( G.nbF(0).size(), 2 );
- BOOST_CHECK_EQUAL( G.nbF(0,0), 0 );
- BOOST_CHECK_EQUAL( G.nbF(0,1), 1 );
- BOOST_CHECK_EQUAL( G.nbF(1).size(), 2 );
- BOOST_CHECK_EQUAL( G.nbF(1,0), 0 );
- BOOST_CHECK_EQUAL( G.nbF(1,1), 2 );
- BOOST_CHECK_EQUAL( G.nbF(2).size(), 2 );
- BOOST_CHECK_EQUAL( G.nbF(2,0), 1 );
- BOOST_CHECK_EQUAL( G.nbF(2,1), 2 );
- BOOST_CHECK_EQUAL( G.nbF(3).size(), 1 );
- BOOST_CHECK_EQUAL( G.nbF(3,0), 1 );
-}
-
-
-BOOST_AUTO_TEST_CASE( QueriesTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 2 );
- Var v2( 2, 2 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v12( v1, v2 );
- VarSet v012 = v01 | v2;
-
- FactorGraph G0;
- BOOST_CHECK_EQUAL( G0.nrVars(), 0 );
- BOOST_CHECK_EQUAL( G0.nrFactors(), 0 );
- BOOST_CHECK_EQUAL( G0.nrEdges(), 0 );
- BOOST_CHECK_THROW( G0.findVar( v0 ), Exception );
- BOOST_CHECK_THROW( G0.findVars( v01 ), Exception );
- BOOST_CHECK_THROW( G0.findFactor( v01 ), Exception );
-#ifdef DAI_DBEUG
- BOOST_CHECK_THROW( G0.delta( 0 ), Exception );
- BOOST_CHECK_THROW( G0.Delta( 0 ), Exception );
- BOOST_CHECK_THROW( G0.delta( v0 ), Exception );
- BOOST_CHECK_THROW( G0.Delta( v0 ), Exception );
-#endif
- BOOST_CHECK( G0.isConnected() );
- BOOST_CHECK( G0.isTree() );
- BOOST_CHECK( G0.isBinary() );
- BOOST_CHECK( G0.isPairwise() );
- BOOST_CHECK( G0.MarkovGraph() == GraphAL() );
- BOOST_CHECK( G0.bipGraph() == BipartiteGraph() );
- BOOST_CHECK_EQUAL( G0.maximalFactorDomains().size(), 0 );
-
- std::vector<Factor> facs;
- facs.push_back( Factor( v01 ) );
- facs.push_back( Factor( v12 ) );
- facs.push_back( Factor( v1 ) );
- std::vector<Var> vars;
- vars.push_back( v0 );
- vars.push_back( v1 );
- vars.push_back( v2 );
- GraphAL H(3);
- H.addEdge( 0, 1 );
- H.addEdge( 1, 2 );
- BipartiteGraph K(3, 3);
- K.addEdge( 0, 0 );
- K.addEdge( 1, 0 );
- K.addEdge( 1, 1 );
- K.addEdge( 2, 1 );
- K.addEdge( 1, 2 );
-
- FactorGraph G1( facs );
- BOOST_CHECK_EQUAL( G1.nrVars(), 3 );
- BOOST_CHECK_EQUAL( G1.nrFactors(), 3 );
- BOOST_CHECK_EQUAL( G1.nrEdges(), 5 );
- BOOST_CHECK_EQUAL( G1.findVar( v0 ), 0 );
- BOOST_CHECK_EQUAL( G1.findVar( v1 ), 1 );
- BOOST_CHECK_EQUAL( G1.findVar( v2 ), 2 );
- BOOST_CHECK_EQUAL( G1.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
- BOOST_CHECK_EQUAL( G1.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
- BOOST_CHECK_EQUAL( G1.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
- BOOST_CHECK_EQUAL( G1.findFactor( v01 ), 0 );
- BOOST_CHECK_EQUAL( G1.findFactor( v12 ), 1 );
- BOOST_CHECK_EQUAL( G1.findFactor( v1 ), 2 );
- BOOST_CHECK_THROW( G1.findFactor( v02 ), Exception );
- BOOST_CHECK_EQUAL( G1.delta( 0 ), v1 );
- BOOST_CHECK_EQUAL( G1.delta( 1 ), v02 );
- BOOST_CHECK_EQUAL( G1.delta( 2 ), v1 );
- BOOST_CHECK_EQUAL( G1.Delta( 0 ), v01 );
- BOOST_CHECK_EQUAL( G1.Delta( 1 ), v012 );
- BOOST_CHECK_EQUAL( G1.Delta( 2 ), v12 );
- BOOST_CHECK_EQUAL( G1.delta( v0 ), v1 );
- BOOST_CHECK_EQUAL( G1.delta( v1 ), v02 );
- BOOST_CHECK_EQUAL( G1.delta( v2 ), v1 );
- BOOST_CHECK_EQUAL( G1.delta( v01 ), v2 );
- BOOST_CHECK_EQUAL( G1.delta( v02 ), v1 );
- BOOST_CHECK_EQUAL( G1.delta( v12 ), v0 );
- BOOST_CHECK_EQUAL( G1.delta( v012 ), VarSet() );
- BOOST_CHECK_EQUAL( G1.Delta( v0 ), v01 );
- BOOST_CHECK_EQUAL( G1.Delta( v1 ), v012 );
- BOOST_CHECK_EQUAL( G1.Delta( v2 ), v12 );
- BOOST_CHECK_EQUAL( G1.Delta( v01 ), v012 );
- BOOST_CHECK_EQUAL( G1.Delta( v02 ), v012 );
- BOOST_CHECK_EQUAL( G1.Delta( v12 ), v012 );
- BOOST_CHECK_EQUAL( G1.Delta( v012 ), v012 );
- BOOST_CHECK( G1.isConnected() );
- BOOST_CHECK( G1.isTree() );
- BOOST_CHECK( G1.isBinary() );
- BOOST_CHECK( G1.isPairwise() );
- BOOST_CHECK( G1.MarkovGraph() == H );
- BOOST_CHECK( G1.bipGraph() == K );
- BOOST_CHECK_EQUAL( G1.maximalFactorDomains().size(), 2 );
- BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[0], v01 );
- BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[1], v12 );
-
- facs.push_back( Factor( v02 ) );
- H.addEdge( 0, 2 );
- K.addNode2();
- K.addEdge( 0, 3 );
- K.addEdge( 2, 3 );
- FactorGraph G2( facs );
- BOOST_CHECK_EQUAL( G2.nrVars(), 3 );
- BOOST_CHECK_EQUAL( G2.nrFactors(), 4 );
- BOOST_CHECK_EQUAL( G2.nrEdges(), 7 );
- BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
- BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
- BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
- BOOST_CHECK_EQUAL( G2.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
- BOOST_CHECK_EQUAL( G2.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
- BOOST_CHECK_EQUAL( G2.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
- BOOST_CHECK_EQUAL( G2.findFactor( v01 ), 0 );
- BOOST_CHECK_EQUAL( G2.findFactor( v12 ), 1 );
- BOOST_CHECK_EQUAL( G2.findFactor( v1 ), 2 );
- BOOST_CHECK_EQUAL( G2.findFactor( v02 ), 3 );
- BOOST_CHECK_EQUAL( G2.delta( 0 ), v12 );
- BOOST_CHECK_EQUAL( G2.delta( 1 ), v02 );
- BOOST_CHECK_EQUAL( G2.delta( 2 ), v01 );
- BOOST_CHECK_EQUAL( G2.Delta( 0 ), v012 );
- BOOST_CHECK_EQUAL( G2.Delta( 1 ), v012 );
- BOOST_CHECK_EQUAL( G2.Delta( 2 ), v012 );
- BOOST_CHECK( G2.isConnected() );
- BOOST_CHECK( !G2.isTree() );
- BOOST_CHECK( G2.isBinary() );
- BOOST_CHECK( G2.isPairwise() );
- BOOST_CHECK( G2.MarkovGraph() == H );
- BOOST_CHECK( G2.bipGraph() == K );
- BOOST_CHECK_EQUAL( G2.maximalFactorDomains().size(), 3 );
- BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[0], v01 );
- BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[1], v12 );
- BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[2], v02 );
-
- Var v3( 3, 3 );
- VarSet v03( v0, v3 );
- VarSet v13( v1, v3 );
- VarSet v23( v2, v3 );
- VarSet v013 = v01 | v3;
- VarSet v023 = v02 | v3;
- VarSet v123 = v12 | v3;
- VarSet v0123 = v012 | v3;
- vars.push_back( v3 );
- facs.push_back( Factor( v3 ) );
- H.addNode();
- K.addNode1();
- K.addNode2();
- K.addEdge( 3, 4 );
- FactorGraph G3( facs );
- BOOST_CHECK_EQUAL( G3.nrVars(), 4 );
- BOOST_CHECK_EQUAL( G3.nrFactors(), 5 );
- BOOST_CHECK_EQUAL( G3.nrEdges(), 8 );
- BOOST_CHECK_EQUAL( G3.findVar( v0 ), 0 );
- BOOST_CHECK_EQUAL( G3.findVar( v1 ), 1 );
- BOOST_CHECK_EQUAL( G3.findVar( v2 ), 2 );
- BOOST_CHECK_EQUAL( G3.findVar( v3 ), 3 );
- BOOST_CHECK_EQUAL( G3.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
- BOOST_CHECK_EQUAL( G3.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
- BOOST_CHECK_EQUAL( G3.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
- BOOST_CHECK_EQUAL( G3.findFactor( v01 ), 0 );
- BOOST_CHECK_EQUAL( G3.findFactor( v12 ), 1 );
- BOOST_CHECK_EQUAL( G3.findFactor( v1 ), 2 );
- BOOST_CHECK_EQUAL( G3.findFactor( v02 ), 3 );
- BOOST_CHECK_EQUAL( G3.findFactor( v3 ), 4 );
- BOOST_CHECK_THROW( G3.findFactor( v23 ), Exception );
- BOOST_CHECK_EQUAL( G3.delta( 0 ), v12 );
- BOOST_CHECK_EQUAL( G3.delta( 1 ), v02 );
- BOOST_CHECK_EQUAL( G3.delta( 2 ), v01 );
- BOOST_CHECK_EQUAL( G3.delta( 3 ), VarSet() );
- BOOST_CHECK_EQUAL( G3.Delta( 0 ), v012 );
- BOOST_CHECK_EQUAL( G3.Delta( 1 ), v012 );
- BOOST_CHECK_EQUAL( G3.Delta( 2 ), v012 );
- BOOST_CHECK_EQUAL( G3.Delta( 3 ), v3 );
- BOOST_CHECK( !G3.isConnected() );
- BOOST_CHECK( !G3.isTree() );
- BOOST_CHECK( !G3.isBinary() );
- BOOST_CHECK( G3.isPairwise() );
- BOOST_CHECK( G3.MarkovGraph() == H );
- BOOST_CHECK( G3.bipGraph() == K );
- BOOST_CHECK_EQUAL( G3.maximalFactorDomains().size(), 4 );
- BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[0], v01 );
- BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[1], v12 );
- BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[2], v02 );
- BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[3], v3 );
-
- facs.push_back( Factor( v123 ) );
- H.addEdge( 1, 3 );
- H.addEdge( 2, 3 );
- K.addNode2();
- K.addEdge( 1, 5 );
- K.addEdge( 2, 5 );
- K.addEdge( 3, 5 );
- FactorGraph G4( facs );
- BOOST_CHECK_EQUAL( G4.nrVars(), 4 );
- BOOST_CHECK_EQUAL( G4.nrFactors(), 6 );
- BOOST_CHECK_EQUAL( G4.nrEdges(), 11 );
- BOOST_CHECK_EQUAL( G4.findVar( v0 ), 0 );
- BOOST_CHECK_EQUAL( G4.findVar( v1 ), 1 );
- BOOST_CHECK_EQUAL( G4.findVar( v2 ), 2 );
- BOOST_CHECK_EQUAL( G4.findVar( v3 ), 3 );
- BOOST_CHECK_EQUAL( G4.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
- BOOST_CHECK_EQUAL( G4.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
- BOOST_CHECK_EQUAL( G4.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
- BOOST_CHECK_EQUAL( G4.findFactor( v01 ), 0 );
- BOOST_CHECK_EQUAL( G4.findFactor( v12 ), 1 );
- BOOST_CHECK_EQUAL( G4.findFactor( v1 ), 2 );
- BOOST_CHECK_EQUAL( G4.findFactor( v02 ), 3 );
- BOOST_CHECK_EQUAL( G4.findFactor( v3 ), 4 );
- BOOST_CHECK_EQUAL( G4.findFactor( v123 ), 5 );
- BOOST_CHECK_THROW( G4.findFactor( v23 ), Exception );
- BOOST_CHECK_EQUAL( G4.delta( 0 ), v12 );
- BOOST_CHECK_EQUAL( G4.delta( 1 ), v023 );
- BOOST_CHECK_EQUAL( G4.delta( 2 ), v013 );
- BOOST_CHECK_EQUAL( G4.delta( 3 ), v12 );
- BOOST_CHECK_EQUAL( G4.Delta( 0 ), v012 );
- BOOST_CHECK_EQUAL( G4.Delta( 1 ), v0123 );
- BOOST_CHECK_EQUAL( G4.Delta( 2 ), v0123 );
- BOOST_CHECK_EQUAL( G4.Delta( 3 ), v123 );
- BOOST_CHECK( G4.isConnected() );
- BOOST_CHECK( !G4.isTree() );
- BOOST_CHECK( !G4.isBinary() );
- BOOST_CHECK( !G4.isPairwise() );
- BOOST_CHECK( G4.MarkovGraph() == H );
- BOOST_CHECK( G4.bipGraph() == K );
- BOOST_CHECK_EQUAL( G4.maximalFactorDomains().size(), 3 );
- BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[0], v01 );
- BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[1], v02 );
- BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[2], v123 );
-}
-
-
-BOOST_AUTO_TEST_CASE( BackupRestoreTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 2 );
- Var v2( 2, 2 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v12( v1, v2 );
- VarSet v012 = v01 | v2;
-
- std::vector<Factor> facs;
- facs.push_back( Factor( v01 ) );
- facs.push_back( Factor( v12 ) );
- facs.push_back( Factor( v1 ) );
- std::vector<Var> vars;
- vars.push_back( v0 );
- vars.push_back( v1 );
- vars.push_back( v2 );
-
- FactorGraph G( facs );
- FactorGraph Gorg( G );
-
- BOOST_CHECK_THROW( G.setFactor( 0, Factor( v0 ), false ), Exception );
- G.setFactor( 0, Factor( v01, 2.0 ), false );
- BOOST_CHECK_THROW( G.restoreFactor( 0 ), Exception );
- G.setFactor( 0, Factor( v01, 3.0 ), true );
- G.restoreFactor( 0 );
- BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
- G.setFactor( 0, Gorg.factor( 0 ), false );
- G.backupFactor( 0 );
- BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
- G.setFactor( 0, Factor( v01, 2.0 ), false );
- BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
- G.restoreFactor( 0 );
- BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
-
- std::map<size_t, Factor> fs;
- fs[0] = Factor( v01, 3.0 );
- fs[2] = Factor( v1, 2.0 );
- G.setFactors( fs, false );
- BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
- BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
- BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
- G.restoreFactors();
- BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
- BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
- BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
- G = Gorg;
- BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
- BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
- BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
- G.setFactors( fs, true );
- BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
- BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
- BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
- G.restoreFactors();
- BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
- BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
- BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
- std::set<size_t> fsind;
- fsind.insert( 0 );
- fsind.insert( 2 );
- G.backupFactors( fsind );
- G.setFactors( fs, false );
- BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
- BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
- BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
- G.restoreFactors();
- BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
- BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
- BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
-
- G.backupFactors( v2 );
- G.setFactor( 1, Factor(v12, 5.0) );
- BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
- BOOST_CHECK_EQUAL( G.factor(1), Factor(v12, 5.0) );
- BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
- G.restoreFactors( v2 );
- BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
- BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
- BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
-
- G.backupFactors( v1 );
- fs[1] = Factor( v12, 5.0 );
- G.setFactors( fs, false );
- BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
- BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
- BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
- G.restoreFactors();
- BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
- BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
- BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
- G.setFactors( fs, true );
- BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
- BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
- BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
- G.restoreFactors( v1 );
- BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
- BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
- BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
-}
-
-
-BOOST_AUTO_TEST_CASE( TransformationsTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 2 );
- Var v2( 2, 2 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v12( v1, v2 );
- VarSet v012 = v01 | v2;
-
- std::vector<Factor> facs;
- facs.push_back( Factor( v01 ).randomize() );
- facs.push_back( Factor( v12 ).randomize() );
- facs.push_back( Factor( v1 ).randomize() );
- std::vector<Var> vars;
- vars.push_back( v0 );
- vars.push_back( v1 );
- vars.push_back( v2 );
-
- FactorGraph G( facs );
-
- FactorGraph Gsmall = G.maximalFactors();
- BOOST_CHECK_EQUAL( Gsmall.nrVars(), 3 );
- BOOST_CHECK_EQUAL( Gsmall.nrFactors(), 2 );
- BOOST_CHECK_EQUAL( Gsmall.factor( 0 ), G.factor( 0 ) * G.factor( 2 ) );
- BOOST_CHECK_EQUAL( Gsmall.factor( 1 ), G.factor( 1 ) );
-
- size_t i = 0;
- for( size_t x = 0; x < 2; x++ ) {
- FactorGraph Gcl = G.clamped( i, x );
- BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
- BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
- BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
- BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) * G.factor(2) );
- BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1) );
- }
- i = 1;
- for( size_t x = 0; x < 2; x++ ) {
- FactorGraph Gcl = G.clamped( i, x );
- BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
- BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
- BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) * G.factor(2).slice(vars[i],x) );
- BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) );
- BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) );
- }
- i = 2;
- for( size_t x = 0; x < 2; x++ ) {
- FactorGraph Gcl = G.clamped( i, x );
- BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
- BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
- BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
- BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0) );
- BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) * G.factor(2) );
- }
-}
-
-
-BOOST_AUTO_TEST_CASE( OperationsTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 2 );
- Var v2( 2, 2 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v12( v1, v2 );
- VarSet v012 = v01 | v2;
-
- std::vector<Factor> facs;
- facs.push_back( Factor( v01 ).randomize() );
- facs.push_back( Factor( v12 ).randomize() );
- facs.push_back( Factor( v1 ).randomize() );
- std::vector<Var> vars;
- vars.push_back( v0 );
- vars.push_back( v1 );
- vars.push_back( v2 );
-
- FactorGraph G( facs );
-
- // clamp
- FactorGraph Gcl = G;
- for( size_t i = 0; i < 3; i++ )
- for( size_t x = 0; x < 2; x++ ) {
- Gcl.clamp( i, x, true );
- Factor delta = createFactorDelta( vars[i], x );
- BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
- BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
- for( size_t j = 0; j < 3; j++ )
- if( G.factor(j).vars().contains( vars[i] ) )
- BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
- else
- BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
-
- Gcl.restoreFactors();
- for( size_t j = 0; j < 3; j++ )
- BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
- }
-
- // clampVar
- for( size_t i = 0; i < 3; i++ )
- for( size_t x = 0; x < 2; x++ ) {
- Gcl.clampVar( i, std::vector<size_t>(1, x), true );
- Factor delta = createFactorDelta( vars[i], x );
- BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
- BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
- for( size_t j = 0; j < 3; j++ )
- if( G.factor(j).vars().contains( vars[i] ) )
- BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
- else
- BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
-
- Gcl.restoreFactors();
- for( size_t j = 0; j < 3; j++ )
- BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
- }
- for( size_t i = 0; i < 3; i++ )
- for( size_t x = 0; x < 2; x++ ) {
- Gcl.clampVar( i, std::vector<size_t>(), true );
- BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
- BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
- for( size_t j = 0; j < 3; j++ )
- if( G.factor(j).vars().contains( vars[i] ) )
- BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * 0.0 );
- else
- BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
-
- Gcl.restoreFactors();
- for( size_t j = 0; j < 3; j++ )
- BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
- }
- std::vector<size_t> both;
- both.push_back( 0 );
- both.push_back( 1 );
- for( size_t i = 0; i < 3; i++ )
- for( size_t x = 0; x < 2; x++ ) {
- Gcl.clampVar( i, both, true );
- BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
- BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
- for( size_t j = 0; j < 3; j++ )
- BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
- Gcl.restoreFactors();
- }
-
- // clampFactor
- for( size_t x = 0; x < 4; x++ ) {
- Gcl.clampFactor( 0, std::vector<size_t>(1,x), true );
- Factor mask( v01, 0.0 );
- mask.set( x, 1.0 );
- BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) * mask );
- BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
- BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
- Gcl.restoreFactor( 0 );
- }
- for( size_t x = 0; x < 4; x++ ) {
- Gcl.clampFactor( 1, std::vector<size_t>(1,x), true );
- Factor mask( v12, 0.0 );
- mask.set( x, 1.0 );
- BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
- BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) * mask );
- BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
- Gcl.restoreFactor( 1 );
- }
- for( size_t x = 0; x < 2; x++ ) {
- Gcl.clampFactor( 2, std::vector<size_t>(1,x), true );
- Factor mask( v1, 0.0 );
- mask.set( x, 1.0 );
- BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
- BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
- BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) * mask );
- Gcl.restoreFactors();
- }
-
- // makeCavity
- FactorGraph Gcav( G );
- Gcav.makeCavity( 0, true );
- BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
- BOOST_CHECK_EQUAL( Gcav.factor(1), G.factor(1) );
- BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
- Gcav.restoreFactors();
- Gcav.makeCavity( 1, true );
- BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
- BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
- BOOST_CHECK_EQUAL( Gcav.factor(2), Factor( v1, 1.0 ) );
- Gcav.restoreFactors();
- Gcav.makeCavity( 2, true );
- BOOST_CHECK_EQUAL( Gcav.factor(0), G.factor(0) );
- BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
- BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
- Gcav.restoreFactors();
-}
-
-
-BOOST_AUTO_TEST_CASE( IOTest ) {
- Var v0( 0, 2 );
- Var v1( 1, 2 );
- Var v2( 2, 2 );
- VarSet v01( v0, v1 );
- VarSet v02( v0, v2 );
- VarSet v12( v1, v2 );
- VarSet v012 = v01 | v2;
-
- std::vector<Factor> facs;
- facs.push_back( Factor( v01 ).randomize() );
- facs.push_back( Factor( v12 ).randomize() );
- facs.push_back( Factor( v1 ).randomize() );
- std::vector<Var> vars;
- vars.push_back( v0 );
- vars.push_back( v1 );
- vars.push_back( v2 );
-
- FactorGraph G( facs );
-
- G.WriteToFile( "factorgraph_test.fg" );
- FactorGraph G2;
- G2.ReadFromFile( "factorgraph_test.fg" );
-
- BOOST_CHECK( G.vars() == G2.vars() );
- BOOST_CHECK( G.bipGraph() == G2.bipGraph() );
- BOOST_CHECK_EQUAL( G.nrFactors(), G2.nrFactors() );
- for( size_t I = 0; I < G.nrFactors(); I++ ) {
- BOOST_CHECK( G.factor(I).vars() == G2.factor(I).vars() );
- for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
- BOOST_CHECK_CLOSE( G.factor(I)[s], G2.factor(I)[s], tol );
- }
-
- std::stringstream ss;
- std::string s;
- G.printDot( ss );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph FactorGraph {" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv2;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box,width=0.3,height=0.3,fixedsize=true];" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf2;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv0 -- f0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f2;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv2 -- f1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
-
- G.setFactor( 0, Factor( G.factor(0).vars(), 1.0 ) );
- G.setFactor( 1, Factor( G.factor(1).vars(), 2.0 ) );
- G.setFactor( 2, Factor( G.factor(2).vars(), 3.0 ) );
- ss << G;
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 1 " );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2 " );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "4" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 1" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 1" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 1" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3 1" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 2 " );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2 " );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "4" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 2" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 2" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3 2" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 " );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 " );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 3" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 3" );
-
- ss << G;
- FactorGraph G3;
- ss >> G3;
- BOOST_CHECK( G.vars() == G3.vars() );
- BOOST_CHECK( G.bipGraph() == G3.bipGraph() );
- BOOST_CHECK_EQUAL( G.nrFactors(), G3.nrFactors() );
- for( size_t I = 0; I < G.nrFactors(); I++ ) {
- BOOST_CHECK( G.factor(I).vars() == G3.factor(I).vars() );
- for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
- BOOST_CHECK_CLOSE( G.factor(I)[s], G3.factor(I)[s], tol );
- }
-}
--- /dev/null
+/* This file is part of libDAI - http://www.libdai.org/
+ *
+ * libDAI is licensed under the terms of the GNU General Public License version
+ * 2, or (at your option) any later version. libDAI is distributed without any
+ * warranty. See the file COPYING for more details.
+ *
+ * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/bipgraph.h>
+#include <dai/factorgraph.h>
+#include <vector>
+#include <strstream>
+
+
+using namespace dai;
+
+
+const double tol = 1e-8;
+
+
+#define BOOST_TEST_MODULE FactorGraphTest
+
+
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+
+
+BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
+ FactorGraph G;
+ BOOST_CHECK_EQUAL( G.vars(), std::vector<Var>() );
+ BOOST_CHECK_EQUAL( G.factors(), std::vector<Factor>() );
+
+ std::vector<Factor> facs;
+ facs.push_back( Factor( VarSet( Var(0, 2), Var(1, 2) ) ) );
+ facs.push_back( Factor( VarSet( Var(0, 2), Var(2, 2) ) ) );
+ facs.push_back( Factor( VarSet( Var(1, 2), Var(2, 2) ) ) );
+ facs.push_back( Factor( VarSet( Var(1, 2) ) ) );
+ std::vector<Var> vars;
+ vars.push_back( Var( 0, 2 ) );
+ vars.push_back( Var( 1, 2 ) );
+ vars.push_back( Var( 2, 2 ) );
+
+ FactorGraph G1( facs );
+ BOOST_CHECK_EQUAL( G1.vars(), vars );
+ BOOST_CHECK_EQUAL( G1.factors(), facs );
+
+ FactorGraph G2( facs.begin(), facs.end(), vars.begin(), vars.end(), facs.size(), vars.size() );
+ BOOST_CHECK_EQUAL( G2.vars(), vars );
+ BOOST_CHECK_EQUAL( G2.factors(), facs );
+
+ FactorGraph *G3 = G2.clone();
+ BOOST_CHECK_EQUAL( G3->vars(), vars );
+ BOOST_CHECK_EQUAL( G3->factors(), facs );
+ delete G3;
+
+ FactorGraph G4 = G2;
+ BOOST_CHECK_EQUAL( G4.vars(), vars );
+ BOOST_CHECK_EQUAL( G4.factors(), facs );
+
+ FactorGraph G5( G2 );
+ BOOST_CHECK_EQUAL( G5.vars(), vars );
+ BOOST_CHECK_EQUAL( G5.factors(), facs );
+}
+
+
+BOOST_AUTO_TEST_CASE( AccMutTest ) {
+ std::vector<Factor> facs;
+ facs.push_back( Factor( VarSet( Var(0, 2), Var(1, 2) ) ) );
+ facs.push_back( Factor( VarSet( Var(0, 2), Var(2, 2) ) ) );
+ facs.push_back( Factor( VarSet( Var(1, 2), Var(2, 2) ) ) );
+ facs.push_back( Factor( VarSet( Var(1, 2) ) ) );
+ std::vector<Var> vars;
+ vars.push_back( Var( 0, 2 ) );
+ vars.push_back( Var( 1, 2 ) );
+ vars.push_back( Var( 2, 2 ) );
+
+ FactorGraph G( facs );
+ BOOST_CHECK_EQUAL( G.var(0), Var(0, 2) );
+ BOOST_CHECK_EQUAL( G.var(1), Var(1, 2) );
+ BOOST_CHECK_EQUAL( G.var(2), Var(2, 2) );
+ BOOST_CHECK_EQUAL( G.vars(), vars );
+ BOOST_CHECK_EQUAL( G.factor(0), facs[0] );
+ BOOST_CHECK_EQUAL( G.factor(1), facs[1] );
+ BOOST_CHECK_EQUAL( G.factor(2), facs[2] );
+ BOOST_CHECK_EQUAL( G.factor(3), facs[3] );
+ BOOST_CHECK_EQUAL( G.factors(), facs );
+ BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
+ BOOST_CHECK_EQUAL( G.nbV(0,0), 0 );
+ BOOST_CHECK_EQUAL( G.nbV(0,1), 1 );
+ BOOST_CHECK_EQUAL( G.nbV(1).size(), 3 );
+ BOOST_CHECK_EQUAL( G.nbV(1,0), 0 );
+ BOOST_CHECK_EQUAL( G.nbV(1,1), 2 );
+ BOOST_CHECK_EQUAL( G.nbV(1,2), 3 );
+ BOOST_CHECK_EQUAL( G.nbV(0).size(), 2 );
+ BOOST_CHECK_EQUAL( G.nbV(2,0), 1 );
+ BOOST_CHECK_EQUAL( G.nbV(2,1), 2 );
+ BOOST_CHECK_EQUAL( G.nbF(0).size(), 2 );
+ BOOST_CHECK_EQUAL( G.nbF(0,0), 0 );
+ BOOST_CHECK_EQUAL( G.nbF(0,1), 1 );
+ BOOST_CHECK_EQUAL( G.nbF(1).size(), 2 );
+ BOOST_CHECK_EQUAL( G.nbF(1,0), 0 );
+ BOOST_CHECK_EQUAL( G.nbF(1,1), 2 );
+ BOOST_CHECK_EQUAL( G.nbF(2).size(), 2 );
+ BOOST_CHECK_EQUAL( G.nbF(2,0), 1 );
+ BOOST_CHECK_EQUAL( G.nbF(2,1), 2 );
+ BOOST_CHECK_EQUAL( G.nbF(3).size(), 1 );
+ BOOST_CHECK_EQUAL( G.nbF(3,0), 1 );
+}
+
+
+BOOST_AUTO_TEST_CASE( QueriesTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 2 );
+ Var v2( 2, 2 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v12( v1, v2 );
+ VarSet v012 = v01 | v2;
+
+ FactorGraph G0;
+ BOOST_CHECK_EQUAL( G0.nrVars(), 0 );
+ BOOST_CHECK_EQUAL( G0.nrFactors(), 0 );
+ BOOST_CHECK_EQUAL( G0.nrEdges(), 0 );
+ BOOST_CHECK_THROW( G0.findVar( v0 ), Exception );
+ BOOST_CHECK_THROW( G0.findVars( v01 ), Exception );
+ BOOST_CHECK_THROW( G0.findFactor( v01 ), Exception );
+#ifdef DAI_DBEUG
+ BOOST_CHECK_THROW( G0.delta( 0 ), Exception );
+ BOOST_CHECK_THROW( G0.Delta( 0 ), Exception );
+ BOOST_CHECK_THROW( G0.delta( v0 ), Exception );
+ BOOST_CHECK_THROW( G0.Delta( v0 ), Exception );
+#endif
+ BOOST_CHECK( G0.isConnected() );
+ BOOST_CHECK( G0.isTree() );
+ BOOST_CHECK( G0.isBinary() );
+ BOOST_CHECK( G0.isPairwise() );
+ BOOST_CHECK( G0.MarkovGraph() == GraphAL() );
+ BOOST_CHECK( G0.bipGraph() == BipartiteGraph() );
+ BOOST_CHECK_EQUAL( G0.maximalFactorDomains().size(), 0 );
+
+ std::vector<Factor> facs;
+ facs.push_back( Factor( v01 ) );
+ facs.push_back( Factor( v12 ) );
+ facs.push_back( Factor( v1 ) );
+ std::vector<Var> vars;
+ vars.push_back( v0 );
+ vars.push_back( v1 );
+ vars.push_back( v2 );
+ GraphAL H(3);
+ H.addEdge( 0, 1 );
+ H.addEdge( 1, 2 );
+ BipartiteGraph K(3, 3);
+ K.addEdge( 0, 0 );
+ K.addEdge( 1, 0 );
+ K.addEdge( 1, 1 );
+ K.addEdge( 2, 1 );
+ K.addEdge( 1, 2 );
+
+ FactorGraph G1( facs );
+ BOOST_CHECK_EQUAL( G1.nrVars(), 3 );
+ BOOST_CHECK_EQUAL( G1.nrFactors(), 3 );
+ BOOST_CHECK_EQUAL( G1.nrEdges(), 5 );
+ BOOST_CHECK_EQUAL( G1.findVar( v0 ), 0 );
+ BOOST_CHECK_EQUAL( G1.findVar( v1 ), 1 );
+ BOOST_CHECK_EQUAL( G1.findVar( v2 ), 2 );
+ BOOST_CHECK_EQUAL( G1.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
+ BOOST_CHECK_EQUAL( G1.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
+ BOOST_CHECK_EQUAL( G1.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
+ BOOST_CHECK_EQUAL( G1.findFactor( v01 ), 0 );
+ BOOST_CHECK_EQUAL( G1.findFactor( v12 ), 1 );
+ BOOST_CHECK_EQUAL( G1.findFactor( v1 ), 2 );
+ BOOST_CHECK_THROW( G1.findFactor( v02 ), Exception );
+ BOOST_CHECK_EQUAL( G1.delta( 0 ), v1 );
+ BOOST_CHECK_EQUAL( G1.delta( 1 ), v02 );
+ BOOST_CHECK_EQUAL( G1.delta( 2 ), v1 );
+ BOOST_CHECK_EQUAL( G1.Delta( 0 ), v01 );
+ BOOST_CHECK_EQUAL( G1.Delta( 1 ), v012 );
+ BOOST_CHECK_EQUAL( G1.Delta( 2 ), v12 );
+ BOOST_CHECK_EQUAL( G1.delta( v0 ), v1 );
+ BOOST_CHECK_EQUAL( G1.delta( v1 ), v02 );
+ BOOST_CHECK_EQUAL( G1.delta( v2 ), v1 );
+ BOOST_CHECK_EQUAL( G1.delta( v01 ), v2 );
+ BOOST_CHECK_EQUAL( G1.delta( v02 ), v1 );
+ BOOST_CHECK_EQUAL( G1.delta( v12 ), v0 );
+ BOOST_CHECK_EQUAL( G1.delta( v012 ), VarSet() );
+ BOOST_CHECK_EQUAL( G1.Delta( v0 ), v01 );
+ BOOST_CHECK_EQUAL( G1.Delta( v1 ), v012 );
+ BOOST_CHECK_EQUAL( G1.Delta( v2 ), v12 );
+ BOOST_CHECK_EQUAL( G1.Delta( v01 ), v012 );
+ BOOST_CHECK_EQUAL( G1.Delta( v02 ), v012 );
+ BOOST_CHECK_EQUAL( G1.Delta( v12 ), v012 );
+ BOOST_CHECK_EQUAL( G1.Delta( v012 ), v012 );
+ BOOST_CHECK( G1.isConnected() );
+ BOOST_CHECK( G1.isTree() );
+ BOOST_CHECK( G1.isBinary() );
+ BOOST_CHECK( G1.isPairwise() );
+ BOOST_CHECK( G1.MarkovGraph() == H );
+ BOOST_CHECK( G1.bipGraph() == K );
+ BOOST_CHECK_EQUAL( G1.maximalFactorDomains().size(), 2 );
+ BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[0], v01 );
+ BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[1], v12 );
+
+ facs.push_back( Factor( v02 ) );
+ H.addEdge( 0, 2 );
+ K.addNode2();
+ K.addEdge( 0, 3 );
+ K.addEdge( 2, 3 );
+ FactorGraph G2( facs );
+ BOOST_CHECK_EQUAL( G2.nrVars(), 3 );
+ BOOST_CHECK_EQUAL( G2.nrFactors(), 4 );
+ BOOST_CHECK_EQUAL( G2.nrEdges(), 7 );
+ BOOST_CHECK_EQUAL( G2.findVar( v0 ), 0 );
+ BOOST_CHECK_EQUAL( G2.findVar( v1 ), 1 );
+ BOOST_CHECK_EQUAL( G2.findVar( v2 ), 2 );
+ BOOST_CHECK_EQUAL( G2.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
+ BOOST_CHECK_EQUAL( G2.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
+ BOOST_CHECK_EQUAL( G2.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
+ BOOST_CHECK_EQUAL( G2.findFactor( v01 ), 0 );
+ BOOST_CHECK_EQUAL( G2.findFactor( v12 ), 1 );
+ BOOST_CHECK_EQUAL( G2.findFactor( v1 ), 2 );
+ BOOST_CHECK_EQUAL( G2.findFactor( v02 ), 3 );
+ BOOST_CHECK_EQUAL( G2.delta( 0 ), v12 );
+ BOOST_CHECK_EQUAL( G2.delta( 1 ), v02 );
+ BOOST_CHECK_EQUAL( G2.delta( 2 ), v01 );
+ BOOST_CHECK_EQUAL( G2.Delta( 0 ), v012 );
+ BOOST_CHECK_EQUAL( G2.Delta( 1 ), v012 );
+ BOOST_CHECK_EQUAL( G2.Delta( 2 ), v012 );
+ BOOST_CHECK( G2.isConnected() );
+ BOOST_CHECK( !G2.isTree() );
+ BOOST_CHECK( G2.isBinary() );
+ BOOST_CHECK( G2.isPairwise() );
+ BOOST_CHECK( G2.MarkovGraph() == H );
+ BOOST_CHECK( G2.bipGraph() == K );
+ BOOST_CHECK_EQUAL( G2.maximalFactorDomains().size(), 3 );
+ BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[0], v01 );
+ BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[1], v12 );
+ BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[2], v02 );
+
+ Var v3( 3, 3 );
+ VarSet v03( v0, v3 );
+ VarSet v13( v1, v3 );
+ VarSet v23( v2, v3 );
+ VarSet v013 = v01 | v3;
+ VarSet v023 = v02 | v3;
+ VarSet v123 = v12 | v3;
+ VarSet v0123 = v012 | v3;
+ vars.push_back( v3 );
+ facs.push_back( Factor( v3 ) );
+ H.addNode();
+ K.addNode1();
+ K.addNode2();
+ K.addEdge( 3, 4 );
+ FactorGraph G3( facs );
+ BOOST_CHECK_EQUAL( G3.nrVars(), 4 );
+ BOOST_CHECK_EQUAL( G3.nrFactors(), 5 );
+ BOOST_CHECK_EQUAL( G3.nrEdges(), 8 );
+ BOOST_CHECK_EQUAL( G3.findVar( v0 ), 0 );
+ BOOST_CHECK_EQUAL( G3.findVar( v1 ), 1 );
+ BOOST_CHECK_EQUAL( G3.findVar( v2 ), 2 );
+ BOOST_CHECK_EQUAL( G3.findVar( v3 ), 3 );
+ BOOST_CHECK_EQUAL( G3.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
+ BOOST_CHECK_EQUAL( G3.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
+ BOOST_CHECK_EQUAL( G3.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
+ BOOST_CHECK_EQUAL( G3.findFactor( v01 ), 0 );
+ BOOST_CHECK_EQUAL( G3.findFactor( v12 ), 1 );
+ BOOST_CHECK_EQUAL( G3.findFactor( v1 ), 2 );
+ BOOST_CHECK_EQUAL( G3.findFactor( v02 ), 3 );
+ BOOST_CHECK_EQUAL( G3.findFactor( v3 ), 4 );
+ BOOST_CHECK_THROW( G3.findFactor( v23 ), Exception );
+ BOOST_CHECK_EQUAL( G3.delta( 0 ), v12 );
+ BOOST_CHECK_EQUAL( G3.delta( 1 ), v02 );
+ BOOST_CHECK_EQUAL( G3.delta( 2 ), v01 );
+ BOOST_CHECK_EQUAL( G3.delta( 3 ), VarSet() );
+ BOOST_CHECK_EQUAL( G3.Delta( 0 ), v012 );
+ BOOST_CHECK_EQUAL( G3.Delta( 1 ), v012 );
+ BOOST_CHECK_EQUAL( G3.Delta( 2 ), v012 );
+ BOOST_CHECK_EQUAL( G3.Delta( 3 ), v3 );
+ BOOST_CHECK( !G3.isConnected() );
+ BOOST_CHECK( !G3.isTree() );
+ BOOST_CHECK( !G3.isBinary() );
+ BOOST_CHECK( G3.isPairwise() );
+ BOOST_CHECK( G3.MarkovGraph() == H );
+ BOOST_CHECK( G3.bipGraph() == K );
+ BOOST_CHECK_EQUAL( G3.maximalFactorDomains().size(), 4 );
+ BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[0], v01 );
+ BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[1], v12 );
+ BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[2], v02 );
+ BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[3], v3 );
+
+ facs.push_back( Factor( v123 ) );
+ H.addEdge( 1, 3 );
+ H.addEdge( 2, 3 );
+ K.addNode2();
+ K.addEdge( 1, 5 );
+ K.addEdge( 2, 5 );
+ K.addEdge( 3, 5 );
+ FactorGraph G4( facs );
+ BOOST_CHECK_EQUAL( G4.nrVars(), 4 );
+ BOOST_CHECK_EQUAL( G4.nrFactors(), 6 );
+ BOOST_CHECK_EQUAL( G4.nrEdges(), 11 );
+ BOOST_CHECK_EQUAL( G4.findVar( v0 ), 0 );
+ BOOST_CHECK_EQUAL( G4.findVar( v1 ), 1 );
+ BOOST_CHECK_EQUAL( G4.findVar( v2 ), 2 );
+ BOOST_CHECK_EQUAL( G4.findVar( v3 ), 3 );
+ BOOST_CHECK_EQUAL( G4.findVars( v01 ), SmallSet<size_t>( 0, 1 ) );
+ BOOST_CHECK_EQUAL( G4.findVars( v02 ), SmallSet<size_t>( 0, 2 ) );
+ BOOST_CHECK_EQUAL( G4.findVars( v12 ), SmallSet<size_t>( 1, 2 ) );
+ BOOST_CHECK_EQUAL( G4.findFactor( v01 ), 0 );
+ BOOST_CHECK_EQUAL( G4.findFactor( v12 ), 1 );
+ BOOST_CHECK_EQUAL( G4.findFactor( v1 ), 2 );
+ BOOST_CHECK_EQUAL( G4.findFactor( v02 ), 3 );
+ BOOST_CHECK_EQUAL( G4.findFactor( v3 ), 4 );
+ BOOST_CHECK_EQUAL( G4.findFactor( v123 ), 5 );
+ BOOST_CHECK_THROW( G4.findFactor( v23 ), Exception );
+ BOOST_CHECK_EQUAL( G4.delta( 0 ), v12 );
+ BOOST_CHECK_EQUAL( G4.delta( 1 ), v023 );
+ BOOST_CHECK_EQUAL( G4.delta( 2 ), v013 );
+ BOOST_CHECK_EQUAL( G4.delta( 3 ), v12 );
+ BOOST_CHECK_EQUAL( G4.Delta( 0 ), v012 );
+ BOOST_CHECK_EQUAL( G4.Delta( 1 ), v0123 );
+ BOOST_CHECK_EQUAL( G4.Delta( 2 ), v0123 );
+ BOOST_CHECK_EQUAL( G4.Delta( 3 ), v123 );
+ BOOST_CHECK( G4.isConnected() );
+ BOOST_CHECK( !G4.isTree() );
+ BOOST_CHECK( !G4.isBinary() );
+ BOOST_CHECK( !G4.isPairwise() );
+ BOOST_CHECK( G4.MarkovGraph() == H );
+ BOOST_CHECK( G4.bipGraph() == K );
+ BOOST_CHECK_EQUAL( G4.maximalFactorDomains().size(), 3 );
+ BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[0], v01 );
+ BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[1], v02 );
+ BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[2], v123 );
+}
+
+
+BOOST_AUTO_TEST_CASE( BackupRestoreTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 2 );
+ Var v2( 2, 2 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v12( v1, v2 );
+ VarSet v012 = v01 | v2;
+
+ std::vector<Factor> facs;
+ facs.push_back( Factor( v01 ) );
+ facs.push_back( Factor( v12 ) );
+ facs.push_back( Factor( v1 ) );
+ std::vector<Var> vars;
+ vars.push_back( v0 );
+ vars.push_back( v1 );
+ vars.push_back( v2 );
+
+ FactorGraph G( facs );
+ FactorGraph Gorg( G );
+
+ BOOST_CHECK_THROW( G.setFactor( 0, Factor( v0 ), false ), Exception );
+ G.setFactor( 0, Factor( v01, 2.0 ), false );
+ BOOST_CHECK_THROW( G.restoreFactor( 0 ), Exception );
+ G.setFactor( 0, Factor( v01, 3.0 ), true );
+ G.restoreFactor( 0 );
+ BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
+ G.setFactor( 0, Gorg.factor( 0 ), false );
+ G.backupFactor( 0 );
+ BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
+ G.setFactor( 0, Factor( v01, 2.0 ), false );
+ BOOST_CHECK_EQUAL( G.factor( 0 )[0], 2.0 );
+ G.restoreFactor( 0 );
+ BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
+
+ std::map<size_t, Factor> fs;
+ fs[0] = Factor( v01, 3.0 );
+ fs[2] = Factor( v1, 2.0 );
+ G.setFactors( fs, false );
+ BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
+ BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
+ BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
+ G.restoreFactors();
+ BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
+ BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
+ BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
+ G = Gorg;
+ BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
+ BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
+ BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
+ G.setFactors( fs, true );
+ BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
+ BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
+ BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
+ G.restoreFactors();
+ BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
+ BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
+ BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
+ std::set<size_t> fsind;
+ fsind.insert( 0 );
+ fsind.insert( 2 );
+ G.backupFactors( fsind );
+ G.setFactors( fs, false );
+ BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
+ BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
+ BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
+ G.restoreFactors();
+ BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
+ BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
+ BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
+
+ G.backupFactors( v2 );
+ G.setFactor( 1, Factor(v12, 5.0) );
+ BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
+ BOOST_CHECK_EQUAL( G.factor(1), Factor(v12, 5.0) );
+ BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
+ G.restoreFactors( v2 );
+ BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
+ BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
+ BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
+
+ G.backupFactors( v1 );
+ fs[1] = Factor( v12, 5.0 );
+ G.setFactors( fs, false );
+ BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
+ BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
+ BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
+ G.restoreFactors();
+ BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
+ BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
+ BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
+ G.setFactors( fs, true );
+ BOOST_CHECK_EQUAL( G.factor(0), fs[0] );
+ BOOST_CHECK_EQUAL( G.factor(1), fs[1] );
+ BOOST_CHECK_EQUAL( G.factor(2), fs[2] );
+ G.restoreFactors( v1 );
+ BOOST_CHECK_EQUAL( G.factor(0), Gorg.factor(0) );
+ BOOST_CHECK_EQUAL( G.factor(1), Gorg.factor(1) );
+ BOOST_CHECK_EQUAL( G.factor(2), Gorg.factor(2) );
+}
+
+
+BOOST_AUTO_TEST_CASE( TransformationsTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 2 );
+ Var v2( 2, 2 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v12( v1, v2 );
+ VarSet v012 = v01 | v2;
+
+ std::vector<Factor> facs;
+ facs.push_back( Factor( v01 ).randomize() );
+ facs.push_back( Factor( v12 ).randomize() );
+ facs.push_back( Factor( v1 ).randomize() );
+ std::vector<Var> vars;
+ vars.push_back( v0 );
+ vars.push_back( v1 );
+ vars.push_back( v2 );
+
+ FactorGraph G( facs );
+
+ FactorGraph Gsmall = G.maximalFactors();
+ BOOST_CHECK_EQUAL( Gsmall.nrVars(), 3 );
+ BOOST_CHECK_EQUAL( Gsmall.nrFactors(), 2 );
+ BOOST_CHECK_EQUAL( Gsmall.factor( 0 ), G.factor( 0 ) * G.factor( 2 ) );
+ BOOST_CHECK_EQUAL( Gsmall.factor( 1 ), G.factor( 1 ) );
+
+ size_t i = 0;
+ for( size_t x = 0; x < 2; x++ ) {
+ FactorGraph Gcl = G.clamped( i, x );
+ BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
+ BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
+ BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
+ BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) * G.factor(2) );
+ BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1) );
+ }
+ i = 1;
+ for( size_t x = 0; x < 2; x++ ) {
+ FactorGraph Gcl = G.clamped( i, x );
+ BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
+ BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
+ BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) * G.factor(2).slice(vars[i],x) );
+ BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0).slice(vars[i], x) );
+ BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) );
+ }
+ i = 2;
+ for( size_t x = 0; x < 2; x++ ) {
+ FactorGraph Gcl = G.clamped( i, x );
+ BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
+ BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
+ BOOST_CHECK_EQUAL( Gcl.factor(0), createFactorDelta(vars[i], x) );
+ BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(0) );
+ BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(1).slice(vars[i], x) * G.factor(2) );
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE( OperationsTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 2 );
+ Var v2( 2, 2 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v12( v1, v2 );
+ VarSet v012 = v01 | v2;
+
+ std::vector<Factor> facs;
+ facs.push_back( Factor( v01 ).randomize() );
+ facs.push_back( Factor( v12 ).randomize() );
+ facs.push_back( Factor( v1 ).randomize() );
+ std::vector<Var> vars;
+ vars.push_back( v0 );
+ vars.push_back( v1 );
+ vars.push_back( v2 );
+
+ FactorGraph G( facs );
+
+ // clamp
+ FactorGraph Gcl = G;
+ for( size_t i = 0; i < 3; i++ )
+ for( size_t x = 0; x < 2; x++ ) {
+ Gcl.clamp( i, x, true );
+ Factor delta = createFactorDelta( vars[i], x );
+ BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
+ BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
+ for( size_t j = 0; j < 3; j++ )
+ if( G.factor(j).vars().contains( vars[i] ) )
+ BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
+ else
+ BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
+
+ Gcl.restoreFactors();
+ for( size_t j = 0; j < 3; j++ )
+ BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
+ }
+
+ // clampVar
+ for( size_t i = 0; i < 3; i++ )
+ for( size_t x = 0; x < 2; x++ ) {
+ Gcl.clampVar( i, std::vector<size_t>(1, x), true );
+ Factor delta = createFactorDelta( vars[i], x );
+ BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
+ BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
+ for( size_t j = 0; j < 3; j++ )
+ if( G.factor(j).vars().contains( vars[i] ) )
+ BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * delta );
+ else
+ BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
+
+ Gcl.restoreFactors();
+ for( size_t j = 0; j < 3; j++ )
+ BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
+ }
+ for( size_t i = 0; i < 3; i++ )
+ for( size_t x = 0; x < 2; x++ ) {
+ Gcl.clampVar( i, std::vector<size_t>(), true );
+ BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
+ BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
+ for( size_t j = 0; j < 3; j++ )
+ if( G.factor(j).vars().contains( vars[i] ) )
+ BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) * 0.0 );
+ else
+ BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
+
+ Gcl.restoreFactors();
+ for( size_t j = 0; j < 3; j++ )
+ BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
+ }
+ std::vector<size_t> both;
+ both.push_back( 0 );
+ both.push_back( 1 );
+ for( size_t i = 0; i < 3; i++ )
+ for( size_t x = 0; x < 2; x++ ) {
+ Gcl.clampVar( i, both, true );
+ BOOST_CHECK_EQUAL( Gcl.nrVars(), 3 );
+ BOOST_CHECK_EQUAL( Gcl.nrFactors(), 3 );
+ for( size_t j = 0; j < 3; j++ )
+ BOOST_CHECK_EQUAL( Gcl.factor(j), G.factor(j) );
+ Gcl.restoreFactors();
+ }
+
+ // clampFactor
+ for( size_t x = 0; x < 4; x++ ) {
+ Gcl.clampFactor( 0, std::vector<size_t>(1,x), true );
+ Factor mask( v01, 0.0 );
+ mask.set( x, 1.0 );
+ BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) * mask );
+ BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
+ BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
+ Gcl.restoreFactor( 0 );
+ }
+ for( size_t x = 0; x < 4; x++ ) {
+ Gcl.clampFactor( 1, std::vector<size_t>(1,x), true );
+ Factor mask( v12, 0.0 );
+ mask.set( x, 1.0 );
+ BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
+ BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) * mask );
+ BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) );
+ Gcl.restoreFactor( 1 );
+ }
+ for( size_t x = 0; x < 2; x++ ) {
+ Gcl.clampFactor( 2, std::vector<size_t>(1,x), true );
+ Factor mask( v1, 0.0 );
+ mask.set( x, 1.0 );
+ BOOST_CHECK_EQUAL( Gcl.factor(0), G.factor(0) );
+ BOOST_CHECK_EQUAL( Gcl.factor(1), G.factor(1) );
+ BOOST_CHECK_EQUAL( Gcl.factor(2), G.factor(2) * mask );
+ Gcl.restoreFactors();
+ }
+
+ // makeCavity
+ FactorGraph Gcav( G );
+ Gcav.makeCavity( 0, true );
+ BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
+ BOOST_CHECK_EQUAL( Gcav.factor(1), G.factor(1) );
+ BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
+ Gcav.restoreFactors();
+ Gcav.makeCavity( 1, true );
+ BOOST_CHECK_EQUAL( Gcav.factor(0), Factor( v01, 1.0 ) );
+ BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
+ BOOST_CHECK_EQUAL( Gcav.factor(2), Factor( v1, 1.0 ) );
+ Gcav.restoreFactors();
+ Gcav.makeCavity( 2, true );
+ BOOST_CHECK_EQUAL( Gcav.factor(0), G.factor(0) );
+ BOOST_CHECK_EQUAL( Gcav.factor(1), Factor( v12, 1.0 ) );
+ BOOST_CHECK_EQUAL( Gcav.factor(2), G.factor(2) );
+ Gcav.restoreFactors();
+}
+
+
+BOOST_AUTO_TEST_CASE( IOTest ) {
+ Var v0( 0, 2 );
+ Var v1( 1, 2 );
+ Var v2( 2, 2 );
+ VarSet v01( v0, v1 );
+ VarSet v02( v0, v2 );
+ VarSet v12( v1, v2 );
+ VarSet v012 = v01 | v2;
+
+ std::vector<Factor> facs;
+ facs.push_back( Factor( v01 ).randomize() );
+ facs.push_back( Factor( v12 ).randomize() );
+ facs.push_back( Factor( v1 ).randomize() );
+ std::vector<Var> vars;
+ vars.push_back( v0 );
+ vars.push_back( v1 );
+ vars.push_back( v2 );
+
+ FactorGraph G( facs );
+
+ G.WriteToFile( "factorgraph_test.fg" );
+ FactorGraph G2;
+ G2.ReadFromFile( "factorgraph_test.fg" );
+
+ BOOST_CHECK( G.vars() == G2.vars() );
+ BOOST_CHECK( G.bipGraph() == G2.bipGraph() );
+ BOOST_CHECK_EQUAL( G.nrFactors(), G2.nrFactors() );
+ for( size_t I = 0; I < G.nrFactors(); I++ ) {
+ BOOST_CHECK( G.factor(I).vars() == G2.factor(I).vars() );
+ for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
+ BOOST_CHECK_CLOSE( G.factor(I)[s], G2.factor(I)[s], tol );
+ }
+
+ std::stringstream ss;
+ std::string s;
+ G.printDot( ss );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph FactorGraph {" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv2;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=box,width=0.3,height=0.3,fixedsize=true];" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tf2;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv0 -- f0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv1 -- f2;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tv2 -- f1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
+
+ G.setFactor( 0, Factor( G.factor(0).vars(), 1.0 ) );
+ G.setFactor( 1, Factor( G.factor(1).vars(), 2.0 ) );
+ G.setFactor( 2, Factor( G.factor(2).vars(), 3.0 ) );
+ ss << G;
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 1 " );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2 " );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "4" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 1" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 1" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 1" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3 1" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 2 " );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2 " );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "4" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 2" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 2" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 2" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "3 2" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 " );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2 " );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "2" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0 3" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1 3" );
+
+ ss << G;
+ FactorGraph G3;
+ ss >> G3;
+ BOOST_CHECK( G.vars() == G3.vars() );
+ BOOST_CHECK( G.bipGraph() == G3.bipGraph() );
+ BOOST_CHECK_EQUAL( G.nrFactors(), G3.nrFactors() );
+ for( size_t I = 0; I < G.nrFactors(); I++ ) {
+ BOOST_CHECK( G.factor(I).vars() == G3.factor(I).vars() );
+ for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
+ BOOST_CHECK_CLOSE( G.factor(I)[s], G3.factor(I)[s], tol );
+ }
+}
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/graph.h>
-#include <vector>
-#include <strstream>
-
-
-using namespace dai;
-
-
-#define BOOST_TEST_MODULE GraphALTest
-
-
-#include <boost/test/unit_test.hpp>
-
-
-BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
- // check constructors
- GraphAL G0;
- BOOST_CHECK_EQUAL( G0.nrNodes(), 0 );
- BOOST_CHECK_EQUAL( G0.nrEdges(), 0 );
- BOOST_CHECK_EQUAL( G0.isConnected(), true );
- BOOST_CHECK_EQUAL( G0.isTree(), true );
- G0.checkConsistency();
-
- GraphAL G2( 2 );
- BOOST_CHECK_EQUAL( G2.nrNodes(), 2 );
- BOOST_CHECK_EQUAL( G2.nrEdges(), 0 );
- BOOST_CHECK_EQUAL( G2.isConnected(), false );
- BOOST_CHECK_EQUAL( G2.isTree(), false );
- G2.checkConsistency();
- BOOST_CHECK( !(G2 == G0) );
-
- typedef GraphAL::Edge Edge;
- std::vector<Edge> edges;
- edges.push_back( Edge( 0, 1 ) );
- edges.push_back( Edge( 1, 2 ) );
- edges.push_back( Edge( 2, 1 ) );
- edges.push_back( Edge( 1, 2 ) );
- GraphAL G3( 3, edges.begin(), edges.end() );
- BOOST_CHECK_EQUAL( G3.nrNodes(), 3 );
- BOOST_CHECK_EQUAL( G3.nrEdges(), 2 );
- BOOST_CHECK_EQUAL( G3.isConnected(), true );
- BOOST_CHECK_EQUAL( G3.isTree(), true );
- G3.checkConsistency();
- BOOST_CHECK( !(G3 == G0) );
- BOOST_CHECK( !(G3 == G2) );
-
- GraphAL G4( G3 );
- BOOST_CHECK( !(G4 == G0) );
- BOOST_CHECK( !(G4 == G2) );
- BOOST_CHECK( G4 == G3 );
-
- GraphAL G5 = G3;
- BOOST_CHECK( !(G5 == G0) );
- BOOST_CHECK( !(G5 == G2) );
- BOOST_CHECK( G5 == G3 );
-}
-
-
-BOOST_AUTO_TEST_CASE( NeighborTest ) {
- // check nb() accessor / mutator
- typedef GraphAL::Edge Edge;
- std::vector<Edge> edges;
- edges.push_back( Edge( 0, 1 ) );
- edges.push_back( Edge( 1, 2 ) );
- GraphAL G( 3, edges.begin(), edges.end() );
- BOOST_CHECK_EQUAL( G.nb(0).size(), 1 );
- BOOST_CHECK_EQUAL( G.nb(1).size(), 2 );
- BOOST_CHECK_EQUAL( G.nb(2).size(), 1 );
- BOOST_CHECK_EQUAL( G.nb(0,0).iter, 0 );
- BOOST_CHECK_EQUAL( G.nb(0,0).node, 1 );
- BOOST_CHECK_EQUAL( G.nb(0,0).dual, 0 );
- BOOST_CHECK_EQUAL( G.nb(1,0).iter, 0 );
- BOOST_CHECK_EQUAL( G.nb(1,0).node, 0 );
- BOOST_CHECK_EQUAL( G.nb(1,0).dual, 0 );
- BOOST_CHECK_EQUAL( G.nb(1,1).iter, 1 );
- BOOST_CHECK_EQUAL( G.nb(1,1).node, 2 );
- BOOST_CHECK_EQUAL( G.nb(1,1).dual, 0 );
- BOOST_CHECK_EQUAL( G.nb(2,0).iter, 0 );
- BOOST_CHECK_EQUAL( G.nb(2,0).node, 1 );
- BOOST_CHECK_EQUAL( G.nb(2,0).dual, 1 );
-}
-
-
-BOOST_AUTO_TEST_CASE( AddEraseTest ) {
- // check addition and erasure of nodes and edges
- typedef GraphAL::Edge Edge;
- std::vector<Edge> edges;
- edges.push_back( Edge( 0, 1 ) );
- edges.push_back( Edge( 1, 2 ) );
- edges.push_back( Edge( 1, 0 ) );
- GraphAL G( 2 );
- G.construct( 3, edges.begin(), edges.end() );
- G.checkConsistency();
- BOOST_CHECK_EQUAL( G.nrNodes(), 3 );
- BOOST_CHECK_EQUAL( G.nrEdges(), 2 );
- BOOST_CHECK_EQUAL( G.addNode(), 3 );
- G.checkConsistency();
- std::vector<size_t> nbs;
- nbs.push_back( 3 );
- BOOST_CHECK_EQUAL( G.addNode( nbs.begin(), nbs.end() ), 4 );
- BOOST_CHECK_EQUAL( G.addNode(), 5 );
- G.checkConsistency();
- G.addEdge( 0, 4 );
- G.checkConsistency();
- G.addEdge( 0, 5 );
- BOOST_CHECK( G.isTree() );
- G.checkConsistency();
- BOOST_CHECK_EQUAL( G.nrNodes(), 6 );
- BOOST_CHECK_EQUAL( G.nrEdges(), 5 );
- G.addEdge( 2, 3 );
- BOOST_CHECK( !G.isTree() );
-
- G.addEdge( 5, 3 );
- G.eraseNode( 0 );
- G.checkConsistency();
- BOOST_CHECK( G.isTree() );
- G.eraseEdge( 0, 1 );
- G.checkConsistency();
- BOOST_CHECK( !G.isTree() );
- BOOST_CHECK( !G.isConnected() );
- G.eraseNode( 0 );
- G.checkConsistency();
- BOOST_CHECK( G.isTree() );
- G.addEdge( 3, 2 );
- G.checkConsistency();
- BOOST_CHECK( !G.isTree() );
- G.eraseNode( 1 );
- G.checkConsistency();
- BOOST_CHECK( !G.isTree() );
- BOOST_CHECK( !G.isConnected() );
- G.eraseNode( 2 );
- G.checkConsistency();
- BOOST_CHECK( !G.isTree() );
- BOOST_CHECK( !G.isConnected() );
- G.addEdge( 1, 0 );
- G.checkConsistency();
- BOOST_CHECK( G.isTree() );
- BOOST_CHECK( G.isConnected() );
- G.eraseNode( 1 );
- G.checkConsistency();
- BOOST_CHECK( G.isTree() );
- BOOST_CHECK( G.isConnected() );
- G.eraseNode( 0 );
- BOOST_CHECK( G.isTree() );
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK_EQUAL( G.nrNodes(), 0 );
- BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
-
- G.addNode();
- G.addNode();
- G.addNode();
- G.addNode();
- G.addEdge( 0, 1 );
- G.addEdge( 2, 3 );
- G.addEdge( 0, 3 );
- G.checkConsistency();
- G.eraseNode( 2 );
- G.checkConsistency();
-}
-
-
-BOOST_AUTO_TEST_CASE( RandomAddEraseTest ) {
- // check adding and erasing nodes and edges randomly
- GraphAL G;
- for( size_t maxN = 2; maxN < 50; maxN++ )
- for( size_t repeats = 0; repeats < 10000; repeats++ ) {
- size_t action = rnd( 5 );
- size_t N = G.nrNodes();
- size_t M = G.nrEdges();
- size_t maxM = N * (N - 1) / 2;
- if( action == 0 ) {
- // add node
- if( N < maxN )
- G.addNode();
- } else if( action == 1 ) {
- // erase node
- if( N > 0 )
- G.eraseNode( rnd( N ) );
- } else if( action == 2 || action == 3 ) {
- // add edge
- if( N >= 2 && M < maxM ) {
- size_t n1 = 0;
- do {
- n1 = rnd( N );
- } while( G.nb(n1).size() >= N - 1 );
- size_t n2 = 0;
- do {
- n2 = rnd( N );
- } while( G.hasEdge( n1, n2 ) );
- G.addEdge( n1, n2 );
- }
- } else if( action == 4 ) {
- // erase edge
- if( M > 0 ) {
- size_t n1 = 0;
- do {
- n1 = rnd( N );
- } while( G.nb(n1).size() == 0 );
- size_t n2 = 0;
- do {
- n2 = rnd( N );
- } while( !G.hasEdge( n1, n2 ) );
- G.eraseEdge( n1, n2 );
- }
- }
- G.checkConsistency();
- }
-}
-
-
-BOOST_AUTO_TEST_CASE( QueriesAndCreationTest ) {
- // check queries and createGraph* functions
-
- // createGraphFull
- for( size_t N = 0; N < 20; N++ ) {
- GraphAL G = createGraphFull( N );
- BOOST_CHECK_EQUAL( G.nrNodes(), N );
- BOOST_CHECK_EQUAL( G.nrEdges(), N > 0 ? N * (N-1) / 2 : 0 );
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK_EQUAL( G.isTree(), N < 3 );
- for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
- foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
- BOOST_CHECK( G.hasEdge( n1, n2 ) );
- BOOST_CHECK( G.hasEdge( n2, n1 ) );
- }
- for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
- if( G.hasEdge( n1, n2 ) ) {
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
- }
- }
- G.checkConsistency();
- }
-
- // createGraphGrid
- for( size_t N1 = 0; N1 < 10; N1++ )
- for( size_t N2 = 0; N2 < 10; N2++ ) {
- GraphAL G = createGraphGrid( N1, N2, false );
- BOOST_CHECK_EQUAL( G.nrNodes(), N1 * N2 );
- BOOST_CHECK_EQUAL( G.nrEdges(), (N1 > 0 && N2 > 0) ? 2 * (N1-1) * (N2-1) + (N1-1) + (N2-1) : 0 );
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK_EQUAL( G.isTree(), (N1 <= 1) || (N2 <= 1) );
- for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
- foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
- BOOST_CHECK( G.hasEdge( n1, n2 ) );
- BOOST_CHECK( G.hasEdge( n2, n1 ) );
- }
- for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
- if( G.hasEdge( n1, n2 ) ) {
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
- }
- }
- G.checkConsistency();
-
- G = createGraphGrid( N1, N2, true );
- BOOST_CHECK_EQUAL( G.nrNodes(), N1 * N2 );
- if( N1 == 0 || N2 == 0 )
- BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
- else
- BOOST_CHECK_EQUAL( G.nrEdges(), (N1 <= 2 ? (N1-1) : N1) * N2 + N1 * (N2 <= 2 ? (N2-1) : N2) );
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK_EQUAL( G.isTree(), (G.nrNodes() <= 2) );
- for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
- foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
- BOOST_CHECK( G.hasEdge( n1, n2 ) );
- BOOST_CHECK( G.hasEdge( n2, n1 ) );
- }
- for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
- if( G.hasEdge( n1, n2 ) ) {
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
- }
- }
- G.checkConsistency();
- }
-
- // createGraphGrid3D
- for( size_t N1 = 0; N1 < 8; N1++ )
- for( size_t N2 = 0; N2 < 8; N2++ )
- for( size_t N3 = 0; N3 < 8; N3++ ) {
- GraphAL G = createGraphGrid3D( N1, N2, N3, false );
- BOOST_CHECK_EQUAL( G.nrNodes(), N1 * N2 * N3 );
- BOOST_CHECK_EQUAL( G.nrEdges(), (N1 > 0 && N2 > 0 && N3 > 0) ? 3 * (N1-1) * (N2-1) * (N3-1) + 2 * (N1-1) * (N2-1) + 2 * (N1-1) * (N3-1) + 2 * (N2-1) * (N3-1) + (N1-1) + (N2-1) + (N3-1) : 0 );
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK_EQUAL( G.isTree(), (G.nrNodes() == 0) || (N1 <= 1 && N2 <= 1) || (N1 <= 1 && N3 <= 1) || (N2 <= 1 && N3 <= 1) );
- for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
- foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
- BOOST_CHECK( G.hasEdge( n1, n2 ) );
- BOOST_CHECK( G.hasEdge( n2, n1 ) );
- }
- for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
- if( G.hasEdge( n1, n2 ) ) {
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
- }
- }
- G.checkConsistency();
-
- G = createGraphGrid3D( N1, N2, N3, true );
- BOOST_CHECK_EQUAL( G.nrNodes(), N1 * N2 * N3 );
- if( N1 == 0 || N2 == 0 || N3 == 0 )
- BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
- else
- BOOST_CHECK_EQUAL( G.nrEdges(), (N1 <= 2 ? (N1-1) : N1) * N2 * N3 + N1 * (N2 <= 2 ? (N2-1) : N2) * N3 + N1 * N2 * (N3 <= 2 ? (N3-1) : N3) );
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK_EQUAL( G.isTree(), (G.nrNodes() <= 2) );
- for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
- foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
- BOOST_CHECK( G.hasEdge( n1, n2 ) );
- BOOST_CHECK( G.hasEdge( n2, n1 ) );
- }
- for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
- if( G.hasEdge( n1, n2 ) ) {
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
- }
- }
- G.checkConsistency();
- }
-
- // createGraphLoop
- for( size_t N = 0; N < 100; N++ ) {
- GraphAL G = createGraphLoop( N );
- BOOST_CHECK_EQUAL( G.nrNodes(), N );
- if( N == 0 )
- BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
- else if( N <= 2 )
- BOOST_CHECK_EQUAL( G.nrEdges(), N-1 );
- else
- BOOST_CHECK_EQUAL( G.nrEdges(), N );
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK_EQUAL( G.isTree(), N <= 2 );
- for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
- foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
- BOOST_CHECK( G.hasEdge( n1, n2 ) );
- BOOST_CHECK( G.hasEdge( n2, n1 ) );
- }
- for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
- if( G.hasEdge( n1, n2 ) ) {
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
- }
- }
- G.checkConsistency();
- }
-
- // createGraphTree
- for( size_t N = 0; N < 100; N++ ) {
- GraphAL G = createGraphTree( N );
- BOOST_CHECK_EQUAL( G.nrNodes(), N );
- BOOST_CHECK_EQUAL( G.nrEdges(), N > 0 ? N - 1 : N );
- BOOST_CHECK( G.isConnected() );
- BOOST_CHECK( G.isTree() );
- for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
- foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
- BOOST_CHECK( G.hasEdge( n1, n2 ) );
- BOOST_CHECK( G.hasEdge( n2, n1 ) );
- }
- for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
- if( G.hasEdge( n1, n2 ) ) {
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
- }
- }
- G.checkConsistency();
- }
-
- // createGraphRegular
- for( size_t N = 0; N < 50; N++ ) {
- for( size_t d = 0; d < N && d <= 15; d++ ) {
- if( (N * d) % 2 == 0 ) {
- GraphAL G = createGraphRegular( N, d );
- BOOST_CHECK_EQUAL( G.nrNodes(), N );
- BOOST_CHECK_EQUAL( G.nrEdges(), d * N / 2 );
- for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
- BOOST_CHECK_EQUAL( G.nb(n1).size(), d );
- foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
- BOOST_CHECK( G.hasEdge( n1, n2 ) );
- BOOST_CHECK( G.hasEdge( n2, n1 ) );
- }
- for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
- if( G.hasEdge( n1, n2 ) ) {
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
- BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
- BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
- }
- }
- G.checkConsistency();
- }
- }
- }
-}
-
-
-BOOST_AUTO_TEST_CASE( StreamTest ) {
- // check printDot
- GraphAL G( 4 );
- G.addEdge( 0, 1 );
- G.addEdge( 0, 2 );
- G.addEdge( 1, 3 );
- G.addEdge( 2, 3 );
- G.addEdge( 2, 2 );
- G.addEdge( 3, 2 );
-
- std::stringstream ss;
- std::string s;
-
- G.printDot( ss );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph GraphAL {" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx2;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx3;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- x1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- x2;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- x3;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx2 -- x3;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
-
- ss << G;
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph GraphAL {" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx2;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx3;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- x1;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- x2;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- x3;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx2 -- x3;" );
- std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
-}
--- /dev/null
+/* This file is part of libDAI - http://www.libdai.org/
+ *
+ * libDAI is licensed under the terms of the GNU General Public License version
+ * 2, or (at your option) any later version. libDAI is distributed without any
+ * warranty. See the file COPYING for more details.
+ *
+ * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/graph.h>
+#include <vector>
+#include <strstream>
+
+
+using namespace dai;
+
+
+#define BOOST_TEST_MODULE GraphALTest
+
+
+#include <boost/test/unit_test.hpp>
+
+
+BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
+ // check constructors
+ GraphAL G0;
+ BOOST_CHECK_EQUAL( G0.nrNodes(), 0 );
+ BOOST_CHECK_EQUAL( G0.nrEdges(), 0 );
+ BOOST_CHECK_EQUAL( G0.isConnected(), true );
+ BOOST_CHECK_EQUAL( G0.isTree(), true );
+ G0.checkConsistency();
+
+ GraphAL G2( 2 );
+ BOOST_CHECK_EQUAL( G2.nrNodes(), 2 );
+ BOOST_CHECK_EQUAL( G2.nrEdges(), 0 );
+ BOOST_CHECK_EQUAL( G2.isConnected(), false );
+ BOOST_CHECK_EQUAL( G2.isTree(), false );
+ G2.checkConsistency();
+ BOOST_CHECK( !(G2 == G0) );
+
+ typedef GraphAL::Edge Edge;
+ std::vector<Edge> edges;
+ edges.push_back( Edge( 0, 1 ) );
+ edges.push_back( Edge( 1, 2 ) );
+ edges.push_back( Edge( 2, 1 ) );
+ edges.push_back( Edge( 1, 2 ) );
+ GraphAL G3( 3, edges.begin(), edges.end() );
+ BOOST_CHECK_EQUAL( G3.nrNodes(), 3 );
+ BOOST_CHECK_EQUAL( G3.nrEdges(), 2 );
+ BOOST_CHECK_EQUAL( G3.isConnected(), true );
+ BOOST_CHECK_EQUAL( G3.isTree(), true );
+ G3.checkConsistency();
+ BOOST_CHECK( !(G3 == G0) );
+ BOOST_CHECK( !(G3 == G2) );
+
+ GraphAL G4( G3 );
+ BOOST_CHECK( !(G4 == G0) );
+ BOOST_CHECK( !(G4 == G2) );
+ BOOST_CHECK( G4 == G3 );
+
+ GraphAL G5 = G3;
+ BOOST_CHECK( !(G5 == G0) );
+ BOOST_CHECK( !(G5 == G2) );
+ BOOST_CHECK( G5 == G3 );
+}
+
+
+BOOST_AUTO_TEST_CASE( NeighborTest ) {
+ // check nb() accessor / mutator
+ typedef GraphAL::Edge Edge;
+ std::vector<Edge> edges;
+ edges.push_back( Edge( 0, 1 ) );
+ edges.push_back( Edge( 1, 2 ) );
+ GraphAL G( 3, edges.begin(), edges.end() );
+ BOOST_CHECK_EQUAL( G.nb(0).size(), 1 );
+ BOOST_CHECK_EQUAL( G.nb(1).size(), 2 );
+ BOOST_CHECK_EQUAL( G.nb(2).size(), 1 );
+ BOOST_CHECK_EQUAL( G.nb(0,0).iter, 0 );
+ BOOST_CHECK_EQUAL( G.nb(0,0).node, 1 );
+ BOOST_CHECK_EQUAL( G.nb(0,0).dual, 0 );
+ BOOST_CHECK_EQUAL( G.nb(1,0).iter, 0 );
+ BOOST_CHECK_EQUAL( G.nb(1,0).node, 0 );
+ BOOST_CHECK_EQUAL( G.nb(1,0).dual, 0 );
+ BOOST_CHECK_EQUAL( G.nb(1,1).iter, 1 );
+ BOOST_CHECK_EQUAL( G.nb(1,1).node, 2 );
+ BOOST_CHECK_EQUAL( G.nb(1,1).dual, 0 );
+ BOOST_CHECK_EQUAL( G.nb(2,0).iter, 0 );
+ BOOST_CHECK_EQUAL( G.nb(2,0).node, 1 );
+ BOOST_CHECK_EQUAL( G.nb(2,0).dual, 1 );
+}
+
+
+BOOST_AUTO_TEST_CASE( AddEraseTest ) {
+ // check addition and erasure of nodes and edges
+ typedef GraphAL::Edge Edge;
+ std::vector<Edge> edges;
+ edges.push_back( Edge( 0, 1 ) );
+ edges.push_back( Edge( 1, 2 ) );
+ edges.push_back( Edge( 1, 0 ) );
+ GraphAL G( 2 );
+ G.construct( 3, edges.begin(), edges.end() );
+ G.checkConsistency();
+ BOOST_CHECK_EQUAL( G.nrNodes(), 3 );
+ BOOST_CHECK_EQUAL( G.nrEdges(), 2 );
+ BOOST_CHECK_EQUAL( G.addNode(), 3 );
+ G.checkConsistency();
+ std::vector<size_t> nbs;
+ nbs.push_back( 3 );
+ BOOST_CHECK_EQUAL( G.addNode( nbs.begin(), nbs.end() ), 4 );
+ BOOST_CHECK_EQUAL( G.addNode(), 5 );
+ G.checkConsistency();
+ G.addEdge( 0, 4 );
+ G.checkConsistency();
+ G.addEdge( 0, 5 );
+ BOOST_CHECK( G.isTree() );
+ G.checkConsistency();
+ BOOST_CHECK_EQUAL( G.nrNodes(), 6 );
+ BOOST_CHECK_EQUAL( G.nrEdges(), 5 );
+ G.addEdge( 2, 3 );
+ BOOST_CHECK( !G.isTree() );
+
+ G.addEdge( 5, 3 );
+ G.eraseNode( 0 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isTree() );
+ G.eraseEdge( 0, 1 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isTree() );
+ BOOST_CHECK( !G.isConnected() );
+ G.eraseNode( 0 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isTree() );
+ G.addEdge( 3, 2 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isTree() );
+ G.eraseNode( 1 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isTree() );
+ BOOST_CHECK( !G.isConnected() );
+ G.eraseNode( 2 );
+ G.checkConsistency();
+ BOOST_CHECK( !G.isTree() );
+ BOOST_CHECK( !G.isConnected() );
+ G.addEdge( 1, 0 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isTree() );
+ BOOST_CHECK( G.isConnected() );
+ G.eraseNode( 1 );
+ G.checkConsistency();
+ BOOST_CHECK( G.isTree() );
+ BOOST_CHECK( G.isConnected() );
+ G.eraseNode( 0 );
+ BOOST_CHECK( G.isTree() );
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK_EQUAL( G.nrNodes(), 0 );
+ BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
+
+ G.addNode();
+ G.addNode();
+ G.addNode();
+ G.addNode();
+ G.addEdge( 0, 1 );
+ G.addEdge( 2, 3 );
+ G.addEdge( 0, 3 );
+ G.checkConsistency();
+ G.eraseNode( 2 );
+ G.checkConsistency();
+}
+
+
+BOOST_AUTO_TEST_CASE( RandomAddEraseTest ) {
+ // check adding and erasing nodes and edges randomly
+ GraphAL G;
+ for( size_t maxN = 2; maxN < 50; maxN++ )
+ for( size_t repeats = 0; repeats < 10000; repeats++ ) {
+ size_t action = rnd( 5 );
+ size_t N = G.nrNodes();
+ size_t M = G.nrEdges();
+ size_t maxM = N * (N - 1) / 2;
+ if( action == 0 ) {
+ // add node
+ if( N < maxN )
+ G.addNode();
+ } else if( action == 1 ) {
+ // erase node
+ if( N > 0 )
+ G.eraseNode( rnd( N ) );
+ } else if( action == 2 || action == 3 ) {
+ // add edge
+ if( N >= 2 && M < maxM ) {
+ size_t n1 = 0;
+ do {
+ n1 = rnd( N );
+ } while( G.nb(n1).size() >= N - 1 );
+ size_t n2 = 0;
+ do {
+ n2 = rnd( N );
+ } while( G.hasEdge( n1, n2 ) );
+ G.addEdge( n1, n2 );
+ }
+ } else if( action == 4 ) {
+ // erase edge
+ if( M > 0 ) {
+ size_t n1 = 0;
+ do {
+ n1 = rnd( N );
+ } while( G.nb(n1).size() == 0 );
+ size_t n2 = 0;
+ do {
+ n2 = rnd( N );
+ } while( !G.hasEdge( n1, n2 ) );
+ G.eraseEdge( n1, n2 );
+ }
+ }
+ G.checkConsistency();
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE( QueriesAndCreationTest ) {
+ // check queries and createGraph* functions
+
+ // createGraphFull
+ for( size_t N = 0; N < 20; N++ ) {
+ GraphAL G = createGraphFull( N );
+ BOOST_CHECK_EQUAL( G.nrNodes(), N );
+ BOOST_CHECK_EQUAL( G.nrEdges(), N > 0 ? N * (N-1) / 2 : 0 );
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK_EQUAL( G.isTree(), N < 3 );
+ for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
+ foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
+ BOOST_CHECK( G.hasEdge( n1, n2 ) );
+ BOOST_CHECK( G.hasEdge( n2, n1 ) );
+ }
+ for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
+ if( G.hasEdge( n1, n2 ) ) {
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
+ }
+ }
+ G.checkConsistency();
+ }
+
+ // createGraphGrid
+ for( size_t N1 = 0; N1 < 10; N1++ )
+ for( size_t N2 = 0; N2 < 10; N2++ ) {
+ GraphAL G = createGraphGrid( N1, N2, false );
+ BOOST_CHECK_EQUAL( G.nrNodes(), N1 * N2 );
+ BOOST_CHECK_EQUAL( G.nrEdges(), (N1 > 0 && N2 > 0) ? 2 * (N1-1) * (N2-1) + (N1-1) + (N2-1) : 0 );
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK_EQUAL( G.isTree(), (N1 <= 1) || (N2 <= 1) );
+ for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
+ foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
+ BOOST_CHECK( G.hasEdge( n1, n2 ) );
+ BOOST_CHECK( G.hasEdge( n2, n1 ) );
+ }
+ for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
+ if( G.hasEdge( n1, n2 ) ) {
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
+ }
+ }
+ G.checkConsistency();
+
+ G = createGraphGrid( N1, N2, true );
+ BOOST_CHECK_EQUAL( G.nrNodes(), N1 * N2 );
+ if( N1 == 0 || N2 == 0 )
+ BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
+ else
+ BOOST_CHECK_EQUAL( G.nrEdges(), (N1 <= 2 ? (N1-1) : N1) * N2 + N1 * (N2 <= 2 ? (N2-1) : N2) );
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK_EQUAL( G.isTree(), (G.nrNodes() <= 2) );
+ for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
+ foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
+ BOOST_CHECK( G.hasEdge( n1, n2 ) );
+ BOOST_CHECK( G.hasEdge( n2, n1 ) );
+ }
+ for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
+ if( G.hasEdge( n1, n2 ) ) {
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
+ }
+ }
+ G.checkConsistency();
+ }
+
+ // createGraphGrid3D
+ for( size_t N1 = 0; N1 < 8; N1++ )
+ for( size_t N2 = 0; N2 < 8; N2++ )
+ for( size_t N3 = 0; N3 < 8; N3++ ) {
+ GraphAL G = createGraphGrid3D( N1, N2, N3, false );
+ BOOST_CHECK_EQUAL( G.nrNodes(), N1 * N2 * N3 );
+ BOOST_CHECK_EQUAL( G.nrEdges(), (N1 > 0 && N2 > 0 && N3 > 0) ? 3 * (N1-1) * (N2-1) * (N3-1) + 2 * (N1-1) * (N2-1) + 2 * (N1-1) * (N3-1) + 2 * (N2-1) * (N3-1) + (N1-1) + (N2-1) + (N3-1) : 0 );
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK_EQUAL( G.isTree(), (G.nrNodes() == 0) || (N1 <= 1 && N2 <= 1) || (N1 <= 1 && N3 <= 1) || (N2 <= 1 && N3 <= 1) );
+ for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
+ foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
+ BOOST_CHECK( G.hasEdge( n1, n2 ) );
+ BOOST_CHECK( G.hasEdge( n2, n1 ) );
+ }
+ for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
+ if( G.hasEdge( n1, n2 ) ) {
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
+ }
+ }
+ G.checkConsistency();
+
+ G = createGraphGrid3D( N1, N2, N3, true );
+ BOOST_CHECK_EQUAL( G.nrNodes(), N1 * N2 * N3 );
+ if( N1 == 0 || N2 == 0 || N3 == 0 )
+ BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
+ else
+ BOOST_CHECK_EQUAL( G.nrEdges(), (N1 <= 2 ? (N1-1) : N1) * N2 * N3 + N1 * (N2 <= 2 ? (N2-1) : N2) * N3 + N1 * N2 * (N3 <= 2 ? (N3-1) : N3) );
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK_EQUAL( G.isTree(), (G.nrNodes() <= 2) );
+ for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
+ foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
+ BOOST_CHECK( G.hasEdge( n1, n2 ) );
+ BOOST_CHECK( G.hasEdge( n2, n1 ) );
+ }
+ for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
+ if( G.hasEdge( n1, n2 ) ) {
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
+ }
+ }
+ G.checkConsistency();
+ }
+
+ // createGraphLoop
+ for( size_t N = 0; N < 100; N++ ) {
+ GraphAL G = createGraphLoop( N );
+ BOOST_CHECK_EQUAL( G.nrNodes(), N );
+ if( N == 0 )
+ BOOST_CHECK_EQUAL( G.nrEdges(), 0 );
+ else if( N <= 2 )
+ BOOST_CHECK_EQUAL( G.nrEdges(), N-1 );
+ else
+ BOOST_CHECK_EQUAL( G.nrEdges(), N );
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK_EQUAL( G.isTree(), N <= 2 );
+ for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
+ foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
+ BOOST_CHECK( G.hasEdge( n1, n2 ) );
+ BOOST_CHECK( G.hasEdge( n2, n1 ) );
+ }
+ for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
+ if( G.hasEdge( n1, n2 ) ) {
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
+ }
+ }
+ G.checkConsistency();
+ }
+
+ // createGraphTree
+ for( size_t N = 0; N < 100; N++ ) {
+ GraphAL G = createGraphTree( N );
+ BOOST_CHECK_EQUAL( G.nrNodes(), N );
+ BOOST_CHECK_EQUAL( G.nrEdges(), N > 0 ? N - 1 : N );
+ BOOST_CHECK( G.isConnected() );
+ BOOST_CHECK( G.isTree() );
+ for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
+ foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
+ BOOST_CHECK( G.hasEdge( n1, n2 ) );
+ BOOST_CHECK( G.hasEdge( n2, n1 ) );
+ }
+ for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
+ if( G.hasEdge( n1, n2 ) ) {
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
+ }
+ }
+ G.checkConsistency();
+ }
+
+ // createGraphRegular
+ for( size_t N = 0; N < 50; N++ ) {
+ for( size_t d = 0; d < N && d <= 15; d++ ) {
+ if( (N * d) % 2 == 0 ) {
+ GraphAL G = createGraphRegular( N, d );
+ BOOST_CHECK_EQUAL( G.nrNodes(), N );
+ BOOST_CHECK_EQUAL( G.nrEdges(), d * N / 2 );
+ for( size_t n1 = 0; n1 < G.nrNodes(); n1++ ) {
+ BOOST_CHECK_EQUAL( G.nb(n1).size(), d );
+ foreach( const GraphAL::Neighbor &n2, G.nb(n1) ) {
+ BOOST_CHECK( G.hasEdge( n1, n2 ) );
+ BOOST_CHECK( G.hasEdge( n2, n1 ) );
+ }
+ for( size_t n2 = 0; n2 < G.nrNodes(); n2++ )
+ if( G.hasEdge( n1, n2 ) ) {
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ), n2 );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ), n1 );
+ BOOST_CHECK_EQUAL( G.nb( n1, G.findNb( n1, n2 ) ).iter, G.findNb( n1, n2 ) );
+ BOOST_CHECK_EQUAL( G.nb( n2, G.findNb( n2, n1 ) ).iter, G.findNb( n2, n1 ) );
+ }
+ }
+ G.checkConsistency();
+ }
+ }
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE( StreamTest ) {
+ // check printDot
+ GraphAL G( 4 );
+ G.addEdge( 0, 1 );
+ G.addEdge( 0, 2 );
+ G.addEdge( 1, 3 );
+ G.addEdge( 2, 3 );
+ G.addEdge( 2, 2 );
+ G.addEdge( 3, 2 );
+
+ std::stringstream ss;
+ std::string s;
+
+ G.printDot( ss );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph GraphAL {" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx2;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx3;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- x1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- x2;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- x3;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx2 -- x3;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
+
+ ss << G;
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "graph GraphAL {" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "node[shape=circle,width=0.4,fixedsize=true];" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx2;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx3;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- x1;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx0 -- x2;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- x3;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx2 -- x3;" );
+ std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
+}
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/index.h>
-#include <strstream>
-#include <map>
-
-
-using namespace dai;
-
-
-#define BOOST_TEST_MODULE IndexTest
-
-
-#include <boost/test/unit_test.hpp>
-
-
-BOOST_AUTO_TEST_CASE( IndexForTest ) {
- IndexFor x;
- BOOST_CHECK( !x.valid() );
- x.reset();
- BOOST_CHECK( x.valid() );
-
- size_t nrVars = 5;
- std::vector<Var> vars;
- for( size_t i = 0; i < nrVars; i++ )
- vars.push_back( Var( i, i+2 ) );
-
- for( size_t repeat = 0; repeat < 10000; repeat++ ) {
- VarSet indexVars;
- VarSet forVars;
- for( size_t i = 0; i < 5; i++ ) {
- if( rnd(2) == 0 )
- indexVars |= vars[i];
- if( rnd(2) == 0 )
- forVars |= vars[i];
- }
- IndexFor ind( indexVars, forVars );
- size_t iter = 0;
- for( ; ind.valid(); ind++, iter++ )
- BOOST_CHECK_EQUAL( calcLinearState( indexVars, calcState( forVars, iter ) ), (size_t)ind );
- BOOST_CHECK_EQUAL( iter, forVars.nrStates() );
- iter = 0;
- ind.reset();
- for( ; ind.valid(); ++ind, iter++ )
- BOOST_CHECK_EQUAL( calcLinearState( indexVars, calcState( forVars, iter ) ), (size_t)ind );
- BOOST_CHECK_EQUAL( iter, forVars.nrStates() );
- }
-}
-
-
-BOOST_AUTO_TEST_CASE( PermuteTest ) {
- Permute x;
-
- Var x0(0, 2);
- Var x1(1, 3);
- Var x2(2, 2);
- std::vector<Var> V;
- V.push_back( x1 );
- V.push_back( x2 );
- V.push_back( x0 );
- VarSet X( V.begin(), V.end() );
- Permute sigma(V);
- BOOST_CHECK_EQUAL( sigma.sigma()[0], 2 );
- BOOST_CHECK_EQUAL( sigma.sigma()[1], 0 );
- BOOST_CHECK_EQUAL( sigma.sigma()[2], 1 );
- BOOST_CHECK_EQUAL( sigma[0], 2 );
- BOOST_CHECK_EQUAL( sigma[1], 0 );
- BOOST_CHECK_EQUAL( sigma[2], 1 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 0 ), 0 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 1 ), 2 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 2 ), 4 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 3 ), 6 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 4 ), 8 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 5 ), 10 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 6 ), 1 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 7 ), 3 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 8 ), 5 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 9 ), 7 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 10 ), 9 );
- BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 11 ), 11 );
-
- std::vector<size_t> rs, sig;
- rs.push_back(2);
- rs.push_back(3);
- rs.push_back(2);
- sig.push_back(2);
- sig.push_back(0);
- sig.push_back(1);
- Permute tau( rs, sig );
- BOOST_CHECK( tau.sigma() == sig );
- BOOST_CHECK_EQUAL( tau[0], 2 );
- BOOST_CHECK_EQUAL( tau[1], 0 );
- BOOST_CHECK_EQUAL( tau[2], 1 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 0 ), 0 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 1 ), 2 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 2 ), 4 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 3 ), 6 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 4 ), 8 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 5 ), 10 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 6 ), 1 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 7 ), 3 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 8 ), 5 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 9 ), 7 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 10 ), 9 );
- BOOST_CHECK_EQUAL( tau.convertLinearIndex( 11 ), 11 );
-}
-
-
-BOOST_AUTO_TEST_CASE( multiforTest ) {
- multifor x;
- BOOST_CHECK( x.valid() );
-
- std::vector<size_t> ranges;
- ranges.push_back( 3 );
- ranges.push_back( 4 );
- ranges.push_back( 5 );
- multifor S(ranges);
- size_t s = 0;
- for( size_t s2 = 0; s2 < 5; s2++ )
- for( size_t s1 = 0; s1 < 4; s1++ )
- for( size_t s0 = 0; s0 < 3; s0++, s++, S++ ) {
- BOOST_CHECK( S.valid() );
- BOOST_CHECK_EQUAL( s, (size_t)S );
- BOOST_CHECK_EQUAL( S[0], s0 );
- BOOST_CHECK_EQUAL( S[1], s1 );
- BOOST_CHECK_EQUAL( S[2], s2 );
- }
- BOOST_CHECK( !S.valid() );
-
- for( size_t repeat = 0; repeat < 10000; repeat++ ) {
- std::vector<size_t> dims;
- size_t total = 1;
- for( size_t i = 0; i < 4; i++ ) {
- dims.push_back( rnd(3) + 1 );
- total *= dims.back();
- }
- multifor ind( dims );
- size_t iter = 0;
- for( ; ind.valid(); ind++, iter++ ) {
- BOOST_CHECK_EQUAL( (size_t)ind, iter );
- BOOST_CHECK_EQUAL( ind[0], iter % dims[0] );
- BOOST_CHECK_EQUAL( ind[1], (iter / dims[0]) % dims[1] );
- BOOST_CHECK_EQUAL( ind[2], (iter / (dims[0] * dims[1])) % dims[2] );
- BOOST_CHECK_EQUAL( ind[3], (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
- }
- BOOST_CHECK_EQUAL( iter, total );
- iter = 0;
- ind.reset();
- for( ; ind.valid(); ++ind, iter++ ) {
- BOOST_CHECK_EQUAL( (size_t)ind, iter );
- BOOST_CHECK_EQUAL( ind[0], iter % dims[0] );
- BOOST_CHECK_EQUAL( ind[1], (iter / dims[0]) % dims[1] );
- BOOST_CHECK_EQUAL( ind[2], (iter / (dims[0] * dims[1])) % dims[2] );
- BOOST_CHECK_EQUAL( ind[3], (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
- }
- BOOST_CHECK_EQUAL( iter, total );
- }
-}
-
-
-BOOST_AUTO_TEST_CASE( StateTest ) {
- State x;
- BOOST_CHECK( x.valid() );
-
- Var v0( 0, 3 );
- Var v1( 1, 4 );
- Var v2( 3, 5 );
- VarSet vars;
- vars |= v2;
- vars |= v1;
- vars |= v0;
- State S( vars );
- size_t s = 0;
- for( size_t s2 = 0; s2 < 5; s2++ )
- for( size_t s1 = 0; s1 < 4; s1++ )
- for( size_t s0 = 0; s0 < 3; s0++, s++, S++ ) {
- BOOST_CHECK( S.valid() );
- BOOST_CHECK_EQUAL( s, (size_t)S );
- BOOST_CHECK_EQUAL( S(v0), s0 );
- BOOST_CHECK_EQUAL( S(v1), s1 );
- BOOST_CHECK_EQUAL( S(v2), s2 );
- BOOST_CHECK_EQUAL( S( Var( 2, 2 ) ), 0 );
- }
- BOOST_CHECK( !S.valid() );
- S.reset();
- std::vector<std::pair<Var, size_t> > ps;
- ps.push_back( std::make_pair( Var( 2, 2 ), 1 ) );
- ps.push_back( std::make_pair( Var( 4, 2 ), 1 ) );
- S.insert( ps.begin(), ps.end() );
- BOOST_CHECK( S.valid() );
- BOOST_CHECK_EQUAL( (size_t)S, 132 );
-
- for( size_t repeat = 0; repeat < 10000; repeat++ ) {
- std::vector<size_t> dims;
- size_t total = 1;
- for( size_t i = 0; i < 4; i++ ) {
- dims.push_back( rnd(3) + 1 );
- total *= dims.back();
- }
- std::vector<Var> vs;
- for( size_t i = 0; i < 4; i++ )
- vs.push_back( Var( i, dims[i] ) );
- State ind( VarSet( vs.begin(), vs.end() ) );
- size_t iter = 0;
- for( ; ind.valid(); ind++, iter++ ) {
- BOOST_CHECK_EQUAL( (size_t)ind, iter );
- BOOST_CHECK_EQUAL( ind(vs[0]), iter % dims[0] );
- BOOST_CHECK_EQUAL( ind(vs[1]), (iter / dims[0]) % dims[1] );
- BOOST_CHECK_EQUAL( ind(vs[2]), (iter / (dims[0] * dims[1])) % dims[2] );
- BOOST_CHECK_EQUAL( ind(vs[3]), (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
- BOOST_CHECK_EQUAL( ind(VarSet(vs[0], vs[1])), iter % (dims[0] * dims[1]) );
- BOOST_CHECK_EQUAL( ind(VarSet(vs[1], vs[2])), (iter / dims[0]) % (dims[1] * dims[2]) );
- BOOST_CHECK_EQUAL( ind(VarSet(vs[2], vs[3])), (iter / (dims[0] * dims[1])) % (dims[2] * dims[3]) );
- BOOST_CHECK_EQUAL( ind(VarSet(vs.begin(), vs.end())), iter );
- State indcopy( VarSet(vs.begin(), vs.end()), (size_t)ind );
- BOOST_CHECK_EQUAL( ind(vs[0]), indcopy(vs[0]) );
- BOOST_CHECK_EQUAL( ind(vs[1]), indcopy(vs[1]) );
- BOOST_CHECK_EQUAL( ind(vs[2]), indcopy(vs[2]) );
- BOOST_CHECK_EQUAL( ind(vs[3]), indcopy(vs[3]) );
- State indcopy2( indcopy.get() );
- BOOST_CHECK_EQUAL( ind(vs[0]), indcopy2(vs[0]) );
- BOOST_CHECK_EQUAL( ind(vs[1]), indcopy2(vs[1]) );
- BOOST_CHECK_EQUAL( ind(vs[2]), indcopy2(vs[2]) );
- BOOST_CHECK_EQUAL( ind(vs[3]), indcopy2(vs[3]) );
- std::map<Var,size_t> indmap( ind );
- State indcopy3( indmap );
- BOOST_CHECK_EQUAL( ind(vs[0]), indcopy3(vs[0]) );
- BOOST_CHECK_EQUAL( ind(vs[1]), indcopy3(vs[1]) );
- BOOST_CHECK_EQUAL( ind(vs[2]), indcopy3(vs[2]) );
- BOOST_CHECK_EQUAL( ind(vs[3]), indcopy3(vs[3]) );
- }
- BOOST_CHECK_EQUAL( iter, total );
- iter = 0;
- ind.reset();
- for( ; ind.valid(); ++ind, iter++ ) {
- BOOST_CHECK_EQUAL( (size_t)ind, iter );
- BOOST_CHECK_EQUAL( ind(vs[0]), iter % dims[0] );
- BOOST_CHECK_EQUAL( ind(vs[1]), (iter / dims[0]) % dims[1] );
- BOOST_CHECK_EQUAL( ind(vs[2]), (iter / (dims[0] * dims[1])) % dims[2] );
- BOOST_CHECK_EQUAL( ind(vs[3]), (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
- State::const_iterator ci = ind.begin();
- BOOST_CHECK_EQUAL( (ci++)->second, iter % dims[0] );
- BOOST_CHECK_EQUAL( (ci++)->second, (iter / dims[0]) % dims[1] );
- BOOST_CHECK_EQUAL( (ci++)->second, (iter / (dims[0] * dims[1])) % dims[2] );
- BOOST_CHECK_EQUAL( (ci++)->second, (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
- BOOST_CHECK( ci == ind.end() );
- }
- BOOST_CHECK_EQUAL( iter, total );
- State::const_iterator ci = ind.begin();
- BOOST_CHECK_EQUAL( (ci++)->first, vs[0] );
- BOOST_CHECK_EQUAL( (ci++)->first, vs[1] );
- BOOST_CHECK_EQUAL( (ci++)->first, vs[2] );
- BOOST_CHECK_EQUAL( (ci++)->first, vs[3] );
- BOOST_CHECK( ci == ind.end() );
- }
-}
--- /dev/null
+/* This file is part of libDAI - http://www.libdai.org/
+ *
+ * libDAI is licensed under the terms of the GNU General Public License version
+ * 2, or (at your option) any later version. libDAI is distributed without any
+ * warranty. See the file COPYING for more details.
+ *
+ * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/index.h>
+#include <strstream>
+#include <map>
+
+
+using namespace dai;
+
+
+#define BOOST_TEST_MODULE IndexTest
+
+
+#include <boost/test/unit_test.hpp>
+
+
+BOOST_AUTO_TEST_CASE( IndexForTest ) {
+ IndexFor x;
+ BOOST_CHECK( !x.valid() );
+ x.reset();
+ BOOST_CHECK( x.valid() );
+
+ size_t nrVars = 5;
+ std::vector<Var> vars;
+ for( size_t i = 0; i < nrVars; i++ )
+ vars.push_back( Var( i, i+2 ) );
+
+ for( size_t repeat = 0; repeat < 10000; repeat++ ) {
+ VarSet indexVars;
+ VarSet forVars;
+ for( size_t i = 0; i < 5; i++ ) {
+ if( rnd(2) == 0 )
+ indexVars |= vars[i];
+ if( rnd(2) == 0 )
+ forVars |= vars[i];
+ }
+ IndexFor ind( indexVars, forVars );
+ size_t iter = 0;
+ for( ; ind.valid(); ind++, iter++ )
+ BOOST_CHECK_EQUAL( calcLinearState( indexVars, calcState( forVars, iter ) ), (size_t)ind );
+ BOOST_CHECK_EQUAL( iter, forVars.nrStates() );
+ iter = 0;
+ ind.reset();
+ for( ; ind.valid(); ++ind, iter++ )
+ BOOST_CHECK_EQUAL( calcLinearState( indexVars, calcState( forVars, iter ) ), (size_t)ind );
+ BOOST_CHECK_EQUAL( iter, forVars.nrStates() );
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE( PermuteTest ) {
+ Permute x;
+
+ Var x0(0, 2);
+ Var x1(1, 3);
+ Var x2(2, 2);
+ std::vector<Var> V;
+ V.push_back( x1 );
+ V.push_back( x2 );
+ V.push_back( x0 );
+ VarSet X( V.begin(), V.end() );
+ Permute sigma(V);
+ BOOST_CHECK_EQUAL( sigma.sigma()[0], 2 );
+ BOOST_CHECK_EQUAL( sigma.sigma()[1], 0 );
+ BOOST_CHECK_EQUAL( sigma.sigma()[2], 1 );
+ BOOST_CHECK_EQUAL( sigma[0], 2 );
+ BOOST_CHECK_EQUAL( sigma[1], 0 );
+ BOOST_CHECK_EQUAL( sigma[2], 1 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 0 ), 0 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 1 ), 2 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 2 ), 4 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 3 ), 6 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 4 ), 8 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 5 ), 10 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 6 ), 1 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 7 ), 3 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 8 ), 5 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 9 ), 7 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 10 ), 9 );
+ BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 11 ), 11 );
+
+ std::vector<size_t> rs, sig;
+ rs.push_back(2);
+ rs.push_back(3);
+ rs.push_back(2);
+ sig.push_back(2);
+ sig.push_back(0);
+ sig.push_back(1);
+ Permute tau( rs, sig );
+ BOOST_CHECK( tau.sigma() == sig );
+ BOOST_CHECK_EQUAL( tau[0], 2 );
+ BOOST_CHECK_EQUAL( tau[1], 0 );
+ BOOST_CHECK_EQUAL( tau[2], 1 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 0 ), 0 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 1 ), 2 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 2 ), 4 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 3 ), 6 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 4 ), 8 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 5 ), 10 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 6 ), 1 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 7 ), 3 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 8 ), 5 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 9 ), 7 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 10 ), 9 );
+ BOOST_CHECK_EQUAL( tau.convertLinearIndex( 11 ), 11 );
+}
+
+
+BOOST_AUTO_TEST_CASE( multiforTest ) {
+ multifor x;
+ BOOST_CHECK( x.valid() );
+
+ std::vector<size_t> ranges;
+ ranges.push_back( 3 );
+ ranges.push_back( 4 );
+ ranges.push_back( 5 );
+ multifor S(ranges);
+ size_t s = 0;
+ for( size_t s2 = 0; s2 < 5; s2++ )
+ for( size_t s1 = 0; s1 < 4; s1++ )
+ for( size_t s0 = 0; s0 < 3; s0++, s++, S++ ) {
+ BOOST_CHECK( S.valid() );
+ BOOST_CHECK_EQUAL( s, (size_t)S );
+ BOOST_CHECK_EQUAL( S[0], s0 );
+ BOOST_CHECK_EQUAL( S[1], s1 );
+ BOOST_CHECK_EQUAL( S[2], s2 );
+ }
+ BOOST_CHECK( !S.valid() );
+
+ for( size_t repeat = 0; repeat < 10000; repeat++ ) {
+ std::vector<size_t> dims;
+ size_t total = 1;
+ for( size_t i = 0; i < 4; i++ ) {
+ dims.push_back( rnd(3) + 1 );
+ total *= dims.back();
+ }
+ multifor ind( dims );
+ size_t iter = 0;
+ for( ; ind.valid(); ind++, iter++ ) {
+ BOOST_CHECK_EQUAL( (size_t)ind, iter );
+ BOOST_CHECK_EQUAL( ind[0], iter % dims[0] );
+ BOOST_CHECK_EQUAL( ind[1], (iter / dims[0]) % dims[1] );
+ BOOST_CHECK_EQUAL( ind[2], (iter / (dims[0] * dims[1])) % dims[2] );
+ BOOST_CHECK_EQUAL( ind[3], (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
+ }
+ BOOST_CHECK_EQUAL( iter, total );
+ iter = 0;
+ ind.reset();
+ for( ; ind.valid(); ++ind, iter++ ) {
+ BOOST_CHECK_EQUAL( (size_t)ind, iter );
+ BOOST_CHECK_EQUAL( ind[0], iter % dims[0] );
+ BOOST_CHECK_EQUAL( ind[1], (iter / dims[0]) % dims[1] );
+ BOOST_CHECK_EQUAL( ind[2], (iter / (dims[0] * dims[1])) % dims[2] );
+ BOOST_CHECK_EQUAL( ind[3], (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
+ }
+ BOOST_CHECK_EQUAL( iter, total );
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE( StateTest ) {
+ State x;
+ BOOST_CHECK( x.valid() );
+
+ Var v0( 0, 3 );
+ Var v1( 1, 4 );
+ Var v2( 3, 5 );
+ VarSet vars;
+ vars |= v2;
+ vars |= v1;
+ vars |= v0;
+ State S( vars );
+ size_t s = 0;
+ for( size_t s2 = 0; s2 < 5; s2++ )
+ for( size_t s1 = 0; s1 < 4; s1++ )
+ for( size_t s0 = 0; s0 < 3; s0++, s++, S++ ) {
+ BOOST_CHECK( S.valid() );
+ BOOST_CHECK_EQUAL( s, (size_t)S );
+ BOOST_CHECK_EQUAL( S(v0), s0 );
+ BOOST_CHECK_EQUAL( S(v1), s1 );
+ BOOST_CHECK_EQUAL( S(v2), s2 );
+ BOOST_CHECK_EQUAL( S( Var( 2, 2 ) ), 0 );
+ }
+ BOOST_CHECK( !S.valid() );
+ S.reset();
+ std::vector<std::pair<Var, size_t> > ps;
+ ps.push_back( std::make_pair( Var( 2, 2 ), 1 ) );
+ ps.push_back( std::make_pair( Var( 4, 2 ), 1 ) );
+ S.insert( ps.begin(), ps.end() );
+ BOOST_CHECK( S.valid() );
+ BOOST_CHECK_EQUAL( (size_t)S, 132 );
+
+ for( size_t repeat = 0; repeat < 10000; repeat++ ) {
+ std::vector<size_t> dims;
+ size_t total = 1;
+ for( size_t i = 0; i < 4; i++ ) {
+ dims.push_back( rnd(3) + 1 );
+ total *= dims.back();
+ }
+ std::vector<Var> vs;
+ for( size_t i = 0; i < 4; i++ )
+ vs.push_back( Var( i, dims[i] ) );
+ State ind( VarSet( vs.begin(), vs.end() ) );
+ size_t iter = 0;
+ for( ; ind.valid(); ind++, iter++ ) {
+ BOOST_CHECK_EQUAL( (size_t)ind, iter );
+ BOOST_CHECK_EQUAL( ind(vs[0]), iter % dims[0] );
+ BOOST_CHECK_EQUAL( ind(vs[1]), (iter / dims[0]) % dims[1] );
+ BOOST_CHECK_EQUAL( ind(vs[2]), (iter / (dims[0] * dims[1])) % dims[2] );
+ BOOST_CHECK_EQUAL( ind(vs[3]), (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
+ BOOST_CHECK_EQUAL( ind(VarSet(vs[0], vs[1])), iter % (dims[0] * dims[1]) );
+ BOOST_CHECK_EQUAL( ind(VarSet(vs[1], vs[2])), (iter / dims[0]) % (dims[1] * dims[2]) );
+ BOOST_CHECK_EQUAL( ind(VarSet(vs[2], vs[3])), (iter / (dims[0] * dims[1])) % (dims[2] * dims[3]) );
+ BOOST_CHECK_EQUAL( ind(VarSet(vs.begin(), vs.end())), iter );
+ State indcopy( VarSet(vs.begin(), vs.end()), (size_t)ind );
+ BOOST_CHECK_EQUAL( ind(vs[0]), indcopy(vs[0]) );
+ BOOST_CHECK_EQUAL( ind(vs[1]), indcopy(vs[1]) );
+ BOOST_CHECK_EQUAL( ind(vs[2]), indcopy(vs[2]) );
+ BOOST_CHECK_EQUAL( ind(vs[3]), indcopy(vs[3]) );
+ State indcopy2( indcopy.get() );
+ BOOST_CHECK_EQUAL( ind(vs[0]), indcopy2(vs[0]) );
+ BOOST_CHECK_EQUAL( ind(vs[1]), indcopy2(vs[1]) );
+ BOOST_CHECK_EQUAL( ind(vs[2]), indcopy2(vs[2]) );
+ BOOST_CHECK_EQUAL( ind(vs[3]), indcopy2(vs[3]) );
+ std::map<Var,size_t> indmap( ind );
+ State indcopy3( indmap );
+ BOOST_CHECK_EQUAL( ind(vs[0]), indcopy3(vs[0]) );
+ BOOST_CHECK_EQUAL( ind(vs[1]), indcopy3(vs[1]) );
+ BOOST_CHECK_EQUAL( ind(vs[2]), indcopy3(vs[2]) );
+ BOOST_CHECK_EQUAL( ind(vs[3]), indcopy3(vs[3]) );
+ }
+ BOOST_CHECK_EQUAL( iter, total );
+ iter = 0;
+ ind.reset();
+ for( ; ind.valid(); ++ind, iter++ ) {
+ BOOST_CHECK_EQUAL( (size_t)ind, iter );
+ BOOST_CHECK_EQUAL( ind(vs[0]), iter % dims[0] );
+ BOOST_CHECK_EQUAL( ind(vs[1]), (iter / dims[0]) % dims[1] );
+ BOOST_CHECK_EQUAL( ind(vs[2]), (iter / (dims[0] * dims[1])) % dims[2] );
+ BOOST_CHECK_EQUAL( ind(vs[3]), (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
+ State::const_iterator ci = ind.begin();
+ BOOST_CHECK_EQUAL( (ci++)->second, iter % dims[0] );
+ BOOST_CHECK_EQUAL( (ci++)->second, (iter / dims[0]) % dims[1] );
+ BOOST_CHECK_EQUAL( (ci++)->second, (iter / (dims[0] * dims[1])) % dims[2] );
+ BOOST_CHECK_EQUAL( (ci++)->second, (iter / (dims[0] * dims[1] * dims[2])) % dims[3] );
+ BOOST_CHECK( ci == ind.end() );
+ }
+ BOOST_CHECK_EQUAL( iter, total );
+ State::const_iterator ci = ind.begin();
+ BOOST_CHECK_EQUAL( (ci++)->first, vs[0] );
+ BOOST_CHECK_EQUAL( (ci++)->first, vs[1] );
+ BOOST_CHECK_EQUAL( (ci++)->first, vs[2] );
+ BOOST_CHECK_EQUAL( (ci++)->first, vs[3] );
+ BOOST_CHECK( ci == ind.end() );
+ }
+}
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/prob.h>
-#include <strstream>
-
-
-using namespace dai;
-
-
-const double tol = 1e-8;
-
-
-#define BOOST_TEST_MODULE ProbTest
-
-
-#include <boost/test/unit_test.hpp>
-#include <boost/test/floating_point_comparison.hpp>
-
-
-BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
- // check constructors
- Prob x1;
- BOOST_CHECK_EQUAL( x1.size(), 0 );
- BOOST_CHECK( x1.p() == Prob::container_type() );
-
- Prob x2( 3 );
- BOOST_CHECK_EQUAL( x2.size(), 3 );
- BOOST_CHECK( x2.p() == Prob::container_type( 3, 1.0 / 3.0 ) );
- BOOST_CHECK_EQUAL( x2[0], 1.0 / 3.0 );
- BOOST_CHECK_EQUAL( x2[1], 1.0 / 3.0 );
- BOOST_CHECK_EQUAL( x2[2], 1.0 / 3.0 );
-
- Prob x3( 4, 1.0 );
- BOOST_CHECK_EQUAL( x3.size(), 4 );
- BOOST_CHECK( x3.p() == Prob::container_type( 4, 1.0 ) );
- BOOST_CHECK_EQUAL( x3[0], 1.0 );
- BOOST_CHECK_EQUAL( x3[1], 1.0 );
- BOOST_CHECK_EQUAL( x3[2], 1.0 );
- BOOST_CHECK_EQUAL( x3[3], 1.0 );
- x3.set( 0, 0.5 );
- x3.set( 1, 1.0 );
- x3.set( 2, 2.0 );
- x3.set( 3, 4.0 );
-
- std::vector<Real> v;
- v.push_back( 0.5 );
- v.push_back( 1.0 );
- v.push_back( 2.0 );
- v.push_back( 4.0 );
- Prob x4( v.begin(), v.end(), 0 );
- BOOST_CHECK_EQUAL( x4.size(), 4 );
- BOOST_CHECK( x4.p() == x3.p() );
- BOOST_CHECK( x4 == x3 );
- BOOST_CHECK_EQUAL( x4[0], 0.5 );
- BOOST_CHECK_EQUAL( x4[1], 1.0 );
- BOOST_CHECK_EQUAL( x4[2], 2.0 );
- BOOST_CHECK_EQUAL( x4[3], 4.0 );
-
- Prob x5( v.begin(), v.end(), v.size() );
- BOOST_CHECK_EQUAL( x5.size(), 4 );
- BOOST_CHECK( x5.p() == x3.p() );
- BOOST_CHECK( x5 == x3 );
- BOOST_CHECK_EQUAL( x5[0], 0.5 );
- BOOST_CHECK_EQUAL( x5[1], 1.0 );
- BOOST_CHECK_EQUAL( x5[2], 2.0 );
- BOOST_CHECK_EQUAL( x5[3], 4.0 );
-
- std::vector<int> y( 3, 2 );
- Prob x6( y );
- BOOST_CHECK_EQUAL( x6.size(), 3 );
- BOOST_CHECK_EQUAL( x6[0], 2.0 );
- BOOST_CHECK_EQUAL( x6[1], 2.0 );
- BOOST_CHECK_EQUAL( x6[2], 2.0 );
-
- Prob x7( x6 );
- BOOST_CHECK( x7 == x6 );
-
- Prob x8 = x6;
- BOOST_CHECK( x8 == x6 );
-
- x7.resize( 5 );
- BOOST_CHECK_EQUAL( x7.size(), 5 );
- BOOST_CHECK_EQUAL( x7[0], 2.0 );
- BOOST_CHECK_EQUAL( x7[1], 2.0 );
- BOOST_CHECK_EQUAL( x7[2], 2.0 );
- BOOST_CHECK_EQUAL( x7[3], 0.0 );
- BOOST_CHECK_EQUAL( x7[4], 0.0 );
-
- x8.resize( 1 );
- BOOST_CHECK_EQUAL( x8.size(), 1 );
- BOOST_CHECK_EQUAL( x8[0], 2.0 );
-}
-
-
-#ifndef DAI_SPARSE
-BOOST_AUTO_TEST_CASE( IteratorTest ) {
- Prob x( 5, 0.0 );
- size_t i;
- for( i = 0; i < x.size(); i++ )
- x.set( i, i );
-
- i = 0;
- for( Prob::const_iterator cit = x.begin(); cit != x.end(); cit++, i++ )
- BOOST_CHECK_EQUAL( *cit, i );
-
- i = 0;
- for( Prob::iterator it = x.begin(); it != x.end(); it++, i++ )
- *it = 4 - i;
-
- i = 0;
- for( Prob::const_iterator it = x.begin(); it != x.end(); it++, i++ )
- BOOST_CHECK_EQUAL( *it, 4 - i );
-
- i = 0;
- for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
- BOOST_CHECK_EQUAL( *crit, i );
-
- i = 0;
- for( Prob::reverse_iterator rit = x.rbegin(); rit != x.rend(); rit++, i++ )
- *rit = 2 * i;
-
- i = 0;
- for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
- BOOST_CHECK_EQUAL( *crit, 2 * i );
-}
-#endif
-
-
-BOOST_AUTO_TEST_CASE( QueriesTest ) {
- Prob x( 5, 0.0 );
- for( size_t i = 0; i < x.size(); i++ )
- x.set( i, 2.0 - i );
-
- // test accumulate, min, max, sum, sumAbs, maxAbs
- BOOST_CHECK_EQUAL( x.sum(), 0.0 );
- BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_id<Real>() ), 0.0 );
- BOOST_CHECK_EQUAL( x.accumulateSum( 1.0, fo_id<Real>() ), 1.0 );
- BOOST_CHECK_EQUAL( x.accumulateSum( -1.0, fo_id<Real>() ), -1.0 );
- BOOST_CHECK_EQUAL( x.max(), 2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( -INFINITY, fo_id<Real>(), false ), 2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_id<Real>(), false ), 3.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( -5.0, fo_id<Real>(), false ), 2.0 );
- BOOST_CHECK_EQUAL( x.min(), -2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( INFINITY, fo_id<Real>(), true ), -2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_id<Real>(), true ), -3.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( 5.0, fo_id<Real>(), true ), -2.0 );
- BOOST_CHECK_EQUAL( x.sumAbs(), 6.0 );
- BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_abs<Real>() ), 6.0 );
- BOOST_CHECK_EQUAL( x.accumulateSum( 1.0, fo_abs<Real>() ), 7.0 );
- BOOST_CHECK_EQUAL( x.accumulateSum( -1.0, fo_abs<Real>() ), 7.0 );
- BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( 0.0, fo_abs<Real>(), false ), 2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( 1.0, fo_abs<Real>(), false ), 2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( -1.0, fo_abs<Real>(), false ), 2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_abs<Real>(), false ), 3.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_abs<Real>(), false ), 3.0 );
- x.set( 1, 1.0 );
- BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( 0.0, fo_abs<Real>(), false ), 2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( 1.0, fo_abs<Real>(), false ), 2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( -1.0, fo_abs<Real>(), false ), 2.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_abs<Real>(), false ), 3.0 );
- BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_abs<Real>(), false ), 3.0 );
- for( size_t i = 0; i < x.size(); i++ )
- x.set( i, i ? (1.0 / i) : 0.0 );
- BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_inv0<Real>() ), 10.0 );
- x /= x.sum();
-
- // test entropy
- BOOST_CHECK( x.entropy() < Prob(5).entropy() );
- for( size_t i = 1; i < 100; i++ )
- BOOST_CHECK_CLOSE( Prob(i).entropy(), std::log((Real)i), tol );
-
- // test hasNaNs and hasNegatives
- BOOST_CHECK( !Prob( 3, 0.0 ).hasNaNs() );
- Real c = 0.0;
- BOOST_CHECK( Prob( 3, c / c ).hasNaNs() );
- BOOST_CHECK( !Prob( 3, 0.0 ).hasNegatives() );
- BOOST_CHECK( !Prob( 3, 1.0 ).hasNegatives() );
- BOOST_CHECK( Prob( 3, -1.0 ).hasNegatives() );
- x.set( 0, 0.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 ); x.set( 3, 1.0 ); x.set( 4, 100.0 );
- BOOST_CHECK( x.hasNegatives() );
- x.set( 2, -INFINITY );
- BOOST_CHECK( x.hasNegatives() );
- x.set( 2, INFINITY );
- BOOST_CHECK( !x.hasNegatives() );
- x.set( 2, -1.0 );
-
- // test argmax
- BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)100.0 ) );
- x.set( 4, 0.5 );
- BOOST_CHECK( x.argmax() == std::make_pair( (size_t)3, (Real)1.0 ) );
- x.set( 3, -2.0 );
- BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)0.5 ) );
- x.set( 4, -1.0 );
- BOOST_CHECK( x.argmax() == std::make_pair( (size_t)0, (Real)0.0 ) );
- x.set( 0, -2.0 );
- BOOST_CHECK( x.argmax() == std::make_pair( (size_t)1, (Real)0.0 ) );
- x.set( 1, -3.0 );
- BOOST_CHECK( x.argmax() == std::make_pair( (size_t)2, (Real)-1.0 ) );
- x.set( 2, -2.0 );
- BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)-1.0 ) );
-
- // test draw
- for( size_t i = 0; i < x.size(); i++ )
- x.set( i, i ? (1.0 / i) : 0.0 );
- for( size_t repeat = 0; repeat < 10000; repeat++ ) {
- BOOST_CHECK( x.draw() < x.size() );
- BOOST_CHECK( x.draw() != 0 );
- }
- x.set( 2, 0.0 );
- for( size_t repeat = 0; repeat < 10000; repeat++ ) {
- BOOST_CHECK( x.draw() < x.size() );
- BOOST_CHECK( x.draw() != 0 );
- BOOST_CHECK( x.draw() != 2 );
- }
- x.set( 4, 0.0 );
- for( size_t repeat = 0; repeat < 10000; repeat++ ) {
- BOOST_CHECK( x.draw() < x.size() );
- BOOST_CHECK( x.draw() != 0 );
- BOOST_CHECK( x.draw() != 2 );
- BOOST_CHECK( x.draw() != 4 );
- }
- x.set( 1, 0.0 );
- for( size_t repeat = 0; repeat < 10000; repeat++ )
- BOOST_CHECK( x.draw() == 3 );
-
- // test <, ==
- Prob a(3, 1.0), b(3, 1.0);
- BOOST_CHECK( !(a < b) );
- BOOST_CHECK( !(b < a) );
- BOOST_CHECK( a == b );
- a.set( 0, 0.0 );
- BOOST_CHECK( a < b );
- BOOST_CHECK( !(b < a) );
- BOOST_CHECK( !(a == b) );
- b.set( 2, 0.0 );
- BOOST_CHECK( a < b );
- BOOST_CHECK( !(b < a) );
- BOOST_CHECK( !(a == b) );
- b.set( 0, 0.0 );
- BOOST_CHECK( !(a < b) );
- BOOST_CHECK( b < a );
- BOOST_CHECK( !(a == b) );
- a.set( 1, 0.0 );
- BOOST_CHECK( a < b );
- BOOST_CHECK( !(b < a) );
- BOOST_CHECK( !(a == b) );
- b.set( 1, 0.0 );
- BOOST_CHECK( !(a < b) );
- BOOST_CHECK( b < a );
- BOOST_CHECK( !(a == b) );
- a.set( 2, 0.0 );
- BOOST_CHECK( !(a < b) );
- BOOST_CHECK( !(b < a) );
- BOOST_CHECK( a == b );
-}
-
-
-BOOST_AUTO_TEST_CASE( UnaryTransformationsTest ) {
- Prob x( 3 );
- x.set( 0, -2.0 );
- x.set( 1, 0.0 );
- x.set( 2, 2.0 );
-
- Prob y = -x;
- Prob z = x.pwUnaryTr( std::negate<Real>() );
- BOOST_CHECK_EQUAL( y[0], 2.0 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], -2.0 );
- BOOST_CHECK( y == z );
-
- y = x.abs();
- z = x.pwUnaryTr( fo_abs<Real>() );
- BOOST_CHECK_EQUAL( y[0], 2.0 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], 2.0 );
- BOOST_CHECK( y == z );
-
- y = x.exp();
- z = x.pwUnaryTr( fo_exp<Real>() );
- BOOST_CHECK_CLOSE( y[0], std::exp(-2.0), tol );
- BOOST_CHECK_EQUAL( y[1], 1.0 );
- BOOST_CHECK_CLOSE( y[2], 1.0 / y[0], tol );
- BOOST_CHECK( y == z );
-
- y = x.log(false);
- z = x.pwUnaryTr( fo_log<Real>() );
- BOOST_CHECK( isnan( y[0] ) );
- BOOST_CHECK_EQUAL( y[1], -INFINITY );
- BOOST_CHECK_CLOSE( y[2], std::log(2.0), tol );
- BOOST_CHECK( !(y == z) );
- y.set( 0, 0.0 );
- z.set( 0, 0.0 );
- BOOST_CHECK( y == z );
-
- y = x.log(true);
- z = x.pwUnaryTr( fo_log0<Real>() );
- BOOST_CHECK( isnan( y[0] ) );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], std::log(2.0) );
- BOOST_CHECK( !(y == z) );
- y.set( 0, 0.0 );
- z.set( 0, 0.0 );
- BOOST_CHECK( y == z );
-
- y = x.inverse(false);
- z = x.pwUnaryTr( fo_inv<Real>() );
- BOOST_CHECK_EQUAL( y[0], -0.5 );
- BOOST_CHECK_EQUAL( y[1], INFINITY );
- BOOST_CHECK_EQUAL( y[2], 0.5 );
- BOOST_CHECK( y == z );
-
- y = x.inverse(true);
- z = x.pwUnaryTr( fo_inv0<Real>() );
- BOOST_CHECK_EQUAL( y[0], -0.5 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], 0.5 );
- BOOST_CHECK( y == z );
-
- x.set( 0, 2.0 );
- y = x.normalized();
- BOOST_CHECK_EQUAL( y[0], 0.5 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], 0.5 );
-
- y = x.normalized( NORMPROB );
- BOOST_CHECK_EQUAL( y[0], 0.5 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], 0.5 );
-
- x.set( 0, -2.0 );
- y = x.normalized( NORMLINF );
- BOOST_CHECK_EQUAL( y[0], -1.0 );
- BOOST_CHECK_EQUAL( y[1], 0.0 );
- BOOST_CHECK_EQUAL( y[2], 1.0 );
-}
-
-
-BOOST_AUTO_TEST_CASE( UnaryOperationsTest ) {
- Prob xorg(3);
- xorg.set( 0, 2.0 );
- xorg.set( 1, 0.0 );
- xorg.set( 2, 1.0 );
- Prob y(3);
-
- Prob x = xorg;
- BOOST_CHECK( x.setUniform() == Prob(3) );
- BOOST_CHECK( x == Prob(3) );
-
- y.set( 0, std::exp(2.0) );
- y.set( 1, 1.0 );
- y.set( 2, std::exp(1.0) );
- x = xorg;
- BOOST_CHECK( x.takeExp() == y );
- BOOST_CHECK( x == y );
- x = xorg;
- BOOST_CHECK( x.pwUnaryOp( fo_exp<Real>() ) == y );
- BOOST_CHECK( x == y );
-
- y.set( 0, std::log(2.0) );
- y.set( 1, -INFINITY );
- y.set( 2, 0.0 );
- x = xorg;
- BOOST_CHECK( x.takeLog() == y );
- BOOST_CHECK( x == y );
- x = xorg;
- BOOST_CHECK( x.takeLog(false) == y );
- BOOST_CHECK( x == y );
- x = xorg;
- BOOST_CHECK( x.pwUnaryOp( fo_log<Real>() ) == y );
- BOOST_CHECK( x == y );
-
- y.set( 1, 0.0 );
- x = xorg;
- BOOST_CHECK( x.takeLog(true) == y );
- BOOST_CHECK( x == y );
- x = xorg;
- BOOST_CHECK( x.pwUnaryOp( fo_log0<Real>() ) == y );
- BOOST_CHECK( x == y );
-
- y.set( 0, 2.0 / 3.0 );
- y.set( 1, 0.0 / 3.0 );
- y.set( 2, 1.0 / 3.0 );
- x = xorg;
- BOOST_CHECK_EQUAL( x.normalize(), 3.0 );
- BOOST_CHECK( x == y );
-
- x = xorg;
- BOOST_CHECK_EQUAL( x.normalize( NORMPROB ), 3.0 );
- BOOST_CHECK( x == y );
-
- y.set( 0, 2.0 / 2.0 );
- y.set( 1, 0.0 / 2.0 );
- y.set( 2, 1.0 / 2.0 );
- x = xorg;
- BOOST_CHECK_EQUAL( x.normalize( NORMLINF ), 2.0 );
- BOOST_CHECK( x == y );
-
- xorg.set( 0, -2.0 );
- y.set( 0, 2.0 );
- y.set( 1, 0.0 );
- y.set( 2, 1.0 );
- x = xorg;
- BOOST_CHECK( x.takeAbs() == y );
- BOOST_CHECK( x == y );
-
- for( size_t repeat = 0; repeat < 10000; repeat++ ) {
- x.randomize();
- for( size_t i = 0; i < x.size(); i++ ) {
- BOOST_CHECK( x[i] < 1.0 );
- BOOST_CHECK( x[i] >= 0.0 );
- }
- }
-}
-
-
-BOOST_AUTO_TEST_CASE( ScalarTransformationsTest ) {
- Prob x(3);
- x.set( 0, 2.0 );
- x.set( 1, 0.0 );
- x.set( 2, 1.0 );
- Prob y(3);
-
- y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
- BOOST_CHECK( (x + 1.0) == y );
- y.set( 0, 0.0 ); y.set( 1, -2.0 ); y.set( 2, -1.0 );
- BOOST_CHECK( (x + (-2.0)) == y );
-
- y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
- BOOST_CHECK( (x - 1.0) == y );
- y.set( 0, 4.0 ); y.set( 1, 2.0 ); y.set( 2, 3.0 );
- BOOST_CHECK( (x - (-2.0)) == y );
-
- BOOST_CHECK( (x * 1.0) == x );
- y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
- BOOST_CHECK( (x * 2.0) == y );
- y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
- BOOST_CHECK( (x * -0.5) == y );
-
- BOOST_CHECK( (x / 1.0) == x );
- y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
- BOOST_CHECK( (x / 2.0) == y );
- y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
- BOOST_CHECK( (x / -0.5) == y );
- BOOST_CHECK( (x / 0.0) == Prob(3, 0.0) );
-
- BOOST_CHECK( (x ^ 1.0) == x );
- BOOST_CHECK( (x ^ 0.0) == Prob(3, 1.0) );
- y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
- BOOST_CHECK( (x ^ 2.0) == y );
- y.set( 0, 1.0 / std::sqrt(2.0) ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
- Prob z = (x ^ -0.5);
- BOOST_CHECK_CLOSE( z[0], y[0], tol );
- BOOST_CHECK_EQUAL( z[1], y[1] );
- BOOST_CHECK_CLOSE( z[2], y[2], tol );
-}
-
-
-BOOST_AUTO_TEST_CASE( ScalarOperationsTest ) {
- Prob xorg(3), x(3);
- xorg.set( 0, 2.0 );
- xorg.set( 1, 0.0 );
- xorg.set( 2, 1.0 );
- Prob y(3);
-
- x = xorg;
- BOOST_CHECK( x.fill( 1.0 ) == Prob(3, 1.0) );
- BOOST_CHECK( x == Prob(3, 1.0) );
- BOOST_CHECK( x.fill( 2.0 ) == Prob(3, 2.0) );
- BOOST_CHECK( x == Prob(3, 2.0) );
- BOOST_CHECK( x.fill( 0.0 ) == Prob(3, 0.0) );
- BOOST_CHECK( x == Prob(3, 0.0) );
-
- x = xorg;
- y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
- BOOST_CHECK( (x += 1.0) == y );
- BOOST_CHECK( x == y );
- y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
- BOOST_CHECK( (x += -2.0) == y );
- BOOST_CHECK( x == y );
-
- x = xorg;
- y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
- BOOST_CHECK( (x -= 1.0) == y );
- BOOST_CHECK( x == y );
- y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
- BOOST_CHECK( (x -= -2.0) == y );
- BOOST_CHECK( x == y );
-
- x = xorg;
- BOOST_CHECK( (x *= 1.0) == x );
- BOOST_CHECK( x == x );
- y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
- BOOST_CHECK( (x *= 2.0) == y );
- BOOST_CHECK( x == y );
- y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
- BOOST_CHECK( (x *= -0.25) == y );
- BOOST_CHECK( x == y );
-
- x = xorg;
- BOOST_CHECK( (x /= 1.0) == x );
- BOOST_CHECK( x == x );
- y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
- BOOST_CHECK( (x /= 2.0) == y );
- BOOST_CHECK( x == y );
- y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
- BOOST_CHECK( (x /= -0.25) == y );
- BOOST_CHECK( x == y );
- BOOST_CHECK( (x /= 0.0) == Prob(3, 0.0) );
- BOOST_CHECK( x == Prob(3, 0.0) );
-
- x = xorg;
- BOOST_CHECK( (x ^= 1.0) == x );
- BOOST_CHECK( x == x );
- BOOST_CHECK( (x ^= 0.0) == Prob(3, 1.0) );
- BOOST_CHECK( x == Prob(3, 1.0) );
- x = xorg;
- y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
- BOOST_CHECK( (x ^= 2.0) == y );
- BOOST_CHECK( x == y );
- y.set( 0, 0.5 ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
- BOOST_CHECK( (x ^= -0.5) == y );
- BOOST_CHECK( x == y );
-}
-
-
-BOOST_AUTO_TEST_CASE( VectorOperationsTest ) {
- size_t N = 6;
- Prob xorg(N), x(N);
- xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, 1.0 ); xorg.set( 3, 0.0 ); xorg.set( 4, 2.0 ); xorg.set( 5, 3.0 );
- Prob y(N);
- y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
- Prob z(N), r(N);
-
- z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
- x = xorg;
- r = (x += y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.pwBinaryOp( y, std::plus<Real>() ) == z );
- BOOST_CHECK( x == z );
-
- z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
- x = xorg;
- r = (x -= y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.pwBinaryOp( y, std::minus<Real>() ) == z );
- BOOST_CHECK( x == z );
-
- z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
- x = xorg;
- r = (x *= y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.pwBinaryOp( y, std::multiplies<Real>() ) == z );
- BOOST_CHECK( x == z );
-
- z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
- x = xorg;
- r = (x /= y);
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.pwBinaryOp( y, fo_divides0<Real>() ) == z );
- BOOST_CHECK( x == z );
-
- z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY ); /*z.set( 3, INFINITY );*/ z.set( 4, -1.0 ); z.set( 5, 1.0 );
- x = xorg;
- r = (x.divide( y ));
- BOOST_CHECK_CLOSE( r[0], z[0], tol );
- BOOST_CHECK_CLOSE( r[1], z[1], tol );
- BOOST_CHECK_EQUAL( r[2], z[2] );
- BOOST_CHECK( isnan(r[3]) );
- BOOST_CHECK_CLOSE( r[4], z[4], tol );
- BOOST_CHECK_CLOSE( r[5], z[5], tol );
- x.set( 3, 0.0 ); r.set( 3, 0.0 );
- BOOST_CHECK( x == r );
- x = xorg;
- r = x.pwBinaryOp( y, std::divides<Real>() );
- BOOST_CHECK_CLOSE( r[0], z[0], tol );
- BOOST_CHECK_CLOSE( r[1], z[1], tol );
- BOOST_CHECK_EQUAL( r[2], z[2] );
- BOOST_CHECK( isnan(r[3]) );
- BOOST_CHECK_CLOSE( r[4], z[4], tol );
- BOOST_CHECK_CLOSE( r[5], z[5], tol );
- x.set( 3, 0.0 ); r.set( 3, 0.0 );
- BOOST_CHECK( x == r );
-
- z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
- x = xorg;
- r = (x ^= y);
- BOOST_CHECK_CLOSE( r[0], z[0], tol );
- BOOST_CHECK_EQUAL( r[1], z[1] );
- BOOST_CHECK_CLOSE( r[2], z[2], tol );
- BOOST_CHECK_CLOSE( r[3], z[3], tol );
- BOOST_CHECK_CLOSE( r[4], z[4], tol );
- BOOST_CHECK_CLOSE( r[5], z[5], tol );
- BOOST_CHECK( x == z );
- x = xorg;
- BOOST_CHECK( x.pwBinaryOp( y, fo_pow<Real>() ) == z );
- BOOST_CHECK( x == z );
-}
-
-
-BOOST_AUTO_TEST_CASE( VectorTransformationsTest ) {
- size_t N = 6;
- Prob x(N);
- x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, 1.0 ); x.set( 3, 0.0 ); x.set( 4, 2.0 ); x.set( 5, 3.0 );
- Prob y(N);
- y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
- Prob z(N), r(N);
-
- z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
- r = x + y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- z = x.pwBinaryTr( y, std::plus<Real>() );
- BOOST_CHECK( r == z );
-
- z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
- r = x - y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- z = x.pwBinaryTr( y, std::minus<Real>() );
- BOOST_CHECK( r == z );
-
- z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
- r = x * y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- z = x.pwBinaryTr( y, std::multiplies<Real>() );
- BOOST_CHECK( r == z );
-
- z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
- r = x / y;
- for( size_t i = 0; i < N; i++ )
- BOOST_CHECK_CLOSE( r[i], z[i], tol );
- z = x.pwBinaryTr( y, fo_divides0<Real>() );
- BOOST_CHECK( r == z );
-
- z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY ); /*z.set( 3, INFINITY );*/ z.set( 4, -1.0 ); z.set( 5, 1.0 );
- r = x.divided_by( y );
- BOOST_CHECK_CLOSE( r[0], z[0], tol );
- BOOST_CHECK_CLOSE( r[1], z[1], tol );
- BOOST_CHECK_EQUAL( r[2], z[2] );
- BOOST_CHECK( isnan(r[3]) );
- BOOST_CHECK_CLOSE( r[4], z[4], tol );
- BOOST_CHECK_CLOSE( r[5], z[5], tol );
- z = x.pwBinaryTr( y, std::divides<Real>() );
- BOOST_CHECK_CLOSE( r[0], z[0], tol );
- BOOST_CHECK_CLOSE( r[1], z[1], tol );
- BOOST_CHECK_EQUAL( r[2], z[2] );
- BOOST_CHECK( isnan(r[3]) );
- BOOST_CHECK_CLOSE( r[4], z[4], tol );
- BOOST_CHECK_CLOSE( r[5], z[5], tol );
-
- z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
- r = x ^ y;
- BOOST_CHECK_CLOSE( r[0], z[0], tol );
- BOOST_CHECK_EQUAL( r[1], z[1] );
- BOOST_CHECK_CLOSE( r[2], z[2], tol );
- BOOST_CHECK_CLOSE( r[3], z[3], tol );
- BOOST_CHECK_CLOSE( r[4], z[4], tol );
- BOOST_CHECK_CLOSE( r[5], z[5], tol );
- z = x.pwBinaryTr( y, fo_pow<Real>() );
- BOOST_CHECK( r == z );
-}
-
-
-BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
- Prob x(3), y(3), z(3);
- x.set( 0, 0.2 );
- x.set( 1, 0.8 );
- x.set( 2, 0.0 );
- y.set( 0, 0.0 );
- y.set( 1, 0.6 );
- y.set( 2, 0.4 );
-
- z = min( x, y );
- BOOST_CHECK_EQUAL( z[0], 0.0 );
- BOOST_CHECK_EQUAL( z[1], 0.6 );
- BOOST_CHECK_EQUAL( z[2], 0.0 );
- z = max( x, y );
- BOOST_CHECK_EQUAL( z[0], 0.2 );
- BOOST_CHECK_EQUAL( z[1], 0.8 );
- BOOST_CHECK_EQUAL( z[2], 0.4 );
-
- BOOST_CHECK_EQUAL( dist( x, x, DISTL1 ), 0.0 );
- BOOST_CHECK_EQUAL( dist( y, y, DISTL1 ), 0.0 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), 0.2 + 0.2 + 0.4 );
- BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), 0.2 + 0.2 + 0.4 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
- BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
- BOOST_CHECK_EQUAL( dist( x, x, DISTLINF ), 0.0 );
- BOOST_CHECK_EQUAL( dist( y, y, DISTLINF ), 0.0 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), 0.4 );
- BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), 0.4 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), x.innerProduct( y, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
- BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), y.innerProduct( x, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
- BOOST_CHECK_EQUAL( dist( x, x, DISTTV ), 0.0 );
- BOOST_CHECK_EQUAL( dist( y, y, DISTTV ), 0.0 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
- BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
- BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
- BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
- BOOST_CHECK_EQUAL( dist( x, x, DISTKL ), 0.0 );
- BOOST_CHECK_EQUAL( dist( y, y, DISTKL ), 0.0 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), INFINITY );
- BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), INFINITY );
- BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
- BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
- BOOST_CHECK_EQUAL( dist( x, x, DISTHEL ), 0.0 );
- BOOST_CHECK_EQUAL( dist( y, y, DISTHEL ), 0.0 );
- BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
- BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
- BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
- BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
- x.set( 1, 0.7 ); x.set( 2, 0.1 );
- y.set( 0, 0.1 ); y.set( 1, 0.5 );
- BOOST_CHECK_CLOSE( dist( x, y, DISTKL ), 0.2 * std::log(0.2 / 0.1) + 0.7 * std::log(0.7 / 0.5) + 0.1 * std::log(0.1 / 0.4), tol );
- BOOST_CHECK_CLOSE( dist( y, x, DISTKL ), 0.1 * std::log(0.1 / 0.2) + 0.5 * std::log(0.5 / 0.7) + 0.4 * std::log(0.4 / 0.1), tol );
- BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
- BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
-
- Prob xx(4), yy(4);
- for( size_t i = 0; i < 3; i++ ) {
- xx.set( i, x[i] );
- yy.set( i, y[i] );
- }
- std::stringstream ss;
- ss << xx;
- std::string s;
- std::getline( ss, s );
-#ifdef DAI_SPARSE
- BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.2, 1:0.7, 2:0.1)") );
-#else
- BOOST_CHECK_EQUAL( s, std::string("(0.2, 0.7, 0.1, 0.25)") );
-#endif
- std::stringstream ss2;
- ss2 << yy;
- std::getline( ss2, s );
-#ifdef DAI_SPARSE
- BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.1, 1:0.5, 2:0.4)") );
-#else
- BOOST_CHECK_EQUAL( s, std::string("(0.1, 0.5, 0.4, 0.25)") );
-#endif
-
- z = min( x, y );
- BOOST_CHECK_EQUAL( z[0], 0.1 );
- BOOST_CHECK_EQUAL( z[1], 0.5 );
- BOOST_CHECK_EQUAL( z[2], 0.1 );
- z = max( x, y );
- BOOST_CHECK_EQUAL( z[0], 0.2 );
- BOOST_CHECK_EQUAL( z[1], 0.7 );
- BOOST_CHECK_EQUAL( z[2], 0.4 );
-
- BOOST_CHECK_CLOSE( x.innerProduct( y, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
- BOOST_CHECK_CLOSE( y.innerProduct( x, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
- BOOST_CHECK_CLOSE( x.innerProduct( y, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
- BOOST_CHECK_CLOSE( y.innerProduct( x, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
- BOOST_CHECK_CLOSE( x.innerProduct( y, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
- BOOST_CHECK_CLOSE( y.innerProduct( x, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
-}
--- /dev/null
+/* This file is part of libDAI - http://www.libdai.org/
+ *
+ * libDAI is licensed under the terms of the GNU General Public License version
+ * 2, or (at your option) any later version. libDAI is distributed without any
+ * warranty. See the file COPYING for more details.
+ *
+ * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/prob.h>
+#include <strstream>
+
+
+using namespace dai;
+
+
+const double tol = 1e-8;
+
+
+#define BOOST_TEST_MODULE ProbTest
+
+
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+
+
+BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
+ // check constructors
+ Prob x1;
+ BOOST_CHECK_EQUAL( x1.size(), 0 );
+ BOOST_CHECK( x1.p() == Prob::container_type() );
+
+ Prob x2( 3 );
+ BOOST_CHECK_EQUAL( x2.size(), 3 );
+ BOOST_CHECK( x2.p() == Prob::container_type( 3, 1.0 / 3.0 ) );
+ BOOST_CHECK_EQUAL( x2[0], 1.0 / 3.0 );
+ BOOST_CHECK_EQUAL( x2[1], 1.0 / 3.0 );
+ BOOST_CHECK_EQUAL( x2[2], 1.0 / 3.0 );
+
+ Prob x3( 4, 1.0 );
+ BOOST_CHECK_EQUAL( x3.size(), 4 );
+ BOOST_CHECK( x3.p() == Prob::container_type( 4, 1.0 ) );
+ BOOST_CHECK_EQUAL( x3[0], 1.0 );
+ BOOST_CHECK_EQUAL( x3[1], 1.0 );
+ BOOST_CHECK_EQUAL( x3[2], 1.0 );
+ BOOST_CHECK_EQUAL( x3[3], 1.0 );
+ x3.set( 0, 0.5 );
+ x3.set( 1, 1.0 );
+ x3.set( 2, 2.0 );
+ x3.set( 3, 4.0 );
+
+ std::vector<Real> v;
+ v.push_back( 0.5 );
+ v.push_back( 1.0 );
+ v.push_back( 2.0 );
+ v.push_back( 4.0 );
+ Prob x4( v.begin(), v.end(), 0 );
+ BOOST_CHECK_EQUAL( x4.size(), 4 );
+ BOOST_CHECK( x4.p() == x3.p() );
+ BOOST_CHECK( x4 == x3 );
+ BOOST_CHECK_EQUAL( x4[0], 0.5 );
+ BOOST_CHECK_EQUAL( x4[1], 1.0 );
+ BOOST_CHECK_EQUAL( x4[2], 2.0 );
+ BOOST_CHECK_EQUAL( x4[3], 4.0 );
+
+ Prob x5( v.begin(), v.end(), v.size() );
+ BOOST_CHECK_EQUAL( x5.size(), 4 );
+ BOOST_CHECK( x5.p() == x3.p() );
+ BOOST_CHECK( x5 == x3 );
+ BOOST_CHECK_EQUAL( x5[0], 0.5 );
+ BOOST_CHECK_EQUAL( x5[1], 1.0 );
+ BOOST_CHECK_EQUAL( x5[2], 2.0 );
+ BOOST_CHECK_EQUAL( x5[3], 4.0 );
+
+ std::vector<int> y( 3, 2 );
+ Prob x6( y );
+ BOOST_CHECK_EQUAL( x6.size(), 3 );
+ BOOST_CHECK_EQUAL( x6[0], 2.0 );
+ BOOST_CHECK_EQUAL( x6[1], 2.0 );
+ BOOST_CHECK_EQUAL( x6[2], 2.0 );
+
+ Prob x7( x6 );
+ BOOST_CHECK( x7 == x6 );
+
+ Prob x8 = x6;
+ BOOST_CHECK( x8 == x6 );
+
+ x7.resize( 5 );
+ BOOST_CHECK_EQUAL( x7.size(), 5 );
+ BOOST_CHECK_EQUAL( x7[0], 2.0 );
+ BOOST_CHECK_EQUAL( x7[1], 2.0 );
+ BOOST_CHECK_EQUAL( x7[2], 2.0 );
+ BOOST_CHECK_EQUAL( x7[3], 0.0 );
+ BOOST_CHECK_EQUAL( x7[4], 0.0 );
+
+ x8.resize( 1 );
+ BOOST_CHECK_EQUAL( x8.size(), 1 );
+ BOOST_CHECK_EQUAL( x8[0], 2.0 );
+}
+
+
+#ifndef DAI_SPARSE
+BOOST_AUTO_TEST_CASE( IteratorTest ) {
+ Prob x( 5, 0.0 );
+ size_t i;
+ for( i = 0; i < x.size(); i++ )
+ x.set( i, i );
+
+ i = 0;
+ for( Prob::const_iterator cit = x.begin(); cit != x.end(); cit++, i++ )
+ BOOST_CHECK_EQUAL( *cit, i );
+
+ i = 0;
+ for( Prob::iterator it = x.begin(); it != x.end(); it++, i++ )
+ *it = 4 - i;
+
+ i = 0;
+ for( Prob::const_iterator it = x.begin(); it != x.end(); it++, i++ )
+ BOOST_CHECK_EQUAL( *it, 4 - i );
+
+ i = 0;
+ for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
+ BOOST_CHECK_EQUAL( *crit, i );
+
+ i = 0;
+ for( Prob::reverse_iterator rit = x.rbegin(); rit != x.rend(); rit++, i++ )
+ *rit = 2 * i;
+
+ i = 0;
+ for( Prob::const_reverse_iterator crit = x.rbegin(); crit != x.rend(); crit++, i++ )
+ BOOST_CHECK_EQUAL( *crit, 2 * i );
+}
+#endif
+
+
+BOOST_AUTO_TEST_CASE( QueriesTest ) {
+ Prob x( 5, 0.0 );
+ for( size_t i = 0; i < x.size(); i++ )
+ x.set( i, 2.0 - i );
+
+ // test accumulate, min, max, sum, sumAbs, maxAbs
+ BOOST_CHECK_EQUAL( x.sum(), 0.0 );
+ BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_id<Real>() ), 0.0 );
+ BOOST_CHECK_EQUAL( x.accumulateSum( 1.0, fo_id<Real>() ), 1.0 );
+ BOOST_CHECK_EQUAL( x.accumulateSum( -1.0, fo_id<Real>() ), -1.0 );
+ BOOST_CHECK_EQUAL( x.max(), 2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( -INFINITY, fo_id<Real>(), false ), 2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_id<Real>(), false ), 3.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( -5.0, fo_id<Real>(), false ), 2.0 );
+ BOOST_CHECK_EQUAL( x.min(), -2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( INFINITY, fo_id<Real>(), true ), -2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_id<Real>(), true ), -3.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( 5.0, fo_id<Real>(), true ), -2.0 );
+ BOOST_CHECK_EQUAL( x.sumAbs(), 6.0 );
+ BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_abs<Real>() ), 6.0 );
+ BOOST_CHECK_EQUAL( x.accumulateSum( 1.0, fo_abs<Real>() ), 7.0 );
+ BOOST_CHECK_EQUAL( x.accumulateSum( -1.0, fo_abs<Real>() ), 7.0 );
+ BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( 0.0, fo_abs<Real>(), false ), 2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( 1.0, fo_abs<Real>(), false ), 2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( -1.0, fo_abs<Real>(), false ), 2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_abs<Real>(), false ), 3.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_abs<Real>(), false ), 3.0 );
+ x.set( 1, 1.0 );
+ BOOST_CHECK_EQUAL( x.maxAbs(), 2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( 0.0, fo_abs<Real>(), false ), 2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( 1.0, fo_abs<Real>(), false ), 2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( -1.0, fo_abs<Real>(), false ), 2.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( 3.0, fo_abs<Real>(), false ), 3.0 );
+ BOOST_CHECK_EQUAL( x.accumulateMax( -3.0, fo_abs<Real>(), false ), 3.0 );
+ for( size_t i = 0; i < x.size(); i++ )
+ x.set( i, i ? (1.0 / i) : 0.0 );
+ BOOST_CHECK_EQUAL( x.accumulateSum( 0.0, fo_inv0<Real>() ), 10.0 );
+ x /= x.sum();
+
+ // test entropy
+ BOOST_CHECK( x.entropy() < Prob(5).entropy() );
+ for( size_t i = 1; i < 100; i++ )
+ BOOST_CHECK_CLOSE( Prob(i).entropy(), std::log((Real)i), tol );
+
+ // test hasNaNs and hasNegatives
+ BOOST_CHECK( !Prob( 3, 0.0 ).hasNaNs() );
+ Real c = 0.0;
+ BOOST_CHECK( Prob( 3, c / c ).hasNaNs() );
+ BOOST_CHECK( !Prob( 3, 0.0 ).hasNegatives() );
+ BOOST_CHECK( !Prob( 3, 1.0 ).hasNegatives() );
+ BOOST_CHECK( Prob( 3, -1.0 ).hasNegatives() );
+ x.set( 0, 0.0 ); x.set( 1, 0.0 ); x.set( 2, -1.0 ); x.set( 3, 1.0 ); x.set( 4, 100.0 );
+ BOOST_CHECK( x.hasNegatives() );
+ x.set( 2, -INFINITY );
+ BOOST_CHECK( x.hasNegatives() );
+ x.set( 2, INFINITY );
+ BOOST_CHECK( !x.hasNegatives() );
+ x.set( 2, -1.0 );
+
+ // test argmax
+ BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)100.0 ) );
+ x.set( 4, 0.5 );
+ BOOST_CHECK( x.argmax() == std::make_pair( (size_t)3, (Real)1.0 ) );
+ x.set( 3, -2.0 );
+ BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)0.5 ) );
+ x.set( 4, -1.0 );
+ BOOST_CHECK( x.argmax() == std::make_pair( (size_t)0, (Real)0.0 ) );
+ x.set( 0, -2.0 );
+ BOOST_CHECK( x.argmax() == std::make_pair( (size_t)1, (Real)0.0 ) );
+ x.set( 1, -3.0 );
+ BOOST_CHECK( x.argmax() == std::make_pair( (size_t)2, (Real)-1.0 ) );
+ x.set( 2, -2.0 );
+ BOOST_CHECK( x.argmax() == std::make_pair( (size_t)4, (Real)-1.0 ) );
+
+ // test draw
+ for( size_t i = 0; i < x.size(); i++ )
+ x.set( i, i ? (1.0 / i) : 0.0 );
+ for( size_t repeat = 0; repeat < 10000; repeat++ ) {
+ BOOST_CHECK( x.draw() < x.size() );
+ BOOST_CHECK( x.draw() != 0 );
+ }
+ x.set( 2, 0.0 );
+ for( size_t repeat = 0; repeat < 10000; repeat++ ) {
+ BOOST_CHECK( x.draw() < x.size() );
+ BOOST_CHECK( x.draw() != 0 );
+ BOOST_CHECK( x.draw() != 2 );
+ }
+ x.set( 4, 0.0 );
+ for( size_t repeat = 0; repeat < 10000; repeat++ ) {
+ BOOST_CHECK( x.draw() < x.size() );
+ BOOST_CHECK( x.draw() != 0 );
+ BOOST_CHECK( x.draw() != 2 );
+ BOOST_CHECK( x.draw() != 4 );
+ }
+ x.set( 1, 0.0 );
+ for( size_t repeat = 0; repeat < 10000; repeat++ )
+ BOOST_CHECK( x.draw() == 3 );
+
+ // test <, ==
+ Prob a(3, 1.0), b(3, 1.0);
+ BOOST_CHECK( !(a < b) );
+ BOOST_CHECK( !(b < a) );
+ BOOST_CHECK( a == b );
+ a.set( 0, 0.0 );
+ BOOST_CHECK( a < b );
+ BOOST_CHECK( !(b < a) );
+ BOOST_CHECK( !(a == b) );
+ b.set( 2, 0.0 );
+ BOOST_CHECK( a < b );
+ BOOST_CHECK( !(b < a) );
+ BOOST_CHECK( !(a == b) );
+ b.set( 0, 0.0 );
+ BOOST_CHECK( !(a < b) );
+ BOOST_CHECK( b < a );
+ BOOST_CHECK( !(a == b) );
+ a.set( 1, 0.0 );
+ BOOST_CHECK( a < b );
+ BOOST_CHECK( !(b < a) );
+ BOOST_CHECK( !(a == b) );
+ b.set( 1, 0.0 );
+ BOOST_CHECK( !(a < b) );
+ BOOST_CHECK( b < a );
+ BOOST_CHECK( !(a == b) );
+ a.set( 2, 0.0 );
+ BOOST_CHECK( !(a < b) );
+ BOOST_CHECK( !(b < a) );
+ BOOST_CHECK( a == b );
+}
+
+
+BOOST_AUTO_TEST_CASE( UnaryTransformationsTest ) {
+ Prob x( 3 );
+ x.set( 0, -2.0 );
+ x.set( 1, 0.0 );
+ x.set( 2, 2.0 );
+
+ Prob y = -x;
+ Prob z = x.pwUnaryTr( std::negate<Real>() );
+ BOOST_CHECK_EQUAL( y[0], 2.0 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], -2.0 );
+ BOOST_CHECK( y == z );
+
+ y = x.abs();
+ z = x.pwUnaryTr( fo_abs<Real>() );
+ BOOST_CHECK_EQUAL( y[0], 2.0 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], 2.0 );
+ BOOST_CHECK( y == z );
+
+ y = x.exp();
+ z = x.pwUnaryTr( fo_exp<Real>() );
+ BOOST_CHECK_CLOSE( y[0], std::exp(-2.0), tol );
+ BOOST_CHECK_EQUAL( y[1], 1.0 );
+ BOOST_CHECK_CLOSE( y[2], 1.0 / y[0], tol );
+ BOOST_CHECK( y == z );
+
+ y = x.log(false);
+ z = x.pwUnaryTr( fo_log<Real>() );
+ BOOST_CHECK( isnan( y[0] ) );
+ BOOST_CHECK_EQUAL( y[1], -INFINITY );
+ BOOST_CHECK_CLOSE( y[2], std::log(2.0), tol );
+ BOOST_CHECK( !(y == z) );
+ y.set( 0, 0.0 );
+ z.set( 0, 0.0 );
+ BOOST_CHECK( y == z );
+
+ y = x.log(true);
+ z = x.pwUnaryTr( fo_log0<Real>() );
+ BOOST_CHECK( isnan( y[0] ) );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], std::log(2.0) );
+ BOOST_CHECK( !(y == z) );
+ y.set( 0, 0.0 );
+ z.set( 0, 0.0 );
+ BOOST_CHECK( y == z );
+
+ y = x.inverse(false);
+ z = x.pwUnaryTr( fo_inv<Real>() );
+ BOOST_CHECK_EQUAL( y[0], -0.5 );
+ BOOST_CHECK_EQUAL( y[1], INFINITY );
+ BOOST_CHECK_EQUAL( y[2], 0.5 );
+ BOOST_CHECK( y == z );
+
+ y = x.inverse(true);
+ z = x.pwUnaryTr( fo_inv0<Real>() );
+ BOOST_CHECK_EQUAL( y[0], -0.5 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], 0.5 );
+ BOOST_CHECK( y == z );
+
+ x.set( 0, 2.0 );
+ y = x.normalized();
+ BOOST_CHECK_EQUAL( y[0], 0.5 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], 0.5 );
+
+ y = x.normalized( NORMPROB );
+ BOOST_CHECK_EQUAL( y[0], 0.5 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], 0.5 );
+
+ x.set( 0, -2.0 );
+ y = x.normalized( NORMLINF );
+ BOOST_CHECK_EQUAL( y[0], -1.0 );
+ BOOST_CHECK_EQUAL( y[1], 0.0 );
+ BOOST_CHECK_EQUAL( y[2], 1.0 );
+}
+
+
+BOOST_AUTO_TEST_CASE( UnaryOperationsTest ) {
+ Prob xorg(3);
+ xorg.set( 0, 2.0 );
+ xorg.set( 1, 0.0 );
+ xorg.set( 2, 1.0 );
+ Prob y(3);
+
+ Prob x = xorg;
+ BOOST_CHECK( x.setUniform() == Prob(3) );
+ BOOST_CHECK( x == Prob(3) );
+
+ y.set( 0, std::exp(2.0) );
+ y.set( 1, 1.0 );
+ y.set( 2, std::exp(1.0) );
+ x = xorg;
+ BOOST_CHECK( x.takeExp() == y );
+ BOOST_CHECK( x == y );
+ x = xorg;
+ BOOST_CHECK( x.pwUnaryOp( fo_exp<Real>() ) == y );
+ BOOST_CHECK( x == y );
+
+ y.set( 0, std::log(2.0) );
+ y.set( 1, -INFINITY );
+ y.set( 2, 0.0 );
+ x = xorg;
+ BOOST_CHECK( x.takeLog() == y );
+ BOOST_CHECK( x == y );
+ x = xorg;
+ BOOST_CHECK( x.takeLog(false) == y );
+ BOOST_CHECK( x == y );
+ x = xorg;
+ BOOST_CHECK( x.pwUnaryOp( fo_log<Real>() ) == y );
+ BOOST_CHECK( x == y );
+
+ y.set( 1, 0.0 );
+ x = xorg;
+ BOOST_CHECK( x.takeLog(true) == y );
+ BOOST_CHECK( x == y );
+ x = xorg;
+ BOOST_CHECK( x.pwUnaryOp( fo_log0<Real>() ) == y );
+ BOOST_CHECK( x == y );
+
+ y.set( 0, 2.0 / 3.0 );
+ y.set( 1, 0.0 / 3.0 );
+ y.set( 2, 1.0 / 3.0 );
+ x = xorg;
+ BOOST_CHECK_EQUAL( x.normalize(), 3.0 );
+ BOOST_CHECK( x == y );
+
+ x = xorg;
+ BOOST_CHECK_EQUAL( x.normalize( NORMPROB ), 3.0 );
+ BOOST_CHECK( x == y );
+
+ y.set( 0, 2.0 / 2.0 );
+ y.set( 1, 0.0 / 2.0 );
+ y.set( 2, 1.0 / 2.0 );
+ x = xorg;
+ BOOST_CHECK_EQUAL( x.normalize( NORMLINF ), 2.0 );
+ BOOST_CHECK( x == y );
+
+ xorg.set( 0, -2.0 );
+ y.set( 0, 2.0 );
+ y.set( 1, 0.0 );
+ y.set( 2, 1.0 );
+ x = xorg;
+ BOOST_CHECK( x.takeAbs() == y );
+ BOOST_CHECK( x == y );
+
+ for( size_t repeat = 0; repeat < 10000; repeat++ ) {
+ x.randomize();
+ for( size_t i = 0; i < x.size(); i++ ) {
+ BOOST_CHECK( x[i] < 1.0 );
+ BOOST_CHECK( x[i] >= 0.0 );
+ }
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE( ScalarTransformationsTest ) {
+ Prob x(3);
+ x.set( 0, 2.0 );
+ x.set( 1, 0.0 );
+ x.set( 2, 1.0 );
+ Prob y(3);
+
+ y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
+ BOOST_CHECK( (x + 1.0) == y );
+ y.set( 0, 0.0 ); y.set( 1, -2.0 ); y.set( 2, -1.0 );
+ BOOST_CHECK( (x + (-2.0)) == y );
+
+ y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
+ BOOST_CHECK( (x - 1.0) == y );
+ y.set( 0, 4.0 ); y.set( 1, 2.0 ); y.set( 2, 3.0 );
+ BOOST_CHECK( (x - (-2.0)) == y );
+
+ BOOST_CHECK( (x * 1.0) == x );
+ y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
+ BOOST_CHECK( (x * 2.0) == y );
+ y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
+ BOOST_CHECK( (x * -0.5) == y );
+
+ BOOST_CHECK( (x / 1.0) == x );
+ y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
+ BOOST_CHECK( (x / 2.0) == y );
+ y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
+ BOOST_CHECK( (x / -0.5) == y );
+ BOOST_CHECK( (x / 0.0) == Prob(3, 0.0) );
+
+ BOOST_CHECK( (x ^ 1.0) == x );
+ BOOST_CHECK( (x ^ 0.0) == Prob(3, 1.0) );
+ y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
+ BOOST_CHECK( (x ^ 2.0) == y );
+ y.set( 0, 1.0 / std::sqrt(2.0) ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
+ Prob z = (x ^ -0.5);
+ BOOST_CHECK_CLOSE( z[0], y[0], tol );
+ BOOST_CHECK_EQUAL( z[1], y[1] );
+ BOOST_CHECK_CLOSE( z[2], y[2], tol );
+}
+
+
+BOOST_AUTO_TEST_CASE( ScalarOperationsTest ) {
+ Prob xorg(3), x(3);
+ xorg.set( 0, 2.0 );
+ xorg.set( 1, 0.0 );
+ xorg.set( 2, 1.0 );
+ Prob y(3);
+
+ x = xorg;
+ BOOST_CHECK( x.fill( 1.0 ) == Prob(3, 1.0) );
+ BOOST_CHECK( x == Prob(3, 1.0) );
+ BOOST_CHECK( x.fill( 2.0 ) == Prob(3, 2.0) );
+ BOOST_CHECK( x == Prob(3, 2.0) );
+ BOOST_CHECK( x.fill( 0.0 ) == Prob(3, 0.0) );
+ BOOST_CHECK( x == Prob(3, 0.0) );
+
+ x = xorg;
+ y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
+ BOOST_CHECK( (x += 1.0) == y );
+ BOOST_CHECK( x == y );
+ y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
+ BOOST_CHECK( (x += -2.0) == y );
+ BOOST_CHECK( x == y );
+
+ x = xorg;
+ y.set( 0, 1.0 ); y.set( 1, -1.0 ); y.set( 2, 0.0 );
+ BOOST_CHECK( (x -= 1.0) == y );
+ BOOST_CHECK( x == y );
+ y.set( 0, 3.0 ); y.set( 1, 1.0 ); y.set( 2, 2.0 );
+ BOOST_CHECK( (x -= -2.0) == y );
+ BOOST_CHECK( x == y );
+
+ x = xorg;
+ BOOST_CHECK( (x *= 1.0) == x );
+ BOOST_CHECK( x == x );
+ y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 2.0 );
+ BOOST_CHECK( (x *= 2.0) == y );
+ BOOST_CHECK( x == y );
+ y.set( 0, -1.0 ); y.set( 1, 0.0 ); y.set( 2, -0.5 );
+ BOOST_CHECK( (x *= -0.25) == y );
+ BOOST_CHECK( x == y );
+
+ x = xorg;
+ BOOST_CHECK( (x /= 1.0) == x );
+ BOOST_CHECK( x == x );
+ y.set( 0, 1.0 ); y.set( 1, 0.0 ); y.set( 2, 0.5 );
+ BOOST_CHECK( (x /= 2.0) == y );
+ BOOST_CHECK( x == y );
+ y.set( 0, -4.0 ); y.set( 1, 0.0 ); y.set( 2, -2.0 );
+ BOOST_CHECK( (x /= -0.25) == y );
+ BOOST_CHECK( x == y );
+ BOOST_CHECK( (x /= 0.0) == Prob(3, 0.0) );
+ BOOST_CHECK( x == Prob(3, 0.0) );
+
+ x = xorg;
+ BOOST_CHECK( (x ^= 1.0) == x );
+ BOOST_CHECK( x == x );
+ BOOST_CHECK( (x ^= 0.0) == Prob(3, 1.0) );
+ BOOST_CHECK( x == Prob(3, 1.0) );
+ x = xorg;
+ y.set( 0, 4.0 ); y.set( 1, 0.0 ); y.set( 2, 1.0 );
+ BOOST_CHECK( (x ^= 2.0) == y );
+ BOOST_CHECK( x == y );
+ y.set( 0, 0.5 ); y.set( 1, INFINITY ); y.set( 2, 1.0 );
+ BOOST_CHECK( (x ^= -0.5) == y );
+ BOOST_CHECK( x == y );
+}
+
+
+BOOST_AUTO_TEST_CASE( VectorOperationsTest ) {
+ size_t N = 6;
+ Prob xorg(N), x(N);
+ xorg.set( 0, 2.0 ); xorg.set( 1, 0.0 ); xorg.set( 2, 1.0 ); xorg.set( 3, 0.0 ); xorg.set( 4, 2.0 ); xorg.set( 5, 3.0 );
+ Prob y(N);
+ y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
+ Prob z(N), r(N);
+
+ z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
+ x = xorg;
+ r = (x += y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.pwBinaryOp( y, std::plus<Real>() ) == z );
+ BOOST_CHECK( x == z );
+
+ z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
+ x = xorg;
+ r = (x -= y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.pwBinaryOp( y, std::minus<Real>() ) == z );
+ BOOST_CHECK( x == z );
+
+ z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
+ x = xorg;
+ r = (x *= y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.pwBinaryOp( y, std::multiplies<Real>() ) == z );
+ BOOST_CHECK( x == z );
+
+ z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
+ x = xorg;
+ r = (x /= y);
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.pwBinaryOp( y, fo_divides0<Real>() ) == z );
+ BOOST_CHECK( x == z );
+
+ z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY ); /*z.set( 3, INFINITY );*/ z.set( 4, -1.0 ); z.set( 5, 1.0 );
+ x = xorg;
+ r = (x.divide( y ));
+ BOOST_CHECK_CLOSE( r[0], z[0], tol );
+ BOOST_CHECK_CLOSE( r[1], z[1], tol );
+ BOOST_CHECK_EQUAL( r[2], z[2] );
+ BOOST_CHECK( isnan(r[3]) );
+ BOOST_CHECK_CLOSE( r[4], z[4], tol );
+ BOOST_CHECK_CLOSE( r[5], z[5], tol );
+ x.set( 3, 0.0 ); r.set( 3, 0.0 );
+ BOOST_CHECK( x == r );
+ x = xorg;
+ r = x.pwBinaryOp( y, std::divides<Real>() );
+ BOOST_CHECK_CLOSE( r[0], z[0], tol );
+ BOOST_CHECK_CLOSE( r[1], z[1], tol );
+ BOOST_CHECK_EQUAL( r[2], z[2] );
+ BOOST_CHECK( isnan(r[3]) );
+ BOOST_CHECK_CLOSE( r[4], z[4], tol );
+ BOOST_CHECK_CLOSE( r[5], z[5], tol );
+ x.set( 3, 0.0 ); r.set( 3, 0.0 );
+ BOOST_CHECK( x == r );
+
+ z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
+ x = xorg;
+ r = (x ^= y);
+ BOOST_CHECK_CLOSE( r[0], z[0], tol );
+ BOOST_CHECK_EQUAL( r[1], z[1] );
+ BOOST_CHECK_CLOSE( r[2], z[2], tol );
+ BOOST_CHECK_CLOSE( r[3], z[3], tol );
+ BOOST_CHECK_CLOSE( r[4], z[4], tol );
+ BOOST_CHECK_CLOSE( r[5], z[5], tol );
+ BOOST_CHECK( x == z );
+ x = xorg;
+ BOOST_CHECK( x.pwBinaryOp( y, fo_pow<Real>() ) == z );
+ BOOST_CHECK( x == z );
+}
+
+
+BOOST_AUTO_TEST_CASE( VectorTransformationsTest ) {
+ size_t N = 6;
+ Prob x(N);
+ x.set( 0, 2.0 ); x.set( 1, 0.0 ); x.set( 2, 1.0 ); x.set( 3, 0.0 ); x.set( 4, 2.0 ); x.set( 5, 3.0 );
+ Prob y(N);
+ y.set( 0, 0.5 ); y.set( 1, -1.0 ); y.set( 2, 0.0 ); y.set( 3, 0.0 ); y.set( 4, -2.0 ); y.set( 5, 3.0 );
+ Prob z(N), r(N);
+
+ z.set( 0, 2.5 ); z.set( 1, -1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 0.0 ); z.set( 5, 6.0 );
+ r = x + y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ z = x.pwBinaryTr( y, std::plus<Real>() );
+ BOOST_CHECK( r == z );
+
+ z.set( 0, 1.5 ); z.set( 1, 1.0 ); z.set( 2, 1.0 ); z.set( 3, 0.0 ); z.set( 4, 4.0 ); z.set( 5, 0.0 );
+ r = x - y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ z = x.pwBinaryTr( y, std::minus<Real>() );
+ BOOST_CHECK( r == z );
+
+ z.set( 0, 1.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -4.0 ); z.set( 5, 9.0 );
+ r = x * y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ z = x.pwBinaryTr( y, std::multiplies<Real>() );
+ BOOST_CHECK( r == z );
+
+ z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, 0.0 ); z.set( 3, 0.0 ); z.set( 4, -1.0 ); z.set( 5, 1.0 );
+ r = x / y;
+ for( size_t i = 0; i < N; i++ )
+ BOOST_CHECK_CLOSE( r[i], z[i], tol );
+ z = x.pwBinaryTr( y, fo_divides0<Real>() );
+ BOOST_CHECK( r == z );
+
+ z.set( 0, 4.0 ); z.set( 1, 0.0 ); z.set( 2, INFINITY ); /*z.set( 3, INFINITY );*/ z.set( 4, -1.0 ); z.set( 5, 1.0 );
+ r = x.divided_by( y );
+ BOOST_CHECK_CLOSE( r[0], z[0], tol );
+ BOOST_CHECK_CLOSE( r[1], z[1], tol );
+ BOOST_CHECK_EQUAL( r[2], z[2] );
+ BOOST_CHECK( isnan(r[3]) );
+ BOOST_CHECK_CLOSE( r[4], z[4], tol );
+ BOOST_CHECK_CLOSE( r[5], z[5], tol );
+ z = x.pwBinaryTr( y, std::divides<Real>() );
+ BOOST_CHECK_CLOSE( r[0], z[0], tol );
+ BOOST_CHECK_CLOSE( r[1], z[1], tol );
+ BOOST_CHECK_EQUAL( r[2], z[2] );
+ BOOST_CHECK( isnan(r[3]) );
+ BOOST_CHECK_CLOSE( r[4], z[4], tol );
+ BOOST_CHECK_CLOSE( r[5], z[5], tol );
+
+ z.set( 0, std::sqrt(2.0) ); z.set( 1, INFINITY ); z.set( 2, 1.0 ); z.set( 3, 1.0 ); z.set( 4, 0.25 ); z.set( 5, 27.0 );
+ r = x ^ y;
+ BOOST_CHECK_CLOSE( r[0], z[0], tol );
+ BOOST_CHECK_EQUAL( r[1], z[1] );
+ BOOST_CHECK_CLOSE( r[2], z[2], tol );
+ BOOST_CHECK_CLOSE( r[3], z[3], tol );
+ BOOST_CHECK_CLOSE( r[4], z[4], tol );
+ BOOST_CHECK_CLOSE( r[5], z[5], tol );
+ z = x.pwBinaryTr( y, fo_pow<Real>() );
+ BOOST_CHECK( r == z );
+}
+
+
+BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
+ Prob x(3), y(3), z(3);
+ x.set( 0, 0.2 );
+ x.set( 1, 0.8 );
+ x.set( 2, 0.0 );
+ y.set( 0, 0.0 );
+ y.set( 1, 0.6 );
+ y.set( 2, 0.4 );
+
+ z = min( x, y );
+ BOOST_CHECK_EQUAL( z[0], 0.0 );
+ BOOST_CHECK_EQUAL( z[1], 0.6 );
+ BOOST_CHECK_EQUAL( z[2], 0.0 );
+ z = max( x, y );
+ BOOST_CHECK_EQUAL( z[0], 0.2 );
+ BOOST_CHECK_EQUAL( z[1], 0.8 );
+ BOOST_CHECK_EQUAL( z[2], 0.4 );
+
+ BOOST_CHECK_EQUAL( dist( x, x, DISTL1 ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( y, y, DISTL1 ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), 0.2 + 0.2 + 0.4 );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), 0.2 + 0.2 + 0.4 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTL1 ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTL1 ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) );
+ BOOST_CHECK_EQUAL( dist( x, x, DISTLINF ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( y, y, DISTLINF ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), 0.4 );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), 0.4 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTLINF ), x.innerProduct( y, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTLINF ), y.innerProduct( x, 0.0, fo_max<Real>(), fo_absdiff<Real>() ) );
+ BOOST_CHECK_EQUAL( dist( x, x, DISTTV ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( y, y, DISTTV ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), 0.5 * (0.2 + 0.2 + 0.4) );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTTV ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTTV ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_absdiff<Real>() ) / 2.0 );
+ BOOST_CHECK_EQUAL( dist( x, x, DISTKL ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( y, y, DISTKL ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), INFINITY );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), INFINITY );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
+ BOOST_CHECK_EQUAL( dist( x, x, DISTHEL ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( y, y, DISTHEL ), 0.0 );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), 0.5 * (0.2 + std::pow(std::sqrt(0.8) - std::sqrt(0.6), 2.0) + 0.4) );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTHEL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTHEL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_Hellinger<Real>() ) / 2.0 );
+ x.set( 1, 0.7 ); x.set( 2, 0.1 );
+ y.set( 0, 0.1 ); y.set( 1, 0.5 );
+ BOOST_CHECK_CLOSE( dist( x, y, DISTKL ), 0.2 * std::log(0.2 / 0.1) + 0.7 * std::log(0.7 / 0.5) + 0.1 * std::log(0.1 / 0.4), tol );
+ BOOST_CHECK_CLOSE( dist( y, x, DISTKL ), 0.1 * std::log(0.1 / 0.2) + 0.5 * std::log(0.5 / 0.7) + 0.4 * std::log(0.4 / 0.1), tol );
+ BOOST_CHECK_EQUAL( dist( x, y, DISTKL ), x.innerProduct( y, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
+ BOOST_CHECK_EQUAL( dist( y, x, DISTKL ), y.innerProduct( x, 0.0, std::plus<Real>(), fo_KL<Real>() ) );
+
+ Prob xx(4), yy(4);
+ for( size_t i = 0; i < 3; i++ ) {
+ xx.set( i, x[i] );
+ yy.set( i, y[i] );
+ }
+ std::stringstream ss;
+ ss << xx;
+ std::string s;
+ std::getline( ss, s );
+#ifdef DAI_SPARSE
+ BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.2, 1:0.7, 2:0.1)") );
+#else
+ BOOST_CHECK_EQUAL( s, std::string("(0.2, 0.7, 0.1, 0.25)") );
+#endif
+ std::stringstream ss2;
+ ss2 << yy;
+ std::getline( ss2, s );
+#ifdef DAI_SPARSE
+ BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.1, 1:0.5, 2:0.4)") );
+#else
+ BOOST_CHECK_EQUAL( s, std::string("(0.1, 0.5, 0.4, 0.25)") );
+#endif
+
+ z = min( x, y );
+ BOOST_CHECK_EQUAL( z[0], 0.1 );
+ BOOST_CHECK_EQUAL( z[1], 0.5 );
+ BOOST_CHECK_EQUAL( z[2], 0.1 );
+ z = max( x, y );
+ BOOST_CHECK_EQUAL( z[0], 0.2 );
+ BOOST_CHECK_EQUAL( z[1], 0.7 );
+ BOOST_CHECK_EQUAL( z[2], 0.4 );
+
+ BOOST_CHECK_CLOSE( x.innerProduct( y, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
+ BOOST_CHECK_CLOSE( y.innerProduct( x, 0.0, std::plus<Real>(), std::multiplies<Real>() ), 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
+ BOOST_CHECK_CLOSE( x.innerProduct( y, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
+ BOOST_CHECK_CLOSE( y.innerProduct( x, 1.0, std::plus<Real>(), std::multiplies<Real>() ), 1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
+ BOOST_CHECK_CLOSE( x.innerProduct( y, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
+ BOOST_CHECK_CLOSE( y.innerProduct( x, -1.0, std::plus<Real>(), std::multiplies<Real>() ), -1.0 + 0.2*0.1 + 0.7*0.5 + 0.1*0.4, tol );
+}
+++ /dev/null
-/* This file is part of libDAI - http://www.libdai.org/
- *
- * libDAI is licensed under the terms of the GNU General Public License version
- * 2, or (at your option) any later version. libDAI is distributed without any
- * warranty. See the file COPYING for more details.
- *
- * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
- */
-
-
-#include <dai/properties.h>
-#include <strstream>
-
-
-using namespace dai;
-
-
-#define BOOST_TEST_MODULE PropertiesTest
-
-
-#include <boost/test/unit_test.hpp>
-
-
-BOOST_AUTO_TEST_CASE( PropertyTest ) {
- std::stringstream str1, str2, str3, str4, str5, str6, str7, str8;
- std::string s;
-
- Property p;
- p.first = "key";
-
- p.second = (int)-5;
- str1 << p;
- str1 >> s;
- BOOST_CHECK_EQUAL( s, "key=-5" );
-
- p.second = (size_t)5;
- str2 << p;
- str2 >> s;
- BOOST_CHECK_EQUAL( s, "key=5" );
-
- p.second = std::string("value");
- str3 << p;
- str3 >> s;
- BOOST_CHECK_EQUAL( s, "key=value" );
-
- p.second = 3.141;
- str4 << p;
- str4 >> s;
- BOOST_CHECK_EQUAL( s, "key=3.141" );
-
- p.second = (long double)3.141;
- str5 << p;
- str5 >> s;
- BOOST_CHECK_EQUAL( s, "key=3.141" );
-
- p.second = false;
- str6 << p;
- str6 >> s;
- BOOST_CHECK_EQUAL( s, "key=0" );
-
- p.second = PropertySet()("prop2",(size_t)5)("prop1",std::string("hi"));
- str7 << p;
- str7 >> s;
- BOOST_CHECK_EQUAL( s, "key=[prop1=hi,prop2=5]" );
-
- p.second = std::vector<int>();
- BOOST_CHECK_THROW( str7 << p, Exception );
-}
-
-
-BOOST_AUTO_TEST_CASE( ConstructorsTest ) {
- PropertySet x;
- PropertySet y("[key1=val1,key2=5,key3=[key3a=val,key3b=7.0],key4=1.0]");
- PropertySet z = PropertySet()("key1",std::string("val1"))("key2",5)("key3",PropertySet("[key3a=val,key3b=7.0]"))("key4",1.0);
-
- BOOST_CHECK( !x.hasKey("key") );
- BOOST_CHECK( !y.hasKey("key") );
- BOOST_CHECK( !z.hasKey("key") );
- BOOST_CHECK( y.hasKey("key1") );
- BOOST_CHECK( y.hasKey("key2") );
- BOOST_CHECK( y.hasKey("key3") );
- BOOST_CHECK( !y.hasKey("key3a") );
- BOOST_CHECK( !y.hasKey("key3b") );
- BOOST_CHECK( y.hasKey("key4") );
- BOOST_CHECK( z.hasKey("key1") );
- BOOST_CHECK( z.hasKey("key2") );
- BOOST_CHECK( z.hasKey("key3") );
- BOOST_CHECK( !z.hasKey("key3a") );
- BOOST_CHECK( !z.hasKey("key3b") );
- BOOST_CHECK( z.hasKey("key4") );
-
- BOOST_CHECK_EQUAL( x.size(), 0 );
- BOOST_CHECK_EQUAL( y.size(), 4 );
- BOOST_CHECK_EQUAL( z.size(), 4 );
- std::set<std::string> keys;
- keys.insert( "key1" );
- keys.insert( "key2" );
- keys.insert( "key3" );
- keys.insert( "key4" );
- BOOST_CHECK( y.keys() == keys );
- BOOST_CHECK( z.keys() == keys );
-
- BOOST_CHECK_THROW( x.get( "key" ), Exception );
- BOOST_CHECK_THROW( y.get( "key" ), Exception );
- BOOST_CHECK_THROW( z.get( "key" ), Exception );
- BOOST_CHECK_THROW( x.getAs<std::string>( "key" ), Exception );
- BOOST_CHECK_THROW( y.getAs<int>( "key" ), Exception );
- BOOST_CHECK_THROW( z.getAs<double>( "key" ), Exception );
- BOOST_CHECK_THROW( x.getStringAs<int>( "key" ), Exception );
- BOOST_CHECK_THROW( y.getStringAs<int>( "key" ), Exception );
- BOOST_CHECK_THROW( z.getStringAs<int>( "key" ), Exception );
-
- BOOST_CHECK_EQUAL( boost::any_cast<std::string>(y.get( "key1" )), std::string("val1") );
- BOOST_CHECK_EQUAL( boost::any_cast<std::string>(y.get( "key2" )), std::string("5") );
- BOOST_CHECK_EQUAL( boost::any_cast<std::string>(y.get( "key3" )), std::string("[key3a=val,key3b=7.0]") );
- BOOST_CHECK_EQUAL( boost::any_cast<std::string>(y.get( "key4" )), std::string("1.0") );
- BOOST_CHECK_EQUAL( y.getAs<std::string>( "key1" ), std::string("val1") );
- BOOST_CHECK_EQUAL( y.getAs<std::string>( "key2" ), std::string("5") );
- BOOST_CHECK_EQUAL( y.getAs<std::string>( "key3" ), std::string("[key3a=val,key3b=7.0]") );
- BOOST_CHECK_EQUAL( y.getAs<std::string>( "key4" ), std::string("1.0") );
- BOOST_CHECK_EQUAL( y.getStringAs<std::string>( "key1" ), std::string("val1") );
- BOOST_CHECK_EQUAL( y.getStringAs<size_t>( "key2" ), 5 );
- BOOST_CHECK_EQUAL( y.getStringAs<int>( "key2" ), 5 );
- BOOST_CHECK_EQUAL( y.getStringAs<double>( "key4" ), 1.0 );
- PropertySet key3val = y.getStringAs<PropertySet>( "key3" );
- BOOST_CHECK_EQUAL( key3val.size(), 2 );
- BOOST_CHECK( key3val.hasKey( "key3a" ) );
- BOOST_CHECK( key3val.hasKey( "key3b" ) );
- BOOST_CHECK_EQUAL( key3val.getStringAs<std::string>( "key3a" ), std::string("val") );
- BOOST_CHECK_EQUAL( key3val.getStringAs<double>( "key3b" ), 7.0 );
- BOOST_CHECK_THROW( y.getAs<int>( "key2" ), Exception );
- BOOST_CHECK_THROW( y.getStringAs<int>( "key4" ), Exception );
-
- BOOST_CHECK_EQUAL( boost::any_cast<std::string>(z.get( "key1" )), std::string("val1") );
- BOOST_CHECK_EQUAL( boost::any_cast<int>(z.get( "key2" )), 5 );
- BOOST_CHECK_EQUAL( boost::any_cast<PropertySet>(z.get( "key3" )).size(), 2 );
- BOOST_CHECK_EQUAL( boost::any_cast<double>(z.get( "key4" )), 1.0 );
- BOOST_CHECK_EQUAL( z.getAs<std::string>( "key1" ), std::string("val1") );
- BOOST_CHECK_EQUAL( z.getAs<int>( "key2" ), 5 );
- BOOST_CHECK_EQUAL( z.getAs<PropertySet>( "key3" ).size(), 2 );
- BOOST_CHECK_EQUAL( z.getAs<double>( "key4" ), 1.0 );
- BOOST_CHECK_EQUAL( z.getStringAs<std::string>( "key1" ), std::string("val1") );
- BOOST_CHECK_EQUAL( z.getStringAs<int>( "key2" ), 5 );
- BOOST_CHECK_EQUAL( z.getStringAs<double>( "key4" ), 1.0 );
- key3val = z.getAs<PropertySet>( "key3" );
- BOOST_CHECK_EQUAL( key3val.size(), 2 );
- BOOST_CHECK( key3val.hasKey( "key3a" ) );
- BOOST_CHECK( key3val.hasKey( "key3b" ) );
- BOOST_CHECK_EQUAL( key3val.getStringAs<std::string>( "key3a" ), std::string("val") );
- BOOST_CHECK_EQUAL( key3val.getStringAs<double>( "key3b" ), 7.0 );
- BOOST_CHECK_THROW( z.getAs<size_t>( "key2" ), Exception );
- BOOST_CHECK_THROW( z.getStringAs<size_t>( "key4" ), Exception );
-}
-
-
-BOOST_AUTO_TEST_CASE( SetTest ) {
- PropertySet x;
- PropertySet y("[key1=val1,key2=5]");
-
- x.set( "key1", 5 );
- BOOST_CHECK( x.hasKey( "key1" ) );
- BOOST_CHECK( !x.hasKey( "key2" ) );
- BOOST_CHECK_EQUAL( x.getAs<int>( "key1" ), 5 );
- BOOST_CHECK_THROW( x.getAs<double>( "key1" ), Exception );
- BOOST_CHECK_THROW( x.getAs<int>( "key2" ), Exception );
-
- x.set( y );
- BOOST_CHECK( x.hasKey( "key1" ) );
- BOOST_CHECK( x.hasKey( "key2" ) );
- BOOST_CHECK( !x.hasKey( "key" ) );
- BOOST_CHECK_EQUAL( x.getAs<std::string>( "key1" ), std::string("val1") );
- BOOST_CHECK_EQUAL( x.getAs<std::string>( "key2" ), std::string("5") );
- x.setAsString<int>( "key1", 5 );
- BOOST_CHECK_EQUAL( x.getAs<std::string>( "key1" ), std::string("5") );
- x.convertTo<size_t>( "key1" );
- BOOST_CHECK_EQUAL( x.getAs<size_t>( "key1" ), 5 );
- x.setAsString( "key1", -5 );
- BOOST_CHECK_EQUAL( x.getAs<std::string>( "key1" ), std::string("-5") );
- x.convertTo<int>( "key1" );
- BOOST_CHECK_EQUAL( x.getAs<int>( "key1" ), -5 );
- x.setAsString( "key1", 1.234 );
- BOOST_CHECK_EQUAL( x.getAs<std::string>( "key1" ), std::string("1.234") );
- BOOST_CHECK_THROW( x.convertTo<int>( "key1" ), Exception );
- x.convertTo<double>( "key1" );
- BOOST_CHECK_EQUAL( x.getAs<double>( "key1" ), 1.234 );
- x.setAsString( "key1", "val1");
- BOOST_CHECK_EQUAL( x.getAs<std::string>( "key1" ), std::string("val1") );
- BOOST_CHECK_THROW( x.convertTo<int>( "key1" ), Exception );
- x.convertTo<std::string>( "key1" );
- BOOST_CHECK_EQUAL( x.getAs<std::string>( "key1" ), std::string("val1") );
-
- BOOST_CHECK_EQUAL( x.erase( "key1" ), 1 );
- BOOST_CHECK( !x.hasKey( "key1" ) );
-}
-
-
-BOOST_AUTO_TEST_CASE( StreamTest ) {
- std::stringstream ss1, ss2, ss3;
- std::string s;
-
- PropertySet z = PropertySet()("key1",std::string("val1"))("key2",5)("key3",PropertySet("[key3a=val,key3b=7.0]"))("key4",1.0);
- ss1 << z;
- ss1 >> s;
- BOOST_CHECK_EQUAL( s, std::string("[key1=val1,key2=5,key3=[key3a=val,key3b=7.0],key4=1]") );
- PropertySet y;
- ss2 << z;
- ss2 >> y;
- BOOST_CHECK( y.hasKey( "key1" ) );
- BOOST_CHECK( y.hasKey( "key2" ) );
- BOOST_CHECK( y.hasKey( "key3" ) );
- BOOST_CHECK( y.hasKey( "key4" ) );
- BOOST_CHECK_EQUAL( y.getAs<std::string>( "key1" ), std::string("val1") );
- BOOST_CHECK_EQUAL( y.getAs<std::string>( "key2" ), std::string("5") );
- BOOST_CHECK_EQUAL( y.getAs<std::string>( "key3" ), std::string("[key3a=val,key3b=7.0]") );
- BOOST_CHECK_EQUAL( y.getAs<std::string>( "key4" ), std::string("1") );
-
- z.set( "key5", std::vector<int>() );
- BOOST_CHECK_THROW( ss1 << z, Exception );