afbca04ad5d2aa62fe2b398d72d2699b491efb1f
[libdai.git] / include / dai / prob.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines TProb<> and Prob classes which represent (probability) vectors (e.g., probability distributions of discrete random variables)
14
15
16 #ifndef __defined_libdai_prob_h
17 #define __defined_libdai_prob_h
18
19
20 #include <cmath>
21 #include <vector>
22 #include <ostream>
23 #include <algorithm>
24 #include <numeric>
25 #include <functional>
26 #include <dai/util.h>
27 #include <dai/exceptions.h>
28
29
30 namespace dai {
31
32
33 /// Function object that returns the value itself
34 template<typename T> struct fo_id : public std::unary_function<T, T> {
35 /// Returns \a x
36 T operator()( const T &x ) const {
37 return x;
38 }
39 };
40
41
42 /// Function object that takes the absolute value
43 template<typename T> struct fo_abs : public std::unary_function<T, T> {
44 /// Returns abs(\a x)
45 T operator()( const T &x ) const {
46 if( x < (T)0 )
47 return -x;
48 else
49 return x;
50 }
51 };
52
53
54 /// Function object that takes the exponent
55 template<typename T> struct fo_exp : public std::unary_function<T, T> {
56 /// Returns exp(\a x)
57 T operator()( const T &x ) const {
58 return exp( x );
59 }
60 };
61
62
63 /// Function object that takes the logarithm
64 template<typename T> struct fo_log : public std::unary_function<T, T> {
65 /// Returns log(\a x)
66 T operator()( const T &x ) const {
67 return log( x );
68 }
69 };
70
71
72 /// Function object that takes the logarithm, except that log(0) is defined to be 0
73 template<typename T> struct fo_log0 : public std::unary_function<T, T> {
74 /// Returns (\a x == 0 ? 0 : log(\a x))
75 T operator()( const T &x ) const {
76 if( x )
77 return log( x );
78 else
79 return 0;
80 }
81 };
82
83
84 /// Function object that takes the inverse
85 template<typename T> struct fo_inv : public std::unary_function<T, T> {
86 /// Returns 1 / \a x
87 T operator()( const T &x ) const {
88 return 1 / x;
89 }
90 };
91
92
93 /// Function object that takes the inverse, except that 1/0 is defined to be 0
94 template<typename T> struct fo_inv0 : public std::unary_function<T, T> {
95 /// Returns (\a x == 0 ? 0 : (1 / \a x))
96 T operator()( const T &x ) const {
97 if( x )
98 return 1 / x;
99 else
100 return 0;
101 }
102 };
103
104
105 /// Function object that returns p*log0(p)
106 template<typename T> struct fo_plog0p : public std::unary_function<T, T> {
107 /// Returns \a p * log0(\a p)
108 T operator()( const T &p ) const {
109 return p * dai::log0(p);
110 }
111 };
112
113
114 /// Function object similar to std::divides(), but different in that dividing by zero results in zero
115 template<typename T> struct fo_divides0 : public std::binary_function<T, T, T> {
116 /// Returns (\a y == 0 ? 0 : (\a x / \a y))
117 T operator()( const T &x, const T &y ) const {
118 if( y == (T)0 )
119 return (T)0;
120 else
121 return x / y;
122 }
123 };
124
125
126 /// Function object useful for calculating the KL distance
127 template<typename T> struct fo_KL : public std::binary_function<T, T, T> {
128 /// Returns (\a p == 0 ? 0 : (\a p * (log(\a p) - log(\a q))))
129 T operator()( const T &p, const T &q ) const {
130 if( p == (T)0 )
131 return (T)0;
132 else
133 return p * (log(p) - log(q));
134 }
135 };
136
137
138 /// Function object that returns x to the power y
139 template<typename T> struct fo_pow : public std::binary_function<T, T, T> {
140 /// Returns (\a x ^ \a y)
141 T operator()( const T &x, const T &y ) const {
142 if( y != 1 )
143 return std::pow( x, y );
144 else
145 return x;
146 }
147 };
148
149
150 /// Function object that returns the maximum of two values
151 template<typename T> struct fo_max : public std::binary_function<T, T, T> {
152 /// Returns (\a x > y ? x : y)
153 T operator()( const T &x, const T &y ) const {
154 return (x > y) ? x : y;
155 }
156 };
157
158
159 /// Function object that returns the minimum of two values
160 template<typename T> struct fo_min : public std::binary_function<T, T, T> {
161 /// Returns (\a x > y ? y : x)
162 T operator()( const T &x, const T &y ) const {
163 return (x > y) ? y : x;
164 }
165 };
166
167
168 /// Function object that returns the absolute difference of x and y
169 template<typename T> struct fo_absdiff : public std::binary_function<T, T, T> {
170 /// Returns abs( \a x - \a y )
171 T operator()( const T &x, const T &y ) const {
172 return dai::abs( x - y );
173 }
174 };
175
176
177 /// Represents a vector with entries of type \a T.
178 /** It is simply a <tt>std::vector</tt><<em>T</em>> with an interface designed for dealing with probability mass functions.
179 *
180 * It is mainly used for representing measures on a finite outcome space, for example, the probability
181 * distribution of a discrete random variable. However, entries are not necessarily non-negative; it is also used to
182 * represent logarithms of probability mass functions.
183 *
184 * \tparam T Should be a scalar that is castable from and to dai::Real and should support elementary arithmetic operations.
185 */
186 template <typename T> class TProb {
187 private:
188 /// The vector
189 std::vector<T> _p;
190
191 public:
192 /// Enumerates different ways of normalizing a probability measure.
193 /**
194 * - NORMPROB means that the sum of all entries should be 1;
195 * - NORMLINF means that the maximum absolute value of all entries should be 1.
196 */
197 typedef enum { NORMPROB, NORMLINF } NormType;
198 /// Enumerates different distance measures between probability measures.
199 /**
200 * - DISTL1 is the \f$\ell_1\f$ distance (sum of absolute values of pointwise difference);
201 * - DISTLINF is the \f$\ell_\infty\f$ distance (maximum absolute value of pointwise difference);
202 * - DISTTV is the total variation distance (half of the \f$\ell_1\f$ distance);
203 * - DISTKL is the Kullback-Leibler distance (\f$\sum_i p_i (\log p_i - \log q_i)\f$).
204 */
205 typedef enum { DISTL1, DISTLINF, DISTTV, DISTKL } DistType;
206
207 /// \name Constructors and destructors
208 //@{
209 /// Default constructor (constructs empty vector)
210 TProb() : _p() {}
211
212 /// Construct uniform probability distribution over \a n outcomes (i.e., a vector of length \a n with each entry set to \f$1/n\f$)
213 explicit TProb( size_t n ) : _p(std::vector<T>(n, (T)1 / n)) {}
214
215 /// Construct vector of length \a n with each entry set to \a p
216 explicit TProb( size_t n, T p ) : _p(n, p) {}
217
218 /// Construct vector from a range
219 /** \tparam TIterator Iterates over instances that can be cast to \a T
220 * \param begin Points to first instance to be added.
221 * \param end Points just beyond last instance to be added.
222 * \param sizeHint For efficiency, the number of entries can be speficied by \a sizeHint.
223 */
224 template <typename TIterator>
225 TProb( TIterator begin, TIterator end, size_t sizeHint=0 ) : _p() {
226 _p.reserve( sizeHint );
227 _p.insert( _p.begin(), begin, end );
228 }
229
230 /// Construct vector from another vector
231 /** \tparam S type of elements in \a v (should be castable to type \a T)
232 * \param v vector used for initialization
233 */
234 template <typename S>
235 TProb( const std::vector<S> &v ) : _p() {
236 _p.reserve( v.size() );
237 _p.insert( _p.begin(), v.begin(), v.end() );
238 }
239 //@}
240
241 /// Constant iterator over the elements
242 typedef typename std::vector<T>::const_iterator const_iterator;
243 /// Iterator over the elements
244 typedef typename std::vector<T>::iterator iterator;
245 /// Constant reverse iterator over the elements
246 typedef typename std::vector<T>::const_reverse_iterator const_reverse_iterator;
247 /// Reverse iterator over the elements
248 typedef typename std::vector<T>::reverse_iterator reverse_iterator;
249
250 /// \name Iterator interface
251 //@{
252 /// Returns iterator that points to the first element
253 iterator begin() { return _p.begin(); }
254 /// Returns constant iterator that points to the first element
255 const_iterator begin() const { return _p.begin(); }
256
257 /// Returns iterator that points beyond the last element
258 iterator end() { return _p.end(); }
259 /// Returns constant iterator that points beyond the last element
260 const_iterator end() const { return _p.end(); }
261
262 /// Returns reverse iterator that points to the last element
263 reverse_iterator rbegin() { return _p.rbegin(); }
264 /// Returns constant reverse iterator that points to the last element
265 const_reverse_iterator rbegin() const { return _p.rbegin(); }
266
267 /// Returns reverse iterator that points beyond the first element
268 reverse_iterator rend() { return _p.rend(); }
269 /// Returns constant reverse iterator that points beyond the first element
270 const_reverse_iterator rend() const { return _p.rend(); }
271 //@}
272
273 /// \name Queries
274 //@{
275 /// Returns a const reference to the wrapped vector
276 const std::vector<T> & p() const { return _p; }
277
278 /// Returns a reference to the wrapped vector
279 std::vector<T> & p() { return _p; }
280
281 /// Returns a copy of the \a i 'th entry
282 T operator[]( size_t i ) const {
283 #ifdef DAI_DEBUG
284 return _p.at(i);
285 #else
286 return _p[i];
287 #endif
288 }
289
290 /// Returns reference to the \a i 'th entry
291 T& operator[]( size_t i ) { return _p[i]; }
292
293 /// Returns length of the vector (i.e., the number of entries)
294 size_t size() const { return _p.size(); }
295
296 /// Accumulate over all values, similar to std::accumulate
297 template<typename binOp, typename unOp> T accumulate( T init, binOp op1, unOp op2 ) const {
298 T t = init;
299 for( const_iterator it = begin(); it != end(); it++ )
300 t = op1( t, op2(*it) );
301 return t;
302 }
303
304 /// Returns the Shannon entropy of \c *this, \f$-\sum_i p_i \log p_i\f$
305 T entropy() const { return -accumulate( (T)0, std::plus<T>(), fo_plog0p<T>() ); }
306
307 /// Returns maximum value of all entries
308 T max() const { return accumulate( (T)(-INFINITY), fo_max<T>(), fo_id<T>() ); }
309
310 // OBSOLETE
311 /// Returns maximum value of all entries
312 /** \deprecated Please use max() instead
313 */
314 T maxVal() const { return max(); }
315
316 /// Returns minimum value of all entries
317 T min() const { return accumulate( (T)INFINITY, fo_min<T>(), fo_id<T>() ); }
318
319 // OBSOLETE
320 /// Returns minimum value of all entries
321 /** \deprecated Please use min() instead
322 */
323 T minVal() const { return min(); }
324
325 /// Returns sum of all entries
326 T sum() const { return accumulate( (T)0, std::plus<T>(), fo_id<T>() ); }
327
328 // OBSOLETE
329 /// Returns sum of all entries
330 /** \deprecated Please use sum() instead
331 */
332 T totalSum() const { return sum(); }
333
334 /// Return sum of absolute value of all entries
335 T sumAbs() const { return accumulate( (T)0, std::plus<T>(), fo_abs<T>() ); }
336
337 /// Returns maximum absolute value of all entries
338 T maxAbs() const { return accumulate( (T)0, fo_max<T>(), fo_abs<T>() ); }
339
340 /// Returns \c true if one or more entries are NaN
341 bool hasNaNs() const {
342 bool foundnan = false;
343 for( typename std::vector<T>::const_iterator x = _p.begin(); x != _p.end(); x++ )
344 if( isnan( *x ) ) {
345 foundnan = true;
346 break;
347 }
348 return foundnan;
349 }
350
351 /// Returns \c true if one or more entries are negative
352 bool hasNegatives() const {
353 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<T>(), (T)0 ) ) != _p.end());
354 }
355
356 // OBSOLETE
357 /// Returns \c true if one or more entries are non-positive
358 /** \deprecated Functionality was not widely used
359 */
360 bool hasNonPositives() const {
361 bool found = false;
362 for( typename std::vector<T>::const_iterator x = _p.begin(); x != _p.end(); x++ )
363 if( *x <= (T)0 ) {
364 found = true;
365 break;
366 }
367 return found;
368 }
369
370 /// Returns a pair consisting of the index of the maximum value and the maximum value itself
371 std::pair<size_t,T> argmax() const {
372 T max = _p[0];
373 size_t arg = 0;
374 for( size_t i = 1; i < size(); i++ ) {
375 if( _p[i] > max ) {
376 max = _p[i];
377 arg = i;
378 }
379 }
380 return std::make_pair(arg,max);
381 }
382
383 /// Returns a random index, according to the (normalized) distribution described by *this
384 size_t draw() {
385 Real x = rnd_uniform() * sum();
386 T s = 0;
387 for( size_t i = 0; i < size(); i++ ) {
388 s += _p[i];
389 if( s > x )
390 return i;
391 }
392 return( size() - 1 );
393 }
394
395 /// Lexicographical comparison
396 /** \pre <tt>this->size() == q.size()</tt>
397 */
398 bool operator<= (const TProb<T> & q) const {
399 DAI_DEBASSERT( size() == q.size() );
400 return lexicographical_compare( begin(), end(), q.begin(), q.end() );
401 }
402 //@}
403
404 /// \name Unary transformations
405 //@{
406 /// Returns the result of applying operation \a op pointwise on \c *this
407 template<typename unaryOp> TProb<T> pwUnaryTr( unaryOp op ) const {
408 TProb<T> r;
409 r._p.reserve( size() );
410 std::transform( _p.begin(), _p.end(), back_inserter( r._p ), op );
411 return r;
412 }
413
414 // OBSOLETE
415 /// Returns pointwise signum
416 /** \deprecated Functionality was not widely used
417 */
418 TProb<T> sgn() const {
419 TProb<T> x;
420 x._p.reserve( size() );
421 for( size_t i = 0; i < size(); i++ ) {
422 T s = 0;
423 if( _p[i] > 0 )
424 s = 1;
425 else if( _p[i] < 0 )
426 s = -1;
427 x._p.push_back( s );
428 }
429 return x;
430 }
431
432 /// Returns pointwise absolute value
433 TProb<T> abs() const { return pwUnaryTr( fo_abs<T>() ); }
434
435 /// Returns pointwise exponent
436 TProb<T> exp() const { return pwUnaryTr( fo_exp<T>() ); }
437
438 /// Returns pointwise logarithm
439 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
440 */
441 TProb<T> log(bool zero=false) const {
442 if( zero )
443 return pwUnaryTr( fo_log0<T>() );
444 else
445 return pwUnaryTr( fo_log<T>() );
446 }
447
448 // OBSOLETE
449 /// Returns pointwise logarithm (or zero if argument is zero)
450 /** \deprecated Please use log() instead with \a zero == \c true
451 */
452 TProb<T> log0() const { return log(true); }
453
454 /// Returns pointwise inverse
455 /** If \a zero == \c true, uses <tt>1/0==0</tt>; otherwise, <tt>1/0==Inf</tt>.
456 */
457 TProb<T> inverse(bool zero=true) const {
458 if( zero )
459 return pwUnaryTr( fo_inv0<T>() );
460 else
461 return pwUnaryTr( fo_inv<T>() );
462 }
463
464 /// Returns normalized copy of \c *this, using the specified norm
465 /** \throw NOT_NORMALIZABLE if the norm is zero
466 */
467 TProb<T> normalized( NormType norm = NORMPROB ) const {
468 T Z = 0;
469 if( norm == NORMPROB )
470 Z = sum();
471 else if( norm == NORMLINF )
472 Z = maxAbs();
473 if( Z == (T)0 ) {
474 DAI_THROW(NOT_NORMALIZABLE);
475 return *this;
476 } else
477 return pwUnaryTr( std::bind2nd( std::divides<T>(), Z ) );
478 }
479 //@}
480
481 /// \name Unary operations
482 //@{
483 /// Applies unary operation \a op pointwise
484 template<typename unaryOp> TProb<T>& pwUnaryOp( unaryOp op ) {
485 std::transform( _p.begin(), _p.end(), _p.begin(), op );
486 return *this;
487 }
488
489 /// Draws all entries i.i.d. from a uniform distribution on [0,1)
490 TProb<T>& randomize() {
491 std::generate( _p.begin(), _p.end(), rnd_uniform );
492 return *this;
493 }
494
495 /// Sets all entries to \f$1/n\f$ where \a n is the length of the vector
496 TProb<T>& setUniform () {
497 fill( (T)1 / size() );
498 return *this;
499 }
500
501 /// Applies absolute value pointwise
502 const TProb<T>& takeAbs() { return pwUnaryOp( fo_abs<T>() ); }
503
504 /// Applies exponent pointwise
505 const TProb<T>& takeExp() { return pwUnaryOp( fo_exp<T>() ); }
506
507 /// Applies logarithm pointwise
508 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
509 */
510 const TProb<T>& takeLog(bool zero=false) {
511 if( zero ) {
512 return pwUnaryOp( fo_log0<T>() );
513 } else
514 return pwUnaryOp( fo_log<T>() );
515 }
516
517 // OBSOLETE
518 /// Applies logarithm pointwise
519 /** \deprecated Please use takeLog() instead, with \a zero == \c true
520 */
521 const TProb<T>& takeLog0() { return takeLog(true); }
522
523 /// Normalizes vector using the specified norm
524 /** \throw NOT_NORMALIZABLE if the norm is zero
525 */
526 T normalize( NormType norm=NORMPROB ) {
527 T Z = 0;
528 if( norm == NORMPROB )
529 Z = sum();
530 else if( norm == NORMLINF )
531 Z = maxAbs();
532 if( Z == (T)0 )
533 DAI_THROW(NOT_NORMALIZABLE);
534 else
535 *this /= Z;
536 return Z;
537 }
538 //@}
539
540 /// \name Operations with scalars
541 //@{
542 /// Sets all entries to \a x
543 TProb<T> & fill(T x) {
544 std::fill( _p.begin(), _p.end(), x );
545 return *this;
546 }
547
548 // OBSOLETE
549 /// Sets entries that are smaller (in absolute value) than \a epsilon to 0
550 /** \deprecated Functionality was not widely used
551 */
552 TProb<T>& makeZero( T epsilon ) {
553 for( size_t i = 0; i < size(); i++ )
554 if( (_p[i] < epsilon) && (_p[i] > -epsilon) )
555 _p[i] = 0;
556 return *this;
557 }
558
559 // OBSOLETE
560 /// Sets entries that are smaller than \a epsilon to \a epsilon
561 /** \deprecated Functionality was not widely used
562 */
563 TProb<T>& makePositive( T epsilon ) {
564 for( size_t i = 0; i < size(); i++ )
565 if( (0 < _p[i]) && (_p[i] < epsilon) )
566 _p[i] = epsilon;
567 return *this;
568 }
569
570 /// Adds scalar \a x to each entry
571 TProb<T>& operator+= (T x) {
572 if( x != 0 )
573 return pwUnaryOp( std::bind2nd( std::plus<T>(), x ) );
574 else
575 return *this;
576 }
577
578 /// Subtracts scalar \a x from each entry
579 TProb<T>& operator-= (T x) {
580 if( x != 0 )
581 return pwUnaryOp( std::bind2nd( std::minus<T>(), x ) );
582 else
583 return *this;
584 }
585
586 /// Multiplies each entry with scalar \a x
587 TProb<T>& operator*= (T x) {
588 if( x != 1 )
589 return pwUnaryOp( std::bind2nd( std::multiplies<T>(), x ) );
590 else
591 return *this;
592 }
593
594 /// Divides each entry by scalar \a x
595 TProb<T>& operator/= (T x) {
596 DAI_DEBASSERT( x != 0 );
597 if( x != 1 )
598 return pwUnaryOp( std::bind2nd( std::divides<T>(), x ) );
599 else
600 return *this;
601 }
602
603 /// Raises entries to the power \a x
604 TProb<T>& operator^= (T x) {
605 if( x != (T)1 )
606 return pwUnaryOp( std::bind2nd( fo_pow<T>(), x) );
607 else
608 return *this;
609 }
610 //@}
611
612 /// \name Transformations with scalars
613 //@{
614 /// Returns sum of \c *this and scalar \a x
615 TProb<T> operator+ (T x) const { return pwUnaryTr( std::bind2nd( std::plus<T>(), x ) ); }
616
617 /// Returns difference of \c *this and scalar \a x
618 TProb<T> operator- (T x) const { return pwUnaryTr( std::bind2nd( std::minus<T>(), x ) ); }
619
620 /// Returns product of \c *this with scalar \a x
621 TProb<T> operator* (T x) const { return pwUnaryTr( std::bind2nd( std::multiplies<T>(), x ) ); }
622
623 /// Returns quotient of \c *this and scalar \a x, where division by 0 yields 0
624 TProb<T> operator/ (T x) const { return pwUnaryTr( std::bind2nd( fo_divides0<T>(), x ) ); }
625
626 /// Returns \c *this raised to the power \a x
627 TProb<T> operator^ (T x) const { return pwUnaryTr( std::bind2nd( fo_pow<T>(), x ) ); }
628 //@}
629
630 /// \name Operations with other equally-sized vectors
631 //@{
632 /// Applies binary operation pointwise on two vectors
633 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
634 * \param q Right operand
635 * \param op Operation of type \a binaryOp
636 */
637 template<typename binaryOp> TProb<T>& pwBinaryOp( const TProb<T> &q, binaryOp op ) {
638 DAI_DEBASSERT( size() == q.size() );
639 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), op );
640 return *this;
641 }
642
643 /// Pointwise addition with \a q
644 /** \pre <tt>this->size() == q.size()</tt>
645 */
646 TProb<T>& operator+= (const TProb<T> & q) { return pwBinaryOp( q, std::plus<T>() ); }
647
648 /// Pointwise subtraction of \a q
649 /** \pre <tt>this->size() == q.size()</tt>
650 */
651 TProb<T>& operator-= (const TProb<T> & q) { return pwBinaryOp( q, std::minus<T>() ); }
652
653 /// Pointwise multiplication with \a q
654 /** \pre <tt>this->size() == q.size()</tt>
655 */
656 TProb<T>& operator*= (const TProb<T> & q) { return pwBinaryOp( q, std::multiplies<T>() ); }
657
658 /// Pointwise division by \a q, where division by 0 yields 0
659 /** \pre <tt>this->size() == q.size()</tt>
660 * \see divide(const TProb<T> &)
661 */
662 TProb<T>& operator/= (const TProb<T> & q) { return pwBinaryOp( q, fo_divides0<T>() ); }
663
664 /// Pointwise division by \a q, where division by 0 yields +Inf
665 /** \pre <tt>this->size() == q.size()</tt>
666 * \see operator/=(const TProb<T> &)
667 */
668 TProb<T>& divide (const TProb<T> & q) { return pwBinaryOp( q, std::divides<T>() ); }
669
670 /// Pointwise power
671 /** \pre <tt>this->size() == q.size()</tt>
672 */
673 TProb<T>& operator^= (const TProb<T> & q) { return pwBinaryOp( q, fo_pow<T>() ); }
674 //@}
675
676 /// \name Transformations with other equally-sized vectors
677 //@{
678 /// Returns the result of applying binary operation \a op pointwise on \c *this and \a q
679 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
680 * \param q Right operand
681 * \param op Operation of type \a binaryOp
682 */
683 template<typename binaryOp> TProb<T> pwBinaryTr( const TProb<T> &q, binaryOp op ) const {
684 DAI_DEBASSERT( size() == q.size() );
685 TProb<T> r;
686 r._p.reserve( size() );
687 std::transform( _p.begin(), _p.end(), q._p.begin(), back_inserter( r._p ), op );
688 return r;
689 }
690
691 /// Returns sum of \c *this and \a q
692 /** \pre <tt>this->size() == q.size()</tt>
693 */
694 TProb<T> operator+ ( const TProb<T>& q ) const { return pwBinaryTr( q, std::plus<T>() ); }
695
696 /// Return \c *this minus \a q
697 /** \pre <tt>this->size() == q.size()</tt>
698 */
699 TProb<T> operator- ( const TProb<T>& q ) const { return pwBinaryTr( q, std::minus<T>() ); }
700
701 /// Return product of \c *this with \a q
702 /** \pre <tt>this->size() == q.size()</tt>
703 */
704 TProb<T> operator* ( const TProb<T> &q ) const { return pwBinaryTr( q, std::multiplies<T>() ); }
705
706 /// Returns quotient of \c *this with \a q, where division by 0 yields 0
707 /** \pre <tt>this->size() == q.size()</tt>
708 * \see divided_by(const TProb<T> &)
709 */
710 TProb<T> operator/ ( const TProb<T> &q ) const { return pwBinaryTr( q, fo_divides0<T>() ); }
711
712 /// Pointwise division by \a q, where division by 0 yields +Inf
713 /** \pre <tt>this->size() == q.size()</tt>
714 * \see operator/(const TProb<T> &)
715 */
716 TProb<T> divided_by( const TProb<T> &q ) const { return pwBinaryTr( q, std::divides<T>() ); }
717
718 /// Returns \c *this to the power \a q
719 /** \pre <tt>this->size() == q.size()</tt>
720 */
721 TProb<T> operator^ ( const TProb<T> &q ) const { return pwBinaryTr( q, fo_pow<T>() ); }
722 //@}
723
724 /// Performs a generalized inner product, similar to std::inner_product
725 /** \pre <tt>this->size() == q.size()</tt>
726 */
727 template<typename binOp1, typename binOp2> T innerProduct( const TProb<T> &q, T init, binOp1 binaryOp1, binOp2 binaryOp2 ) const {
728 DAI_DEBASSERT( size() == q.size() );
729 return std::inner_product( begin(), end(), q.begin(), init, binaryOp1, binaryOp2 );
730 }
731 };
732
733
734 /// Returns distance between \a p and \a q, measured using distance measure \a dt
735 /** \relates TProb
736 * \pre <tt>this->size() == q.size()</tt>
737 */
738 template<typename T> T dist( const TProb<T> &p, const TProb<T> &q, typename TProb<T>::DistType dt ) {
739 switch( dt ) {
740 case TProb<T>::DISTL1:
741 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() );
742 case TProb<T>::DISTLINF:
743 return p.innerProduct( q, (T)0, fo_max<T>(), fo_absdiff<T>() );
744 case TProb<T>::DISTTV:
745 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() ) / 2;
746 case TProb<T>::DISTKL:
747 return p.innerProduct( q, (T)0, std::plus<T>(), fo_KL<T>() );
748 default:
749 DAI_THROW(UNKNOWN_ENUM_VALUE);
750 return INFINITY;
751 }
752 }
753
754
755 /// Writes a TProb<T> to an output stream
756 /** \relates TProb
757 */
758 template<typename T> std::ostream& operator<< (std::ostream& os, const TProb<T>& p) {
759 os << "[";
760 std::copy( p.p().begin(), p.p().end(), std::ostream_iterator<T>(os, " ") );
761 os << "]";
762 return os;
763 }
764
765
766 /// Returns the pointwise minimum of \a a and \a b
767 /** \relates TProb
768 * \pre <tt>this->size() == q.size()</tt>
769 */
770 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) {
771 return a.pwBinaryTr( b, fo_min<T>() );
772 }
773
774
775 /// Returns the pointwise maximum of \a a and \a b
776 /** \relates TProb
777 * \pre <tt>this->size() == q.size()</tt>
778 */
779 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) {
780 return a.pwBinaryTr( b, fo_max<T>() );
781 }
782
783
784 /// Represents a vector with entries of type dai::Real.
785 typedef TProb<Real> Prob;
786
787
788 } // end of namespace dai
789
790
791 #endif