Replaced all "protected:" by "private:" or "public:"
[libdai.git] / include / dai / prob.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_prob_h
23 #define __defined_libdai_prob_h
24
25
26 #include <cmath>
27 #include <vector>
28 #include <ostream>
29 #include <cassert>
30 #include <algorithm>
31 #include <numeric>
32 #include <functional>
33 #include <dai/util.h>
34
35
36 namespace dai {
37
38
39 typedef double Real;
40
41 template<typename T> class TProb;
42 typedef TProb<Real> Prob;
43
44
45 // predefine friends
46 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b );
47 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b );
48
49
50 /// TProb<T> implements a probability vector of type T.
51 /// T should be castable from and to double.
52 template <typename T> class TProb {
53 private:
54 /// The entries
55 std::vector<T> _p;
56
57 private:
58 /// Calculate x times log(x), or 0 if x == 0
59 Real xlogx( Real x ) const { return( x == 0.0 ? 0.0 : x * std::log(x)); }
60
61 public:
62 /// NORMPROB means that the sum of all entries should be 1
63 /// NORMLINF means that the maximum absolute value of all entries should be 1
64 typedef enum { NORMPROB, NORMLINF } NormType;
65 /// DISTL1 is the L-1 distance (sum of absolute values of pointwise difference)
66 /// DISTLINF is the L-inf distance (maximum absolute value of pointwise difference)
67 /// DISTTV is the Total Variation distance
68 typedef enum { DISTL1, DISTLINF, DISTTV } DistType;
69
70 /// Default constructor
71 TProb() : _p() {}
72
73 /// Construct uniform distribution of given length
74 explicit TProb( size_t n ) : _p(std::vector<T>(n, 1.0 / n)) {}
75
76 /// Construct with given length and initial value
77 TProb( size_t n, Real p ) : _p(n, (T)p) {}
78
79 /// Construct with given length and initial array
80 TProb( size_t n, const Real* p ) : _p(p, p + n ) {}
81
82 /// Provide read access to _p
83 const std::vector<T> & p() const { return _p; }
84
85 /// Provide full access to _p
86 std::vector<T> & p() { return _p; }
87
88 /// Provide read access to ith element of _p
89 T operator[]( size_t i ) const {
90 #ifdef DAI_DEBUG
91 return _p.at(i);
92 #else
93 return _p[i];
94 #endif
95 }
96
97 /// Provide full access to ith element of _p
98 T& operator[]( size_t i ) { return _p[i]; }
99
100 /// Set all elements to x
101 TProb<T> & fill(T x) {
102 std::fill( _p.begin(), _p.end(), x );
103 return *this;
104 }
105
106 /// Set all elements to iid random numbers from uniform(0,1) distribution
107 TProb<T> & randomize() {
108 std::generate(_p.begin(), _p.end(), rnd_uniform);
109 return *this;
110 }
111
112 /// Return size
113 size_t size() const {
114 return _p.size();
115 }
116
117 /// Make entries zero if (Real) absolute value smaller than epsilon
118 TProb<T>& makeZero (Real epsilon) {
119 for( size_t i = 0; i < size(); i++ )
120 if( fabs((Real)_p[i]) < epsilon )
121 _p[i] = 0;
122 // std::replace_if( _p.begin(), _p.end(), fabs((Real)boost::lambda::_1) < epsilon, 0.0 );
123 return *this;
124 }
125
126 /// Make entries epsilon if they are smaller than epsilon
127 TProb<T>& makePositive (Real epsilon) {
128 for( size_t i = 0; i < size(); i++ )
129 if( (0 < (Real)_p[i]) && ((Real)_p[i] < epsilon) )
130 _p[i] = epsilon;
131 return *this;
132 }
133
134 /// Multiplication with T x
135 TProb<T>& operator*= (T x) {
136 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::multiplies<T>(), x) );
137 return *this;
138 }
139
140 /// Return product of *this with T x
141 TProb<T> operator* (T x) const {
142 TProb<T> prod( *this );
143 prod *= x;
144 return prod;
145 }
146
147 /// Division by T x
148 TProb<T>& operator/= (T x) {
149 #ifdef DAI_DEBUG
150 assert( x != 0.0 );
151 #endif
152 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::divides<T>(), x ) );
153 return *this;
154 }
155
156 /// Return quotient of *this and T x
157 TProb<T> operator/ (T x) const {
158 TProb<T> quot( *this );
159 quot /= x;
160 return quot;
161 }
162
163 /// addition of x
164 TProb<T>& operator+= (T x) {
165 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::plus<T>(), x ) );
166 return *this;
167 }
168
169 /// Return sum of *this with T x
170 TProb<T> operator+ (T x) const {
171 TProb<T> sum( *this );
172 sum += x;
173 return sum;
174 }
175
176 /// Difference by T x
177 TProb<T>& operator-= (T x) {
178 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::minus<T>(), x ) );
179 return *this;
180 }
181
182 /// Return difference of *this and T x
183 TProb<T> operator- (T x) const {
184 TProb<T> diff( *this );
185 diff -= x;
186 return diff;
187 }
188
189 /// Pointwise comparison
190 bool operator<= (const TProb<T> & q) const {
191 #ifdef DAI_DEBUG
192 assert( size() == q.size() );
193 #endif
194 for( size_t i = 0; i < size(); i++ )
195 if( !(_p[i] <= q[i]) )
196 return false;
197 return true;
198 }
199
200 /// Pointwise multiplication with q
201 TProb<T>& operator*= (const TProb<T> & q) {
202 #ifdef DAI_DEBUG
203 assert( size() == q.size() );
204 #endif
205 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::multiplies<T>() );
206 return *this;
207 }
208
209 /// Return product of *this with q
210 TProb<T> operator* (const TProb<T> & q) const {
211 #ifdef DAI_DEBUG
212 assert( size() == q.size() );
213 #endif
214 TProb<T> prod( *this );
215 prod *= q;
216 return prod;
217 }
218
219 /// Pointwise addition with q
220 TProb<T>& operator+= (const TProb<T> & q) {
221 #ifdef DAI_DEBUG
222 assert( size() == q.size() );
223 #endif
224 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::plus<T>() );
225 return *this;
226 }
227
228 /// Pointwise subtraction of q
229 TProb<T>& operator-= (const TProb<T> & q) {
230 #ifdef DAI_DEBUG
231 assert( size() == q.size() );
232 #endif
233 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::minus<T>() );
234 return *this;
235 }
236
237 /// Return sum of *this and q
238 TProb<T> operator+ (const TProb<T> & q) const {
239 #ifdef DAI_DEBUG
240 assert( size() == q.size() );
241 #endif
242 TProb<T> sum( *this );
243 sum += q;
244 return sum;
245 }
246
247 /// Return *this minus q
248 TProb<T> operator- (const TProb<T> & q) const {
249 #ifdef DAI_DEBUG
250 assert( size() == q.size() );
251 #endif
252 TProb<T> diff( *this );
253 diff -= q;
254 return diff;
255 }
256
257 /// Pointwise division by q (division by zero yields zero)
258 TProb<T>& operator/= (const TProb<T> & q) {
259 #ifdef DAI_DEBUG
260 assert( size() == q.size() );
261 #endif
262 for( size_t i = 0; i < size(); i++ ) {
263 if( q[i] == 0.0 )
264 _p[i] = 0.0;
265 else
266 _p[i] /= q[i];
267 }
268 return *this;
269 }
270
271 /// Pointwise division by q (division by zero yields infinity)
272 TProb<T>& divide (const TProb<T> & q) {
273 #ifdef DAI_DEBUG
274 assert( size() == q.size() );
275 #endif
276 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::divides<T>() );
277 return *this;
278 }
279
280 /// Return quotient of *this with q
281 TProb<T> operator/ (const TProb<T> & q) const {
282 #ifdef DAI_DEBUG
283 assert( size() == q.size() );
284 #endif
285 TProb<T> quot( *this );
286 quot /= q;
287 return quot;
288 }
289
290 /// Return pointwise inverse
291 TProb<T> inverse(bool zero = false) const {
292 TProb<T> inv;
293 inv._p.reserve( size() );
294 if( zero )
295 for( size_t i = 0; i < size(); i++ )
296 inv._p.push_back( _p[i] == 0.0 ? 0.0 : 1.0 / _p[i] );
297 else
298 for( size_t i = 0; i < size(); i++ ) {
299 #ifdef DAI_DEBUG
300 assert( _p[i] != 0.0 );
301 #endif
302 inv._p.push_back( 1.0 / _p[i] );
303 }
304 return inv;
305 }
306
307 /// Return *this to the power of a (pointwise)
308 TProb<T>& operator^= (Real a) {
309 if( a != 1.0 )
310 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::ptr_fun<T, Real, T>(std::pow), a) );
311 return *this;
312 }
313
314 /// Pointwise power of a
315 TProb<T> operator^ (Real a) const {
316 TProb<T> power(*this);
317 power ^= a;
318 return power;
319 }
320
321 /// Pointwise signum
322 TProb<T> sgn() const {
323 TProb<T> x;
324 x._p.reserve( size() );
325 for( size_t i = 0; i < size(); i++ ) {
326 T s = 0;
327 if( _p[i] > 0 )
328 s = 1;
329 else if( _p[i] < 0 )
330 s = -1;
331 x._p.push_back( s );
332 }
333 return x;
334 }
335
336 /// Pointwise absolute value
337 TProb<T> abs() const {
338 TProb<T> x;
339 x._p.reserve( size() );
340 for( size_t i = 0; i < size(); i++ )
341 x._p.push_back( _p[i] < 0 ? (-p[i]) : p[i] );
342 return x;
343 }
344
345 /// Pointwise exp
346 const TProb<T>& takeExp() {
347 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::exp) );
348 return *this;
349 }
350
351 /// Pointwise log
352 const TProb<T>& takeLog() {
353 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::log) );
354 return *this;
355 }
356
357 /// Pointwise log (or 0 if == 0)
358 const TProb<T>& takeLog0() {
359 for( size_t i = 0; i < size(); i++ )
360 _p[i] = ( (_p[i] == 0.0) ? 0.0 : std::log( _p[i] ) );
361 return *this;
362 }
363
364 /// Pointwise exp
365 TProb<T> exp() const {
366 TProb<T> e(*this);
367 e.takeExp();
368 return e;
369 }
370
371 /// Pointwise log
372 TProb<T> log() const {
373 TProb<T> l(*this);
374 l.takeLog();
375 return l;
376 }
377
378 /// Pointwise log (or 0 if == 0)
379 TProb<T> log0() const {
380 TProb<T> l0(*this);
381 l0.takeLog0();
382 return l0;
383 }
384
385 /// Return distance of p and q
386 friend Real dist( const TProb<T> & p, const TProb<T> & q, DistType dt ) {
387 #ifdef DAI_DEBUG
388 assert( p.size() == q.size() );
389 #endif
390 Real result = 0.0;
391 switch( dt ) {
392 case DISTL1:
393 for( size_t i = 0; i < p.size(); i++ )
394 result += fabs((Real)p[i] - (Real)q[i]);
395 break;
396
397 case DISTLINF:
398 for( size_t i = 0; i < p.size(); i++ ) {
399 Real z = fabs((Real)p[i] - (Real)q[i]);
400 if( z > result )
401 result = z;
402 }
403 break;
404
405 case DISTTV:
406 for( size_t i = 0; i < p.size(); i++ )
407 result += fabs((Real)p[i] - (Real)q[i]);
408 result *= 0.5;
409 break;
410 }
411 return result;
412 }
413
414 /// Return Kullback-Leibler distance with q
415 friend Real KL_dist( const TProb<T> & p, const TProb<T> & q ) {
416 #ifdef DAI_DEBUG
417 assert( p.size() == q.size() );
418 #endif
419 Real result = 0.0;
420 for( size_t i = 0; i < p.size(); i++ ) {
421 if( (Real) p[i] != 0.0 ) {
422 Real p_i = p[i];
423 Real q_i = q[i];
424 result += p_i * (std::log(p_i) - std::log(q_i));
425 }
426 }
427 return result;
428 }
429
430 /// Return sum of all entries
431 T totalSum() const {
432 T Z = std::accumulate( _p.begin(), _p.end(), (T)0 );
433 return Z;
434 }
435
436 /// Converts entries to Real and returns maximum absolute value
437 T maxAbs() const {
438 T Z = 0;
439 for( size_t i = 0; i < size(); i++ ) {
440 Real mag = fabs( (Real) _p[i] );
441 if( mag > Z )
442 Z = mag;
443 }
444 return Z;
445 }
446
447 /// Returns maximum value
448 T maxVal() const {
449 T Z = *std::max_element( _p.begin(), _p.end() );
450 return Z;
451 }
452
453 /// Returns minimum value
454 T minVal() const {
455 T Z = *std::min_element( _p.begin(), _p.end() );
456 return Z;
457 }
458
459 /// Normalize, using the specified norm
460 T normalize( NormType norm = NORMPROB ) {
461 T Z = 0.0;
462 if( norm == NORMPROB )
463 Z = totalSum();
464 else if( norm == NORMLINF )
465 Z = maxAbs();
466 #ifdef DAI_DEBUG
467 assert( Z != 0.0 );
468 #endif
469 *this /= Z;
470 return Z;
471 }
472
473 /// Return normalized copy of *this, using the specified norm
474 TProb<T> normalized( NormType norm = NORMPROB ) const {
475 TProb<T> result(*this);
476 result.normalize( norm );
477 return result;
478 }
479
480 /// Returns true if one or more entries are NaN
481 bool hasNaNs() const {
482 return (std::find_if( _p.begin(), _p.end(), isnan ) != _p.end());
483 }
484
485 /// Returns true if one or more entries are negative
486 bool hasNegatives() const {
487 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<Real>(), 0.0 ) ) != _p.end());
488 }
489
490 /// Returns true if one or more entries are non-positive (causes problems with logscale)
491 bool hasNonPositives() const {
492 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less_equal<Real>(), 0.0 ) ) != _p.end());
493 }
494
495 /// Returns entropy
496 Real entropy() const {
497 Real S = 0.0;
498 for( size_t i = 0; i < size(); i++ )
499 S -= xlogx(_p[i]);
500 return S;
501 }
502
503 /// Returns TProb<T> containing the pointwise minimum of a and b (which should have equal size)
504 friend TProb<T> min <> ( const TProb<T> &a, const TProb<T> &b );
505
506 /// Returns TProb<T> containing the pointwise maximum of a and b (which should have equal size)
507 friend TProb<T> max <> ( const TProb<T> &a, const TProb<T> &b );
508
509 friend std::ostream& operator<< (std::ostream& os, const TProb<T>& P) {
510 os << "[";
511 std::copy( P._p.begin(), P._p.end(), std::ostream_iterator<T>(os, " ") );
512 os << "]";
513 return os;
514 }
515 };
516
517
518 template<typename T> TProb<T> min( const TProb<T> &a, const TProb<T> &b ) {
519 assert( a.size() == b.size() );
520 TProb<T> result( a.size() );
521 for( size_t i = 0; i < a.size(); i++ )
522 if( a[i] < b[i] )
523 result[i] = a[i];
524 else
525 result[i] = b[i];
526 return result;
527 }
528
529
530 template<typename T> TProb<T> max( const TProb<T> &a, const TProb<T> &b ) {
531 assert( a.size() == b.size() );
532 TProb<T> result( a.size() );
533 for( size_t i = 0; i < a.size(); i++ )
534 if( a[i] > b[i] )
535 result[i] = a[i];
536 else
537 result[i] = b[i];
538 return result;
539 }
540
541
542 } // end of namespace dai
543
544
545 #endif