Generalized VarSet to "template<typename T> small_set<T>"
authorJoris Mooij <jorism@marvin.jorismooij.nl>
Mon, 29 Sep 2008 19:00:35 +0000 (21:00 +0200)
committerJoris Mooij <jorism@marvin.jorismooij.nl>
Mon, 29 Sep 2008 19:00:35 +0000 (21:00 +0200)
12 files changed:
ChangeLog
Makefile
Makefile.win
STATUS
TODO
include/dai/factor.h
include/dai/varset.h
src/hak.cpp
src/jtree.cpp
src/lc.cpp
src/varset.cpp [new file with mode: 0644]
utils/fginfo.cpp

index d5c9cbc..91987de 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -66,7 +66,7 @@ libDAI-0.2.2 (2008-??-??)
 * Improved documetation
 * Interface changes:
   - VarSet::
-      stateSpace() -> states()
+      VarSet::stateSpace() -> nrStates(const VarSet &)
       VarSet( const std::set<Var> ) -> VarSet( begin, end, sizeHint=0 )
       VarSet( const std::vector<Var> ) -> VarSet( begin, end, sizeHint=0 )
       removed bool operator||
index f519ffa..7c01696 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -122,8 +122,8 @@ all : $(TARGETS)
 
 matlabs : matlab/dai.$(ME) matlab/dai_readfg.$(ME) matlab/dai_writefg.$(ME) matlab/dai_potstrength.$(ME)
 
-$(LIB)/libdai$(LE) : bipgraph$(OE) daialg$(OE) alldai$(OE) clustergraph$(OE) factorgraph$(OE) properties$(OE) regiongraph$(OE) util$(OE) weightedgraph$(OE) exceptions$(OE) $(OBJECTS)
-       ar rcus $(LIB)/libdai$(LE) bipgraph$(OE) daialg$(OE) alldai$(OE) clustergraph$(OE) factorgraph$(OE) properties$(OE) regiongraph$(OE) util$(OE) weightedgraph$(OE) exceptions$(OE) $(OBJECTS)
+$(LIB)/libdai$(LE) : bipgraph$(OE) daialg$(OE) alldai$(OE) clustergraph$(OE) factorgraph$(OE) properties$(OE) regiongraph$(OE) util$(OE) weightedgraph$(OE) exceptions$(OE) varset$(OE) $(OBJECTS)
+       ar rcus $(LIB)/libdai$(LE) bipgraph$(OE) daialg$(OE) alldai$(OE) clustergraph$(OE) factorgraph$(OE) properties$(OE) regiongraph$(OE) util$(OE) weightedgraph$(OE) exceptions$(OE) varset$(OE) $(OBJECTS)
 
 tests : tests/test$(EE)
 
@@ -194,6 +194,9 @@ exceptions$(OE) : $(SRC)/exceptions.cpp $(HEADERS)
 alldai$(OE) : $(SRC)/alldai.cpp $(HEADERS)
        $(CC) $(CCFLAGS) -c $(SRC)/alldai.cpp
 
+varset$(OE) : $(SRC)/varset.cpp $(HEADERS)
+       $(CC) $(CCFLAGS) -c $(SRC)/varset.cpp
+
 
 # EXAMPLE
 ##########
@@ -214,14 +217,14 @@ tests/test$(EE) : tests/test.cpp $(HEADERS) $(LIB)/libdai$(LE)
 matlab/dai.$(ME) : matlab/dai.cpp $(HEADERS) matlab/matlab$(OE) $(LIB)/libdai$(LE)
        $(MEX) $(MEXFLAGS) -o matlab/dai matlab/dai.cpp matlab/matlab$(OE) $(LIB)/libdai$(LE)
 
-matlab/dai_readfg.$(ME) : matlab/dai_readfg.cpp $(HEADERS) factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE)
-       $(MEX) $(MEXFLAGS) -o matlab/dai_readfg matlab/dai_readfg.cpp factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE)
+matlab/dai_readfg.$(ME) : matlab/dai_readfg.cpp $(HEADERS) factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE) varset$(OE)
+       $(MEX) $(MEXFLAGS) -o matlab/dai_readfg matlab/dai_readfg.cpp factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE) varset$(OE)
 
