Extended SWIG python interface (inspired by Kyle Ellrott): inference is possible...
authorJoris Mooij <j.mooij@cs.ru.nl>
Thu, 22 Nov 2012 16:20:48 +0000 (17:20 +0100)
committerJoris Mooij <j.mooij@cs.ru.nl>
Thu, 22 Nov 2012 16:20:48 +0000 (17:20 +0100)
42 files changed:
ChangeLog
include/dai/bipgraph.h
include/dai/bp.h
include/dai/clustergraph.h
include/dai/cobwebgraph.h
include/dai/dag.h
include/dai/daialg.h
include/dai/exactinf.h
include/dai/factor.h
include/dai/factorgraph.h
include/dai/gibbs.h
include/dai/glc.h
include/dai/graph.h
include/dai/jtree.h
include/dai/prob.h
include/dai/properties.h
include/dai/regiongraph.h
include/dai/smallset.h
include/dai/util.h
include/dai/var.h
include/dai/varset.h
include/dai/weightedgraph.h
src/cobwebgraph.cpp
src/exactinf.cpp
src/matlab/dai.cpp
src/matlab/dai_jtree.cpp
swig/README
swig/dai.i
swig/example.py [new file with mode: 0644]
tests/unit/bipgraph_test.cpp
tests/unit/clustergraph_test.cpp
tests/unit/dag_test.cpp
tests/unit/factor_test.cpp
tests/unit/factorgraph_test.cpp
tests/unit/graph_test.cpp
tests/unit/prob_test.cpp
tests/unit/properties_test.cpp
tests/unit/regiongraph_test.cpp
tests/unit/smallset_test.cpp
tests/unit/var_test.cpp
tests/unit/varset_test.cpp
tests/unit/weightedgraph_test.cpp

index 10f6e80..1701530 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,5 +1,8 @@
 git HEAD
 --------
+* Extended SWIG python interface (inspired by Kyle Ellrott): inference is possible in Python now!
+* Added toString() formatting functions to the classes that had an output streaming operator<<
+* Replaced all occurrences of std::size_t by size_t to avoid SWIG problems
 * Fixed division-by-zero issue in pow() which caused problems for GLC+ with sparse factors
 * Fixed bugs in unit test prob_test.cpp: replaced all occurences of
   BOOST_CHECK_CLOSE(...,(Real)0.0,tol) with BOOST_CHECK_SMALL(...,tol)
index 569f815..6c86bd2 100644 (file)
@@ -324,14 +324,21 @@ class BipartiteGraph {
 
     /// \name Input and output
     //@{
-        /// Writes this BipartiteGraph to an output stream in GraphViz .dot syntax
+        /// Writes a BipartiteGraph to an output stream in GraphViz .dot syntax
         void printDot( std::ostream& os ) const;
 
-        /// Writes this BipartiteGraph to an output stream
+        /// Writes a BipartiteGraph to an output stream
         friend std::ostream& operator<<( std::ostream& os, const BipartiteGraph& g ) {
             g.printDot( os );
             return os;
         }
+
+        /// Formats a BipartiteGraph as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
     //@}
 };
 
