Small misc changes
[libdai.git] / include / dai / factor.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_factor_h
23 #define __defined_libdai_factor_h
24
25
26 #include <iostream>
27 #include <cmath>
28 #include <dai/prob.h>
29 #include <dai/varset.h>
30 #include <dai/index.h>
31
32
33 namespace dai {
34
35
36 template<typename T> class TFactor;
37 typedef TFactor<Real> Factor;
38 typedef TFactor<Complex> CFactor;
39
40
41 // predefine friends
42 template<typename T> Real dist( const TFactor<T> & x, const TFactor<T> & y, Prob::DistType dt );
43 template<typename T> Complex KL_dist( const TFactor<T> & p, const TFactor<T> & q );
44 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& P);
45
46
47 // T should be castable from and to double and to complex
48 template <typename T> class TFactor {
49 protected:
50 VarSet _vs;
51 TProb<T> _p;
52
53 public:
54 // Default constructor
55 TFactor () : _vs(), _p(1,1.0) {}
56
57 // Construct Factor from VarSet
58 TFactor( const VarSet& ns ) : _vs(ns), _p(_vs.states()) {}
59
60 // Construct Factor from VarSet and initial value
61 TFactor( const VarSet& ns, Real p ) : _vs(ns), _p(_vs.states(),p) {}
62
63 // Construct Factor from VarSet and initial array
64 TFactor( const VarSet& ns, const Real* p ) : _vs(ns), _p(_vs.states(),p) {}
65
66 // Construct Factor from VarSet and TProb<T>
67 TFactor( const VarSet& ns, const TProb<T> p ) : _vs(ns), _p(p) {
68 #ifdef DAI_DEBUG
69 assert( _vs.states() == _p.size() );
70 #endif
71 }
72
73 // Construct Factor from Var
74 TFactor( const Var& n ) : _vs(n), _p(n.states()) {}
75
76 // Copy constructor
77 TFactor( const TFactor<T> &x ) : _vs(x._vs), _p(x._p) {}
78
79 // Assignment operator
80 TFactor<T> & operator= (const TFactor<T> &x) {
81 if( this != &x ) {
82 _vs = x._vs;
83 _p = x._p;
84 }
85 return *this;
86 }
87
88 const TProb<T> & p() const { return _p; }
89 TProb<T> & p() { return _p; }
90 const VarSet & vars() const { return _vs; }
91 size_t states() const {
92 #ifdef DAI_DEBUG
93 assert( _vs.states() == _p.size() );
94 #endif
95 return _p.size();
96 }
97
98 T operator[] (size_t i) const { return _p[i]; }
99 T& operator[] (size_t i) { return _p[i]; }
100 TFactor<T> & fill (T p)
101 { _p.fill( p ); return(*this); }
102 TFactor<T> & randomize ()
103 { _p.randomize(); return(*this); }
104 TFactor<T> operator* (T x) const {
105 Factor result = *this;
106 result.p() *= x;
107 return result;
108 }
109 TFactor<T>& operator*= (T x) {
110 _p *= x;
111 return *this;
112 }
113 TFactor<T> operator/ (T x) const {
114 Factor result = *this;
115 result.p() /= x;
116 return result;
117 }
118 TFactor<T>& operator/= (T x) {
119 _p /= x;
120 return *this;
121 }
122 TFactor<T> operator* (const TFactor<T>& Q) const;
123 TFactor<T>& operator*= (const TFactor<T>& Q) { return( *this = (*this * Q) ); }
124 TFactor<T> operator+ (const TFactor<T>& Q) const {
125 #ifdef DAI_DEBUG
126 assert( Q._vs == _vs );
127 #endif
128 TFactor<T> sum(*this);
129 sum._p += Q._p;
130 return sum;
131 }
132 TFactor<T> operator- (const TFactor<T>& Q) const {
133 #ifdef DAI_DEBUG
134 assert( Q._vs == _vs );
135 #endif
136 TFactor<T> sum(*this);
137 sum._p -= Q._p;
138 return sum;
139 }
140 TFactor<T>& operator+= (const TFactor<T>& Q) {
141 #ifdef DAI_DEBUG
142 assert( Q._vs == _vs );
143 #endif
144 _p += Q._p;
145 return *this;
146 }
147 TFactor<T>& operator-= (const TFactor<T>& Q) {
148 #ifdef DAI_DEBUG
149 assert( Q._vs == _vs );
150 #endif
151 _p -= Q._p;
152 return *this;
153 }
154
155 TFactor<T> operator^ (Real a) const { TFactor<T> x; x._vs = _vs; x._p = _p^a; return x; }
156 TFactor<T>& operator^= (Real a) { _p ^= a; return *this; }
157
158 TFactor<T>& makeZero( Real epsilon ) {
159 _p.makeZero( epsilon );
160 return *this;
161 }
162
163 TFactor<T> inverse() const {
164 TFactor<T> inv;
165 inv._vs = _vs;
166 inv._p = _p.inverse(true); // FIXME
167 return inv;
168 }
169
170 TFactor<T> divided_by( const TFactor<T>& denom ) const {
171 #ifdef DAI_DEBUG
172 assert( denom._vs == _vs );
173 #endif
174 TFactor<T> quot(*this);
175 quot._p /= denom._p;
176 return quot;
177 }
178
179 TFactor<T>& divide( const TFactor<T>& denom ) {
180 #ifdef DAI_DEBUG
181 assert( denom._vs == _vs );
182 #endif
183 _p /= denom._p;
184 return *this;
185 }
186
187 TFactor<T> exp() const {
188 TFactor<T> e;
189 e._vs = _vs;
190 e._p = _p.exp();
191 return e;
192 }
193
194 TFactor<T> log() const {
195 TFactor<T> l;
196 l._vs = _vs;
197 l._p = _p.log();
198 return l;
199 }
200
201 TFactor<T> log0() const {
202 TFactor<T> l0;
203 l0._vs = _vs;
204 l0._p = _p.log0();
205 return l0;
206 }
207
208 CFactor clog0() const {
209 CFactor l0;
210 l0._vs = _vs;
211 l0._p = _p.clog0();
212 return l0;
213 }
214
215 T normalize( typename Prob::NormType norm ) { return _p.normalize( norm ); }
216 TFactor<T> normalized( typename Prob::NormType norm ) const {
217 TFactor<T> result;
218 result._vs = _vs;
219 result._p = _p.normalized( norm );
220 return result;
221 }
222
223 // returns slice of this factor where the subset ns is in state ns_state
224 Factor slice( const VarSet & ns, size_t ns_state ) const {
225 assert( ns << _vs );
226 VarSet nsrem = _vs / ns;
227 Factor result( nsrem, 0.0 );
228
229 // OPTIMIZE ME
230 Index i_ns (ns, _vs);
231 Index i_nsrem (nsrem, _vs);
232 for( size_t i = 0; i < states(); i++, ++i_ns, ++i_nsrem )
233 if( (size_t)i_ns == ns_state )
234 result._p[i_nsrem] = _p[i];
235
236 return result;
237 }
238
239 // returns unnormalized marginal
240 TFactor<T> part_sum(const VarSet & ns) const;
241 // returns normalized marginal
242 TFactor<T> marginal(const VarSet & ns) const { return part_sum(ns).normalized( Prob::NORMPROB ); }
243
244 bool hasNaNs() const { return _p.hasNaNs(); }
245 bool hasNegatives() const { return _p.hasNegatives(); }
246 T totalSum() const { return _p.totalSum(); }
247 T maxAbs() const { return _p.maxAbs(); }
248 T max() const { return _p.max(); }
249 Complex entropy() const { return _p.entropy(); }
250 T strength( const Var &i, const Var &j ) const;
251
252 friend Real dist( const TFactor<T> & x, const TFactor<T> & y, Prob::DistType dt ) {
253 if( x._vs.empty() || y._vs.empty() )
254 return -1;
255 else {
256 #ifdef DAI_DEBUG
257 assert( x._vs == y._vs );
258 #endif
259 return dist( x._p, y._p, dt );
260 }
261 }
262 friend Complex KL_dist <> (const TFactor<T> & p, const TFactor<T> & q);
263 template<class U> friend std::ostream& operator<< (std::ostream& os, const TFactor<U>& P);
264 };
265
266
267 template<typename T> TFactor<T> TFactor<T>::part_sum(const VarSet & ns) const {
268 #ifdef DAI_DEBUG
269 assert( ns << _vs );
270 #endif
271
272 TFactor<T> res( ns, 0.0 );
273
274 Index i_res( ns, _vs );
275 for( size_t i = 0; i < _p.size(); i++, ++i_res )
276 res._p[i_res] += _p[i];
277
278 return res;
279 }
280
281
282 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& P) {
283 os << "(" << P.vars() << " <";
284 for( size_t i = 0; i < P._p.size(); i++ )
285 os << P._p[i] << " ";
286 os << ">)";
287 return os;
288 }
289
290
291 template<typename T> TFactor<T> TFactor<T>::operator* (const TFactor<T>& Q) const {
292 TFactor<T> prod( _vs | Q._vs, 0.0 );
293
294 Index i1(_vs, prod._vs);
295 Index i2(Q._vs, prod._vs);
296
297 for( size_t i = 0; i < prod._p.size(); i++, ++i1, ++i2 )
298 prod._p[i] += _p[i1] * Q._p[i2];
299
300 return prod;
301 }
302
303
304 template<typename T> Complex KL_dist(const TFactor<T> & P, const TFactor<T> & Q) {
305 if( P._vs.empty() || Q._vs.empty() )
306 return -1;
307 else {
308 #ifdef DAI_DEBUG
309 assert( P._vs == Q._vs );
310 #endif
311 return KL_dist( P._p, Q._p );
312 }
313 }
314
315
316 // calculate N(psi, i, j)
317 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const {
318 #ifdef DAI_DEBUG
319 assert( _vs && i );
320 assert( _vs && j );
321 assert( i != j );
322 #endif
323 VarSet ij = i | j;
324
325 T max = 0.0;
326 for( size_t alpha1 = 0; alpha1 < i.states(); alpha1++ )
327 for( size_t alpha2 = 0; alpha2 < i.states(); alpha2++ )
328 if( alpha2 != alpha1 )
329 for( size_t beta1 = 0; beta1 < j.states(); beta1++ )
330 for( size_t beta2 = 0; beta2 < j.states(); beta2++ )
331 if( beta2 != beta1 ) {
332 size_t as = 1, bs = 1;
333 if( i < j )
334 bs = i.states();
335 else
336 as = j.states();
337 T f1 = slice( ij, alpha1 * as + beta1 * bs ).p().divide( slice( ij, alpha2 * as + beta1 * bs ).p() ).max();
338 T f2 = slice( ij, alpha2 * as + beta2 * bs ).p().divide( slice( ij, alpha1 * as + beta2 * bs ).p() ).max();
339 T f = f1 * f2;
340 if( f > max )
341 max = f;
342 }
343
344 return std::tanh( 0.25 * std::log( max ) );
345 }
346
347
348 template<typename T> TFactor<T> RemoveFirstOrderInteractions( const TFactor<T> & psi ) {
349 TFactor<T> result = psi;
350
351 VarSet vars = psi.vars();
352 for( size_t iter = 0; iter < 100; iter++ ) {
353 for( VarSet::const_iterator n = vars.begin(); n != vars.end(); n++ )
354 result = result * result.part_sum(*n).inverse();
355 result.normalize( Prob::NORMPROB );
356 }
357
358 return result;
359 }
360
361
362 } // end of namespace dai
363
364
365 #endif