Updated copyrights
[libdai.git] / include / dai / factor.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
7
8 This file is part of libDAI.
9
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 */
24
25
26 #ifndef __defined_libdai_factor_h
27 #define __defined_libdai_factor_h
28
29
30 #include <iostream>
31 #include <cmath>
32 #include <dai/prob.h>
33 #include <dai/varset.h>
34 #include <dai/index.h>
35
36
37 namespace dai {
38
39
40 template<typename T> class TFactor;
41 typedef TFactor<Real> Factor;
42
43
44 // predefine friends
45 template<typename T> Real dist( const TFactor<T> & x, const TFactor<T> & y, Prob::DistType dt );
46 template<typename T> Real KL_dist( const TFactor<T> & p, const TFactor<T> & q );
47 template<typename T> Real MutualInfo( const TFactor<T> & p );
48 template<typename T> TFactor<T> max( const TFactor<T> & P, const TFactor<T> & Q );
49 template<typename T> TFactor<T> min( const TFactor<T> & P, const TFactor<T> & Q );
50 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& P);
51
52
53 // T should be castable from and to double
54 template <typename T> class TFactor {
55 private:
56 VarSet _vs;
57 TProb<T> _p;
58
59 public:
60 // Construct Factor with empty VarSet but nonempty _p
61 TFactor ( Real p = 1.0 ) : _vs(), _p(1,p) {}
62
63 // Construct Factor from VarSet
64 TFactor( const VarSet& ns ) : _vs(ns), _p(nrStates(_vs)) {}
65
66 // Construct Factor from VarSet and initial value
67 TFactor( const VarSet& ns, Real p ) : _vs(ns), _p(nrStates(_vs),p) {}
68
69 // Construct Factor from VarSet and initial array
70 TFactor( const VarSet& ns, const Real* p ) : _vs(ns), _p(nrStates(_vs),p) {}
71
72 // Construct Factor from VarSet and TProb<T>
73 TFactor( const VarSet& ns, const TProb<T>& p ) : _vs(ns), _p(p) {
74 #ifdef DAI_DEBUG
75 assert( nrStates(_vs) == _p.size() );
76 #endif
77 }
78
79 // Construct Factor from Var
80 TFactor( const Var& n ) : _vs(n), _p(n.states()) {}
81
82 // Copy constructor
83 TFactor( const TFactor<T> &x ) : _vs(x._vs), _p(x._p) {}
84
85 // Assignment operator
86 TFactor<T> & operator= (const TFactor<T> &x) {
87 if( this != &x ) {
88 _vs = x._vs;
89 _p = x._p;
90 }
91 return *this;
92 }
93
94 const TProb<T> & p() const { return _p; }
95 TProb<T> & p() { return _p; }
96 const VarSet & vars() const { return _vs; }
97 size_t states() const { return _p.size(); }
98
99 T operator[] (size_t i) const { return _p[i]; }
100 T& operator[] (size_t i) { return _p[i]; }
101 TFactor<T> & fill (T p)
102 { _p.fill( p ); return(*this); }
103 TFactor<T> & randomize ()
104 { _p.randomize(); return(*this); }
105 TFactor<T> operator* (T x) const {
106 Factor result = *this;
107 result.p() *= x;
108 return result;
109 }
110 TFactor<T>& operator*= (T x) {
111 _p *= x;
112 return *this;
113 }
114 TFactor<T> operator/ (T x) const {
115 Factor result = *this;
116 result.p() /= x;
117 return result;
118 }
119 TFactor<T>& operator/= (T x) {
120 _p /= x;
121 return *this;
122 }
123 TFactor<T> operator* (const TFactor<T>& Q) const;
124 TFactor<T> operator/ (const TFactor<T>& Q) const;
125 TFactor<T>& operator*= (const TFactor<T>& Q) { return( *this = (*this * Q) ); }
126 TFactor<T>& operator/= (const TFactor<T>& Q) { return( *this = (*this / Q) ); }
127 TFactor<T> operator+ (const TFactor<T>& Q) const {
128 #ifdef DAI_DEBUG
129 assert( Q._vs == _vs );
130 #endif
131 TFactor<T> sum(*this);
132 sum._p += Q._p;
133 return sum;
134 }
135 TFactor<T> operator- (const TFactor<T>& Q) const {
136 #ifdef DAI_DEBUG
137 assert( Q._vs == _vs );
138 #endif
139 TFactor<T> sum(*this);
140 sum._p -= Q._p;
141 return sum;
142 }
143 TFactor<T>& operator+= (const TFactor<T>& Q) {
144 #ifdef DAI_DEBUG
145 assert( Q._vs == _vs );
146 #endif
147 _p += Q._p;
148 return *this;
149 }
150 TFactor<T>& operator-= (const TFactor<T>& Q) {
151 #ifdef DAI_DEBUG
152 assert( Q._vs == _vs );
153 #endif
154 _p -= Q._p;
155 return *this;
156 }
157 TFactor<T>& operator+= (T q) {
158 _p += q;
159 return *this;
160 }
161 TFactor<T>& operator-= (T q) {
162 _p -= q;
163 return *this;
164 }
165 TFactor<T> operator+ (T q) const {
166 TFactor<T> result(*this);
167 result._p += q;
168 return result;
169 }
170 TFactor<T> operator- (T q) const {
171 TFactor<T> result(*this);
172 result._p -= q;
173 return result;
174 }
175
176 TFactor<T> operator^ (Real a) const { TFactor<T> x; x._vs = _vs; x._p = _p^a; return x; }
177 TFactor<T>& operator^= (Real a) { _p ^= a; return *this; }
178
179 TFactor<T>& makeZero( Real epsilon ) {
180 _p.makeZero( epsilon );
181 return *this;
182 }
183
184 TFactor<T>& makePositive( Real epsilon ) {
185 _p.makePositive( epsilon );
186 return *this;
187 }
188
189 TFactor<T> inverse() const {
190 TFactor<T> inv;
191 inv._vs = _vs;
192 inv._p = _p.inverse(true); // FIXME
193 return inv;
194 }
195
196 TFactor<T> divided_by( const TFactor<T>& denom ) const {
197 #ifdef DAI_DEBUG
198 assert( denom._vs == _vs );
199 #endif
200 TFactor<T> quot(*this);
201 quot._p /= denom._p;
202 return quot;
203 }
204
205 TFactor<T>& divide( const TFactor<T>& denom ) {
206 #ifdef DAI_DEBUG
207 assert( denom._vs == _vs );
208 #endif
209 _p /= denom._p;
210 return *this;
211 }
212
213 TFactor<T> exp() const {
214 TFactor<T> e;
215 e._vs = _vs;
216 e._p = _p.exp();
217 return e;
218 }
219
220 TFactor<T> abs() const {
221 TFactor<T> e;
222 e._vs = _vs;
223 e._p = _p.abs();
224 return e;
225 }
226
227 TFactor<T> log() const {
228 TFactor<T> l;
229 l._vs = _vs;
230 l._p = _p.log();
231 return l;
232 }
233
234 TFactor<T> log0() const {
235 TFactor<T> l0;
236 l0._vs = _vs;
237 l0._p = _p.log0();
238 return l0;
239 }
240
241 T normalize( typename Prob::NormType norm = Prob::NORMPROB ) { return _p.normalize( norm ); }
242 TFactor<T> normalized( typename Prob::NormType norm = Prob::NORMPROB ) const {
243 TFactor<T> result;
244 result._vs = _vs;
245 result._p = _p.normalized( norm );
246 return result;
247 }
248
249 // returns slice of this factor where the subset ns is in state ns_state
250 Factor slice( const VarSet & ns, size_t ns_state ) const {
251 assert( ns << _vs );
252 VarSet nsrem = _vs / ns;
253 Factor result( nsrem, 0.0 );
254
255 // OPTIMIZE ME
256 IndexFor i_ns (ns, _vs);
257 IndexFor i_nsrem (nsrem, _vs);
258 for( size_t i = 0; i < states(); i++, ++i_ns, ++i_nsrem )
259 if( (size_t)i_ns == ns_state )
260 result._p[i_nsrem] = _p[i];
261
262 return result;
263 }
264
265 // returns unnormalized marginal; ns should be a subset of vars()
266 TFactor<T> partSum(const VarSet & ns) const;
267 // returns (normalized by default) marginal; ns should be a subset of vars()
268 TFactor<T> marginal(const VarSet & ns, bool normed = true) const { if(normed) return partSum(ns).normalized(); else return partSum(ns); }
269 // sums out all variables except those in ns
270 TFactor<T> notSum(const VarSet & ns) const { return partSum(vars() ^ ns); }
271
272 // embeds this factor in larger varset ns
273 TFactor<T> embed(const VarSet & ns) const {
274 VarSet vs = vars();
275 assert( ns >> vs );
276 if( vs == ns )
277 return *this;
278 else
279 return (*this) * Factor(ns / vs, 1.0);
280 }
281
282 bool hasNaNs() const { return _p.hasNaNs(); }
283 bool hasNegatives() const { return _p.hasNegatives(); }
284 T totalSum() const { return _p.totalSum(); }
285 T maxAbs() const { return _p.maxAbs(); }
286 T maxVal() const { return _p.maxVal(); }
287 T minVal() const { return _p.minVal(); }
288 Real entropy() const { return _p.entropy(); }
289 T strength( const Var &i, const Var &j ) const;
290
291 friend Real dist( const TFactor<T> & x, const TFactor<T> & y, Prob::DistType dt ) {
292 if( x._vs.empty() || y._vs.empty() )
293 return -1;
294 else {
295 #ifdef DAI_DEBUG
296 assert( x._vs == y._vs );
297 #endif
298 return dist( x._p, y._p, dt );
299 }
300 }
301 friend Real KL_dist <> (const TFactor<T> & p, const TFactor<T> & q);
302 friend Real MutualInfo <> ( const TFactor<T> & P );
303 template<class U> friend std::ostream& operator<< (std::ostream& os, const TFactor<U>& P);
304 };
305
306
307 template<typename T> TFactor<T> TFactor<T>::partSum(const VarSet & ns) const {
308 #ifdef DAI_DEBUG
309 assert( ns << _vs );
310 #endif
311
312 TFactor<T> res( ns, 0.0 );
313
314 IndexFor i_res( ns, _vs );
315 for( size_t i = 0; i < _p.size(); i++, ++i_res )
316 res._p[i_res] += _p[i];
317
318 return res;
319 }
320
321
322 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& P) {
323 os << "(" << P.vars() << " <";
324 for( size_t i = 0; i < P._p.size(); i++ )
325 os << P._p[i] << " ";
326 os << ">)";
327 return os;
328 }
329
330
331 template<typename T> TFactor<T> TFactor<T>::operator* (const TFactor<T>& Q) const {
332 TFactor<T> prod( _vs | Q._vs, 0.0 );
333
334 IndexFor i1(_vs, prod._vs);
335 IndexFor i2(Q._vs, prod._vs);
336
337 for( size_t i = 0; i < prod._p.size(); i++, ++i1, ++i2 )
338 prod._p[i] += _p[i1] * Q._p[i2];
339
340 return prod;
341 }
342
343
344 template<typename T> TFactor<T> TFactor<T>::operator/ (const TFactor<T>& Q) const {
345 TFactor<T> quot( _vs + Q._vs, 0.0 );
346
347 IndexFor i1(_vs, quot._vs);
348 IndexFor i2(Q._vs, quot._vs);
349
350 for( size_t i = 0; i < quot._p.size(); i++, ++i1, ++i2 )
351 quot._p[i] += _p[i1] / Q._p[i2];
352
353 return quot;
354 }
355
356
357 template<typename T> Real KL_dist(const TFactor<T> & P, const TFactor<T> & Q) {
358 if( P._vs.empty() || Q._vs.empty() )
359 return -1;
360 else {
361 #ifdef DAI_DEBUG
362 assert( P._vs == Q._vs );
363 #endif
364 return KL_dist( P._p, Q._p );
365 }
366 }
367
368
369 // calculate mutual information of x_i and x_j where P.vars() = \{x_i,x_j\}
370 template<typename T> Real MutualInfo(const TFactor<T> & P) {
371 assert( P._vs.size() == 2 );
372 VarSet::const_iterator it = P._vs.begin();
373 Var i = *it; it++; Var j = *it;
374 TFactor<T> projection = P.marginal(i) * P.marginal(j);
375 return real( KL_dist( P.normalized(), projection ) );
376 }
377
378
379 template<typename T> TFactor<T> max( const TFactor<T> & P, const TFactor<T> & Q ) {
380 assert( P._vs == Q._vs );
381 return TFactor<T>( P._vs, min( P.p(), Q.p() ) );
382 }
383
384 template<typename T> TFactor<T> min( const TFactor<T> & P, const TFactor<T> & Q ) {
385 assert( P._vs == Q._vs );
386 return TFactor<T>( P._vs, max( P.p(), Q.p() ) );
387 }
388
389 // calculate N(psi, i, j)
390 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const {
391 #ifdef DAI_DEBUG
392 assert( _vs.contains( i ) );
393 assert( _vs.contains( j ) );
394 assert( i != j );
395 #endif
396 VarSet ij(i, j);
397
398 T max = 0.0;
399 for( size_t alpha1 = 0; alpha1 < i.states(); alpha1++ )
400 for( size_t alpha2 = 0; alpha2 < i.states(); alpha2++ )
401 if( alpha2 != alpha1 )
402 for( size_t beta1 = 0; beta1 < j.states(); beta1++ )
403 for( size_t beta2 = 0; beta2 < j.states(); beta2++ )
404 if( beta2 != beta1 ) {
405 size_t as = 1, bs = 1;
406 if( i < j )
407 bs = i.states();
408 else
409 as = j.states();
410 T f1 = slice( ij, alpha1 * as + beta1 * bs ).p().divide( slice( ij, alpha2 * as + beta1 * bs ).p() ).maxVal();
411 T f2 = slice( ij, alpha2 * as + beta2 * bs ).p().divide( slice( ij, alpha1 * as + beta2 * bs ).p() ).maxVal();
412 T f = f1 * f2;
413 if( f > max )
414 max = f;
415 }
416
417 return std::tanh( 0.25 * std::log( max ) );
418 }
419
420
421 template<typename T> TFactor<T> RemoveFirstOrderInteractions( const TFactor<T> & psi ) {
422 TFactor<T> result = psi;
423
424 VarSet vars = psi.vars();
425 for( size_t iter = 0; iter < 100; iter++ ) {
426 for( VarSet::const_iterator n = vars.begin(); n != vars.end(); n++ )
427 result = result * result.partSum(*n).inverse();
428 result.normalize();
429 }
430
431 return result;
432 }
433
434
435 } // end of namespace dai
436
437
438 #endif