d1b689592d93eb068de8d664472913e3cdc6901b
[libdai.git] / include / dai / prob.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 /// \file
10 /// \brief Defines TProb<> and Prob classes which represent (probability) vectors (e.g., probability distributions of discrete random variables)
11
12
13 #ifndef __defined_libdai_prob_h
14 #define __defined_libdai_prob_h
15
16
17 #include <cmath>
18 #include <vector>
19 #include <ostream>
20 #include <algorithm>
21 #include <numeric>
22 #include <functional>
23 #include <dai/util.h>
24 #include <dai/exceptions.h>
25
26
27 namespace dai {
28
29
30 /// Function object that returns the value itself
31 template<typename T> struct fo_id : public std::unary_function<T, T> {
32 /// Returns \a x
33 T operator()( const T &x ) const {
34 return x;
35 }
36 };
37
38
39 /// Function object that takes the absolute value
40 template<typename T> struct fo_abs : public std::unary_function<T, T> {
41 /// Returns abs(\a x)
42 T operator()( const T &x ) const {
43 if( x < (T)0 )
44 return -x;
45 else
46 return x;
47 }
48 };
49
50
51 /// Function object that takes the exponent
52 template<typename T> struct fo_exp : public std::unary_function<T, T> {
53 /// Returns exp(\a x)
54 T operator()( const T &x ) const {
55 return exp( x );
56 }
57 };
58
59
60 /// Function object that takes the logarithm
61 template<typename T> struct fo_log : public std::unary_function<T, T> {
62 /// Returns log(\a x)
63 T operator()( const T &x ) const {
64 return log( x );
65 }
66 };
67
68
69 /// Function object that takes the logarithm, except that log(0) is defined to be 0
70 template<typename T> struct fo_log0 : public std::unary_function<T, T> {
71 /// Returns (\a x == 0 ? 0 : log(\a x))
72 T operator()( const T &x ) const {
73 if( x )
74 return log( x );
75 else
76 return 0;
77 }
78 };
79
80
81 /// Function object that takes the inverse
82 template<typename T> struct fo_inv : public std::unary_function<T, T> {
83 /// Returns 1 / \a x
84 T operator()( const T &x ) const {
85 return 1 / x;
86 }
87 };
88
89
90 /// Function object that takes the inverse, except that 1/0 is defined to be 0
91 template<typename T> struct fo_inv0 : public std::unary_function<T, T> {
92 /// Returns (\a x == 0 ? 0 : (1 / \a x))
93 T operator()( const T &x ) const {
94 if( x )
95 return 1 / x;
96 else
97 return 0;
98 }
99 };
100
101
102 /// Function object that returns p*log0(p)
103 template<typename T> struct fo_plog0p : public std::unary_function<T, T> {
104 /// Returns \a p * log0(\a p)
105 T operator()( const T &p ) const {
106 return p * dai::log0(p);
107 }
108 };
109
110
111 /// Function object similar to std::divides(), but different in that dividing by zero results in zero
112 template<typename T> struct fo_divides0 : public std::binary_function<T, T, T> {
113 /// Returns (\a y == 0 ? 0 : (\a x / \a y))
114 T operator()( const T &x, const T &y ) const {
115 if( y == (T)0 )
116 return (T)0;
117 else
118 return x / y;
119 }
120 };
121
122
123 /// Function object useful for calculating the KL distance
124 template<typename T> struct fo_KL : public std::binary_function<T, T, T> {
125 /// Returns (\a p == 0 ? 0 : (\a p * (log(\a p) - log(\a q))))
126 T operator()( const T &p, const T &q ) const {
127 if( p == (T)0 )
128 return (T)0;
129 else
130 return p * (log(p) - log(q));
131 }
132 };
133
134
135 /// Function object useful for calculating the Hellinger distance
136 template<typename T> struct fo_Hellinger : public std::binary_function<T, T, T> {
137 /// Returns (sqrt(\a p) - sqrt(\a q))^2
138 T operator()( const T &p, const T &q ) const {
139 T x = sqrt(p) - sqrt(q);
140 return x * x;
141 }
142 };
143
144
145 /// Function object that returns x to the power y
146 template<typename T> struct fo_pow : public std::binary_function<T, T, T> {
147 /// Returns (\a x ^ \a y)
148 T operator()( const T &x, const T &y ) const {
149 if( y != 1 )
150 return pow( x, y );
151 else
152 return x;
153 }
154 };
155
156
157 /// Function object that returns the maximum of two values
158 template<typename T> struct fo_max : public std::binary_function<T, T, T> {
159 /// Returns (\a x > y ? x : y)
160 T operator()( const T &x, const T &y ) const {
161 return (x > y) ? x : y;
162 }
163 };
164
165
166 /// Function object that returns the minimum of two values
167 template<typename T> struct fo_min : public std::binary_function<T, T, T> {
168 /// Returns (\a x > y ? y : x)
169 T operator()( const T &x, const T &y ) const {
170 return (x > y) ? y : x;
171 }
172 };
173
174
175 /// Function object that returns the absolute difference of x and y
176 template<typename T> struct fo_absdiff : public std::binary_function<T, T, T> {
177 /// Returns abs( \a x - \a y )
178 T operator()( const T &x, const T &y ) const {
179 return dai::abs( x - y );
180 }
181 };
182
183
184 /// Represents a vector with entries of type \a T.
185 /** It is simply a <tt>std::vector</tt><<em>T</em>> with an interface designed for dealing with probability mass functions.
186 *
187 * It is mainly used for representing measures on a finite outcome space, for example, the probability
188 * distribution of a discrete random variable. However, entries are not necessarily non-negative; it is also used to
189 * represent logarithms of probability mass functions.
190 *
191 * \tparam T Should be a scalar that is castable from and to dai::Real and should support elementary arithmetic operations.
192 */
193 template <typename T>
194 class TProb {
195 public:
196 /// Type of data structure used for storing the values
197 typedef std::vector<T> container_type;
198
199 /// Shorthand
200 typedef TProb<T> this_type;
201
202 private:
203 /// The data structure that stores the values
204 container_type _p;
205
206 public:
207 /// \name Constructors and destructors
208 //@{
209 /// Default constructor (constructs empty vector)
210 TProb() : _p() {}
211
212 /// Construct uniform probability distribution over \a n outcomes (i.e., a vector of length \a n with each entry set to \f$1/n\f$)
213 explicit TProb( size_t n ) : _p( n, (T)1 / n ) {}
214
215 /// Construct vector of length \a n with each entry set to \a p
216 explicit TProb( size_t n, T p ) : _p( n, p ) {}
217
218 /// Construct vector from a range
219 /** \tparam TIterator Iterates over instances that can be cast to \a T
220 * \param begin Points to first instance to be added.
221 * \param end Points just beyond last instance to be added.
222 * \param sizeHint For efficiency, the number of entries can be speficied by \a sizeHint;
223 * the value 0 can be given if the size is unknown, but this will result in a performance penalty.
224 */
225 template <typename TIterator>
226 TProb( TIterator begin, TIterator end, size_t sizeHint ) : _p() {
227 _p.reserve( sizeHint );
228 _p.insert( _p.begin(), begin, end );
229 }
230
231 /// Construct vector from another vector
232 /** \tparam S type of elements in \a v (should be castable to type \a T)
233 * \param v vector used for initialization.
234 */
235 template <typename S>
236 TProb( const std::vector<S> &v ) : _p() {
237 _p.reserve( v.size() );
238 _p.insert( _p.begin(), v.begin(), v.end() );
239 }
240 //@}
241
242 /// Constant iterator over the elements
243 typedef typename container_type::const_iterator const_iterator;
244 /// Iterator over the elements
245 typedef typename container_type::iterator iterator;
246 /// Constant reverse iterator over the elements
247 typedef typename container_type::const_reverse_iterator const_reverse_iterator;
248 /// Reverse iterator over the elements
249 typedef typename container_type::reverse_iterator reverse_iterator;
250
251 /// \name Iterator interface
252 //@{
253 /// Returns iterator that points to the first element
254 iterator begin() { return _p.begin(); }
255 /// Returns constant iterator that points to the first element
256 const_iterator begin() const { return _p.begin(); }
257
258 /// Returns iterator that points beyond the last element
259 iterator end() { return _p.end(); }
260 /// Returns constant iterator that points beyond the last element
261 const_iterator end() const { return _p.end(); }
262
263 /// Returns reverse iterator that points to the last element
264 reverse_iterator rbegin() { return _p.rbegin(); }
265 /// Returns constant reverse iterator that points to the last element
266 const_reverse_iterator rbegin() const { return _p.rbegin(); }
267
268 /// Returns reverse iterator that points beyond the first element
269 reverse_iterator rend() { return _p.rend(); }
270 /// Returns constant reverse iterator that points beyond the first element
271 const_reverse_iterator rend() const { return _p.rend(); }
272 //@}
273
274 /// \name Miscellaneous operations
275 //@{
276 void resize( size_t sz ) {
277 _p.resize( sz );
278 }
279 //@}
280
281 /// \name Get/set individual entries
282 //@{
283 /// Gets \a i 'th entry
284 T get( size_t i ) const {
285 #ifdef DAI_DEBUG
286 return _p.at(i);
287 #else
288 return _p[i];
289 #endif
290 }
291
292 /// Sets \a i 'th entry to \a val
293 void set( size_t i, T val ) {
294 DAI_DEBASSERT( i < _p.size() );
295 _p[i] = val;
296 }
297 //@}
298
299 /// \name Queries
300 //@{
301 /// Returns a const reference to the wrapped container
302 const container_type& p() const { return _p; }
303
304 /// Returns a reference to the wrapped container
305 container_type& p() { return _p; }
306
307 /// Returns a copy of the \a i 'th entry
308 T operator[]( size_t i ) const { return get(i); }
309
310 /// Returns length of the vector (i.e., the number of entries)
311 size_t size() const { return _p.size(); }
312
313 /// Accumulate all values (similar to std::accumulate) by summing
314 /** The following calculation is done:
315 * \code
316 * T t = op(init);
317 * for( const_iterator it = begin(); it != end(); it++ )
318 * t += op(*it);
319 * return t;
320 * \endcode
321 */
322 template<typename unOp> T accumulateSum( T init, unOp op ) const {
323 T t = op(init);
324 for( const_iterator it = begin(); it != end(); it++ )
325 t += op(*it);
326 return t;
327 }
328
329 /// Accumulate all values (similar to std::accumulate) by maximization/minimization
330 /** The following calculation is done (with "max" replaced by "min" if \a minimize == \c true):
331 * \code
332 * T t = op(init);
333 * for( const_iterator it = begin(); it != end(); it++ )
334 * t = std::max( t, op(*it) );
335 * return t;
336 * \endcode
337 */
338 template<typename unOp> T accumulateMax( T init, unOp op, bool minimize ) const {
339 T t = op(init);
340 if( minimize ) {
341 for( const_iterator it = begin(); it != end(); it++ )
342 t = std::min( t, op(*it) );
343 } else {
344 for( const_iterator it = begin(); it != end(); it++ )
345 t = std::max( t, op(*it) );
346 }
347 return t;
348 }
349
350 /// Returns the Shannon entropy of \c *this, \f$-\sum_i p_i \log p_i\f$
351 T entropy() const { return -accumulateSum( (T)0, fo_plog0p<T>() ); }
352
353 /// Returns maximum value of all entries
354 T max() const { return accumulateMax( (T)(-INFINITY), fo_id<T>(), false ); }
355
356 /// Returns minimum value of all entries
357 T min() const { return accumulateMax( (T)INFINITY, fo_id<T>(), true ); }
358
359 /// Returns sum of all entries
360 T sum() const { return accumulateSum( (T)0, fo_id<T>() ); }
361
362 /// Return sum of absolute value of all entries
363 T sumAbs() const { return accumulateSum( (T)0, fo_abs<T>() ); }
364
365 /// Returns maximum absolute value of all entries
366 T maxAbs() const { return accumulateMax( (T)0, fo_abs<T>(), false ); }
367
368 /// Returns \c true if one or more entries are NaN
369 bool hasNaNs() const {
370 bool foundnan = false;
371 for( const_iterator x = _p.begin(); x != _p.end(); x++ )
372 if( dai::isnan( *x ) ) {
373 foundnan = true;
374 break;
375 }
376 return foundnan;
377 }
378
379 /// Returns \c true if one or more entries are negative
380 bool hasNegatives() const {
381 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<T>(), (T)0 ) ) != _p.end());
382 }
383
384 /// Returns a pair consisting of the index of the maximum value and the maximum value itself
385 std::pair<size_t,T> argmax() const {
386 T max = _p[0];
387 size_t arg = 0;
388 for( size_t i = 1; i < size(); i++ ) {
389 if( _p[i] > max ) {
390 max = _p[i];
391 arg = i;
392 }
393 }
394 return std::make_pair( arg, max );
395 }
396
397 /// Returns a random index, according to the (normalized) distribution described by *this
398 size_t draw() {
399 Real x = rnd_uniform() * sum();
400 T s = 0;
401 for( size_t i = 0; i < size(); i++ ) {
402 s += get(i);
403 if( s > x )
404 return i;
405 }
406 return( size() - 1 );
407 }
408
409 /// Lexicographical comparison
410 /** \pre <tt>this->size() == q.size()</tt>
411 */
412 bool operator<( const this_type& q ) const {
413 DAI_DEBASSERT( size() == q.size() );
414 return lexicographical_compare( begin(), end(), q.begin(), q.end() );
415 }
416
417 /// Comparison
418 bool operator==( const this_type& q ) const {
419 if( size() != q.size() )
420 return false;
421 return p() == q.p();
422 }
423 //@}
424
425 /// \name Unary transformations
426 //@{
427 /// Returns the result of applying operation \a op pointwise on \c *this
428 template<typename unaryOp> this_type pwUnaryTr( unaryOp op ) const {
429 this_type r;
430 r._p.reserve( size() );
431 std::transform( _p.begin(), _p.end(), back_inserter( r._p ), op );
432 return r;
433 }
434
435 /// Returns negative of \c *this
436 this_type operator- () const { return pwUnaryTr( std::negate<T>() ); }
437
438 /// Returns pointwise absolute value
439 this_type abs() const { return pwUnaryTr( fo_abs<T>() ); }
440
441 /// Returns pointwise exponent
442 this_type exp() const { return pwUnaryTr( fo_exp<T>() ); }
443
444 /// Returns pointwise logarithm
445 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
446 */
447 this_type log(bool zero=false) const {
448 if( zero )
449 return pwUnaryTr( fo_log0<T>() );
450 else
451 return pwUnaryTr( fo_log<T>() );
452 }
453
454 /// Returns pointwise inverse
455 /** If \a zero == \c true, uses <tt>1/0==0</tt>; otherwise, <tt>1/0==Inf</tt>.
456 */
457 this_type inverse(bool zero=true) const {
458 if( zero )
459 return pwUnaryTr( fo_inv0<T>() );
460 else
461 return pwUnaryTr( fo_inv<T>() );
462 }
463
464 /// Returns normalized copy of \c *this, using the specified norm
465 /** \throw NOT_NORMALIZABLE if the norm is zero
466 */
467 this_type normalized( ProbNormType norm = dai::NORMPROB ) const {
468 T Z = 0;
469 if( norm == dai::NORMPROB )
470 Z = sum();
471 else if( norm == dai::NORMLINF )
472 Z = maxAbs();
473 if( Z == (T)0 ) {
474 DAI_THROW(NOT_NORMALIZABLE);
475 return *this;
476 } else
477 return pwUnaryTr( std::bind2nd( std::divides<T>(), Z ) );
478 }
479 //@}
480
481 /// \name Unary operations
482 //@{
483 /// Applies unary operation \a op pointwise
484 template<typename unaryOp> this_type& pwUnaryOp( unaryOp op ) {
485 std::transform( _p.begin(), _p.end(), _p.begin(), op );
486 return *this;
487 }
488
489 /// Draws all entries i.i.d. from a uniform distribution on [0,1)
490 this_type& randomize() {
491 std::generate( _p.begin(), _p.end(), rnd_uniform );
492 return *this;
493 }
494
495 /// Sets all entries to \f$1/n\f$ where \a n is the length of the vector
496 this_type& setUniform () {
497 fill( (T)1 / size() );
498 return *this;
499 }
500
501 /// Applies absolute value pointwise
502 this_type& takeAbs() { return pwUnaryOp( fo_abs<T>() ); }
503
504 /// Applies exponent pointwise
505 this_type& takeExp() { return pwUnaryOp( fo_exp<T>() ); }
506
507 /// Applies logarithm pointwise
508 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
509 */
510 this_type& takeLog(bool zero=false) {
511 if( zero ) {
512 return pwUnaryOp( fo_log0<T>() );
513 } else
514 return pwUnaryOp( fo_log<T>() );
515 }
516
517 /// Normalizes vector using the specified norm
518 /** \throw NOT_NORMALIZABLE if the norm is zero
519 */
520 T normalize( ProbNormType norm=dai::NORMPROB ) {
521 T Z = 0;
522 if( norm == dai::NORMPROB )
523 Z = sum();
524 else if( norm == dai::NORMLINF )
525 Z = maxAbs();
526 if( Z == (T)0 )
527 DAI_THROW(NOT_NORMALIZABLE);
528 else
529 *this /= Z;
530 return Z;
531 }
532 //@}
533
534 /// \name Operations with scalars
535 //@{
536 /// Sets all entries to \a x
537 this_type& fill( T x ) {
538 std::fill( _p.begin(), _p.end(), x );
539 return *this;
540 }
541
542 /// Adds scalar \a x to each entry
543 this_type& operator+= (T x) {
544 if( x != 0 )
545 return pwUnaryOp( std::bind2nd( std::plus<T>(), x ) );
546 else
547 return *this;
548 }
549
550 /// Subtracts scalar \a x from each entry
551 this_type& operator-= (T x) {
552 if( x != 0 )
553 return pwUnaryOp( std::bind2nd( std::minus<T>(), x ) );
554 else
555 return *this;
556 }
557
558 /// Multiplies each entry with scalar \a x
559 this_type& operator*= (T x) {
560 if( x != 1 )
561 return pwUnaryOp( std::bind2nd( std::multiplies<T>(), x ) );
562 else
563 return *this;
564 }
565
566 /// Divides each entry by scalar \a x, where division by 0 yields 0
567 this_type& operator/= (T x) {
568 if( x != 1 )
569 return pwUnaryOp( std::bind2nd( fo_divides0<T>(), x ) );
570 else
571 return *this;
572 }
573
574 /// Raises entries to the power \a x
575 this_type& operator^= (T x) {
576 if( x != (T)1 )
577 return pwUnaryOp( std::bind2nd( fo_pow<T>(), x) );
578 else
579 return *this;
580 }
581 //@}
582
583 /// \name Transformations with scalars
584 //@{
585 /// Returns sum of \c *this and scalar \a x
586 this_type operator+ (T x) const { return pwUnaryTr( std::bind2nd( std::plus<T>(), x ) ); }
587
588 /// Returns difference of \c *this and scalar \a x
589 this_type operator- (T x) const { return pwUnaryTr( std::bind2nd( std::minus<T>(), x ) ); }
590
591 /// Returns product of \c *this with scalar \a x
592 this_type operator* (T x) const { return pwUnaryTr( std::bind2nd( std::multiplies<T>(), x ) ); }
593
594 /// Returns quotient of \c *this and scalar \a x, where division by 0 yields 0
595 this_type operator/ (T x) const { return pwUnaryTr( std::bind2nd( fo_divides0<T>(), x ) ); }
596
597 /// Returns \c *this raised to the power \a x
598 this_type operator^ (T x) const { return pwUnaryTr( std::bind2nd( fo_pow<T>(), x ) ); }
599 //@}
600
601 /// \name Operations with other equally-sized vectors
602 //@{
603 /// Applies binary operation pointwise on two vectors
604 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
605 * \param q Right operand
606 * \param op Operation of type \a binaryOp
607 */
608 template<typename binaryOp> this_type& pwBinaryOp( const this_type &q, binaryOp op ) {
609 DAI_DEBASSERT( size() == q.size() );
610 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), op );
611 return *this;
612 }
613
614 /// Pointwise addition with \a q
615 /** \pre <tt>this->size() == q.size()</tt>
616 */
617 this_type& operator+= (const this_type & q) { return pwBinaryOp( q, std::plus<T>() ); }
618
619 /// Pointwise subtraction of \a q
620 /** \pre <tt>this->size() == q.size()</tt>
621 */
622 this_type& operator-= (const this_type & q) { return pwBinaryOp( q, std::minus<T>() ); }
623
624 /// Pointwise multiplication with \a q
625 /** \pre <tt>this->size() == q.size()</tt>
626 */
627 this_type& operator*= (const this_type & q) { return pwBinaryOp( q, std::multiplies<T>() ); }
628
629 /// Pointwise division by \a q, where division by 0 yields 0
630 /** \pre <tt>this->size() == q.size()</tt>
631 * \see divide(const TProb<T> &)
632 */
633 this_type& operator/= (const this_type & q) { return pwBinaryOp( q, fo_divides0<T>() ); }
634
635 /// Pointwise division by \a q, where division by 0 yields +Inf
636 /** \pre <tt>this->size() == q.size()</tt>
637 * \see operator/=(const TProb<T> &)
638 */
639 this_type& divide (const this_type & q) { return pwBinaryOp( q, std::divides<T>() ); }
640
641 /// Pointwise power
642 /** \pre <tt>this->size() == q.size()</tt>
643 */
644 this_type& operator^= (const this_type & q) { return pwBinaryOp( q, fo_pow<T>() ); }
645 //@}
646
647 /// \name Transformations with other equally-sized vectors
648 //@{
649 /// Returns the result of applying binary operation \a op pointwise on \c *this and \a q
650 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
651 * \param q Right operand
652 * \param op Operation of type \a binaryOp
653 */
654 template<typename binaryOp> this_type pwBinaryTr( const this_type &q, binaryOp op ) const {
655 DAI_DEBASSERT( size() == q.size() );
656 TProb<T> r;
657 r._p.reserve( size() );
658 std::transform( _p.begin(), _p.end(), q._p.begin(), back_inserter( r._p ), op );
659 return r;
660 }
661
662 /// Returns sum of \c *this and \a q
663 /** \pre <tt>this->size() == q.size()</tt>
664 */
665 this_type operator+ ( const this_type& q ) const { return pwBinaryTr( q, std::plus<T>() ); }
666
667 /// Return \c *this minus \a q
668 /** \pre <tt>this->size() == q.size()</tt>
669 */
670 this_type operator- ( const this_type& q ) const { return pwBinaryTr( q, std::minus<T>() ); }
671
672 /// Return product of \c *this with \a q
673 /** \pre <tt>this->size() == q.size()</tt>
674 */
675 this_type operator* ( const this_type &q ) const { return pwBinaryTr( q, std::multiplies<T>() ); }
676
677 /// Returns quotient of \c *this with \a q, where division by 0 yields 0
678 /** \pre <tt>this->size() == q.size()</tt>
679 * \see divided_by(const TProb<T> &)
680 */
681 this_type operator/ ( const this_type &q ) const { return pwBinaryTr( q, fo_divides0<T>() ); }
682
683 /// Pointwise division by \a q, where division by 0 yields +Inf
684 /** \pre <tt>this->size() == q.size()</tt>
685 * \see operator/(const TProb<T> &)
686 */
687 this_type divided_by( const this_type &q ) const { return pwBinaryTr( q, std::divides<T>() ); }
688
689 /// Returns \c *this to the power \a q
690 /** \pre <tt>this->size() == q.size()</tt>
691 */
692 this_type operator^ ( const this_type &q ) const { return pwBinaryTr( q, fo_pow<T>() ); }
693 //@}
694
695 /// Performs a generalized inner product, similar to std::inner_product
696 /** \pre <tt>this->size() == q.size()</tt>
697 */
698 template<typename binOp1, typename binOp2> T innerProduct( const this_type &q, T init, binOp1 binaryOp1, binOp2 binaryOp2 ) const {
699 DAI_DEBASSERT( size() == q.size() );
700 return std::inner_product( begin(), end(), q.begin(), init, binaryOp1, binaryOp2 );
701 }
702 };
703
704
705 /// Returns distance between \a p and \a q, measured using distance measure \a dt
706 /** \relates TProb
707 * \pre <tt>this->size() == q.size()</tt>
708 */
709 template<typename T> T dist( const TProb<T> &p, const TProb<T> &q, ProbDistType dt ) {
710 switch( dt ) {
711 case DISTL1:
712 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() );
713 case DISTLINF:
714 return p.innerProduct( q, (T)0, fo_max<T>(), fo_absdiff<T>() );
715 case DISTTV:
716 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() ) / 2;
717 case DISTKL:
718 return p.innerProduct( q, (T)0, std::plus<T>(), fo_KL<T>() );
719 case DISTHEL:
720 return p.innerProduct( q, (T)0, std::plus<T>(), fo_Hellinger<T>() ) / 2;
721 default:
722 DAI_THROW(UNKNOWN_ENUM_VALUE);
723 return INFINITY;
724 }
725 }
726
727
728 /// Writes a TProb<T> to an output stream
729 /** \relates TProb
730 */
731 template<typename T> std::ostream& operator<< (std::ostream& os, const TProb<T>& p) {
732 os << "(";
733 for( size_t i = 0; i < p.size(); i++ )
734 os << ((i != 0) ? ", " : "") << p.get(i);
735 os << ")";
736 return os;
737 }
738
739
740 /// Returns the pointwise minimum of \a a and \a b
741 /** \relates TProb
742 * \pre <tt>this->size() == q.size()</tt>
743 */
744 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) {
745 return a.pwBinaryTr( b, fo_min<T>() );
746 }
747
748
749 /// Returns the pointwise maximum of \a a and \a b
750 /** \relates TProb
751 * \pre <tt>this->size() == q.size()</tt>
752 */
753 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) {
754 return a.pwBinaryTr( b, fo_max<T>() );
755 }
756
757
758 /// Represents a vector with entries of type dai::Real.
759 typedef TProb<Real> Prob;
760
761
762 } // end of namespace dai
763
764
765 #endif