Improved index.h (merged from SVN head), which yields a 25% speedup.
authorJoris Mooij <jorism@marvin.jorismooij.nl>
Mon, 8 Sep 2008 17:36:10 +0000 (19:36 +0200)
committerJoris Mooij <jorism@marvin.jorismooij.nl>
Mon, 8 Sep 2008 17:36:10 +0000 (19:36 +0200)
Also, added some copyrights for Martijn Leisink.

include/dai/factor.h
include/dai/index.h
include/dai/var.h
include/dai/varset.h
matlab/matlab.cpp
src/bp.cpp
src/daialg.cpp
src/factorgraph.cpp
src/jtree.cpp
src/treeep.cpp

index d1fea49..7759ce3 100644 (file)
@@ -1,4 +1,5 @@
 /*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
+    Copyright (C) 2002  Martijn Leisink  [martijn@mbfys.kun.nl]
     Radboud University Nijmegen, The Netherlands
     
     This file is part of libDAI.
@@ -227,8 +228,8 @@ template <typename T> class TFactor {
             Factor result( nsrem, 0.0 );
             
             // OPTIMIZE ME
-            Index i_ns (ns, _vs);
-            Index i_nsrem (nsrem, _vs);
+            IndexFor i_ns (ns, _vs);
+            IndexFor i_nsrem (nsrem, _vs);
             for( size_t i = 0; i < states(); i++, ++i_ns, ++i_nsrem )
                 if( (size_t)i_ns == ns_state )
                     result._p[i_nsrem] = _p[i];
@@ -271,7 +272,7 @@ template<typename T> TFactor<T> TFactor<T>::part_sum(const VarSet & ns) const {
 
     TFactor<T> res( ns, 0.0 );
 
-    Index i_res( ns, _vs );
+    IndexFor i_res( ns, _vs );
     for( size_t i = 0; i < _p.size(); i++, ++i_res )
         res._p[i_res] += _p[i];
 
@@ -291,8 +292,8 @@ template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T
 template<typename T> TFactor<T> TFactor<T>::operator* (const TFactor<T>& Q) const {
     TFactor<T> prod( _vs | Q._vs, 0.0 );
 
-    Index i1(_vs, prod._vs);
-    Index i2(Q._vs, prod._vs);
+    IndexFor i1(_vs, prod._vs);
+    IndexFor i2(Q._vs, prod._vs);
 
     for( size_t i = 0; i < prod._p.size(); i++, ++i1, ++i2 )
         prod._p[i] += _p[i1] * Q._p[i2];
index 7399c5c..83d35af 100644 (file)
@@ -1,4 +1,5 @@
 /*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
+    Copyright (C) 2002  Martijn Leisink  [martijn@mbfys.kun.nl]
     Radboud University Nijmegen, The Netherlands
     
     This file is part of libDAI.
 
 
 #include <vector>
+#include <algorithm>
+#include <map>
+#include <cassert>
 #include <dai/varset.h>
 
 
 namespace dai {
 
 
-/* Example:
- *
- * Index i ({s_j_1,s_j_2,...,s_j_m}, {s_1,...,s_N});    // j_k in {1,...,N}
- * for( ; i>=0; ++i ) {
- *      // loops over all states of (s_1,...,s_N)
- *      // i is linear index of corresponding state of (s_j_1, ..., s_j_m)
- * }
- */
+    /// Tool for looping over the states of several variables.
+    /** The class IndexFor is an important tool for indexing of Factors.
+     *  Its usage can best be explained by an example.
+     *  Assume indexVars, forVars are two VarSets.
+     *  Then the following code:
+     *  \code
+     *      IndexFor i( indexVars, forVars );
+     *      for( ; i >= 0; ++i ) {
+     *          // use long(i)
+     *      }
+     *  \endcode
+     *  loops over all joint states of the variables in forVars,
+     *  and (long)i is equal to the linear index of the corresponding
+     *  state of indexVars, where the variables in indexVars that are
+     *  not in forVars assume their zero'th value.
+     */
+    class IndexFor {
+        private:
+            /// The current linear index corresponding to the state of indexVars
+            long                _index;
 
+            /// For each variable in forVars, the amount of change in _index
+            std::vector<long>   _sum;
 
-class Index
-{
-private:
-    long _index;
-    std::vector<int> _count,_max,_sum;
-public:
-    Index () { _index=-1; };
-    Index (const VarSet& P, const VarSet& ns)
-    {
-        long sum=1;
-        VarSet::const_iterator j=ns.begin();
-        for(VarSet::const_iterator i=P.begin();i!=P.end();++i)
-        {
-            for(;j!=ns.end()&&j->label()<=i->label();++j)
-            {
-                _count.push_back(0);
-                _max.push_back(j->states());
-                _sum.push_back((i->label()==j->label())?sum:0);
-            };
-            sum*=i->states();
-        };
-        for(;j!=ns.end();++j)
-        {
-            _count.push_back(0);
-            _max.push_back(j->states());
-            _sum.push_back(0);
-        };
-        _index=0;
-    };
-    Index (const Index & ind) : _index(ind._index), _count(ind._count), _max(ind._max), _sum(ind._sum) {};
-    Index & operator=(const Index & ind) {
-        if(this!=&ind) {
-            _index = ind._index;
-            _count = ind._count;
-            _max = ind._max;
-            _sum = ind._sum;
-        }
-        return *this;
-    }
-    Index& clear ()
-    {
-        for(unsigned i=0;i!=_count.size();++i) _count[i]=0;
-        _index=0;
-        return(*this);
-    };
-    operator long () const { return(_index); };
-    Index& operator ++ ()
-    {
-        if(_index>=0)
-        {
-            unsigned i;
-            for(i=0;(i<_count.size())
-                    &&(_index+=_sum[i],++_count[i]==_max[i]);++i)
-            {
-                _index-=_sum[i]*_max[i];
-                _count[i]=0;
-            };
-            if(i==_count.size()) _index=-1;
-        };
-        return(*this);
+            /// For each variable in forVars, the current state
+            std::vector<size_t> _count;
+            
+            /// For each variable in forVars, its number of possible values
+            std::vector<size_t> _dims;
+
+        public:
+            /// Default constructor
+            IndexFor() { 
+                _index = -1; 
+            }
+
+            /// Constructor
+            IndexFor( const VarSet& indexVars, const VarSet& forVars ) : _count( forVars.size(), 0 ) {
+                long sum = 1;
+
+                _dims.reserve( forVars.size() );
+                _sum.reserve( forVars.size() );
+
+                VarSet::const_iterator j = forVars.begin();
+                for( VarSet::const_iterator i = indexVars.begin(); i != indexVars.end(); ++i ) {
+                    for( ; j != forVars.end() && *j <= *i; ++j ) {
+                        _dims.push_back( j->states() );
+                        _sum.push_back( (*i == *j) ? sum : 0 );
+                    }
+                    sum *= i->states();
+                }
+                for( ; j != forVars.end(); ++j ) {
+                    _dims.push_back( j->states() );
+                    _sum.push_back( 0 );
+                }
+                _index = 0;
+            }
+
+            /// Copy constructor
+            IndexFor( const IndexFor & ind ) : _index(ind._index), _sum(ind._sum), _count(ind._count), _dims(ind._dims) {}
+
+            /// Assignment operator
+            IndexFor& operator=( const IndexFor &ind ) {
+                if( this != &ind ) {
+                    _index = ind._index;
+                    _sum = ind._sum;
+                    _count = ind._count;
+                    _dims = ind._dims;
+                }
+                return *this;
+            }
+
+            /// Sets the index back to zero
+            IndexFor& clear() {
+                fill( _count.begin(), _count.end(), 0 );
+                _index = 0;
+                return( *this );
+            }
+
+            /// Conversion to long
+            operator long () const { 
+                return( _index ); 
+            }
+
+            /// Pre-increment operator
+            IndexFor& operator++ () {
+                if( _index >= 0 ) {
+                    size_t i = 0;
+
+                    while( i < _count.size() ) {
+                        _index += _sum[i];
+                        if( ++_count[i] < _dims[i] )
+                            break;
+                        _index -= _sum[i] * _dims[i];
+                        _count[i] = 0;
+                        i++;
+                    }
+
+                    if( i == _count.size() ) 
+                        _index = -1;
+                }
+                return( *this );
+            }
     };
+
+
+/// MultiFor makes it easy to perform a dynamic number of nested for loops.
+/** An example of the usage is as follows:
+ *  \code
+ *      std::vector<size_t> dims;
+ *      dims.push_back( 3 );
+ *      dims.push_back( 4 );
+ *      dims.push_back( 5 );
+ *      for( MultiFor s(dims); s.valid(); ++s )
+ *          cout << "linear index: " << (size_t)s << " corresponds with indices " << s[0] << ", " << s[1] << ", " << s[2] << endl;
+ *  \endcode
+ *  which would be equivalent to:
+ *  \code
+ *  size_t s = 0;
+ *  for( size_t s0 = 0; s0 < 3; s0++ )
+ *      for( size_t s1 = 0; s1 < 4; s1++ )
+ *          for( size_t s2 = 0; s2 < 5; s++, s2++ )
+ *              cout << "linear index: " << (size_t)s << " corresponds with indices " << s0 << ", " << s1 << ", " << s2 << endl;
+ *  \endcode
+ */
+class MultiFor {
+    private:
+        std::vector<size_t>  _dims;
+        std::vector<size_t>  _states;
+        long                 _state;
+
+    public:
+        /// Default constructor
+        MultiFor() : _dims(), _states(), _state(0) {}
+
+        /// Initialize from vector of index dimensions
+        MultiFor( const std::vector<size_t> &d ) : _dims(d), _states(d.size(),0), _state(0) {}
+
+        /// Copy constructor
+        MultiFor( const MultiFor &x ) : _dims(x._dims), _states(x._states), _state(x._state) {}
+
+        /// Assignment operator
+        MultiFor& operator=( const MultiFor & x ) {
+            if( this != &x ) {
+                _dims   = x._dims;
+                _states = x._states;
+                _state  = x._state;
+            }
+            return *this;
+        }
+
+        /// Return linear state
+        operator size_t() const { 
+            assert( valid() );
+            return( _state );
+        }
+
+        /// Return k'th index
+        size_t operator[]( size_t k ) const {
+            assert( valid() );
+            assert( k < _states.size() );
+            return _states[k];
+        }
+
+        /// Prefix increment operator
+        MultiFor & operator++() {
+            if( valid() ) {
+                _state++;
+                size_t i;
+                for( i = 0; i != _states.size(); i++ ) {
+                    if( ++(_states[i]) < _dims[i] )
+                        break;
+                    _states[i] = 0;
+                }
+                if( i == _states.size() )
+                    _state = -1;
+            }
+            return *this;
+        }
+
+        /// Postfix increment operator
+        void operator++( int ) {
+            operator++();
+        }
+
+        /// Returns true if the current state is valid
+        bool valid() const {
+            return( _state >= 0 );
+        }
 };
 
 
-class multind {
+/// Tool for calculating permutations of multiple indices.
+class Permute {
     private:
-        std::vector<size_t> _dims;  // dimensions
-        std::vector<size_t> _pdims; // products of dimensions
+        std::vector<size_t>  _dims;
+        std::vector<size_t>  _sigma;
 
     public:
-        multind(const std::vector<size_t> di) {
-            _dims = di;
-            size_t prod = 1;
-            for( std::vector<size_t>::const_iterator i=di.begin(); i!=di.end(); i++ ) {
-                _pdims.push_back(prod);
-                prod = prod * (*i);
+        /// Default constructor
+        Permute() : _dims(), _sigma() {}
+
+        /// Initialize from vector of index dimensions and permutation sigma
+        Permute( const std::vector<size_t> &d, const std::vector<size_t> &sigma ) : _dims(d), _sigma(sigma) {
+            assert( _dims.size() == _sigma.size() );
+        }
+
+        /// Copy constructor
+        Permute( const Permute &x ) : _dims(x._dims), _sigma(x._sigma) {}
+
+        /// Assignment operator
+        Permute& operator=( const Permute &x ) {
+            if( this != &x ) {
+                _dims  = x._dims;
+                _sigma = x._sigma;
             }
-            _pdims.push_back(prod);
+            return *this;
         }
-        multind(const VarSet& ns) {
-            _dims.reserve( ns.size() ); 
-            _pdims.reserve( ns.size() + 1 ); 
+
+        /// Converts the linear index li to a vector index
+        /// corresponding with the dimensions in _dims,
+        /// permutes it according to sigma, 
+        /// and converts it back to a linear index
+        /// according to the permuted dimensions.
+        size_t convert_linear_index( size_t li ) {
+            size_t N = _dims.size();
+
+            // calculate vector index corresponding to linear index
+            std::vector<size_t> vi;
+            vi.reserve( N );
             size_t prod = 1;
-            for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ ) {
-                _pdims.push_back( prod );
-                prod *= n->states();
-                _dims.push_back( n->states() );
+            for( size_t k = 0; k < N; k++ ) {
+                vi.push_back( li % _dims[k] );
+                li /= _dims[k];
+                prod *= _dims[k];
             }
-            _pdims.push_back( prod );
-        }
-        std::vector<size_t> vi(size_t li) const {   // linear index to vector index
-            std::vector<size_t> v(_dims.size(),0);
-            assert(li < _pdims.back());
-            for( long j = v.size()-1; j >= 0; j-- ) {
-                size_t q = li / _pdims[j];
-                v[j] = q;
-                li = li - q * _pdims[j];
+
+            // convert permuted vector index to corresponding linear index
+            prod = 1;
+            size_t sigma_li = 0;
+            for( size_t k = 0; k < N; k++ ) {
+                sigma_li += vi[_sigma[k]] * prod;
+                prod *= _dims[_sigma[k]];
             }
-            return v;
+
+            return sigma_li;
         }
-        size_t li(const std::vector<size_t> vi) const { // linear index
-            size_t s = 0;
-            assert(vi.size() == _dims.size());
-            for( size_t j = 0; j < vi.size(); j++ ) 
-                s += vi[j] * _pdims[j];
-            return s;
+};
+
+
+/// Contains the state of variables within a VarSet and useful things to do with this information.
+/// This is very similar to a MultiFor, but tailored for Vars and Varsets.
+class State {
+    private:
+        typedef std::map<Var, size_t> states_type;
+
+        long                          state;
+        states_type                   states;
+        
+    public:
+        /// Default constructor
+        State() : state(0), states() {}
+
+        /// Initialize from VarSet
+        State( const VarSet &vs ) : state(0) {
+            for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
+                states[*v] = 0;
+        }
+
+        /// Copy constructor
+        State( const State & x ) : state(x.state), states(x.states) {}
+
+        /// Assignment operator
+        State& operator=( const State &x ) {
+            if( this != &x ) {
+                state  = x.state;
+                states = x.states;
+            }
+            return *this;
+        }
+
+        /// Return linear state
+        operator size_t() const { 
+            assert( valid() );
+            return( state );
+        }
+
+        /// Return state of variable n,
+        /// or zero if n is not in this State
+        size_t operator() ( const Var &n ) const {
+            assert( valid() );
+            states_type::const_iterator entry = states.find( n );
+            if( entry == states.end() )
+                return 0;
+            else
+                return entry->second;
+        }
+
+        /// Return linear state of variables in varset,
+        /// setting them to zero if they are not in this State
+        size_t operator() ( const VarSet &vs ) const {
+            assert( valid() );
+            size_t vs_state = 0;
+            size_t prod = 1;
+            for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ ) {
+                states_type::const_iterator entry = states.find( *v );
+                if( entry != states.end() )
+                    vs_state += entry->second * prod; 
+                prod *= v->states();
+            }
+            return vs_state;
         }
-        size_t max() const { return( _pdims.back() ); };
 
-        // FIXME add an iterator, which increases a vector index just using addition
+        /// Postfix increment operator
+        void operator++( int ) {
+            if( valid() ) {
+                state++;
+                states_type::iterator entry = states.begin();
+                while( entry != states.end() ) {
+                    if( ++(entry->second) < entry->first.states() )
+                        break;
+                    entry->second = 0;
+                    entry++;
+                }
+                if( entry == states.end() )
+                    state = -1;
+            }
+        }
+
+        /// Returns true if the current state is valid
+        bool valid() const {
+            return( state >= 0 );
+        }
 };
 
 
index 4391f5f..d0a9950 100644 (file)
@@ -1,4 +1,5 @@
 /*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
+    Copyright (C) 2002  Martijn Leisink  [martijn@mbfys.kun.nl]
     Radboud University Nijmegen, The Netherlands
     
     This file is part of libDAI.
index d13a2bc..9450637 100644 (file)
@@ -1,4 +1,5 @@
 /*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
+    Copyright (C) 2002  Martijn Leisink  [martijn@mbfys.kun.nl]
     Radboud University Nijmegen, The Netherlands
     
     This file is part of libDAI.
index 1e7ee2a..30eac37 100644 (file)
@@ -120,40 +120,14 @@ vector<Factor> mx2Factors(const mxArray *psi, long verbose) {
 
         // read Factor
         vector<size_t> di(nr_mem,0);
-        vector<size_t> pdi(nr_mem,0);
+        size_t prod = 1;
         for( size_t k = 0; k < nr_mem; k++ ) {
             di[k] = dims[k];
-            pdi[k] = dims[perm[k]];
-        }
-        multind mi(di);
-        multind pmi(pdi);
-        if( verbose >= 3 ) {
-            cout << "  mi.max(): " << mi.max() << endl;
-            cout << "       ";
-            for( size_t k = 0; k < nr_mem; k++ ) 
-                cout << labels[k] << " ";
-            cout << "   ";
-            for( size_t k = 0; k < nr_mem; k++ ) 
-                cout << labels[perm[k]] << " ";
-            cout << endl;
-        }
-        for( size_t li = 0; li < mi.max(); li++ ) {
-            vector<size_t> vi = mi.vi(li);
-            vector<size_t> pvi(vi.size(),0);
-            for( size_t k = 0; k < vi.size(); k++ )
-                pvi[k] = vi[perm[k]];
-            size_t pli = pmi.li(pvi);
-            if( verbose >= 3 ) {
-                cout << "    " << li << ": ";
-                for( vector<size_t>::iterator r=vi.begin(); r!=vi.end(); r++)
-                    cout << *r << " ";
-                cout << "-> ";
-                for( vector<size_t>::iterator r=pvi.begin(); r!=pvi.end(); r++)
-                    cout << *r << " ";
-                cout << ": " << pli << endl;
-            }
-            factors.back()[pli] = factordata[li];
+            prod *= dims[k];
         }
+        Permute permindex( di, perm );
+        for( size_t li = 0; li < prod; li++ )
+            factors.back()[permindex.convert_linear_index(li)] = factordata[li];
     }
 
     if( verbose >= 3 ) {
@@ -193,21 +167,14 @@ Factor mx2Factor(const mxArray *psi) {
 
     // read Factor
     vector<size_t> di(nr_mem,0);
-    vector<size_t> pdi(nr_mem,0);
+    size_t prod = 1;
     for( size_t k = 0; k < nr_mem; k++ ) {
         di[k] = dims[k];
-        pdi[k] = dims[perm[k]];
-    }
-    multind mi(di);
-    multind pmi(pdi);
-    for( size_t li = 0; li < mi.max(); li++ ) {
-        vector<size_t> vi = mi.vi(li);
-        vector<size_t> pvi(vi.size(),0);
-        for( size_t k = 0; k < vi.size(); k++ )
-            pvi[k] = vi[perm[k]];
-        size_t pli = pmi.li(pvi);
-        factor[pli] = factordata[li];
+        prod *= dims[k];
     }
+    Permute permindex( di, perm );
+    for( size_t li = 0; li < prod; li++ )
+        factor[permindex.convert_linear_index(li)] = factordata[li];
 
     return( factor );
 }
index ce27640..1d60bf4 100644 (file)
@@ -71,7 +71,7 @@ void BP::create() {
             newEP.newMessage = Prob( var(i).states() );
 
             newEP.index.reserve( factor(I).states() );
-            for( Index k( var(i), factor(I).vars() ); k >= 0; ++k )
+            for( IndexFor k( var(i), factor(I).vars() ); k >= 0; ++k )
                 newEP.index.push_back( k );
 
             newEP.residual = 0.0;
@@ -127,7 +127,7 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
     foreach( const Neighbor &j, nbF(I) ) {
         if( j != i ) {     // for all j in I \ i
             size_t _I = j.dual;
-            // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+            // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
             const ind_t & ind = index(j, _I);
 
             // prod_j will be the product of messages coming into j
@@ -145,7 +145,7 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
 
     // Marginalize onto i
     Prob marg( var(i).states(), 0.0 );
-    // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
+    // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
     const ind_t ind = index(i,_I);
     for( size_t r = 0; r < prod.size(); ++r )
         marg[ind[r]] += prod[r];
@@ -313,7 +313,7 @@ Factor BP::beliefF (size_t I) const {
 
     foreach( const Neighbor &j, nbF(I) ) {
         size_t _I = j.dual;
-        // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+        // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
         const ind_t & ind = index(j, _I);
 
         // prod_j will be the product of messages coming into j
index b32475d..a688f36 100644 (file)
@@ -34,32 +34,28 @@ using namespace std;
 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
     Factor Pns (ns);
     
-    multind mi( ns );
-
     InfAlg *clamped = obj.clone();
     if( !reInit )
         clamped->init();
 
     Complex logZ0;
-    for( size_t j = 0; j < mi.max(); j++ ) {
+    for( State s(ns); s.valid(); s++ ) {
         // save unclamped factors connected to ns
         clamped->saveProbs( ns );
 
         // set clamping Factors to delta functions
-        vector<size_t> vi = mi.vi( j );
-        size_t k = 0;
-        for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++, k++ )
-            clamped->clamp( *n, vi[k] );
+        for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
+            clamped->clamp( *n, s(*n) );
         
         // run DAIAlg, calc logZ, store in Pns
         if( clamped->Verbose() >= 2 )
-            cout << j << ": ";
+            cout << s << ": ";
         if( reInit )
             clamped->init();
         clamped->run();
 
         Complex Z;
-        if( j == 0 ) {
+        if( s == 0 ) {
             logZ0 = clamped->logZ();
             Z = 1.0;
         } else {
@@ -69,7 +65,7 @@ Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
                 cout << "Marginal:: WARNING: complex Z (" << Z << ")" << endl;
         }
 
-        Pns[j] = real(Z);
+        Pns[s] = real(Z);
         
         // restore clamped factors
         clamped->undoProbs( ns );
index 0b65e58..774fe31 100644 (file)
@@ -176,7 +176,7 @@ istream& operator >> (istream& is, FactorGraph& fg) {
             factors.push_back(Factor(I_vars,0.0));
             
             // calculate permutation sigma (internally, members are sorted)
-            vector<long> sigma(nr_members,0);
+            vector<size_t> sigma(nr_members,0);
             VarSet::iterator j = I_vars.begin();
             for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
                 long search_for = j->label();
@@ -190,22 +190,7 @@ istream& operator >> (istream& is, FactorGraph& fg) {
             }
 
             // calculate multindices
-            vector<size_t> sdims(nr_members,0);
-            for( size_t k = 0; k < nr_members; k++ ) {
-                sdims[k] = dims[sigma[k]];
-            }
-            multind mi(dims);
-            multind smi(sdims);
-            if( verbose >= 3 ) {
-                cout << "  mi.max(): " << mi.max() << endl;
-                cout << "       ";
-                for( size_t k=0; k < nr_members; k++ ) 
-                    cout << labels[k] << " ";
-                cout << "   ";
-                for( size_t k=0; k < nr_members; k++ ) 
-                    cout << labels[sigma[k]] << " ";
-                cout << endl;
-            }
+            Permute permindex( dims, sigma );
             
             // read values
             size_t nr_nonzeros;
@@ -224,19 +209,9 @@ istream& operator >> (istream& is, FactorGraph& fg) {
                     getline(is,line);
                 is >> val;
 
-                vector<size_t> vi = mi.vi(li);
-                vector<size_t> svi(vi.size(),0);
-                for( size_t k = 0; k < vi.size(); k++ )
-                    svi[k] = vi[sigma[k]];
-                size_t sli = smi.li(svi);
-                if( verbose >= 3 ) {
-                    cout << "    " << li << ": ";
-                    copy( vi.begin(), vi.end(), ostream_iterator<size_t>(cout," "));
-                    cout << "-> ";
-                    copy( svi.begin(), svi.end(), ostream_iterator<size_t>(cout," "));
-                    cout << ": " << sli << endl;
-                }
-                factors.back()[sli] = val;
+                // store value, but permute indices first according
+                // to internal representation
+                factors.back()[permindex.convert_linear_index( li  )] = val;
             }
         }
 
index 774c12f..a5e2bf4 100644 (file)
@@ -451,8 +451,6 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
             VarSet nsrem = ns / OR(T.front().n1).vars();
             Factor Pns (ns, 0.0);
             
-            multind mi( nsrem );
-
             // Save _Qa and _Qb on the subtree
             map<size_t,Factor> _Qa_old;
             map<size_t,Factor> _Qb_old;
@@ -476,20 +474,18 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
             }
                 
             // For all states of nsrem
-            for( size_t j = 0; j < mi.max(); j++ ) {
-                vector<size_t> vi = mi.vi( j );
+            for( State s(nsrem); s.valid(); s++ ) {
                 
                 // CollectEvidence
                 double logZ = 0.0;
                 for( size_t i = Tsize; (i--) != 0; ) {
-            //      Make outer region T[i].n1 consistent with outer region T[i].n2
-            //      IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
+                // Make outer region T[i].n1 consistent with outer region T[i].n2
+                // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
 
-                    size_t k = 0;
-                    for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++, k++ )
+                    for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
                         if( _Qa[T[i].n2].vars() >> *n ) {
                             Factor piet( *n, 0.0 );
-                            piet[vi[k]] = 1.0;
+                            piet[s(*n)] = 1.0;
                             _Qa[T[i].n2] *= piet; 
                         }
 
@@ -501,7 +497,7 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
                 logZ += log(_Qa[T[0].n1].normalize( Prob::NORMPROB ));
 
                 Factor piet( nsrem, 0.0 );
-                piet[j] = exp(logZ);
+                piet[s] = exp(logZ);
                 Pns += piet * _Qa[T[0].n1].part_sum( ns / nsrem );      // OPTIMIZE ME
 
                 // Restore clamped beliefs
index a12ff1b..d102e0f 100644 (file)
@@ -112,8 +112,6 @@ void TreeEPSubTree::InvertAndMultiply( const vector<Factor> &Qa, const vector<Fa
 
 
 void TreeEPSubTree::HUGIN_with_I( vector<Factor> &Qa, vector<Factor> &Qb ) {
-    multind mi( _nsrem );
-
     // Backup _Qa and _Qb
     vector<Factor> _Qa_old(_Qa);
     vector<Factor> _Qb_old(_Qb);
@@ -125,20 +123,17 @@ void TreeEPSubTree::HUGIN_with_I( vector<Factor> &Qa, vector<Factor> &Qb ) {
         Qb[_b[beta]].fill( 0.0 );
     
     // For all states of _nsrem
-    for( size_t j = 0; j < mi.max(); j++ ) {
-        vector<size_t> vi = mi.vi( j );
-        
+    for( State s(_nsrem); s.valid(); s++ ) {
         // Multiply root with slice of I
-        _Qa[0] *= _I->slice( _nsrem, j );
+        _Qa[0] *= _I->slice( _nsrem, s );
 
         // CollectEvidence
         for( size_t i = _RTree.size(); (i--) != 0; ) {
             // clamp variables in nsrem
-            size_t k = 0;
-            for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++, k++ )
+            for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
                 if( _Qa[_RTree[i].n2].vars() >> *n ) {
                     Factor delta( *n, 0.0 );
-                    delta[vi[k]] = 1.0;
+                    delta[s(*n)] = 1.0;
                     _Qa[_RTree[i].n2] *= delta;
                 }
             Factor new_Qb = _Qa[_RTree[i].n2].part_sum( _Qb[i].vars() );