Miscellaneous changes:
[libdai.git] / include / dai / prob.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_prob_h
23 #define __defined_libdai_prob_h
24
25
26 #include <complex>
27 #include <cmath>
28 #include <vector>
29 #include <ostream>
30 #include <cassert>
31 #include <algorithm>
32 #include <numeric>
33 #include <functional>
34 #include <dai/util.h>
35
36
37 namespace dai {
38
39
40 typedef double Real;
41 typedef std::complex<double> Complex;
42
43 template<typename T> class TProb;
44 typedef TProb<Real> Prob;
45 typedef TProb<Complex> CProb;
46
47
48 /// TProb<T> implements a probability vector of type T.
49 /// T should be castable from and to double and to complex.
50 template <typename T> class TProb {
51 protected:
52 /// The entries
53 std::vector<T> _p;
54
55 private:
56 /// Calculate x times log(x), or 0 if x == 0
57 Complex xlogx( Real x ) const { return( x == 0.0 ? 0.0 : Complex(x) * std::log(Complex(x))); }
58
59 public:
60 /// NORMPROB means that the sum of all entries should be 1
61 /// NORMLINF means that the maximum absolute value of all entries should be 1
62 typedef enum { NORMPROB, NORMLINF } NormType;
63 /// DISTL1 is the L-1 distance (sum of absolute values of pointwise difference)
64 /// DISTLINF is the L-inf distance (maximum absolute value of pointwise difference)
65 /// DISTTV is the Total Variation distance
66 typedef enum { DISTL1, DISTLINF, DISTTV } DistType;
67
68 /// Default constructor
69 TProb() {}
70
71 /// Construct uniform distribution of given length
72 explicit TProb( size_t n ) : _p(std::vector<T>(n, 1.0 / n)) {}
73
74 /// Construct with given length and initial value
75 TProb( size_t n, Real p ) : _p(n, (T)p) {}
76
77 /// Construct with given length and initial array
78 TProb( size_t n, const Real* p ) : _p(p, p + n ) {}
79
80 /// Provide read access to _p
81 const std::vector<T> & p() const { return _p; }
82
83 /// Provide full access to _p
84 std::vector<T> & p() { return _p; }
85
86 /// Provide read access to ith element of _p
87 T operator[]( size_t i ) const { return _p[i]; }
88
89 /// Provide full access to ith element of _p
90 T& operator[]( size_t i ) { return _p[i]; }
91
92 /// Set all elements to x
93 TProb<T> & fill(T x) {
94 std::fill( _p.begin(), _p.end(), x );
95 return *this;
96 }
97
98 /// Set all elements to iid random numbers from uniform(0,1) distribution
99 TProb<T> & randomize() {
100 std::generate(_p.begin(), _p.end(), rnd_uniform);
101 return *this;
102 }
103
104 /// Return size
105 size_t size() const {
106 return _p.size();
107 }
108
109 /// Make entries zero if (Real) absolute value smaller than epsilon
110 TProb<T>& makeZero (Real epsilon) {
111 for( size_t i = 0; i < size(); i++ )
112 if( fabs((Real)_p[i]) < epsilon )
113 _p[i] = 0;
114 // std::replace_if( _p.begin(), _p.end(), fabs((Real)boost::lambda::_1) < epsilon, 0.0 );
115 return *this;
116 }
117
118 /// Multiplication with T x
119 TProb<T>& operator*= (T x) {
120 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::multiplies<T>(), x) );
121 return *this;
122 }
123
124 /// Return product of *this with T x
125 TProb<T> operator* (T x) const {
126 TProb<T> prod( *this );
127 prod *= x;
128 return prod;
129 }
130
131 /// Division by T x
132 TProb<T>& operator/= (T x) {
133 #ifdef DAI_DEBUG
134 assert( x != 0.0 );
135 #endif
136 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::divides<T>(), x ) );
137 return *this;
138 }
139
140 /// Return quotient of *this and T x
141 TProb<T> operator/ (T x) const {
142 TProb<T> quot( *this );
143 quot /= x;
144 return quot;
145 }
146
147 /// addition of x
148 TProb<T>& operator+= (T x) {
149 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::plus<T>(), x ) );
150 return *this;
151 }
152
153 /// Return sum of *this with T x
154 TProb<T> operator+ (T x) const {
155 TProb<T> sum( *this );
156 sum += x;
157 return sum;
158 }
159
160 /// Difference by T x
161 TProb<T>& operator-= (T x) {
162 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::minus<T>(), x ) );
163 return *this;
164 }
165
166 /// Return difference of *this and T x
167 TProb<T> operator- (T x) const {
168 TProb<T> diff( *this );
169 diff -= x;
170 return diff;
171 }
172
173 /// Pointwise multiplication with q
174 TProb<T>& operator*= (const TProb<T> & q) {
175 #ifdef DAI_DEBUG
176 assert( size() == q.size() );
177 #endif
178 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::multiplies<T>() );
179 return *this;
180 }
181
182 /// Return product of *this with q
183 TProb<T> operator* (const TProb<T> & q) const {
184 #ifdef DAI_DEBUG
185 assert( size() == q.size() );
186 #endif
187 TProb<T> prod( *this );
188 prod *= q;
189 return prod;
190 }
191
192 /// Pointwise addition with q
193 TProb<T>& operator+= (const TProb<T> & q) {
194 #ifdef DAI_DEBUG
195 assert( size() == q.size() );
196 #endif
197 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::plus<T>() );
198 return *this;
199 }
200
201 /// Pointwise subtraction of q
202 TProb<T>& operator-= (const TProb<T> & q) {
203 #ifdef DAI_DEBUG
204 assert( size() == q.size() );
205 #endif
206 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::minus<T>() );
207 return *this;
208 }
209
210 /// Return sum of *this and q
211 TProb<T> operator+ (const TProb<T> & q) const {
212 #ifdef DAI_DEBUG
213 assert( size() == q.size() );
214 #endif
215 TProb<T> sum( *this );
216 sum += q;
217 return sum;
218 }
219
220 /// Return *this minus q
221 TProb<T> operator- (const TProb<T> & q) const {
222 #ifdef DAI_DEBUG
223 assert( size() == q.size() );
224 #endif
225 TProb<T> diff( *this );
226 diff -= q;
227 return diff;
228 }
229
230 /// Pointwise division by q (division by zero yields zero)
231 TProb<T>& operator/= (const TProb<T> & q) {
232 #ifdef DAI_DEBUG
233 assert( size() == q.size() );
234 #endif
235 for( size_t i = 0; i < size(); i++ ) {
236 #ifdef DAI_DEBUG
237 // assert( q[i] != 0.0 );
238 #endif
239 if( q[i] == 0.0 )
240 _p[i] = 0.0;
241 else
242 _p[i] /= q[i];
243 }
244 return *this;
245 }
246
247 /// Pointwise division by q (division by zero yields infinity)
248 TProb<T>& divide (const TProb<T> & q) {
249 #ifdef DAI_DEBUG
250 assert( size() == q.size() );
251 #endif
252 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::divides<T>() );
253 return *this;
254 }
255
256 /// Return quotient of *this with q
257 TProb<T> operator/ (const TProb<T> & q) const {
258 #ifdef DAI_DEBUG
259 assert( size() == q.size() );
260 #endif
261 TProb<T> quot( *this );
262 quot /= q;
263 return quot;
264 }
265
266 /// Return pointwise inverse
267 TProb<T> inverse(bool zero = false) const {
268 TProb<T> inv;
269 inv._p.reserve( size() );
270 if( zero )
271 for( size_t i = 0; i < size(); i++ )
272 inv._p.push_back( _p[i] == 0.0 ? 0.0 : 1.0 / _p[i] );
273 else
274 for( size_t i = 0; i < size(); i++ ) {
275 #ifdef DAI_DEBUG
276 assert( _p[i] != 0.0 );
277 #endif
278 inv._p.push_back( 1.0 / _p[i] );
279 }
280 return inv;
281 }
282
283 /// Return *this to the power of a (pointwise)
284 TProb<T>& operator^= (Real a) {
285 if( a != 1.0 )
286 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::ptr_fun<T, Real, T>(std::pow), a) );
287 return *this;
288 }
289
290 /// Pointwise power of a
291 TProb<T> operator^ (Real a) const {
292 TProb<T> power(*this);
293 power ^= a;
294 return power;
295 }
296
297 /// Pointwise exp
298 const TProb<T>& takeExp() {
299 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::exp) );
300 return *this;
301 }
302
303 /// Pointwise log
304 const TProb<T>& takeLog() {
305 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::log) );
306 return *this;
307 }
308
309 /// Pointwise log (or 0 if == 0)
310 const TProb<T>& takeLog0() {
311 for( size_t i = 0; i < size(); i++ )
312 _p[i] = ( (_p[i] == 0.0) ? 0.0 : std::log( _p[i] ) );
313 return *this;
314 }
315
316 /// Pointwise exp
317 TProb<T> exp() const {
318 TProb<T> e(*this);
319 e.takeExp();
320 return e;
321 }
322
323 /// Pointwise log
324 TProb<T> log() const {
325 TProb<T> l(*this);
326 l.takeLog();
327 return l;
328 }
329
330 /// Pointwise log (or 0 if == 0)
331 TProb<T> log0() const {
332 TProb<T> l0(*this);
333 l0.takeLog0();
334 return l0;
335 }
336
337 /// Pointwise (complex) log (or 0 if == 0)
338 /* CProb clog0() const {
339 CProb l0;
340 l0._p.reserve( size() );
341 for( size_t i = 0; i < size(); i++ )
342 l0._p.push_back( (_p[i] == 0.0) ? 0.0 : std::log( Complex( _p[i] ) ) );
343 return l0;
344 }*/
345
346 /// Return distance of p and q
347 friend Real dist( const TProb<T> & p, const TProb<T> & q, DistType dt ) {
348 #ifdef DAI_DEBUG
349 assert( p.size() == q.size() );
350 #endif
351 Real result = 0.0;
352 switch( dt ) {
353 case DISTL1:
354 for( size_t i = 0; i < p.size(); i++ )
355 result += fabs((Real)p[i] - (Real)q[i]);
356 break;
357
358 case DISTLINF:
359 for( size_t i = 0; i < p.size(); i++ ) {
360 Real z = fabs((Real)p[i] - (Real)q[i]);
361 if( z > result )
362 result = z;
363 }
364 break;
365
366 case DISTTV:
367 for( size_t i = 0; i < p.size(); i++ )
368 result += fabs((Real)p[i] - (Real)q[i]);
369 result *= 0.5;
370 break;
371 }
372 return result;
373 }
374
375 /// Return (complex) Kullback-Leibler distance with q
376 friend Complex KL_dist( const TProb<T> & p, const TProb<T> & q ) {
377 #ifdef DAI_DEBUG
378 assert( p.size() == q.size() );
379 #endif
380 Complex result = 0.0;
381 for( size_t i = 0; i < p.size(); i++ ) {
382 if( (Real) p[i] != 0.0 ) {
383 Complex p_i = p[i];
384 Complex q_i = q[i];
385 result += p_i * (std::log(p_i) - std::log(q_i));
386 }
387 }
388 return result;
389 }
390
391 /// Return sum of all entries
392 T totalSum() const {
393 T Z = std::accumulate( _p.begin(), _p.end(), (T)0 );
394 return Z;
395 }
396
397 /// Converts entries to Real and returns maximum absolute value
398 T maxAbs() const {
399 T Z = 0;
400 for( size_t i = 0; i < size(); i++ ) {
401 Real mag = fabs( (Real) _p[i] );
402 if( mag > Z )
403 Z = mag;
404 }
405 return Z;
406 }
407
408 /// Returns maximum value
409 T maxVal() const {
410 T Z = *std::max_element( _p.begin(), _p.end() );
411 return Z;
412 }
413
414 /// Normalize, using the specified norm
415 T normalize( NormType norm ) {
416 T Z = 0.0;
417 if( norm == NORMPROB )
418 Z = totalSum();
419 else if( norm == NORMLINF )
420 Z = maxAbs();
421 #ifdef DAI_DEBUG
422 assert( Z != 0.0 );
423 #endif
424 *this /= Z;
425 return Z;
426 }
427
428 /// Return normalized copy of *this, using the specified norm
429 TProb<T> normalized( NormType norm ) const {
430 TProb<T> result(*this);
431 result.normalize( norm );
432 return result;
433 }
434
435 /// Returns true if one or more entries are NaN
436 bool hasNaNs() const {
437 return (std::find_if( _p.begin(), _p.end(), isnan ) != _p.end());
438 }
439
440 /// Returns true if one or more entries are negative
441 bool hasNegatives() const {
442 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<Real>(), 0.0 ) ) != _p.end());
443 }
444
445 /// Returns true if one or more entries are non-positive (causes problems with logscale)
446 bool hasNonPositives() const {
447 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less_equal<Real>(), 0.0 ) ) != _p.end());
448 }
449
450 /// Returns (complex) entropy
451 Complex entropy() const {
452 Complex S = 0.0;
453 for( size_t i = 0; i < size(); i++ )
454 S -= xlogx(_p[i]);
455 return S;
456 }
457
458 friend std::ostream& operator<< (std::ostream& os, const TProb<T>& P) {
459 std::copy( P._p.begin(), P._p.end(), std::ostream_iterator<T>(os, " ") );
460 os << std::endl;
461 return os;
462 }
463 };
464
465
466 } // end of namespace dai
467
468
469 #endif