-matlab/dai_writefg.$(ME) : matlab/dai_writefg.cpp $(HEADERS) factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE)
-       $(MEX) $(MEXFLAGS) -o matlab/dai_writefg matlab/dai_writefg.cpp factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE)
+matlab/dai_writefg.$(ME) : matlab/dai_writefg.cpp $(HEADERS) factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE) varset$(OE)
+       $(MEX) $(MEXFLAGS) -o matlab/dai_writefg matlab/dai_writefg.cpp factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE) varset$(OE)
 
-matlab/dai_potstrength.$(ME) : matlab/dai_potstrength.cpp $(HEADERS) matlab/matlab$(OE) exceptions$(OE)
-       $(MEX) $(MEXFLAGS) -o matlab/dai_potstrength matlab/dai_potstrength.cpp matlab/matlab$(OE) exceptions$(OE)
+matlab/dai_potstrength.$(ME) : matlab/dai_potstrength.cpp $(HEADERS) matlab/matlab$(OE) exceptions$(OE) varset$(OE)
+       $(MEX) $(MEXFLAGS) -o matlab/dai_potstrength matlab/dai_potstrength.cpp matlab/matlab$(OE) exceptions$(OE) varset$(OE)
 
 matlab/matlab$(OE) : matlab/matlab.cpp matlab/matlab.h $(HEADERS)
        $(MEX) $(MEXFLAGS) -outdir matlab -c matlab/matlab.cpp
index 9ecc574..876e988 100755 (executable)
@@ -123,8 +123,8 @@ all : $(TARGETS)
 \r
 matlabs : matlab/dai.$(ME) matlab/dai_readfg.$(ME) matlab/dai_writefg.$(ME) matlab/dai_potstrength.$(ME)\r
 \r
-$(LIB)/libdai$(LE) : bipgraph$(OE) daialg$(OE) alldai$(OE) clustergraph$(OE) factorgraph$(OE) properties$(OE) regiongraph$(OE) util$(OE) weightedgraph$(OE) exceptions$(OE) $(OBJECTS)\r
-       lib /out:$(LIB)/libdai$(LE) bipgraph$(OE) daialg$(OE) alldai$(OE) clustergraph$(OE) factorgraph$(OE) properties$(OE) regiongraph$(OE) util$(OE) weightedgraph$(OE) exceptions$(OE) $(OBJECTS)\r
+$(LIB)/libdai$(LE) : bipgraph$(OE) daialg$(OE) alldai$(OE) clustergraph$(OE) factorgraph$(OE) properties$(OE) regiongraph$(OE) util$(OE) weightedgraph$(OE) exceptions$(OE) varset$(OE) $(OBJECTS)\r
+       lib /out:$(LIB)/libdai$(LE) bipgraph$(OE) daialg$(OE) alldai$(OE) clustergraph$(OE) factorgraph$(OE) properties$(OE) regiongraph$(OE) util$(OE) weightedgraph$(OE) exceptions$(OE) varset$(OE) $(OBJECTS)\r
 \r
 tests : tests/test$(EE)\r
 \r
@@ -195,6 +195,9 @@ exceptions$(OE) : $(SRC)/exceptions.cpp $(HEADERS)
 alldai$(OE) : $(SRC)/alldai.cpp $(HEADERS)\r
        $(CC) $(CCFLAGS) -c $(SRC)/alldai.cpp\r
 \r
+varset$(OE) : $(SRC)/varset.cpp $(HEADERS)\r
+       $(CC) $(CCFLAGS) -c $(SRC)/varset.cpp\r
+\r
 \r
 # EXAMPLE\r
 ##########\r
@@ -216,14 +219,14 @@ tests/test$(EE) : tests/test.cpp $(HEADERS) $(LIB)/libdai$(LE)
 matlab/dai.$(ME) : matlab/dai.cpp $(HEADERS) matlab/matlab$(OE) $(LIB)/libdai$(LE)\r
        $(MEX) $(MEXFLAGS) -o matlab/dai matlab/dai.cpp matlab/matlab$(OE) $(LIB)/libdai$(LE)\r
 \r
-matlab/dai_readfg.$(ME) : matlab/dai_readfg.cpp $(HEADERS) factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE)\r
-       $(MEX) $(MEXFLAGS) -o matlab/dai_readfg matlab/dai_readfg.cpp factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE)\r
+matlab/dai_readfg.$(ME) : matlab/dai_readfg.cpp $(HEADERS) factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE) varset$(OE)\r
+       $(MEX) $(MEXFLAGS) -o matlab/dai_readfg matlab/dai_readfg.cpp factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE) varset$(OE)\r
 \r
