2036e17d29aec391ffdccaaf8cb7a06b61058843
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
9 */
12 /// \file
13 /// \brief Defines TProb<> and Prob classes which represent (probability) vectors (e.g., probability distributions of discrete random variables)
16 #ifndef __defined_libdai_prob_h
17 #define __defined_libdai_prob_h
20 #include <cmath>
21 #include <vector>
22 #include <ostream>
23 #include <algorithm>
24 #include <numeric>
25 #include <functional>
26 #include <dai/util.h>
27 #include <dai/exceptions.h>
30 namespace dai {
33 /// Function object that returns the value itself
34 template<typename T> struct fo_id : public std::unary_function<T, T> {
35 /// Returns \a x
36 T operator()( const T &x ) const {
37 return x;
38 }
39 };
42 /// Function object that takes the absolute value
43 template<typename T> struct fo_abs : public std::unary_function<T, T> {
44 /// Returns abs(\a x)
45 T operator()( const T &x ) const {
46 if( x < (T)0 )
47 return -x;
48 else
49 return x;
50 }
51 };
54 /// Function object that takes the exponent
55 template<typename T> struct fo_exp : public std::unary_function<T, T> {
56 /// Returns exp(\a x)
57 T operator()( const T &x ) const {
58 return exp( x );
59 }
60 };
63 /// Function object that takes the logarithm
64 template<typename T> struct fo_log : public std::unary_function<T, T> {
65 /// Returns log(\a x)
66 T operator()( const T &x ) const {
67 return log( x );
68 }
69 };
72 /// Function object that takes the logarithm, except that log(0) is defined to be 0
73 template<typename T> struct fo_log0 : public std::unary_function<T, T> {
74 /// Returns (\a x == 0 ? 0 : log(\a x))
75 T operator()( const T &x ) const {
76 if( x )
77 return log( x );
78 else
79 return 0;
80 }
81 };
84 /// Function object that takes the inverse
85 template<typename T> struct fo_inv : public std::unary_function<T, T> {
86 /// Returns 1 / \a x
87 T operator()( const T &x ) const {
88 return 1 / x;
89 }
90 };
93 /// Function object that takes the inverse, except that 1/0 is defined to be 0
94 template<typename T> struct fo_inv0 : public std::unary_function<T, T> {
95 /// Returns (\a x == 0 ? 0 : (1 / \a x))
96 T operator()( const T &x ) const {
97 if( x )
98 return 1 / x;
99 else
100 return 0;
101 }
102 };
105 /// Function object that returns p*log0(p)
106 template<typename T> struct fo_plog0p : public std::unary_function<T, T> {
107 /// Returns \a p * log0(\a p)
108 T operator()( const T &p ) const {
109 return p * dai::log0(p);
110 }
111 };
114 /// Function object similar to std::divides(), but different in that dividing by zero results in zero
115 template<typename T> struct fo_divides0 : public std::binary_function<T, T, T> {
116 /// Returns (\a y == 0 ? 0 : (\a x / \a y))
117 T operator()( const T &x, const T &y ) const {
118 if( y == (T)0 )
119 return (T)0;
120 else
121 return x / y;
122 }
123 };
126 /// Function object useful for calculating the KL distance
127 template<typename T> struct fo_KL : public std::binary_function<T, T, T> {
128 /// Returns (\a p == 0 ? 0 : (\a p * (log(\a p) - log(\a q))))
129 T operator()( const T &p, const T &q ) const {
130 if( p == (T)0 )
131 return (T)0;
132 else
133 return p * (log(p) - log(q));
134 }
135 };
138 /// Function object useful for calculating the Hellinger distance
139 template<typename T> struct fo_Hellinger : public std::binary_function<T, T, T> {
140 /// Returns (sqrt(\a p) - sqrt(\a q))^2
141 T operator()( const T &p, const T &q ) const {
142 T x = sqrt(p) - sqrt(q);
143 return x * x;
144 }
145 };
148 /// Function object that returns x to the power y
149 template<typename T> struct fo_pow : public std::binary_function<T, T, T> {
150 /// Returns (\a x ^ \a y)
151 T operator()( const T &x, const T &y ) const {
152 if( y != 1 )
153 return std::pow( x, y );
154 else
155 return x;
156 }
157 };
160 /// Function object that returns the maximum of two values
161 template<typename T> struct fo_max : public std::binary_function<T, T, T> {
162 /// Returns (\a x > y ? x : y)
163 T operator()( const T &x, const T &y ) const {
164 return (x > y) ? x : y;
165 }
166 };
169 /// Function object that returns the minimum of two values
170 template<typename T> struct fo_min : public std::binary_function<T, T, T> {
171 /// Returns (\a x > y ? y : x)
172 T operator()( const T &x, const T &y ) const {
173 return (x > y) ? y : x;
174 }
175 };
178 /// Function object that returns the absolute difference of x and y
179 template<typename T> struct fo_absdiff : public std::binary_function<T, T, T> {
180 /// Returns abs( \a x - \a y )
181 T operator()( const T &x, const T &y ) const {
182 return dai::abs( x - y );
183 }
184 };
187 /// Represents a vector with entries of type \a T.
188 /** It is simply a <tt>std::vector</tt><<em>T</em>> with an interface designed for dealing with probability mass functions.
189 *
190 * It is mainly used for representing measures on a finite outcome space, for example, the probability
191 * distribution of a discrete random variable. However, entries are not necessarily non-negative; it is also used to
192 * represent logarithms of probability mass functions.
193 *
194 * \tparam T Should be a scalar that is castable from and to dai::Real and should support elementary arithmetic operations.
195 */
196 template <typename T> class TProb {
197 private:
198 /// The vector
199 std::vector<T> _p;
201 public:
202 /// Enumerates different ways of normalizing a probability measure.
203 /**
204 * - NORMPROB means that the sum of all entries should be 1;
205 * - NORMLINF means that the maximum absolute value of all entries should be 1.
206 */
207 typedef enum { NORMPROB, NORMLINF } NormType;
208 /// Enumerates different distance measures between probability measures.
209 /**
210 * - DISTL1 is the \f$\ell_1\f$ distance (sum of absolute values of pointwise difference);
211 * - DISTLINF is the \f$\ell_\infty\f$ distance (maximum absolute value of pointwise difference);
212 * - DISTTV is the total variation distance (half of the \f$\ell_1\f$ distance);
213 * - DISTKL is the Kullback-Leibler distance (\f$\sum_i p_i (\log p_i - \log q_i)\f$).
214 * - DISTHEL is the Hellinger distance (\f$\frac{1}{2}\sum_i (\sqrt{p_i}-\sqrt{q_i})^2\f$).
215 */
216 typedef enum { DISTL1, DISTLINF, DISTTV, DISTKL, DISTHEL } DistType;
218 /// \name Constructors and destructors
219 //@{
220 /// Default constructor (constructs empty vector)
221 TProb() : _p() {}
223 /// Construct uniform probability distribution over \a n outcomes (i.e., a vector of length \a n with each entry set to \f$1/n\f$)
224 explicit TProb( size_t n ) : _p(std::vector<T>(n, (T)1 / n)) {}
226 /// Construct vector of length \a n with each entry set to \a p
227 explicit TProb( size_t n, T p ) : _p(n, p) {}
229 /// Construct vector from a range
230 /** \tparam TIterator Iterates over instances that can be cast to \a T
231 * \param begin Points to first instance to be added.
232 * \param end Points just beyond last instance to be added.
233 * \param sizeHint For efficiency, the number of entries can be speficied by \a sizeHint.
234 */
235 template <typename TIterator>
236 TProb( TIterator begin, TIterator end, size_t sizeHint=0 ) : _p() {
237 _p.reserve( sizeHint );
238 _p.insert( _p.begin(), begin, end );
239 }
241 /// Construct vector from another vector
242 /** \tparam S type of elements in \a v (should be castable to type \a T)
243 * \param v vector used for initialization
244 */
245 template <typename S>
246 TProb( const std::vector<S> &v ) : _p() {
247 _p.reserve( v.size() );
248 _p.insert( _p.begin(), v.begin(), v.end() );
249 }
250 //@}
252 /// Constant iterator over the elements
253 typedef typename std::vector<T>::const_iterator const_iterator;
254 /// Iterator over the elements
255 typedef typename std::vector<T>::iterator iterator;
256 /// Constant reverse iterator over the elements
257 typedef typename std::vector<T>::const_reverse_iterator const_reverse_iterator;
258 /// Reverse iterator over the elements
259 typedef typename std::vector<T>::reverse_iterator reverse_iterator;
261 /// \name Iterator interface
262 //@{
263 /// Returns iterator that points to the first element
264 iterator begin() { return _p.begin(); }
265 /// Returns constant iterator that points to the first element
266 const_iterator begin() const { return _p.begin(); }
268 /// Returns iterator that points beyond the last element
269 iterator end() { return _p.end(); }
270 /// Returns constant iterator that points beyond the last element
271 const_iterator end() const { return _p.end(); }
273 /// Returns reverse iterator that points to the last element
274 reverse_iterator rbegin() { return _p.rbegin(); }
275 /// Returns constant reverse iterator that points to the last element
276 const_reverse_iterator rbegin() const { return _p.rbegin(); }
278 /// Returns reverse iterator that points beyond the first element
279 reverse_iterator rend() { return _p.rend(); }
280 /// Returns constant reverse iterator that points beyond the first element
281 const_reverse_iterator rend() const { return _p.rend(); }
282 //@}
284 /// \name Queries
285 //@{
286 /// Returns a const reference to the wrapped vector
287 const std::vector<T> & p() const { return _p; }
289 /// Returns a reference to the wrapped vector
290 std::vector<T> & p() { return _p; }
292 /// Returns a copy of the \a i 'th entry
293 T operator[]( size_t i ) const {
294 #ifdef DAI_DEBUG
295 return _p.at(i);
296 #else
297 return _p[i];
298 #endif
299 }
301 /// Returns reference to the \a i 'th entry
302 T& operator[]( size_t i ) { return _p[i]; }
304 /// Returns length of the vector (i.e., the number of entries)
305 size_t size() const { return _p.size(); }
307 /// Accumulate over all values, similar to std::accumulate
308 /** The following calculation is done:
309 * \code
310 * T t = op2(init);
311 * for( const_iterator it = begin(); it != end(); it++ )
312 * t = op1( t, op2(*it) );
313 * return t;
314 * \endcode
315 */
316 template<typename binOp, typename unOp> T accumulate( T init, binOp op1, unOp op2 ) const {
317 T t = op2(init);
318 for( const_iterator it = begin(); it != end(); it++ )
319 t = op1( t, op2(*it) );
320 return t;
321 }
323 /// Returns the Shannon entropy of \c *this, \f$-\sum_i p_i \log p_i\f$
324 T entropy() const { return -accumulate( (T)0, std::plus<T>(), fo_plog0p<T>() ); }
326 /// Returns maximum value of all entries
327 T max() const { return accumulate( (T)(-INFINITY), fo_max<T>(), fo_id<T>() ); }
329 /// Returns minimum value of all entries
330 T min() const { return accumulate( (T)INFINITY, fo_min<T>(), fo_id<T>() ); }
332 /// Returns sum of all entries
333 T sum() const { return accumulate( (T)0, std::plus<T>(), fo_id<T>() ); }
335 /// Return sum of absolute value of all entries
336 T sumAbs() const { return accumulate( (T)0, std::plus<T>(), fo_abs<T>() ); }
338 /// Returns maximum absolute value of all entries
339 T maxAbs() const { return accumulate( (T)0, fo_max<T>(), fo_abs<T>() ); }
341 /// Returns \c true if one or more entries are NaN
342 bool hasNaNs() const {
343 bool foundnan = false;
344 for( typename std::vector<T>::const_iterator x = _p.begin(); x != _p.end(); x++ )
345 if( isnan( *x ) ) {
346 foundnan = true;
347 break;
348 }
349 return foundnan;
350 }
352 /// Returns \c true if one or more entries are negative
353 bool hasNegatives() const {
354 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<T>(), (T)0 ) ) != _p.end());
355 }
357 /// Returns a pair consisting of the index of the maximum value and the maximum value itself
358 std::pair<size_t,T> argmax() const {
359 T max = _p[0];
360 size_t arg = 0;
361 for( size_t i = 1; i < size(); i++ ) {
362 if( _p[i] > max ) {
363 max = _p[i];
364 arg = i;
365 }
366 }
367 return std::make_pair(arg,max);
368 }
370 /// Returns a random index, according to the (normalized) distribution described by *this
371 size_t draw() {
372 Real x = rnd_uniform() * sum();
373 T s = 0;
374 for( size_t i = 0; i < size(); i++ ) {
375 s += _p[i];
376 if( s > x )
377 return i;
378 }
379 return( size() - 1 );
380 }
382 /// Lexicographical comparison
383 /** \pre <tt>this->size() == q.size()</tt>
384 */
385 bool operator<( const TProb<T>& q ) const {
386 DAI_DEBASSERT( size() == q.size() );
387 return lexicographical_compare( begin(), end(), q.begin(), q.end() );
388 }
390 /// Comparison
391 bool operator==( const TProb<T>& q ) const {
392 if( size() != q.size() )
393 return false;
394 return p() == q.p();
395 }
396 //@}
398 /// \name Unary transformations
399 //@{
400 /// Returns the result of applying operation \a op pointwise on \c *this
401 template<typename unaryOp> TProb<T> pwUnaryTr( unaryOp op ) const {
402 TProb<T> r;
403 r._p.reserve( size() );
404 std::transform( _p.begin(), _p.end(), back_inserter( r._p ), op );
405 return r;
406 }
408 /// Returns negative of \c *this
409 TProb<T> operator- () const { return pwUnaryTr( std::negate<T>() ); }
411 /// Returns pointwise absolute value
412 TProb<T> abs() const { return pwUnaryTr( fo_abs<T>() ); }
414 /// Returns pointwise exponent
415 TProb<T> exp() const { return pwUnaryTr( fo_exp<T>() ); }
417 /// Returns pointwise logarithm
418 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
419 */
420 TProb<T> log(bool zero=false) const {
421 if( zero )
422 return pwUnaryTr( fo_log0<T>() );
423 else
424 return pwUnaryTr( fo_log<T>() );
425 }
427 /// Returns pointwise inverse
428 /** If \a zero == \c true, uses <tt>1/0==0</tt>; otherwise, <tt>1/0==Inf</tt>.
429 */
430 TProb<T> inverse(bool zero=true) const {
431 if( zero )
432 return pwUnaryTr( fo_inv0<T>() );
433 else
434 return pwUnaryTr( fo_inv<T>() );
435 }
437 /// Returns normalized copy of \c *this, using the specified norm
438 /** \throw NOT_NORMALIZABLE if the norm is zero
439 */
440 TProb<T> normalized( NormType norm = NORMPROB ) const {
441 T Z = 0;
442 if( norm == NORMPROB )
443 Z = sum();
444 else if( norm == NORMLINF )
445 Z = maxAbs();
446 if( Z == (T)0 ) {
447 DAI_THROW(NOT_NORMALIZABLE);
448 return *this;
449 } else
450 return pwUnaryTr( std::bind2nd( std::divides<T>(), Z ) );
451 }
452 //@}
454 /// \name Unary operations
455 //@{
456 /// Applies unary operation \a op pointwise
457 template<typename unaryOp> TProb<T>& pwUnaryOp( unaryOp op ) {
458 std::transform( _p.begin(), _p.end(), _p.begin(), op );
459 return *this;
460 }
462 /// Draws all entries i.i.d. from a uniform distribution on [0,1)
463 TProb<T>& randomize() {
464 std::generate( _p.begin(), _p.end(), rnd_uniform );
465 return *this;
466 }
468 /// Sets all entries to \f$1/n\f$ where \a n is the length of the vector
469 TProb<T>& setUniform () {
470 fill( (T)1 / size() );
471 return *this;
472 }
474 /// Applies absolute value pointwise
475 const TProb<T>& takeAbs() { return pwUnaryOp( fo_abs<T>() ); }
477 /// Applies exponent pointwise
478 const TProb<T>& takeExp() { return pwUnaryOp( fo_exp<T>() ); }
480 /// Applies logarithm pointwise
481 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
482 */
483 const TProb<T>& takeLog(bool zero=false) {
484 if( zero ) {
485 return pwUnaryOp( fo_log0<T>() );
486 } else
487 return pwUnaryOp( fo_log<T>() );
488 }
490 /// Normalizes vector using the specified norm
491 /** \throw NOT_NORMALIZABLE if the norm is zero
492 */
493 T normalize( NormType norm=NORMPROB ) {
494 T Z = 0;
495 if( norm == NORMPROB )
496 Z = sum();
497 else if( norm == NORMLINF )
498 Z = maxAbs();
499 if( Z == (T)0 )
500 DAI_THROW(NOT_NORMALIZABLE);
501 else
502 *this /= Z;
503 return Z;
504 }
505 //@}
507 /// \name Operations with scalars
508 //@{
509 /// Sets all entries to \a x
510 TProb<T> & fill(T x) {
511 std::fill( _p.begin(), _p.end(), x );
512 return *this;
513 }
515 /// Adds scalar \a x to each entry
516 TProb<T>& operator+= (T x) {
517 if( x != 0 )
518 return pwUnaryOp( std::bind2nd( std::plus<T>(), x ) );
519 else
520 return *this;
521 }
523 /// Subtracts scalar \a x from each entry
524 TProb<T>& operator-= (T x) {
525 if( x != 0 )
526 return pwUnaryOp( std::bind2nd( std::minus<T>(), x ) );
527 else
528 return *this;
529 }
531 /// Multiplies each entry with scalar \a x
532 TProb<T>& operator*= (T x) {
533 if( x != 1 )
534 return pwUnaryOp( std::bind2nd( std::multiplies<T>(), x ) );
535 else
536 return *this;
537 }
539 /// Divides each entry by scalar \a x, where division by 0 yields 0
540 TProb<T>& operator/= (T x) {
541 if( x != 1 )
542 return pwUnaryOp( std::bind2nd( fo_divides0<T>(), x ) );
543 else
544 return *this;
545 }
547 /// Raises entries to the power \a x
548 TProb<T>& operator^= (T x) {
549 if( x != (T)1 )
550 return pwUnaryOp( std::bind2nd( fo_pow<T>(), x) );
551 else
552 return *this;
553 }
554 //@}
556 /// \name Transformations with scalars
557 //@{
558 /// Returns sum of \c *this and scalar \a x
559 TProb<T> operator+ (T x) const { return pwUnaryTr( std::bind2nd( std::plus<T>(), x ) ); }
561 /// Returns difference of \c *this and scalar \a x
562 TProb<T> operator- (T x) const { return pwUnaryTr( std::bind2nd( std::minus<T>(), x ) ); }
564 /// Returns product of \c *this with scalar \a x
565 TProb<T> operator* (T x) const { return pwUnaryTr( std::bind2nd( std::multiplies<T>(), x ) ); }
567 /// Returns quotient of \c *this and scalar \a x, where division by 0 yields 0
568 TProb<T> operator/ (T x) const { return pwUnaryTr( std::bind2nd( fo_divides0<T>(), x ) ); }
570 /// Returns \c *this raised to the power \a x
571 TProb<T> operator^ (T x) const { return pwUnaryTr( std::bind2nd( fo_pow<T>(), x ) ); }
572 //@}
574 /// \name Operations with other equally-sized vectors
575 //@{
576 /// Applies binary operation pointwise on two vectors
577 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
578 * \param q Right operand
579 * \param op Operation of type \a binaryOp
580 */
581 template<typename binaryOp> TProb<T>& pwBinaryOp( const TProb<T> &q, binaryOp op ) {
582 DAI_DEBASSERT( size() == q.size() );
583 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), op );
584 return *this;
585 }
587 /// Pointwise addition with \a q
588 /** \pre <tt>this->size() == q.size()</tt>
589 */
590 TProb<T>& operator+= (const TProb<T> & q) { return pwBinaryOp( q, std::plus<T>() ); }
592 /// Pointwise subtraction of \a q
593 /** \pre <tt>this->size() == q.size()</tt>
594 */
595 TProb<T>& operator-= (const TProb<T> & q) { return pwBinaryOp( q, std::minus<T>() ); }
597 /// Pointwise multiplication with \a q
598 /** \pre <tt>this->size() == q.size()</tt>
599 */
600 TProb<T>& operator*= (const TProb<T> & q) { return pwBinaryOp( q, std::multiplies<T>() ); }
602 /// Pointwise division by \a q, where division by 0 yields 0
603 /** \pre <tt>this->size() == q.size()</tt>
604 * \see divide(const TProb<T> &)
605 */
606 TProb<T>& operator/= (const TProb<T> & q) { return pwBinaryOp( q, fo_divides0<T>() ); }
608 /// Pointwise division by \a q, where division by 0 yields +Inf
609 /** \pre <tt>this->size() == q.size()</tt>
610 * \see operator/=(const TProb<T> &)
611 */
612 TProb<T>& divide (const TProb<T> & q) { return pwBinaryOp( q, std::divides<T>() ); }
614 /// Pointwise power
615 /** \pre <tt>this->size() == q.size()</tt>
616 */
617 TProb<T>& operator^= (const TProb<T> & q) { return pwBinaryOp( q, fo_pow<T>() ); }
618 //@}
620 /// \name Transformations with other equally-sized vectors
621 //@{
622 /// Returns the result of applying binary operation \a op pointwise on \c *this and \a q
623 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
624 * \param q Right operand
625 * \param op Operation of type \a binaryOp
626 */
627 template<typename binaryOp> TProb<T> pwBinaryTr( const TProb<T> &q, binaryOp op ) const {
628 DAI_DEBASSERT( size() == q.size() );
629 TProb<T> r;
630 r._p.reserve( size() );
631 std::transform( _p.begin(), _p.end(), q._p.begin(), back_inserter( r._p ), op );
632 return r;
633 }
635 /// Returns sum of \c *this and \a q
636 /** \pre <tt>this->size() == q.size()</tt>
637 */
638 TProb<T> operator+ ( const TProb<T>& q ) const { return pwBinaryTr( q, std::plus<T>() ); }
640 /// Return \c *this minus \a q
641 /** \pre <tt>this->size() == q.size()</tt>
642 */
643 TProb<T> operator- ( const TProb<T>& q ) const { return pwBinaryTr( q, std::minus<T>() ); }
645 /// Return product of \c *this with \a q
646 /** \pre <tt>this->size() == q.size()</tt>
647 */
648 TProb<T> operator* ( const TProb<T> &q ) const { return pwBinaryTr( q, std::multiplies<T>() ); }
650 /// Returns quotient of \c *this with \a q, where division by 0 yields 0
651 /** \pre <tt>this->size() == q.size()</tt>
652 * \see divided_by(const TProb<T> &)
653 */
654 TProb<T> operator/ ( const TProb<T> &q ) const { return pwBinaryTr( q, fo_divides0<T>() ); }
656 /// Pointwise division by \a q, where division by 0 yields +Inf
657 /** \pre <tt>this->size() == q.size()</tt>
658 * \see operator/(const TProb<T> &)
659 */
660 TProb<T> divided_by( const TProb<T> &q ) const { return pwBinaryTr( q, std::divides<T>() ); }
662 /// Returns \c *this to the power \a q
663 /** \pre <tt>this->size() == q.size()</tt>
664 */
665 TProb<T> operator^ ( const TProb<T> &q ) const { return pwBinaryTr( q, fo_pow<T>() ); }
666 //@}
668 /// Performs a generalized inner product, similar to std::inner_product
669 /** \pre <tt>this->size() == q.size()</tt>
670 */
671 template<typename binOp1, typename binOp2> T innerProduct( const TProb<T> &q, T init, binOp1 binaryOp1, binOp2 binaryOp2 ) const {
672 DAI_DEBASSERT( size() == q.size() );
673 return std::inner_product( begin(), end(), q.begin(), init, binaryOp1, binaryOp2 );
674 }
675 };
678 /// Returns distance between \a p and \a q, measured using distance measure \a dt
679 /** \relates TProb
680 * \pre <tt>this->size() == q.size()</tt>
681 */
682 template<typename T> T dist( const TProb<T> &p, const TProb<T> &q, typename TProb<T>::DistType dt ) {
683 switch( dt ) {
684 case TProb<T>::DISTL1:
685 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() );
686 case TProb<T>::DISTLINF:
687 return p.innerProduct( q, (T)0, fo_max<T>(), fo_absdiff<T>() );
688 case TProb<T>::DISTTV:
689 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() ) / 2;
690 case TProb<T>::DISTKL:
691 return p.innerProduct( q, (T)0, std::plus<T>(), fo_KL<T>() );
692 case TProb<T>::DISTHEL:
693 return p.innerProduct( q, (T)0, std::plus<T>(), fo_Hellinger<T>() ) / 2;
694 default:
695 DAI_THROW(UNKNOWN_ENUM_VALUE);
696 return INFINITY;
697 }
698 }
701 /// Writes a TProb<T> to an output stream
702 /** \relates TProb
703 */
704 template<typename T> std::ostream& operator<< (std::ostream& os, const TProb<T>& p) {
705 os << "(";
706 for( typename std::vector<T>::const_iterator it = p.begin(); it != p.end(); it++ )
707 os << (it != p.begin() ? ", " : "") << *it;
708 os << ")";
709 return os;
710 }
713 /// Returns the pointwise minimum of \a a and \a b
714 /** \relates TProb
715 * \pre <tt>this->size() == q.size()</tt>
716 */
717 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) {
718 return a.pwBinaryTr( b, fo_min<T>() );
719 }
722 /// Returns the pointwise maximum of \a a and \a b
723 /** \relates TProb
724 * \pre <tt>this->size() == q.size()</tt>
725 */
726 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) {
727 return a.pwBinaryTr( b, fo_max<T>() );
728 }
731 /// Represents a vector with entries of type dai::Real.
732 typedef TProb<Real> Prob;
735 } // end of namespace dai
738 #endif