Partial adoption of contributions by Giuseppe:
[libdai.git] / include / dai / prob.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_prob_h
23 #define __defined_libdai_prob_h
24
25
26 #include <complex>
27 #include <cmath>
28 #include <vector>
29 #include <iostream>
30 #include <cassert>
31 #include <dai/util.h>
32
33
34 namespace dai {
35
36
37 typedef double Real;
38 typedef std::complex<double> Complex;
39
40 template<typename T> class TProb;
41 typedef TProb<Real> Prob;
42 typedef TProb<Complex> CProb;
43
44
45 /// TProb<T> implements a probability vector of type T.
46 /// T should be castable from and to double and to complex.
47 template <typename T> class TProb {
48 protected:
49 /// The entries
50 std::vector<T> _p;
51
52 private:
53 /// Calculate x times log(x), or 0 if x == 0
54 Complex xlogx( Real x ) const { return( x == 0.0 ? 0.0 : Complex(x) * std::log(Complex(x))); }
55
56 public:
57 /// NORMPROB means that the sum of all entries should be 1
58 /// NORMLINF means that the maximum absolute value of all entries should be 1
59 typedef enum { NORMPROB, NORMLINF } NormType;
60 /// DISTL1 is the L-1 distance (sum of absolute values of pointwise difference)
61 /// DISTLINF is the L-inf distance (maximum absolute value of pointwise difference)
62 /// DISTTV is the Total Variation distance
63 typedef enum { DISTL1, DISTLINF, DISTTV } DistType;
64
65 /// Default constructor
66 TProb() : _p() {}
67
68 /// Construct uniform distribution of given length
69 TProb( size_t n ) : _p(std::vector<T>(n, 1.0 / n)) {}
70
71 /// Construct with given length and initial value
72 TProb( size_t n, Real p ) : _p(std::vector<T>(n,(T)p)) {}
73
74 /// Construct with given length and initial array
75 TProb( size_t n, const Real* p ) {
76 // Reserve-push_back is faster than resize-copy
77 _p.reserve( n );
78 for( size_t i = 0; i < n; i++ )
79 _p.push_back( p[i] );
80 }
81
82 /// Copy constructor
83 TProb( const TProb<T> & x ) : _p(x._p) {}
84
85 /// Assignment operator
86 TProb<T> & operator=( const TProb<T> &x ) {
87 if( this != &x ) {
88 _p = x._p;
89 }
90 return *this;
91 }
92
93 /// Provide read access to _p
94 const std::vector<T> & p() const { return _p; }
95
96 /// Provide full access to _p
97 std::vector<T> & p() { return _p; }
98
99 /// Provide read access to ith element of _p
100 T operator[]( size_t i ) const { return _p[i]; }
101
102 /// Provide full access to ith element of _p
103 T& operator[]( size_t i ) { return _p[i]; }
104
105 /// Set all elements to x
106 TProb<T> & fill(T x) {
107 for( size_t i = 0; i < size(); i++ )
108 _p[i] = x;
109 return *this;
110 }
111
112 /// Set all elements to iid random numbers from uniform(0,1) distribution
113 TProb<T> & randomize() {
114 for( size_t i = 0; i < size(); i++ )
115 _p[i] = rnd_uniform();
116 return *this;
117 }
118
119 /// Return size
120 size_t size() const {
121 return _p.size();
122 }
123
124 /// Make entries zero if (Real) absolute value smaller than epsilon
125 TProb<T>& makeZero (Real epsilon) {
126 for( size_t i = 0; i < size(); i++ )
127 if( fabs((Real)_p[i]) < epsilon )
128 _p[i] = 0;
129 return *this;
130 }
131
132 /// Multiplication with T x
133 TProb<T>& operator*= (T x) {
134 for( size_t i = 0; i < size(); i++ )
135 _p[i] *= x;
136 return *this;
137 }
138
139 /// Return product of *this with T x
140 TProb<T> operator* (T x) const {
141 TProb<T> prod( *this );
142 prod *= x;
143 return prod;
144 }
145
146 /// Division by T x
147 TProb<T>& operator/= (T x) {
148 #ifdef DEBUG
149 assert( x != 0.0 );
150 #endif
151 for( size_t i = 0; i < size(); i++ )
152 _p[i] /= x;
153 return *this;
154 }
155
156 /// Return quotient of *this and T x
157 TProb<T> operator/ (T x) const {
158 TProb<T> prod( *this );
159 prod /= x;
160 return prod;
161 }
162
163 /// Pointwise multiplication with q
164 TProb<T>& operator*= (const TProb<T> & q) {
165 #ifdef DEBUG
166 assert( size() == q.size() );
167 #endif
168 for( size_t i = 0; i < size(); i++ )
169 _p[i] *= q[i];
170 return *this;
171 }
172
173 /// Return product of *this with q
174 TProb<T> operator* (const TProb<T> & q) const {
175 #ifdef DEBUG
176 assert( size() == q.size() );
177 #endif
178 TProb<T> prod( *this );
179 prod *= q;
180 return prod;
181 }
182
183 /// Pointwise addition with q
184 TProb<T>& operator+= (const TProb<T> & q) {
185 #ifdef DEBUG
186 assert( size() == q.size() );
187 #endif
188 for( size_t i = 0; i < size(); i++ )
189 _p[i] += q[i];
190 return *this;
191 }
192
193 /// Pointwise subtraction of q
194 TProb<T>& operator-= (const TProb<T> & q) {
195 #ifdef DEBUG
196 assert( size() == q.size() );
197 #endif
198 for( size_t i = 0; i < size(); i++ )
199 _p[i] -= q[i];
200 return *this;
201 }
202
203 /// Return sum of *this and q
204 TProb<T> operator+ (const TProb<T> & q) const {
205 #ifdef DEBUG
206 assert( size() == q.size() );
207 #endif
208 TProb<T> sum( *this );
209 sum += q;
210 return sum;
211 }
212
213 /// Return *this minus q
214 TProb<T> operator- (const TProb<T> & q) const {
215 #ifdef DEBUG
216 assert( size() == q.size() );
217 #endif
218 TProb<T> sum( *this );
219 sum -= q;
220 return sum;
221 }
222
223 /// Pointwise division by q
224 TProb<T>& operator/= (const TProb<T> & q) {
225 #ifdef DEBUG
226 assert( size() == q.size() );
227 #endif
228 for( size_t i = 0; i < size(); i++ ) {
229 #ifdef DEBUG
230 // assert( q[i] != 0.0 );
231 #endif
232 if( q[i] == 0.0 ) // FIXME
233 _p[i] = 0.0;
234 else
235 _p[i] /= q[i];
236 }
237 return *this;
238 }
239
240 /// Pointwise division by q, division by zero yields infinity
241 TProb<T>& divide (const TProb<T> & q) {
242 #ifdef DEBUG
243 assert( size() == q.size() );
244 #endif
245 for( size_t i = 0; i < size(); i++ )
246 _p[i] /= q[i];
247 return *this;
248 }
249
250 /// Return quotient of *this with q
251 TProb<T> operator/ (const TProb<T> & q) const {
252 #ifdef DEBUG
253 assert( size() == q.size() );
254 #endif
255 TProb<T> quot( *this );
256 quot /= q;
257 return quot;
258 }
259
260 /// Return pointwise inverse
261 TProb<T> inverse(bool zero = false) const {
262 TProb<T> inv;
263 inv._p.reserve( size() );
264 if( zero )
265 for( size_t i = 0; i < size(); i++ )
266 inv._p.push_back( _p[i] == 0.0 ? 0.0 : 1.0 / _p[i] );
267 else
268 for( size_t i = 0; i < size(); i++ ) {
269 #ifdef DEBUG
270 assert( _p[i] != 0.0 );
271 #endif
272 inv._p.push_back( 1.0 / _p[i] );
273 }
274 return inv;
275 }
276
277 /// Return *this to the power of a (pointwise)
278 TProb<T>& operator^= (Real a) {
279 if( a != 1.0 ) {
280 for( size_t i = 0; i < size(); i++ )
281 _p[i] = std::pow( _p[i], a );
282 }
283 return *this;
284 }
285
286 /// Pointwise power of a
287 TProb<T> operator^ (Real a) const {
288 TProb<T> power;
289 if( a != 1.0 ) {
290 power._p.reserve( size() );
291 for( size_t i = 0; i < size(); i++ )
292 power._p.push_back( std::pow( _p[i], a ) );
293 } else
294 power = *this;
295 return power;
296 }
297
298 /// Pointwise exp
299 TProb<T> exp() const {
300 TProb<T> e;
301 e._p.reserve( size() );
302 for( size_t i = 0; i < size(); i++ )
303 e._p.push_back( std::exp( _p[i] ) );
304 return e;
305 }
306
307 /// Pointwise log
308 TProb<T> log() const {
309 TProb<T> l;
310 l._p.reserve( size() );
311 for( size_t i = 0; i < size(); i++ )
312 l._p.push_back( std::log( _p[i] ) );
313 return l;
314 }
315
316 /// Pointwise log (or 0 if == 0)
317 TProb<T> log0() const {
318 TProb<T> l0;
319 l0._p.reserve( size() );
320 for( size_t i = 0; i < size(); i++ )
321 l0._p.push_back( (_p[i] == 0.0) ? 0.0 : std::log( _p[i] ) );
322 return l0;
323 }
324
325 /// Pointwise (complex) log (or 0 if == 0)
326 /* CProb clog0() const {
327 CProb l0;
328 l0._p.reserve( size() );
329 for( size_t i = 0; i < size(); i++ )
330 l0._p.push_back( (_p[i] == 0.0) ? 0.0 : std::log( Complex( _p[i] ) ) );
331 return l0;
332 }*/
333
334 /// Return distance of p and q
335 friend Real dist( const TProb<T> & p, const TProb<T> & q, DistType dt ) {
336 #ifdef DEBUG
337 assert( p.size() == q.size() );
338 #endif
339 Real result = 0.0;
340 switch( dt ) {
341 case DISTL1:
342 for( size_t i = 0; i < p.size(); i++ )
343 result += fabs((Real)p[i] - (Real)q[i]);
344 break;
345
346 case DISTLINF:
347 for( size_t i = 0; i < p.size(); i++ ) {
348 Real z = fabs((Real)p[i] - (Real)q[i]);
349 if( z > result )
350 result = z;
351 }
352 break;
353
354 case DISTTV:
355 for( size_t i = 0; i < p.size(); i++ )
356 result += fabs((Real)p[i] - (Real)q[i]);
357 result *= 0.5;
358 break;
359 }
360 return result;
361 }
362
363 /// Return (complex) Kullback-Leibler distance with q
364 friend Complex KL_dist( const TProb<T> & p, const TProb<T> & q ) {
365 #ifdef DEBUG
366 assert( p.size() == q.size() );
367 #endif
368 Complex result = 0.0;
369 for( size_t i = 0; i < p.size(); i++ ) {
370 if( (Real) p[i] != 0.0 ) {
371 Complex p_i = p[i];
372 Complex q_i = q[i];
373 result += p_i * (std::log(p_i) - std::log(q_i));
374 }
375 }
376 return result;
377 }
378
379 /// Return sum of all entries
380 T totalSum() const {
381 T Z = 0.0;
382 for( size_t i = 0; i < size(); i++ )
383 Z += _p[i];
384 return Z;
385 }
386
387 /// Converts entries to Real and returns maximum absolute value
388 T maxAbs() const {
389 T Z = 0.0;
390 for( size_t i = 0; i < size(); i++ ) {
391 Real mag = fabs( (Real) _p[i] );
392 if( mag > Z )
393 Z = mag;
394 }
395 return Z;
396 }
397
398 /// Returns maximum value
399 T max() const {
400 T Z = 0.0;
401 for( size_t i = 0; i < size(); i++ ) {
402 if( _p[i] > Z )
403 Z = _p[i];
404 }
405 return Z;
406 }
407
408 /// Normalize, using the specified norm
409 T normalize( NormType norm ) {
410 T Z = 0.0;
411 if( norm == NORMPROB )
412 Z = totalSum();
413 else if( norm == NORMLINF )
414 Z = maxAbs();
415 #ifdef DEBUG
416 assert( Z != 0.0 );
417 #endif
418 T Zi = 1.0 / Z;
419 for( size_t i = 0; i < size(); i++ )
420 _p[i] *= Zi;
421 return Z;
422 }
423
424 /// Return normalized copy of *this, using the specified norm
425 TProb<T> normalized( NormType norm ) const {
426 T Z = 0.0;
427 if( norm == NORMPROB )
428 Z = totalSum();
429 else if( norm == NORMLINF )
430 Z = maxAbs();
431 #ifdef DEBUG
432 assert( Z != 0.0 );
433 #endif
434 Z = 1.0 / Z;
435
436 TProb<T> result;
437 result._p.reserve( size() );
438 for( size_t i = 0; i < size(); i++ )
439 result._p.push_back( _p[i] * Z );
440 return result;
441 }
442
443 /// Returns true if one or more entries are NaN
444 bool hasNaNs() const {
445 bool NaNs = false;
446 for( size_t i = 0; i < size() && !NaNs; i++ )
447 if( isnan( _p[i] ) )
448 NaNs = true;
449 return NaNs;
450 }
451
452 /// Returns true if one or more entries are negative
453 bool hasNegatives() const {
454 bool Negatives = false;
455 for( size_t i = 0; i < size() && !Negatives; i++ )
456 if( _p[i] < 0.0 )
457 Negatives = true;
458 return Negatives;
459 }
460
461 /// Returns (complex) entropy
462 Complex entropy() const {
463 Complex S = 0.0;
464 for( size_t i = 0; i < size(); i++ )
465 S -= xlogx(_p[i]);
466 return S;
467 }
468
469 friend std::ostream& operator<< (std::ostream& os, const TProb<T>& P) {
470 for( size_t i = 0; i < P.size(); i++ )
471 os << P._p[i] << " ";
472 os << std::endl;
473 return os;
474 }
475 };
476
477
478 } // end of namespace dai
479
480
481 #endif