Removed deprecated interfaces
[libdai.git] / include / dai / prob.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines TProb<> and Prob classes which represent (probability) vectors (e.g., probability distributions of discrete random variables)
14
15
16 #ifndef __defined_libdai_prob_h
17 #define __defined_libdai_prob_h
18
19
20 #include <cmath>
21 #include <vector>
22 #include <ostream>
23 #include <algorithm>
24 #include <numeric>
25 #include <functional>
26 #include <dai/util.h>
27 #include <dai/exceptions.h>
28
29
30 namespace dai {
31
32
33 /// Function object that returns the value itself
34 template<typename T> struct fo_id : public std::unary_function<T, T> {
35 /// Returns \a x
36 T operator()( const T &x ) const {
37 return x;
38 }
39 };
40
41
42 /// Function object that takes the absolute value
43 template<typename T> struct fo_abs : public std::unary_function<T, T> {
44 /// Returns abs(\a x)
45 T operator()( const T &x ) const {
46 if( x < (T)0 )
47 return -x;
48 else
49 return x;
50 }
51 };
52
53
54 /// Function object that takes the exponent
55 template<typename T> struct fo_exp : public std::unary_function<T, T> {
56 /// Returns exp(\a x)
57 T operator()( const T &x ) const {
58 return exp( x );
59 }
60 };
61
62
63 /// Function object that takes the logarithm
64 template<typename T> struct fo_log : public std::unary_function<T, T> {
65 /// Returns log(\a x)
66 T operator()( const T &x ) const {
67 return log( x );
68 }
69 };
70
71
72 /// Function object that takes the logarithm, except that log(0) is defined to be 0
73 template<typename T> struct fo_log0 : public std::unary_function<T, T> {
74 /// Returns (\a x == 0 ? 0 : log(\a x))
75 T operator()( const T &x ) const {
76 if( x )
77 return log( x );
78 else
79 return 0;
80 }
81 };
82
83
84 /// Function object that takes the inverse
85 template<typename T> struct fo_inv : public std::unary_function<T, T> {
86 /// Returns 1 / \a x
87 T operator()( const T &x ) const {
88 return 1 / x;
89 }
90 };
91
92
93 /// Function object that takes the inverse, except that 1/0 is defined to be 0
94 template<typename T> struct fo_inv0 : public std::unary_function<T, T> {
95 /// Returns (\a x == 0 ? 0 : (1 / \a x))
96 T operator()( const T &x ) const {
97 if( x )
98 return 1 / x;
99 else
100 return 0;
101 }
102 };
103
104
105 /// Function object that returns p*log0(p)
106 template<typename T> struct fo_plog0p : public std::unary_function<T, T> {
107 /// Returns \a p * log0(\a p)
108 T operator()( const T &p ) const {
109 return p * dai::log0(p);
110 }
111 };
112
113
114 /// Function object similar to std::divides(), but different in that dividing by zero results in zero
115 template<typename T> struct fo_divides0 : public std::binary_function<T, T, T> {
116 /// Returns (\a y == 0 ? 0 : (\a x / \a y))
117 T operator()( const T &x, const T &y ) const {
118 if( y == (T)0 )
119 return (T)0;
120 else
121 return x / y;
122 }
123 };
124
125
126 /// Function object useful for calculating the KL distance
127 template<typename T> struct fo_KL : public std::binary_function<T, T, T> {
128 /// Returns (\a p == 0 ? 0 : (\a p * (log(\a p) - log(\a q))))
129 T operator()( const T &p, const T &q ) const {
130 if( p == (T)0 )
131 return (T)0;
132 else
133 return p * (log(p) - log(q));
134 }
135 };
136
137
138 /// Function object useful for calculating the Hellinger distance
139 template<typename T> struct fo_Hellinger : public std::binary_function<T, T, T> {
140 /// Returns (sqrt(\a p) - sqrt(\a q))^2
141 T operator()( const T &p, const T &q ) const {
142 T x = sqrt(p) - sqrt(q);
143 return x * x;
144 }
145 };
146
147
148 /// Function object that returns x to the power y
149 template<typename T> struct fo_pow : public std::binary_function<T, T, T> {
150 /// Returns (\a x ^ \a y)
151 T operator()( const T &x, const T &y ) const {
152 if( y != 1 )
153 return pow( x, y );
154 else
155 return x;
156 }
157 };
158
159
160 /// Function object that returns the maximum of two values
161 template<typename T> struct fo_max : public std::binary_function<T, T, T> {
162 /// Returns (\a x > y ? x : y)
163 T operator()( const T &x, const T &y ) const {
164 return (x > y) ? x : y;
165 }
166 };
167
168
169 /// Function object that returns the minimum of two values
170 template<typename T> struct fo_min : public std::binary_function<T, T, T> {
171 /// Returns (\a x > y ? y : x)
172 T operator()( const T &x, const T &y ) const {
173 return (x > y) ? y : x;
174 }
175 };
176
177
178 /// Function object that returns the absolute difference of x and y
179 template<typename T> struct fo_absdiff : public std::binary_function<T, T, T> {
180 /// Returns abs( \a x - \a y )
181 T operator()( const T &x, const T &y ) const {
182 return dai::abs( x - y );
183 }
184 };
185
186
187 /// Represents a vector with entries of type \a T.
188 /** It is simply a <tt>std::vector</tt><<em>T</em>> with an interface designed for dealing with probability mass functions.
189 *
190 * It is mainly used for representing measures on a finite outcome space, for example, the probability
191 * distribution of a discrete random variable. However, entries are not necessarily non-negative; it is also used to
192 * represent logarithms of probability mass functions.
193 *
194 * \tparam T Should be a scalar that is castable from and to dai::Real and should support elementary arithmetic operations.
195 */
196 template <typename T>
197 class TProb {
198 public:
199 /// Type of data structure used for storing the values
200 typedef std::vector<T> container_type;
201
202 /// Shorthand
203 typedef TProb<T> this_type;
204
205 private:
206 /// The data structure that stores the values
207 container_type _p;
208
209 public:
210 /// \name Constructors and destructors
211 //@{
212 /// Default constructor (constructs empty vector)
213 TProb() : _p() {}
214
215 /// Construct uniform probability distribution over \a n outcomes (i.e., a vector of length \a n with each entry set to \f$1/n\f$)
216 explicit TProb( size_t n ) : _p( n, (T)1 / n ) {}
217
218 /// Construct vector of length \a n with each entry set to \a p
219 explicit TProb( size_t n, T p ) : _p( n, p ) {}
220
221 /// Construct vector from a range
222 /** \tparam TIterator Iterates over instances that can be cast to \a T
223 * \param begin Points to first instance to be added.
224 * \param end Points just beyond last instance to be added.
225 * \param sizeHint For efficiency, the number of entries can be speficied by \a sizeHint;
226 * the value 0 can be given if the size is unknown, but this will result in a performance penalty.
227 */
228 template <typename TIterator>
229 TProb( TIterator begin, TIterator end, size_t sizeHint ) : _p() {
230 _p.reserve( sizeHint );
231 _p.insert( _p.begin(), begin, end );
232 }
233
234 /// Construct vector from another vector
235 /** \tparam S type of elements in \a v (should be castable to type \a T)
236 * \param v vector used for initialization.
237 */
238 template <typename S>
239 TProb( const std::vector<S> &v ) : _p() {
240 _p.reserve( v.size() );
241 _p.insert( _p.begin(), v.begin(), v.end() );
242 }
243 //@}
244
245 /// Constant iterator over the elements
246 typedef typename container_type::const_iterator const_iterator;
247 /// Iterator over the elements
248 typedef typename container_type::iterator iterator;
249 /// Constant reverse iterator over the elements
250 typedef typename container_type::const_reverse_iterator const_reverse_iterator;
251 /// Reverse iterator over the elements
252 typedef typename container_type::reverse_iterator reverse_iterator;
253
254 /// \name Iterator interface
255 //@{
256 /// Returns iterator that points to the first element
257 iterator begin() { return _p.begin(); }
258 /// Returns constant iterator that points to the first element
259 const_iterator begin() const { return _p.begin(); }
260
261 /// Returns iterator that points beyond the last element
262 iterator end() { return _p.end(); }
263 /// Returns constant iterator that points beyond the last element
264 const_iterator end() const { return _p.end(); }
265
266 /// Returns reverse iterator that points to the last element
267 reverse_iterator rbegin() { return _p.rbegin(); }
268 /// Returns constant reverse iterator that points to the last element
269 const_reverse_iterator rbegin() const { return _p.rbegin(); }
270
271 /// Returns reverse iterator that points beyond the first element
272 reverse_iterator rend() { return _p.rend(); }
273 /// Returns constant reverse iterator that points beyond the first element
274 const_reverse_iterator rend() const { return _p.rend(); }
275 //@}
276
277 /// \name Miscellaneous operations
278 //@{
279 void resize( size_t sz ) {
280 _p.resize( sz );
281 }
282 //@}
283
284 /// \name Get/set individual entries
285 //@{
286 /// Gets \a i 'th entry
287 T get( size_t i ) const {
288 #ifdef DAI_DEBUG
289 return _p.at(i);
290 #else
291 return _p[i];
292 #endif
293 }
294
295 /// Sets \a i 'th entry to \a val
296 void set( size_t i, T val ) {
297 DAI_DEBASSERT( i < _p.size() );
298 _p[i] = val;
299 }
300 //@}
301
302 /// \name Queries
303 //@{
304 /// Returns a const reference to the wrapped container
305 const container_type& p() const { return _p; }
306
307 /// Returns a reference to the wrapped container
308 container_type& p() { return _p; }
309
310 /// Returns a copy of the \a i 'th entry
311 T operator[]( size_t i ) const { return get(i); }
312
313 /// Returns length of the vector (i.e., the number of entries)
314 size_t size() const { return _p.size(); }
315
316 /// Accumulate all values (similar to std::accumulate) by summing
317 /** The following calculation is done:
318 * \code
319 * T t = op(init);
320 * for( const_iterator it = begin(); it != end(); it++ )
321 * t += op(*it);
322 * return t;
323 * \endcode
324 */
325 template<typename unOp> T accumulateSum( T init, unOp op ) const {
326 T t = op(init);
327 for( const_iterator it = begin(); it != end(); it++ )
328 t += op(*it);
329 return t;
330 }
331
332 /// Accumulate all values (similar to std::accumulate) by maximization/minimization
333 /** The following calculation is done (with "max" replaced by "min" if \a minimize == \c true):
334 * \code
335 * T t = op(init);
336 * for( const_iterator it = begin(); it != end(); it++ )
337 * t = std::max( t, op(*it) );
338 * return t;
339 * \endcode
340 */
341 template<typename unOp> T accumulateMax( T init, unOp op, bool minimize ) const {
342 T t = op(init);
343 if( minimize ) {
344 for( const_iterator it = begin(); it != end(); it++ )
345 t = std::min( t, op(*it) );
346 } else {
347 for( const_iterator it = begin(); it != end(); it++ )
348 t = std::max( t, op(*it) );
349 }
350 return t;
351 }
352
353 /// Returns the Shannon entropy of \c *this, \f$-\sum_i p_i \log p_i\f$
354 T entropy() const { return -accumulateSum( (T)0, fo_plog0p<T>() ); }
355
356 /// Returns maximum value of all entries
357 T max() const { return accumulateMax( (T)(-INFINITY), fo_id<T>(), false ); }
358
359 /// Returns minimum value of all entries
360 T min() const { return accumulateMax( (T)INFINITY, fo_id<T>(), true ); }
361
362 /// Returns sum of all entries
363 T sum() const { return accumulateSum( (T)0, fo_id<T>() ); }
364
365 /// Return sum of absolute value of all entries
366 T sumAbs() const { return accumulateSum( (T)0, fo_abs<T>() ); }
367
368 /// Returns maximum absolute value of all entries
369 T maxAbs() const { return accumulateMax( (T)0, fo_abs<T>(), false ); }
370
371 /// Returns \c true if one or more entries are NaN
372 bool hasNaNs() const {
373 bool foundnan = false;
374 for( const_iterator x = _p.begin(); x != _p.end(); x++ )
375 if( dai::isnan( *x ) ) {
376 foundnan = true;
377 break;
378 }
379 return foundnan;
380 }
381
382 /// Returns \c true if one or more entries are negative
383 bool hasNegatives() const {
384 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<T>(), (T)0 ) ) != _p.end());
385 }
386
387 /// Returns a pair consisting of the index of the maximum value and the maximum value itself
388 std::pair<size_t,T> argmax() const {
389 T max = _p[0];
390 size_t arg = 0;
391 for( size_t i = 1; i < size(); i++ ) {
392 if( _p[i] > max ) {
393 max = _p[i];
394 arg = i;
395 }
396 }
397 return std::make_pair( arg, max );
398 }
399
400 /// Returns a random index, according to the (normalized) distribution described by *this
401 size_t draw() {
402 Real x = rnd_uniform() * sum();
403 T s = 0;
404 for( size_t i = 0; i < size(); i++ ) {
405 s += get(i);
406 if( s > x )
407 return i;
408 }
409 return( size() - 1 );
410 }
411
412 /// Lexicographical comparison
413 /** \pre <tt>this->size() == q.size()</tt>
414 */
415 bool operator<( const this_type& q ) const {
416 DAI_DEBASSERT( size() == q.size() );
417 return lexicographical_compare( begin(), end(), q.begin(), q.end() );
418 }
419
420 /// Comparison
421 bool operator==( const this_type& q ) const {
422 if( size() != q.size() )
423 return false;
424 return p() == q.p();
425 }
426 //@}
427
428 /// \name Unary transformations
429 //@{
430 /// Returns the result of applying operation \a op pointwise on \c *this
431 template<typename unaryOp> this_type pwUnaryTr( unaryOp op ) const {
432 this_type r;
433 r._p.reserve( size() );
434 std::transform( _p.begin(), _p.end(), back_inserter( r._p ), op );
435 return r;
436 }
437
438 /// Returns negative of \c *this
439 this_type operator- () const { return pwUnaryTr( std::negate<T>() ); }
440
441 /// Returns pointwise absolute value
442 this_type abs() const { return pwUnaryTr( fo_abs<T>() ); }
443
444 /// Returns pointwise exponent
445 this_type exp() const { return pwUnaryTr( fo_exp<T>() ); }
446
447 /// Returns pointwise logarithm
448 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
449 */
450 this_type log(bool zero=false) const {
451 if( zero )
452 return pwUnaryTr( fo_log0<T>() );
453 else
454 return pwUnaryTr( fo_log<T>() );
455 }
456
457 /// Returns pointwise inverse
458 /** If \a zero == \c true, uses <tt>1/0==0</tt>; otherwise, <tt>1/0==Inf</tt>.
459 */
460 this_type inverse(bool zero=true) const {
461 if( zero )
462 return pwUnaryTr( fo_inv0<T>() );
463 else
464 return pwUnaryTr( fo_inv<T>() );
465 }
466
467 /// Returns normalized copy of \c *this, using the specified norm
468 /** \throw NOT_NORMALIZABLE if the norm is zero
469 */
470 this_type normalized( ProbNormType norm = dai::NORMPROB ) const {
471 T Z = 0;
472 if( norm == dai::NORMPROB )
473 Z = sum();
474 else if( norm == dai::NORMLINF )
475 Z = maxAbs();
476 if( Z == (T)0 ) {
477 DAI_THROW(NOT_NORMALIZABLE);
478 return *this;
479 } else
480 return pwUnaryTr( std::bind2nd( std::divides<T>(), Z ) );
481 }
482 //@}
483
484 /// \name Unary operations
485 //@{
486 /// Applies unary operation \a op pointwise
487 template<typename unaryOp> this_type& pwUnaryOp( unaryOp op ) {
488 std::transform( _p.begin(), _p.end(), _p.begin(), op );
489 return *this;
490 }
491
492 /// Draws all entries i.i.d. from a uniform distribution on [0,1)
493 this_type& randomize() {
494 std::generate( _p.begin(), _p.end(), rnd_uniform );
495 return *this;
496 }
497
498 /// Sets all entries to \f$1/n\f$ where \a n is the length of the vector
499 this_type& setUniform () {
500 fill( (T)1 / size() );
501 return *this;
502 }
503
504 /// Applies absolute value pointwise
505 this_type& takeAbs() { return pwUnaryOp( fo_abs<T>() ); }
506
507 /// Applies exponent pointwise
508 this_type& takeExp() { return pwUnaryOp( fo_exp<T>() ); }
509
510 /// Applies logarithm pointwise
511 /** If \a zero == \c true, uses <tt>log(0)==0</tt>; otherwise, <tt>log(0)==-Inf</tt>.
512 */
513 this_type& takeLog(bool zero=false) {
514 if( zero ) {
515 return pwUnaryOp( fo_log0<T>() );
516 } else
517 return pwUnaryOp( fo_log<T>() );
518 }
519
520 /// Normalizes vector using the specified norm
521 /** \throw NOT_NORMALIZABLE if the norm is zero
522 */
523 T normalize( ProbNormType norm=dai::NORMPROB ) {
524 T Z = 0;
525 if( norm == dai::NORMPROB )
526 Z = sum();
527 else if( norm == dai::NORMLINF )
528 Z = maxAbs();
529 if( Z == (T)0 )
530 DAI_THROW(NOT_NORMALIZABLE);
531 else
532 *this /= Z;
533 return Z;
534 }
535 //@}
536
537 /// \name Operations with scalars
538 //@{
539 /// Sets all entries to \a x
540 this_type& fill( T x ) {
541 std::fill( _p.begin(), _p.end(), x );
542 return *this;
543 }
544
545 /// Adds scalar \a x to each entry
546 this_type& operator+= (T x) {
547 if( x != 0 )
548 return pwUnaryOp( std::bind2nd( std::plus<T>(), x ) );
549 else
550 return *this;
551 }
552
553 /// Subtracts scalar \a x from each entry
554 this_type& operator-= (T x) {
555 if( x != 0 )
556 return pwUnaryOp( std::bind2nd( std::minus<T>(), x ) );
557 else
558 return *this;
559 }
560
561 /// Multiplies each entry with scalar \a x
562 this_type& operator*= (T x) {
563 if( x != 1 )
564 return pwUnaryOp( std::bind2nd( std::multiplies<T>(), x ) );
565 else
566 return *this;
567 }
568
569 /// Divides each entry by scalar \a x, where division by 0 yields 0
570 this_type& operator/= (T x) {
571 if( x != 1 )
572 return pwUnaryOp( std::bind2nd( fo_divides0<T>(), x ) );
573 else
574 return *this;
575 }
576
577 /// Raises entries to the power \a x
578 this_type& operator^= (T x) {
579 if( x != (T)1 )
580 return pwUnaryOp( std::bind2nd( fo_pow<T>(), x) );
581 else
582 return *this;
583 }
584 //@}
585
586 /// \name Transformations with scalars
587 //@{
588 /// Returns sum of \c *this and scalar \a x
589 this_type operator+ (T x) const { return pwUnaryTr( std::bind2nd( std::plus<T>(), x ) ); }
590
591 /// Returns difference of \c *this and scalar \a x
592 this_type operator- (T x) const { return pwUnaryTr( std::bind2nd( std::minus<T>(), x ) ); }
593
594 /// Returns product of \c *this with scalar \a x
595 this_type operator* (T x) const { return pwUnaryTr( std::bind2nd( std::multiplies<T>(), x ) ); }
596
597 /// Returns quotient of \c *this and scalar \a x, where division by 0 yields 0
598 this_type operator/ (T x) const { return pwUnaryTr( std::bind2nd( fo_divides0<T>(), x ) ); }
599
600 /// Returns \c *this raised to the power \a x
601 this_type operator^ (T x) const { return pwUnaryTr( std::bind2nd( fo_pow<T>(), x ) ); }
602 //@}
603
604 /// \name Operations with other equally-sized vectors
605 //@{
606 /// Applies binary operation pointwise on two vectors
607 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
608 * \param q Right operand
609 * \param op Operation of type \a binaryOp
610 */
611 template<typename binaryOp> this_type& pwBinaryOp( const this_type &q, binaryOp op ) {
612 DAI_DEBASSERT( size() == q.size() );
613 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), op );
614 return *this;
615 }
616
617 /// Pointwise addition with \a q
618 /** \pre <tt>this->size() == q.size()</tt>
619 */
620 this_type& operator+= (const this_type & q) { return pwBinaryOp( q, std::plus<T>() ); }
621
622 /// Pointwise subtraction of \a q
623 /** \pre <tt>this->size() == q.size()</tt>
624 */
625 this_type& operator-= (const this_type & q) { return pwBinaryOp( q, std::minus<T>() ); }
626
627 /// Pointwise multiplication with \a q
628 /** \pre <tt>this->size() == q.size()</tt>
629 */
630 this_type& operator*= (const this_type & q) { return pwBinaryOp( q, std::multiplies<T>() ); }
631
632 /// Pointwise division by \a q, where division by 0 yields 0
633 /** \pre <tt>this->size() == q.size()</tt>
634 * \see divide(const TProb<T> &)
635 */
636 this_type& operator/= (const this_type & q) { return pwBinaryOp( q, fo_divides0<T>() ); }
637
638 /// Pointwise division by \a q, where division by 0 yields +Inf
639 /** \pre <tt>this->size() == q.size()</tt>
640 * \see operator/=(const TProb<T> &)
641 */
642 this_type& divide (const this_type & q) { return pwBinaryOp( q, std::divides<T>() ); }
643
644 /// Pointwise power
645 /** \pre <tt>this->size() == q.size()</tt>
646 */
647 this_type& operator^= (const this_type & q) { return pwBinaryOp( q, fo_pow<T>() ); }
648 //@}
649
650 /// \name Transformations with other equally-sized vectors
651 //@{
652 /// Returns the result of applying binary operation \a op pointwise on \c *this and \a q
653 /** \tparam binaryOp Type of function object that accepts two arguments of type \a T and outputs a type \a T
654 * \param q Right operand
655 * \param op Operation of type \a binaryOp
656 */
657 template<typename binaryOp> this_type pwBinaryTr( const this_type &q, binaryOp op ) const {
658 DAI_DEBASSERT( size() == q.size() );
659 TProb<T> r;
660 r._p.reserve( size() );
661 std::transform( _p.begin(), _p.end(), q._p.begin(), back_inserter( r._p ), op );
662 return r;
663 }
664
665 /// Returns sum of \c *this and \a q
666 /** \pre <tt>this->size() == q.size()</tt>
667 */
668 this_type operator+ ( const this_type& q ) const { return pwBinaryTr( q, std::plus<T>() ); }
669
670 /// Return \c *this minus \a q
671 /** \pre <tt>this->size() == q.size()</tt>
672 */
673 this_type operator- ( const this_type& q ) const { return pwBinaryTr( q, std::minus<T>() ); }
674
675 /// Return product of \c *this with \a q
676 /** \pre <tt>this->size() == q.size()</tt>
677 */
678 this_type operator* ( const this_type &q ) const { return pwBinaryTr( q, std::multiplies<T>() ); }
679
680 /// Returns quotient of \c *this with \a q, where division by 0 yields 0
681 /** \pre <tt>this->size() == q.size()</tt>
682 * \see divided_by(const TProb<T> &)
683 */
684 this_type operator/ ( const this_type &q ) const { return pwBinaryTr( q, fo_divides0<T>() ); }
685
686 /// Pointwise division by \a q, where division by 0 yields +Inf
687 /** \pre <tt>this->size() == q.size()</tt>
688 * \see operator/(const TProb<T> &)
689 */
690 this_type divided_by( const this_type &q ) const { return pwBinaryTr( q, std::divides<T>() ); }
691
692 /// Returns \c *this to the power \a q
693 /** \pre <tt>this->size() == q.size()</tt>
694 */
695 this_type operator^ ( const this_type &q ) const { return pwBinaryTr( q, fo_pow<T>() ); }
696 //@}
697
698 /// Performs a generalized inner product, similar to std::inner_product
699 /** \pre <tt>this->size() == q.size()</tt>
700 */
701 template<typename binOp1, typename binOp2> T innerProduct( const this_type &q, T init, binOp1 binaryOp1, binOp2 binaryOp2 ) const {
702 DAI_DEBASSERT( size() == q.size() );
703 return std::inner_product( begin(), end(), q.begin(), init, binaryOp1, binaryOp2 );
704 }
705 };
706
707
708 /// Returns distance between \a p and \a q, measured using distance measure \a dt
709 /** \relates TProb
710 * \pre <tt>this->size() == q.size()</tt>
711 */
712 template<typename T> T dist( const TProb<T> &p, const TProb<T> &q, ProbDistType dt ) {
713 switch( dt ) {
714 case DISTL1:
715 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() );
716 case DISTLINF:
717 return p.innerProduct( q, (T)0, fo_max<T>(), fo_absdiff<T>() );
718 case DISTTV:
719 return p.innerProduct( q, (T)0, std::plus<T>(), fo_absdiff<T>() ) / 2;
720 case DISTKL:
721 return p.innerProduct( q, (T)0, std::plus<T>(), fo_KL<T>() );
722 case DISTHEL:
723 return p.innerProduct( q, (T)0, std::plus<T>(), fo_Hellinger<T>() ) / 2;
724 default:
725 DAI_THROW(UNKNOWN_ENUM_VALUE);
726 return INFINITY;
727 }
728 }
729
730
731 /// Writes a TProb<T> to an output stream
732 /** \relates TProb
733 */
734 template<typename T> std::ostream& operator<< (std::ostream& os, const TProb<T>& p) {
735 os << "(";
736 for( size_t i = 0; i < p.size(); i++ )
737 os << ((i != 0) ? ", " : "") << p.get(i);
738 os << ")";
739 return os;
740 }
741
742
743 /// Returns the pointwise minimum of \a a and \a b
744 /** \relates TProb
745 * \pre <tt>this->size() == q.size()</tt>
746 */
747 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) {
748 return a.pwBinaryTr( b, fo_min<T>() );
749 }
750
751
752 /// Returns the pointwise maximum of \a a and \a b
753 /** \relates TProb
754 * \pre <tt>this->size() == q.size()</tt>
755 */
756 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) {
757 return a.pwBinaryTr( b, fo_max<T>() );
758 }
759
760
761 /// Represents a vector with entries of type dai::Real.
762 typedef TProb<Real> Prob;
763
764
765 } // end of namespace dai
766
767
768 #endif