namespace dai {
-/// Represents a discrete variable
+/// Represents a discrete variable.
+/* It contains the label of the variable, an integer-valued
+ * unique ID for that variable, and the number of possible
+ * values ("states") of the variable.
+ */
class Var {
private:
/// Internal label of the variable
public:
/// Default constructor
- Var() : _label(-1), _states(0) {};
+ Var() : _label(-1), _states(0) {}
/// Constructor
- Var(long label, size_t states) : _label(label), _states(states) {};
+ Var( long label, size_t states ) : _label(label), _states(states) {}
- /// Read access to label
- long label() const { return _label; };
- /// Access to label
- long & label() { return _label; };
+ /// Gets the label
+ long label() const { return _label; }
+ /// Returns reference to label
+ long & label() { return _label; }
- /// Read access to states
- size_t states () const { return _states; };
- /// Access to states
- size_t& states () { return _states; };
+ /// Gets the number of states
+ size_t states () const { return _states; }
+ /// Returns reference to number of states
+ size_t& states () { return _states; }
/// Smaller-than operator (compares labels)
- bool operator < (const Var& n) const { return( _label < n._label ); };
+ bool operator < ( const Var& n ) const { return( _label < n._label ); }
/// Larger-than operator (compares labels)
- bool operator > (const Var& n) const { return( _label > n._label ); };
+ bool operator > ( const Var& n ) const { return( _label > n._label ); }
/// Smaller-than-or-equal-to operator (compares labels)
- bool operator <= (const Var& n) const { return( _label <= n._label ); };
+ bool operator <= ( const Var& n ) const { return( _label <= n._label ); }
/// Larger-than-or-equal-to operator (compares labels)
- bool operator >= (const Var& n) const { return( _label >= n._label ); };
+ bool operator >= ( const Var& n ) const { return( _label >= n._label ); }
/// Not-equal-to operator (compares labels)
- bool operator != (const Var& n) const { return( _label != n._label ); };
+ bool operator != ( const Var& n ) const { return( _label != n._label ); }
/// Equal-to operator (compares labels)
- bool operator == (const Var& n) const { return( _label == n._label ); };
+ bool operator == ( const Var& n ) const { return( _label == n._label ); }
- /// Stream output operator
- friend std::ostream& operator << (std::ostream& os, const Var& n) {
+ /// Write this Var to stream
+ friend std::ostream& operator << ( std::ostream& os, const Var& n ) {
return( os << "[" << n.label() << "]" );
- };
+ }
};
#define __defined_libdai_varset_h
-#include <set>
+#include <vector>
+#include <map>
#include <algorithm>
#include <iostream>
#include <cassert>
#include <dai/var.h>
+#include <dai/util.h>
namespace dai {
-/// VarSet represents a set of variables and is a descendant of set<Var>.
-/// In addition, it provides an easy interface for set-theoretic operations
-/// by operator overloading.
-class VarSet : private std::set<Var> {
+/// Represents a set of variables.
+/**
+ * It is implemented as an ordered vector<Var> for efficiency reasons
+ * (this is more efficient than a set<Var>). In addition, it provides
+ * an interface for common set-theoretic operations.
+ */
+class VarSet {
protected:
- /// Product of number of states of all contained variables
- size_t _statespace;
-
- /// Check whether ns is a subset
- bool includes( const VarSet& ns ) const {
- return std::includes( begin(), end(), ns.begin(), ns.end() );
- }
-
- /// Calculate statespace
- size_t calcStateSpace() {
- _statespace = 1;
- for( const_iterator i = begin(); i != end(); ++i )
- _statespace *= i->states();
- return _statespace;
- }
+ /// The variables in this set
+ std::vector<Var> _vars;
+ /// Product of number of states of all contained variables
+ size_t _states;
public:
/// Default constructor
- VarSet() : _statespace(0) {};
+ VarSet() : _vars(), _states(1) {};
- /// Construct a VarSet with one variable
- VarSet( const Var &n ) : _statespace( n.states() ) {
- insert( n );
+ /// Construct a VarSet from one variable
+ VarSet( const Var &n ) : _vars(), _states( n.states() ) {
+ _vars.push_back( n );
}
- /// Construct a VarSet with two variables
+ /// Construct a VarSet from two variables
VarSet( const Var &n1, const Var &n2 ) {
- insert( n1 );
- insert( n2 );
- calcStateSpace();
+ if( n1 < n2 ) {
+ _vars.push_back( n1 );
+ _vars.push_back( n2 );
+ } else if( n1 > n2 ) {
+ _vars.push_back( n2 );
+ _vars.push_back( n1 );
+ } else
+ _vars.push_back( n1 );
+ calcStates();
}
- /// Construct from a set<Var>
- VarSet( const std::set<Var> &ns ) {
- std::set<Var>::operator=( ns );
- calcStateSpace();
- }
-
- /// Construct from a vector<Var>
- VarSet( const std::vector<Var> &ns ) {
- for( std::vector<Var>::const_iterator n = ns.begin(); n != ns.end(); n++ )
- insert( *n );
- calcStateSpace();
+ /// Construct from a range of iterators
+ /* The value_type of the VarIterator should be Var.
+ * For efficiency, the number of variables can be
+ * speficied by sizeHint.
+ */
+ template <typename VarIterator>
+ VarSet( VarIterator begin, VarIterator end, size_t sizeHint=0 ) {
+ _vars.reserve( sizeHint );
+ _vars.insert( _vars.begin(), begin, end );
+ std::sort( _vars.begin(), _vars.end() );
+ std::vector<Var>::iterator new_end = std::unique( _vars.begin(), _vars.end() );
+ _vars.erase( new_end, _vars.end() );
+ calcStates();
}
/// Copy constructor
- VarSet( const VarSet &x ) : std::set<Var>( x ), _statespace( x._statespace ) {}
+ VarSet( const VarSet &x ) : _vars( x._vars ), _states( x._states ) {}
/// Assignment operator
VarSet & operator=( const VarSet &x ) {
if( this != &x ) {
- std::set<Var>::operator=( x );
- _statespace = x._statespace;
+ _vars = x._vars;
+ _states = x._states;
}
return *this;
}
- /// Return statespace, i.e. the product of the number of states of each variable
+ /// Return the product of the number of states of each variable in this set
size_t states() const {
-#ifdef DAI_DEBUG
- size_t x = 1;
- for( const_iterator i = begin(); i != end(); ++i )
- x *= i->states();
- assert( x == _statespace );
-#endif
- return _statespace;
+ return _states;
}
- /// Erase one variable
- VarSet& operator/= (const Var& n) {
- erase( n );
- calcStateSpace();
- return *this;
- }
-
- /// Add one variable
- VarSet& operator|= (const Var& n) {
- insert( n );
- calcStateSpace();
- return *this;
- }
-
/// Setminus operator (result contains all variables except those in ns)
- VarSet operator/ (const VarSet& ns) const {
+ VarSet operator/ ( const VarSet& ns ) const {
VarSet res;
- std::set_difference( begin(), end(), ns.begin(), ns.end(), inserter( res, res.begin() ) );
- res.calcStateSpace();
+ std::set_difference( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
+ res.calcStates();
return res;
}
/// Set-union operator (result contains all variables plus those in ns)
- VarSet operator| (const VarSet& ns) const {
+ VarSet operator| ( const VarSet& ns ) const {
VarSet res;
- std::set_union( begin(), end(), ns.begin(), ns.end(), inserter( res, res.begin() ) );
- res.calcStateSpace();
+ std::set_union( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
+ res.calcStates();
return res;
}
/// Set-intersection operator (result contains all variables that are also contained in ns)
- VarSet operator& (const VarSet& ns) const {
+ VarSet operator& ( const VarSet& ns ) const {
VarSet res;
- std::set_intersection( begin(), end(), ns.begin(), ns.end(), inserter( res, res.begin() ) );
- res.calcStateSpace();
+ std::set_intersection( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
+ res.calcStates();
return res;
}
/// Erases from *this all variables in ns
- VarSet& operator/= (const VarSet& ns) {
+ VarSet& operator/= ( const VarSet& ns ) {
return (*this = (*this / ns));
}
- /// Adds to *this all variables in ns
- VarSet& operator|= (const VarSet& ns) {
- return (*this = (*this | ns));
+ /// Erase one variable
+ VarSet& operator/= ( const Var& n ) {
+ std::vector<Var>::iterator pos = lower_bound( _vars.begin(), _vars.end(), n );
+ if( pos != _vars.end() )
+ if( *pos == n ) { // found variable, delete it
+ _vars.erase( pos );
+ _states /= n.states();
+ }
+ return *this;
}
- /// Erases from *this all variables not in ns
- VarSet& operator&= (const VarSet& ns) {
- return (*this = (*this & ns));
+ /// Adds to *this all variables in ns
+ VarSet& operator|= ( const VarSet& ns ) {
+ return( *this = (*this | ns) );
}
-
- /// Returns false if both *this and ns are empty
- bool operator|| (const VarSet& ns) const {
- return !( this->empty() && ns.empty() );
+ /// Add one variable
+ VarSet& operator|= ( const Var& n ) {
+ std::vector<Var>::iterator pos = lower_bound( _vars.begin(), _vars.end(), n );
+ if( pos == _vars.end() || *pos != n ) { // insert it
+ _vars.insert( pos, n );
+ _states *= n.states();
+ }
+ return *this;
}
- /// Returns true if *this and ns contain common variables
- bool operator&& (const VarSet& ns) const {
- return !( (*this & ns).empty() );
+
+ /// Erases from *this all variables not in ns
+ VarSet& operator&= ( const VarSet& ns ) {
+ return (*this = (*this & ns));
}
+
/// Returns true if *this is a subset of ns
- bool operator<< (const VarSet& ns) const {
- return ns.includes( *this );
+ bool operator<< ( const VarSet& ns ) const {
+ return std::includes( ns._vars.begin(), ns._vars.end(), _vars.begin(), _vars.end() );
}
/// Returns true if ns is a subset of *this
- bool operator>> (const VarSet& ns) const {
- return includes( ns );
+ bool operator>> ( const VarSet& ns ) const {
+ return std::includes( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end() );
+ }
+
+ /// Returns true if *this and ns contain common variables
+ bool intersects( const VarSet& ns ) const {
+ return( (*this & ns).size() > 0 );
}
/// Returns true if *this contains the variable n
- bool operator&& (const Var& n) const {
- return( find( n ) == end() ? false : true );
+ bool contains( const Var& n ) const {
+ return std::binary_search( _vars.begin(), _vars.end(), n );
}
-
/// Sends a VarSet to an output stream
friend std::ostream& operator<< (std::ostream & os, const VarSet& ns) {
- for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++)
- os << *n;
+ foreach( const Var &n, ns._vars )
+ os << n;
return( os );
}
+ // Constant iterator
+ typedef std::vector<Var>::const_iterator const_iterator;
+ /// Iterator
+ typedef std::vector<Var>::iterator iterator;
+ // Constant reverse iterator
+ typedef std::vector<Var>::const_reverse_iterator const_reverse_iterator;
+ /// Reverse Iterator
+ typedef std::vector<Var>::reverse_iterator reverse_iterator;
-/* The following makes part of the public interface of set<Var> available.
- * It is important to note that insert functions have to be overloaded,
- * because they have to recalculate the statespace. A different approach
- * would be to publicly inherit from set<Var> and only overload the insert
- * methods.
- */
-
- // Additional interface from set<Var> that has to be provided
- using std::set<Var>::const_iterator;
- using std::set<Var>::iterator;
- using std::set<Var>::const_reference;
- using std::set<Var>::begin;
- using std::set<Var>::end;
- using std::set<Var>::size;
- using std::set<Var>::empty;
-
- /// Copy of set<Var>::insert which additionally calculates the new statespace
- std::pair<iterator, bool> insert( const Var& x ) {
- std::pair<iterator, bool> result = std::set<Var>::insert( x );
- calcStateSpace();
- return result;
- }
+ /// Returns iterator that points to the first variable
+ iterator begin() { return _vars.begin(); }
+ /// Returns constant iterator that points to the first variable
+ const_iterator begin() const { return _vars.begin(); }
- /// Copy of set<Var>::insert which additionally calculates the new statespace
- iterator insert( iterator pos, const value_type& x ) {
- iterator result = std::set<Var>::insert( pos, x );
- calcStateSpace();
- return result;
- }
+ /// Returns iterator that points beyond the last variable
+ iterator end() { return _vars.end(); }
+ /// Returns constant iterator that points beyond the last variable
+ const_iterator end() const { return _vars.end(); }
- /// Test for equality (ignore _statespace member)
+ /// Returns reverse iterator that points to the last variable
+ reverse_iterator rbegin() { return _vars.rbegin(); }
+ /// Returns constant reverse iterator that points to the last variable
+ const_reverse_iterator rbegin() const { return _vars.rbegin(); }
+
+ /// Returns reverse iterator that points beyond the first variable
+ reverse_iterator rend() { return _vars.rend(); }
+ /// Returns constant reverse iterator that points beyond the first variable
+ const_reverse_iterator rend() const { return _vars.rend(); }
+
+
+ /// Returns number of variables
+ std::vector<Var>::size_type size() const { return _vars.size(); }
+
+
+ /// Returns whether the set is empty
+ bool empty() const { return _vars.size() == 0; }
+
+
+ /// Test for equality of variable labels
friend bool operator==( const VarSet &a, const VarSet &b ) {
- return operator==( (std::set<Var>)a, (std::set<Var>)b );
+ return (a._vars == b._vars);
}
- /// Test for inequality (ignore _statespace member)
+ /// Test for inequality of variable labels
friend bool operator!=( const VarSet &a, const VarSet &b ) {
- return operator!=( (std::set<Var>)a, (std::set<Var>)b );
+ return !(a._vars == b._vars);
}
+ /// Lexicographical comparison of variable labels
friend bool operator<( const VarSet &a, const VarSet &b ) {
- return operator<( (std::set<Var>)a, (std::set<Var>)b );
+ return a._vars < b._vars;
+ }
+
+ /// calcState calculates the linear index of this VarSet that corresponds
+ /// to the states of the variables given in states, implicitly assuming
+ /// states[m] = 0 for all m in this VarSet which are not in states.
+ size_t calcState( const std::map<Var, size_t> &states ) const {
+ size_t prod = 1;
+ size_t state = 0;
+ foreach( const Var &n, *this ) {
+ std::map<Var, size_t>::const_iterator m = states.find( n );
+ if( m != states.end() )
+ state += prod * m->second;
+ prod *= n.states();
+ }
+ return state;
+ }
+
+ protected:
+ /// Calculate number of states
+ size_t calcStates() {
+ _states = 1;
+ foreach( Var &i, _vars )
+ _states *= i.states();
+ return _states;
}
};