Some cleanups of include/dai/prob.h
[libdai.git] / include / dai / prob.h
index 9612d78..4c250d2 100644 (file)
 namespace dai {
 
 
+/// Function object that returns the value itself
+template<typename T> struct fo_id : public std::unary_function<T, T> {
+    /// Returns \a x
+    T operator()( const T &x ) const {
+        return x;
+    }
+};
+
+
 /// Function object that takes the absolute value
 template<typename T> struct fo_abs : public std::unary_function<T, T> {
     /// Returns abs(\a x)
@@ -93,6 +102,15 @@ template<typename T> struct fo_inv0 : public std::unary_function<T, T> {
 };
 
 
+/// Function object that returns p*log0(p)
+template<typename T> struct fo_plog0p : public std::unary_function<T, T> {
+    /// Returns \a p * log0(\a p)
+    T operator()( const T &p ) const {
+        return p * dai::log0(p);
+    }
+};
+
+
 /// Function object similar to std::divides(), but different in that dividing by zero results in zero
 template<typename T> struct fo_divides0 : public std::binary_function<T, T, T> {
     /// Returns (\a y == 0 ? 0 : (\a x / \a y))
@@ -147,33 +165,6 @@ template<typename T> struct fo_min : public std::binary_function<T, T, T> {
 };
 
 
-/// Function object that returns the sum of x and abs(y)
-template<typename T> struct fo_plusabs : public std::binary_function<T, T, T> {
-    /// Returns \a x + abs(\a y)
-    T operator()( const T &x, const T &y ) const {
-        return (y < 0) ? x - y : x + y;
-    }
-};
-
-
-/// Function object that returns the sum of x and abs(y)
-template<typename T> struct fo_maxabs : public std::binary_function<T, T, T> {
-    /// Returns max(\a x, abs(\a y))
-    T operator()( const T &x, const T &y ) const {
-        return (y < 0) ? (x > (-y) ? x : (-y)) : (x > y ? x : y);
-    }
-};
-
-
-/// Function object that returns the sum of x and y*log0(y)
-template<typename T> struct fo_plusplog0p : public std::binary_function<T, T, T> {
-    /// Returns \a x + \a p * log0(\a p)
-    T operator()( const T &x, const T &p ) const {
-        return x + p * dai::log0(p);
-    }
-};
-
-
 /// Function object that returns the absolute difference of x and y
 template<typename T> struct fo_absdiff : public std::binary_function<T, T, T> {
     /// Returns abs( \a x - \a y )
@@ -302,48 +293,31 @@ template <typename T> class TProb {
         /// Returns length of the vector (i.e., the number of entries)
         size_t size() const { return _p.size(); }
 
-        /// Returns the Shannon entropy of \c *this, \f$-\sum_i p_i \log p_i\f$
-        T entropy() const {
-            return -std::accumulate( _p.begin(), _p.end(), (T)0, fo_plusplog0p<T>() );
+        /// Accumulate over all values, similar to std::accumulate
+        template<typename binOp, typename unOp> T accumulate( T init, binOp op1, unOp op2 ) const {
+            T t = init;
+            for( const_iterator it = begin(); it != end(); it++ )
+                t = op1( t, op2(*it) );
+            return t;
         }
 
+        /// Returns the Shannon entropy of \c *this, \f$-\sum_i p_i \log p_i\f$
+        T entropy() const { return -accumulate( (T)0, std::plus<T>(), fo_plog0p<T>() ); }
+
         /// Returns maximum value of all entries
-        T max() const {
-            return std::accumulate( _p.begin(), _p.end(), (T)(-INFINITY), fo_max<T>() );
-        }
+        T max() const { return accumulate( (T)(-INFINITY), fo_max<T>(), fo_id<T>() ); }
 
         /// Returns minimum value of all entries
-        T min() const {
-            return std::accumulate( _p.begin(), _p.end(), (T)INFINITY, fo_min<T>() );
-        }
-
-        /// Returns a pair consisting of the index of the maximum value and the maximum value itself
-        std::pair<size_t,T> argmax() const {
-            T max = _p[0];
-            size_t arg = 0;
-            for( size_t i = 1; i < size(); i++ ) {
-              if( _p[i] > max ) {
-                max = _p[i];
-                arg = i;
-              }
-            }
-            return std::make_pair(arg,max);
-        }
+        T min() const { return accumulate( (T)INFINITY, fo_min<T>(), fo_id<T>() ); }
 
         /// Returns sum of all entries
-        T sum() const {
-            return std::accumulate( _p.begin(), _p.end(), (T)0 );
-        }
+        T sum() const { return accumulate( (T)0, std::plus<T>(), fo_id<T>() ); }
 
         /// Return sum of absolute value of all entries
-        T sumAbs() const {
-            return std::accumulate( _p.begin(), _p.end(), (T)0, fo_plusabs<T>() );
-        }
+        T sumAbs() const { return accumulate( (T)0, std::plus<T>(), fo_abs<T>() ); }
 
         /// Returns maximum absolute value of all entries
-        T maxAbs() const {
-            return std::accumulate( _p.begin(), _p.end(), (T)0, fo_maxabs<T>() );
-        }
+        T maxAbs() const { return accumulate( (T)0, fo_max<T>(), fo_abs<T>() ); }
 
         /// Returns \c true if one or more entries are NaN
         bool hasNaNs() const {
@@ -361,6 +335,19 @@ template <typename T> class TProb {
             return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<T>(), (T)0 ) ) != _p.end());
         }
 
+        /// Returns a pair consisting of the index of the maximum value and the maximum value itself
+        std::pair<size_t,T> argmax() const {
+            T max = _p[0];
+            size_t arg = 0;
+            for( size_t i = 1; i < size(); i++ ) {
+              if( _p[i] > max ) {
+                max = _p[i];
+                arg = i;
+              }
+            }
+            return std::make_pair(arg,max);
+        }
+
         /// Returns a random index, according to the (normalized) distribution described by *this
         size_t draw() {
             Real x = rnd_uniform() * sum();
@@ -411,14 +398,10 @@ template <typename T> class TProb {
         }
 
         /// Returns pointwise absolute value
-        TProb<T> abs() const {
-            return pwUnaryTr( fo_abs<T>() );
-        }
+        TProb<T> abs() const { return pwUnaryTr( fo_abs<T>() ); }
 
         /// Returns pointwise exponent
-        TProb<T> exp() const {
-            return pwUnaryTr( fo_exp<T>() );
-        }
+        TProb<T> exp() const { return pwUnaryTr( fo_exp<T>() ); }
 
         /// Returns pointwise logarithm
         /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
@@ -598,7 +581,7 @@ template <typename T> class TProb {
 
     /// \name Operations with other equally-sized vectors
     //@{
-        /// Applies binary operation pointwise
+        /// Applies binary operation pointwise on two vectors
         /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
          *  \param q Right operand
          *  \param op Operation of type \a binaryOp
@@ -689,6 +672,14 @@ template <typename T> class TProb {
          */
         TProb<T> operator^ ( const TProb<T> &q ) const { return pwBinaryTr( q, fo_pow<T>() ); }
     //@}
+
+        /// Performs a generalized inner product, similar to std::inner_product
+        /** \pre <tt>this->size() == q.size()</tt>
+         */
+        template<typename binOp1, typename binOp2> T innerProduct( const TProb<T> &q, T init, binOp1 binaryOp1, binOp2 binaryOp2 ) const {
+            DAI_DEBASSERT( size() == q.size() );
+            return std::inner_product( begin(), end(), q.begin(), init, binaryOp1, binaryOp2 );
+        }
 };
 
 
@@ -697,16 +688,15 @@ template <typename T> class TProb {
  *  \pre <tt>this->size() == q.size()</tt>
  */
 template<typename T> T dist( const TProb<T> &p, const TProb<T> &q, typename TProb<T>::DistType dt ) {
-    DAI_DEBASSERT( p.size() == q.size() );
     switch( dt ) {
         case TProb<T>::DISTL1:
-            return std::inner_product( p.begin(), p.end(), q.begin(), (T)0, std::plus<T>(), fo_absdiff<T>() );
+            return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() );
         case TProb<T>::DISTLINF:
-            return std::inner_product( p.begin(), p.end(), q.begin(), (T)0, fo_max<T>(), fo_absdiff<T>() );
+            return p.innerProduct( q, (T)0, fo_max<T>(), fo_absdiff<T>() );
         case TProb<T>::DISTTV:
-            return std::inner_product( p.begin(), p.end(), q.begin(), (T)0, std::plus<T>(), fo_absdiff<T>() ) / 2;
+            return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() ) / 2;
         case TProb<T>::DISTKL:
-            return std::inner_product( p.begin(), p.end(), q.begin(), (T)0, std::plus<T>(), fo_KL<T>() );
+            return p.innerProduct( q, (T)0, std::plus<T>(), fo_KL<T>() );
         default:
             DAI_THROW(UNKNOWN_ENUM_VALUE);
             return INFINITY;