-matlab/dai_writefg.$(ME) : matlab/dai_writefg.cpp $(HEADERS) factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE)\r
-       $(MEX) $(MEXFLAGS) -o matlab/dai_writefg matlab/dai_writefg.cpp factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE)\r
+matlab/dai_writefg.$(ME) : matlab/dai_writefg.cpp $(HEADERS) factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE) varset$(OE)\r
+       $(MEX) $(MEXFLAGS) -o matlab/dai_writefg matlab/dai_writefg.cpp factorgraph$(OE) matlab/matlab$(OE) exceptions$(OE) varset$(OE)\r
 \r
-matlab/dai_potstrength.$(ME) : matlab/dai_potstrength.cpp $(HEADERS) matlab/matlab$(OE) exceptions$(OE)\r
-       $(MEX) $(MEXFLAGS) -o matlab/dai_potstrength matlab/dai_potstrength.cpp matlab/matlab$(OE) exceptions$(OE)\r
+matlab/dai_potstrength.$(ME) : matlab/dai_potstrength.cpp $(HEADERS) matlab/matlab$(OE) exceptions$(OE) varset$(OE)\r
+       $(MEX) $(MEXFLAGS) -o matlab/dai_potstrength matlab/dai_potstrength.cpp matlab/matlab$(OE) exceptions$(OE) varset$(OE)\r
 \r
 matlab/matlab$(OE) : matlab/matlab.cpp matlab/matlab.h $(HEADERS)\r
        $(MEX) $(MEXFLAGS) -outdir matlab -c matlab/matlab.cpp\r
