Cleanup of BBP code
[libdai.git] / include / dai / index.h
index 9193d5b..46b2244 100644 (file)
 */
 
 
+/// \file
+/// \brief Defines the IndexFor, MultiFor, Permute and State classes
+/// \todo Improve documentation
+
+
 #ifndef __defined_libdai_index_h
 #define __defined_libdai_index_h
 
 namespace dai {
 
 
-    /// Tool for looping over the states of several variables.
-    /** The class IndexFor is an important tool for indexing of Factors.
-     *  Its usage can best be explained by an example.
-     *  Assume indexVars, forVars are two VarSets.
-     *  Then the following code:
-     *  \code
-     *      IndexFor i( indexVars, forVars );
-     *      for( ; i >= 0; ++i ) {
-     *          // use long(i)
-     *      }
-     *  \endcode
-     *  loops over all joint states of the variables in forVars,
-     *  and (long)i is equal to the linear index of the corresponding
-     *  state of indexVars, where the variables in indexVars that are
-     *  not in forVars assume their zero'th value.
-     */
-    class IndexFor {
-        private:
-            /// The current linear index corresponding to the state of indexVars
-            long                _index;
-
-            /// For each variable in forVars, the amount of change in _index
-            std::vector<long>   _sum;
-
-            /// For each variable in forVars, the current state
-            std::vector<size_t> _count;
-            
-            /// For each variable in forVars, its number of possible values
-            std::vector<size_t> _dims;
-
-        public:
-            /// Default constructor
-            IndexFor() { 
-                _index = -1; 
-            }
+/// Tool for looping over the states of several variables.
+/** The class IndexFor is an important tool for indexing Factor entries.
+ *  Its usage can best be explained by an example.
+ *  Assume indexVars, forVars are both VarSets.
+ *  Then the following code:
+ *  \code
+ *      IndexFor i( indexVars, forVars );
+ *      for( ; i >= 0; ++i ) {
+ *          // use long(i)
+ *      }
+ *  \endcode
+ *  loops over all joint states of the variables in forVars,
+ *  and (long)i is equal to the linear index of the corresponding
+ *  state of indexVars, where the variables in indexVars that are
+ *  not in forVars assume their zero'th value.
+ *  \idea Optimize all indices as follows: keep a cache of all (or only 
+ *  relatively small) indices that have been computed (use a hash). Then, 
+ *  instead of computing on the fly, use the precomputed ones.
+ */
+class IndexFor {
+    private:
+        /// The current linear index corresponding to the state of indexVars
+        long                _index;
 
-            /// Constructor
-            IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
-                long sum = 1;
+        /// For each variable in forVars, the amount of change in _index
+        std::vector<long>   _sum;
 
-                _dims.reserve( forVars.size() );
-                _sum.reserve( forVars.size() );
+        /// For each variable in forVars, the current state
+        std::vector<size_t> _count;
+        
+        /// For each variable in forVars, its number of possible values
+        std::vector<size_t> _dims;
 
-                VarSet::const_iterator j = forVars.begin();
-                for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
-                    for( ; j != forVars.end() && *j <= *i; ++j ) {
-                        _dims.push_back( j->states() );
-                        _sum.push_back( (*i == *j) ? sum : 0 );
-                    }
-                    sum *= i->states();
-                }
-                for( ; j != forVars.end(); ++j ) {
-                    _dims.push_back( j->states() );
-                    _sum.push_back( 0 );
-                }
-                _index = 0;
-            }
+    public:
+        /// Default constructor
+        IndexFor() { 
+            _index = -1; 
+        }
 
-            /// Copy constructor
-            IndexFor( const IndexFor & ind ) : _index(ind._index), _sum(ind._sum), _count(ind._count), _dims(ind._dims) {}
+        /// Constructor
+        IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
+            long sum = 1;
 
-            /// Assignment operator
-            IndexFor& operator=( const IndexFor &ind ) {
-                if( this != &ind ) {
-                    _index = ind._index;
-                    _sum = ind._sum;
-                    _count = ind._count;
-                    _dims = ind._dims;
+            _dims.reserve( forVars.size() );
+            _sum.reserve( forVars.size() );
+
+            VarSet::const_iterator j = forVars.begin();
+            for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
+                for( ; j != forVars.end() && *j <= *i; ++j ) {
+                    _dims.push_back( j->states() );
+                    _sum.push_back( (*i == *j) ? sum : 0 );
                 }
-                return *this;
+                sum *= i->states();
             }
-
-            /// Sets the index back to zero
-            IndexFor& clear() {
-                fill( _count.begin(), _count.end(), 0 );
-                _index = 0;
-                return( *this );
+            for( ; j != forVars.end(); ++j ) {
+                _dims.push_back( j->states() );
+                _sum.push_back( 0 );
             }
+            _index = 0;
+        }
 
-            /// Conversion to long
-            operator long () const { 
-                return( _index ); 
-            }
+        /// Sets the index back to zero
+        IndexFor& clear() {
+            fill( _count.begin(), _count.end(), 0 );
+            _index = 0;
+            return( *this );
+        }
+
+        /// Conversion to long
+        operator long () const { 
+            return( _index ); 
+        }
 
-            /// Pre-increment operator
-            IndexFor& operator++ () {
-                if( _index >= 0 ) {
-                    size_t i = 0;
-
-                    while( i < _count.size() ) {
-                        _index += _sum[i];
-                        if( ++_count[i] < _dims[i] )
-                            break;
-                        _index -= _sum[i] * _dims[i];
-                        _count[i] = 0;
-                        i++;
-                    }
-
-                    if( i == _count.size() ) 
-                        _index = -1;
+        /// Pre-increment operator
+        IndexFor& operator++ () {
+            if( _index >= 0 ) {
+                size_t i = 0;
+
+                while( i < _count.size() ) {
+                    _index += _sum[i];
+                    if( ++_count[i] < _dims[i] )
+                        break;
+                    _index -= _sum[i] * _dims[i];
+                    _count[i] = 0;
+                    i++;
                 }
-                return( *this );
+
+                if( i == _count.size() ) 
+                    _index = -1;
             }
-    };
+            return( *this );
+        }
+};
 
 
 /// MultiFor makes it easy to perform a dynamic number of nested for loops.
 /** An example of the usage is as follows:
  *  \code
- *      std::vector<size_t> dims;
- *      dims.push_back( 3 );
- *      dims.push_back( 4 );
- *      dims.push_back( 5 );
- *      for( MultiFor s(dims); s.valid(); ++s )
- *          cout << "linear index: " << (size_t)s << " corresponds with indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
+ *  std::vector<size_t> dims;
+ *  dims.push_back( 3 );
+ *  dims.push_back( 4 );
+ *  dims.push_back( 5 );
+ *  for( MultiFor s(dims); s.valid(); ++s )
+ *      cout << "linear index: " << (size_t)s << " corresponds to indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
  *  \endcode
  *  which would be equivalent to:
  *  \code
@@ -159,7 +153,7 @@ namespace dai {
  *  for( size_t s0 = 0; s0 < 3; s0++ )
  *      for( size_t s1 = 0; s1 < 4; s1++ )
  *          for( size_t s2 = 0; s2 < 5; s++, s2++ )
- *              cout << "linear index: " << (size_t)s << " corresponds with indices " << s0 << ", " << s1 << ", " << s2 << endl;
+ *              cout << "linear index: " << (size_t)s << " corresponds to indices " << s0 << ", " << s1 << ", " << s2 << endl;
  *  \endcode
  */
 class MultiFor {
@@ -175,19 +169,6 @@ class MultiFor {
         /// Initialize from vector of index dimensions
         MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
 
-        /// Copy constructor
-        MultiFor( const MultiFor &x ) : _dims(x._dims), _states(x._states), _state(x._state) {}
-
-        /// Assignment operator
-        MultiFor& operator=( const MultiFor & x ) {
-            if( this != &x ) {
-                _dims   = x._dims;
-                _states = x._states;
-                _state  = x._state;
-            }
-            return *this;
-        }
-
         /// Return linear state
         operator size_t() const { 
             assert( valid() );
@@ -244,23 +225,11 @@ class Permute {
             assert( _dims.size() == _sigma.size() );
         }
 
-        /// Copy constructor
-        Permute( const Permute &x ) : _dims(x._dims), _sigma(x._sigma) {}
-
-        /// Assignment operator
-        Permute& operator=( const Permute &x ) {
-            if( this != &x ) {
-                _dims  = x._dims;
-                _sigma = x._sigma;
-            }
-            return *this;
-        }
-
-        /// Converts the linear index li to a vector index
-        /// corresponding with the dimensions in _dims,
-        /// permutes it according to sigma, 
-        /// and converts it back to a linear index
-        /// according to the permuted dimensions.
+        /// Calculates a permuted linear index.
+        /** Converts the linear index li to a vector index
+         *  corresponding with the dimensions in _dims, permutes it according to sigma, 
+         *  and converts it back to a linear index  according to the permuted dimensions.
+         */
         size_t convert_linear_index( size_t li ) {
             size_t N = _dims.size();
 
@@ -287,8 +256,9 @@ class Permute {
 };
 
 
-/// Contains the state of variables within a VarSet and useful things to do with this information.
-/// This is very similar to a MultiFor, but tailored for Vars and Varsets.
+/// Contains the joint state of variables within a VarSet and useful things to do with this information.
+/** This is very similar to a MultiFor, but tailored for Vars and Varsets.
+ */
 class State {
     private:
         typedef std::map<Var, size_t> states_type;
@@ -306,26 +276,13 @@ class State {
                 states[*v] = 0;
         }
 
-        /// Copy constructor
-        State( const State & x ) : state(x.state), states(x.states) {}
-
-        /// Assignment operator
-        State& operator=( const State &x ) {
-            if( this != &x ) {
-                state  = x.state;
-                states = x.states;
-            }
-            return *this;
-        }
-
         /// Return linear state
         operator size_t() const { 
             assert( valid() );
             return( state );
         }
 
-        /// Return state of variable n,
-        /// or zero if n is not in this State
+        /// Return state of variable n, or zero if n is not in this State
         size_t operator() ( const Var &n ) const {
             assert( valid() );
             states_type::const_iterator entry = states.find( n );
@@ -335,8 +292,7 @@ class State {
                 return entry->second;
         }
 
-        /// Return linear state of variables in varset,
-        /// setting them to zero if they are not in this State
+        /// Return linear state of variables in varset, setting them to zero if they are not in this State
         size_t operator() ( const VarSet &vs ) const {
             assert( valid() );
             size_t vs_state = 0;
@@ -350,8 +306,8 @@ class State {
             return vs_state;
         }
 
-        /// Postfix increment operator
-        void operator++( int ) {
+        /// Prefix increment operator
+        void operator++( ) {
             if( valid() ) {
                 state++;
                 states_type::iterator entry = states.begin();
@@ -365,6 +321,11 @@ class State {
                     state = -1;
             }
         }
+        
+        /// Postfix increment operator
+        void operator++( int ) {
+               operator++();
+        }
 
         /// Returns true if the current state is valid
         bool valid() const {