Oops, correct previous partial commit.
[libdai.git] / include / dai / factor.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
7
8 This file is part of libDAI.
9
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25
26 /// \file
27 /// \brief Defines TFactor<T> and Factor classes
28 /// \todo Improve documentation
29
30
31 #ifndef __defined_libdai_factor_h
32 #define __defined_libdai_factor_h
33
34
35 #include <iostream>
36 #include <cmath>
37 #include <dai/prob.h>
38 #include <dai/varset.h>
39 #include <dai/index.h>
40
41
42 namespace dai {
43
44
45 // predefine TFactor<T> class
46 template<typename T> class TFactor;
47
48
49 /// Represents a factor with probability entries represented as Real
50 typedef TFactor<Real> Factor;
51
52
53 /// Represents a probability factor.
54 /** A \e factor is a function of the Cartesian product of the state
55 * spaces of some set of variables to the nonnegative real numbers.
56 * More formally, if \f$x_i \in X_i\f$ for all \f$i\f$, then a factor
57 * depending on the variables \f$\{x_i\}\f$ is a function defined
58 * on \f$\prod_i X_i\f$ with values in \f$[0,\infty)\f$.
59 *
60 * A Factor has two components: a VarSet, defining the set of variables
61 * that the factor depends on, and a TProb<T>, containing the values of
62 * the factor for all possible joint states of the variables.
63 *
64 * \tparam T Should be castable from and to double.
65 */
66 template <typename T> class TFactor {
67 private:
68 VarSet _vs;
69 TProb<T> _p;
70
71 public:
72 /// Construct Factor with empty VarSet
73 TFactor ( Real p = 1.0 ) : _vs(), _p(1,p) {}
74
75 /// Construct Factor from VarSet
76 TFactor( const VarSet& ns ) : _vs(ns), _p(_vs.nrStates()) {}
77
78 /// Construct Factor from VarSet and initial value
79 TFactor( const VarSet& ns, Real p ) : _vs(ns), _p(_vs.nrStates(),p) {}
80
81 /// Construct Factor from VarSet and initial array
82 TFactor( const VarSet& ns, const Real *p ) : _vs(ns), _p(_vs.nrStates(),p) {}
83
84 /// Construct Factor from VarSet and TProb<T>
85 TFactor( const VarSet& ns, const TProb<T>& p ) : _vs(ns), _p(p) {
86 #ifdef DAI_DEBUG
87 assert( _vs.nrStates() == _p.size() );
88 #endif
89 }
90
91 /// Construct Factor from Var
92 TFactor( const Var& n ) : _vs(n), _p(n.states()) {}
93
94 /// Copy constructor
95 TFactor( const TFactor<T> &x ) : _vs(x._vs), _p(x._p) {}
96
97 /// Assignment operator
98 TFactor<T> & operator= (const TFactor<T> &x) {
99 if( this != &x ) {
100 _vs = x._vs;
101 _p = x._p;
102 }
103 return *this;
104 }
105
106 /// Returns const reference to probability entries
107 const TProb<T> & p() const { return _p; }
108 /// Returns reference to probability entries
109 TProb<T> & p() { return _p; }
110
111 /// Returns const reference to variables
112 const VarSet & vars() const { return _vs; }
113
114 /// Returns the number of possible joint states of the variables
115 size_t states() const { return _p.size(); }
116
117 /// Returns a copy of the i'th probability value
118 T operator[] (size_t i) const { return _p[i]; }
119
120 /// Returns a reference to the i'th probability value
121 T& operator[] (size_t i) { return _p[i]; }
122
123 /// Sets all probability entries to p
124 TFactor<T> & fill (T p) { _p.fill( p ); return(*this); }
125
126 /// Fills all probability entries with random values
127 TFactor<T> & randomize () { _p.randomize(); return(*this); }
128
129 /// Returns product of *this with x
130 TFactor<T> operator* (T x) const {
131 Factor result = *this;
132 result.p() *= x;
133 return result;
134 }
135
136 /// Multiplies each probability entry with x
137 TFactor<T>& operator*= (T x) {
138 _p *= x;
139 return *this;
140 }
141
142 /// Returns quotient of *this with x
143 TFactor<T> operator/ (T x) const {
144 Factor result = *this;
145 result.p() /= x;
146 return result;
147 }
148
149 /// Divides each probability entry by x
150 TFactor<T>& operator/= (T x) {
151 _p /= x;
152 return *this;
153 }
154
155 /// Returns product of *this with another Factor
156 TFactor<T> operator* (const TFactor<T>& Q) const;
157
158 /// Returns quotient of *this with another Factor
159 TFactor<T> operator/ (const TFactor<T>& Q) const;
160
161 /// Multiplies *this with another Factor
162 TFactor<T>& operator*= (const TFactor<T>& Q) { return( *this = (*this * Q) ); }
163
164 /// Divides *this by another Factor
165 TFactor<T>& operator/= (const TFactor<T>& Q) { return( *this = (*this / Q) ); }
166
167 /// Returns sum of *this and another Factor (their vars() should be identical)
168 TFactor<T> operator+ (const TFactor<T>& Q) const {
169 #ifdef DAI_DEBUG
170 assert( Q._vs == _vs );
171 #endif
172 TFactor<T> sum(*this);
173 sum._p += Q._p;
174 return sum;
175 }
176
177 /// Returns difference of *this and another Factor (their vars() should be identical)
178 TFactor<T> operator- (const TFactor<T>& Q) const {
179 #ifdef DAI_DEBUG
180 assert( Q._vs == _vs );
181 #endif
182 TFactor<T> sum(*this);
183 sum._p -= Q._p;
184 return sum;
185 }
186
187 /// Adds another Factor to *this (their vars() should be identical)
188 TFactor<T>& operator+= (const TFactor<T>& Q) {
189 #ifdef DAI_DEBUG
190 assert( Q._vs == _vs );
191 #endif
192 _p += Q._p;
193 return *this;
194 }
195
196 /// Subtracts another Factor from *this (their vars() should be identical)
197 TFactor<T>& operator-= (const TFactor<T>& Q) {
198 #ifdef DAI_DEBUG
199 assert( Q._vs == _vs );
200 #endif
201 _p -= Q._p;
202 return *this;
203 }
204
205 /// Adds scalar to *this
206 TFactor<T>& operator+= (T q) {
207 _p += q;
208 return *this;
209 }
210
211 /// Subtracts scalar from *this
212 TFactor<T>& operator-= (T q) {
213 _p -= q;
214 return *this;
215 }
216
217 /// Returns sum of *this and a scalar
218 TFactor<T> operator+ (T q) const {
219 TFactor<T> result(*this);
220 result._p += q;
221 return result;
222 }
223
224 /// Returns difference of *this with a scalar
225 TFactor<T> operator- (T q) const {
226 TFactor<T> result(*this);
227 result._p -= q;
228 return result;
229 }
230
231 /// Returns *this raised to some power
232 TFactor<T> operator^ (Real a) const { TFactor<T> x; x._vs = _vs; x._p = _p^a; return x; }
233
234 /// Raises *this to some power
235 TFactor<T>& operator^= (Real a) { _p ^= a; return *this; }
236
237 /// Sets all entries that are smaller than epsilon to zero
238 TFactor<T>& makeZero( Real epsilon ) {
239 _p.makeZero( epsilon );
240 return *this;
241 }
242
243 /// Sets all entries that are smaller than epsilon to epsilon
244 TFactor<T>& makePositive( Real epsilon ) {
245 _p.makePositive( epsilon );
246 return *this;
247 }
248
249 /// Returns inverse of *this
250 TFactor<T> inverse() const {
251 TFactor<T> inv;
252 inv._vs = _vs;
253 inv._p = _p.inverse(true); // FIXME
254 return inv;
255 }
256
257 /// Returns *this divided by another Factor
258 TFactor<T> divided_by( const TFactor<T>& denom ) const {
259 #ifdef DAI_DEBUG
260 assert( denom._vs == _vs );
261 #endif
262 TFactor<T> quot(*this);
263 quot._p /= denom._p;
264 return quot;
265 }
266
267 /// Divides *this by another Factor
268 TFactor<T>& divide( const TFactor<T>& denom ) {
269 #ifdef DAI_DEBUG
270 assert( denom._vs == _vs );
271 #endif
272 _p /= denom._p;
273 return *this;
274 }
275
276 /// Returns exp of *this
277 TFactor<T> exp() const {
278 TFactor<T> e;
279 e._vs = _vs;
280 e._p = _p.exp();
281 return e;
282 }
283
284 /// Returns absolute value of *this
285 TFactor<T> abs() const {
286 TFactor<T> e;
287 e._vs = _vs;
288 e._p = _p.abs();
289 return e;
290 }
291
292 /// Returns logarithm of *this
293 TFactor<T> log() const {
294 TFactor<T> l;
295 l._vs = _vs;
296 l._p = _p.log();
297 return l;
298 }
299
300 /// Returns logarithm of *this (defining log(0)=0)
301 TFactor<T> log0() const {
302 TFactor<T> l0;
303 l0._vs = _vs;
304 l0._p = _p.log0();
305 return l0;
306 }
307
308 /// Normalizes *this Factor
309 T normalize( typename Prob::NormType norm = Prob::NORMPROB ) { return _p.normalize( norm ); }
310
311 /// Returns a normalized copy of *this
312 TFactor<T> normalized( typename Prob::NormType norm = Prob::NORMPROB ) const {
313 TFactor<T> result;
314 result._vs = _vs;
315 result._p = _p.normalized( norm );
316 return result;
317 }
318
319 /// Returns a slice of this factor, where the subset ns is in state ns_state
320 Factor slice( const VarSet & ns, size_t ns_state ) const {
321 assert( ns << _vs );
322 VarSet nsrem = _vs / ns;
323 Factor result( nsrem, 0.0 );
324
325 // OPTIMIZE ME
326 IndexFor i_ns (ns, _vs);
327 IndexFor i_nsrem (nsrem, _vs);
328 for( size_t i = 0; i < states(); i++, ++i_ns, ++i_nsrem )
329 if( (size_t)i_ns == ns_state )
330 result._p[i_nsrem] = _p[i];
331
332 return result;
333 }
334
335 /// Returns unnormalized marginal; ns should be a subset of vars()
336 TFactor<T> partSum(const VarSet & ns) const;
337
338 /// Returns (normalized by default) marginal; ns should be a subset of vars()
339 TFactor<T> marginal(const VarSet & ns, bool normed = true) const { if(normed) return partSum(ns).normalized(); else return partSum(ns); }
340
341 /// Sums out all variables except those in ns
342 TFactor<T> notSum(const VarSet & ns) const { return partSum(vars() ^ ns); }
343
344 /// Embeds this factor in a larger VarSet
345 TFactor<T> embed(const VarSet & ns) const {
346 VarSet vs = vars();
347 assert( ns >> vs );
348 if( vs == ns )
349 return *this;
350 else
351 return (*this) * Factor(ns / vs, 1.0);
352 }
353
354 /// Returns true if *this has NANs
355 bool hasNaNs() const { return _p.hasNaNs(); }
356
357 /// Returns true if *this has negative entries
358 bool hasNegatives() const { return _p.hasNegatives(); }
359
360 /// Returns total sum of probability entries
361 T totalSum() const { return _p.totalSum(); }
362
363 /// Returns maximum absolute value of probability entries
364 T maxAbs() const { return _p.maxAbs(); }
365
366 /// Returns maximum value of probability entries
367 T maxVal() const { return _p.maxVal(); }
368
369 /// Returns minimum value of probability entries
370 T minVal() const { return _p.minVal(); }
371
372 /// Returns entropy of *this
373 Real entropy() const { return _p.entropy(); }
374
375 /// Returns strength of *this, between variables i and j, using (52) of [\ref MoK07b]
376 T strength( const Var &i, const Var &j ) const;
377 };
378
379
380 template<typename T> TFactor<T> TFactor<T>::partSum(const VarSet & ns) const {
381 #ifdef DAI_DEBUG
382 assert( ns << _vs );
383 #endif
384
385 TFactor<T> res( ns, 0.0 );
386
387 IndexFor i_res( ns, _vs );
388 for( size_t i = 0; i < _p.size(); i++, ++i_res )
389 res._p[i_res] += _p[i];
390
391 return res;
392 }
393
394
395 template<typename T> TFactor<T> TFactor<T>::operator* (const TFactor<T>& Q) const {
396 TFactor<T> prod( _vs | Q._vs, 0.0 );
397
398 IndexFor i1(_vs, prod._vs);
399 IndexFor i2(Q._vs, prod._vs);
400
401 for( size_t i = 0; i < prod._p.size(); i++, ++i1, ++i2 )
402 prod._p[i] += _p[i1] * Q._p[i2];
403
404 return prod;
405 }
406
407
408 template<typename T> TFactor<T> TFactor<T>::operator/ (const TFactor<T>& Q) const {
409 TFactor<T> quot( _vs + Q._vs, 0.0 );
410
411 IndexFor i1(_vs, quot._vs);
412 IndexFor i2(Q._vs, quot._vs);
413
414 for( size_t i = 0; i < quot._p.size(); i++, ++i1, ++i2 )
415 quot._p[i] += _p[i1] / Q._p[i2];
416
417 return quot;
418 }
419
420
421 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const {
422 #ifdef DAI_DEBUG
423 assert( _vs.contains( i ) );
424 assert( _vs.contains( j ) );
425 assert( i != j );
426 #endif
427 VarSet ij(i, j);
428
429 T max = 0.0;
430 for( size_t alpha1 = 0; alpha1 < i.states(); alpha1++ )
431 for( size_t alpha2 = 0; alpha2 < i.states(); alpha2++ )
432 if( alpha2 != alpha1 )
433 for( size_t beta1 = 0; beta1 < j.states(); beta1++ )
434 for( size_t beta2 = 0; beta2 < j.states(); beta2++ )
435 if( beta2 != beta1 ) {
436 size_t as = 1, bs = 1;
437 if( i < j )
438 bs = i.states();
439 else
440 as = j.states();
441 T f1 = slice( ij, alpha1 * as + beta1 * bs ).p().divide( slice( ij, alpha2 * as + beta1 * bs ).p() ).maxVal();
442 T f2 = slice( ij, alpha2 * as + beta2 * bs ).p().divide( slice( ij, alpha1 * as + beta2 * bs ).p() ).maxVal();
443 T f = f1 * f2;
444 if( f > max )
445 max = f;
446 }
447
448 return std::tanh( 0.25 * std::log( max ) );
449 }
450
451
452 /// Writes a Factor to an output stream
453 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& P) {
454 os << "(" << P.vars() << " <";
455 for( size_t i = 0; i < P.states(); i++ )
456 os << P[i] << " ";
457 os << ">)";
458 return os;
459 }
460
461
462 /// Returns distance between two Factors (with identical vars())
463 template<typename T> Real dist( const TFactor<T> & x, const TFactor<T> & y, Prob::DistType dt ) {
464 if( x.vars().empty() || y.vars().empty() )
465 return -1;
466 else {
467 #ifdef DAI_DEBUG
468 assert( x.vars() == y.vars() );
469 #endif
470 return dist( x.p(), y.p(), dt );
471 }
472 }
473
474
475 /// Returns the pointwise maximum of two Factors
476 template<typename T> TFactor<T> max( const TFactor<T> & P, const TFactor<T> & Q ) {
477 assert( P._vs == Q._vs );
478 return TFactor<T>( P._vs, min( P.p(), Q.p() ) );
479 }
480
481
482 /// Returns the pointwise minimum of two Factors
483 template<typename T> TFactor<T> min( const TFactor<T> & P, const TFactor<T> & Q ) {
484 assert( P._vs == Q._vs );
485 return TFactor<T>( P._vs, max( P.p(), Q.p() ) );
486 }
487
488
489 /// Calculates the mutual information between the two variables in P
490 template<typename T> Real MutualInfo(const TFactor<T> & P) {
491 assert( P.vars().size() == 2 );
492 VarSet::const_iterator it = P.vars().begin();
493 Var i = *it; it++; Var j = *it;
494 TFactor<T> projection = P.marginal(i) * P.marginal(j);
495 return real( dist( P.normalized(), projection, Prob::DISTKL ) );
496 }
497
498
499 } // end of namespace dai
500
501
502 #endif