Cleanup of BBP code
[libdai.git] / include / dai / index.h
index 7399c5c..46b2244 100644 (file)
@@ -1,6 +1,10 @@
-/*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
+/*  Copyright (C) 2006-2008  Joris Mooij  [joris dot mooij at tuebingen dot mpg dot de]
+    Radboud University Nijmegen, The Netherlands /
+    Max Planck Institute for Biological Cybernetics, Germany
+
+    Copyright (C) 2002  Martijn Leisink  [martijn@mbfys.kun.nl]
     Radboud University Nijmegen, The Netherlands
-    
+
     This file is part of libDAI.
 
     libDAI is free software; you can redistribute it and/or modify
 */
 
 
+/// \file
+/// \brief Defines the IndexFor, MultiFor, Permute and State classes
+/// \todo Improve documentation
+
+
 #ifndef __defined_libdai_index_h
 #define __defined_libdai_index_h
 
 
 #include <vector>
+#include <algorithm>
+#include <map>
+#include <cassert>
 #include <dai/varset.h>
 
 
 namespace dai {
 
 
-/* Example:
- *
- * Index i ({s_j_1,s_j_2,...,s_j_m}, {s_1,...,s_N});    // j_k in {1,...,N}
- * for( ; i>=0; ++i ) {
- *      // loops over all states of (s_1,...,s_N)
- *      // i is linear index of corresponding state of (s_j_1, ..., s_j_m)
- * }
+/// Tool for looping over the states of several variables.
+/** The class IndexFor is an important tool for indexing Factor entries.
+ *  Its usage can best be explained by an example.
+ *  Assume indexVars, forVars are both VarSets.
+ *  Then the following code:
+ *  \code
+ *      IndexFor i( indexVars, forVars );
+ *      for( ; i >= 0; ++i ) {
+ *          // use long(i)
+ *      }
+ *  \endcode
+ *  loops over all joint states of the variables in forVars,
+ *  and (long)i is equal to the linear index of the corresponding
+ *  state of indexVars, where the variables in indexVars that are
+ *  not in forVars assume their zero'th value.
+ *  \idea Optimize all indices as follows: keep a cache of all (or only 
+ *  relatively small) indices that have been computed (use a hash). Then, 
+ *  instead of computing on the fly, use the precomputed ones.
  */
+class IndexFor {
+    private:
+        /// The current linear index corresponding to the state of indexVars
+        long                _index;
+
+        /// For each variable in forVars, the amount of change in _index
+        std::vector<long>   _sum;
+
+        /// For each variable in forVars, the current state
+        std::vector<size_t> _count;
+        
+        /// For each variable in forVars, its number of possible values
+        std::vector<size_t> _dims;
+
+    public:
+        /// Default constructor
+        IndexFor() { 
+            _index = -1; 
+        }
+
+        /// Constructor
+        IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
+            long sum = 1;
+
+            _dims.reserve( forVars.size() );
+            _sum.reserve( forVars.size() );
+
+            VarSet::const_iterator j = forVars.begin();
+            for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
+                for( ; j != forVars.end() && *j <= *i; ++j ) {
+                    _dims.push_back( j->states() );
+                    _sum.push_back( (*i == *j) ? sum : 0 );
+                }
+                sum *= i->states();
+            }
+            for( ; j != forVars.end(); ++j ) {
+                _dims.push_back( j->states() );
+                _sum.push_back( 0 );
+            }
+            _index = 0;
+        }
+
+        /// Sets the index back to zero
+        IndexFor& clear() {
+            fill( _count.begin(), _count.end(), 0 );
+            _index = 0;
+            return( *this );
+        }
+
+        /// Conversion to long
+        operator long () const { 
+            return( _index ); 
+        }
+
+        /// Pre-increment operator
+        IndexFor& operator++ () {
+            if( _index >= 0 ) {
+                size_t i = 0;
+
+                while( i < _count.size() ) {
+                    _index += _sum[i];
+                    if( ++_count[i] < _dims[i] )
+                        break;
+                    _index -= _sum[i] * _dims[i];
+                    _count[i] = 0;
+                    i++;
+                }
+
+                if( i == _count.size() ) 
+                    _index = -1;
+            }
+            return( *this );
+        }
+};
+
+
+/// MultiFor makes it easy to perform a dynamic number of nested for loops.
+/** An example of the usage is as follows:
+ *  \code
+ *  std::vector<size_t> dims;
+ *  dims.push_back( 3 );
+ *  dims.push_back( 4 );
+ *  dims.push_back( 5 );
+ *  for( MultiFor s(dims); s.valid(); ++s )
+ *      cout << "linear index: " << (size_t)s << " corresponds to indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
+ *  \endcode
+ *  which would be equivalent to:
+ *  \code
+ *  size_t s = 0;
+ *  for( size_t s0 = 0; s0 < 3; s0++ )
+ *      for( size_t s1 = 0; s1 < 4; s1++ )
+ *          for( size_t s2 = 0; s2 < 5; s++, s2++ )
+ *              cout << "linear index: " << (size_t)s << " corresponds to indices " << s0 << ", " << s1 << ", " << s2 << endl;
+ *  \endcode
+ */
+class MultiFor {
+    private:
+        std::vector<size_t>  _dims;
+        std::vector<size_t>  _states;
+        long                 _state;
+
+    public:
+        /// Default constructor
+        MultiFor() : _dims(), _states(), _state(0) {}
+
+        /// Initialize from vector of index dimensions
+        MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
 
+        /// Return linear state
+        operator size_t() const { 
+            assert( valid() );
+            return( _state );
+        }
+
+        /// Return k'th index
+        size_t operator[]( size_t k ) const {
+            assert( valid() );
+            assert( k < _states.size() );
+            return _states[k];
+        }
 
-class Index
-{
-private:
-    long _index;
-    std::vector<int> _count,_max,_sum;
-public:
-    Index () { _index=-1; };
-    Index (const VarSet& P, const VarSet& ns)
-    {
-        long sum=1;
-        VarSet::const_iterator j=ns.begin();
-        for(VarSet::const_iterator i=P.begin();i!=P.end();++i)
-        {
-            for(;j!=ns.end()&&j->label()<=i->label();++j)
-            {
-                _count.push_back(0);
-                _max.push_back(j->states());
-                _sum.push_back((i->label()==j->label())?sum:0);
-            };
-            sum*=i->states();
-        };
-        for(;j!=ns.end();++j)
-        {
-            _count.push_back(0);
-            _max.push_back(j->states());
-            _sum.push_back(0);
-        };
-        _index=0;
-    };
-    Index (const Index & ind) : _index(ind._index), _count(ind._count), _max(ind._max), _sum(ind._sum) {};
-    Index & operator=(const Index & ind) {
-        if(this!=&ind) {
-            _index = ind._index;
-            _count = ind._count;
-            _max = ind._max;
-            _sum = ind._sum;
-        }
-        return *this;
-    }
-    Index& clear ()
-    {
-        for(unsigned i=0;i!=_count.size();++i) _count[i]=0;
-        _index=0;
-        return(*this);
-    };
-    operator long () const { return(_index); };
-    Index& operator ++ ()
-    {
-        if(_index>=0)
-        {
-            unsigned i;
-            for(i=0;(i<_count.size())
-                    &&(_index+=_sum[i],++_count[i]==_max[i]);++i)
-            {
-                _index-=_sum[i]*_max[i];
-                _count[i]=0;
-            };
-            if(i==_count.size()) _index=-1;
-        };
-        return(*this);
-    };
+        /// Prefix increment operator
+        MultiFor & operator++() {
+            if( valid() ) {
+                _state++;
+                size_t i;
+                for( i = 0; i != _states.size(); i++ ) {
+                    if( ++(_states[i]) < _dims[i] )
+                        break;
+                    _states[i] = 0;
+                }
+                if( i == _states.size() )
+                    _state = -1;
+            }
+            return *this;
+        }
+
+        /// Postfix increment operator
+        void operator++( int ) {
+            operator++();
+        }
+
+        /// Returns true if the current state is valid
+        bool valid() const {
+            return( _state >= 0 );
+        }
 };
 
 
-class multind {
+/// Tool for calculating permutations of multiple indices.
+class Permute {
     private:
-        std::vector<size_t> _dims;  // dimensions
-        std::vector<size_t> _pdims; // products of dimensions
+        std::vector<size_t>  _dims;
+        std::vector<size_t>  _sigma;
 
     public:
-        multind(const std::vector<size_t> di) {
-            _dims = di;
+        /// Default constructor
+        Permute() : _dims(), _sigma() {}
+
+        /// Initialize from vector of index dimensions and permutation sigma
+        Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
+            assert( _dims.size() == _sigma.size() );
+        }
+
+        /// Calculates a permuted linear index.
+        /** Converts the linear index li to a vector index
+         *  corresponding with the dimensions in _dims, permutes it according to sigma, 
+         *  and converts it back to a linear index  according to the permuted dimensions.
+         */
+        size_t convert_linear_index( size_t li ) {
+            size_t N = _dims.size();
+
+            // calculate vector index corresponding to linear index
+            std::vector<size_t> vi;
+            vi.reserve( N );
             size_t prod = 1;
-            for( std::vector<size_t>::const_iterator i=di.begin(); i!=di.end(); i++ ) {
-                _pdims.push_back(prod);
-                prod = prod * (*i);
+            for( size_t k = 0; k < N; k++ ) {
+                vi.push_back( li % _dims[k] );
+                li /= _dims[k];
+                prod *= _dims[k];
             }
-            _pdims.push_back(prod);
+
+            // convert permuted vector index to corresponding linear index
+            prod = 1;
+            size_t sigma_li = 0;
+            for( size_t k = 0; k < N; k++ ) {
+                sigma_li += vi[_sigma[k]] * prod;
+                prod *= _dims[_sigma[k]];
+            }
+
+            return sigma_li;
         }
-        multind(const VarSet& ns) {
-            _dims.reserve( ns.size() ); 
-            _pdims.reserve( ns.size() + 1 ); 
+};
+
+
+/// Contains the joint state of variables within a VarSet and useful things to do with this information.
+/** This is very similar to a MultiFor, but tailored for Vars and Varsets.
+ */
+class State {
+    private:
+        typedef std::map<Var, size_t> states_type;
+
+        long                          state;
+        states_type                   states;
+        
+    public:
+        /// Default constructor
+        State() : state(0), states() {}
+
+        /// Initialize from VarSet
+        State( const VarSet &vs ) : state(0) {
+            for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
+                states[*v] = 0;
+        }
+
+        /// Return linear state
+        operator size_t() const { 
+            assert( valid() );
+            return( state );
+        }
+
+        /// Return state of variable n, or zero if n is not in this State
+        size_t operator() ( const Var &n ) const {
+            assert( valid() );
+            states_type::const_iterator entry = states.find( n );
+            if( entry == states.end() )
+                return 0;
+            else
+                return entry->second;
+        }
+
+        /// Return linear state of variables in varset, setting them to zero if they are not in this State
+        size_t operator() ( const VarSet &vs ) const {
+            assert( valid() );
+            size_t vs_state = 0;
             size_t prod = 1;
-            for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ ) {
-                _pdims.push_back( prod );
-                prod *= n->states();
-                _dims.push_back( n->states() );
+            for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
+                states_type::const_iterator entry = states.find( *v );
+                if( entry != states.end() )
+                    vs_state += entry->second * prod; 
+                prod *= v->states();
             }
-            _pdims.push_back( prod );
-        }
-        std::vector<size_t> vi(size_t li) const {   // linear index to vector index
-            std::vector<size_t> v(_dims.size(),0);
-            assert(li < _pdims.back());
-            for( long j = v.size()-1; j >= 0; j-- ) {
-                size_t q = li / _pdims[j];
-                v[j] = q;
-                li = li - q * _pdims[j];
+            return vs_state;
+        }
+
+        /// Prefix increment operator
+        void operator++( ) {
+            if( valid() ) {
+                state++;
+                states_type::iterator entry = states.begin();
+                while( entry != states.end() ) {
+                    if( ++(entry->second) < entry->first.states() )
+                        break;
+                    entry->second = 0;
+                    entry++;
+                }
+                if( entry == states.end() )
+                    state = -1;
             }
-            return v;
         }
-        size_t li(const std::vector<size_t> vi) const { // linear index
-            size_t s = 0;
-            assert(vi.size() == _dims.size());
-            for( size_t j = 0; j < vi.size(); j++ ) 
-                s += vi[j] * _pdims[j];
-            return s;
+        
+        /// Postfix increment operator
+        void operator++( int ) {
+               operator++();
         }
-        size_t max() const { return( _pdims.back() ); };
 
-        // FIXME add an iterator, which increases a vector index just using addition
+        /// Returns true if the current state is valid
+        bool valid() const {
+            return( state >= 0 );
+        }
 };