90b457b20ac8bb57cb172caf6c2b63437169ba67
[libdai.git] / include / dai / prob.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines TProb<> and Prob classes which represent (probability) vectors (e.g., probability distributions of discrete random variables)
14
15
16 #ifndef __defined_libdai_prob_h
17 #define __defined_libdai_prob_h
18
19
20 #include <cmath>
21 #include <vector>
22 #include <ostream>
23 #include <algorithm>
24 #include <numeric>
25 #include <functional>
26 #include <dai/util.h>
27 #include <dai/exceptions.h>
28
29
30 namespace dai {
31
32
33 /// Function object that returns the value itself
34 template<typename T> struct fo_id : public std::unary_function<T, T> {
35 /// Returns \a x
36 T operator()( const T &x ) const {
37 return x;
38 }
39 };
40
41
42 /// Function object that takes the absolute value
43 template<typename T> struct fo_abs : public std::unary_function<T, T> {
44 /// Returns abs(\a x)
45 T operator()( const T &x ) const {
46 if( x < (T)0 )
47 return -x;
48 else
49 return x;
50 }
51 };
52
53
54 /// Function object that takes the exponent
55 template<typename T> struct fo_exp : public std::unary_function<T, T> {
56 /// Returns exp(\a x)
57 T operator()( const T &x ) const {
58 return exp( x );
59 }
60 };
61
62
63 /// Function object that takes the logarithm
64 template<typename T> struct fo_log : public std::unary_function<T, T> {
65 /// Returns log(\a x)
66 T operator()( const T &x ) const {
67 return log( x );
68 }
69 };
70
71
72 /// Function object that takes the logarithm, except that log(0) is defined to be 0
73 template<typename T> struct fo_log0 : public std::unary_function<T, T> {
74 /// Returns (\a x == 0 ? 0 : log(\a x))
75 T operator()( const T &x ) const {
76 if( x )
77 return log( x );
78 else
79 return 0;
80 }
81 };
82
83
84 /// Function object that takes the inverse
85 template<typename T> struct fo_inv : public std::unary_function<T, T> {
86 /// Returns 1 / \a x
87 T operator()( const T &x ) const {
88 return 1 / x;
89 }
90 };
91
92
93 /// Function object that takes the inverse, except that 1/0 is defined to be 0
94 template<typename T> struct fo_inv0 : public std::unary_function<T, T> {
95 /// Returns (\a x == 0 ? 0 : (1 / \a x))
96 T operator()( const T &x ) const {
97 if( x )
98 return 1 / x;
99 else
100 return 0;
101 }
102 };
103
104
105 /// Function object that returns p*log0(p)
106 template<typename T> struct fo_plog0p : public std::unary_function<T, T> {
107 /// Returns \a p * log0(\a p)
108 T operator()( const T &p ) const {
109 return p * dai::log0(p);
110 }
111 };
112
113
114 /// Function object similar to std::divides(), but different in that dividing by zero results in zero
115 template<typename T> struct fo_divides0 : public std::binary_function<T, T, T> {
116 /// Returns (\a y == 0 ? 0 : (\a x / \a y))
117 T operator()( const T &x, const T &y ) const {
118 if( y == (T)0 )
119 return (T)0;
120 else
121 return x / y;
122 }
123 };
124
125
126 /// Function object useful for calculating the KL distance
127 template<typename T> struct fo_KL : public std::binary_function<T, T, T> {
128 /// Returns (\a p == 0 ? 0 : (\a p * (log(\a p) - log(\a q))))
129 T operator()( const T &p, const T &q ) const {
130 if( p == (T)0 )
131 return (T)0;
132 else
133 return p * (log(p) - log(q));
134 }
135 };
136
137
138 /// Function object useful for calculating the Hellinger distance
139 template<typename T> struct fo_Hellinger : public std::binary_function<T, T, T> {
140 /// Returns (sqrt(\a p) - sqrt(\a q))^2
141 T operator()( const T &p, const T &q ) const {
142 T x = sqrt(p) - sqrt(q);
143 return x * x;
144 }
145 };
146
147
148 /// Function object that returns x to the power y
149 template<typename T> struct fo_pow : public std::binary_function<T, T, T> {
150 /// Returns (\a x ^ \a y)
151 T operator()( const T &x, const T &y ) const {
152 if( y != 1 )
153 return pow( x, y );
154 else
155 return x;
156 }
157 };
158
159
160 /// Function object that returns the maximum of two values
161 template<typename T> struct fo_max : public std::binary_function<T, T, T> {
162 /// Returns (\a x > y ? x : y)
163 T operator()( const T &x, const T &y ) const {
164 return (x > y) ? x : y;
165 }
166 };
167
168
169 /// Function object that returns the minimum of two values
170 template<typename T> struct fo_min : public std::binary_function<T, T, T> {
171 /// Returns (\a x > y ? y : x)
172 T operator()( const T &x, const T &y ) const {
173 return (x > y) ? y : x;
174 }
175 };
176
177
178 /// Function object that returns the absolute difference of x and y
179 template<typename T> struct fo_absdiff : public std::binary_function<T, T, T> {
180 /// Returns abs( \a x - \a y )
181 T operator()( const T &x, const T &y ) const {
182 return dai::abs( x - y );
183 }
184 };
185
186
187 /// Represents a vector with entries of type \a T.
188 /** It is simply a <tt>std::vector</tt><<em>T</em>> with an interface designed for dealing with probability mass functions.
189 *
190 * It is mainly used for representing measures on a finite outcome space, for example, the probability
191 * distribution of a discrete random variable. However, entries are not necessarily non-negative; it is also used to
192 * represent logarithms of probability mass functions.
193 *
194 * \tparam T Should be a scalar that is castable from and to dai::Real and should support elementary arithmetic operations.
195 */
196 template <typename T>
197 class TProb {
198 public:
199 /// Type of data structure used for storing the values
200 typedef std::vector<T> container_type;
201
202 /// Shorthand
203 typedef TProb<T> this_type;
204
205 private:
206 /// The data structure that stores the values
207 container_type _p;
208
209 public:
210 /// Enumerates different ways of normalizing a probability measure.
211 /**
212 * - NORMPROB means that the sum of all entries should be 1;
213 * - NORMLINF means that the maximum absolute value of all entries should be 1.
214 * \deprecated Please use dai::ProbNormType instead.
215 */
216 typedef enum { NORMPROB, NORMLINF } NormType;
217 /// Enumerates different distance measures between probability measures.
218 /**
219 * - DISTL1 is the \f$\ell_1\f$ distance (sum of absolute values of pointwise difference);
220 * - DISTLINF is the \f$\ell_\infty\f$ distance (maximum absolute value of pointwise difference);
221 * - DISTTV is the total variation distance (half of the \f$\ell_1\f$ distance);
222 * - DISTKL is the Kullback-Leibler distance (\f$\sum_i p_i (\log p_i - \log q_i)\f$).
223 * - DISTHEL is the Hellinger distance (\f$\frac{1}{2}\sum_i (\sqrt{p_i}-\sqrt{q_i})^2\f$).
224 * \deprecated Please use dai::ProbDistType instead.
225 */
226 typedef enum { DISTL1, DISTLINF, DISTTV, DISTKL, DISTHEL } DistType;
227
228 /// \name Constructors and destructors
229 //@{
230 /// Default constructor (constructs empty vector)
231 TProb() : _p() {}
232
233 /// Construct uniform probability distribution over \a n outcomes (i.e., a vector of length \a n with each entry set to \f$1/n\f$)
234 explicit TProb( size_t n ) : _p( n, (T)1 / n ) {}
235
236 /// Construct vector of length \a n with each entry set to \a p
237 explicit TProb( size_t n, T p ) : _p( n, p ) {}
238
239 /// Construct vector from a range
240 /** \tparam TIterator Iterates over instances that can be cast to \a T
241 * \param begin Points to first instance to be added.
242 * \param end Points just beyond last instance to be added.
243 * \param sizeHint For efficiency, the number of entries can be speficied by \a sizeHint;
244 * the value 0 can be given if the size is unknown, but this will result in a performance penalty.
245 * \deprecated In future libDAI versions, the \a sizeHint argument will no longer default to 0.
246 */
247 template <typename TIterator>
248 TProb( TIterator begin, TIterator end, size_t sizeHint=0 ) : _p() {
249 _p.reserve( sizeHint );
250 _p.insert( _p.begin(), begin, end );
251 }
252
253 /// Construct vector from another vector
254 /** \tparam S type of elements in \a v (should be castable to type \a T)
255 * \param v vector used for initialization.
256 */
257 template <typename S>
258 TProb( const std::vector<S> &v ) : _p() {
259 _p.reserve( v.size() );
260 _p.insert( _p.begin(), v.begin(), v.end() );
261 }
262 //@}
263
264 /// Constant iterator over the elements
265 typedef typename container_type::const_iterator const_iterator;
266 /// Iterator over the elements
267 typedef typename container_type::iterator iterator;
268 /// Constant reverse iterator over the elements
269 typedef typename container_type::const_reverse_iterator const_reverse_iterator;
270 /// Reverse iterator over the elements
271 typedef typename container_type::reverse_iterator reverse_iterator;
272
273 /// \name Iterator interface
274 //@{
275 /// Returns iterator that points to the first element
276 iterator begin() { return _p.begin(); }
277 /// Returns constant iterator that points to the first element
278 const_iterator begin() const { return _p.begin(); }
279
280 /// Returns iterator that points beyond the last element
281 iterator end() { return _p.end(); }
282 /// Returns constant iterator that points beyond the last element
283 const_iterator end() const { return _p.end(); }
284
285 /// Returns reverse iterator that points to the last element
286 reverse_iterator rbegin() { return _p.rbegin(); }
287 /// Returns constant reverse iterator that points to the last element
288 const_reverse_iterator rbegin() const { return _p.rbegin(); }
289
290 /// Returns reverse iterator that points beyond the first element
291 reverse_iterator rend() { return _p.rend(); }
292 /// Returns constant reverse iterator that points beyond the first element
293 const_reverse_iterator rend() const { return _p.rend(); }
294 //@}
295
296 /// \name Miscellaneous operations
297 //@{
298 void resize( size_t sz ) {
299 _p.resize( sz );
300 }
301 //@}
302
303 /// \name Get/set individual entries
304 //@{
305 /// Gets \a i 'th entry
306 T get( size_t i ) const {
307 #ifdef DAI_DEBUG
308 return _p.at(i);
309 #else
310 return _p[i];
311 #endif
312 }
313
314 /// Sets \a i 'th entry to \a val
315 void set( size_t i, T val ) {
316 DAI_DEBASSERT( i < _p.size() );
317 _p[i] = val;
318 }
319 //@}
320
321 /// \name Queries
322 //@{
323 /// Returns a const reference to the wrapped container
324 const container_type& p() const { return _p; }
325
326 /// Returns a reference to the wrapped container
327 container_type& p() { return _p; }
328
329 /// Returns a copy of the \a i 'th entry
330 T operator[]( size_t i ) const { return get(i); }
331
332 /// Returns reference to the \a i 'th entry
333 /** \deprecated Please use dai::TProb::set() instead
334 */
335 T& operator[]( size_t i ) { return _p[i]; }
336
337 /// Returns length of the vector (i.e., the number of entries)
338 size_t size() const { return _p.size(); }
339
340 /// Accumulate over all values, similar to std::accumulate
341 /** The following calculation is done:
342 * \code
343 * T t = op2(init);
344 * for( const_iterator it = begin(); it != end(); it++ )
345 * t = op1( t, op2(*it) );
346 * return t;
347 * \endcode
348 * \deprecated Please use dai::TProb::accumulateSum or dai::TProb::accumulateMax instead
349 */
350 template<typename binOp, typename unOp> T accumulate( T init, binOp op1, unOp op2 ) const {
351 T t = op2(init);
352 for( const_iterator it = begin(); it != end(); it++ )
353 t = op1( t, op2(*it) );
354 return t;
355 }
356
357
358 /// Accumulate all values (similar to std::accumulate) by summing
359 /** The following calculation is done:
360 * \code
361 * T t = op(init);
362 * for( const_iterator it = begin(); it != end(); it++ )
363 * t += op(*it);
364 * return t;
365 * \endcode
366 */
367 template<typename unOp> T accumulateSum( T init, unOp op ) const {
368 T t = op(init);
369 for( const_iterator it = begin(); it != end(); it++ )
370 t += op(*it);
371 return t;
372 }
373
374 /// Accumulate all values (similar to std::accumulate) by maximization/minimization
375 /** The following calculation is done (with "max" replaced by "min" if \a minimize == \c true):
376 * \code
377 * T t = op(init);
378 * for( const_iterator it = begin(); it != end(); it++ )
379 * t = std::max( t, op(*it) );
380 * return t;
381 * \endcode
382 */
383 template<typename unOp> T accumulateMax( T init, unOp op, bool minimize ) const {
384 T t = op(init);
385 if( minimize ) {
386 for( const_iterator it = begin(); it != end(); it++ )
387 t = std::min( t, op(*it) );
388 } else {
389 for( const_iterator it = begin(); it != end(); it++ )
390 t = std::max( t, op(*it) );
391 }
392 return t;
393 }
394
395 /// Returns the Shannon entropy of \c *this, \f$-\sum_i p_i \log p_i\f$
396 T entropy() const { return -accumulateSum( (T)0, fo_plog0p<T>() ); }
397
398 /// Returns maximum value of all entries
399 T max() const { return accumulateMax( (T)(-INFINITY), fo_id<T>(), false ); }
400
401 /// Returns minimum value of all entries
402 T min() const { return accumulateMax( (T)INFINITY, fo_id<T>(), true ); }
403
404 /// Returns sum of all entries
405 T sum() const { return accumulateSum( (T)0, fo_id<T>() ); }
406
407 /// Return sum of absolute value of all entries
408 T sumAbs() const { return accumulateSum( (T)0, fo_abs<T>() ); }
409
410 /// Returns maximum absolute value of all entries
411 T maxAbs() const { return accumulateMax( (T)0, fo_abs<T>(), false ); }
412
413 /// Returns \c true if one or more entries are NaN
414 bool hasNaNs() const {
415 bool foundnan = false;
416 for( const_iterator x = _p.begin(); x != _p.end(); x++ )
417 if( isnan( *x ) ) {
418 foundnan = true;
419 break;
420 }
421 return foundnan;
422 }
423
424 /// Returns \c true if one or more entries are negative
425 bool hasNegatives() const {
426 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<T>(), (T)0 ) ) != _p.end());
427 }
428
429 /// Returns a pair consisting of the index of the maximum value and the maximum value itself
430 std::pair<size_t,T> argmax() const {
431 T max = _p[0];
432 size_t arg = 0;
433 for( size_t i = 1; i < size(); i++ ) {
434 if( _p[i] > max ) {
435 max = _p[i];
436 arg = i;
437 }
438 }
439 return std::make_pair( arg, max );
440 }
441
442 /// Returns a random index, according to the (normalized) distribution described by *this
443 size_t draw() {
444 Real x = rnd_uniform() * sum();
445 T s = 0;
446 for( size_t i = 0; i < size(); i++ ) {
447 s += get(i);
448 if( s > x )
449 return i;
450 }
451 return( size() - 1 );
452 }
453
454 /// Lexicographical comparison
455 /** \pre <tt>this->size() == q.size()</tt>
456 */
457 bool operator<( const this_type& q ) const {
458 DAI_DEBASSERT( size() == q.size() );
459 return lexicographical_compare( begin(), end(), q.begin(), q.end() );
460 }
461
462 /// Comparison
463 bool operator==( const this_type& q ) const {
464 if( size() != q.size() )
465 return false;
466 return p() == q.p();
467 }
468 //@}
469
470 /// \name Unary transformations
471 //@{
472 /// Returns the result of applying operation \a op pointwise on \c *this
473 template<typename unaryOp> this_type pwUnaryTr( unaryOp op ) const {
474 this_type r;
475 r._p.reserve( size() );
476 std::transform( _p.begin(), _p.end(), back_inserter( r._p ), op );
477 return r;
478 }
479
480 /// Returns negative of \c *this
481 this_type operator- () const { return pwUnaryTr( std::negate<T>() ); }
482
483 /// Returns pointwise absolute value
484 this_type abs() const { return pwUnaryTr( fo_abs<T>() ); }
485
486 /// Returns pointwise exponent
487 this_type exp() const { return pwUnaryTr( fo_exp<T>() ); }
488
489 /// Returns pointwise logarithm
490 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
491 */
492 this_type log(bool zero=false) const {
493 if( zero )
494 return pwUnaryTr( fo_log0<T>() );
495 else
496 return pwUnaryTr( fo_log<T>() );
497 }
498
499 /// Returns pointwise inverse
500 /** If \a zero == \c true, uses <tt>1/0==0</tt>; otherwise, <tt>1/0==Inf</tt>.
501 */
502 this_type inverse(bool zero=true) const {
503 if( zero )
504 return pwUnaryTr( fo_inv0<T>() );
505 else
506 return pwUnaryTr( fo_inv<T>() );
507 }
508
509 /// Returns normalized copy of \c *this, using the specified norm
510 /** \throw NOT_NORMALIZABLE if the norm is zero
511 */
512 this_type normalized( ProbNormType norm = dai::NORMPROB ) const {
513 T Z = 0;
514 if( norm == dai::NORMPROB )
515 Z = sum();
516 else if( norm == dai::NORMLINF )
517 Z = maxAbs();
518 if( Z == (T)0 ) {
519 DAI_THROW(NOT_NORMALIZABLE);
520 return *this;
521 } else
522 return pwUnaryTr( std::bind2nd( std::divides<T>(), Z ) );
523 }
524 //@}
525
526 /// \name Unary operations
527 //@{
528 /// Applies unary operation \a op pointwise
529 template<typename unaryOp> this_type& pwUnaryOp( unaryOp op ) {
530 std::transform( _p.begin(), _p.end(), _p.begin(), op );
531 return *this;
532 }
533
534 /// Draws all entries i.i.d. from a uniform distribution on [0,1)
535 this_type& randomize() {
536 std::generate( _p.begin(), _p.end(), rnd_uniform );
537 return *this;
538 }
539
540 /// Sets all entries to \f$1/n\f$ where \a n is the length of the vector
541 this_type& setUniform () {
542 fill( (T)1 / size() );
543 return *this;
544 }
545
546 /// Applies absolute value pointwise
547 this_type& takeAbs() { return pwUnaryOp( fo_abs<T>() ); }
548
549 /// Applies exponent pointwise
550 this_type& takeExp() { return pwUnaryOp( fo_exp<T>() ); }
551
552 /// Applies logarithm pointwise
553 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
554 */
555 this_type& takeLog(bool zero=false) {
556 if( zero ) {
557 return pwUnaryOp( fo_log0<T>() );
558 } else
559 return pwUnaryOp( fo_log<T>() );
560 }
561
562 /// Normalizes vector using the specified norm
563 /** \throw NOT_NORMALIZABLE if the norm is zero
564 */
565 T normalize( ProbNormType norm=dai::NORMPROB ) {
566 T Z = 0;
567 if( norm == dai::NORMPROB )
568 Z = sum();
569 else if( norm == dai::NORMLINF )
570 Z = maxAbs();
571 if( Z == (T)0 )
572 DAI_THROW(NOT_NORMALIZABLE);
573 else
574 *this /= Z;
575 return Z;
576 }
577 //@}
578
579 /// \name Operations with scalars
580 //@{
581 /// Sets all entries to \a x
582 this_type& fill( T x ) {
583 std::fill( _p.begin(), _p.end(), x );
584 return *this;
585 }
586
587 /// Adds scalar \a x to each entry
588 this_type& operator+= (T x) {
589 if( x != 0 )
590 return pwUnaryOp( std::bind2nd( std::plus<T>(), x ) );
591 else
592 return *this;
593 }
594
595 /// Subtracts scalar \a x from each entry
596 this_type& operator-= (T x) {
597 if( x != 0 )
598 return pwUnaryOp( std::bind2nd( std::minus<T>(), x ) );
599 else
600 return *this;
601 }
602
603 /// Multiplies each entry with scalar \a x
604 this_type& operator*= (T x) {
605 if( x != 1 )
606 return pwUnaryOp( std::bind2nd( std::multiplies<T>(), x ) );
607 else
608 return *this;
609 }
610
611 /// Divides each entry by scalar \a x, where division by 0 yields 0
612 this_type& operator/= (T x) {
613 if( x != 1 )
614 return pwUnaryOp( std::bind2nd( fo_divides0<T>(), x ) );
615 else
616 return *this;
617 }
618
619 /// Raises entries to the power \a x
620 this_type& operator^= (T x) {
621 if( x != (T)1 )
622 return pwUnaryOp( std::bind2nd( fo_pow<T>(), x) );
623 else
624 return *this;
625 }
626 //@}
627
628 /// \name Transformations with scalars
629 //@{
630 /// Returns sum of \c *this and scalar \a x
631 this_type operator+ (T x) const { return pwUnaryTr( std::bind2nd( std::plus<T>(), x ) ); }
632
633 /// Returns difference of \c *this and scalar \a x
634 this_type operator- (T x) const { return pwUnaryTr( std::bind2nd( std::minus<T>(), x ) ); }
635
636 /// Returns product of \c *this with scalar \a x
637 this_type operator* (T x) const { return pwUnaryTr( std::bind2nd( std::multiplies<T>(), x ) ); }
638
639 /// Returns quotient of \c *this and scalar \a x, where division by 0 yields 0
640 this_type operator/ (T x) const { return pwUnaryTr( std::bind2nd( fo_divides0<T>(), x ) ); }
641
642 /// Returns \c *this raised to the power \a x
643 this_type operator^ (T x) const { return pwUnaryTr( std::bind2nd( fo_pow<T>(), x ) ); }
644 //@}
645
646 /// \name Operations with other equally-sized vectors
647 //@{
648 /// Applies binary operation pointwise on two vectors
649 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
650 * \param q Right operand
651 * \param op Operation of type \a binaryOp
652 */
653 template<typename binaryOp> this_type& pwBinaryOp( const this_type &q, binaryOp op ) {
654 DAI_DEBASSERT( size() == q.size() );
655 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), op );
656 return *this;
657 }
658
659 /// Pointwise addition with \a q
660 /** \pre <tt>this->size() == q.size()</tt>
661 */
662 this_type& operator+= (const this_type & q) { return pwBinaryOp( q, std::plus<T>() ); }
663
664 /// Pointwise subtraction of \a q
665 /** \pre <tt>this->size() == q.size()</tt>
666 */
667 this_type& operator-= (const this_type & q) { return pwBinaryOp( q, std::minus<T>() ); }
668
669 /// Pointwise multiplication with \a q
670 /** \pre <tt>this->size() == q.size()</tt>
671 */
672 this_type& operator*= (const this_type & q) { return pwBinaryOp( q, std::multiplies<T>() ); }
673
674 /// Pointwise division by \a q, where division by 0 yields 0
675 /** \pre <tt>this->size() == q.size()</tt>
676 * \see divide(const TProb<T> &)
677 */
678 this_type& operator/= (const this_type & q) { return pwBinaryOp( q, fo_divides0<T>() ); }
679
680 /// Pointwise division by \a q, where division by 0 yields +Inf
681 /** \pre <tt>this->size() == q.size()</tt>
682 * \see operator/=(const TProb<T> &)
683 */
684 this_type& divide (const this_type & q) { return pwBinaryOp( q, std::divides<T>() ); }
685
686 /// Pointwise power
687 /** \pre <tt>this->size() == q.size()</tt>
688 */
689 this_type& operator^= (const this_type & q) { return pwBinaryOp( q, fo_pow<T>() ); }
690 //@}
691
692 /// \name Transformations with other equally-sized vectors
693 //@{
694 /// Returns the result of applying binary operation \a op pointwise on \c *this and \a q
695 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
696 * \param q Right operand
697 * \param op Operation of type \a binaryOp
698 */
699 template<typename binaryOp> this_type pwBinaryTr( const this_type &q, binaryOp op ) const {
700 DAI_DEBASSERT( size() == q.size() );
701 TProb<T> r;
702 r._p.reserve( size() );
703 std::transform( _p.begin(), _p.end(), q._p.begin(), back_inserter( r._p ), op );
704 return r;
705 }
706
707 /// Returns sum of \c *this and \a q
708 /** \pre <tt>this->size() == q.size()</tt>
709 */
710 this_type operator+ ( const this_type& q ) const { return pwBinaryTr( q, std::plus<T>() ); }
711
712 /// Return \c *this minus \a q
713 /** \pre <tt>this->size() == q.size()</tt>
714 */
715 this_type operator- ( const this_type& q ) const { return pwBinaryTr( q, std::minus<T>() ); }
716
717 /// Return product of \c *this with \a q
718 /** \pre <tt>this->size() == q.size()</tt>
719 */
720 this_type operator* ( const this_type &q ) const { return pwBinaryTr( q, std::multiplies<T>() ); }
721
722 /// Returns quotient of \c *this with \a q, where division by 0 yields 0
723 /** \pre <tt>this->size() == q.size()</tt>
724 * \see divided_by(const TProb<T> &)
725 */
726 this_type operator/ ( const this_type &q ) const { return pwBinaryTr( q, fo_divides0<T>() ); }
727
728 /// Pointwise division by \a q, where division by 0 yields +Inf
729 /** \pre <tt>this->size() == q.size()</tt>
730 * \see operator/(const TProb<T> &)
731 */
732 this_type divided_by( const this_type &q ) const { return pwBinaryTr( q, std::divides<T>() ); }
733
734 /// Returns \c *this to the power \a q
735 /** \pre <tt>this->size() == q.size()</tt>
736 */
737 this_type operator^ ( const this_type &q ) const { return pwBinaryTr( q, fo_pow<T>() ); }
738 //@}
739
740 /// Performs a generalized inner product, similar to std::inner_product
741 /** \pre <tt>this->size() == q.size()</tt>
742 */
743 template<typename binOp1, typename binOp2> T innerProduct( const this_type &q, T init, binOp1 binaryOp1, binOp2 binaryOp2 ) const {
744 DAI_DEBASSERT( size() == q.size() );
745 return std::inner_product( begin(), end(), q.begin(), init, binaryOp1, binaryOp2 );
746 }
747 };
748
749
750 /// Returns distance between \a p and \a q, measured using distance measure \a dt
751 /** \relates TProb
752 * \pre <tt>this->size() == q.size()</tt>
753 */
754 template<typename T> T dist( const TProb<T> &p, const TProb<T> &q, ProbDistType dt ) {
755 switch( dt ) {
756 case DISTL1:
757 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() );
758 case DISTLINF:
759 return p.innerProduct( q, (T)0, fo_max<T>(), fo_absdiff<T>() );
760 case DISTTV:
761 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() ) / 2;
762 case DISTKL:
763 return p.innerProduct( q, (T)0, std::plus<T>(), fo_KL<T>() );
764 case DISTHEL:
765 return p.innerProduct( q, (T)0, std::plus<T>(), fo_Hellinger<T>() ) / 2;
766 default:
767 DAI_THROW(UNKNOWN_ENUM_VALUE);
768 return INFINITY;
769 }
770 }
771
772
773 /// Writes a TProb<T> to an output stream
774 /** \relates TProb
775 */
776 template<typename T> std::ostream& operator<< (std::ostream& os, const TProb<T>& p) {
777 os << "(";
778 for( size_t i = 0; i < p.size(); i++ )
779 os << ((i != 0) ? ", " : "") << p.get(i);
780 os << ")";
781 return os;
782 }
783
784
785 /// Returns the pointwise minimum of \a a and \a b
786 /** \relates TProb
787 * \pre <tt>this->size() == q.size()</tt>
788 */
789 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) {
790 return a.pwBinaryTr( b, fo_min<T>() );
791 }
792
793
794 /// Returns the pointwise maximum of \a a and \a b
795 /** \relates TProb
796 * \pre <tt>this->size() == q.size()</tt>
797 */
798 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) {
799 return a.pwBinaryTr( b, fo_max<T>() );
800 }
801
802
803 /// Represents a vector with entries of type dai::Real.
804 typedef TProb<Real> Prob;
805
806
807 } // end of namespace dai
808
809
810 #endif