index 0769fc3..4d0936e 100644 (file)
@@ -74,7 +74,7 @@ class BP : public DAIAlgFG {
         /// Stores all edge properties
         std::vector<std::vector<EdgeProp> > _edges;
         /// Type of lookup table (only used for maximum-residual BP)
-        typedef std::multimap<Real, std::pair<std::size_t, std::size_t> > LutType;
+        typedef std::multimap<Real, std::pair<size_t, size_t> > LutType;
         /// Lookup table (only used for maximum-residual BP)
         std::vector<std::vector<LutType::iterator> > _edge2lut;
         /// Lookup table (only used for maximum-residual BP)
@@ -84,7 +84,7 @@ class BP : public DAIAlgFG {
         /// Number of iterations needed
         size_t _iters;
         /// The history of message updates (only recorded if \a recordSentMessages is \c true)
-        std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
+        std::vector<std::pair<size_t, size_t> > _sentMessages;
         /// Stores variable beliefs of previous iteration
         std::vector<Factor> _oldBeliefsV;
         /// Stores factor beliefs of previous iteration
@@ -194,7 +194,7 @@ class BP : public DAIAlgFG {
         virtual Real logZ() const;
         /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
          */
-        std::vector<std::size_t> findMaximum() const { return dai::findMaximum( *this ); }
+        std::vector<size_t> findMaximum() const { return dai::findMaximum( *this ); }
         virtual void init();
         virtual void init( const VarSet &ns );
         virtual Real run();
@@ -209,7 +209,7 @@ class BP : public DAIAlgFG {
     /// \name Additional interface specific for BP
     //@{
         /// Returns history of which messages have been updated
-        const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
+        const std::vector<std::pair<size_t, size_t> >& getSentMessages() const {
             return _sentMessages;
         }
 
index 75ab339..4c10288 100644 (file)
@@ -209,6 +209,13 @@ namespace dai {
                 os << cl.clusters();
                 return os;
             }
+
+            /// Formats a ClusterGraph as a string
+            std::string toString() const {
+                std::stringstream ss;
+                ss << *this;
+                return ss.str();
+            }
         //@}
 
         /// \name Variable elimination
index fc68e1b..e2f4250 100644 (file)
@@ -29,22 +29,10 @@ namespace dai {
 
 /// A CobwebGraph is a special type of region graph used by the GLC algorithm
 /** \author Siamak Ravanbakhsh
+ *  \todo Implement unit test for Cobwebgraph
  */
 class CobwebGraph : public FactorGraph {
-    protected:
-        /// Vector of variable indices internal to each region (r)
-        std::vector<SmallSet<size_t> >        _INRs;
-        /// Vector of variable indices on the boundary of each region (\ominus r)
-        std::vector<SmallSet<size_t> >        _EXRs;
-        /// Index of factors in each region
-        std::vector<SmallSet<size_t> >        _Rfs;
-        /// Index of factors internal to each region, i.e., all its variables are internal to the region
-        std::vector<SmallSet<size_t> >        _Rifs;
-        /// Index of factors that bridge each region, i.e., not all its variables are internal to the region
-        std::vector<SmallSet<size_t> >        _Rxfs;
-        /// The vector of domain of messages leaving each region (\ominus r_{p,q})
-        std::vector<std::vector<VarSet> >     _outM;
-
+    public:
         /// The information in connection between two regions
         struct Connection {
             /// Index of the first region (p)
@@ -78,6 +66,19 @@ class CobwebGraph : public FactorGraph {
          */
         DAI_ENUM(NeighborType,ALL,TOP,CLOSEST);
 
+    protected:
+        /// Vector of variable indices internal to each region (r)
+        std::vector<SmallSet<size_t> >        _INRs;
+        /// Vector of variable indices on the boundary of each region (\ominus r)
+        std::vector<SmallSet<size_t> >        _EXRs;
+        /// Index of factors in each region
+        std::vector<SmallSet<size_t> >        _Rfs;
+        /// Index of factors internal to each region, i.e., all its variables are internal to the region
+        std::vector<SmallSet<size_t> >        _Rifs;
+        /// Index of factors that bridge each region, i.e., not all its variables are internal to the region
+        std::vector<SmallSet<size_t> >        _Rxfs;
+        /// The vector of domain of messages leaving each region (\ominus r_{p,q})
+        std::vector<std::vector<VarSet> >     _outM;
         /// Vector of all connections to each region
         std::vector<std::vector<Connection> > _M;
 
@@ -93,7 +94,6 @@ class CobwebGraph : public FactorGraph {
         /// Whether a given set of region vars makes a partitioning or not
         bool isPartition;
 
-
     public:
     /// \name Constructors and destructors
     //@{
@@ -241,7 +241,16 @@ class CobwebGraph : public FactorGraph {
         }
 
         /// Writes a cobweb graph to an output stream
-        friend std::ostream& operator<< ( std::ostream& os, const CobwebGraph& rg );
+        friend std::ostream& operator<< ( std::ostream& /*os*/, const CobwebGraph& /*rg*/ ) {
+            DAI_THROW(NOT_IMPLEMENTED);
+        }
+
+        /// Formats a cobweb graph as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
 
         /// Writes a cobweb graph to a GraphViz .dot file
         /** \note Not implemented yet
index fbfb828..494948c 100644 (file)
@@ -279,14 +279,21 @@ class DAG {
 
     /// \name Input and output
     //@{
-        /// Writes this DAG to an output stream in GraphViz .dot syntax
+        /// Writes a DAG to an output stream in GraphViz .dot syntax
         void printDot( std::ostream& os ) const;
 
-        /// Writes this DAG to an output stream
+        /// Writes a DAG to an output stream
         friend std::ostream& operator<<( std::ostream& os, const DAG& g ) {
             g.printDot( os );
             return os;
         }
+
+        /// Formats a DAG as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
     //@}
 };
 
index e58a830..6581bc0 100644 (file)
@@ -124,7 +124,7 @@ class InfAlg {
         /** \note Before this method is called, run() should have been called.
          *  \throw NOT_IMPLEMENTED if not implemented/supported
          */
-        virtual std::vector<std::size_t> findMaximum() const { DAI_THROW(NOT_IMPLEMENTED); }
+        virtual std::vector<size_t> findMaximum() const { DAI_THROW(NOT_IMPLEMENTED); }
 
         /// Returns maximum difference between single variable beliefs in the last iteration.
         /** \throw NOT_IMPLEMENTED if not implemented/supported
index 764d787..52268f4 100644 (file)
@@ -74,7 +74,7 @@ class ExactInf : public DAIAlgFG {
         virtual Real logZ() const { return _logZ; }
         /** \note The complexity of this calculation is exponential in the number of variables.
          */
-        std::vector<std::size_t> findMaximum() const;
+        std::vector<size_t> findMaximum() const;
         virtual void init();
         virtual void init( const VarSet &/*ns*/ ) {}
         virtual Real run();
index b157586..ec236f4 100644 (file)
@@ -177,6 +177,13 @@ class TFactor {
         bool operator==( const TFactor<T>& y ) const {
             return (_vs == y._vs) && (_p == y._p);
         }
+
+        /// Formats a factor as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
     //@}
 
     /// \name Unary transformations
index 5aff6a5..f0adad3 100644 (file)
@@ -361,6 +361,19 @@ class FactorGraph {
 
         /// Writes a factor graph to a GraphViz .dot file
         virtual void printDot( std::ostream& os ) const;
+
+        /// Formats a factor graph as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
+        
+        /// Reads a factor graph from a string
+        void fromString( const std::string& s ) {
+            std::stringstream ss( s );
+            ss >> *this;
+        }
     //@}
 
     private:
index 4f5b633..4eb9617 100644 (file)
@@ -90,7 +90,7 @@ class Gibbs : public DAIAlgFG {
         virtual Factor beliefF( size_t I ) const;
         virtual std::vector<Factor> beliefs() const;
         virtual Real logZ() const { DAI_THROW(NOT_IMPLEMENTED); return 0.0; }
-        std::vector<std::size_t> findMaximum() const { return _max_state; }
+        std::vector<size_t> findMaximum() const { return _max_state; }
         virtual void init();
         virtual void init( const VarSet &/*ns*/ ) { init(); }
         virtual Real run();
index f22d27c..009a101 100644 (file)
@@ -17,7 +17,6 @@
 
 #include <algorithm>
 #include <set>
-#include <iostream>
 #include <string>
 #include <dai/util.h>
 #include <dai/daialg.h>
index ab6857a..d8ffb61 100644 (file)
@@ -295,14 +295,21 @@ class GraphAL {
 
     /// \name Input and output
     //@{
-        /// Writes this GraphAL to an output stream in GraphViz .dot syntax
+        /// Writes a GraphAL to an output stream in GraphViz .dot syntax
         void printDot( std::ostream& os ) const;
 
-        /// Writes this GraphAL to an output stream
+        /// Writes a GraphAL to an output stream
         friend std::ostream& operator<<( std::ostream& os, const GraphAL& g ) {
             g.printDot( os );
             return os;
         }
+
+        /// Formats a GraphAL as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
     //@}
 };
 
index b7041e2..de329e5 100644 (file)
@@ -125,7 +125,7 @@ class JTree : public DAIAlgRG {
         virtual Real logZ() const;
         /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
          */
-        std::vector<std::size_t> findMaximum() const;
+        std::vector<size_t> findMaximum() const;
         virtual void init() {}
         virtual void init( const VarSet &/*ns*/ ) {}
         virtual Real run();
index d1b6895..51dfe47 100644 (file)
@@ -420,6 +420,13 @@ class TProb {
                 return false;
             return p() == q.p();
         }
+
+        /// Formats a TProb as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
     //@}
 
     /// \name Unary transformations
index cdabd8c..2becfbf 100644 (file)
@@ -241,12 +241,25 @@ PropertySet p()("method","BP")("verbose",1)("tol",1e-9)
          */
         friend std::ostream& operator<< ( std::ostream& os, const PropertySet& ps );
 
+        /// Formats a PropertySet as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
+
         /// Reads a PropertySet object from an input stream.
         /** It expects a string in the format <tt>"[key1=val1,key2=val2,...,keyn=valn]"</tt>.
          *  Values are stored as strings.
          *  \throw MALFORMED_PROPERTY if the string is not in the expected format
          */
         friend std::istream& operator>> ( std::istream& is, PropertySet& ps );
+        
+        /// Reads a PropertySet from a string
+        void fromString( const std::string& s ) {
+            std::stringstream ss( s );
+            ss >> *this;
+        }
     //@}
 };
 
index aef3ed3..2fd82a9 100644 (file)
@@ -244,6 +244,13 @@ class RegionGraph : public FactorGraph {
         /// Writes a region graph to an output stream
         friend std::ostream& operator<< ( std::ostream& os, const RegionGraph& rg );
 
+        /// Formats a region graph as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
+
         /// Writes a region graph to a GraphViz .dot file
         /** \note Not implemented yet
          */
index 03963b4..61db611 100644 (file)
@@ -17,6 +17,7 @@
 #include <vector>
 #include <algorithm>
 #include <iostream>
+#include <sstream>
 
 
 namespace dai {
@@ -237,7 +238,7 @@ class SmallSet {
         }
     //@}
 
-    /// \name Streaming input/output
+    /// \name Input/output
     //@{
         /// Writes a SmallSet to an output stream
         friend std::ostream& operator << ( std::ostream& os, const SmallSet& x ) {
@@ -247,6 +248,13 @@ class SmallSet {
             os << "}";
             return os;
         }
+
+        /// Formats a SmallSet as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
     //@}
 };
 
index dba4002..c59ceb8 100644 (file)
@@ -102,7 +102,7 @@ typedef mpz_class BigInt;
 
 /// Safe down-cast of big integer to size_t
 inline size_t BigInt_size_t( const BigInt &N ) {
-    DAI_ASSERT( N <= (BigInt)std::numeric_limits<std::size_t>::max() );
+    DAI_ASSERT( N <= (BigInt)std::numeric_limits<size_t>::max() );
     return N.get_ui();
 }
 
index 4815da5..0157148 100644 (file)
@@ -15,6 +15,7 @@
 
 
 #include <iostream>
+#include <sstream>
 #include <dai/exceptions.h>
 
 
@@ -115,6 +116,13 @@ class Var {
         friend std::ostream& operator << ( std::ostream& os, const Var& n ) {
             return( os << "x" << n.label() );
         }
+
+        /// Formats a Var as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
 };
 
 
index d32b1e4..c51f7d9 100644 (file)
@@ -145,6 +145,13 @@ class VarSet : public SmallSet<Var> {
             os << "}";
             return( os );
         }
+
+        /// Formats a VarSet as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
     //@}
 };
 
index 1546715..b89ec7b 100644 (file)
@@ -63,6 +63,13 @@ class DEdge {
             os << "(" << e.first << "->" << e.second << ")";
             return os;
         }
+
+        /// Formats a directed edge as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
 };
 
 
@@ -106,6 +113,13 @@ class UEdge {
                 os << "{" << e.second << "--" << e.first << "}";
             return os;
         }
+
+        /// Formats an undirected edge as a string
+        std::string toString() const {
+            std::stringstream ss;
+            ss << *this;
+            return ss.str();
+        }
 };
 
 
index 73c368d..28e5aa8 100644 (file)
@@ -264,9 +264,4 @@ void CobwebGraph::setExtnFact() {
 }
 
 
-ostream& operator<< ( ostream& os, const CobwebGraph& rg ) {
-    return os;
-}
-
-
 } // end of namespace dai
index 4c66019..54ecc2a 100644 (file)
@@ -92,7 +92,7 @@ Factor ExactInf::calcMarginal( const VarSet &vs ) const {
 }
 
         
-std::vector<std::size_t> ExactInf::findMaximum() const {
+std::vector<size_t> ExactInf::findMaximum() const {
     Factor P;
     for( size_t I = 0; I < nrFactors(); I++ )
         P *= factor(I);
index a3f2f53..e2d26a3 100644 (file)
@@ -138,7 +138,7 @@ void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
     }
 
     if( nlhs >= 6 ) {
-        std::vector<std::size_t> map_state;
+        std::vector<size_t> map_state;
         bool supported = true;
         try {
             map_state = obj->findMaximum();
index 6e59648..c532f09 100644 (file)
@@ -189,7 +189,7 @@ void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
     }
 
     if( nlhs >= 5 ) {
-        std::vector<std::size_t> map_state;
+        std::vector<size_t> map_state;
         bool supported = true;
         try {
             map_state = jt.findMaximum();
index bebf640..a05475c 100644 (file)
@@ -1,3 +1,3 @@
 This directory contains preliminary experimental SWIG wrappers for libDAI
-written by Patrick Pletscher. They enable usage of libDAI functionality
-directly from python and octave.
+written originally by Patrick Pletscher, extended later by Kyle Ellrott and Joris Mooij. 
+They enable usage of libDAI functionality directly from python and octave.
index d60c9f1..5faec4e 100644 (file)
 
 %module dai
 
+%include "std_string.i"
+%include "std_vector.i"
+%template(IntVector) std::vector<size_t>;
+//%include "std_set.i"  /* for python */
+
 %{
-#include "../include/dai/var.h"
-#include "../include/dai/smallset.h"
-#include "../include/dai/varset.h"
-#include "../include/dai/prob.h"
-#include "../include/dai/factor.h"
-#include "../include/dai/graph.h"
-#include "../include/dai/bipgraph.h"
-#include "../include/dai/factorgraph.h"
-#include "../include/dai/util.h"
+#define DAI_WITH_BP 1
+#define DAI_WITH_FBP 1
+#define DAI_WITH_TRWBP 1
+#define DAI_WITH_MF 1
+#define DAI_WITH_HAK 1
+#define DAI_WITH_LC 1
+#define DAI_WITH_TREEEP 1
+#define DAI_WITH_JTREE 1
+#define DAI_WITH_MR 1
+#define DAI_WITH_GIBBS 1
+#define DAI_WITH_CBP 1
+#define DAI_WITH_DECMAP 1
+#define DAI_WITH_GLC 1
+#include "../include/dai/alldai.h"
+
+using namespace dai;
 %}
 
-%ignore dai::TProb::operator[];
-%ignore dai::TFactor::operator[];
+// ************************************************************************************************
+%include "../include/dai/util.h"
 
+// ************************************************************************************************
 %ignore dai::Var::label() const;
 %ignore dai::Var::states() const;
-
-%include "../include/dai/util.h"
+%ignore operator<<(std::ostream&, const Var&);
 %include "../include/dai/var.h"
+%extend dai::Var {
+    inline const char* __str__() const { return (*self).toString().c_str(); }  /* for python */
+    inline std::string __str() const { return (*self).toString(); }  /* for octave */
+};
+
+// ************************************************************************************************
+%ignore operator<<(std::ostream&, const SmallSet&);
+%rename(__eq__) operator==(const SmallSet&, const SmallSet&); /* for python */
+%rename(__ne__) operator!=(const SmallSet&, const SmallSet&); /* for python */
+%rename(__lt__) operator<(const SmallSet&, const SmallSet&);  /* for python */
 %include "../include/dai/smallset.h"
 %template(SmallSetVar) dai::SmallSet< dai::Var >;
+%extend dai::SmallSet<dai::Var> {
+    inline const char* __str__() const { return (*self).toString().c_str(); }  /* for python */
+    inline std::string __str() const { return (*self).toString(); }  /* for octave */
+}
+
+// ************************************************************************************************
+%ignore operator<<(std::ostream&, const VarSet&);
 %include "../include/dai/varset.h"
 %extend dai::VarSet {
-        inline void append(const dai::Var &v) { (*self) |= v; }   /* for python, octave */
+    inline void append(const dai::Var &v) { (*self) |= v; }   /* for python, octave */
+    inline const char* __str__() const { return (*self).toString().c_str(); }  /* for python */
+    inline std::string __str() const { return (*self).toString(); }  /* for octave */
 };
 
+// ************************************************************************************************
+%ignore dai::TProb::operator[];
 %include "../include/dai/prob.h"
 %template(Prob) dai::TProb<dai::Real>;
 %extend dai::TProb<dai::Real> {
-        inline dai::Real __getitem__(int i) const {return (*self).get(i);} /* for python */
-        inline void __setitem__(int i,dai::Real d) {(*self).set(i,d);}   /* for python */
-        inline dai::Real __paren(int i) const {return (*self).get(i);}     /* for octave */
-        inline void __paren_asgn(int i,dai::Real d) {(*self).set(i,d);}  /* for octave */
+    inline dai::Real __getitem__(int i) const {return (*self).get(i);} /* for python */
+    inline void __setitem__(int i,dai::Real d) {(*self).set(i,d);}   /* for python */
+    inline dai::Real __paren(int i) const {return (*self).get(i);}     /* for octave */
+    inline void __paren_asgn(int i,dai::Real d) {(*self).set(i,d);}  /* for octave */
+    inline const char* __str__() const { return (*self).toString().c_str(); }  /* for python */
+    inline std::string __str() const { return (*self).toString(); }  /* for octave */
 };
+
+// ************************************************************************************************
+%ignore dai::TFactor::operator[];
 %include "../include/dai/factor.h"
 %extend dai::TFactor<dai::Real> {
-        inline dai::Real __getitem__(int i) const {return (*self).get(i);} /* for python */
-        inline void __setitem__(int i,dai::Real d) {(*self).set(i,d);}   /* for python */
-        inline dai::Real __paren__(int i) const {return (*self).get(i);}     /* for octave */
-        inline void __paren_asgn__(int i,dai::Real d) {(*self).set(i,d);}  /* for octave */
+    inline dai::Real __getitem__(int i) const {return (*self).get(i);} /* for python */
+    inline void __setitem__(int i,dai::Real d) {(*self).set(i,d);}   /* for python */
+    inline dai::Real __paren__(int i) const {return (*self).get(i);}     /* for octave */
+    inline void __paren_asgn__(int i,dai::Real d) {(*self).set(i,d);}  /* for octave */
+    inline const char* __str__() const { return (*self).toString().c_str(); }  /* for python */
+    inline std::string __str() const { return (*self).toString(); }  /* for octave */
 };
-
 %template(Factor) dai::TFactor<dai::Real>;
+%inline %{
+typedef std::vector<dai::Factor> VecFactor;
+typedef std::vector<VecFactor> VecVecFactor;
+%}
+%template(VecFactor) std::vector<dai::Factor>;
+%template(VecVecFactor) std::vector<VecFactor>;
+
+// ************************************************************************************************
+%ignore operator<<(std::ostream&, const GraphAL&);
+%rename(toInt) dai::Neighbor::operator size_t() const;
 %include "../include/dai/graph.h"
+%extend dai::GraphAL {
+    inline const char* __str__() const { return (*self).toString().c_str(); }  /* for python */
+    inline std::string __str() const { return (*self).toString(); }  /* for octave */
+}
+
+// ************************************************************************************************
+%ignore operator<<(std::ostream&, const BipartiteGraph&);
 %include "../include/dai/bipgraph.h"
+%extend dai::BipartiteGraph {
+    inline const char* __str__() const { return (*self).toString().c_str(); }  /* for python */
+    inline std::string __str() const { return (*self).toString(); }  /* for octave */
+}
+
+// ************************************************************************************************
+%ignore operator<<(std::ostream&, const FactorGraph&);
+%ignore operator>>(std::istream&,FactorGraph&);
 %include "../include/dai/factorgraph.h"
-%include "std_vector.i"
-// TODO: typemaps for the vectors (input/output python arrays)
-%inline{
-typedef std::vector<dai::Factor> VecFactor;
-typedef std::vector< VecFactor > VecVecFactor;
+%extend dai::FactorGraph {
+    inline const char* __str__() const { return (*self).toString().c_str(); }  /* for python */
+    inline std::string __str() const { return (*self).toString(); }  /* for octave */
 }
-%template(VecFactor) std::vector< dai::Factor >;
-%template(VecVecFactor) std::vector< VecFactor >;
 
+// ************************************************************************************************
+%ignore operator<<(std::ostream&, const RegionGraph&);
+%include "../include/dai/regiongraph.h"
+%extend dai::RegionGraph {
+    inline const char* __str__() const { return (*self).toString().c_str(); }  /* for python */
+    inline std::string __str() const { return (*self).toString(); }  /* for octave */
+}
+
+// ************************************************************************************************
+//%ignore operator<<(std::ostream&, const CobwebGraph&);
+//%include "../include/dai/cobwebgraph.h"
+//%extend dai::CobwebGraph {
+//    inline const char* __str__() const { return (*self).toString().c_str(); }  /* for python */
+//    inline std::string __str() const { return (*self).toString(); }  /* for octave */
+//}
+//TODO fix problems with CobwebGraph
+
+// ************************************************************************************************
+%ignore operator<<(std::ostream&, const PropertySet&);
+%ignore operator>>(std::istream&,PropertySet&);
+%include "../include/dai/properties.h"
+%extend dai::PropertySet {
+    inline void __setitem__(char *name, char *val) {
+        self->set(std::string(name), std::string(val));
+    }
+    inline const char* __str__() const { return (*self).toString().c_str(); }  /* for python */
+    inline std::string __str() const { return (*self).toString(); }  /* for octave */
+}
+
+// ************************************************************************************************
+%ignore dai::IndexFor::operator++;
+%rename(toInt) dai::IndexFor::operator size_t() const;
+%ignore dai::Permute::operator[];
+%ignore dai::multifor::operator++;
+%ignore dai::multifor::operator[];
+%rename(toInt) dai::multifor::operator size_t() const;
+%ignore dai::State::operator++;
+%rename(toInt) dai::State::operator size_t() const;
+%ignore dai::State::operator const std::map<Var,size_t>&;
 %include "../include/dai/index.h"
+%extend dai::IndexFor {
+    inline void next() { return (*self)++; }
+};
+%extend dai::Permute {
+    inline size_t __getitem__(int i) const {return (*self)[i];} /* for python */
+    inline size_t __paren__(int i) const {return (*self)[i];}   /* for octave */
+};
 %extend dai::multifor {
-    inline size_t __getitem__(int i) const {
-        return (*self)[i];
-    }
-    inline void next() {
-        return (*self)++;
-    }
+    inline void next() { return (*self)++; }
+    inline size_t __getitem__(int i) const {return (*self)[i];} /* for python */
+    inline size_t __paren__(int i) const {return (*self)[i];}   /* for octave */
+};
+%extend dai::State {
+    inline void next() { return (*self)++; }
 };
+
+// ************************************************************************************************
+%include "../include/dai/daialg.h"
+//TODO: why do the following lines not work?
+//%template(DAIAlgFG) dai::DAIAlg<dai::FactorGraph>;
+//%template(DAIAlgRG) dai::DAIAlg<dai::RegionGraph>;
+//%template(DAIAlgCG) dai::DAIAlg<dai::CobwebGraph>;
+
+// ************************************************************************************************
+%include "../include/dai/alldai.h"
+
+// ************************************************************************************************
+%ignore dai::BP::operator=;
+%include "../include/dai/bp.h"
+
+// ************************************************************************************************
+%include "../include/dai/fbp.h"
+
+// ************************************************************************************************
+%include "../include/dai/trwbp.h"
+
+// ************************************************************************************************
+%include "../include/dai/mf.h"
+
+// ************************************************************************************************
+%include "../include/dai/hak.h"
+
+// ************************************************************************************************
+%include "../include/dai/lc.h"
+
+// ************************************************************************************************
+%include "../include/dai/jtree.h"
+
+// ************************************************************************************************
+%ignore dai::TreeEP::operator=;
+%include "../include/dai/treeep.h"
+
+// ************************************************************************************************
+%include "../include/dai/mr.h"
+
+// ************************************************************************************************
+%include "../include/dai/gibbs.h"
+
+// ************************************************************************************************
+%include "../include/dai/cbp.h"
+
+// ************************************************************************************************
+%include "../include/dai/decmap.h"
+
+// ************************************************************************************************
+%include "../include/dai/glc.h"
diff --git a/swig/example.py b/swig/example.py
new file mode 100644 (file)
index 0000000..d205736
--- /dev/null
@@ -0,0 +1,216 @@
+# This file is part of libDAI - http:#www.libdai.org/
+#
+# Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
+#
+# Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
+
+
+# This example program illustrates how to read a factrograph from
+# a file and run Belief Propagation, Max-Product and JunctionTree on it.
+# This version uses the SWIG python wrapper of libDAI
+
+
+import dai
+import sys
+
+a = dai.IntVector()
+
+if len(sys.argv) != 2 and len(sys.argv) != 3:
+    print 'Usage:', sys.argv[0], "<filename.fg> [maxstates]"
+    print 'Reads factor graph <filename.fg> and runs'
+    print 'Belief Propagation, Max-Product and JunctionTree on it.'
+    print 'JunctionTree is only run if a junction tree is found with'
+    print 'total number of states less than <maxstates> (where 0 means unlimited).'
+    sys.exit(1)
+else:
+    # Report inference algorithms built into libDAI
+#   print 'Builtin inference algorithms:', dai.builtinInfAlgNames()
+#   TODO THIS CRASHES
+
+    # Read FactorGraph from the file specified by the first command line argument
+    fg = dai.FactorGraph()
+    fg.ReadFromFile(sys.argv[1])
+    maxstates = 1000000
+    if len(sys.argv) == 3:
+        maxstates = int(sys.argv[2])
+
+    # Set some constants
+    maxiter = 10000
+    tol = 1e-9
+    verb = 1
+
+    # Store the constants in a PropertySet object
+    opts = dai.PropertySet()
+    opts["maxiter"] = str(maxiter)   # Maximum number of iterations
+    opts["tol"] = str(tol)           # Tolerance for convergence
+    opts["verbose"] = str(verb)      # Verbosity (amount of output generated)
+
+    # Bound treewidth for junctiontree
+    do_jt = True
+    # TODO
+    #    try {
+    #        boundTreewidth(fg, &eliminationCost_MinFill, maxstates );
+    #    } catch( Exception &e ) {
+    #        if( e.getCode() == Exception::OUT_OF_MEMORY ) {
+    #            do_jt = false;
+    #            cout << "Skipping junction tree (need more than " << maxstates << " states)." << endl;
+    #        }
+    #        else
+    #            throw;
+    #    }
+
+    if do_jt:
+        # Construct a JTree (junction tree) object from the FactorGraph fg
+        # using the parameters specified by opts and an additional property
+        # that specifies the type of updates the JTree algorithm should perform
+        jtopts = opts
+        jtopts["updates"] = "HUGIN"
+        jt = dai.JTree( fg, jtopts )
+        # Initialize junction tree algorithm
+        jt.init()
+        # Run junction tree algorithm
+        jt.run()
+
+        # Construct another JTree (junction tree) object that is used to calculate
+        # the joint configuration of variables that has maximum probability (MAP state)
+        jtmapopts = opts
+        jtmapopts["updates"] = "HUGIN"
+        jtmapopts["inference"] = "MAXPROD"
+        jtmap = dai.JTree( fg, jtmapopts )
+        # Initialize junction tree algorithm
+        jtmap.init()
+        # Run junction tree algorithm
+        jtmap.run()
+        # Calculate joint state of all variables that has maximum probability
+        jtmapstate = jtmap.findMaximum()
+
+        # Construct a BP (belief propagation) object from the FactorGraph fg
+        # using the parameters specified by opts and two additional properties,
+        # specifying the type of updates the BP algorithm should perform and
+        # whether they should be done in the real or in the logdomain
+        bpopts = opts
+        bpopts["updates"] = "SEQRND"
+        bpopts["logdomain"] = "0"
+        bp = dai.BP( fg, bpopts )
+        # Initialize belief propagation algorithm
+        bp.init()
+        # Run belief propagation algorithm
+        bp.run()
+
+        # Construct a BP (belief propagation) object from the FactorGraph fg
+        # using the parameters specified by opts and two additional properties,
+        # specifying the type of updates the BP algorithm should perform and
+        # whether they should be done in the real or in the logdomain
+        #
+        # Note that inference is set to MAXPROD, which means that the object
+        # will perform the max-product algorithm instead of the sum-product algorithm
+        mpopts = opts
+        mpopts["updates"] = "SEQRND"
+        mpopts["logdomain"] = "0"
+        mpopts["inference"] = "MAXPROD"
+        mpopts["damping"] = "0.1"
+        mp = dai.BP( fg, mpopts )
+        # Initialize max-product algorithm
+        mp.init()
+        # Run max-product algorithm
+        mp.run()
+        # Calculate joint state of all variables that has maximum probability
+        # based on the max-product result
+        mpstate = mp.findMaximum()
+
+        # Construct a decimation algorithm object from the FactorGraph fg
+        # using the parameters specified by opts and three additional properties,
+        # specifying that the decimation algorithm should use the max-product
+        # algorithm and should completely reinitalize its state at every step
+        decmapopts = opts
+        decmapopts["reinit"] = "1"
+        decmapopts["ianame"] = "BP"
+        decmapopts["iaopts"] = "[damping=0.1,inference=MAXPROD,logdomain=0,maxiter=1000,tol=1e-9,updates=SEQRND,verbose=1]"
+        decmap = dai.DecMAP( fg, decmapopts )
+        decmap.init()
+        decmap.run()
+        decmapstate = decmap.findMaximum()
+
+        if do_jt:
+            # Report variable marginals for fg, calculated by the junction tree algorithm
+            print 'Exact variable marginals:'
+            for i in range(fg.nrVars()):                 # iterate over all variables in fg
+                print jt.belief(dai.VarSet(fg.var(i)))   # display the "belief" of jt for that variable
+
+        # Report variable marginals for fg, calculated by the belief propagation algorithm
+        print 'Approximate (loopy belief propagation) variable marginals:'
+        for i in range(fg.nrVars()):                     # iterate over all variables in fg
+            print bp.belief(dai.VarSet(fg.var(i)))       # display the belief of bp for that variable
+
+        if do_jt:
+            # Report factor marginals for fg, calculated by the junction tree algorithm
+            print 'Exact factor marginals:'
+            for I in range(fg.nrFactors()):              # iterate over all factors in fg
+                print jt.belief(fg.factor(I).vars())     # display the "belief" of jt for the variables in that factor
+
+        # Report factor marginals for fg, calculated by the belief propagation algorithm
+        print 'Approximate (loopy belief propagation) factor marginals:'
+        for I in range(fg.nrFactors()):                  # iterate over all factors in fg
+            print bp.belief(fg.factor(I).vars())         # display the belief of bp for the variables in that factor
+
+        if do_jt:
+            # Report log partition sum (normalizing constant) of fg, calculated by the junction tree algorithm
+            print 'Exact log partition sum:', jt.logZ()
+
+        # Report log partition sum of fg, approximated by the belief propagation algorithm
+        print 'Approximate (loopy belief propagation) log partition sum:', bp.logZ()
+
+        if do_jt:
+            # Report exact MAP variable marginals
+            print 'Exact MAP variable marginals:'
+            for i in range(fg.nrVars()):
+                print jtmap.belief(dai.VarSet(fg.var(i)))
+
+        # Report max-product variable marginals
+        print 'Approximate (max-product) MAP variable marginals:'
+        for i in range(fg.nrVars()):
+            print mp.belief(dai.VarSet(fg.var(i)))
+
+        if do_jt:
+            # Report exact MAP factor marginals
+            print 'Exact MAP factor marginals:'
+            for I in range(fg.nrFactors()):
+                print jtmap.belief(fg.factor(I).vars()), '==', jtmap.belief(fg.factor(I).vars())
+
+        # Report max-product factor marginals
+        print 'Approximate (max-product) MAP factor marginals:'
+        for I in range(fg.nrFactors()):
+            print mp.belief(fg.factor(I).vars()), '==', mp.belief(fg.factor(I).vars())
+
+        if do_jt:
+            # Report exact MAP joint state
+            hoie = dai.IntVector()
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            hoie.push_back( 0 )
+            print 'Exact MAP state (log score =', fg.logScore( hoie ), '):'
+            for i in range(len(jtmapstate)):
+                print fg.var(i), ':', jtmapstate[i]
+
+        # Report max-product MAP joint state
+        print 'Approximate (max-product) MAP state (log score =', fg.logScore( mpstate ), '):'
+        for i in range(len(mpstate)):
+            print fg.var(i), ':', mpstate[i]
+
+        # Report DecMAP joint state
+        print 'Approximate DecMAP state (log score =', fg.logScore( decmapstate ), '):'
+        for i in range(len(decmapstate)):
+            print fg.var(i), ':', decmapstate[i]
index f85809e..19dd7a5 100644 (file)
@@ -417,4 +417,6 @@ BOOST_AUTO_TEST_CASE( StreamTest ) {
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- y1;" );
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- y2;" );
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
+
+    BOOST_CHECK_EQUAL( G.toString(), "graph BipartiteGraph {\nnode[shape=circle,width=0.4,fixedsize=true];\n\tx0;\n\tx1;\nnode[shape=box,width=0.3,height=0.3,fixedsize=true];\n\ty0;\n\ty1;\n\ty2;\n\tx0 -- y0;\n\tx0 -- y1;\n\tx1 -- y1;\n\tx1 -- y2;\n}\n" );
 }
index dc1a86c..c18416c 100644 (file)
@@ -540,4 +540,6 @@ BOOST_AUTO_TEST_CASE( IOTest ) {
     std::string s;
     getline( ss, s );
     BOOST_CHECK_EQUAL( s, "({x0, x1}, {x1, x2}, {x2, x3}, {x1, x3})" );
+
+    BOOST_CHECK_EQUAL( G.toString(), "({x0, x1}, {x1, x2}, {x2, x3}, {x1, x3})" );
 }
index 83417cc..bb7b3b2 100644 (file)
@@ -391,4 +391,6 @@ BOOST_AUTO_TEST_CASE( StreamTest ) {
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -> x3;" );
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx2 -> x3;" );
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
+
+    BOOST_CHECK_EQUAL( G.toString(), "digraph DAG {\nnode[shape=circle,width=0.4,fixedsize=true];\n\tx0;\n\tx1;\n\tx2;\n\tx3;\n\tx0 -> x1;\n\tx0 -> x2;\n\tx1 -> x3;\n\tx2 -> x3;\n}\n" );
 }
index 2832673..2041b20 100644 (file)
@@ -845,6 +845,9 @@ BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
     std::getline( ss2, s );
     BOOST_CHECK_EQUAL( s, std::string("({x0}, (0.1, 0.5, 0.4))") );
 
+    BOOST_CHECK_EQUAL( x.toString(), "({x0}, (0.2, 0.7, 0.1))" );
+    BOOST_CHECK_EQUAL( y.toString(), "({x0}, (0.1, 0.5, 0.4))" );
+
     z = min( x, y );
     BOOST_CHECK_EQUAL( z[0], (Real)0.1 );
     BOOST_CHECK_EQUAL( z[1], (Real)0.5 );
index 1ef18b1..b182391 100644 (file)
@@ -748,6 +748,8 @@ BOOST_AUTO_TEST_CASE( IOTest ) {
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "0          3" );
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "1          3" );
 
+    BOOST_CHECK_EQUAL( G.toString(), "3\n\n2\n0 1 \n2 2 \n4\n0          1\n1          1\n2          1\n3          1\n\n2\n1 2 \n2 2 \n4\n0          2\n1          2\n2          2\n3          2\n\n1\n1 \n2 \n2\n0          3\n1          3\n" );
+
     ss << G;
     FactorGraph G3;
     ss >> G3;
@@ -759,4 +761,15 @@ BOOST_AUTO_TEST_CASE( IOTest ) {
         for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
             BOOST_CHECK_CLOSE( G.factor(I)[s], G3.factor(I)[s], tol );
     }
+
+    FactorGraph G4;
+    G4.fromString( G.toString() );
+    BOOST_CHECK( G.vars() == G4.vars() );
+    BOOST_CHECK( G.bipGraph() == G4.bipGraph() );
+    BOOST_CHECK_EQUAL( G.nrFactors(), G4.nrFactors() );
+    for( size_t I = 0; I < G.nrFactors(); I++ ) {
+        BOOST_CHECK( G.factor(I).vars() == G4.factor(I).vars() );
+        for( size_t s = 0; s < G.factor(I).nrStates(); s++ )
+            BOOST_CHECK_CLOSE( G.factor(I)[s], G4.factor(I)[s], tol );
+    }
 }
index 10fd00e..cc2c82c 100644 (file)
@@ -454,4 +454,6 @@ BOOST_AUTO_TEST_CASE( StreamTest ) {
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx1 -- x3;" );
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\tx2 -- x3;" );
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
+
+    BOOST_CHECK_EQUAL( G.toString(), "graph GraphAL {\nnode[shape=circle,width=0.4,fixedsize=true];\n\tx0;\n\tx1;\n\tx2;\n\tx3;\n\tx0 -- x1;\n\tx0 -- x2;\n\tx1 -- x3;\n\tx2 -- x3;\n}\n" );
 }
index 9b6bc12..f9215ef 100644 (file)
@@ -751,18 +751,20 @@ BOOST_AUTO_TEST_CASE( RelatedFunctionsTest ) {
     std::string s;
     std::getline( ss, s );
 #ifdef DAI_SPARSE
-    BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.2, 1:0.7, 2:0.1)") );
+    BOOST_CHECK_EQUAL( s, "(size:4, def:0.25, 0:0.2, 1:0.7, 2:0.1)" );
 #else
-    BOOST_CHECK_EQUAL( s, std::string("(0.2, 0.7, 0.1, 0.25)") );
+    BOOST_CHECK_EQUAL( s, "(0.2, 0.7, 0.1, 0.25)" );
 #endif
+    BOOST_CHECK_EQUAL( xx.toString(), s );
     std::stringstream ss2;
     ss2 << yy;
     std::getline( ss2, s );
 #ifdef DAI_SPARSE
-    BOOST_CHECK_EQUAL( s, std::string("(size:4, def:0.25, 0:0.1, 1:0.5, 2:0.4)") );
+    BOOST_CHECK_EQUAL( s, "(size:4, def:0.25, 0:0.1, 1:0.5, 2:0.4)" );
 #else
-    BOOST_CHECK_EQUAL( s, std::string("(0.1, 0.5, 0.4, 0.25)") );
+    BOOST_CHECK_EQUAL( s, "(0.1, 0.5, 0.4, 0.25)" );
 #endif
+    BOOST_CHECK_EQUAL( yy.toString(), s );
 
     z = min( x, y );
     BOOST_CHECK_EQUAL( z[0], (Real)0.1 );
index d6f5a1c..f350bcc 100644 (file)
@@ -200,6 +200,7 @@ BOOST_AUTO_TEST_CASE( StreamTest ) {
     ss1 << z;
     ss1 >> s;
     BOOST_CHECK_EQUAL( s, std::string("[key1=val1,key2=5,key3=[key3a=val,key3b=7.0],key4=1]") );
+    BOOST_CHECK_EQUAL( z.toString(), s );
     PropertySet y;
     ss2 << z;
     ss2 >> y;
@@ -211,6 +212,15 @@ BOOST_AUTO_TEST_CASE( StreamTest ) {
     BOOST_CHECK_EQUAL( y.getAs<std::string>( "key2" ), std::string("5") );
     BOOST_CHECK_EQUAL( y.getAs<std::string>( "key3" ), std::string("[key3a=val,key3b=7.0]") );
     BOOST_CHECK_EQUAL( y.getAs<std::string>( "key4" ), std::string("1") );
+    y.fromString( z.toString() );
+    BOOST_CHECK( y.hasKey( "key1" ) );
+    BOOST_CHECK( y.hasKey( "key2" ) );
+    BOOST_CHECK( y.hasKey( "key3" ) );
+    BOOST_CHECK( y.hasKey( "key4" ) );
+    BOOST_CHECK_EQUAL( y.getAs<std::string>( "key1" ), std::string("val1") );
+    BOOST_CHECK_EQUAL( y.getAs<std::string>( "key2" ), std::string("5") );
+    BOOST_CHECK_EQUAL( y.getAs<std::string>( "key3" ), std::string("[key3a=val,key3b=7.0]") );
+    BOOST_CHECK_EQUAL( y.getAs<std::string>( "key4" ), std::string("1") );
 
     z.set( "key5", std::vector<int>() );
     BOOST_CHECK_THROW( ss1 << z, Exception );
index e00a866..bb45de2 100644 (file)
@@ -1029,4 +1029,6 @@ BOOST_AUTO_TEST_CASE( IOTest ) {
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta0 -> b0;" );
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "\ta1 -> b0;" );
     std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "}" );
+
+    BOOST_CHECK_EQUAL( G.toString(), "digraph RegionGraph {\nnode[shape=box];\n\ta0 [label=\"a0: {x0, x1}, c=1\"];\n\ta1 [label=\"a1: {x1, x2}, c=1\"];\nnode[shape=ellipse];\n\tb0 [label=\"b0: {x1}, c=-1\"];\n\ta0 -> b0;\n\ta1 -> b0;\n}\n" );
 }
index e495698..0c7e7ee 100644 (file)
@@ -7,6 +7,7 @@
 
 
 #include <dai/smallset.h>
+#include <sstream>
 #include <vector>
 
 
@@ -825,3 +826,15 @@ BOOST_AUTO_TEST_CASE( OperatorTest ) {
     BOOST_CHECK(  (x123 >> x23 ) );
     BOOST_CHECK(  (x123 >> x123) );
 }
+
+
+BOOST_AUTO_TEST_CASE( IOTest ) {
+    SmallSet<size_t> u( 0, 5 );
+    u |= 1;
+    std::stringstream ss;
+    std::string s;
+    ss << u;
+    std::getline( ss, s ); BOOST_CHECK_EQUAL( s, "{0, 1, 5}" );
+    
+    BOOST_CHECK_EQUAL( u.toString(), "{0, 1, 5}" );
+}
index b1c0929..84bb131 100644 (file)
@@ -67,4 +67,5 @@ BOOST_AUTO_TEST_CASE( StreamTest ) {
     std::string s;
     ss >> s;
     BOOST_CHECK_EQUAL( s, "x5" );
+    BOOST_CHECK_EQUAL( x.toString(), s );
 }
index 0295c7b..061209c 100644 (file)
@@ -806,6 +806,7 @@ BOOST_AUTO_TEST_CASE( StreamTest ) {
     std::string s;
     std::getline( ss, s );
     BOOST_CHECK_EQUAL( s, "{x1, x2}" );
+    BOOST_CHECK_EQUAL( x.toString(), s );
 }
 
 
index 927ba18..cf3ca4c 100644 (file)
@@ -74,6 +74,7 @@ BOOST_AUTO_TEST_CASE( DEdgeTest ) {
     std::string s;
     ss >> s;
     BOOST_CHECK_EQUAL( s, "(5->3)" );
+    BOOST_CHECK_EQUAL( c.toString(), s );
 }
 
 
@@ -134,9 +135,11 @@ BOOST_AUTO_TEST_CASE( UEdgeTest ) {
     ss << c;
     ss >> s;
     BOOST_CHECK_EQUAL( s, "{3--5}" );
+    BOOST_CHECK_EQUAL( c.toString(), s );
     ss << b;
     ss >> s;
     BOOST_CHECK_EQUAL( s, "{3--5}" );
+    BOOST_CHECK_EQUAL( b.toString(), s );
 }