1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
13 /// \brief Defines TProb<> and Prob classes which represent (probability) vectors (e.g., probability distributions of discrete random variables)
16 #ifndef __defined_libdai_prob_h
17 #define __defined_libdai_prob_h
27 #include <dai/exceptions.h>
33 /// Function object that returns the value itself
34 template<typename T
> struct fo_id
: public std::unary_function
<T
, T
> {
36 T
operator()( const T
&x
) const {
42 /// Function object that takes the absolute value
43 template<typename T
> struct fo_abs
: public std::unary_function
<T
, T
> {
45 T
operator()( const T
&x
) const {
54 /// Function object that takes the exponent
55 template<typename T
> struct fo_exp
: public std::unary_function
<T
, T
> {
57 T
operator()( const T
&x
) const {
63 /// Function object that takes the logarithm
64 template<typename T
> struct fo_log
: public std::unary_function
<T
, T
> {
66 T
operator()( const T
&x
) const {
72 /// Function object that takes the logarithm, except that log(0) is defined to be 0
73 template<typename T
> struct fo_log0
: public std::unary_function
<T
, T
> {
74 /// Returns (\a x == 0 ? 0 : log(\a x))
75 T
operator()( const T
&x
) const {
84 /// Function object that takes the inverse
85 template<typename T
> struct fo_inv
: public std::unary_function
<T
, T
> {
87 T
operator()( const T
&x
) const {
93 /// Function object that takes the inverse, except that 1/0 is defined to be 0
94 template<typename T
> struct fo_inv0
: public std::unary_function
<T
, T
> {
95 /// Returns (\a x == 0 ? 0 : (1 / \a x))
96 T
operator()( const T
&x
) const {
105 /// Function object that returns p*log0(p)
106 template<typename T
> struct fo_plog0p
: public std::unary_function
<T
, T
> {
107 /// Returns \a p * log0(\a p)
108 T
operator()( const T
&p
) const {
109 return p
* dai::log0(p
);
114 /// Function object similar to std::divides(), but different in that dividing by zero results in zero
115 template<typename T
> struct fo_divides0
: public std::binary_function
<T
, T
, T
> {
116 /// Returns (\a y == 0 ? 0 : (\a x / \a y))
117 T
operator()( const T
&x
, const T
&y
) const {
126 /// Function object useful for calculating the KL distance
127 template<typename T
> struct fo_KL
: public std::binary_function
<T
, T
, T
> {
128 /// Returns (\a p == 0 ? 0 : (\a p * (log(\a p) - log(\a q))))
129 T
operator()( const T
&p
, const T
&q
) const {
133 return p
* (log(p
) - log(q
));
138 /// Function object useful for calculating the Hellinger distance
139 template<typename T
> struct fo_Hellinger
: public std::binary_function
<T
, T
, T
> {
140 /// Returns (sqrt(\a p) - sqrt(\a q))^2
141 T
operator()( const T
&p
, const T
&q
) const {
142 T x
= sqrt(p
) - sqrt(q
);
148 /// Function object that returns x to the power y
149 template<typename T
> struct fo_pow
: public std::binary_function
<T
, T
, T
> {
150 /// Returns (\a x ^ \a y)
151 T
operator()( const T
&x
, const T
&y
) const {
153 return std::pow( x
, y
);
160 /// Function object that returns the maximum of two values
161 template<typename T
> struct fo_max
: public std::binary_function
<T
, T
, T
> {
162 /// Returns (\a x > y ? x : y)
163 T
operator()( const T
&x
, const T
&y
) const {
164 return (x
> y
) ? x
: y
;
169 /// Function object that returns the minimum of two values
170 template<typename T
> struct fo_min
: public std::binary_function
<T
, T
, T
> {
171 /// Returns (\a x > y ? y : x)
172 T
operator()( const T
&x
, const T
&y
) const {
173 return (x
> y
) ? y
: x
;
178 /// Function object that returns the absolute difference of x and y
179 template<typename T
> struct fo_absdiff
: public std::binary_function
<T
, T
, T
> {
180 /// Returns abs( \a x - \a y )
181 T
operator()( const T
&x
, const T
&y
) const {
182 return dai::abs( x
- y
);
187 /// Represents a vector with entries of type \a T.
188 /** It is simply a <tt>std::vector</tt><<em>T</em>> with an interface designed for dealing with probability mass functions.
190 * It is mainly used for representing measures on a finite outcome space, for example, the probability
191 * distribution of a discrete random variable. However, entries are not necessarily non-negative; it is also used to
192 * represent logarithms of probability mass functions.
194 * \tparam T Should be a scalar that is castable from and to dai::Real and should support elementary arithmetic operations.
196 template <typename T
> class TProb
{
202 /// Enumerates different ways of normalizing a probability measure.
204 * - NORMPROB means that the sum of all entries should be 1;
205 * - NORMLINF means that the maximum absolute value of all entries should be 1.
207 typedef enum { NORMPROB
, NORMLINF
} NormType
;
208 /// Enumerates different distance measures between probability measures.
210 * - DISTL1 is the \f$\ell_1\f$ distance (sum of absolute values of pointwise difference);
211 * - DISTLINF is the \f$\ell_\infty\f$ distance (maximum absolute value of pointwise difference);
212 * - DISTTV is the total variation distance (half of the \f$\ell_1\f$ distance);
213 * - DISTKL is the Kullback-Leibler distance (\f$\sum_i p_i (\log p_i - \log q_i)\f$).
214 * - DISTHEL is the Hellinger distance (\f$\frac{1}{2}\sum_i (\sqrt{p_i}-\sqrt{q_i})^2\f$).
216 typedef enum { DISTL1
, DISTLINF
, DISTTV
, DISTKL
, DISTHEL
} DistType
;
218 /// \name Constructors and destructors
220 /// Default constructor (constructs empty vector)
223 /// Construct uniform probability distribution over \a n outcomes (i.e., a vector of length \a n with each entry set to \f$1/n\f$)
224 explicit TProb( size_t n
) : _p(std::vector
<T
>(n
, (T
)1 / n
)) {}
226 /// Construct vector of length \a n with each entry set to \a p
227 explicit TProb( size_t n
, T p
) : _p(n
, p
) {}
229 /// Construct vector from a range
230 /** \tparam TIterator Iterates over instances that can be cast to \a T
231 * \param begin Points to first instance to be added.
232 * \param end Points just beyond last instance to be added.
233 * \param sizeHint For efficiency, the number of entries can be speficied by \a sizeHint.
235 template <typename TIterator
>
236 TProb( TIterator begin
, TIterator end
, size_t sizeHint
=0 ) : _p() {
237 _p
.reserve( sizeHint
);
238 _p
.insert( _p
.begin(), begin
, end
);
241 /// Construct vector from another vector
242 /** \tparam S type of elements in \a v (should be castable to type \a T)
243 * \param v vector used for initialization
245 template <typename S
>
246 TProb( const std::vector
<S
> &v
) : _p() {
247 _p
.reserve( v
.size() );
248 _p
.insert( _p
.begin(), v
.begin(), v
.end() );
252 /// Constant iterator over the elements
253 typedef typename
std::vector
<T
>::const_iterator const_iterator
;
254 /// Iterator over the elements
255 typedef typename
std::vector
<T
>::iterator iterator
;
256 /// Constant reverse iterator over the elements
257 typedef typename
std::vector
<T
>::const_reverse_iterator const_reverse_iterator
;
258 /// Reverse iterator over the elements
259 typedef typename
std::vector
<T
>::reverse_iterator reverse_iterator
;
261 /// \name Iterator interface
263 /// Returns iterator that points to the first element
264 iterator
begin() { return _p
.begin(); }
265 /// Returns constant iterator that points to the first element
266 const_iterator
begin() const { return _p
.begin(); }
268 /// Returns iterator that points beyond the last element
269 iterator
end() { return _p
.end(); }
270 /// Returns constant iterator that points beyond the last element
271 const_iterator
end() const { return _p
.end(); }
273 /// Returns reverse iterator that points to the last element
274 reverse_iterator
rbegin() { return _p
.rbegin(); }
275 /// Returns constant reverse iterator that points to the last element
276 const_reverse_iterator
rbegin() const { return _p
.rbegin(); }
278 /// Returns reverse iterator that points beyond the first element
279 reverse_iterator
rend() { return _p
.rend(); }
280 /// Returns constant reverse iterator that points beyond the first element
281 const_reverse_iterator
rend() const { return _p
.rend(); }
286 /// Returns a const reference to the wrapped vector
287 const std::vector
<T
> & p() const { return _p
; }
289 /// Returns a reference to the wrapped vector
290 std::vector
<T
> & p() { return _p
; }
292 /// Returns a copy of the \a i 'th entry
293 T
operator[]( size_t i
) const {
301 /// Returns reference to the \a i 'th entry
302 T
& operator[]( size_t i
) { return _p
[i
]; }
304 /// Returns length of the vector (i.e., the number of entries)
305 size_t size() const { return _p
.size(); }
307 /// Accumulate over all values, similar to std::accumulate
308 template<typename binOp
, typename unOp
> T
accumulate( T init
, binOp op1
, unOp op2
) const {
310 for( const_iterator it
= begin(); it
!= end(); it
++ )
311 t
= op1( t
, op2(*it
) );
315 /// Returns the Shannon entropy of \c *this, \f$-\sum_i p_i \log p_i\f$
316 T
entropy() const { return -accumulate( (T
)0, std::plus
<T
>(), fo_plog0p
<T
>() ); }
318 /// Returns maximum value of all entries
319 T
max() const { return accumulate( (T
)(-INFINITY
), fo_max
<T
>(), fo_id
<T
>() ); }
321 /// Returns minimum value of all entries
322 T
min() const { return accumulate( (T
)INFINITY
, fo_min
<T
>(), fo_id
<T
>() ); }
324 /// Returns sum of all entries
325 T
sum() const { return accumulate( (T
)0, std::plus
<T
>(), fo_id
<T
>() ); }
327 /// Return sum of absolute value of all entries
328 T
sumAbs() const { return accumulate( (T
)0, std::plus
<T
>(), fo_abs
<T
>() ); }
330 /// Returns maximum absolute value of all entries
331 T
maxAbs() const { return accumulate( (T
)0, fo_max
<T
>(), fo_abs
<T
>() ); }
333 /// Returns \c true if one or more entries are NaN
334 bool hasNaNs() const {
335 bool foundnan
= false;
336 for( typename
std::vector
<T
>::const_iterator x
= _p
.begin(); x
!= _p
.end(); x
++ )
344 /// Returns \c true if one or more entries are negative
345 bool hasNegatives() const {
346 return (std::find_if( _p
.begin(), _p
.end(), std::bind2nd( std::less
<T
>(), (T
)0 ) ) != _p
.end());
349 /// Returns a pair consisting of the index of the maximum value and the maximum value itself
350 std::pair
<size_t,T
> argmax() const {
353 for( size_t i
= 1; i
< size(); i
++ ) {
359 return std::make_pair(arg
,max
);
362 /// Returns a random index, according to the (normalized) distribution described by *this
364 Real x
= rnd_uniform() * sum();
366 for( size_t i
= 0; i
< size(); i
++ ) {
371 return( size() - 1 );
374 /// Lexicographical comparison
375 /** \pre <tt>this->size() == q.size()</tt>
377 bool operator<= (const TProb
<T
> & q
) const {
378 DAI_DEBASSERT( size() == q
.size() );
379 return lexicographical_compare( begin(), end(), q
.begin(), q
.end() );
383 /// \name Unary transformations
385 /// Returns the result of applying operation \a op pointwise on \c *this
386 template<typename unaryOp
> TProb
<T
> pwUnaryTr( unaryOp op
) const {
388 r
._p
.reserve( size() );
389 std::transform( _p
.begin(), _p
.end(), back_inserter( r
._p
), op
);
393 /// Returns pointwise absolute value
394 TProb
<T
> abs() const { return pwUnaryTr( fo_abs
<T
>() ); }
396 /// Returns pointwise exponent
397 TProb
<T
> exp() const { return pwUnaryTr( fo_exp
<T
>() ); }
399 /// Returns pointwise logarithm
400 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
402 TProb
<T
> log(bool zero
=false) const {
404 return pwUnaryTr( fo_log0
<T
>() );
406 return pwUnaryTr( fo_log
<T
>() );
409 /// Returns pointwise inverse
410 /** If \a zero == \c true, uses <tt>1/0==0</tt>; otherwise, <tt>1/0==Inf</tt>.
412 TProb
<T
> inverse(bool zero
=true) const {
414 return pwUnaryTr( fo_inv0
<T
>() );
416 return pwUnaryTr( fo_inv
<T
>() );
419 /// Returns normalized copy of \c *this, using the specified norm
420 /** \throw NOT_NORMALIZABLE if the norm is zero
422 TProb
<T
> normalized( NormType norm
= NORMPROB
) const {
424 if( norm
== NORMPROB
)
426 else if( norm
== NORMLINF
)
429 DAI_THROW(NOT_NORMALIZABLE
);
432 return pwUnaryTr( std::bind2nd( std::divides
<T
>(), Z
) );
436 /// \name Unary operations
438 /// Applies unary operation \a op pointwise
439 template<typename unaryOp
> TProb
<T
>& pwUnaryOp( unaryOp op
) {
440 std::transform( _p
.begin(), _p
.end(), _p
.begin(), op
);
444 /// Draws all entries i.i.d. from a uniform distribution on [0,1)
445 TProb
<T
>& randomize() {
446 std::generate( _p
.begin(), _p
.end(), rnd_uniform
);
450 /// Sets all entries to \f$1/n\f$ where \a n is the length of the vector
451 TProb
<T
>& setUniform () {
452 fill( (T
)1 / size() );
456 /// Applies absolute value pointwise
457 const TProb
<T
>& takeAbs() { return pwUnaryOp( fo_abs
<T
>() ); }
459 /// Applies exponent pointwise
460 const TProb
<T
>& takeExp() { return pwUnaryOp( fo_exp
<T
>() ); }
462 /// Applies logarithm pointwise
463 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
465 const TProb
<T
>& takeLog(bool zero
=false) {
467 return pwUnaryOp( fo_log0
<T
>() );
469 return pwUnaryOp( fo_log
<T
>() );
472 /// Normalizes vector using the specified norm
473 /** \throw NOT_NORMALIZABLE if the norm is zero
475 T
normalize( NormType norm
=NORMPROB
) {
477 if( norm
== NORMPROB
)
479 else if( norm
== NORMLINF
)
482 DAI_THROW(NOT_NORMALIZABLE
);
489 /// \name Operations with scalars
491 /// Sets all entries to \a x
492 TProb
<T
> & fill(T x
) {
493 std::fill( _p
.begin(), _p
.end(), x
);
497 /// Adds scalar \a x to each entry
498 TProb
<T
>& operator+= (T x
) {
500 return pwUnaryOp( std::bind2nd( std::plus
<T
>(), x
) );
505 /// Subtracts scalar \a x from each entry
506 TProb
<T
>& operator-= (T x
) {
508 return pwUnaryOp( std::bind2nd( std::minus
<T
>(), x
) );
513 /// Multiplies each entry with scalar \a x
514 TProb
<T
>& operator*= (T x
) {
516 return pwUnaryOp( std::bind2nd( std::multiplies
<T
>(), x
) );
521 /// Divides each entry by scalar \a x
522 TProb
<T
>& operator/= (T x
) {
523 DAI_DEBASSERT( x
!= 0 );
525 return pwUnaryOp( std::bind2nd( std::divides
<T
>(), x
) );
530 /// Raises entries to the power \a x
531 TProb
<T
>& operator^= (T x
) {
533 return pwUnaryOp( std::bind2nd( fo_pow
<T
>(), x
) );
539 /// \name Transformations with scalars
541 /// Returns sum of \c *this and scalar \a x
542 TProb
<T
> operator+ (T x
) const { return pwUnaryTr( std::bind2nd( std::plus
<T
>(), x
) ); }
544 /// Returns difference of \c *this and scalar \a x
545 TProb
<T
> operator- (T x
) const { return pwUnaryTr( std::bind2nd( std::minus
<T
>(), x
) ); }
547 /// Returns product of \c *this with scalar \a x
548 TProb
<T
> operator* (T x
) const { return pwUnaryTr( std::bind2nd( std::multiplies
<T
>(), x
) ); }
550 /// Returns quotient of \c *this and scalar \a x, where division by 0 yields 0
551 TProb
<T
> operator/ (T x
) const { return pwUnaryTr( std::bind2nd( fo_divides0
<T
>(), x
) ); }
553 /// Returns \c *this raised to the power \a x
554 TProb
<T
> operator^ (T x
) const { return pwUnaryTr( std::bind2nd( fo_pow
<T
>(), x
) ); }
557 /// \name Operations with other equally-sized vectors
559 /// Applies binary operation pointwise on two vectors
560 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
561 * \param q Right operand
562 * \param op Operation of type \a binaryOp
564 template<typename binaryOp
> TProb
<T
>& pwBinaryOp( const TProb
<T
> &q
, binaryOp op
) {
565 DAI_DEBASSERT( size() == q
.size() );
566 std::transform( _p
.begin(), _p
.end(), q
._p
.begin(), _p
.begin(), op
);
570 /// Pointwise addition with \a q
571 /** \pre <tt>this->size() == q.size()</tt>
573 TProb
<T
>& operator+= (const TProb
<T
> & q
) { return pwBinaryOp( q
, std::plus
<T
>() ); }
575 /// Pointwise subtraction of \a q
576 /** \pre <tt>this->size() == q.size()</tt>
578 TProb
<T
>& operator-= (const TProb
<T
> & q
) { return pwBinaryOp( q
, std::minus
<T
>() ); }
580 /// Pointwise multiplication with \a q
581 /** \pre <tt>this->size() == q.size()</tt>
583 TProb
<T
>& operator*= (const TProb
<T
> & q
) { return pwBinaryOp( q
, std::multiplies
<T
>() ); }
585 /// Pointwise division by \a q, where division by 0 yields 0
586 /** \pre <tt>this->size() == q.size()</tt>
587 * \see divide(const TProb<T> &)
589 TProb
<T
>& operator/= (const TProb
<T
> & q
) { return pwBinaryOp( q
, fo_divides0
<T
>() ); }
591 /// Pointwise division by \a q, where division by 0 yields +Inf
592 /** \pre <tt>this->size() == q.size()</tt>
593 * \see operator/=(const TProb<T> &)
595 TProb
<T
>& divide (const TProb
<T
> & q
) { return pwBinaryOp( q
, std::divides
<T
>() ); }
598 /** \pre <tt>this->size() == q.size()</tt>
600 TProb
<T
>& operator^= (const TProb
<T
> & q
) { return pwBinaryOp( q
, fo_pow
<T
>() ); }
603 /// \name Transformations with other equally-sized vectors
605 /// Returns the result of applying binary operation \a op pointwise on \c *this and \a q
606 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
607 * \param q Right operand
608 * \param op Operation of type \a binaryOp
610 template<typename binaryOp
> TProb
<T
> pwBinaryTr( const TProb
<T
> &q
, binaryOp op
) const {
611 DAI_DEBASSERT( size() == q
.size() );
613 r
._p
.reserve( size() );
614 std::transform( _p
.begin(), _p
.end(), q
._p
.begin(), back_inserter( r
._p
), op
);
618 /// Returns sum of \c *this and \a q
619 /** \pre <tt>this->size() == q.size()</tt>
621 TProb
<T
> operator+ ( const TProb
<T
>& q
) const { return pwBinaryTr( q
, std::plus
<T
>() ); }
623 /// Return \c *this minus \a q
624 /** \pre <tt>this->size() == q.size()</tt>
626 TProb
<T
> operator- ( const TProb
<T
>& q
) const { return pwBinaryTr( q
, std::minus
<T
>() ); }
628 /// Return product of \c *this with \a q
629 /** \pre <tt>this->size() == q.size()</tt>
631 TProb
<T
> operator* ( const TProb
<T
> &q
) const { return pwBinaryTr( q
, std::multiplies
<T
>() ); }
633 /// Returns quotient of \c *this with \a q, where division by 0 yields 0
634 /** \pre <tt>this->size() == q.size()</tt>
635 * \see divided_by(const TProb<T> &)
637 TProb
<T
> operator/ ( const TProb
<T
> &q
) const { return pwBinaryTr( q
, fo_divides0
<T
>() ); }
639 /// Pointwise division by \a q, where division by 0 yields +Inf
640 /** \pre <tt>this->size() == q.size()</tt>
641 * \see operator/(const TProb<T> &)
643 TProb
<T
> divided_by( const TProb
<T
> &q
) const { return pwBinaryTr( q
, std::divides
<T
>() ); }
645 /// Returns \c *this to the power \a q
646 /** \pre <tt>this->size() == q.size()</tt>
648 TProb
<T
> operator^ ( const TProb
<T
> &q
) const { return pwBinaryTr( q
, fo_pow
<T
>() ); }
651 /// Performs a generalized inner product, similar to std::inner_product
652 /** \pre <tt>this->size() == q.size()</tt>
654 template<typename binOp1
, typename binOp2
> T
innerProduct( const TProb
<T
> &q
, T init
, binOp1 binaryOp1
, binOp2 binaryOp2
) const {
655 DAI_DEBASSERT( size() == q
.size() );
656 return std::inner_product( begin(), end(), q
.begin(), init
, binaryOp1
, binaryOp2
);
661 /// Returns distance between \a p and \a q, measured using distance measure \a dt
663 * \pre <tt>this->size() == q.size()</tt>
665 template<typename T
> T
dist( const TProb
<T
> &p
, const TProb
<T
> &q
, typename TProb
<T
>::DistType dt
) {
667 case TProb
<T
>::DISTL1
:
668 return p
.innerProduct( q
, (T
)0, std::plus
<T
>(), fo_absdiff
<T
>() );
669 case TProb
<T
>::DISTLINF
:
670 return p
.innerProduct( q
, (T
)0, fo_max
<T
>(), fo_absdiff
<T
>() );
671 case TProb
<T
>::DISTTV
:
672 return p
.innerProduct( q
, (T
)0, std::plus
<T
>(), fo_absdiff
<T
>() ) / 2;
673 case TProb
<T
>::DISTKL
:
674 return p
.innerProduct( q
, (T
)0, std::plus
<T
>(), fo_KL
<T
>() );
675 case TProb
<T
>::DISTHEL
:
676 return p
.innerProduct( q
, (T
)0, std::plus
<T
>(), fo_Hellinger
<T
>() ) / 2;
678 DAI_THROW(UNKNOWN_ENUM_VALUE
);
684 /// Writes a TProb<T> to an output stream
687 template<typename T
> std::ostream
& operator<< (std::ostream
& os
, const TProb
<T
>& p
) {
689 std::copy( p
.p().begin(), p
.p().end(), std::ostream_iterator
<T
>(os
, " ") );
695 /// Returns the pointwise minimum of \a a and \a b
697 * \pre <tt>this->size() == q.size()</tt>
699 template<typename T
> TProb
<T
> min( const TProb
<T
> &a
, const TProb
<T
> &b
) {
700 return a
.pwBinaryTr( b
, fo_min
<T
>() );
704 /// Returns the pointwise maximum of \a a and \a b
706 * \pre <tt>this->size() == q.size()</tt>
708 template<typename T
> TProb
<T
> max( const TProb
<T
> &a
, const TProb
<T
> &b
) {
709 return a
.pwBinaryTr( b
, fo_max
<T
>() );
713 /// Represents a vector with entries of type dai::Real.
714 typedef TProb
<Real
> Prob
;
717 } // end of namespace dai