1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
8 This file is part of libDAI.
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
27 /// \brief Defines TFactor<T> and Factor classes
28 /// \todo Improve documentation
31 #ifndef __defined_libdai_factor_h
32 #define __defined_libdai_factor_h
38 #include <dai/varset.h>
39 #include <dai/index.h>
45 // predefine TFactor<T> class
46 template<typename T
> class TFactor
;
49 /// Represents a factor with probability entries represented as Real
50 typedef TFactor
<Real
> Factor
;
53 /// Represents a probability factor.
54 /** A \e factor is a function of the Cartesian product of the state
55 * spaces of some set of variables to the nonnegative real numbers.
56 * More formally, if \f$x_i \in X_i\f$ for all \f$i\f$, then a factor
57 * depending on the variables \f$\{x_i\}\f$ is a function defined
58 * on \f$\prod_i X_i\f$ with values in \f$[0,\infty)\f$.
60 * A Factor has two components: a VarSet, defining the set of variables
61 * that the factor depends on, and a TProb<T>, containing the values of
62 * the factor for all possible joint states of the variables.
64 * \tparam T Should be castable from and to double.
66 template <typename T
> class TFactor
{
72 /// Construct Factor with empty VarSet
73 TFactor ( Real p
= 1.0 ) : _vs(), _p(1,p
) {}
75 /// Construct Factor from VarSet
76 TFactor( const VarSet
& ns
) : _vs(ns
), _p(_vs
.nrStates()) {}
78 /// Construct Factor from VarSet and initial value
79 TFactor( const VarSet
& ns
, Real p
) : _vs(ns
), _p(_vs
.nrStates(),p
) {}
81 /// Construct Factor from VarSet and initial array
82 TFactor( const VarSet
& ns
, const Real
*p
) : _vs(ns
), _p(_vs
.nrStates(),p
) {}
84 /// Construct Factor from VarSet and TProb<T>
85 TFactor( const VarSet
& ns
, const TProb
<T
>& p
) : _vs(ns
), _p(p
) {
87 assert( _vs
.nrStates() == _p
.size() );
91 /// Construct Factor from Var
92 TFactor( const Var
& n
) : _vs(n
), _p(n
.states()) {}
95 TFactor( const TFactor
<T
> &x
) : _vs(x
._vs
), _p(x
._p
) {}
97 /// Assignment operator
98 TFactor
<T
> & operator= (const TFactor
<T
> &x
) {
106 /// Returns const reference to probability entries
107 const TProb
<T
> & p() const { return _p
; }
108 /// Returns reference to probability entries
109 TProb
<T
> & p() { return _p
; }
111 /// Returns const reference to variables
112 const VarSet
& vars() const { return _vs
; }
114 /// Returns the number of possible joint states of the variables
115 size_t states() const { return _p
.size(); }
117 /// Returns a copy of the i'th probability value
118 T
operator[] (size_t i
) const { return _p
[i
]; }
120 /// Returns a reference to the i'th probability value
121 T
& operator[] (size_t i
) { return _p
[i
]; }
123 /// Sets all probability entries to p
124 TFactor
<T
> & fill (T p
) { _p
.fill( p
); return(*this); }
126 /// Fills all probability entries with random values
127 TFactor
<T
> & randomize () { _p
.randomize(); return(*this); }
129 /// Returns product of *this with x
130 TFactor
<T
> operator* (T x
) const {
131 Factor result
= *this;
136 /// Multiplies each probability entry with x
137 TFactor
<T
>& operator*= (T x
) {
142 /// Returns quotient of *this with x
143 TFactor
<T
> operator/ (T x
) const {
144 Factor result
= *this;
149 /// Divides each probability entry by x
150 TFactor
<T
>& operator/= (T x
) {
155 /// Returns product of *this with another Factor
156 TFactor
<T
> operator* (const TFactor
<T
>& Q
) const;
158 /// Returns quotient of *this with another Factor
159 TFactor
<T
> operator/ (const TFactor
<T
>& Q
) const;
161 /// Multiplies *this with another Factor
162 TFactor
<T
>& operator*= (const TFactor
<T
>& Q
) { return( *this = (*this * Q
) ); }
164 /// Divides *this by another Factor
165 TFactor
<T
>& operator/= (const TFactor
<T
>& Q
) { return( *this = (*this / Q
) ); }
167 /// Returns sum of *this and another Factor (their vars() should be identical)
168 TFactor
<T
> operator+ (const TFactor
<T
>& Q
) const {
170 assert( Q
._vs
== _vs
);
172 TFactor
<T
> sum(*this);
177 /// Returns difference of *this and another Factor (their vars() should be identical)
178 TFactor
<T
> operator- (const TFactor
<T
>& Q
) const {
180 assert( Q
._vs
== _vs
);
182 TFactor
<T
> sum(*this);
187 /// Adds another Factor to *this (their vars() should be identical)
188 TFactor
<T
>& operator+= (const TFactor
<T
>& Q
) {
190 assert( Q
._vs
== _vs
);
196 /// Subtracts another Factor from *this (their vars() should be identical)
197 TFactor
<T
>& operator-= (const TFactor
<T
>& Q
) {
199 assert( Q
._vs
== _vs
);
205 /// Adds scalar to *this
206 TFactor
<T
>& operator+= (T q
) {
211 /// Subtracts scalar from *this
212 TFactor
<T
>& operator-= (T q
) {
217 /// Returns sum of *this and a scalar
218 TFactor
<T
> operator+ (T q
) const {
219 TFactor
<T
> result(*this);
224 /// Returns difference of *this with a scalar
225 TFactor
<T
> operator- (T q
) const {
226 TFactor
<T
> result(*this);
231 /// Returns *this raised to some power
232 TFactor
<T
> operator^ (Real a
) const { TFactor
<T
> x
; x
._vs
= _vs
; x
._p
= _p
^a
; return x
; }
234 /// Raises *this to some power
235 TFactor
<T
>& operator^= (Real a
) { _p
^= a
; return *this; }
237 /// Sets all entries that are smaller than epsilon to zero
238 TFactor
<T
>& makeZero( Real epsilon
) {
239 _p
.makeZero( epsilon
);
243 /// Sets all entries that are smaller than epsilon to epsilon
244 TFactor
<T
>& makePositive( Real epsilon
) {
245 _p
.makePositive( epsilon
);
249 /// Returns inverse of *this
250 TFactor
<T
> inverse() const {
253 inv
._p
= _p
.inverse(true); // FIXME
257 /// Returns *this divided by another Factor
258 TFactor
<T
> divided_by( const TFactor
<T
>& denom
) const {
260 assert( denom
._vs
== _vs
);
262 TFactor
<T
> quot(*this);
267 /// Divides *this by another Factor
268 TFactor
<T
>& divide( const TFactor
<T
>& denom
) {
270 assert( denom
._vs
== _vs
);
276 /// Returns exp of *this
277 TFactor
<T
> exp() const {
284 /// Returns absolute value of *this
285 TFactor
<T
> abs() const {
292 /// Returns logarithm of *this
293 TFactor
<T
> log() const {
300 /// Returns logarithm of *this (defining log(0)=0)
301 TFactor
<T
> log0() const {
308 /// Normalizes *this Factor
309 T
normalize( typename
Prob::NormType norm
= Prob::NORMPROB
) { return _p
.normalize( norm
); }
311 /// Returns a normalized copy of *this
312 TFactor
<T
> normalized( typename
Prob::NormType norm
= Prob::NORMPROB
) const {
315 result
._p
= _p
.normalized( norm
);
319 /// Returns a slice of this factor, where the subset ns is in state ns_state
320 Factor
slice( const VarSet
& ns
, size_t ns_state
) const {
322 VarSet nsrem
= _vs
/ ns
;
323 Factor
result( nsrem
, 0.0 );
326 IndexFor
i_ns (ns
, _vs
);
327 IndexFor
i_nsrem (nsrem
, _vs
);
328 for( size_t i
= 0; i
< states(); i
++, ++i_ns
, ++i_nsrem
)
329 if( (size_t)i_ns
== ns_state
)
330 result
._p
[i_nsrem
] = _p
[i
];
335 /// Returns unnormalized marginal; ns should be a subset of vars()
336 TFactor
<T
> partSum(const VarSet
& ns
) const;
338 /// Returns (normalized by default) marginal; ns should be a subset of vars()
339 TFactor
<T
> marginal(const VarSet
& ns
, bool normed
= true) const { if(normed
) return partSum(ns
).normalized(); else return partSum(ns
); }
341 /// Sums out all variables except those in ns
342 TFactor
<T
> notSum(const VarSet
& ns
) const { return partSum(vars() ^ ns
); }
344 /// Embeds this factor in a larger VarSet
345 TFactor
<T
> embed(const VarSet
& ns
) const {
351 return (*this) * Factor(ns
/ vs
, 1.0);
354 /// Returns true if *this has NANs
355 bool hasNaNs() const { return _p
.hasNaNs(); }
357 /// Returns true if *this has negative entries
358 bool hasNegatives() const { return _p
.hasNegatives(); }
360 /// Returns total sum of probability entries
361 T
totalSum() const { return _p
.totalSum(); }
363 /// Returns maximum absolute value of probability entries
364 T
maxAbs() const { return _p
.maxAbs(); }
366 /// Returns maximum value of probability entries
367 T
maxVal() const { return _p
.maxVal(); }
369 /// Returns minimum value of probability entries
370 T
minVal() const { return _p
.minVal(); }
372 /// Returns entropy of *this
373 Real
entropy() const { return _p
.entropy(); }
375 /// Returns strength of *this, between variables i and j, using (52) of [\ref MoK07b]
376 T
strength( const Var
&i
, const Var
&j
) const;
380 template<typename T
> TFactor
<T
> TFactor
<T
>::partSum(const VarSet
& ns
) const {
385 TFactor
<T
> res( ns
, 0.0 );
387 IndexFor
i_res( ns
, _vs
);
388 for( size_t i
= 0; i
< _p
.size(); i
++, ++i_res
)
389 res
._p
[i_res
] += _p
[i
];
395 template<typename T
> TFactor
<T
> TFactor
<T
>::operator* (const TFactor
<T
>& Q
) const {
396 TFactor
<T
> prod( _vs
| Q
._vs
, 0.0 );
398 IndexFor
i1(_vs
, prod
._vs
);
399 IndexFor
i2(Q
._vs
, prod
._vs
);
401 for( size_t i
= 0; i
< prod
._p
.size(); i
++, ++i1
, ++i2
)
402 prod
._p
[i
] += _p
[i1
] * Q
._p
[i2
];
408 template<typename T
> TFactor
<T
> TFactor
<T
>::operator/ (const TFactor
<T
>& Q
) const {
409 TFactor
<T
> quot( _vs
+ Q
._vs
, 0.0 );
411 IndexFor
i1(_vs
, quot
._vs
);
412 IndexFor
i2(Q
._vs
, quot
._vs
);
414 for( size_t i
= 0; i
< quot
._p
.size(); i
++, ++i1
, ++i2
)
415 quot
._p
[i
] += _p
[i1
] / Q
._p
[i2
];
421 template<typename T
> T TFactor
<T
>::strength( const Var
&i
, const Var
&j
) const {
423 assert( _vs
.contains( i
) );
424 assert( _vs
.contains( j
) );
430 for( size_t alpha1
= 0; alpha1
< i
.states(); alpha1
++ )
431 for( size_t alpha2
= 0; alpha2
< i
.states(); alpha2
++ )
432 if( alpha2
!= alpha1
)
433 for( size_t beta1
= 0; beta1
< j
.states(); beta1
++ )
434 for( size_t beta2
= 0; beta2
< j
.states(); beta2
++ )
435 if( beta2
!= beta1
) {
436 size_t as
= 1, bs
= 1;
441 T f1
= slice( ij
, alpha1
* as
+ beta1
* bs
).p().divide( slice( ij
, alpha2
* as
+ beta1
* bs
).p() ).maxVal();
442 T f2
= slice( ij
, alpha2
* as
+ beta2
* bs
).p().divide( slice( ij
, alpha1
* as
+ beta2
* bs
).p() ).maxVal();
448 return std::tanh( 0.25 * std::log( max
) );
452 /// Writes a Factor to an output stream
453 template<typename T
> std::ostream
& operator<< (std::ostream
& os
, const TFactor
<T
>& P
) {
454 os
<< "(" << P
.vars() << " <";
455 for( size_t i
= 0; i
< P
.states(); i
++ )
462 /// Returns distance between two Factors (with identical vars())
463 template<typename T
> Real
dist( const TFactor
<T
> & x
, const TFactor
<T
> & y
, Prob::DistType dt
) {
464 if( x
.vars().empty() || y
.vars().empty() )
468 assert( x
.vars() == y
.vars() );
470 return dist( x
.p(), y
.p(), dt
);
475 /// Returns the pointwise maximum of two Factors
476 template<typename T
> TFactor
<T
> max( const TFactor
<T
> & P
, const TFactor
<T
> & Q
) {
477 assert( P
._vs
== Q
._vs
);
478 return TFactor
<T
>( P
._vs
, min( P
.p(), Q
.p() ) );
482 /// Returns the pointwise minimum of two Factors
483 template<typename T
> TFactor
<T
> min( const TFactor
<T
> & P
, const TFactor
<T
> & Q
) {
484 assert( P
._vs
== Q
._vs
);
485 return TFactor
<T
>( P
._vs
, max( P
.p(), Q
.p() ) );
489 /// Calculates the mutual information between the two variables in P
490 template<typename T
> Real
MutualInfo(const TFactor
<T
> & P
) {
491 assert( P
.vars().size() == 2 );
492 VarSet::const_iterator it
= P
.vars().begin();
493 Var i
= *it
; it
++; Var j
= *it
;
494 TFactor
<T
> projection
= P
.marginal(i
) * P
.marginal(j
);
495 return real( dist( P
.normalized(), projection
, Prob::DISTKL
) );
499 } // end of namespace dai