Replaced all "protected:" by "private:" or "public:"
[libdai.git] / include / dai / factor.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
3 Radboud University Nijmegen, The Netherlands
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #ifndef __defined_libdai_factor_h
24 #define __defined_libdai_factor_h
25
26
27 #include <iostream>
28 #include <cmath>
29 #include <dai/prob.h>
30 #include <dai/varset.h>
31 #include <dai/index.h>
32
33
34 namespace dai {
35
36
37 template<typename T> class TFactor;
38 typedef TFactor<Real> Factor;
39
40
41 // predefine friends
42 template<typename T> Real dist( const TFactor<T> & x, const TFactor<T> & y, Prob::DistType dt );
43 template<typename T> Real KL_dist( const TFactor<T> & p, const TFactor<T> & q );
44 template<typename T> Real MutualInfo( const TFactor<T> & p );
45 template<typename T> TFactor<T> max( const TFactor<T> & P, const TFactor<T> & Q );
46 template<typename T> TFactor<T> min( const TFactor<T> & P, const TFactor<T> & Q );
47 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& P);
48
49
50 // T should be castable from and to double
51 template <typename T> class TFactor {
52 private:
53 VarSet _vs;
54 TProb<T> _p;
55
56 public:
57 // Construct Factor with empty VarSet but nonempty _p
58 TFactor ( Real p = 1.0 ) : _vs(), _p(1,p) {}
59
60 // Construct Factor from VarSet
61 TFactor( const VarSet& ns ) : _vs(ns), _p(_vs.states()) {}
62
63 // Construct Factor from VarSet and initial value
64 TFactor( const VarSet& ns, Real p ) : _vs(ns), _p(_vs.states(),p) {}
65
66 // Construct Factor from VarSet and initial array
67 TFactor( const VarSet& ns, const Real* p ) : _vs(ns), _p(_vs.states(),p) {}
68
69 // Construct Factor from VarSet and TProb<T>
70 TFactor( const VarSet& ns, const TProb<T>& p ) : _vs(ns), _p(p) {
71 #ifdef DAI_DEBUG
72 assert( _vs.states() == _p.size() );
73 #endif
74 }
75
76 // Construct Factor from Var
77 TFactor( const Var& n ) : _vs(n), _p(n.states()) {}
78
79 // Copy constructor
80 TFactor( const TFactor<T> &x ) : _vs(x._vs), _p(x._p) {}
81
82 // Assignment operator
83 TFactor<T> & operator= (const TFactor<T> &x) {
84 if( this != &x ) {
85 _vs = x._vs;
86 _p = x._p;
87 }
88 return *this;
89 }
90
91 const TProb<T> & p() const { return _p; }
92 TProb<T> & p() { return _p; }
93 const VarSet & vars() const { return _vs; }
94 size_t states() const {
95 #ifdef DAI_DEBUG
96 assert( _vs.states() == _p.size() );
97 #endif
98 return _p.size();
99 }
100
101 T operator[] (size_t i) const { return _p[i]; }
102 T& operator[] (size_t i) { return _p[i]; }
103 TFactor<T> & fill (T p)
104 { _p.fill( p ); return(*this); }
105 TFactor<T> & randomize ()
106 { _p.randomize(); return(*this); }
107 TFactor<T> operator* (T x) const {
108 Factor result = *this;
109 result.p() *= x;
110 return result;
111 }
112 TFactor<T>& operator*= (T x) {
113 _p *= x;
114 return *this;
115 }
116 TFactor<T> operator/ (T x) const {
117 Factor result = *this;
118 result.p() /= x;
119 return result;
120 }
121 TFactor<T>& operator/= (T x) {
122 _p /= x;
123 return *this;
124 }
125 TFactor<T> operator* (const TFactor<T>& Q) const;
126 TFactor<T> operator/ (const TFactor<T>& Q) const;
127 TFactor<T>& operator*= (const TFactor<T>& Q) { return( *this = (*this * Q) ); }
128 TFactor<T>& operator/= (const TFactor<T>& Q) { return( *this = (*this / Q) ); }
129 TFactor<T> operator+ (const TFactor<T>& Q) const {
130 #ifdef DAI_DEBUG
131 assert( Q._vs == _vs );
132 #endif
133 TFactor<T> sum(*this);
134 sum._p += Q._p;
135 return sum;
136 }
137 TFactor<T> operator- (const TFactor<T>& Q) const {
138 #ifdef DAI_DEBUG
139 assert( Q._vs == _vs );
140 #endif
141 TFactor<T> sum(*this);
142 sum._p -= Q._p;
143 return sum;
144 }
145 TFactor<T>& operator+= (const TFactor<T>& Q) {
146 #ifdef DAI_DEBUG
147 assert( Q._vs == _vs );
148 #endif
149 _p += Q._p;
150 return *this;
151 }
152 TFactor<T>& operator-= (const TFactor<T>& Q) {
153 #ifdef DAI_DEBUG
154 assert( Q._vs == _vs );
155 #endif
156 _p -= Q._p;
157 return *this;
158 }
159 TFactor<T>& operator+= (T q) {
160 _p += q;
161 return *this;
162 }
163 TFactor<T>& operator-= (T q) {
164 _p -= q;
165 return *this;
166 }
167 TFactor<T> operator+ (T q) const {
168 TFactor<T> result(*this);
169 result._p += q;
170 return result;
171 }
172 TFactor<T> operator- (T q) const {
173 TFactor<T> result(*this);
174 result._p -= q;
175 return result;
176 }
177
178 TFactor<T> operator^ (Real a) const { TFactor<T> x; x._vs = _vs; x._p = _p^a; return x; }
179 TFactor<T>& operator^= (Real a) { _p ^= a; return *this; }
180
181 TFactor<T>& makeZero( Real epsilon ) {
182 _p.makeZero( epsilon );
183 return *this;
184 }
185
186 TFactor<T>& makePositive( Real epsilon ) {
187 _p.makePositive( epsilon );
188 return *this;
189 }
190
191 TFactor<T> inverse() const {
192 TFactor<T> inv;
193 inv._vs = _vs;
194 inv._p = _p.inverse(true); // FIXME
195 return inv;
196 }
197
198 TFactor<T> divided_by( const TFactor<T>& denom ) const {
199 #ifdef DAI_DEBUG
200 assert( denom._vs == _vs );
201 #endif
202 TFactor<T> quot(*this);
203 quot._p /= denom._p;
204 return quot;
205 }
206
207 TFactor<T>& divide( const TFactor<T>& denom ) {
208 #ifdef DAI_DEBUG
209 assert( denom._vs == _vs );
210 #endif
211 _p /= denom._p;
212 return *this;
213 }
214
215 TFactor<T> exp() const {
216 TFactor<T> e;
217 e._vs = _vs;
218 e._p = _p.exp();
219 return e;
220 }
221
222 TFactor<T> abs() const {
223 TFactor<T> e;
224 e._vs = _vs;
225 e._p = _p.abs();
226 return e;
227 }
228
229 TFactor<T> log() const {
230 TFactor<T> l;
231 l._vs = _vs;
232 l._p = _p.log();
233 return l;
234 }
235
236 TFactor<T> log0() const {
237 TFactor<T> l0;
238 l0._vs = _vs;
239 l0._p = _p.log0();
240 return l0;
241 }
242
243 T normalize( typename Prob::NormType norm = Prob::NORMPROB ) { return _p.normalize( norm ); }
244 TFactor<T> normalized( typename Prob::NormType norm = Prob::NORMPROB ) const {
245 TFactor<T> result;
246 result._vs = _vs;
247 result._p = _p.normalized( norm );
248 return result;
249 }
250
251 // returns slice of this factor where the subset ns is in state ns_state
252 Factor slice( const VarSet & ns, size_t ns_state ) const {
253 assert( ns << _vs );
254 VarSet nsrem = _vs / ns;
255 Factor result( nsrem, 0.0 );
256
257 // OPTIMIZE ME
258 IndexFor i_ns (ns, _vs);
259 IndexFor i_nsrem (nsrem, _vs);
260 for( size_t i = 0; i < states(); i++, ++i_ns, ++i_nsrem )
261 if( (size_t)i_ns == ns_state )
262 result._p[i_nsrem] = _p[i];
263
264 return result;
265 }
266
267 // returns unnormalized marginal; ns should be a subset of vars()
268 TFactor<T> partSum(const VarSet & ns) const;
269 // returns (normalized by default) marginal; ns should be a subset of vars()
270 TFactor<T> marginal(const VarSet & ns, bool normed = true) const { if(normed) return partSum(ns).normalized(); else return partSum(ns); }
271 // sums out all variables except those in ns
272 TFactor<T> notSum(const VarSet & ns) const { return partSum(vars() ^ ns); }
273
274 // embeds this factor in larger varset ns
275 TFactor<T> embed(const VarSet & ns) const {
276 VarSet vs = vars();
277 assert( ns >> vs );
278 if( vs == ns )
279 return *this;
280 else
281 return (*this) * Factor(ns / vs, 1.0);
282 }
283
284 bool hasNaNs() const { return _p.hasNaNs(); }
285 bool hasNegatives() const { return _p.hasNegatives(); }
286 T totalSum() const { return _p.totalSum(); }
287 T maxAbs() const { return _p.maxAbs(); }
288 T maxVal() const { return _p.maxVal(); }
289 T minVal() const { return _p.minVal(); }
290 Real entropy() const { return _p.entropy(); }
291 T strength( const Var &i, const Var &j ) const;
292
293 friend Real dist( const TFactor<T> & x, const TFactor<T> & y, Prob::DistType dt ) {
294 if( x._vs.empty() || y._vs.empty() )
295 return -1;
296 else {
297 #ifdef DAI_DEBUG
298 assert( x._vs == y._vs );
299 #endif
300 return dist( x._p, y._p, dt );
301 }
302 }
303 friend Real KL_dist <> (const TFactor<T> & p, const TFactor<T> & q);
304 friend Real MutualInfo <> ( const TFactor<T> & P );
305 template<class U> friend std::ostream& operator<< (std::ostream& os, const TFactor<U>& P);
306 };
307
308
309 template<typename T> TFactor<T> TFactor<T>::partSum(const VarSet & ns) const {
310 #ifdef DAI_DEBUG
311 assert( ns << _vs );
312 #endif
313
314 TFactor<T> res( ns, 0.0 );
315
316 IndexFor i_res( ns, _vs );
317 for( size_t i = 0; i < _p.size(); i++, ++i_res )
318 res._p[i_res] += _p[i];
319
320 return res;
321 }
322
323
324 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& P) {
325 os << "(" << P.vars() << " <";
326 for( size_t i = 0; i < P._p.size(); i++ )
327 os << P._p[i] << " ";
328 os << ">)";
329 return os;
330 }
331
332
333 template<typename T> TFactor<T> TFactor<T>::operator* (const TFactor<T>& Q) const {
334 TFactor<T> prod( _vs | Q._vs, 0.0 );
335
336 IndexFor i1(_vs, prod._vs);
337 IndexFor i2(Q._vs, prod._vs);
338
339 for( size_t i = 0; i < prod._p.size(); i++, ++i1, ++i2 )
340 prod._p[i] += _p[i1] * Q._p[i2];
341
342 return prod;
343 }
344
345
346 template<typename T> TFactor<T> TFactor<T>::operator/ (const TFactor<T>& Q) const {
347 TFactor<T> quot( _vs + Q._vs, 0.0 );
348
349 IndexFor i1(_vs, quot._vs);
350 IndexFor i2(Q._vs, quot._vs);
351
352 for( size_t i = 0; i < quot._p.size(); i++, ++i1, ++i2 )
353 quot._p[i] += _p[i1] / Q._p[i2];
354
355 return quot;
356 }
357
358
359 template<typename T> Real KL_dist(const TFactor<T> & P, const TFactor<T> & Q) {
360 if( P._vs.empty() || Q._vs.empty() )
361 return -1;
362 else {
363 #ifdef DAI_DEBUG
364 assert( P._vs == Q._vs );
365 #endif
366 return KL_dist( P._p, Q._p );
367 }
368 }
369
370
371 // calculate mutual information of x_i and x_j where P.vars() = \{x_i,x_j\}
372 template<typename T> Real MutualInfo(const TFactor<T> & P) {
373 assert( P._vs.size() == 2 );
374 VarSet::const_iterator it = P._vs.begin();
375 Var i = *it; it++; Var j = *it;
376 TFactor<T> projection = P.marginal(i) * P.marginal(j);
377 return real( KL_dist( P.normalized(), projection ) );
378 }
379
380
381 template<typename T> TFactor<T> max( const TFactor<T> & P, const TFactor<T> & Q ) {
382 assert( P._vs == Q._vs );
383 return TFactor<T>( P._vs, min( P.p(), Q.p() ) );
384 }
385
386 template<typename T> TFactor<T> min( const TFactor<T> & P, const TFactor<T> & Q ) {
387 assert( P._vs == Q._vs );
388 return TFactor<T>( P._vs, max( P.p(), Q.p() ) );
389 }
390
391 // calculate N(psi, i, j)
392 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const {
393 #ifdef DAI_DEBUG
394 assert( _vs.contains( i ) );
395 assert( _vs.contains( j ) );
396 assert( i != j );
397 #endif
398 VarSet ij(i, j);
399
400 T max = 0.0;
401 for( size_t alpha1 = 0; alpha1 < i.states(); alpha1++ )
402 for( size_t alpha2 = 0; alpha2 < i.states(); alpha2++ )
403 if( alpha2 != alpha1 )
404 for( size_t beta1 = 0; beta1 < j.states(); beta1++ )
405 for( size_t beta2 = 0; beta2 < j.states(); beta2++ )
406 if( beta2 != beta1 ) {
407 size_t as = 1, bs = 1;
408 if( i < j )
409 bs = i.states();
410 else
411 as = j.states();
412 T f1 = slice( ij, alpha1 * as + beta1 * bs ).p().divide( slice( ij, alpha2 * as + beta1 * bs ).p() ).maxVal();
413 T f2 = slice( ij, alpha2 * as + beta2 * bs ).p().divide( slice( ij, alpha1 * as + beta2 * bs ).p() ).maxVal();
414 T f = f1 * f2;
415 if( f > max )
416 max = f;
417 }
418
419 return std::tanh( 0.25 * std::log( max ) );
420 }
421
422
423 template<typename T> TFactor<T> RemoveFirstOrderInteractions( const TFactor<T> & psi ) {
424 TFactor<T> result = psi;
425
426 VarSet vars = psi.vars();
427 for( size_t iter = 0; iter < 100; iter++ ) {
428 for( VarSet::const_iterator n = vars.begin(); n != vars.end(); n++ )
429 result = result * result.partSum(*n).inverse();
430 result.normalize();
431 }
432
433 return result;
434 }
435
436
437 } // end of namespace dai
438
439
440 #endif