51dfe47c7c2c2e6614b1e4dd5f6a12a55e438bdd
[libdai.git] / include / dai / prob.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 /// \file
10 /// \brief Defines TProb<> and Prob classes which represent (probability) vectors (e.g., probability distributions of discrete random variables)
11
12
13 #ifndef __defined_libdai_prob_h
14 #define __defined_libdai_prob_h
15
16
17 #include <cmath>
18 #include <vector>
19 #include <ostream>
20 #include <algorithm>
21 #include <numeric>
22 #include <functional>
23 #include <dai/util.h>
24 #include <dai/exceptions.h>
25
26
27 namespace dai {
28
29
30 /// Function object that returns the value itself
31 template<typename T> struct fo_id : public std::unary_function<T, T> {
32 /// Returns \a x
33 T operator()( const T &x ) const {
34 return x;
35 }
36 };
37
38
39 /// Function object that takes the absolute value
40 template<typename T> struct fo_abs : public std::unary_function<T, T> {
41 /// Returns abs(\a x)
42 T operator()( const T &x ) const {
43 if( x < (T)0 )
44 return -x;
45 else
46 return x;
47 }
48 };
49
50
51 /// Function object that takes the exponent
52 template<typename T> struct fo_exp : public std::unary_function<T, T> {
53 /// Returns exp(\a x)
54 T operator()( const T &x ) const {
55 return exp( x );
56 }
57 };
58
59
60 /// Function object that takes the logarithm
61 template<typename T> struct fo_log : public std::unary_function<T, T> {
62 /// Returns log(\a x)
63 T operator()( const T &x ) const {
64 return log( x );
65 }
66 };
67
68
69 /// Function object that takes the logarithm, except that log(0) is defined to be 0
70 template<typename T> struct fo_log0 : public std::unary_function<T, T> {
71 /// Returns (\a x == 0 ? 0 : log(\a x))
72 T operator()( const T &x ) const {
73 if( x )
74 return log( x );
75 else
76 return 0;
77 }
78 };
79
80
81 /// Function object that takes the inverse
82 template<typename T> struct fo_inv : public std::unary_function<T, T> {
83 /// Returns 1 / \a x
84 T operator()( const T &x ) const {
85 return 1 / x;
86 }
87 };
88
89
90 /// Function object that takes the inverse, except that 1/0 is defined to be 0
91 template<typename T> struct fo_inv0 : public std::unary_function<T, T> {
92 /// Returns (\a x == 0 ? 0 : (1 / \a x))
93 T operator()( const T &x ) const {
94 if( x )
95 return 1 / x;
96 else
97 return 0;
98 }
99 };
100
101
102 /// Function object that returns p*log0(p)
103 template<typename T> struct fo_plog0p : public std::unary_function<T, T> {
104 /// Returns \a p * log0(\a p)
105 T operator()( const T &p ) const {
106 return p * dai::log0(p);
107 }
108 };
109
110
111 /// Function object similar to std::divides(), but different in that dividing by zero results in zero
112 template<typename T> struct fo_divides0 : public std::binary_function<T, T, T> {
113 /// Returns (\a y == 0 ? 0 : (\a x / \a y))
114 T operator()( const T &x, const T &y ) const {
115 if( y == (T)0 )
116 return (T)0;
117 else
118 return x / y;
119 }
120 };
121
122
123 /// Function object useful for calculating the KL distance
124 template<typename T> struct fo_KL : public std::binary_function<T, T, T> {
125 /// Returns (\a p == 0 ? 0 : (\a p * (log(\a p) - log(\a q))))
126 T operator()( const T &p, const T &q ) const {
127 if( p == (T)0 )
128 return (T)0;
129 else
130 return p * (log(p) - log(q));
131 }
132 };
133
134
135 /// Function object useful for calculating the Hellinger distance
136 template<typename T> struct fo_Hellinger : public std::binary_function<T, T, T> {
137 /// Returns (sqrt(\a p) - sqrt(\a q))^2
138 T operator()( const T &p, const T &q ) const {
139 T x = sqrt(p) - sqrt(q);
140 return x * x;
141 }
142 };
143
144
145 /// Function object that returns x to the power y
146 template<typename T> struct fo_pow : public std::binary_function<T, T, T> {
147 /// Returns (\a x ^ \a y)
148 T operator()( const T &x, const T &y ) const {
149 if( y != 1 )
150 return pow( x, y );
151 else
152 return x;
153 }
154 };
155
156
157 /// Function object that returns the maximum of two values
158 template<typename T> struct fo_max : public std::binary_function<T, T, T> {
159 /// Returns (\a x > y ? x : y)
160 T operator()( const T &x, const T &y ) const {
161 return (x > y) ? x : y;
162 }
163 };
164
165
166 /// Function object that returns the minimum of two values
167 template<typename T> struct fo_min : public std::binary_function<T, T, T> {
168 /// Returns (\a x > y ? y : x)
169 T operator()( const T &x, const T &y ) const {
170 return (x > y) ? y : x;
171 }
172 };
173
174
175 /// Function object that returns the absolute difference of x and y
176 template<typename T> struct fo_absdiff : public std::binary_function<T, T, T> {
177 /// Returns abs( \a x - \a y )
178 T operator()( const T &x, const T &y ) const {
179 return dai::abs( x - y );
180 }
181 };
182
183
184 /// Represents a vector with entries of type \a T.
185 /** It is simply a <tt>std::vector</tt><<em>T</em>> with an interface designed for dealing with probability mass functions.
186 *
187 * It is mainly used for representing measures on a finite outcome space, for example, the probability
188 * distribution of a discrete random variable. However, entries are not necessarily non-negative; it is also used to
189 * represent logarithms of probability mass functions.
190 *
191 * \tparam T Should be a scalar that is castable from and to dai::Real and should support elementary arithmetic operations.
192 */
193 template <typename T>
194 class TProb {
195 public:
196 /// Type of data structure used for storing the values
197 typedef std::vector<T> container_type;
198
199 /// Shorthand
200 typedef TProb<T> this_type;
201
202 private:
203 /// The data structure that stores the values
204 container_type _p;
205
206 public:
207 /// \name Constructors and destructors
208 //@{
209 /// Default constructor (constructs empty vector)
210 TProb() : _p() {}
211
212 /// Construct uniform probability distribution over \a n outcomes (i.e., a vector of length \a n with each entry set to \f$1/n\f$)
213 explicit TProb( size_t n ) : _p( n, (T)1 / n ) {}
214
215 /// Construct vector of length \a n with each entry set to \a p
216 explicit TProb( size_t n, T p ) : _p( n, p ) {}
217
218 /// Construct vector from a range
219 /** \tparam TIterator Iterates over instances that can be cast to \a T
220 * \param begin Points to first instance to be added.
221 * \param end Points just beyond last instance to be added.
222 * \param sizeHint For efficiency, the number of entries can be speficied by \a sizeHint;
223 * the value 0 can be given if the size is unknown, but this will result in a performance penalty.
224 */
225 template <typename TIterator>
226 TProb( TIterator begin, TIterator end, size_t sizeHint ) : _p() {
227 _p.reserve( sizeHint );
228 _p.insert( _p.begin(), begin, end );
229 }
230
231 /// Construct vector from another vector
232 /** \tparam S type of elements in \a v (should be castable to type \a T)
233 * \param v vector used for initialization.
234 */
235 template <typename S>
236 TProb( const std::vector<S> &v ) : _p() {
237 _p.reserve( v.size() );
238 _p.insert( _p.begin(), v.begin(), v.end() );
239 }
240 //@}
241
242 /// Constant iterator over the elements
243 typedef typename container_type::const_iterator const_iterator;
244 /// Iterator over the elements
245 typedef typename container_type::iterator iterator;
246 /// Constant reverse iterator over the elements
247 typedef typename container_type::const_reverse_iterator const_reverse_iterator;
248 /// Reverse iterator over the elements
249 typedef typename container_type::reverse_iterator reverse_iterator;
250
251 /// \name Iterator interface
252 //@{
253 /// Returns iterator that points to the first element
254 iterator begin() { return _p.begin(); }
255 /// Returns constant iterator that points to the first element
256 const_iterator begin() const { return _p.begin(); }
257
258 /// Returns iterator that points beyond the last element
259 iterator end() { return _p.end(); }
260 /// Returns constant iterator that points beyond the last element
261 const_iterator end() const { return _p.end(); }
262
263 /// Returns reverse iterator that points to the last element
264 reverse_iterator rbegin() { return _p.rbegin(); }
265 /// Returns constant reverse iterator that points to the last element
266 const_reverse_iterator rbegin() const { return _p.rbegin(); }
267
268 /// Returns reverse iterator that points beyond the first element
269 reverse_iterator rend() { return _p.rend(); }
270 /// Returns constant reverse iterator that points beyond the first element
271 const_reverse_iterator rend() const { return _p.rend(); }
272 //@}
273
274 /// \name Miscellaneous operations
275 //@{
276 void resize( size_t sz ) {
277 _p.resize( sz );
278 }
279 //@}
280
281 /// \name Get/set individual entries
282 //@{
283 /// Gets \a i 'th entry
284 T get( size_t i ) const {
285 #ifdef DAI_DEBUG
286 return _p.at(i);
287 #else
288 return _p[i];
289 #endif
290 }
291
292 /// Sets \a i 'th entry to \a val
293 void set( size_t i, T val ) {
294 DAI_DEBASSERT( i < _p.size() );
295 _p[i] = val;
296 }
297 //@}
298
299 /// \name Queries
300 //@{
301 /// Returns a const reference to the wrapped container
302 const container_type& p() const { return _p; }
303
304 /// Returns a reference to the wrapped container
305 container_type& p() { return _p; }
306
307 /// Returns a copy of the \a i 'th entry
308 T operator[]( size_t i ) const { return get(i); }
309
310 /// Returns length of the vector (i.e., the number of entries)
311 size_t size() const { return _p.size(); }
312
313 /// Accumulate all values (similar to std::accumulate) by summing
314 /** The following calculation is done:
315 * \code
316 * T t = op(init);
317 * for( const_iterator it = begin(); it != end(); it++ )
318 * t += op(*it);
319 * return t;
320 * \endcode
321 */
322 template<typename unOp> T accumulateSum( T init, unOp op ) const {
323 T t = op(init);
324 for( const_iterator it = begin(); it != end(); it++ )
325 t += op(*it);
326 return t;
327 }
328
329 /// Accumulate all values (similar to std::accumulate) by maximization/minimization
330 /** The following calculation is done (with "max" replaced by "min" if \a minimize == \c true):
331 * \code
332 * T t = op(init);
333 * for( const_iterator it = begin(); it != end(); it++ )
334 * t = std::max( t, op(*it) );
335 * return t;
336 * \endcode
337 */
338 template<typename unOp> T accumulateMax( T init, unOp op, bool minimize ) const {
339 T t = op(init);
340 if( minimize ) {
341 for( const_iterator it = begin(); it != end(); it++ )
342 t = std::min( t, op(*it) );
343 } else {
344 for( const_iterator it = begin(); it != end(); it++ )
345 t = std::max( t, op(*it) );
346 }
347 return t;
348 }
349
350 /// Returns the Shannon entropy of \c *this, \f$-\sum_i p_i \log p_i\f$
351 T entropy() const { return -accumulateSum( (T)0, fo_plog0p<T>() ); }
352
353 /// Returns maximum value of all entries
354 T max() const { return accumulateMax( (T)(-INFINITY), fo_id<T>(), false ); }
355
356 /// Returns minimum value of all entries
357 T min() const { return accumulateMax( (T)INFINITY, fo_id<T>(), true ); }
358
359 /// Returns sum of all entries
360 T sum() const { return accumulateSum( (T)0, fo_id<T>() ); }
361
362 /// Return sum of absolute value of all entries
363 T sumAbs() const { return accumulateSum( (T)0, fo_abs<T>() ); }
364
365 /// Returns maximum absolute value of all entries
366 T maxAbs() const { return accumulateMax( (T)0, fo_abs<T>(), false ); }
367
368 /// Returns \c true if one or more entries are NaN
369 bool hasNaNs() const {
370 bool foundnan = false;
371 for( const_iterator x = _p.begin(); x != _p.end(); x++ )
372 if( dai::isnan( *x ) ) {
373 foundnan = true;
374 break;
375 }
376 return foundnan;
377 }
378
379 /// Returns \c true if one or more entries are negative
380 bool hasNegatives() const {
381 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<T>(), (T)0 ) ) != _p.end());
382 }
383
384 /// Returns a pair consisting of the index of the maximum value and the maximum value itself
385 std::pair<size_t,T> argmax() const {
386 T max = _p[0];
387 size_t arg = 0;
388 for( size_t i = 1; i < size(); i++ ) {
389 if( _p[i] > max ) {
390 max = _p[i];
391 arg = i;
392 }
393 }
394 return std::make_pair( arg, max );
395 }
396
397 /// Returns a random index, according to the (normalized) distribution described by *this
398 size_t draw() {
399 Real x = rnd_uniform() * sum();
400 T s = 0;
401 for( size_t i = 0; i < size(); i++ ) {
402 s += get(i);
403 if( s > x )
404 return i;
405 }
406 return( size() - 1 );
407 }
408
409 /// Lexicographical comparison
410 /** \pre <tt>this->size() == q.size()</tt>
411 */
412 bool operator<( const this_type& q ) const {
413 DAI_DEBASSERT( size() == q.size() );
414 return lexicographical_compare( begin(), end(), q.begin(), q.end() );
415 }
416
417 /// Comparison
418 bool operator==( const this_type& q ) const {
419 if( size() != q.size() )
420 return false;
421 return p() == q.p();
422 }
423
424 /// Formats a TProb as a string
425 std::string toString() const {
426 std::stringstream ss;
427 ss << *this;
428 return ss.str();
429 }
430 //@}
431
432 /// \name Unary transformations
433 //@{
434 /// Returns the result of applying operation \a op pointwise on \c *this
435 template<typename unaryOp> this_type pwUnaryTr( unaryOp op ) const {
436 this_type r;
437 r._p.reserve( size() );
438 std::transform( _p.begin(), _p.end(), back_inserter( r._p ), op );
439 return r;
440 }
441
442 /// Returns negative of \c *this
443 this_type operator- () const { return pwUnaryTr( std::negate<T>() ); }
444
445 /// Returns pointwise absolute value
446 this_type abs() const { return pwUnaryTr( fo_abs<T>() ); }
447
448 /// Returns pointwise exponent
449 this_type exp() const { return pwUnaryTr( fo_exp<T>() ); }
450
451 /// Returns pointwise logarithm
452 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
453 */
454 this_type log(bool zero=false) const {
455 if( zero )
456 return pwUnaryTr( fo_log0<T>() );
457 else
458 return pwUnaryTr( fo_log<T>() );
459 }
460
461 /// Returns pointwise inverse
462 /** If \a zero == \c true, uses <tt>1/0==0</tt>; otherwise, <tt>1/0==Inf</tt>.
463 */
464 this_type inverse(bool zero=true) const {
465 if( zero )
466 return pwUnaryTr( fo_inv0<T>() );
467 else
468 return pwUnaryTr( fo_inv<T>() );
469 }
470
471 /// Returns normalized copy of \c *this, using the specified norm
472 /** \throw NOT_NORMALIZABLE if the norm is zero
473 */
474 this_type normalized( ProbNormType norm = dai::NORMPROB ) const {
475 T Z = 0;
476 if( norm == dai::NORMPROB )
477 Z = sum();
478 else if( norm == dai::NORMLINF )
479 Z = maxAbs();
480 if( Z == (T)0 ) {
481 DAI_THROW(NOT_NORMALIZABLE);
482 return *this;
483 } else
484 return pwUnaryTr( std::bind2nd( std::divides<T>(), Z ) );
485 }
486 //@}
487
488 /// \name Unary operations
489 //@{
490 /// Applies unary operation \a op pointwise
491 template<typename unaryOp> this_type& pwUnaryOp( unaryOp op ) {
492 std::transform( _p.begin(), _p.end(), _p.begin(), op );
493 return *this;
494 }
495
496 /// Draws all entries i.i.d. from a uniform distribution on [0,1)
497 this_type& randomize() {
498 std::generate( _p.begin(), _p.end(), rnd_uniform );
499 return *this;
500 }
501
502 /// Sets all entries to \f$1/n\f$ where \a n is the length of the vector
503 this_type& setUniform () {
504 fill( (T)1 / size() );
505 return *this;
506 }
507
508 /// Applies absolute value pointwise
509 this_type& takeAbs() { return pwUnaryOp( fo_abs<T>() ); }
510
511 /// Applies exponent pointwise
512 this_type& takeExp() { return pwUnaryOp( fo_exp<T>() ); }
513
514 /// Applies logarithm pointwise
515 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
516 */
517 this_type& takeLog(bool zero=false) {
518 if( zero ) {
519 return pwUnaryOp( fo_log0<T>() );
520 } else
521 return pwUnaryOp( fo_log<T>() );
522 }
523
524 /// Normalizes vector using the specified norm
525 /** \throw NOT_NORMALIZABLE if the norm is zero
526 */
527 T normalize( ProbNormType norm=dai::NORMPROB ) {
528 T Z = 0;
529 if( norm == dai::NORMPROB )
530 Z = sum();
531 else if( norm == dai::NORMLINF )
532 Z = maxAbs();
533 if( Z == (T)0 )
534 DAI_THROW(NOT_NORMALIZABLE);
535 else
536 *this /= Z;
537 return Z;
538 }
539 //@}
540
541 /// \name Operations with scalars
542 //@{
543 /// Sets all entries to \a x
544 this_type& fill( T x ) {
545 std::fill( _p.begin(), _p.end(), x );
546 return *this;
547 }
548
549 /// Adds scalar \a x to each entry
550 this_type& operator+= (T x) {
551 if( x != 0 )
552 return pwUnaryOp( std::bind2nd( std::plus<T>(), x ) );
553 else
554 return *this;
555 }
556
557 /// Subtracts scalar \a x from each entry
558 this_type& operator-= (T x) {
559 if( x != 0 )
560 return pwUnaryOp( std::bind2nd( std::minus<T>(), x ) );
561 else
562 return *this;
563 }
564
565 /// Multiplies each entry with scalar \a x
566 this_type& operator*= (T x) {
567 if( x != 1 )
568 return pwUnaryOp( std::bind2nd( std::multiplies<T>(), x ) );
569 else
570 return *this;
571 }
572
573 /// Divides each entry by scalar \a x, where division by 0 yields 0
574 this_type& operator/= (T x) {
575 if( x != 1 )
576 return pwUnaryOp( std::bind2nd( fo_divides0<T>(), x ) );
577 else
578 return *this;
579 }
580
581 /// Raises entries to the power \a x
582 this_type& operator^= (T x) {
583 if( x != (T)1 )
584 return pwUnaryOp( std::bind2nd( fo_pow<T>(), x) );
585 else
586 return *this;
587 }
588 //@}
589
590 /// \name Transformations with scalars
591 //@{
592 /// Returns sum of \c *this and scalar \a x
593 this_type operator+ (T x) const { return pwUnaryTr( std::bind2nd( std::plus<T>(), x ) ); }
594
595 /// Returns difference of \c *this and scalar \a x
596 this_type operator- (T x) const { return pwUnaryTr( std::bind2nd( std::minus<T>(), x ) ); }
597
598 /// Returns product of \c *this with scalar \a x
599 this_type operator* (T x) const { return pwUnaryTr( std::bind2nd( std::multiplies<T>(), x ) ); }
600
601 /// Returns quotient of \c *this and scalar \a x, where division by 0 yields 0
602 this_type operator/ (T x) const { return pwUnaryTr( std::bind2nd( fo_divides0<T>(), x ) ); }
603
604 /// Returns \c *this raised to the power \a x
605 this_type operator^ (T x) const { return pwUnaryTr( std::bind2nd( fo_pow<T>(), x ) ); }
606 //@}
607
608 /// \name Operations with other equally-sized vectors
609 //@{
610 /// Applies binary operation pointwise on two vectors
611 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
612 * \param q Right operand
613 * \param op Operation of type \a binaryOp
614 */
615 template<typename binaryOp> this_type& pwBinaryOp( const this_type &q, binaryOp op ) {
616 DAI_DEBASSERT( size() == q.size() );
617 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), op );
618 return *this;
619 }
620
621 /// Pointwise addition with \a q
622 /** \pre <tt>this->size() == q.size()</tt>
623 */
624 this_type& operator+= (const this_type & q) { return pwBinaryOp( q, std::plus<T>() ); }
625
626 /// Pointwise subtraction of \a q
627 /** \pre <tt>this->size() == q.size()</tt>
628 */
629 this_type& operator-= (const this_type & q) { return pwBinaryOp( q, std::minus<T>() ); }
630
631 /// Pointwise multiplication with \a q
632 /** \pre <tt>this->size() == q.size()</tt>
633 */
634 this_type& operator*= (const this_type & q) { return pwBinaryOp( q, std::multiplies<T>() ); }
635
636 /// Pointwise division by \a q, where division by 0 yields 0
637 /** \pre <tt>this->size() == q.size()</tt>
638 * \see divide(const TProb<T> &)
639 */
640 this_type& operator/= (const this_type & q) { return pwBinaryOp( q, fo_divides0<T>() ); }
641
642 /// Pointwise division by \a q, where division by 0 yields +Inf
643 /** \pre <tt>this->size() == q.size()</tt>
644 * \see operator/=(const TProb<T> &)
645 */
646 this_type& divide (const this_type & q) { return pwBinaryOp( q, std::divides<T>() ); }
647
648 /// Pointwise power
649 /** \pre <tt>this->size() == q.size()</tt>
650 */
651 this_type& operator^= (const this_type & q) { return pwBinaryOp( q, fo_pow<T>() ); }
652 //@}
653
654 /// \name Transformations with other equally-sized vectors
655 //@{
656 /// Returns the result of applying binary operation \a op pointwise on \c *this and \a q
657 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
658 * \param q Right operand
659 * \param op Operation of type \a binaryOp
660 */
661 template<typename binaryOp> this_type pwBinaryTr( const this_type &q, binaryOp op ) const {
662 DAI_DEBASSERT( size() == q.size() );
663 TProb<T> r;
664 r._p.reserve( size() );
665 std::transform( _p.begin(), _p.end(), q._p.begin(), back_inserter( r._p ), op );
666 return r;
667 }
668
669 /// Returns sum of \c *this and \a q
670 /** \pre <tt>this->size() == q.size()</tt>
671 */
672 this_type operator+ ( const this_type& q ) const { return pwBinaryTr( q, std::plus<T>() ); }
673
674 /// Return \c *this minus \a q
675 /** \pre <tt>this->size() == q.size()</tt>
676 */
677 this_type operator- ( const this_type& q ) const { return pwBinaryTr( q, std::minus<T>() ); }
678
679 /// Return product of \c *this with \a q
680 /** \pre <tt>this->size() == q.size()</tt>
681 */
682 this_type operator* ( const this_type &q ) const { return pwBinaryTr( q, std::multiplies<T>() ); }
683
684 /// Returns quotient of \c *this with \a q, where division by 0 yields 0
685 /** \pre <tt>this->size() == q.size()</tt>
686 * \see divided_by(const TProb<T> &)
687 */
688 this_type operator/ ( const this_type &q ) const { return pwBinaryTr( q, fo_divides0<T>() ); }
689
690 /// Pointwise division by \a q, where division by 0 yields +Inf
691 /** \pre <tt>this->size() == q.size()</tt>
692 * \see operator/(const TProb<T> &)
693 */
694 this_type divided_by( const this_type &q ) const { return pwBinaryTr( q, std::divides<T>() ); }
695
696 /// Returns \c *this to the power \a q
697 /** \pre <tt>this->size() == q.size()</tt>
698 */
699 this_type operator^ ( const this_type &q ) const { return pwBinaryTr( q, fo_pow<T>() ); }
700 //@}
701
702 /// Performs a generalized inner product, similar to std::inner_product
703 /** \pre <tt>this->size() == q.size()</tt>
704 */
705 template<typename binOp1, typename binOp2> T innerProduct( const this_type &q, T init, binOp1 binaryOp1, binOp2 binaryOp2 ) const {
706 DAI_DEBASSERT( size() == q.size() );
707 return std::inner_product( begin(), end(), q.begin(), init, binaryOp1, binaryOp2 );
708 }
709 };
710
711
712 /// Returns distance between \a p and \a q, measured using distance measure \a dt
713 /** \relates TProb
714 * \pre <tt>this->size() == q.size()</tt>
715 */
716 template<typename T> T dist( const TProb<T> &p, const TProb<T> &q, ProbDistType dt ) {
717 switch( dt ) {
718 case DISTL1:
719 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() );
720 case DISTLINF:
721 return p.innerProduct( q, (T)0, fo_max<T>(), fo_absdiff<T>() );
722 case DISTTV:
723 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() ) / 2;
724 case DISTKL:
725 return p.innerProduct( q, (T)0, std::plus<T>(), fo_KL<T>() );
726 case DISTHEL:
727 return p.innerProduct( q, (T)0, std::plus<T>(), fo_Hellinger<T>() ) / 2;
728 default:
729 DAI_THROW(UNKNOWN_ENUM_VALUE);
730 return INFINITY;
731 }
732 }
733
734
735 /// Writes a TProb<T> to an output stream
736 /** \relates TProb
737 */
738 template<typename T> std::ostream& operator<< (std::ostream& os, const TProb<T>& p) {
739 os << "(";
740 for( size_t i = 0; i < p.size(); i++ )
741 os << ((i != 0) ? ", " : "") << p.get(i);
742 os << ")";
743 return os;
744 }
745
746
747 /// Returns the pointwise minimum of \a a and \a b
748 /** \relates TProb
749 * \pre <tt>this->size() == q.size()</tt>
750 */
751 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) {
752 return a.pwBinaryTr( b, fo_min<T>() );
753 }
754
755
756 /// Returns the pointwise maximum of \a a and \a b
757 /** \relates TProb
758 * \pre <tt>this->size() == q.size()</tt>
759 */
760 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) {
761 return a.pwBinaryTr( b, fo_max<T>() );
762 }
763
764
765 /// Represents a vector with entries of type dai::Real.
766 typedef TProb<Real> Prob;
767
768
769 } // end of namespace dai
770
771
772 #endif