Improved documentation...
[libdai.git] / include / dai / factor.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
7
8 This file is part of libDAI.
9
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25
26 /// \file
27 /// \brief Defines TFactor<T> and Factor classes
28 /// \todo Improve documentation
29
30
31 #ifndef __defined_libdai_factor_h
32 #define __defined_libdai_factor_h
33
34
35 #include <iostream>
36 #include <cmath>
37 #include <dai/prob.h>
38 #include <dai/varset.h>
39 #include <dai/index.h>
40
41
42 namespace dai {
43
44
45 // predefine TFactor<T> class
46 template<typename T> class TFactor;
47
48
49 /// Represents a factor with probability entries represented as Real
50 typedef TFactor<Real> Factor;
51
52
53 /// Represents a probability factor.
54 /** A \e factor is a function of the Cartesian product of the state
55 * spaces of some set of variables to the nonnegative real numbers.
56 * More formally, if \f$x_i \in X_i\f$ for all \f$i\f$, then a factor
57 * depending on the variables \f$\{x_i\}\f$ is a function defined
58 * on \f$\prod_i X_i\f$ with values in \f$[0,\infty)\f$.
59 *
60 * A Factor has two components: a VarSet, defining the set of variables
61 * that the factor depends on, and a TProb<T>, containing the values of
62 * the factor for all possible joint states of the variables.
63 *
64 * \tparam T Should be castable from and to double.
65 */
66 template <typename T> class TFactor {
67 private:
68 VarSet _vs;
69 TProb<T> _p;
70
71 public:
72 /// Construct Factor with empty VarSet
73 TFactor ( Real p = 1.0 ) : _vs(), _p(1,p) {}
74
75 /// Construct Factor from VarSet
76 TFactor( const VarSet& ns ) : _vs(ns), _p(_vs.nrStates()) {}
77
78 /// Construct Factor from VarSet and initial value
79 TFactor( const VarSet& ns, Real p ) : _vs(ns), _p(_vs.nrStates(),p) {}
80
81 /// Construct Factor from VarSet and initial array
82 TFactor( const VarSet& ns, const Real *p ) : _vs(ns), _p(_vs.nrStates(),p) {}
83
84 /// Construct Factor from VarSet and TProb<T>
85 TFactor( const VarSet& ns, const TProb<T>& p ) : _vs(ns), _p(p) {
86 #ifdef DAI_DEBUG
87 assert( _vs.nrStates() == _p.size() );
88 #endif
89 }
90
91 /// Construct Factor from Var
92 TFactor( const Var& n ) : _vs(n), _p(n.states()) {}
93
94 /// Copy constructor
95 TFactor( const TFactor<T> &x ) : _vs(x._vs), _p(x._p) {}
96
97 /// Assignment operator
98 TFactor<T> & operator= (const TFactor<T> &x) {
99 if( this != &x ) {
100 _vs = x._vs;
101 _p = x._p;
102 }
103 return *this;
104 }
105
106 /// Returns const reference to probability entries
107 const TProb<T> & p() const { return _p; }
108 /// Returns reference to probability entries
109 TProb<T> & p() { return _p; }
110
111 /// Returns const reference to variables
112 const VarSet & vars() const { return _vs; }
113
114 /// Returns the number of possible joint states of the variables
115 size_t states() const { return _p.size(); }
116
117 /// Returns a copy of the i'th probability value
118 T operator[] (size_t i) const { return _p[i]; }
119
120 /// Returns a reference to the i'th probability value
121 T& operator[] (size_t i) { return _p[i]; }
122
123 /// Sets all probability entries to p
124 TFactor<T> & fill (T p) { _p.fill( p ); return(*this); }
125
126 /// Fills all probability entries with random values
127 TFactor<T> & randomize () { _p.randomize(); return(*this); }
128
129 /// Returns product of *this with x
130 TFactor<T> operator* (T x) const {
131 Factor result = *this;
132 result.p() *= x;
133 return result;
134 }
135
136 /// Multiplies each probability entry with x
137 TFactor<T>& operator*= (T x) {
138 _p *= x;
139 return *this;
140 }
141
142 /// Returns quotient of *this with x
143 TFactor<T> operator/ (T x) const {
144 Factor result = *this;
145 result.p() /= x;
146 return result;
147 }
148
149 /// Divides each probability entry by x
150 TFactor<T>& operator/= (T x) {
151 _p /= x;
152 return *this;
153 }
154
155 /// Returns product of *this with another Factor
156 TFactor<T> operator* (const TFactor<T>& Q) const;
157
158 /// Returns quotient of *this with another Factor
159 TFactor<T> operator/ (const TFactor<T>& Q) const;
160
161 /// Multiplies *this with another Factor
162 TFactor<T>& operator*= (const TFactor<T>& Q) { return( *this = (*this * Q) ); }
163
164 /// Divides *this by another Factor
165 TFactor<T>& operator/= (const TFactor<T>& Q) { return( *this = (*this / Q) ); }
166
167 /// Returns sum of *this and another Factor (their vars() should be identical)
168 TFactor<T> operator+ (const TFactor<T>& Q) const {
169 #ifdef DAI_DEBUG
170 assert( Q._vs == _vs );
171 #endif
172 TFactor<T> sum(*this);
173 sum._p += Q._p;
174 return sum;
175 }
176
177 /// Returns difference of *this and another Factor (their vars() should be identical)
178 TFactor<T> operator- (const TFactor<T>& Q) const {
179 #ifdef DAI_DEBUG
180 assert( Q._vs == _vs );
181 #endif
182 TFactor<T> sum(*this);
183 sum._p -= Q._p;
184 return sum;
185 }
186
187 /// Adds another Factor to *this (their vars() should be identical)
188 TFactor<T>& operator+= (const TFactor<T>& Q) {
189 #ifdef DAI_DEBUG
190 assert( Q._vs == _vs );
191 #endif
192 _p += Q._p;
193 return *this;
194 }
195
196 /// Subtracts another Factor from *this (their vars() should be identical)
197 TFactor<T>& operator-= (const TFactor<T>& Q) {
198 #ifdef DAI_DEBUG
199 assert( Q._vs == _vs );
200 #endif
201 _p -= Q._p;
202 return *this;
203 }
204
205 /// Adds scalar to *this
206 TFactor<T>& operator+= (T q) {
207 _p += q;
208 return *this;
209 }
210
211 /// Subtracts scalar from *this
212 TFactor<T>& operator-= (T q) {
213 _p -= q;
214 return *this;
215 }
216
217 /// Returns sum of *this and a scalar
218 TFactor<T> operator+ (T q) const {
219 TFactor<T> result(*this);
220 result._p += q;
221 return result;
222 }
223
224 /// Returns difference of *this with a scalar
225 TFactor<T> operator- (T q) const {
226 TFactor<T> result(*this);
227 result._p -= q;
228 return result;
229 }
230
231 /// Returns *this raised to some power
232 TFactor<T> operator^ (Real a) const { TFactor<T> x; x._vs = _vs; x._p = _p^a; return x; }
233
234 /// Raises *this to some power
235 TFactor<T>& operator^= (Real a) { _p ^= a; return *this; }
236
237 /// Sets all entries that are smaller than epsilon to zero
238 TFactor<T>& makeZero( Real epsilon ) {
239 _p.makeZero( epsilon );
240 return *this;
241 }
242
243 /// Sets all entries that are smaller than epsilon to epsilon
244 TFactor<T>& makePositive( Real epsilon ) {
245 _p.makePositive( epsilon );
246 return *this;
247 }
248
249 /// Returns inverse of *this.
250 /** If zero == true, uses 1 / 0 == 0; otherwise 1 / 0 == Inf.
251 */
252 TFactor<T> inverse(bool zero=true) const {
253 TFactor<T> inv;
254 inv._vs = _vs;
255 inv._p = _p.inverse(zero);
256 return inv;
257 }
258
259 /// Returns *this divided by another Factor
260 TFactor<T> divided_by( const TFactor<T>& denom ) const {
261 #ifdef DAI_DEBUG
262 assert( denom._vs == _vs );
263 #endif
264 TFactor<T> quot(*this);
265 quot._p /= denom._p;
266 return quot;
267 }
268
269 /// Divides *this by another Factor
270 TFactor<T>& divide( const TFactor<T>& denom ) {
271 #ifdef DAI_DEBUG
272 assert( denom._vs == _vs );
273 #endif
274 _p /= denom._p;
275 return *this;
276 }
277
278 /// Returns exp of *this
279 TFactor<T> exp() const {
280 TFactor<T> e;
281 e._vs = _vs;
282 e._p = _p.exp();
283 return e;
284 }
285
286 /// Returns absolute value of *this
287 TFactor<T> abs() const {
288 TFactor<T> e;
289 e._vs = _vs;
290 e._p = _p.abs();
291 return e;
292 }
293
294 /// Returns logarithm of *this
295 /** If zero==true, uses log(0)==0; otherwise, log(0)=-Inf.
296 */
297 TFactor<T> log(bool zero=false) const {
298 TFactor<T> l;
299 l._vs = _vs;
300 l._p = _p.log(zero);
301 return l;
302 }
303
304 /// Normalizes *this Factor
305 T normalize( typename Prob::NormType norm = Prob::NORMPROB ) { return _p.normalize( norm ); }
306
307 /// Returns a normalized copy of *this
308 TFactor<T> normalized( typename Prob::NormType norm = Prob::NORMPROB ) const {
309 TFactor<T> result;
310 result._vs = _vs;
311 result._p = _p.normalized( norm );
312 return result;
313 }
314
315 /// Returns a slice of this factor, where the subset ns is in state ns_state
316 Factor slice( const VarSet & ns, size_t ns_state ) const {
317 assert( ns << _vs );
318 VarSet nsrem = _vs / ns;
319 Factor result( nsrem, 0.0 );
320
321 // OPTIMIZE ME
322 IndexFor i_ns (ns, _vs);
323 IndexFor i_nsrem (nsrem, _vs);
324 for( size_t i = 0; i < states(); i++, ++i_ns, ++i_nsrem )
325 if( (size_t)i_ns == ns_state )
326 result._p[i_nsrem] = _p[i];
327
328 return result;
329 }
330
331 /// Returns unnormalized marginal; ns should be a subset of vars()
332 TFactor<T> partSum(const VarSet & ns) const;
333
334 /// Returns (normalized by default) marginal; ns should be a subset of vars()
335 TFactor<T> marginal(const VarSet & ns, bool normed = true) const { if(normed) return partSum(ns).normalized(); else return partSum(ns); }
336
337 /// Sums out all variables except those in ns
338 TFactor<T> notSum(const VarSet & ns) const { return partSum(vars() ^ ns); }
339
340 /// Embeds this factor in a larger VarSet
341 TFactor<T> embed(const VarSet & ns) const {
342 VarSet vs = vars();
343 assert( ns >> vs );
344 if( vs == ns )
345 return *this;
346 else
347 return (*this) * Factor(ns / vs, 1.0);
348 }
349
350 /// Returns true if *this has NANs
351 bool hasNaNs() const { return _p.hasNaNs(); }
352
353 /// Returns true if *this has negative entries
354 bool hasNegatives() const { return _p.hasNegatives(); }
355
356 /// Returns total sum of probability entries
357 T totalSum() const { return _p.totalSum(); }
358
359 /// Returns maximum absolute value of probability entries
360 T maxAbs() const { return _p.maxAbs(); }
361
362 /// Returns maximum value of probability entries
363 T maxVal() const { return _p.maxVal(); }
364
365 /// Returns minimum value of probability entries
366 T minVal() const { return _p.minVal(); }
367
368 /// Returns entropy of *this
369 Real entropy() const { return _p.entropy(); }
370
371 /// Returns strength of *this, between variables i and j, using (52) of [\ref MoK07b]
372 T strength( const Var &i, const Var &j ) const;
373 };
374
375
376 template<typename T> TFactor<T> TFactor<T>::partSum(const VarSet & ns) const {
377 #ifdef DAI_DEBUG
378 assert( ns << _vs );
379 #endif
380
381 TFactor<T> res( ns, 0.0 );
382
383 IndexFor i_res( ns, _vs );
384 for( size_t i = 0; i < _p.size(); i++, ++i_res )
385 res._p[i_res] += _p[i];
386
387 return res;
388 }
389
390
391 template<typename T> TFactor<T> TFactor<T>::operator* (const TFactor<T>& Q) const {
392 TFactor<T> prod( _vs | Q._vs, 0.0 );
393
394 IndexFor i1(_vs, prod._vs);
395 IndexFor i2(Q._vs, prod._vs);
396
397 for( size_t i = 0; i < prod._p.size(); i++, ++i1, ++i2 )
398 prod._p[i] += _p[i1] * Q._p[i2];
399
400 return prod;
401 }
402
403
404 template<typename T> TFactor<T> TFactor<T>::operator/ (const TFactor<T>& Q) const {
405 TFactor<T> quot( _vs + Q._vs, 0.0 );
406
407 IndexFor i1(_vs, quot._vs);
408 IndexFor i2(Q._vs, quot._vs);
409
410 for( size_t i = 0; i < quot._p.size(); i++, ++i1, ++i2 )
411 quot._p[i] += _p[i1] / Q._p[i2];
412
413 return quot;
414 }
415
416
417 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const {
418 #ifdef DAI_DEBUG
419 assert( _vs.contains( i ) );
420 assert( _vs.contains( j ) );
421 assert( i != j );
422 #endif
423 VarSet ij(i, j);
424
425 T max = 0.0;
426 for( size_t alpha1 = 0; alpha1 < i.states(); alpha1++ )
427 for( size_t alpha2 = 0; alpha2 < i.states(); alpha2++ )
428 if( alpha2 != alpha1 )
429 for( size_t beta1 = 0; beta1 < j.states(); beta1++ )
430 for( size_t beta2 = 0; beta2 < j.states(); beta2++ )
431 if( beta2 != beta1 ) {
432 size_t as = 1, bs = 1;
433 if( i < j )
434 bs = i.states();
435 else
436 as = j.states();
437 T f1 = slice( ij, alpha1 * as + beta1 * bs ).p().divide( slice( ij, alpha2 * as + beta1 * bs ).p() ).maxVal();
438 T f2 = slice( ij, alpha2 * as + beta2 * bs ).p().divide( slice( ij, alpha1 * as + beta2 * bs ).p() ).maxVal();
439 T f = f1 * f2;
440 if( f > max )
441 max = f;
442 }
443
444 return std::tanh( 0.25 * std::log( max ) );
445 }
446
447
448 /// Writes a Factor to an output stream
449 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& P) {
450 os << "(" << P.vars() << " <";
451 for( size_t i = 0; i < P.states(); i++ )
452 os << P[i] << " ";
453 os << ">)";
454 return os;
455 }
456
457
458 /// Returns distance between two Factors (with identical vars())
459 template<typename T> Real dist( const TFactor<T> & x, const TFactor<T> & y, Prob::DistType dt ) {
460 if( x.vars().empty() || y.vars().empty() )
461 return -1;
462 else {
463 #ifdef DAI_DEBUG
464 assert( x.vars() == y.vars() );
465 #endif
466 return dist( x.p(), y.p(), dt );
467 }
468 }
469
470
471 /// Returns the pointwise maximum of two Factors
472 template<typename T> TFactor<T> max( const TFactor<T> & P, const TFactor<T> & Q ) {
473 assert( P._vs == Q._vs );
474 return TFactor<T>( P._vs, min( P.p(), Q.p() ) );
475 }
476
477
478 /// Returns the pointwise minimum of two Factors
479 template<typename T> TFactor<T> min( const TFactor<T> & P, const TFactor<T> & Q ) {
480 assert( P._vs == Q._vs );
481 return TFactor<T>( P._vs, max( P.p(), Q.p() ) );
482 }
483
484
485 /// Calculates the mutual information between the two variables in P
486 template<typename T> Real MutualInfo(const TFactor<T> & P) {
487 assert( P.vars().size() == 2 );
488 VarSet::const_iterator it = P.vars().begin();
489 Var i = *it; it++; Var j = *it;
490 TFactor<T> projection = P.marginal(i) * P.marginal(j);
491 return real( dist( P.normalized(), projection, Prob::DISTKL ) );
492 }
493
494
495 } // end of namespace dai
496
497
498 #endif