Oops, correct previous partial commit.
[libdai.git] / include / dai / prob.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 /// \file
24 /// \brief Defines TProb<T> and Prob classes
25 /// \todo Improve documentation
26
27
28 #ifndef __defined_libdai_prob_h
29 #define __defined_libdai_prob_h
30
31
32 #include <cmath>
33 #include <vector>
34 #include <ostream>
35 #include <cassert>
36 #include <algorithm>
37 #include <numeric>
38 #include <functional>
39 #include <dai/util.h>
40
41
42 namespace dai {
43
44
45 /// Real number (alias for double, could be changed to long double if necessary)
46 typedef double Real;
47
48 template<typename T> class TProb;
49
50 /// Represents a probability measure, with entries of type Real.
51 typedef TProb<Real> Prob;
52
53
54 /// Represents a probability measure on a finite outcome space (i.e., corresponding to a discrete random variable).
55 /** It is implemented as a std::vector<T> but adds a convenient interface.
56 * It is not necessarily normalized at all times.
57 * \tparam T Should be castable from and to double.
58 */
59 template <typename T> class TProb {
60 private:
61 /// The probability measure
62 std::vector<T> _p;
63
64 public:
65 /// Enumerates different ways of normalizing a probability measure.
66 /**
67 * - NORMPROB means that the sum of all entries should be 1;
68 * - NORMLINF means that the maximum absolute value of all entries should be 1.
69 */
70 typedef enum { NORMPROB, NORMLINF } NormType;
71 /// Enumerates different distance measures between probability measures.
72 /**
73 * - DISTL1 is the L-1 distance (sum of absolute values of pointwise difference);
74 * - DISTLINF is the L-inf distance (maximum absolute value of pointwise difference);
75 * - DISTTV is the Total Variation distance;
76 * - DISTKL is the Kullback-Leibler distance.
77 */
78 typedef enum { DISTL1, DISTLINF, DISTTV, DISTKL } DistType;
79
80 /// Default constructor
81 TProb() : _p() {}
82
83 /// Construct uniform distribution of given length
84 explicit TProb( size_t n ) : _p(std::vector<T>(n, 1.0 / n)) {}
85
86 /// Construct from given length and initial value
87 TProb( size_t n, Real p ) : _p(n, (T)p) {}
88
89 /// Construct from given length and initial array
90 TProb( size_t n, const Real* p ) : _p(p, p + n ) {}
91
92 /// Returns a const reference to the probability vector
93 const std::vector<T> & p() const { return _p; }
94
95 /// Returns a reference to the probability vector
96 std::vector<T> & p() { return _p; }
97
98 /// Returns a copy of the i'th probability entry
99 T operator[]( size_t i ) const {
100 #ifdef DAI_DEBUG
101 return _p.at(i);
102 #else
103 return _p[i];
104 #endif
105 }
106
107 /// Returns a reference to the i'th probability entry
108 T& operator[]( size_t i ) { return _p[i]; }
109
110 /// Sets all elements to x
111 TProb<T> & fill(T x) {
112 std::fill( _p.begin(), _p.end(), x );
113 return *this;
114 }
115
116 /// Sets all elements to i.i.d. random numbers from a uniform[0,1) distribution
117 TProb<T> & randomize() {
118 std::generate(_p.begin(), _p.end(), rnd_uniform);
119 return *this;
120 }
121
122 /// Returns number of elements
123 size_t size() const {
124 return _p.size();
125 }
126
127 /// Sets entries that are smaller than epsilon to zero
128 TProb<T>& makeZero( Real epsilon ) {
129 for( size_t i = 0; i < size(); i++ )
130 if( fabs(_p[i]) < epsilon )
131 _p[i] = 0;
132 return *this;
133 }
134
135 /// Sets entries that are smaller than epsilon to epsilon
136 TProb<T>& makePositive (Real epsilon) {
137 for( size_t i = 0; i < size(); i++ )
138 if( (0 < (Real)_p[i]) && ((Real)_p[i] < epsilon) )
139 _p[i] = epsilon;
140 return *this;
141 }
142
143 /// Multiplies each entry with x
144 TProb<T>& operator*= (T x) {
145 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::multiplies<T>(), x) );
146 return *this;
147 }
148
149 /// Returns product of *this with x
150 TProb<T> operator* (T x) const {
151 TProb<T> prod( *this );
152 prod *= x;
153 return prod;
154 }
155
156 /// Divides each entry by x
157 TProb<T>& operator/= (T x) {
158 #ifdef DAI_DEBUG
159 assert( x != 0.0 );
160 #endif
161 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::divides<T>(), x ) );
162 return *this;
163 }
164
165 /// Returns quotient of *this and x
166 TProb<T> operator/ (T x) const {
167 TProb<T> quot( *this );
168 quot /= x;
169 return quot;
170 }
171
172 /// Adds x to each entry
173 TProb<T>& operator+= (T x) {
174 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::plus<T>(), x ) );
175 return *this;
176 }
177
178 /// Returns sum of *this and x
179 TProb<T> operator+ (T x) const {
180 TProb<T> sum( *this );
181 sum += x;
182 return sum;
183 }
184
185 /// Subtracts x from each entry
186 TProb<T>& operator-= (T x) {
187 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::minus<T>(), x ) );
188 return *this;
189 }
190
191 /// Returns difference of *this and x
192 TProb<T> operator- (T x) const {
193 TProb<T> diff( *this );
194 diff -= x;
195 return diff;
196 }
197
198 /// Pointwise comparison
199 bool operator<= (const TProb<T> & q) const {
200 #ifdef DAI_DEBUG
201 assert( size() == q.size() );
202 #endif
203 for( size_t i = 0; i < size(); i++ )
204 if( !(_p[i] <= q[i]) )
205 return false;
206 return true;
207 }
208
209 /// Pointwise multiplication with q
210 TProb<T>& operator*= (const TProb<T> & q) {
211 #ifdef DAI_DEBUG
212 assert( size() == q.size() );
213 #endif
214 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::multiplies<T>() );
215 return *this;
216 }
217
218 /// Return product of *this with q
219 TProb<T> operator* (const TProb<T> & q) const {
220 #ifdef DAI_DEBUG
221 assert( size() == q.size() );
222 #endif
223 TProb<T> prod( *this );
224 prod *= q;
225 return prod;
226 }
227
228 /// Pointwise addition with q
229 TProb<T>& operator+= (const TProb<T> & q) {
230 #ifdef DAI_DEBUG
231 assert( size() == q.size() );
232 #endif
233 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::plus<T>() );
234 return *this;
235 }
236
237 /// Return sum of *this and q
238 TProb<T> operator+ (const TProb<T> & q) const {
239 #ifdef DAI_DEBUG
240 assert( size() == q.size() );
241 #endif
242 TProb<T> sum( *this );
243 sum += q;
244 return sum;
245 }
246
247 /// Pointwise subtraction of q
248 TProb<T>& operator-= (const TProb<T> & q) {
249 #ifdef DAI_DEBUG
250 assert( size() == q.size() );
251 #endif
252 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::minus<T>() );
253 return *this;
254 }
255
256 /// Return *this minus q
257 TProb<T> operator- (const TProb<T> & q) const {
258 #ifdef DAI_DEBUG
259 assert( size() == q.size() );
260 #endif
261 TProb<T> diff( *this );
262 diff -= q;
263 return diff;
264 }
265
266 /// Pointwise division by q, where division by zero yields zero
267 TProb<T>& operator/= (const TProb<T> & q) {
268 #ifdef DAI_DEBUG
269 assert( size() == q.size() );
270 #endif
271 for( size_t i = 0; i < size(); i++ ) {
272 if( q[i] == 0.0 )
273 _p[i] = 0.0;
274 else
275 _p[i] /= q[i];
276 }
277 return *this;
278 }
279
280 /// Pointwise division by q, where division by zero yields infinity
281 TProb<T>& divide (const TProb<T> & q) {
282 #ifdef DAI_DEBUG
283 assert( size() == q.size() );
284 #endif
285 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::divides<T>() );
286 return *this;
287 }
288
289 /// Returns quotient of *this with q
290 TProb<T> operator/ (const TProb<T> & q) const {
291 #ifdef DAI_DEBUG
292 assert( size() == q.size() );
293 #endif
294 TProb<T> quot( *this );
295 quot /= q;
296 return quot;
297 }
298
299 /// Returns pointwise inverse
300 TProb<T> inverse(bool zero = false) const {
301 TProb<T> inv;
302 inv._p.reserve( size() );
303 if( zero )
304 for( size_t i = 0; i < size(); i++ )
305 inv._p.push_back( _p[i] == 0.0 ? 0.0 : 1.0 / _p[i] );
306 else
307 for( size_t i = 0; i < size(); i++ ) {
308 #ifdef DAI_DEBUG
309 assert( _p[i] != 0.0 );
310 #endif
311 inv._p.push_back( 1.0 / _p[i] );
312 }
313 return inv;
314 }
315
316 /// Raises elements to the power a
317 TProb<T>& operator^= (Real a) {
318 if( a != 1.0 )
319 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::ptr_fun<T, Real, T>(std::pow), a) );
320 return *this;
321 }
322
323 /// Returns *this raised to the power a
324 TProb<T> operator^ (Real a) const {
325 TProb<T> power(*this);
326 power ^= a;
327 return power;
328 }
329
330 /// Returns pointwise signum
331 TProb<T> sgn() const {
332 TProb<T> x;
333 x._p.reserve( size() );
334 for( size_t i = 0; i < size(); i++ ) {
335 T s = 0;
336 if( _p[i] > 0 )
337 s = 1;
338 else if( _p[i] < 0 )
339 s = -1;
340 x._p.push_back( s );
341 }
342 return x;
343 }
344
345 /// Returns pointwise absolute value
346 TProb<T> abs() const {
347 TProb<T> x;
348 x._p.reserve( size() );
349 for( size_t i = 0; i < size(); i++ )
350 x._p.push_back( _p[i] < 0 ? (-p[i]) : p[i] );
351 return x;
352 }
353
354 /// Applies exp pointwise
355 const TProb<T>& takeExp() {
356 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::exp) );
357 return *this;
358 }
359
360 /// Applies log pointwise
361 const TProb<T>& takeLog() {
362 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::log) );
363 return *this;
364 }
365
366 /// Applies log pointwise (defining log(0)=0)
367 const TProb<T>& takeLog0() {
368 for( size_t i = 0; i < size(); i++ )
369 _p[i] = ( (_p[i] == 0.0) ? 0.0 : std::log( _p[i] ) );
370 return *this;
371 }
372
373 /// Returns pointwise exp
374 TProb<T> exp() const {
375 TProb<T> e(*this);
376 e.takeExp();
377 return e;
378 }
379
380 /// Returns pointwise log
381 TProb<T> log() const {
382 TProb<T> l(*this);
383 l.takeLog();
384 return l;
385 }
386
387 /// Returns pointwise log (defining log(0)=0)
388 TProb<T> log0() const {
389 TProb<T> l0(*this);
390 l0.takeLog0();
391 return l0;
392 }
393
394 /// Returns distance of p and q, measured using dt
395 friend Real dist( const TProb<T> &p, const TProb<T> &q, DistType dt ) {
396 #ifdef DAI_DEBUG
397 assert( p.size() == q.size() );
398 #endif
399 Real result = 0.0;
400 switch( dt ) {
401 case DISTL1:
402 for( size_t i = 0; i < p.size(); i++ )
403 result += fabs((Real)p[i] - (Real)q[i]);
404 break;
405
406 case DISTLINF:
407 for( size_t i = 0; i < p.size(); i++ ) {
408 Real z = fabs((Real)p[i] - (Real)q[i]);
409 if( z > result )
410 result = z;
411 }
412 break;
413
414 case DISTTV:
415 for( size_t i = 0; i < p.size(); i++ )
416 result += fabs((Real)p[i] - (Real)q[i]);
417 result *= 0.5;
418 break;
419
420 case DISTKL:
421 for( size_t i = 0; i < p.size(); i++ ) {
422 if( p[i] != 0.0 )
423 result += p[i] * (std::log(p[i]) - std::log(q[i]));
424 }
425 }
426 return result;
427 }
428
429 /// Returns sum of all entries
430 T totalSum() const {
431 T Z = std::accumulate( _p.begin(), _p.end(), (T)0 );
432 return Z;
433 }
434
435 /// Returns maximum absolute value of entries
436 T maxAbs() const {
437 T Z = 0;
438 for( size_t i = 0; i < size(); i++ ) {
439 Real mag = fabs( (Real) _p[i] );
440 if( mag > Z )
441 Z = mag;
442 }
443 return Z;
444 }
445
446 /// Returns maximum value of entries
447 T maxVal() const {
448 T Z = *std::max_element( _p.begin(), _p.end() );
449 return Z;
450 }
451
452 /// Returns minimum value of entries
453 T minVal() const {
454 T Z = *std::min_element( _p.begin(), _p.end() );
455 return Z;
456 }
457
458 /// Normalizes using the specified norm
459 T normalize( NormType norm = NORMPROB ) {
460 T Z = 0.0;
461 if( norm == NORMPROB )
462 Z = totalSum();
463 else if( norm == NORMLINF )
464 Z = maxAbs();
465 #ifdef DAI_DEBUG
466 assert( Z != 0.0 );
467 #endif
468 *this /= Z;
469 return Z;
470 }
471
472 /// Returns normalized copy of *this, using the specified norm
473 TProb<T> normalized( NormType norm = NORMPROB ) const {
474 TProb<T> result(*this);
475 result.normalize( norm );
476 return result;
477 }
478
479 /// Returns true if one or more entries are NaN
480 bool hasNaNs() const {
481 return (std::find_if( _p.begin(), _p.end(), isnan ) != _p.end());
482 }
483
484 /// Returns true if one or more entries are negative
485 bool hasNegatives() const {
486 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<Real>(), 0.0 ) ) != _p.end());
487 }
488
489 /// Returns true if one or more entries are non-positive (causes problems with logscale)
490 bool hasNonPositives() const {
491 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less_equal<Real>(), 0.0 ) ) != _p.end());
492 }
493
494 /// Returns entropy
495 Real entropy() const {
496 Real S = 0.0;
497 for( size_t i = 0; i < size(); i++ )
498 S -= xlogx(_p[i]);
499 return S;
500 }
501
502 /// Writes a TProb<T> to an output stream
503 friend std::ostream& operator<< (std::ostream& os, const TProb<T>& P) {
504 os << "[";
505 std::copy( P._p.begin(), P._p.end(), std::ostream_iterator<T>(os, " ") );
506 os << "]";
507 return os;
508 }
509
510 private:
511 /// Returns x*log(x), or 0 if x == 0
512 Real xlogx( Real x ) const { return( x == 0.0 ? 0.0 : x * std::log(x)); }
513 };
514
515
516 /// Returns TProb<T> containing the pointwise minimum of a and b (which should have equal size)
517 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) {
518 assert( a.size() == b.size() );
519 TProb<T> result( a.size() );
520 for( size_t i = 0; i < a.size(); i++ )
521 if( a[i] < b[i] )
522 result[i] = a[i];
523 else
524 result[i] = b[i];
525 return result;
526 }
527
528
529 /// Returns TProb<T> containing the pointwise maximum of a and b (which should have equal size)
530 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) {
531 assert( a.size() == b.size() );
532 TProb<T> result( a.size() );
533 for( size_t i = 0; i < a.size(); i++ )
534 if( a[i] > b[i] )
535 result[i] = a[i];
536 else
537 result[i] = b[i];
538 return result;
539 }
540
541
542 } // end of namespace dai
543
544
545 #endif