Fixed testem failure caused by rounding error
[libdai.git] / include / dai / prob.h
index d165bfc..4b85374 100644 (file)
@@ -37,6 +37,7 @@
 #include <numeric>
 #include <functional>
 #include <dai/util.h>
+#include <dai/exceptions.h>
 
 
 namespace dai {
@@ -54,6 +55,11 @@ template <typename T> class TProb {
         std::vector<T> _p;
 
     public:
+        /// Iterator over entries
+       typedef typename std::vector<T>::iterator iterator;
+        /// Const iterator over entries
+       typedef typename std::vector<T>::const_iterator const_iterator;
+
         /// Enumerates different ways of normalizing a probability measure.
         /** 
          *  - NORMPROB means that the sum of all entries should be 1;
@@ -78,8 +84,17 @@ template <typename T> class TProb {
         /// Construct vector of length n with each entry set to p
         explicit TProb( size_t n, Real p ) : _p(n, (T)p) {}
         
-        /// Construct vector of length n by copying the elements between p and p+n
-        TProb( size_t n, const Real* p ) : _p(p, p + n ) {}
+        /// Construct vector from a range
+        /** \tparam Iterator Iterates over instances that can be cast to T.
+         *  \param begin Points to first instance to be added.
+         *  \param end Points just beyond last instance to be added.
+         *  \param sizeHint For efficiency, the number of entries can be speficied by sizeHint.
+         */
+        template <typename Iterator>
+        TProb( Iterator begin, Iterator end, size_t sizeHint=0 ) : _p() {
+            _p.reserve( sizeHint );
+            _p.insert( _p.begin(), begin, end );
+        }
         
         /// Returns a const reference to the vector
         const std::vector<T> & p() const { return _p; }
@@ -98,6 +113,18 @@ template <typename T> class TProb {
         
         /// Returns reference to the i'th entry
         T& operator[]( size_t i ) { return _p[i]; }
+        
+        /// Returns iterator pointing to first entry
+        iterator begin() { return _p.begin(); }
+
+        /// Returns const iterator pointing to first entry
+        const_iterator begin() const { return _p.begin(); }
+
+        /// Returns iterator pointing beyond last entry
+        iterator end() { return _p.end(); }
+
+        /// Returns const iterator pointing beyond last entry
+        const_iterator end() const { return _p.end(); }
 
         /// Sets all entries to x
         TProb<T> & fill(T x) { 
@@ -123,6 +150,12 @@ template <typename T> class TProb {
                     _p[i] = 0;
             return *this;
         }
+        
+        /// Set all entries to 1.0/size()
+        TProb<T>& setUniform () {
+            fill(1.0/size());
+            return *this;
+        }
 
         /// Sets entries that are smaller than epsilon to epsilon
         TProb<T>& makePositive( Real epsilon ) {
@@ -337,7 +370,7 @@ template <typename T> class TProb {
             TProb<T> x;
             x._p.reserve( size() );
             for( size_t i = 0; i < size(); i++ )
-                x._p.push_back( _p[i] < 0 ? (-p[i]) : p[i] );
+                x._p.push_back( _p[i] < 0 ? (-_p[i]) : _p[i] );
             return x;
         }
 
@@ -376,11 +409,19 @@ template <typename T> class TProb {
         }
 
         /// Returns sum of all entries
-        T totalSum() const {
+        T sum() const {
             T Z = std::accumulate( _p.begin(),  _p.end(), (T)0 );
             return Z;
         }
 
+        /// Return sum of absolute value of all entries
+        T sumAbs() const {
+            T s = 0;
+            for( size_t i = 0; i < size(); i++ )
+                s += fabs( (Real) _p[i] );
+            return s;
+        }
+
         /// Returns maximum absolute value of all entries
         T maxAbs() const {
             T Z = 0;
@@ -393,28 +434,41 @@ template <typename T> class TProb {
         }
 
         /// Returns maximum value of all entries
-        T maxVal() const {
+        T max() const {
             T Z = *std::max_element( _p.begin(), _p.end() );
             return Z;
         }
 
         /// Returns minimum value of all entries
-        T minVal() const {
+        T min() const {
             T Z = *std::min_element( _p.begin(), _p.end() );
             return Z;
         }
 
+        /// Returns {arg,}maximum value
+        std::pair<size_t,T> argmax() const {
+            T max = _p[0];
+            size_t arg = 0;
+            for( size_t i = 1; i < size(); i++ ) {
+              if( _p[i] > max ) {
+                max = _p[i];
+                arg = i;
+              }
+            }
+            return std::make_pair(arg,max);
+        }
+
         /// Normalizes vector using the specified norm
         T normalize( NormType norm=NORMPROB ) {
             T Z = 0.0;
             if( norm == NORMPROB )
-                Z = totalSum();
+                Z = sum();
             else if( norm == NORMLINF )
                 Z = maxAbs();
-#ifdef DAI_DEBUG
-            assert( Z != 0.0 );
-#endif
-            *this /= Z;
+            if( Z == 0.0 )
+                DAI_THROW(NOT_NORMALIZABLE);
+            else
+                *this /= Z;
             return Z;
         }
 
@@ -427,7 +481,13 @@ template <typename T> class TProb {
     
         /// Returns true if one or more entries are NaN
         bool hasNaNs() const {
-            return (std::find_if( _p.begin(), _p.end(), isnan ) != _p.end());
+            bool foundnan = false;
+            for( typename std::vector<T>::const_iterator x = _p.begin(); x != _p.end(); x++ )
+                if( isnan( *x ) ) {
+                    foundnan = true;
+                    break;
+                }
+            return foundnan;
         }
 
         /// Returns true if one or more entries are negative
@@ -442,6 +502,18 @@ template <typename T> class TProb {
                 S -= (_p[i] == 0 ? 0 : _p[i] * std::log(_p[i]));
             return S;
         }
+
+        /// Returns a random index, according to the (normalized) distribution described by *this
+        size_t draw() {
+            double x = rnd_uniform() * sum();
+            T s = 0;
+            for( size_t i = 0; i < size(); i++ ) {
+                s += _p[i];
+                if( s > x ) 
+                    return i;
+            }
+            return( size() - 1 );
+        }
 };