Cleaned up error handling by introducing the DAI_THROWE macro.
[libdai.git] / include / dai / varset.h
index 883cb62..e45e40d 100644 (file)
@@ -1,7 +1,10 @@
-/*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
+/*  Copyright (C) 2006-2008  Joris Mooij  [joris dot mooij at tuebingen dot mpg dot de]
+    Radboud University Nijmegen, The Netherlands /
+    Max Planck Institute for Biological Cybernetics, Germany
+
     Copyright (C) 2002  Martijn Leisink  [martijn@mbfys.kun.nl]
     Radboud University Nijmegen, The Netherlands
-    
+
     This file is part of libDAI.
 
     libDAI is free software; you can redistribute it and/or modify
 */
 
 
+/// \file
+/// \brief Defines VarSet class
+
+
 #ifndef __defined_libdai_varset_h
 #define __defined_libdai_varset_h
 
 
 #include <vector>
 #include <map>
-#include <algorithm>
-#include <iostream>
 #include <cassert>
+#include <ostream>
 #include <dai/var.h>
 #include <dai/util.h>
+#include <dai/smallset.h>
 
 
 namespace dai {
 
 
-/// A VarSet represents a set of variables.
-/**
- *  It is implemented as an ordered std::vector<Var> for efficiency reasons
- *  (indeed, it was found that a std::set<Var> usually has more overhead). 
- *  In addition, it provides an interface for common set-theoretic operations.
+/// Represents a set of variables.
+/** \note A VarSet is implemented using a SmallSet<Var> instead
+ *  of the more natural std::set<Var> because of efficiency reasons.
+ *  That is, internally, the variables in the set are sorted according 
+ *  to their labels: the set of variables \f$\{x_l\}_{l\in L}\f$ is 
+ *  represented as a vector \f$(x_{l(0)},x_{l(1)},\dots,x_{l(|L|-1)})\f$ 
+ *  where \f$l(0) < l(1) < \dots < l(|L|-1)\f$ 
+ *  and \f$L = \{l(0),l(1),\dots,l(|L|-1)\}\f$.
  */
-class VarSet {
-    private:
-        /// The variables in this set
-        std::vector<Var> _vars;
-
-        /// Product of number of states of all contained variables
-        size_t _states;
-
+class VarSet : public SmallSet<Var> {
     public:
         /// Default constructor
-        VarSet() : _vars(), _states(1) {};
-
-        /// Construct a VarSet from one variable
-        VarSet( const Var &n ) : _vars(), _states( n.states() ) { 
-            _vars.push_back( n );
-        }
-
-        /// Construct a VarSet from two variables
-        VarSet( const Var &n1, const Var &n2 ) { 
-            if( n1 < n2 ) {
-                _vars.push_back( n1 );
-                _vars.push_back( n2 );
-            } else if( n1 > n2 ) {
-                _vars.push_back( n2 );
-                _vars.push_back( n1 );
-            } else
-                _vars.push_back( n1 );
-            calcStates();
-        }
-
-        /// Construct from a range of iterators
-        /** The value_type of the VarIterator should be Var.
-         *  For efficiency, the number of variables can be
-         *  speficied by sizeHint.
+        VarSet() : SmallSet<Var>() {}
+
+        /// Construct from SmallSet<Var>
+        VarSet( const SmallSet<Var> &x ) : SmallSet<Var>(x) {}
+
+        /// Calculates the number of states of this VarSet.
+        /** The number of states of the Cartesian product of the variables in this VarSet
+         *  is simply the product of the number of states of each variable in this VarSet.
+         *  If *this corresponds with the set \f$\{x_l\}_{l\in L}\f$,
+         *  where variable \f$x_l\f$ has label \f$l\f$, and denoting by \f$S_l\f$ the 
+         *  number of possible values ("states") of variable \f$x_l\f$, the number of 
+         *  joint configurations of the variables in \f$\{x_l\}_{l\in L}\f$ is given by \f$\prod_{l\in L} S_l\f$.
          */
-        template <typename VarIterator>
-        VarSet( VarIterator begin, VarIterator end, size_t sizeHint=0 ) {
-            _vars.reserve( sizeHint );
-            _vars.insert( _vars.begin(), begin, end );
-            std::sort( _vars.begin(), _vars.end() );
-            std::vector<Var>::iterator new_end = std::unique( _vars.begin(), _vars.end() );
-            _vars.erase( new_end, _vars.end() );
-            calcStates();
-        }
-
-        /// Copy constructor
-        VarSet( const VarSet &x ) : _vars( x._vars ), _states( x._states ) {}
-
-        /// Assignment operator
-        VarSet & operator=( const VarSet &x ) {
-            if( this != &x ) {
-                _vars = x._vars;
-                _states = x._states;
-            }
-            return *this;
-        }
-        
-
-        /// Returns the product of the number of states of each variable in this set
-        size_t states() const { 
-            return _states; 
-        }
-        
-
-        /// Setminus operator (result contains all variables in *this, except those in ns)
-        VarSet operator/ ( const VarSet& ns ) const {
-            VarSet res;
-            std::set_difference( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
-            res.calcStates();
-            return res;
-        }
-
-        /// Set-union operator (result contains all variables in *this, plus those in ns)
-        VarSet operator| ( const VarSet& ns ) const {
-            VarSet res;
-            std::set_union( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
-            res.calcStates();
-            return res;
-        }
-
-        /// Set-intersection operator (result contains all variables in *this that are also contained in ns)
-        VarSet operator& ( const VarSet& ns ) const {
-            VarSet res;
-            std::set_intersection( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
-            res.calcStates();
-            return res;
-        }
-        
-        /// Erases from *this all variables in ns
-        VarSet& operator/= ( const VarSet& ns ) {
-            return (*this = (*this / ns));
-        }
-
-        /// Erase one variable
-        VarSet& operator/= ( const Var& n ) { 
-            std::vector<Var>::iterator pos = lower_bound( _vars.begin(), _vars.end(), n );
-            if( pos != _vars.end() )
-                if( *pos == n ) { // found variable, delete it
-                    _vars.erase( pos ); 
-                    _states /= n.states();
-                }
-            return *this; 
-        }
-
-        /// Adds to *this all variables in ns
-        VarSet& operator|= ( const VarSet& ns ) {
-            return( *this = (*this | ns) );
-        }
-
-        /// Add one variable
-        VarSet& operator|= ( const Var& n ) {
-            std::vector<Var>::iterator pos = lower_bound( _vars.begin(), _vars.end(), n );
-            if( pos == _vars.end() || *pos != n ) { // insert it
-                _vars.insert( pos, n );
-                _states *= n.states();
-            }
-            return *this;
+        size_t nrStates() {
+            size_t states = 1;
+            for( VarSet::const_iterator n = begin(); n != end(); n++ )
+                states *= n->states();
+            return states;
         }
 
+        /// Construct a VarSet with one element
+        VarSet( const Var &n ) : SmallSet<Var>(n) {}
 
-        /// Erases from *this all variables not in ns
-        VarSet& operator&= ( const VarSet& ns ) { 
-            return (*this = (*this & ns));
-        }
-
-
-        /// Returns true if *this is a subset of ns
-        bool operator<< ( const VarSet& ns ) const { 
-            return std::includes( ns._vars.begin(), ns._vars.end(), _vars.begin(), _vars.end() );
-        }
+        /// Construct a VarSet with two elements
+        VarSet( const Var &n1, const Var &n2 ) : SmallSet<Var>(n1,n2) {} 
 
-        /// Returns true if ns is a subset of *this
-        bool operator>> ( const VarSet& ns ) const { 
-            return std::includes( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end() );
-        }
-
-        /// Returns true if *this and ns contain common variables
-        bool intersects( const VarSet& ns ) const { 
-            return( (*this & ns).size() > 0 ); 
-        }
-
-        /// Returns true if *this contains the variable n
-        bool contains( const Var& n ) const { 
-            return std::binary_search( _vars.begin(), _vars.end(), n );
-        }
-
-        /// Sends a VarSet to an output stream
-        friend std::ostream& operator<< (std::ostream & os, const VarSet& ns) {
-            foreach( const Var &n, ns._vars )
-                os << n;
-            return( os );
-        }
-
-        /// Constant iterator over Vars
-        typedef std::vector<Var>::const_iterator const_iterator;
-        /// Iterator over Vars
-        typedef std::vector<Var>::iterator iterator;
-        /// Constant reverse iterator over Vars
-        typedef std::vector<Var>::const_reverse_iterator const_reverse_iterator;
-        /// Reverse iterator over Vars
-        typedef std::vector<Var>::reverse_iterator reverse_iterator;
-        
-        /// Returns iterator that points to the first variable
-        iterator begin() { return _vars.begin(); }
-        /// Returns constant iterator that points to the first variable
-        const_iterator begin() const { return _vars.begin(); }
-
-        /// Returns iterator that points beyond the last variable
-        iterator end() { return _vars.end(); }
-        /// Returns constant iterator that points beyond the last variable
-        const_iterator end() const { return _vars.end(); }
-
-        /// Returns reverse iterator that points to the last variable
-        reverse_iterator rbegin() { return _vars.rbegin(); }
-        /// Returns constant reverse iterator that points to the last variable
-        const_reverse_iterator rbegin() const { return _vars.rbegin(); }
-
-        /// Returns reverse iterator that points beyond the first variable
-        reverse_iterator rend() { return _vars.rend(); }
-        /// Returns constant reverse iterator that points beyond the first variable
-        const_reverse_iterator rend() const { return _vars.rend(); }
-
-
-        /// Returns number of variables
-        std::vector<Var>::size_type size() const { return _vars.size(); }
-
-
-        /// Returns whether the VarSet is empty
-        bool empty() const { return _vars.size() == 0; }
-
-
-        /// Test for equality of variable labels
-        friend bool operator==( const VarSet &a, const VarSet &b ) {
-            return (a._vars == b._vars);
-        }
-
-        /// Test for inequality of variable labels
-        friend bool operator!=( const VarSet &a, const VarSet &b ) {
-            return !(a._vars == b._vars);
-        }
-
-        /// Lexicographical comparison of variable labels
-        friend bool operator<( const VarSet &a, const VarSet &b ) {
-            return a._vars < b._vars;
-        }
-
-        /// calcState calculates the linear index of this VarSet that corresponds
-        /// to the states of the variables given in states, implicitly assuming
-        /// states[m] = 0 for all m in this VarSet which are not in states.
-        size_t calcState( const std::map<Var, size_t> &states ) const {
+        /// Construct a VarSet from a range.
+        /** \tparam VarIterator Iterates over instances of type Var.
+         *  \param begin Points to first Var to be added.
+         *  \param end Points just beyond last Var to be added.
+         *  \param sizeHint For efficiency, the number of elements can be speficied by sizeHint.
+         */
+        template <typename VarIterator>
+        VarSet( VarIterator begin, VarIterator end, size_t sizeHint=0 ) : SmallSet<Var>(begin,end,sizeHint) {}
+
+        /// Calculates the linear index in the Cartesian product of the variables in *this, which corresponds to a particular joint assignment of the variables specified by \a states.
+        /** \param states Specifies the states of some variables.
+         *  \return The linear index in the Cartesian product of the variables in *this
+         *  corresponding with the joint assignment specified by \a states, where it is
+         *  assumed that \a states[\a m]==0 for all \a m in *this which are not in \a states.
+         *  
+         *  The linear index is calculated as follows. The variables in *this are
+         *  ordered according to their label (in ascending order); say *this corresponds with
+         *  the set \f$\{x_{l(0)},x_{l(1)},\dots,x_{l(n-1)}\}\f$ with \f$l(0) < l(1) < \dots < l(n-1)\f$,
+         *  where variable \f$x_l\f$ has label \a l. Denote by \f$S_l\f$ the number of possible values
+         *  ("states") of variable \f$x_l\f$. The argument \a states corresponds
+         *  with a mapping \a s that assigns to each variable \f$x_l\f$ a state \f$s(x_l) \in \{0,1,\dots,S_l-1\}\f$,
+         *  where \f$s(x_l)=0\f$ if \f$x_l\f$ is not specified in \a states. The linear index \a S corresponding
+         *  with \a states is now calculated as:
+         *  \f{eqnarray*}
+         *    S &:=& \sum_{i=0}^{n-1} s(x_{l(i)}) \prod_{j=0}^{i-1} S_{l(j)} \\
+         *      &= & s(x_{l(0)}) + s(x_{l(1)}) S_{l(0)} + s(x_{l(2)}) S_{l(0)} S_{l(1)} + \dots + s(x_{l(n-1)}) S_{l(0)} \cdots S_{l(n-2)}.
+         *  \f}
+         *
+         *  \note If *this corresponds with \f$\{x_l\}_{l\in L}\f$, and \a states specifies a state
+         *  for each variable \f$x_l\f$ for \f$l\in L\f$, calcState(const std::map<Var,size_t> &) induces a mapping 
+         *  \f$\sigma : \prod_{l\in L} X_l \to \{0,1,\dots,\prod_{l\in L} S_l-1\}\f$ that
+         *  maps a joint state to a linear index; this is the inverse of the mapping 
+         *  \f$\sigma^{-1}\f$ induced by calcStates(size_t).
+         */
+        size_t calcState( const std::map<Var, size_t> &states ) {
             size_t prod = 1;
             size_t state = 0;
-            foreach( const Var &n, *this ) {
-                std::map<Var, size_t>::const_iterator m = states.find( n );
+            for( VarSet::const_iterator n = begin(); n != end(); n++ ) {
+                std::map<Var, size_t>::const_iterator m = states.find( *n );
                 if( m != states.end() )
                     state += prod * m->second;
-                prod *= n.states();
+                prod *= n->states();
             }
             return state;
         }
 
-    private:
-        /// Calculates the number of states
-        size_t calcStates() {
-            _states = 1;
-            foreach( Var &i, _vars )
-                _states *= i.states();
-            return _states;
+        /// Calculates the joint assignment of the variables in *this corresponding to the linear index \a linearState.
+        /** \param linearState should be smaller than nrStates().
+         *  \return A mapping \f$s\f$ that maps each Var \f$x_l\f$ in *this to its state \f$s(x_l)\f$, as specified by \a linearState.
+         *
+         *  The variables in *this are ordered according to their label (in ascending order); say *this corresponds with
+         *  the set \f$\{x_{l(0)},x_{l(1)},\dots,x_{l(n-1)}\}\f$ with \f$l(0) < l(1) < \dots < l(n-1)\f$,
+         *  where variable \f$x_l\f$ has label \a l. Denote by \f$S_l\f$ the number of possible values
+         *  ("states") of variable \f$x_l\f$ with label \a l. 
+         *  The mapping \a s returned by this function is defined as:
+         *  \f{eqnarray*}
+         *    s(x_{l(i)}) = \left\lfloor\frac{S \mbox { mod } \prod_{j=0}^{i} S_{l(j)}}{\prod_{j=0}^{i-1} S_{l(j)}}\right\rfloor \qquad \mbox{for all $i=0,\dots,n-1$}.
+         *  \f}
+         *  where \f$S\f$ denotes the value of \a linearState.
+         *
+         *  \note If *this corresponds with \f$\{x_l\}_{l\in L}\f$, calcStates(size_t) induces a mapping 
+         *  \f$\sigma^{-1} : \{0,1,\dots,\prod_{l\in L} S_l-1\} \to \prod_{l\in L} X_l\f$ that
+         *  maps a linear index to a joint state; this is the inverse of the mapping \f$\sigma\f$ 
+         *  induced by calcState(const std::map<Var,size_t> &).
+         */
+        std::map<Var, size_t> calcStates( size_t linearState ) {
+            std::map<Var, size_t> states;
+            for( VarSet::const_iterator n = begin(); n != end(); n++ ) {
+                states[*n] = linearState % n->states();
+                linearState /= n->states();
+            }
+            assert( linearState == 0 );
+            return states;
+        }
+
+        /// Writes a VarSet to an output stream
+        friend std::ostream& operator<< (std::ostream &os, const VarSet& ns)  {
+            os << "{";
+            for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
+                os << (n != ns.begin() ? "," : "") << *n;
+            os << "}";
+            return( os );
         }
 };
 
 
-/// For two Vars n1 and n2, the expression n1 | n2 gives the Varset containing n1 and n2
-inline VarSet operator| (const Var& n1, const Var& n2) {
-    return( VarSet(n1, n2) );
-}
+} // end of namespace dai
 
 
-} // end of namespace dai
+/** \example example_varset.cpp
+ *  This example shows how to use the Var and VarSet classes. It also explains the concept of "states" for VarSets.
+ *
+ *  \section Output
+ *  \verbinclude examples/example_varset.out
+ *
+ *  \section Source
+ */
 
 
 #endif