d65937a200160aad10904eabf2ccc159b7fce6f2
[libdai.git] / include / dai / factor.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
7
8 This file is part of libDAI.
9
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25
26 /// \file
27 /// \brief Defines TFactor<T> and Factor classes
28
29
30 #ifndef __defined_libdai_factor_h
31 #define __defined_libdai_factor_h
32
33
34 #include <iostream>
35 #include <cmath>
36 #include <dai/prob.h>
37 #include <dai/varset.h>
38 #include <dai/index.h>
39
40
41 namespace dai {
42
43
44 // predefine TFactor<T> class
45 template<typename T> class TFactor;
46
47
48 /// Represents a factor with probability entries represented as Real
49 typedef TFactor<Real> Factor;
50
51
52 /// Represents a probability factor.
53 /** A \e factor is a function of the Cartesian product of the state
54 * spaces of some set of variables to the nonnegative real numbers.
55 * More formally, if \f$x_i \in X_i\f$ for all \f$i\f$, then a factor
56 * depending on the variables \f$\{x_i\}\f$ is a function defined
57 * on \f$\prod_i X_i\f$ with values in \f$[0,\infty)\f$.
58 *
59 * A Factor has two components: a VarSet, defining the set of variables
60 * that the factor depends on, and a TProb<T>, containing the values of
61 * the factor for all possible joint states of the variables.
62 *
63 * \tparam T Should be castable from and to double.
64 */
65 template <typename T> class TFactor {
66 private:
67 VarSet _vs;
68 TProb<T> _p;
69
70 public:
71 /// Construct Factor with empty VarSet
72 TFactor ( Real p = 1.0 ) : _vs(), _p(1,p) {}
73
74 /// Construct Factor from VarSet
75 TFactor( const VarSet& ns ) : _vs(ns), _p(_vs.nrStates()) {}
76
77 /// Construct Factor from VarSet and initial value
78 TFactor( const VarSet& ns, Real p ) : _vs(ns), _p(_vs.nrStates(),p) {}
79
80 /// Construct Factor from VarSet and initial array
81 TFactor( const VarSet& ns, const Real *p ) : _vs(ns), _p(_vs.nrStates(),p) {}
82
83 /// Construct Factor from VarSet and TProb<T>
84 TFactor( const VarSet& ns, const TProb<T>& p ) : _vs(ns), _p(p) {
85 #ifdef DAI_DEBUG
86 assert( _vs.nrStates() == _p.size() );
87 #endif
88 }
89
90 /// Construct Factor from Var
91 TFactor( const Var& n ) : _vs(n), _p(n.states()) {}
92
93 /// Copy constructor
94 TFactor( const TFactor<T> &x ) : _vs(x._vs), _p(x._p) {}
95
96 /// Assignment operator
97 TFactor<T> & operator= (const TFactor<T> &x) {
98 if( this != &x ) {
99 _vs = x._vs;
100 _p = x._p;
101 }
102 return *this;
103 }
104
105 /// Returns const reference to probability entries
106 const TProb<T> & p() const { return _p; }
107 /// Returns reference to probability entries
108 TProb<T> & p() { return _p; }
109
110 /// Returns const reference to variables
111 const VarSet & vars() const { return _vs; }
112
113 /// Returns the number of possible joint states of the variables
114 size_t states() const { return _p.size(); }
115
116 /// Returns a copy of the i'th probability value
117 T operator[] (size_t i) const { return _p[i]; }
118
119 /// Returns a reference to the i'th probability value
120 T& operator[] (size_t i) { return _p[i]; }
121
122 /// Sets all probability entries to p
123 TFactor<T> & fill (T p) { _p.fill( p ); return(*this); }
124
125 /// Fills all probability entries with random values
126 TFactor<T> & randomize () { _p.randomize(); return(*this); }
127
128 /// Returns product of *this with x
129 TFactor<T> operator* (T x) const {
130 Factor result = *this;
131 result.p() *= x;
132 return result;
133 }
134
135 /// Multiplies each probability entry with x
136 TFactor<T>& operator*= (T x) {
137 _p *= x;
138 return *this;
139 }
140
141 /// Returns quotient of *this with x
142 TFactor<T> operator/ (T x) const {
143 Factor result = *this;
144 result.p() /= x;
145 return result;
146 }
147
148 /// Divides each probability entry by x
149 TFactor<T>& operator/= (T x) {
150 _p /= x;
151 return *this;
152 }
153
154 /// Returns product of *this with another Factor
155 TFactor<T> operator* (const TFactor<T>& Q) const;
156
157 /// Returns quotient of *this with another Factor
158 TFactor<T> operator/ (const TFactor<T>& Q) const;
159
160 /// Multiplies *this with another Factor
161 TFactor<T>& operator*= (const TFactor<T>& Q) { return( *this = (*this * Q) ); }
162
163 /// Divides *this by another Factor
164 TFactor<T>& operator/= (const TFactor<T>& Q) { return( *this = (*this / Q) ); }
165
166 /// Returns sum of *this and another Factor (their vars() should be identical)
167 TFactor<T> operator+ (const TFactor<T>& Q) const {
168 #ifdef DAI_DEBUG
169 assert( Q._vs == _vs );
170 #endif
171 TFactor<T> sum(*this);
172 sum._p += Q._p;
173 return sum;
174 }
175
176 /// Returns difference of *this and another Factor (their vars() should be identical)
177 TFactor<T> operator- (const TFactor<T>& Q) const {
178 #ifdef DAI_DEBUG
179 assert( Q._vs == _vs );
180 #endif
181 TFactor<T> sum(*this);
182 sum._p -= Q._p;
183 return sum;
184 }
185
186 /// Adds another Factor to *this (their vars() should be identical)
187 TFactor<T>& operator+= (const TFactor<T>& Q) {
188 #ifdef DAI_DEBUG
189 assert( Q._vs == _vs );
190 #endif
191 _p += Q._p;
192 return *this;
193 }
194
195 /// Subtracts another Factor from *this (their vars() should be identical)
196 TFactor<T>& operator-= (const TFactor<T>& Q) {
197 #ifdef DAI_DEBUG
198 assert( Q._vs == _vs );
199 #endif
200 _p -= Q._p;
201 return *this;
202 }
203
204 /// Adds scalar to *this
205 TFactor<T>& operator+= (T q) {
206 _p += q;
207 return *this;
208 }
209
210 /// Subtracts scalar from *this
211 TFactor<T>& operator-= (T q) {
212 _p -= q;
213 return *this;
214 }
215
216 /// Returns sum of *this and a scalar
217 TFactor<T> operator+ (T q) const {
218 TFactor<T> result(*this);
219 result._p += q;
220 return result;
221 }
222
223 /// Returns difference of *this with a scalar
224 TFactor<T> operator- (T q) const {
225 TFactor<T> result(*this);
226 result._p -= q;
227 return result;
228 }
229
230 /// Returns *this raised to some power
231 TFactor<T> operator^ (Real a) const { TFactor<T> x; x._vs = _vs; x._p = _p^a; return x; }
232
233 /// Raises *this to some power
234 TFactor<T>& operator^= (Real a) { _p ^= a; return *this; }
235
236 /// Sets all entries that are smaller than epsilon to zero
237 TFactor<T>& makeZero( Real epsilon ) {
238 _p.makeZero( epsilon );
239 return *this;
240 }
241
242 /// Sets all entries that are smaller than epsilon to epsilon
243 TFactor<T>& makePositive( Real epsilon ) {
244 _p.makePositive( epsilon );
245 return *this;
246 }
247
248 /// Returns inverse of *this
249 TFactor<T> inverse() const {
250 TFactor<T> inv;
251 inv._vs = _vs;
252 inv._p = _p.inverse(true); // FIXME
253 return inv;
254 }
255
256 /// Returns *this divided by another Factor
257 TFactor<T> divided_by( const TFactor<T>& denom ) const {
258 #ifdef DAI_DEBUG
259 assert( denom._vs == _vs );
260 #endif
261 TFactor<T> quot(*this);
262 quot._p /= denom._p;
263 return quot;
264 }
265
266 /// Divides *this by another Factor
267 TFactor<T>& divide( const TFactor<T>& denom ) {
268 #ifdef DAI_DEBUG
269 assert( denom._vs == _vs );
270 #endif
271 _p /= denom._p;
272 return *this;
273 }
274
275 /// Returns exp of *this
276 TFactor<T> exp() const {
277 TFactor<T> e;
278 e._vs = _vs;
279 e._p = _p.exp();
280 return e;
281 }
282
283 /// Returns absolute value of *this
284 TFactor<T> abs() const {
285 TFactor<T> e;
286 e._vs = _vs;
287 e._p = _p.abs();
288 return e;
289 }
290
291 /// Returns logarithm of *this
292 TFactor<T> log() const {
293 TFactor<T> l;
294 l._vs = _vs;
295 l._p = _p.log();
296 return l;
297 }
298
299 /// Returns logarithm of *this (defining log(0)=0)
300 TFactor<T> log0() const {
301 TFactor<T> l0;
302 l0._vs = _vs;
303 l0._p = _p.log0();
304 return l0;
305 }
306
307 /// Normalizes *this Factor
308 T normalize( typename Prob::NormType norm = Prob::NORMPROB ) { return _p.normalize( norm ); }
309
310 /// Returns a normalized copy of *this
311 TFactor<T> normalized( typename Prob::NormType norm = Prob::NORMPROB ) const {
312 TFactor<T> result;
313 result._vs = _vs;
314 result._p = _p.normalized( norm );
315 return result;
316 }
317
318 /// Returns a slice of this factor, where the subset ns is in state ns_state
319 Factor slice( const VarSet & ns, size_t ns_state ) const {
320 assert( ns << _vs );
321 VarSet nsrem = _vs / ns;
322 Factor result( nsrem, 0.0 );
323
324 // OPTIMIZE ME
325 IndexFor i_ns (ns, _vs);
326 IndexFor i_nsrem (nsrem, _vs);
327 for( size_t i = 0; i < states(); i++, ++i_ns, ++i_nsrem )
328 if( (size_t)i_ns == ns_state )
329 result._p[i_nsrem] = _p[i];
330
331 return result;
332 }
333
334 /// Returns unnormalized marginal; ns should be a subset of vars()
335 TFactor<T> partSum(const VarSet & ns) const;
336
337 /// Returns (normalized by default) marginal; ns should be a subset of vars()
338 TFactor<T> marginal(const VarSet & ns, bool normed = true) const { if(normed) return partSum(ns).normalized(); else return partSum(ns); }
339
340 /// Sums out all variables except those in ns
341 TFactor<T> notSum(const VarSet & ns) const { return partSum(vars() ^ ns); }
342
343 /// Embeds this factor in a larger VarSet
344 TFactor<T> embed(const VarSet & ns) const {
345 VarSet vs = vars();
346 assert( ns >> vs );
347 if( vs == ns )
348 return *this;
349 else
350 return (*this) * Factor(ns / vs, 1.0);
351 }
352
353 /// Returns true if *this has NANs
354 bool hasNaNs() const { return _p.hasNaNs(); }
355
356 /// Returns true if *this has negative entries
357 bool hasNegatives() const { return _p.hasNegatives(); }
358
359 /// Returns total sum of probability entries
360 T totalSum() const { return _p.totalSum(); }
361
362 /// Returns maximum absolute value of probability entries
363 T maxAbs() const { return _p.maxAbs(); }
364
365 /// Returns maximum value of probability entries
366 T maxVal() const { return _p.maxVal(); }
367
368 /// Returns minimum value of probability entries
369 T minVal() const { return _p.minVal(); }
370
371 /// Returns entropy of *this
372 Real entropy() const { return _p.entropy(); }
373
374 /// Returns strength of *this, between variables i and j, using (52) of [\ref MoK07b]
375 T strength( const Var &i, const Var &j ) const;
376 };
377
378
379 template<typename T> TFactor<T> TFactor<T>::partSum(const VarSet & ns) const {
380 #ifdef DAI_DEBUG
381 assert( ns << _vs );
382 #endif
383
384 TFactor<T> res( ns, 0.0 );
385
386 IndexFor i_res( ns, _vs );
387 for( size_t i = 0; i < _p.size(); i++, ++i_res )
388 res._p[i_res] += _p[i];
389
390 return res;
391 }
392
393
394 template<typename T> TFactor<T> TFactor<T>::operator* (const TFactor<T>& Q) const {
395 TFactor<T> prod( _vs | Q._vs, 0.0 );
396
397 IndexFor i1(_vs, prod._vs);
398 IndexFor i2(Q._vs, prod._vs);
399
400 for( size_t i = 0; i < prod._p.size(); i++, ++i1, ++i2 )
401 prod._p[i] += _p[i1] * Q._p[i2];
402
403 return prod;
404 }
405
406
407 template<typename T> TFactor<T> TFactor<T>::operator/ (const TFactor<T>& Q) const {
408 TFactor<T> quot( _vs + Q._vs, 0.0 );
409
410 IndexFor i1(_vs, quot._vs);
411 IndexFor i2(Q._vs, quot._vs);
412
413 for( size_t i = 0; i < quot._p.size(); i++, ++i1, ++i2 )
414 quot._p[i] += _p[i1] / Q._p[i2];
415
416 return quot;
417 }
418
419
420 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const {
421 #ifdef DAI_DEBUG
422 assert( _vs.contains( i ) );
423 assert( _vs.contains( j ) );
424 assert( i != j );
425 #endif
426 VarSet ij(i, j);
427
428 T max = 0.0;
429 for( size_t alpha1 = 0; alpha1 < i.states(); alpha1++ )
430 for( size_t alpha2 = 0; alpha2 < i.states(); alpha2++ )
431 if( alpha2 != alpha1 )
432 for( size_t beta1 = 0; beta1 < j.states(); beta1++ )
433 for( size_t beta2 = 0; beta2 < j.states(); beta2++ )
434 if( beta2 != beta1 ) {
435 size_t as = 1, bs = 1;
436 if( i < j )
437 bs = i.states();
438 else
439 as = j.states();
440 T f1 = slice( ij, alpha1 * as + beta1 * bs ).p().divide( slice( ij, alpha2 * as + beta1 * bs ).p() ).maxVal();
441 T f2 = slice( ij, alpha2 * as + beta2 * bs ).p().divide( slice( ij, alpha1 * as + beta2 * bs ).p() ).maxVal();
442 T f = f1 * f2;
443 if( f > max )
444 max = f;
445 }
446
447 return std::tanh( 0.25 * std::log( max ) );
448 }
449
450
451 /// Writes a Factor to an output stream
452 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& P) {
453 os << "(" << P.vars() << " <";
454 for( size_t i = 0; i < P.states(); i++ )
455 os << P[i] << " ";
456 os << ">)";
457 return os;
458 }
459
460
461 /// Returns distance between two Factors (with identical vars())
462 template<typename T> Real dist( const TFactor<T> & x, const TFactor<T> & y, Prob::DistType dt ) {
463 if( x.vars().empty() || y.vars().empty() )
464 return -1;
465 else {
466 #ifdef DAI_DEBUG
467 assert( x.vars() == y.vars() );
468 #endif
469 return dist( x.p(), y.p(), dt );
470 }
471 }
472
473
474 /// Returns the pointwise maximum of two Factors
475 template<typename T> TFactor<T> max( const TFactor<T> & P, const TFactor<T> & Q ) {
476 assert( P._vs == Q._vs );
477 return TFactor<T>( P._vs, min( P.p(), Q.p() ) );
478 }
479
480
481 /// Returns the pointwise minimum of two Factors
482 template<typename T> TFactor<T> min( const TFactor<T> & P, const TFactor<T> & Q ) {
483 assert( P._vs == Q._vs );
484 return TFactor<T>( P._vs, max( P.p(), Q.p() ) );
485 }
486
487
488 /// Calculates the mutual information between the two variables in P
489 template<typename T> Real MutualInfo(const TFactor<T> & P) {
490 assert( P.vars().size() == 2 );
491 VarSet::const_iterator it = P.vars().begin();
492 Var i = *it; it++; Var j = *it;
493 TFactor<T> projection = P.marginal(i) * P.marginal(j);
494 return real( dist( P.normalized(), projection, Prob::DISTKL ) );
495 }
496
497
498 } // end of namespace dai
499
500
501 #endif