Updated copyrights
[libdai.git] / include / dai / prob.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #ifndef __defined_libdai_prob_h
24 #define __defined_libdai_prob_h
25
26
27 #include <cmath>
28 #include <vector>
29 #include <ostream>
30 #include <cassert>
31 #include <algorithm>
32 #include <numeric>
33 #include <functional>
34 #include <dai/util.h>
35
36
37 namespace dai {
38
39
40 typedef double Real;
41
42 template<typename T> class TProb;
43 typedef TProb<Real> Prob;
44
45
46 // predefine friends
47 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b );
48 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b );
49
50
51 /// TProb<T> implements a probability vector of type T.
52 /// T should be castable from and to double.
53 template <typename T> class TProb {
54 private:
55 /// The entries
56 std::vector<T> _p;
57
58 private:
59 /// Calculate x times log(x), or 0 if x == 0
60 Real xlogx( Real x ) const { return( x == 0.0 ? 0.0 : x * std::log(x)); }
61
62 public:
63 /// NORMPROB means that the sum of all entries should be 1
64 /// NORMLINF means that the maximum absolute value of all entries should be 1
65 typedef enum { NORMPROB, NORMLINF } NormType;
66 /// DISTL1 is the L-1 distance (sum of absolute values of pointwise difference)
67 /// DISTLINF is the L-inf distance (maximum absolute value of pointwise difference)
68 /// DISTTV is the Total Variation distance
69 typedef enum { DISTL1, DISTLINF, DISTTV } DistType;
70
71 /// Default constructor
72 TProb() : _p() {}
73
74 /// Construct uniform distribution of given length
75 explicit TProb( size_t n ) : _p(std::vector<T>(n, 1.0 / n)) {}
76
77 /// Construct with given length and initial value
78 TProb( size_t n, Real p ) : _p(n, (T)p) {}
79
80 /// Construct with given length and initial array
81 TProb( size_t n, const Real* p ) : _p(p, p + n ) {}
82
83 /// Provide read access to _p
84 const std::vector<T> & p() const { return _p; }
85
86 /// Provide full access to _p
87 std::vector<T> & p() { return _p; }
88
89 /// Provide read access to ith element of _p
90 T operator[]( size_t i ) const {
91 #ifdef DAI_DEBUG
92 return _p.at(i);
93 #else
94 return _p[i];
95 #endif
96 }
97
98 /// Provide full access to ith element of _p
99 T& operator[]( size_t i ) { return _p[i]; }
100
101 /// Set all elements to x
102 TProb<T> & fill(T x) {
103 std::fill( _p.begin(), _p.end(), x );
104 return *this;
105 }
106
107 /// Set all elements to iid random numbers from uniform(0,1) distribution
108 TProb<T> & randomize() {
109 std::generate(_p.begin(), _p.end(), rnd_uniform);
110 return *this;
111 }
112
113 /// Return size
114 size_t size() const {
115 return _p.size();
116 }
117
118 /// Make entries zero if (Real) absolute value smaller than epsilon
119 TProb<T>& makeZero (Real epsilon) {
120 for( size_t i = 0; i < size(); i++ )
121 if( fabs((Real)_p[i]) < epsilon )
122 _p[i] = 0;
123 // std::replace_if( _p.begin(), _p.end(), fabs((Real)boost::lambda::_1) < epsilon, 0.0 );
124 return *this;
125 }
126
127 /// Make entries epsilon if they are smaller than epsilon
128 TProb<T>& makePositive (Real epsilon) {
129 for( size_t i = 0; i < size(); i++ )
130 if( (0 < (Real)_p[i]) && ((Real)_p[i] < epsilon) )
131 _p[i] = epsilon;
132 return *this;
133 }
134
135 /// Multiplication with T x
136 TProb<T>& operator*= (T x) {
137 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::multiplies<T>(), x) );
138 return *this;
139 }
140
141 /// Return product of *this with T x
142 TProb<T> operator* (T x) const {
143 TProb<T> prod( *this );
144 prod *= x;
145 return prod;
146 }
147
148 /// Division by T x
149 TProb<T>& operator/= (T x) {
150 #ifdef DAI_DEBUG
151 assert( x != 0.0 );
152 #endif
153 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::divides<T>(), x ) );
154 return *this;
155 }
156
157 /// Return quotient of *this and T x
158 TProb<T> operator/ (T x) const {
159 TProb<T> quot( *this );
160 quot /= x;
161 return quot;
162 }
163
164 /// addition of x
165 TProb<T>& operator+= (T x) {
166 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::plus<T>(), x ) );
167 return *this;
168 }
169
170 /// Return sum of *this with T x
171 TProb<T> operator+ (T x) const {
172 TProb<T> sum( *this );
173 sum += x;
174 return sum;
175 }
176
177 /// Difference by T x
178 TProb<T>& operator-= (T x) {
179 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::minus<T>(), x ) );
180 return *this;
181 }
182
183 /// Return difference of *this and T x
184 TProb<T> operator- (T x) const {
185 TProb<T> diff( *this );
186 diff -= x;
187 return diff;
188 }
189
190 /// Pointwise comparison
191 bool operator<= (const TProb<T> & q) const {
192 #ifdef DAI_DEBUG
193 assert( size() == q.size() );
194 #endif
195 for( size_t i = 0; i < size(); i++ )
196 if( !(_p[i] <= q[i]) )
197 return false;
198 return true;
199 }
200
201 /// Pointwise multiplication with q
202 TProb<T>& operator*= (const TProb<T> & q) {
203 #ifdef DAI_DEBUG
204 assert( size() == q.size() );
205 #endif
206 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::multiplies<T>() );
207 return *this;
208 }
209
210 /// Return product of *this with q
211 TProb<T> operator* (const TProb<T> & q) const {
212 #ifdef DAI_DEBUG
213 assert( size() == q.size() );
214 #endif
215 TProb<T> prod( *this );
216 prod *= q;
217 return prod;
218 }
219
220 /// Pointwise addition with q
221 TProb<T>& operator+= (const TProb<T> & q) {
222 #ifdef DAI_DEBUG
223 assert( size() == q.size() );
224 #endif
225 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::plus<T>() );
226 return *this;
227 }
228
229 /// Pointwise subtraction of q
230 TProb<T>& operator-= (const TProb<T> & q) {
231 #ifdef DAI_DEBUG
232 assert( size() == q.size() );
233 #endif
234 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::minus<T>() );
235 return *this;
236 }
237
238 /// Return sum of *this and q
239 TProb<T> operator+ (const TProb<T> & q) const {
240 #ifdef DAI_DEBUG
241 assert( size() == q.size() );
242 #endif
243 TProb<T> sum( *this );
244 sum += q;
245 return sum;
246 }
247
248 /// Return *this minus q
249 TProb<T> operator- (const TProb<T> & q) const {
250 #ifdef DAI_DEBUG
251 assert( size() == q.size() );
252 #endif
253 TProb<T> diff( *this );
254 diff -= q;
255 return diff;
256 }
257
258 /// Pointwise division by q (division by zero yields zero)
259 TProb<T>& operator/= (const TProb<T> & q) {
260 #ifdef DAI_DEBUG
261 assert( size() == q.size() );
262 #endif
263 for( size_t i = 0; i < size(); i++ ) {
264 if( q[i] == 0.0 )
265 _p[i] = 0.0;
266 else
267 _p[i] /= q[i];
268 }
269 return *this;
270 }
271
272 /// Pointwise division by q (division by zero yields infinity)
273 TProb<T>& divide (const TProb<T> & q) {
274 #ifdef DAI_DEBUG
275 assert( size() == q.size() );
276 #endif
277 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::divides<T>() );
278 return *this;
279 }
280
281 /// Return quotient of *this with q
282 TProb<T> operator/ (const TProb<T> & q) const {
283 #ifdef DAI_DEBUG
284 assert( size() == q.size() );
285 #endif
286 TProb<T> quot( *this );
287 quot /= q;
288 return quot;
289 }
290
291 /// Return pointwise inverse
292 TProb<T> inverse(bool zero = false) const {
293 TProb<T> inv;
294 inv._p.reserve( size() );
295 if( zero )
296 for( size_t i = 0; i < size(); i++ )
297 inv._p.push_back( _p[i] == 0.0 ? 0.0 : 1.0 / _p[i] );
298 else
299 for( size_t i = 0; i < size(); i++ ) {
300 #ifdef DAI_DEBUG
301 assert( _p[i] != 0.0 );
302 #endif
303 inv._p.push_back( 1.0 / _p[i] );
304 }
305 return inv;
306 }
307
308 /// Return *this to the power of a (pointwise)
309 TProb<T>& operator^= (Real a) {
310 if( a != 1.0 )
311 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::ptr_fun<T, Real, T>(std::pow), a) );
312 return *this;
313 }
314
315 /// Pointwise power of a
316 TProb<T> operator^ (Real a) const {
317 TProb<T> power(*this);
318 power ^= a;
319 return power;
320 }
321
322 /// Pointwise signum
323 TProb<T> sgn() const {
324 TProb<T> x;
325 x._p.reserve( size() );
326 for( size_t i = 0; i < size(); i++ ) {
327 T s = 0;
328 if( _p[i] > 0 )
329 s = 1;
330 else if( _p[i] < 0 )
331 s = -1;
332 x._p.push_back( s );
333 }
334 return x;
335 }
336
337 /// Pointwise absolute value
338 TProb<T> abs() const {
339 TProb<T> x;
340 x._p.reserve( size() );
341 for( size_t i = 0; i < size(); i++ )
342 x._p.push_back( _p[i] < 0 ? (-p[i]) : p[i] );
343 return x;
344 }
345
346 /// Pointwise exp
347 const TProb<T>& takeExp() {
348 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::exp) );
349 return *this;
350 }
351
352 /// Pointwise log
353 const TProb<T>& takeLog() {
354 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::log) );
355 return *this;
356 }
357
358 /// Pointwise log (or 0 if == 0)
359 const TProb<T>& takeLog0() {
360 for( size_t i = 0; i < size(); i++ )
361 _p[i] = ( (_p[i] == 0.0) ? 0.0 : std::log( _p[i] ) );
362 return *this;
363 }
364
365 /// Pointwise exp
366 TProb<T> exp() const {
367 TProb<T> e(*this);
368 e.takeExp();
369 return e;
370 }
371
372 /// Pointwise log
373 TProb<T> log() const {
374 TProb<T> l(*this);
375 l.takeLog();
376 return l;
377 }
378
379 /// Pointwise log (or 0 if == 0)
380 TProb<T> log0() const {
381 TProb<T> l0(*this);
382 l0.takeLog0();
383 return l0;
384 }
385
386 /// Return distance of p and q
387 friend Real dist( const TProb<T> & p, const TProb<T> & q, DistType dt ) {
388 #ifdef DAI_DEBUG
389 assert( p.size() == q.size() );
390 #endif
391 Real result = 0.0;
392 switch( dt ) {
393 case DISTL1:
394 for( size_t i = 0; i < p.size(); i++ )
395 result += fabs((Real)p[i] - (Real)q[i]);
396 break;
397
398 case DISTLINF:
399 for( size_t i = 0; i < p.size(); i++ ) {
400 Real z = fabs((Real)p[i] - (Real)q[i]);
401 if( z > result )
402 result = z;
403 }
404 break;
405
406 case DISTTV:
407 for( size_t i = 0; i < p.size(); i++ )
408 result += fabs((Real)p[i] - (Real)q[i]);
409 result *= 0.5;
410 break;
411 }
412 return result;
413 }
414
415 /// Return Kullback-Leibler distance with q
416 friend Real KL_dist( const TProb<T> & p, const TProb<T> & q ) {
417 #ifdef DAI_DEBUG
418 assert( p.size() == q.size() );
419 #endif
420 Real result = 0.0;
421 for( size_t i = 0; i < p.size(); i++ ) {
422 if( (Real) p[i] != 0.0 ) {
423 Real p_i = p[i];
424 Real q_i = q[i];
425 result += p_i * (std::log(p_i) - std::log(q_i));
426 }
427 }
428 return result;
429 }
430
431 /// Return sum of all entries
432 T totalSum() const {
433 T Z = std::accumulate( _p.begin(), _p.end(), (T)0 );
434 return Z;
435 }
436
437 /// Converts entries to Real and returns maximum absolute value
438 T maxAbs() const {
439 T Z = 0;
440 for( size_t i = 0; i < size(); i++ ) {
441 Real mag = fabs( (Real) _p[i] );
442 if( mag > Z )
443 Z = mag;
444 }
445 return Z;
446 }
447
448 /// Returns maximum value
449 T maxVal() const {
450 T Z = *std::max_element( _p.begin(), _p.end() );
451 return Z;
452 }
453
454 /// Returns minimum value
455 T minVal() const {
456 T Z = *std::min_element( _p.begin(), _p.end() );
457 return Z;
458 }
459
460 /// Normalize, using the specified norm
461 T normalize( NormType norm = NORMPROB ) {
462 T Z = 0.0;
463 if( norm == NORMPROB )
464 Z = totalSum();
465 else if( norm == NORMLINF )
466 Z = maxAbs();
467 #ifdef DAI_DEBUG
468 assert( Z != 0.0 );
469 #endif
470 *this /= Z;
471 return Z;
472 }
473
474 /// Return normalized copy of *this, using the specified norm
475 TProb<T> normalized( NormType norm = NORMPROB ) const {
476 TProb<T> result(*this);
477 result.normalize( norm );
478 return result;
479 }
480
481 /// Returns true if one or more entries are NaN
482 bool hasNaNs() const {
483 return (std::find_if( _p.begin(), _p.end(), isnan ) != _p.end());
484 }
485
486 /// Returns true if one or more entries are negative
487 bool hasNegatives() const {
488 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<Real>(), 0.0 ) ) != _p.end());
489 }
490
491 /// Returns true if one or more entries are non-positive (causes problems with logscale)
492 bool hasNonPositives() const {
493 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less_equal<Real>(), 0.0 ) ) != _p.end());
494 }
495
496 /// Returns entropy
497 Real entropy() const {
498 Real S = 0.0;
499 for( size_t i = 0; i < size(); i++ )
500 S -= xlogx(_p[i]);
501 return S;
502 }
503
504 /// Returns TProb<T> containing the pointwise minimum of a and b (which should have equal size)
505 friend TProb<T> min <> ( const TProb<T> &a, const TProb<T> &b );
506
507 /// Returns TProb<T> containing the pointwise maximum of a and b (which should have equal size)
508 friend TProb<T> max <> ( const TProb<T> &a, const TProb<T> &b );
509
510 friend std::ostream& operator<< (std::ostream& os, const TProb<T>& P) {
511 os << "[";
512 std::copy( P._p.begin(), P._p.end(), std::ostream_iterator<T>(os, " ") );
513 os << "]";
514 return os;
515 }
516 };
517
518
519 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) {
520 assert( a.size() == b.size() );
521 TProb<T> result( a.size() );
522 for( size_t i = 0; i < a.size(); i++ )
523 if( a[i] < b[i] )
524 result[i] = a[i];
525 else
526 result[i] = b[i];
527 return result;
528 }
529
530
531 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) {
532 assert( a.size() == b.size() );
533 TProb<T> result( a.size() );
534 for( size_t i = 0; i < a.size(); i++ )
535 if( a[i] > b[i] )
536 result[i] = a[i];
537 else
538 result[i] = b[i];
539 return result;
540 }
541
542
543 } // end of namespace dai
544
545
546 #endif