Replaced complex numbers by real numbers
[libdai.git] / include / dai / prob.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_prob_h
23 #define __defined_libdai_prob_h
24
25
26 #include <cmath>
27 #include <vector>
28 #include <ostream>
29 #include <cassert>
30 #include <algorithm>
31 #include <numeric>
32 #include <functional>
33 #include <dai/util.h>
34
35
36 namespace dai {
37
38
39 typedef double Real;
40
41 template<typename T> class TProb;
42 typedef TProb<Real> Prob;
43
44
45 /// TProb<T> implements a probability vector of type T.
46 /// T should be castable from and to double.
47 template <typename T> class TProb {
48 protected:
49 /// The entries
50 std::vector<T> _p;
51
52 private:
53 /// Calculate x times log(x), or 0 if x == 0
54 Real xlogx( Real x ) const { return( x == 0.0 ? 0.0 : x * std::log(x)); }
55
56 public:
57 /// NORMPROB means that the sum of all entries should be 1
58 /// NORMLINF means that the maximum absolute value of all entries should be 1
59 typedef enum { NORMPROB, NORMLINF } NormType;
60 /// DISTL1 is the L-1 distance (sum of absolute values of pointwise difference)
61 /// DISTLINF is the L-inf distance (maximum absolute value of pointwise difference)
62 /// DISTTV is the Total Variation distance
63 typedef enum { DISTL1, DISTLINF, DISTTV } DistType;
64
65 /// Default constructor
66 TProb() {}
67
68 /// Construct uniform distribution of given length
69 explicit TProb( size_t n ) : _p(std::vector<T>(n, 1.0 / n)) {}
70
71 /// Construct with given length and initial value
72 TProb( size_t n, Real p ) : _p(n, (T)p) {}
73
74 /// Construct with given length and initial array
75 TProb( size_t n, const Real* p ) : _p(p, p + n ) {}
76
77 /// Provide read access to _p
78 const std::vector<T> & p() const { return _p; }
79
80 /// Provide full access to _p
81 std::vector<T> & p() { return _p; }
82
83 /// Provide read access to ith element of _p
84 T operator[]( size_t i ) const { return _p[i]; }
85
86 /// Provide full access to ith element of _p
87 T& operator[]( size_t i ) { return _p[i]; }
88
89 /// Set all elements to x
90 TProb<T> & fill(T x) {
91 std::fill( _p.begin(), _p.end(), x );
92 return *this;
93 }
94
95 /// Set all elements to iid random numbers from uniform(0,1) distribution
96 TProb<T> & randomize() {
97 std::generate(_p.begin(), _p.end(), rnd_uniform);
98 return *this;
99 }
100
101 /// Return size
102 size_t size() const {
103 return _p.size();
104 }
105
106 /// Make entries zero if (Real) absolute value smaller than epsilon
107 TProb<T>& makeZero (Real epsilon) {
108 for( size_t i = 0; i < size(); i++ )
109 if( fabs((Real)_p[i]) < epsilon )
110 _p[i] = 0;
111 // std::replace_if( _p.begin(), _p.end(), fabs((Real)boost::lambda::_1) < epsilon, 0.0 );
112 return *this;
113 }
114
115 /// Multiplication with T x
116 TProb<T>& operator*= (T x) {
117 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::multiplies<T>(), x) );
118 return *this;
119 }
120
121 /// Return product of *this with T x
122 TProb<T> operator* (T x) const {
123 TProb<T> prod( *this );
124 prod *= x;
125 return prod;
126 }
127
128 /// Division by T x
129 TProb<T>& operator/= (T x) {
130 #ifdef DAI_DEBUG
131 assert( x != 0.0 );
132 #endif
133 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::divides<T>(), x ) );
134 return *this;
135 }
136
137 /// Return quotient of *this and T x
138 TProb<T> operator/ (T x) const {
139 TProb<T> quot( *this );
140 quot /= x;
141 return quot;
142 }
143
144 /// addition of x
145 TProb<T>& operator+= (T x) {
146 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::plus<T>(), x ) );
147 return *this;
148 }
149
150 /// Return sum of *this with T x
151 TProb<T> operator+ (T x) const {
152 TProb<T> sum( *this );
153 sum += x;
154 return sum;
155 }
156
157 /// Difference by T x
158 TProb<T>& operator-= (T x) {
159 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::minus<T>(), x ) );
160 return *this;
161 }
162
163 /// Return difference of *this and T x
164 TProb<T> operator- (T x) const {
165 TProb<T> diff( *this );
166 diff -= x;
167 return diff;
168 }
169
170 /// Pointwise multiplication with q
171 TProb<T>& operator*= (const TProb<T> & q) {
172 #ifdef DAI_DEBUG
173 assert( size() == q.size() );
174 #endif
175 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::multiplies<T>() );
176 return *this;
177 }
178
179 /// Return product of *this with q
180 TProb<T> operator* (const TProb<T> & q) const {
181 #ifdef DAI_DEBUG
182 assert( size() == q.size() );
183 #endif
184 TProb<T> prod( *this );
185 prod *= q;
186 return prod;
187 }
188
189 /// Pointwise addition with q
190 TProb<T>& operator+= (const TProb<T> & q) {
191 #ifdef DAI_DEBUG
192 assert( size() == q.size() );
193 #endif
194 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::plus<T>() );
195 return *this;
196 }
197
198 /// Pointwise subtraction of q
199 TProb<T>& operator-= (const TProb<T> & q) {
200 #ifdef DAI_DEBUG
201 assert( size() == q.size() );
202 #endif
203 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::minus<T>() );
204 return *this;
205 }
206
207 /// Return sum of *this and q
208 TProb<T> operator+ (const TProb<T> & q) const {
209 #ifdef DAI_DEBUG
210 assert( size() == q.size() );
211 #endif
212 TProb<T> sum( *this );
213 sum += q;
214 return sum;
215 }
216
217 /// Return *this minus q
218 TProb<T> operator- (const TProb<T> & q) const {
219 #ifdef DAI_DEBUG
220 assert( size() == q.size() );
221 #endif
222 TProb<T> diff( *this );
223 diff -= q;
224 return diff;
225 }
226
227 /// Pointwise division by q (division by zero yields zero)
228 TProb<T>& operator/= (const TProb<T> & q) {
229 #ifdef DAI_DEBUG
230 assert( size() == q.size() );
231 #endif
232 for( size_t i = 0; i < size(); i++ ) {
233 if( q[i] == 0.0 )
234 _p[i] = 0.0;
235 else
236 _p[i] /= q[i];
237 }
238 return *this;
239 }
240
241 /// Pointwise division by q (division by zero yields infinity)
242 TProb<T>& divide (const TProb<T> & q) {
243 #ifdef DAI_DEBUG
244 assert( size() == q.size() );
245 #endif
246 std::transform( _p.begin(), _p.end(), q._p.begin(), _p.begin(), std::divides<T>() );
247 return *this;
248 }
249
250 /// Return quotient of *this with q
251 TProb<T> operator/ (const TProb<T> & q) const {
252 #ifdef DAI_DEBUG
253 assert( size() == q.size() );
254 #endif
255 TProb<T> quot( *this );
256 quot /= q;
257 return quot;
258 }
259
260 /// Return pointwise inverse
261 TProb<T> inverse(bool zero = false) const {
262 TProb<T> inv;
263 inv._p.reserve( size() );
264 if( zero )
265 for( size_t i = 0; i < size(); i++ )
266 inv._p.push_back( _p[i] == 0.0 ? 0.0 : 1.0 / _p[i] );
267 else
268 for( size_t i = 0; i < size(); i++ ) {
269 #ifdef DAI_DEBUG
270 assert( _p[i] != 0.0 );
271 #endif
272 inv._p.push_back( 1.0 / _p[i] );
273 }
274 return inv;
275 }
276
277 /// Return *this to the power of a (pointwise)
278 TProb<T>& operator^= (Real a) {
279 if( a != 1.0 )
280 std::transform( _p.begin(), _p.end(), _p.begin(), std::bind2nd( std::ptr_fun<T, Real, T>(std::pow), a) );
281 return *this;
282 }
283
284 /// Pointwise power of a
285 TProb<T> operator^ (Real a) const {
286 TProb<T> power(*this);
287 power ^= a;
288 return power;
289 }
290
291 /// Pointwise exp
292 const TProb<T>& takeExp() {
293 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::exp) );
294 return *this;
295 }
296
297 /// Pointwise log
298 const TProb<T>& takeLog() {
299 std::transform( _p.begin(), _p.end(), _p.begin(), std::ptr_fun<T, T>(std::log) );
300 return *this;
301 }
302
303 /// Pointwise log (or 0 if == 0)
304 const TProb<T>& takeLog0() {
305 for( size_t i = 0; i < size(); i++ )
306 _p[i] = ( (_p[i] == 0.0) ? 0.0 : std::log( _p[i] ) );
307 return *this;
308 }
309
310 /// Pointwise exp
311 TProb<T> exp() const {
312 TProb<T> e(*this);
313 e.takeExp();
314 return e;
315 }
316
317 /// Pointwise log
318 TProb<T> log() const {
319 TProb<T> l(*this);
320 l.takeLog();
321 return l;
322 }
323
324 /// Pointwise log (or 0 if == 0)
325 TProb<T> log0() const {
326 TProb<T> l0(*this);
327 l0.takeLog0();
328 return l0;
329 }
330
331 /// Return distance of p and q
332 friend Real dist( const TProb<T> & p, const TProb<T> & q, DistType dt ) {
333 #ifdef DAI_DEBUG
334 assert( p.size() == q.size() );
335 #endif
336 Real result = 0.0;
337 switch( dt ) {
338 case DISTL1:
339 for( size_t i = 0; i < p.size(); i++ )
340 result += fabs((Real)p[i] - (Real)q[i]);
341 break;
342
343 case DISTLINF:
344 for( size_t i = 0; i < p.size(); i++ ) {
345 Real z = fabs((Real)p[i] - (Real)q[i]);
346 if( z > result )
347 result = z;
348 }
349 break;
350
351 case DISTTV:
352 for( size_t i = 0; i < p.size(); i++ )
353 result += fabs((Real)p[i] - (Real)q[i]);
354 result *= 0.5;
355 break;
356 }
357 return result;
358 }
359
360 /// Return Kullback-Leibler distance with q
361 friend Real KL_dist( const TProb<T> & p, const TProb<T> & q ) {
362 #ifdef DAI_DEBUG
363 assert( p.size() == q.size() );
364 #endif
365 Real result = 0.0;
366 for( size_t i = 0; i < p.size(); i++ ) {
367 if( (Real) p[i] != 0.0 ) {
368 Real p_i = p[i];
369 Real q_i = q[i];
370 result += p_i * (std::log(p_i) - std::log(q_i));
371 }
372 }
373 return result;
374 }
375
376 /// Return sum of all entries
377 T totalSum() const {
378 T Z = std::accumulate( _p.begin(), _p.end(), (T)0 );
379 return Z;
380 }
381
382 /// Converts entries to Real and returns maximum absolute value
383 T maxAbs() const {
384 T Z = 0;
385 for( size_t i = 0; i < size(); i++ ) {
386 Real mag = fabs( (Real) _p[i] );
387 if( mag > Z )
388 Z = mag;
389 }
390 return Z;
391 }
392
393 /// Returns maximum value
394 T maxVal() const {
395 T Z = *std::max_element( _p.begin(), _p.end() );
396 return Z;
397 }
398
399 /// Normalize, using the specified norm
400 T normalize( NormType norm ) {
401 T Z = 0.0;
402 if( norm == NORMPROB )
403 Z = totalSum();
404 else if( norm == NORMLINF )
405 Z = maxAbs();
406 #ifdef DAI_DEBUG
407 assert( Z != 0.0 );
408 #endif
409 *this /= Z;
410 return Z;
411 }
412
413 /// Return normalized copy of *this, using the specified norm
414 TProb<T> normalized( NormType norm ) const {
415 TProb<T> result(*this);
416 result.normalize( norm );
417 return result;
418 }
419
420 /// Returns true if one or more entries are NaN
421 bool hasNaNs() const {
422 return (std::find_if( _p.begin(), _p.end(), isnan ) != _p.end());
423 }
424
425 /// Returns true if one or more entries are negative
426 bool hasNegatives() const {
427 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less<Real>(), 0.0 ) ) != _p.end());
428 }
429
430 /// Returns true if one or more entries are non-positive (causes problems with logscale)
431 bool hasNonPositives() const {
432 return (std::find_if( _p.begin(), _p.end(), std::bind2nd( std::less_equal<Real>(), 0.0 ) ) != _p.end());
433 }
434
435 /// Returns entropy
436 Real entropy() const {
437 Real S = 0.0;
438 for( size_t i = 0; i < size(); i++ )
439 S -= xlogx(_p[i]);
440 return S;
441 }
442
443 friend std::ostream& operator<< (std::ostream& os, const TProb<T>& P) {
444 std::copy( P._p.begin(), P._p.end(), std::ostream_iterator<T>(os, " ") );
445 os << std::endl;
446 return os;
447 }
448 };
449
450
451 } // end of namespace dai
452
453
454 #endif