1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
13 /// \brief Defines TProb<> and Prob classes which represent (probability) vectors (e.g., probability distributions of discrete random variables)
16 #ifndef __defined_libdai_prob_h
17 #define __defined_libdai_prob_h
27 #include <dai/exceptions.h>
33 /// Function object that returns the value itself
34 template<typename T
> struct fo_id
: public std::unary_function
<T
, T
> {
36 T
operator()( const T
&x
) const {
42 /// Function object that takes the absolute value
43 template<typename T
> struct fo_abs
: public std::unary_function
<T
, T
> {
45 T
operator()( const T
&x
) const {
54 /// Function object that takes the exponent
55 template<typename T
> struct fo_exp
: public std::unary_function
<T
, T
> {
57 T
operator()( const T
&x
) const {
63 /// Function object that takes the logarithm
64 template<typename T
> struct fo_log
: public std::unary_function
<T
, T
> {
66 T
operator()( const T
&x
) const {
72 /// Function object that takes the logarithm, except that log(0) is defined to be 0
73 template<typename T
> struct fo_log0
: public std::unary_function
<T
, T
> {
74 /// Returns (\a x == 0 ? 0 : log(\a x))
75 T
operator()( const T
&x
) const {
84 /// Function object that takes the inverse
85 template<typename T
> struct fo_inv
: public std::unary_function
<T
, T
> {
87 T
operator()( const T
&x
) const {
93 /// Function object that takes the inverse, except that 1/0 is defined to be 0
94 template<typename T
> struct fo_inv0
: public std::unary_function
<T
, T
> {
95 /// Returns (\a x == 0 ? 0 : (1 / \a x))
96 T
operator()( const T
&x
) const {
105 /// Function object that returns p*log0(p)
106 template<typename T
> struct fo_plog0p
: public std::unary_function
<T
, T
> {
107 /// Returns \a p * log0(\a p)
108 T
operator()( const T
&p
) const {
109 return p
* dai::log0(p
);
114 /// Function object similar to std::divides(), but different in that dividing by zero results in zero
115 template<typename T
> struct fo_divides0
: public std::binary_function
<T
, T
, T
> {
116 /// Returns (\a y == 0 ? 0 : (\a x / \a y))
117 T
operator()( const T
&x
, const T
&y
) const {
126 /// Function object useful for calculating the KL distance
127 template<typename T
> struct fo_KL
: public std::binary_function
<T
, T
, T
> {
128 /// Returns (\a p == 0 ? 0 : (\a p * (log(\a p) - log(\a q))))
129 T
operator()( const T
&p
, const T
&q
) const {
133 return p
* (log(p
) - log(q
));
138 /// Function object useful for calculating the Hellinger distance
139 template<typename T
> struct fo_Hellinger
: public std::binary_function
<T
, T
, T
> {
140 /// Returns (sqrt(\a p) - sqrt(\a q))^2
141 T
operator()( const T
&p
, const T
&q
) const {
142 T x
= sqrt(p
) - sqrt(q
);
148 /// Function object that returns x to the power y
149 template<typename T
> struct fo_pow
: public std::binary_function
<T
, T
, T
> {
150 /// Returns (\a x ^ \a y)
151 T
operator()( const T
&x
, const T
&y
) const {
153 return std::pow( x
, y
);
160 /// Function object that returns the maximum of two values
161 template<typename T
> struct fo_max
: public std::binary_function
<T
, T
, T
> {
162 /// Returns (\a x > y ? x : y)
163 T
operator()( const T
&x
, const T
&y
) const {
164 return (x
> y
) ? x
: y
;
169 /// Function object that returns the minimum of two values
170 template<typename T
> struct fo_min
: public std::binary_function
<T
, T
, T
> {
171 /// Returns (\a x > y ? y : x)
172 T
operator()( const T
&x
, const T
&y
) const {
173 return (x
> y
) ? y
: x
;
178 /// Function object that returns the absolute difference of x and y
179 template<typename T
> struct fo_absdiff
: public std::binary_function
<T
, T
, T
> {
180 /// Returns abs( \a x - \a y )
181 T
operator()( const T
&x
, const T
&y
) const {
182 return dai::abs( x
- y
);
187 /// Represents a vector with entries of type \a T.
188 /** It is simply a <tt>std::vector</tt><<em>T</em>> with an interface designed for dealing with probability mass functions.
190 * It is mainly used for representing measures on a finite outcome space, for example, the probability
191 * distribution of a discrete random variable. However, entries are not necessarily non-negative; it is also used to
192 * represent logarithms of probability mass functions.
194 * \tparam T Should be a scalar that is castable from and to dai::Real and should support elementary arithmetic operations.
196 template <typename T
> class TProb
{
202 /// Enumerates different ways of normalizing a probability measure.
204 * - NORMPROB means that the sum of all entries should be 1;
205 * - NORMLINF means that the maximum absolute value of all entries should be 1.
207 typedef enum { NORMPROB
, NORMLINF
} NormType
;
208 /// Enumerates different distance measures between probability measures.
210 * - DISTL1 is the \f$\ell_1\f$ distance (sum of absolute values of pointwise difference);
211 * - DISTLINF is the \f$\ell_\infty\f$ distance (maximum absolute value of pointwise difference);
212 * - DISTTV is the total variation distance (half of the \f$\ell_1\f$ distance);
213 * - DISTKL is the Kullback-Leibler distance (\f$\sum_i p_i (\log p_i - \log q_i)\f$).
214 * - DISTHEL is the Hellinger distance (\f$\frac{1}{2}\sum_i (\sqrt{p_i}-\sqrt{q_i})^2\f$).
216 typedef enum { DISTL1
, DISTLINF
, DISTTV
, DISTKL
, DISTHEL
} DistType
;
218 /// \name Constructors and destructors
220 /// Default constructor (constructs empty vector)
223 /// Construct uniform probability distribution over \a n outcomes (i.e., a vector of length \a n with each entry set to \f$1/n\f$)
224 explicit TProb( size_t n
) : _p(std::vector
<T
>(n
, (T
)1 / n
)) {}
226 /// Construct vector of length \a n with each entry set to \a p
227 explicit TProb( size_t n
, T p
) : _p(n
, p
) {}
229 /// Construct vector from a range
230 /** \tparam TIterator Iterates over instances that can be cast to \a T
231 * \param begin Points to first instance to be added.
232 * \param end Points just beyond last instance to be added.
233 * \param sizeHint For efficiency, the number of entries can be speficied by \a sizeHint.
235 template <typename TIterator
>
236 TProb( TIterator begin
, TIterator end
, size_t sizeHint
=0 ) : _p() {
237 _p
.reserve( sizeHint
);
238 _p
.insert( _p
.begin(), begin
, end
);
241 /// Construct vector from another vector
242 /** \tparam S type of elements in \a v (should be castable to type \a T)
243 * \param v vector used for initialization
245 template <typename S
>
246 TProb( const std::vector
<S
> &v
) : _p() {
247 _p
.reserve( v
.size() );
248 _p
.insert( _p
.begin(), v
.begin(), v
.end() );
252 /// Constant iterator over the elements
253 typedef typename
std::vector
<T
>::const_iterator const_iterator
;
254 /// Iterator over the elements
255 typedef typename
std::vector
<T
>::iterator iterator
;
256 /// Constant reverse iterator over the elements
257 typedef typename
std::vector
<T
>::const_reverse_iterator const_reverse_iterator
;
258 /// Reverse iterator over the elements
259 typedef typename
std::vector
<T
>::reverse_iterator reverse_iterator
;
261 /// \name Iterator interface
263 /// Returns iterator that points to the first element
264 iterator
begin() { return _p
.begin(); }
265 /// Returns constant iterator that points to the first element
266 const_iterator
begin() const { return _p
.begin(); }
268 /// Returns iterator that points beyond the last element
269 iterator
end() { return _p
.end(); }
270 /// Returns constant iterator that points beyond the last element
271 const_iterator
end() const { return _p
.end(); }
273 /// Returns reverse iterator that points to the last element
274 reverse_iterator
rbegin() { return _p
.rbegin(); }
275 /// Returns constant reverse iterator that points to the last element
276 const_reverse_iterator
rbegin() const { return _p
.rbegin(); }
278 /// Returns reverse iterator that points beyond the first element
279 reverse_iterator
rend() { return _p
.rend(); }
280 /// Returns constant reverse iterator that points beyond the first element
281 const_reverse_iterator
rend() const { return _p
.rend(); }
286 /// Returns a const reference to the wrapped vector
287 const std::vector
<T
> & p() const { return _p
; }
289 /// Returns a reference to the wrapped vector
290 std::vector
<T
> & p() { return _p
; }
292 /// Returns a copy of the \a i 'th entry
293 T
operator[]( size_t i
) const {
301 /// Returns reference to the \a i 'th entry
302 T
& operator[]( size_t i
) { return _p
[i
]; }
304 /// Returns length of the vector (i.e., the number of entries)
305 size_t size() const { return _p
.size(); }
307 /// Accumulate over all values, similar to std::accumulate
308 template<typename binOp
, typename unOp
> T
accumulate( T init
, binOp op1
, unOp op2
) const {
310 for( const_iterator it
= begin(); it
!= end(); it
++ )
311 t
= op1( t
, op2(*it
) );
315 /// Returns the Shannon entropy of \c *this, \f$-\sum_i p_i \log p_i\f$
316 T
entropy() const { return -accumulate( (T
)0, std::plus
<T
>(), fo_plog0p
<T
>() ); }
318 /// Returns maximum value of all entries
319 T
max() const { return accumulate( (T
)(-INFINITY
), fo_max
<T
>(), fo_id
<T
>() ); }
321 /// Returns minimum value of all entries
322 T
min() const { return accumulate( (T
)INFINITY
, fo_min
<T
>(), fo_id
<T
>() ); }
324 /// Returns sum of all entries
325 T
sum() const { return accumulate( (T
)0, std::plus
<T
>(), fo_id
<T
>() ); }
327 /// Return sum of absolute value of all entries
328 T
sumAbs() const { return accumulate( (T
)0, std::plus
<T
>(), fo_abs
<T
>() ); }
330 /// Returns maximum absolute value of all entries
331 T
maxAbs() const { return accumulate( (T
)0, fo_max
<T
>(), fo_abs
<T
>() ); }
333 /// Returns \c true if one or more entries are NaN
334 bool hasNaNs() const {
335 bool foundnan
= false;
336 for( typename
std::vector
<T
>::const_iterator x
= _p
.begin(); x
!= _p
.end(); x
++ )
344 /// Returns \c true if one or more entries are negative
345 bool hasNegatives() const {
346 return (std::find_if( _p
.begin(), _p
.end(), std::bind2nd( std::less
<T
>(), (T
)0 ) ) != _p
.end());
349 /// Returns a pair consisting of the index of the maximum value and the maximum value itself
350 std::pair
<size_t,T
> argmax() const {
353 for( size_t i
= 1; i
< size(); i
++ ) {
359 return std::make_pair(arg
,max
);
362 /// Returns a random index, according to the (normalized) distribution described by *this
364 Real x
= rnd_uniform() * sum();
366 for( size_t i
= 0; i
< size(); i
++ ) {
371 return( size() - 1 );
374 /// Lexicographical comparison
375 /** \pre <tt>this->size() == q.size()</tt>
377 bool operator<= (const TProb
<T
> & q
) const {
378 DAI_DEBASSERT( size() == q
.size() );
379 return lexicographical_compare( begin(), end(), q
.begin(), q
.end() );
383 /// \name Unary transformations
385 /// Returns the result of applying operation \a op pointwise on \c *this
386 template<typename unaryOp
> TProb
<T
> pwUnaryTr( unaryOp op
) const {
388 r
._p
.reserve( size() );
389 std::transform( _p
.begin(), _p
.end(), back_inserter( r
._p
), op
);
393 /// Returns negative of \c *this
394 TProb
<T
> operator- () const { return pwUnaryTr( std::negate
<T
>() ); }
396 /// Returns pointwise absolute value
397 TProb
<T
> abs() const { return pwUnaryTr( fo_abs
<T
>() ); }
399 /// Returns pointwise exponent
400 TProb
<T
> exp() const { return pwUnaryTr( fo_exp
<T
>() ); }
402 /// Returns pointwise logarithm
403 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
405 TProb
<T
> log(bool zero
=false) const {
407 return pwUnaryTr( fo_log0
<T
>() );
409 return pwUnaryTr( fo_log
<T
>() );
412 /// Returns pointwise inverse
413 /** If \a zero == \c true, uses <tt>1/0==0</tt>; otherwise, <tt>1/0==Inf</tt>.
415 TProb
<T
> inverse(bool zero
=true) const {
417 return pwUnaryTr( fo_inv0
<T
>() );
419 return pwUnaryTr( fo_inv
<T
>() );
422 /// Returns normalized copy of \c *this, using the specified norm
423 /** \throw NOT_NORMALIZABLE if the norm is zero
425 TProb
<T
> normalized( NormType norm
= NORMPROB
) const {
427 if( norm
== NORMPROB
)
429 else if( norm
== NORMLINF
)
432 DAI_THROW(NOT_NORMALIZABLE
);
435 return pwUnaryTr( std::bind2nd( std::divides
<T
>(), Z
) );
439 /// \name Unary operations
441 /// Applies unary operation \a op pointwise
442 template<typename unaryOp
> TProb
<T
>& pwUnaryOp( unaryOp op
) {
443 std::transform( _p
.begin(), _p
.end(), _p
.begin(), op
);
447 /// Draws all entries i.i.d. from a uniform distribution on [0,1)
448 TProb
<T
>& randomize() {
449 std::generate( _p
.begin(), _p
.end(), rnd_uniform
);
453 /// Sets all entries to \f$1/n\f$ where \a n is the length of the vector
454 TProb
<T
>& setUniform () {
455 fill( (T
)1 / size() );
459 /// Applies absolute value pointwise
460 const TProb
<T
>& takeAbs() { return pwUnaryOp( fo_abs
<T
>() ); }
462 /// Applies exponent pointwise
463 const TProb
<T
>& takeExp() { return pwUnaryOp( fo_exp
<T
>() ); }
465 /// Applies logarithm pointwise
466 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
468 const TProb
<T
>& takeLog(bool zero
=false) {
470 return pwUnaryOp( fo_log0
<T
>() );
472 return pwUnaryOp( fo_log
<T
>() );
475 /// Normalizes vector using the specified norm
476 /** \throw NOT_NORMALIZABLE if the norm is zero
478 T
normalize( NormType norm
=NORMPROB
) {
480 if( norm
== NORMPROB
)
482 else if( norm
== NORMLINF
)
485 DAI_THROW(NOT_NORMALIZABLE
);
492 /// \name Operations with scalars
494 /// Sets all entries to \a x
495 TProb
<T
> & fill(T x
) {
496 std::fill( _p
.begin(), _p
.end(), x
);
500 /// Adds scalar \a x to each entry
501 TProb
<T
>& operator+= (T x
) {
503 return pwUnaryOp( std::bind2nd( std::plus
<T
>(), x
) );
508 /// Subtracts scalar \a x from each entry
509 TProb
<T
>& operator-= (T x
) {
511 return pwUnaryOp( std::bind2nd( std::minus
<T
>(), x
) );
516 /// Multiplies each entry with scalar \a x
517 TProb
<T
>& operator*= (T x
) {
519 return pwUnaryOp( std::bind2nd( std::multiplies
<T
>(), x
) );
524 /// Divides each entry by scalar \a x
525 TProb
<T
>& operator/= (T x
) {
526 DAI_DEBASSERT( x
!= 0 );
528 return pwUnaryOp( std::bind2nd( std::divides
<T
>(), x
) );
533 /// Raises entries to the power \a x
534 TProb
<T
>& operator^= (T x
) {
536 return pwUnaryOp( std::bind2nd( fo_pow
<T
>(), x
) );
542 /// \name Transformations with scalars
544 /// Returns sum of \c *this and scalar \a x
545 TProb
<T
> operator+ (T x
) const { return pwUnaryTr( std::bind2nd( std::plus
<T
>(), x
) ); }
547 /// Returns difference of \c *this and scalar \a x
548 TProb
<T
> operator- (T x
) const { return pwUnaryTr( std::bind2nd( std::minus
<T
>(), x
) ); }
550 /// Returns product of \c *this with scalar \a x
551 TProb
<T
> operator* (T x
) const { return pwUnaryTr( std::bind2nd( std::multiplies
<T
>(), x
) ); }
553 /// Returns quotient of \c *this and scalar \a x, where division by 0 yields 0
554 TProb
<T
> operator/ (T x
) const { return pwUnaryTr( std::bind2nd( fo_divides0
<T
>(), x
) ); }
556 /// Returns \c *this raised to the power \a x
557 TProb
<T
> operator^ (T x
) const { return pwUnaryTr( std::bind2nd( fo_pow
<T
>(), x
) ); }
560 /// \name Operations with other equally-sized vectors
562 /// Applies binary operation pointwise on two vectors
563 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
564 * \param q Right operand
565 * \param op Operation of type \a binaryOp
567 template<typename binaryOp
> TProb
<T
>& pwBinaryOp( const TProb
<T
> &q
, binaryOp op
) {
568 DAI_DEBASSERT( size() == q
.size() );
569 std::transform( _p
.begin(), _p
.end(), q
._p
.begin(), _p
.begin(), op
);
573 /// Pointwise addition with \a q
574 /** \pre <tt>this->size() == q.size()</tt>
576 TProb
<T
>& operator+= (const TProb
<T
> & q
) { return pwBinaryOp( q
, std::plus
<T
>() ); }
578 /// Pointwise subtraction of \a q
579 /** \pre <tt>this->size() == q.size()</tt>
581 TProb
<T
>& operator-= (const TProb
<T
> & q
) { return pwBinaryOp( q
, std::minus
<T
>() ); }
583 /// Pointwise multiplication with \a q
584 /** \pre <tt>this->size() == q.size()</tt>
586 TProb
<T
>& operator*= (const TProb
<T
> & q
) { return pwBinaryOp( q
, std::multiplies
<T
>() ); }
588 /// Pointwise division by \a q, where division by 0 yields 0
589 /** \pre <tt>this->size() == q.size()</tt>
590 * \see divide(const TProb<T> &)
592 TProb
<T
>& operator/= (const TProb
<T
> & q
) { return pwBinaryOp( q
, fo_divides0
<T
>() ); }
594 /// Pointwise division by \a q, where division by 0 yields +Inf
595 /** \pre <tt>this->size() == q.size()</tt>
596 * \see operator/=(const TProb<T> &)
598 TProb
<T
>& divide (const TProb
<T
> & q
) { return pwBinaryOp( q
, std::divides
<T
>() ); }
601 /** \pre <tt>this->size() == q.size()</tt>
603 TProb
<T
>& operator^= (const TProb
<T
> & q
) { return pwBinaryOp( q
, fo_pow
<T
>() ); }
606 /// \name Transformations with other equally-sized vectors
608 /// Returns the result of applying binary operation \a op pointwise on \c *this and \a q
609 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
610 * \param q Right operand
611 * \param op Operation of type \a binaryOp
613 template<typename binaryOp
> TProb
<T
> pwBinaryTr( const TProb
<T
> &q
, binaryOp op
) const {
614 DAI_DEBASSERT( size() == q
.size() );
616 r
._p
.reserve( size() );
617 std::transform( _p
.begin(), _p
.end(), q
._p
.begin(), back_inserter( r
._p
), op
);
621 /// Returns sum of \c *this and \a q
622 /** \pre <tt>this->size() == q.size()</tt>
624 TProb
<T
> operator+ ( const TProb
<T
>& q
) const { return pwBinaryTr( q
, std::plus
<T
>() ); }
626 /// Return \c *this minus \a q
627 /** \pre <tt>this->size() == q.size()</tt>
629 TProb
<T
> operator- ( const TProb
<T
>& q
) const { return pwBinaryTr( q
, std::minus
<T
>() ); }
631 /// Return product of \c *this with \a q
632 /** \pre <tt>this->size() == q.size()</tt>
634 TProb
<T
> operator* ( const TProb
<T
> &q
) const { return pwBinaryTr( q
, std::multiplies
<T
>() ); }
636 /// Returns quotient of \c *this with \a q, where division by 0 yields 0
637 /** \pre <tt>this->size() == q.size()</tt>
638 * \see divided_by(const TProb<T> &)
640 TProb
<T
> operator/ ( const TProb
<T
> &q
) const { return pwBinaryTr( q
, fo_divides0
<T
>() ); }
642 /// Pointwise division by \a q, where division by 0 yields +Inf
643 /** \pre <tt>this->size() == q.size()</tt>
644 * \see operator/(const TProb<T> &)
646 TProb
<T
> divided_by( const TProb
<T
> &q
) const { return pwBinaryTr( q
, std::divides
<T
>() ); }
648 /// Returns \c *this to the power \a q
649 /** \pre <tt>this->size() == q.size()</tt>
651 TProb
<T
> operator^ ( const TProb
<T
> &q
) const { return pwBinaryTr( q
, fo_pow
<T
>() ); }
654 /// Performs a generalized inner product, similar to std::inner_product
655 /** \pre <tt>this->size() == q.size()</tt>
657 template<typename binOp1
, typename binOp2
> T
innerProduct( const TProb
<T
> &q
, T init
, binOp1 binaryOp1
, binOp2 binaryOp2
) const {
658 DAI_DEBASSERT( size() == q
.size() );
659 return std::inner_product( begin(), end(), q
.begin(), init
, binaryOp1
, binaryOp2
);
664 /// Returns distance between \a p and \a q, measured using distance measure \a dt
666 * \pre <tt>this->size() == q.size()</tt>
668 template<typename T
> T
dist( const TProb
<T
> &p
, const TProb
<T
> &q
, typename TProb
<T
>::DistType dt
) {
670 case TProb
<T
>::DISTL1
:
671 return p
.innerProduct( q
, (T
)0, std::plus
<T
>(), fo_absdiff
<T
>() );
672 case TProb
<T
>::DISTLINF
:
673 return p
.innerProduct( q
, (T
)0, fo_max
<T
>(), fo_absdiff
<T
>() );
674 case TProb
<T
>::DISTTV
:
675 return p
.innerProduct( q
, (T
)0, std::plus
<T
>(), fo_absdiff
<T
>() ) / 2;
676 case TProb
<T
>::DISTKL
:
677 return p
.innerProduct( q
, (T
)0, std::plus
<T
>(), fo_KL
<T
>() );
678 case TProb
<T
>::DISTHEL
:
679 return p
.innerProduct( q
, (T
)0, std::plus
<T
>(), fo_Hellinger
<T
>() ) / 2;
681 DAI_THROW(UNKNOWN_ENUM_VALUE
);
687 /// Writes a TProb<T> to an output stream
690 template<typename T
> std::ostream
& operator<< (std::ostream
& os
, const TProb
<T
>& p
) {
692 std::copy( p
.p().begin(), p
.p().end(), std::ostream_iterator
<T
>(os
, " ") );
698 /// Returns the pointwise minimum of \a a and \a b
700 * \pre <tt>this->size() == q.size()</tt>
702 template<typename T
> TProb
<T
> min( const TProb
<T
> &a
, const TProb
<T
> &b
) {
703 return a
.pwBinaryTr( b
, fo_min
<T
>() );
707 /// Returns the pointwise maximum of \a a and \a b
709 * \pre <tt>this->size() == q.size()</tt>
711 template<typename T
> TProb
<T
> max( const TProb
<T
> &a
, const TProb
<T
> &b
) {
712 return a
.pwBinaryTr( b
, fo_max
<T
>() );
716 /// Represents a vector with entries of type dai::Real.
717 typedef TProb
<Real
> Prob
;
720 } // end of namespace dai