diff --git a/STATUS b/STATUS
index 940897a..f418fbb 100644 (file)
--- a/STATUS
+++ b/STATUS
@@ -2,11 +2,8 @@ OPTIMIZATION:
 - BipartiteGraph::isConnected should be optimized using boost::graph
 - Can the FactorGraph constructors be optimized further?
 - Cache second-order neighborhoods (delta's) in BipGraph?
-- Would it be a good idea to remove the states() caching from VarSet? 
-  Then, we could turn a VarSet into an IndexSet (set of integers).
-  This may restrict the use of findVar().
+- Replace VarSets by small_set<size_t> if appropriate, in order to minimize the use of findVar().
 
-- Idea: use a PropertySet as output of a DAIAlg, instead of functions like maxDiff and Iterations().
 - A DAIAlg<T> should not inherit from a FactorGraph/RegionGraph, but should store a reference to it
 
 TODO FOR RELEASE:
diff --git a/TODO b/TODO
index ba34d11..6a563c6 100644 (file)
--- a/TODO
+++ b/TODO
@@ -1,4 +1,5 @@
-IMPORTANT
+- Idea: use a PropertySet as output of a DAIAlg, instead of functions like
+  maxDiff and Iterations().
 
 - Idea: a FactorGraph and a RegionGraph are often equipped with
 extra properties for nodes and edges. The code to initialize those
@@ -50,8 +51,6 @@ Also, the current setup is stupid: I wrote a new function that works
 on FactorGraphs, and I had to write boiler plate code for it in graphicalmodel.h
 and in regiongraph.h (which is stupid).
 
-- Clean up 
-
 - Use Boost::uBLAS framework to deal with matrices, especially, with
 2D sparse matrices. See http://www.boost.org/libs/numeric/ublas/doc/matrix_sparse.htm
 and tests/errorbounds/errorbounds3
index 6332f20..a7e534b 100644 (file)
@@ -58,18 +58,18 @@ template <typename T> class TFactor {
         TFactor ( Real p = 1.0 ) : _vs(), _p(1,p) {}
 
         // Construct Factor from VarSet
-        TFactor( const VarSet& ns ) : _vs(ns), _p(_vs.states()) {}
+        TFactor( const VarSet& ns ) : _vs(ns), _p(nrStates(_vs)) {}
         
         // Construct Factor from VarSet and initial value
-        TFactor( const VarSet& ns, Real p ) : _vs(ns), _p(_vs.states(),p) {}
+        TFactor( const VarSet& ns, Real p ) : _vs(ns), _p(nrStates(_vs),p) {}
         
         // Construct Factor from VarSet and initial array
-        TFactor( const VarSet& ns, const Real* p ) : _vs(ns), _p(_vs.states(),p) {}
+        TFactor( const VarSet& ns, const Real* p ) : _vs(ns), _p(nrStates(_vs),p) {}
 
         // Construct Factor from VarSet and TProb<T>
         TFactor( const VarSet& ns, const TProb<T>& p ) : _vs(ns), _p(p) {
 #ifdef DAI_DEBUG
-            assert( _vs.states() == _p.size() );
+            assert( nrStates(_vs) == _p.size() );
 #endif
         }
         
@@ -91,12 +91,7 @@ template <typename T> class TFactor {
         const TProb<T> & p() const { return _p; }
         TProb<T> & p() { return _p; }
         const VarSet & vars() const { return _vs; }
-        size_t states() const { 
-#ifdef DAI_DEBUG
-            assert( _vs.states() == _p.size() );
-#endif
-            return _p.size();
-        }
+        size_t states() const { return _p.size(); }
 
         T operator[] (size_t i) const { return _p[i]; }
         T& operator[] (size_t i) { return _p[i]; }
index 883cb62..30ff83e 100644 (file)
 namespace dai {
 
 
-/// A VarSet represents a set of variables.
-/**
- *  It is implemented as an ordered std::vector<Var> for efficiency reasons
- *  (indeed, it was found that a std::set<Var> usually has more overhead). 
- *  In addition, it provides an interface for common set-theoretic operations.
+/// A small_set<T> represents a set, optimized for a small number of elements.
+/** For sets consisting of a small number of elements, an implementation using
+ *  an ordered std::vector<T> is faster than an implementation using std::set<T>.
+ *  The elements should be less-than-comparable.
  */
-class VarSet {
+template <typename T>
+class small_set {
     private:
-        /// The variables in this set
-        std::vector<Var> _vars;
-
-        /// Product of number of states of all contained variables
-        size_t _states;
+        /// The elements in this set
+        std::vector<T> _elements;
 
     public:
         /// Default constructor
-        VarSet() : _vars(), _states(1) {};
+        small_set() : _elements() {}
 
-        /// Construct a VarSet from one variable
-        VarSet( const Var &n ) : _vars(), _states( n.states() ) { 
-            _vars.push_back( n );
+        /// Construct a small_set with one element
+        small_set( const T &n ) : _elements() { 
+            _elements.push_back( n );
         }
 
-        /// Construct a VarSet from two variables
-        VarSet( const Var &n1, const Var &n2 ) { 
+        /// Construct a small_set with two elements
+        small_set( const T &n1, const T &n2 ) { 
             if( n1 < n2 ) {
-                _vars.push_back( n1 );
-                _vars.push_back( n2 );
-            } else if( n1 > n2 ) {
-                _vars.push_back( n2 );
-                _vars.push_back( n1 );
+                _elements.push_back( n1 );
+                _elements.push_back( n2 );
+            } else if( n2 < n1 ) {
+                _elements.push_back( n2 );
+                _elements.push_back( n1 );
             } else
-                _vars.push_back( n1 );
-            calcStates();
+                _elements.push_back( n1 );
         }
 
-        /// Construct from a range of iterators
-        /** The value_type of the VarIterator should be Var.
-         *  For efficiency, the number of variables can be
+        /// Construct a small_set from a range of iterators.
+        /** The value_type of the Iterator should be T.
+         *  For efficiency, the number of elements can be
          *  speficied by sizeHint.
          */
-        template <typename VarIterator>
-        VarSet( VarIterator begin, VarIterator end, size_t sizeHint=0 ) {
-            _vars.reserve( sizeHint );
-            _vars.insert( _vars.begin(), begin, end );
-            std::sort( _vars.begin(), _vars.end() );
-            std::vector<Var>::iterator new_end = std::unique( _vars.begin(), _vars.end() );
-            _vars.erase( new_end, _vars.end() );
-            calcStates();
+        template <typename Iterator>
+        small_set( Iterator begin, Iterator end, size_t sizeHint=0 ) {
+            _elements.reserve( sizeHint );
+            _elements.insert( _elements.begin(), begin, end );
+            std::sort( _elements.begin(), _elements.end() );
+            typename std::vector<T>::iterator new_end = std::unique( _elements.begin(), _elements.end() );
+            _elements.erase( new_end, _elements.end() );
         }
 
         /// Copy constructor
-        VarSet( const VarSet &x ) : _vars( x._vars ), _states( x._states ) {}
+        small_set( const small_set &x ) : _elements( x._elements ) {}
 
         /// Assignment operator
-        VarSet & operator=( const VarSet &x ) {
+        small_set & operator=( const small_set &x ) {
             if( this != &x ) {
-                _vars = x._vars;
-                _states = x._states;
+                _elements = x._elements;
             }
             return *this;
         }
         
-
-        /// Returns the product of the number of states of each variable in this set
-        size_t states() const { 
-            return _states; 
-        }
-        
-
-        /// Setminus operator (result contains all variables in *this, except those in ns)
-        VarSet operator/ ( const VarSet& ns ) const {
-            VarSet res;
-            std::set_difference( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
-            res.calcStates();
+        /// Setminus operator (result contains all elements in *this, except those in ns)
+        small_set operator/ ( const small_set& ns ) const {
+            small_set res;
+            std::set_difference( _elements.begin(), _elements.end(), ns._elements.begin(), ns._elements.end(), inserter( res._elements, res._elements.begin() ) );
             return res;
         }
 
-        /// Set-union operator (result contains all variables in *this, plus those in ns)
-        VarSet operator| ( const VarSet& ns ) const {
-            VarSet res;
-            std::set_union( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
-            res.calcStates();
+        /// Set-union operator (result contains all elements in *this, plus those in ns)
+        small_set operator| ( const small_set& ns ) const {
+            small_set res;
+            std::set_union( _elements.begin(), _elements.end(), ns._elements.begin(), ns._elements.end(), inserter( res._elements, res._elements.begin() ) );
             return res;
         }
 
-        /// Set-intersection operator (result contains all variables in *this that are also contained in ns)
-        VarSet operator& ( const VarSet& ns ) const {
-            VarSet res;
-            std::set_intersection( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
-            res.calcStates();
+        /// Set-intersection operator (result contains all elements in *this that are also contained in ns)
+        small_set operator& ( const small_set& ns ) const {
+            small_set res;
+            std::set_intersection( _elements.begin(), _elements.end(), ns._elements.begin(), ns._elements.end(), inserter( res._elements, res._elements.begin() ) );
             return res;
         }
         
-        /// Erases from *this all variables in ns
-        VarSet& operator/= ( const VarSet& ns ) {
+        /// Erases from *this all elements in ns
+        small_set& operator/= ( const small_set& ns ) {
             return (*this = (*this / ns));
         }
 
-        /// Erase one variable
-        VarSet& operator/= ( const Var& n ) { 
-            std::vector<Var>::iterator pos = lower_bound( _vars.begin(), _vars.end(), n );
-            if( pos != _vars.end() )
-                if( *pos == n ) { // found variable, delete it
-                    _vars.erase( pos ); 
-                    _states /= n.states();
-                }
+        /// Erase one element
+        small_set& operator/= ( const T& n ) { 
+            typename std::vector<T>::iterator pos = lower_bound( _elements.begin(), _elements.end(), n );
+            if( pos != _elements.end() )
+                if( *pos == n ) // found element, delete it
+                    _elements.erase( pos ); 
             return *this; 
         }
 
-        /// Adds to *this all variables in ns
-        VarSet& operator|= ( const VarSet& ns ) {
+        /// Adds to *this all elements in ns
+        small_set& operator|= ( const small_set& ns ) {
             return( *this = (*this | ns) );
         }
 
-        /// Add one variable
-        VarSet& operator|= ( const Var& n ) {
-            std::vector<Var>::iterator pos = lower_bound( _vars.begin(), _vars.end(), n );
-            if( pos == _vars.end() || *pos != n ) { // insert it
-                _vars.insert( pos, n );
-                _states *= n.states();
-            }
+        /// Add one element
+        small_set& operator|= ( const T& n ) {
+            typename std::vector<T>::iterator pos = lower_bound( _elements.begin(), _elements.end(), n );
+            if( pos == _elements.end() || *pos != n ) // insert it
+                _elements.insert( pos, n );
             return *this;
         }
 
-
-        /// Erases from *this all variables not in ns
-        VarSet& operator&= ( const VarSet& ns ) { 
+        /// Erases from *this all elements not in ns
+        small_set& operator&= ( const small_set& ns ) { 
             return (*this = (*this & ns));
         }
 
-
         /// Returns true if *this is a subset of ns
-        bool operator<< ( const VarSet& ns ) const { 
-            return std::includes( ns._vars.begin(), ns._vars.end(), _vars.begin(), _vars.end() );
+        bool operator<< ( const small_set& ns ) const { 
+            return std::includes( ns._elements.begin(), ns._elements.end(), _elements.begin(), _elements.end() );
         }
 
         /// Returns true if ns is a subset of *this
-        bool operator>> ( const VarSet& ns ) const { 
-            return std::includes( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end() );
+        bool operator>> ( const small_set& ns ) const { 
+            return std::includes( _elements.begin(), _elements.end(), ns._elements.begin(), ns._elements.end() );
         }
 
-        /// Returns true if *this and ns contain common variables
-        bool intersects( const VarSet& ns ) const { 
+        /// Returns true if *this and ns contain common elements
+        bool intersects( const small_set& ns ) const { 
             return( (*this & ns).size() > 0 ); 
         }
 
-        /// Returns true if *this contains the variable n
-        bool contains( const Var& n ) const { 
-            return std::binary_search( _vars.begin(), _vars.end(), n );
+        /// Returns true if *this contains the element n
+        bool contains( const T& n ) const { 
+            return std::binary_search( _elements.begin(), _elements.end(), n );
         }
 
-        /// Sends a VarSet to an output stream
-        friend std::ostream& operator<< (std::ostream & os, const VarSet& ns) {
-            foreach( const Var &n, ns._vars )
-                os << n;
-            return( os );
-        }
-
-        /// Constant iterator over Vars
-        typedef std::vector<Var>::const_iterator const_iterator;
-        /// Iterator over Vars
-        typedef std::vector<Var>::iterator iterator;
-        /// Constant reverse iterator over Vars
-        typedef std::vector<Var>::const_reverse_iterator const_reverse_iterator;
-        /// Reverse iterator over Vars
-        typedef std::vector<Var>::reverse_iterator reverse_iterator;
+        /// Constant iterator over the elements
+        typedef typename std::vector<T>::const_iterator const_iterator;
+        /// Iterator over the elements
+        typedef typename std::vector<T>::iterator iterator;
+        /// Constant reverse iterator over the elements
+        typedef typename std::vector<T>::const_reverse_iterator const_reverse_iterator;
+        /// Reverse iterator over the elements
+        typedef typename std::vector<T>::reverse_iterator reverse_iterator;
         
-        /// Returns iterator that points to the first variable
-        iterator begin() { return _vars.begin(); }
-        /// Returns constant iterator that points to the first variable
-        const_iterator begin() const { return _vars.begin(); }
-
-        /// Returns iterator that points beyond the last variable
-        iterator end() { return _vars.end(); }
-        /// Returns constant iterator that points beyond the last variable
-        const_iterator end() const { return _vars.end(); }
-
-        /// Returns reverse iterator that points to the last variable
-        reverse_iterator rbegin() { return _vars.rbegin(); }
-        /// Returns constant reverse iterator that points to the last variable
-        const_reverse_iterator rbegin() const { return _vars.rbegin(); }
-
-        /// Returns reverse iterator that points beyond the first variable
-        reverse_iterator rend() { return _vars.rend(); }
-        /// Returns constant reverse iterator that points beyond the first variable
-        const_reverse_iterator rend() const { return _vars.rend(); }
+        /// Returns iterator that points to the first element
+        iterator begin() { return _elements.begin(); }
+        /// Returns constant iterator that points to the first element
+        const_iterator begin() const { return _elements.begin(); }
 
+        /// Returns iterator that points beyond the last element
+        iterator end() { return _elements.end(); }
+        /// Returns constant iterator that points beyond the last element
+        const_iterator end() const { return _elements.end(); }
 
-        /// Returns number of variables
-        std::vector<Var>::size_type size() const { return _vars.size(); }
+        /// Returns reverse iterator that points to the last element
+        reverse_iterator rbegin() { return _elements.rbegin(); }
+        /// Returns constant reverse iterator that points to the last element
+        const_reverse_iterator rbegin() const { return _elements.rbegin(); }
 
+        /// Returns reverse iterator that points beyond the first element
+        reverse_iterator rend() { return _elements.rend(); }
+        /// Returns constant reverse iterator that points beyond the first element
+        const_reverse_iterator rend() const { return _elements.rend(); }
 
-        /// Returns whether the VarSet is empty
-        bool empty() const { return _vars.size() == 0; }
+        /// Returns number of elements
+        typename std::vector<T>::size_type size() const { return _elements.size(); }
 
+        /// Returns whether the small_set is empty
+        bool empty() const { return _elements.size() == 0; }
 
-        /// Test for equality of variable labels
-        friend bool operator==( const VarSet &a, const VarSet &b ) {
-            return (a._vars == b._vars);
+        /// Test for equality of element labels
+        friend bool operator==( const small_set &a, const small_set &b ) {
+            return (a._elements == b._elements);
         }
 
-        /// Test for inequality of variable labels
-        friend bool operator!=( const VarSet &a, const VarSet &b ) {
-            return !(a._vars == b._vars);
+        /// Test for inequality of element labels
+        friend bool operator!=( const small_set &a, const small_set &b ) {
+            return !(a._elements == b._elements);
         }
 
-        /// Lexicographical comparison of variable labels
-        friend bool operator<( const VarSet &a, const VarSet &b ) {
-            return a._vars < b._vars;
+        /// Lexicographical comparison of element labels
+        friend bool operator<( const small_set &a, const small_set &b ) {
+            return a._elements < b._elements;
         }
+};
 
-        /// calcState calculates the linear index of this VarSet that corresponds
-        /// to the states of the variables given in states, implicitly assuming
-        /// states[m] = 0 for all m in this VarSet which are not in states.
-        size_t calcState( const std::map<Var, size_t> &states ) const {
-            size_t prod = 1;
-            size_t state = 0;
-            foreach( const Var &n, *this ) {
-                std::map<Var, size_t>::const_iterator m = states.find( n );
-                if( m != states.end() )
-                    state += prod * m->second;
-                prod *= n.states();
-            }
-            return state;
-        }
 
-    private:
-        /// Calculates the number of states
-        size_t calcStates() {
-            _states = 1;
-            foreach( Var &i, _vars )
-                _states *= i.states();
-            return _states;
-        }
-};
+/// A VarSet represents a set of variables.
+/**
+ *  It is implemented as an ordered std::vector<Var> for efficiency reasons
+ *  (indeed, it was found that a std::set<Var> usually has more overhead). 
+ *  In addition, it provides an interface for common set-theoretic operations.
+ */
+typedef small_set<Var> VarSet;
 
 
 /// For two Vars n1 and n2, the expression n1 | n2 gives the Varset containing n1 and n2
@@ -280,6 +233,20 @@ inline VarSet operator| (const Var& n1, const Var& n2) {
 }
 
 
+/// Calculates the product of number of states of all variables in vars
+size_t nrStates( const VarSet &vars );
+
+
+/// calcState calculates the linear index of vars that corresponds
+/// to the states of the variables given in states, implicitly assuming
+/// states[m] = 0 for all m in this VarSet which are not in states.
+size_t calcState( const VarSet &vars, const std::map<Var, size_t> &states );
+
+
+/// Sends a VarSet to an output stream
+std::ostream& operator<< (std::ostream &os, const VarSet& ns);
+
+
 } // end of namespace dai
 
 
index adcc545..330b0fb 100644 (file)
@@ -187,8 +187,8 @@ void HAK::init( const VarSet &ns ) {
             _Qb[beta].fill( 1.0 );
             foreach( const Neighbor &alpha, nbIR(beta) ) {
                 size_t _beta = alpha.dual;
-                muab( alpha, _beta ).fill( 1.0 / IR(beta).states() );
-                muba( alpha, _beta ).fill( 1.0 / IR(beta).states() );
+                muab( alpha, _beta ).fill( 1.0 );
+                muba( alpha, _beta ).fill( 1.0 );
             }
         }
 }
index 3f16a28..a64d84a 100644 (file)
@@ -323,7 +323,7 @@ size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t Previo
     // find new root clique (the one with maximal statespace overlap with ns)
     size_t maxval = 0, maxalpha = 0;
     for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
-        size_t val = (ns & OR(alpha).vars()).states();
+        size_t val = nrStates( ns & OR(alpha).vars() );
         if( val > maxval ) {
             maxval = val;
             maxalpha = alpha;
@@ -524,8 +524,9 @@ pair<size_t,size_t> treewidth( const FactorGraph & fg ) {
     for( size_t i = 0; i < ElimVec.size(); i++ ) {
         if( ElimVec[i].size() > treewidth )
             treewidth = ElimVec[i].size();
-        if( ElimVec[i].states() > nrstates )
-            nrstates = ElimVec[i].states();
+        size_t s = nrStates(ElimVec[i]);
+        if( s > nrstates )
+            nrstates = s;
     }
 
     return pair<size_t,size_t>(treewidth, nrstates);
index 2c796fe..3aae934 100644 (file)
@@ -135,7 +135,7 @@ double LC::CalcCavityDist (size_t i, const std::string &name, const PropertySet
     double maxdiff = 0;
 
     if( props.verbose >= 2 )
-        cout << "Initing cavity " << var(i) << "(" << delta(i).size() << " vars, " << delta(i).states() << " states)" << endl;
+        cout << "Initing cavity " << var(i) << "(" << delta(i).size() << " vars, " << nrStates(delta(i)) << " states)" << endl;
 
     if( props.cavity == Properties::CavityType::UNIFORM )
         Bi = Factor(delta(i));
diff --git a/src/varset.cpp b/src/varset.cpp
new file mode 100644 (file)
index 0000000..ab4f2de
--- /dev/null
@@ -0,0 +1,73 @@
+/*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
+    Copyright (C) 2002  Martijn Leisink  [martijn@mbfys.kun.nl]
+    Radboud University Nijmegen, The Netherlands
+
+    This file is part of libDAI.
+
+    libDAI is free software; you can redistribute it and/or modify
+    it under the terms of the GNU General Public License as published by
+    the Free Software Foundation; either version 2 of the License, or
+    (at your option) any later version.
+
+    libDAI is distributed in the hope that it will be useful,
+    but WITHOUT ANY WARRANTY; without even the implied warranty of
+    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+    GNU General Public License for more details.
+
+    You should have received a copy of the GNU General Public License
+    along with libDAI; if not, write to the Free Software
+    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+*/
+
+
+/*#include <vector>
+#include <map>
+#include <algorithm>
+#include <iostream>
+#include <cassert>
+#include <dai/var.h>
+#include <dai/util.h>
+*/
+#include <dai/varset.h>
+
+
+namespace dai {
+
+
+using namespace std;
+
+
+/// Calculates the product of number of states of all variables in vars
+size_t nrStates( const VarSet &vars ) {
+    size_t states = 1;
+    for( VarSet::const_iterator n = vars.begin(); n != vars.end(); n++ )
+        states *= n->states();
+    return states;
+}
+
+
+/// calcState calculates the linear index of vars that corresponds
+/// to the states of the variables given in states, implicitly assuming
+/// states[m] = 0 for all m in this VarSet which are not in states.
+size_t calcState( const VarSet &vars, const std::map<Var, size_t> &states ) {
+    size_t prod = 1;
+    size_t state = 0;
+    for( VarSet::const_iterator n = vars.begin(); n != vars.end(); n++ ) {
+        map<Var, size_t>::const_iterator m = states.find( *n );
+        if( m != states.end() )
+            state += prod * m->second;
+        prod *= n->states();
+    }
+    return state;
+}
+
+
+/// Sends a VarSet to an output stream
+std::ostream& operator<< (std::ostream &os, const VarSet& ns) {
+    for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
+        os << *n;
+    return( os );
+}
+
+
+} // end of namespace dai
index f13e84b..1e8b88a 100644 (file)
@@ -122,10 +122,10 @@ int main( int argc, char *argv[] ) {
                 cavsizes[di.size()]++;
             else
                 cavsizes[di.size()] = 1;
-            size_t Ds = fg.Delta(i).states();
+            size_t Ds = nrStates( fg.Delta(i) );
             if( Ds > max_Delta_size )
                 max_Delta_size = Ds;
-            cavsum_lcbp += di.states();
+            cavsum_lcbp += nrStates( di );
             for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
                 cavsum_lcbp2 += j->states();
         }