Added Exceptions framework (and more)
[libdai.git] / include / dai / prob.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_prob_h
23 #define __defined_libdai_prob_h
24
25
26 #include <complex>
27 #include <cmath>
28 #include <vector>
29 #include <ostream>
30 #include <cassert>
31 #include <algorithm>
32 #include <numeric>
33 #include <functional>
34 #include <dai/util.h>
35
36
37 namespace dai {
38
39
40 typedef double Real;
41 typedef std::complex<double> Complex;
42
43 template<typename T> class TProb;
44 typedef TProb<Real> Prob;
45 typedef TProb<Complex> CProb;
46
47
48 /// TProb<T> implements a probability vector of type T.
49 /// T should be castable from and to double and to complex.
50 template <typename T> class TProb {
51 protected:
52 /// The entries
53 std::vector<T> _p;
54
55 private:
56 /// Calculate x times log(x), or 0 if x == 0
57 Complex xlogx( Real x ) const { return( x == 0.0 ? 0.0 : Complex(x) * std::log(Complex(x))); }
58
59 public:
60 /// NORMPROB means that the sum of all entries should be 1
61 /// NORMLINF means that the maximum absolute value of all entries should be 1
62 typedef enum { NORMPROB, NORMLINF } NormType;
63 /// DISTL1 is the L-1 distance (sum of absolute values of pointwise difference)
64 /// DISTLINF is the L-inf distance (maximum absolute value of pointwise difference)
65 /// DISTTV is the Total Variation distance
66 typedef enum { DISTL1, DISTLINF, DISTTV } DistType;
67
68 /// Default constructor
69 TProb() {}
70
71 /// Construct uniform distribution of given length
72 explicit TProb( size_t n ) : _p(std::vector<T>(n, 1.0 / n)) {}
73
74 /// Construct with given length and initial value
75 TProb( size_t n, Real p ) : _p(n, (T)p) {}
76
77 /// Construct with given length and initial array
78 TProb( size_t n, const Real* p ) : _p(p, p + n ) {}
79
80 /// Provide read access to _p
81 const std::vector<T> & p() const { return _p; }
82
83 /// Provide full access to _p
84 std::vector<T> & p() { return _p; }
85
86 /// Provide read access to ith element of _p
87 T operator[]( size_t i ) const { return _p[i]; }
88
89 /// Provide full access to ith element of _p
90 T& operator[]( size_t i ) { return _p[i]; }
91
92 /// Set all elements to x
93 TProb<T> & fill(T x) {
94 std::fill( _p.begin(), _p.end(), x );
95 return *this;
96 }
97
98 /// Set all elements to iid random numbers from uniform(0,1) distribution
99 TProb<T> & randomize() {
100 std::generate(_p.begin(), _p.end(), rnd_uniform);
101 return *this;
102 }
103
104 /// Return size
105 size_t size() const {
106 return _p.size();
107 }
108
109 /// Make entries zero if (Real) absolute value smaller than epsilon
110 TProb<T>& makeZero (Real epsilon) {
111 for( size_t i = 0; i < size(); i++ )
112 if( fabs((Real)_p[i]) < epsilon )
113 _p[i] = 0;
114 // std::replace_if( _p.begin(), _p.end(), fabs((Real)boost::lambda::_1) < epsilon, 0.0 );
115 return *this;
116 }
117
118 /// Multiplication with T x
119 TProb<T>& operator*= (T x) {
120 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::multiplies<T>(), x) );
121 return *this;
122 }
123
124 /// Return product of *this with T x
125 TProb<T> operator* (T x) const {
126 TProb<T> prod( *this );
127 prod *= x;
128 return prod;
129 }
130
131 /// Division by T x
132 TProb<T>& operator/= (T x) {
133 #ifdef DAI_DEBUG
134 assert( x != 0.0 );
135 #endif
136 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::divides<T>(), x ) );
137 return *this;
138 }
139
140 /// Return quotient of *this and T x
141 TProb<T> operator/ (T x) const {
142 TProb<T> quot( *this );
143 quot /= x;
144 return quot;
145 }
146
147 /// addition of x
148 TProb<T>& operator+= (T x) {
149 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::plus<T>(), x ) );
150 return *this;
151 }
152
153 /// Return sum of *this with T x
154 TProb<T> operator+ (T x) const {
155 TProb<T> sum( *this );
156 sum += x;
157 return sum;
158 }
159
160 /// Difference by T x
161 TProb<T>& operator-= (T x) {
162 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::minus<T>(), x ) );
163 return *this;
164 }
165
166 /// Return difference of *this and T x
167 TProb<T> operator- (T x) const {
168 TProb<T> diff( *this );
169 diff -= x;
170 return diff;
171 }
172
173 /// Pointwise multiplication with q
174 TProb<T>& operator*= (const TProb<T> & q) {
175 #ifdef DAI_DEBUG
176 assert( size() == q.size() );
177 #endif
178 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::multiplies<T>() );
179 return *this;
180 }
181
182 /// Return product of *this with q
183 TProb<T> operator* (const TProb<T> & q) const {
184 #ifdef DAI_DEBUG
185 assert( size() == q.size() );
186 #endif
187 TProb<T> prod( *this );
188 prod *= q;
189 return prod;
190 }
191
192 /// Pointwise addition with q
193 TProb<T>& operator+= (const TProb<T> & q) {
194 #ifdef DAI_DEBUG
195 assert( size() == q.size() );
196 #endif
197 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::plus<T>() );
198 return *this;
199 }
200
201 /// Pointwise subtraction of q
202 TProb<T>& operator-= (const TProb<T> & q) {
203 #ifdef DAI_DEBUG
204 assert( size() == q.size() );
205 #endif
206 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::minus<T>() );
207 return *this;
208 }
209
210 /// Return sum of *this and q
211 TProb<T> operator+ (const TProb<T> & q) const {
212 #ifdef DAI_DEBUG
213 assert( size() == q.size() );
214 #endif
215 TProb<T> sum( *this );
216 sum += q;
217 return sum;
218 }
219
220 /// Return *this minus q
221 TProb<T> operator- (const TProb<T> & q) const {
222 #ifdef DAI_DEBUG
223 assert( size() == q.size() );
224 #endif
225 TProb<T> diff( *this );
226 diff -= q;
227 return diff;
228 }
229
230 /// Pointwise division by q (division by zero yields zero)
231 TProb<T>& operator/= (const TProb<T> & q) {
232 #ifdef DAI_DEBUG
233 assert( size() == q.size() );
234 #endif
235 for( size_t i = 0; i < size(); i++ ) {
236 if( q[i] == 0.0 )
237 _p[i] = 0.0;
238 else
239 _p[i] /= q[i];
240 }
241 return *this;
242 }
243
244 /// Pointwise division by q (division by zero yields infinity)
245 TProb<T>& divide (const TProb<T> & q) {
246 #ifdef DAI_DEBUG
247 assert( size() == q.size() );
248 #endif
249 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::divides<T>() );
250 return *this;
251 }
252
253 /// Return quotient of *this with q
254 TProb<T> operator/ (const TProb<T> & q) const {
255 #ifdef DAI_DEBUG
256 assert( size() == q.size() );
257 #endif
258 TProb<T> quot( *this );
259 quot /= q;
260 return quot;
261 }
262
263 /// Return pointwise inverse
264 TProb<T> inverse(bool zero = false) const {
265 TProb<T> inv;
266 inv._p.reserve( size() );
267 if( zero )
268 for( size_t i = 0; i < size(); i++ )
269 inv._p.push_back( _p[i] == 0.0 ? 0.0 : 1.0 / _p[i] );
270 else
271 for( size_t i = 0; i < size(); i++ ) {
272 #ifdef DAI_DEBUG
273 assert( _p[i] != 0.0 );
274 #endif
275 inv._p.push_back( 1.0 / _p[i] );
276 }
277 return inv;
278 }
279
280 /// Return *this to the power of a (pointwise)
281 TProb<T>& operator^= (Real a) {
282 if( a != 1.0 )
283 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::ptr_fun<T, Real, T>(std::pow), a) );
284 return *this;
285 }
286
287 /// Pointwise power of a
288 TProb<T> operator^ (Real a) const {
289 TProb<T> power(*this);
290 power ^= a;
291 return power;
292 }
293
294 /// Pointwise exp
295 const TProb<T>& takeExp() {
296 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::exp) );
297 return *this;
298 }
299
300 /// Pointwise log
301 const TProb<T>& takeLog() {
302 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::log) );
303 return *this;
304 }
305
306 /// Pointwise log (or 0 if == 0)
307 const TProb<T>& takeLog0() {
308 for( size_t i = 0; i < size(); i++ )
309 _p[i] = ( (_p[i] == 0.0) ? 0.0 : std::log( _p[i] ) );
310 return *this;
311 }
312
313 /// Pointwise exp
314 TProb<T> exp() const {
315 TProb<T> e(*this);
316 e.takeExp();
317 return e;
318 }
319
320 /// Pointwise log
321 TProb<T> log() const {
322 TProb<T> l(*this);
323 l.takeLog();
324 return l;
325 }
326
327 /// Pointwise log (or 0 if == 0)
328 TProb<T> log0() const {
329 TProb<T> l0(*this);
330 l0.takeLog0();
331 return l0;
332 }
333
334 /// Pointwise (complex) log (or 0 if == 0)
335 /* CProb clog0() const {
336 CProb l0;
337 l0._p.reserve( size() );
338 for( size_t i = 0; i < size(); i++ )
339 l0._p.push_back( (_p[i] == 0.0) ? 0.0 : std::log( Complex( _p[i] ) ) );
340 return l0;
341 }*/
342
343 /// Return distance of p and q
344 friend Real dist( const TProb<T> & p, const TProb<T> & q, DistType dt ) {
345 #ifdef DAI_DEBUG
346 assert( p.size() == q.size() );
347 #endif
348 Real result = 0.0;
349 switch( dt ) {
350 case DISTL1:
351 for( size_t i = 0; i < p.size(); i++ )
352 result += fabs((Real)p[i] - (Real)q[i]);
353 break;
354
355 case DISTLINF:
356 for( size_t i = 0; i < p.size(); i++ ) {
357 Real z = fabs((Real)p[i] - (Real)q[i]);
358 if( z > result )
359 result = z;
360 }
361 break;
362
363 case DISTTV:
364 for( size_t i = 0; i < p.size(); i++ )
365 result += fabs((Real)p[i] - (Real)q[i]);
366 result *= 0.5;
367 break;
368 }
369 return result;
370 }
371
372 /// Return (complex) Kullback-Leibler distance with q
373 friend Complex KL_dist( const TProb<T> & p, const TProb<T> & q ) {
374 #ifdef DAI_DEBUG
375 assert( p.size() == q.size() );
376 #endif
377 Complex result = 0.0;
378 for( size_t i = 0; i < p.size(); i++ ) {
379 if( (Real) p[i] != 0.0 ) {
380 Complex p_i = p[i];
381 Complex q_i = q[i];
382 result += p_i * (std::log(p_i) - std::log(q_i));
383 }
384 }
385 return result;
386 }
387
388 /// Return sum of all entries
389 T totalSum() const {
390 T Z = std::accumulate( _p.begin(), _p.end(), (T)0 );
391 return Z;
392 }
393
394 /// Converts entries to Real and returns maximum absolute value
395 T maxAbs() const {
396 T Z = 0;
397 for( size_t i = 0; i < size(); i++ ) {
398 Real mag = fabs( (Real) _p[i] );
399 if( mag > Z )
400 Z = mag;
401 }
402 return Z;
403 }
404
405 /// Returns maximum value
406 T maxVal() const {
407 T Z = *std::max_element( _p.begin(), _p.end() );
408 return Z;
409 }
410
411 /// Normalize, using the specified norm
412 T normalize( NormType norm ) {
413 T Z = 0.0;
414 if( norm == NORMPROB )
415 Z = totalSum();
416 else if( norm == NORMLINF )
417 Z = maxAbs();
418 #ifdef DAI_DEBUG
419 assert( Z != 0.0 );
420 #endif
421 *this /= Z;
422 return Z;
423 }
424
425 /// Return normalized copy of *this, using the specified norm
426 TProb<T> normalized( NormType norm ) const {
427 TProb<T> result(*this);
428 result.normalize( norm );
429 return result;
430 }
431
432 /// Returns true if one or more entries are NaN
433 bool hasNaNs() const {
434 return (std::find_if( _p.begin(), _p.end(), isnan ) != _p.end());
435 }
436
437 /// Returns true if one or more entries are negative
438 bool hasNegatives() const {
439 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<Real>(), 0.0 ) ) != _p.end());
440 }
441
442 /// Returns true if one or more entries are non-positive (causes problems with logscale)
443 bool hasNonPositives() const {
444 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less_equal<Real>(), 0.0 ) ) != _p.end());
445 }
446
447 /// Returns (complex) entropy
448 Complex entropy() const {
449 Complex S = 0.0;
450 for( size_t i = 0; i < size(); i++ )
451 S -= xlogx(_p[i]);
452 return S;
453 }
454
455 friend std::ostream& operator<< (std::ostream& os, const TProb<T>& P) {
456 std::copy( P._p.begin(), P._p.end(), std::ostream_iterator<T>(os, " ") );
457 os << std::endl;
458 return os;
459 }
460 };
461
462
463 } // end of namespace dai
464
465
466 #endif