Merged var.h and varset.h from SVN head
[libdai.git] / include / dai / factor.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
3 Radboud University Nijmegen, The Netherlands
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #ifndef __defined_libdai_factor_h
24 #define __defined_libdai_factor_h
25
26
27 #include <iostream>
28 #include <cmath>
29 #include <dai/prob.h>
30 #include <dai/varset.h>
31 #include <dai/index.h>
32
33
34 namespace dai {
35
36
37 template<typename T> class TFactor;
38 typedef TFactor<Real> Factor;
39 typedef TFactor<Complex> CFactor;
40
41
42 // predefine friends
43 template<typename T> Real dist( const TFactor<T> & x, const TFactor<T> & y, Prob::DistType dt );
44 template<typename T> Complex KL_dist( const TFactor<T> & p, const TFactor<T> & q );
45 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& P);
46
47
48 // T should be castable from and to double and to complex
49 template <typename T> class TFactor {
50 protected:
51 VarSet _vs;
52 TProb<T> _p;
53
54 public:
55 // Default constructor
56 TFactor () : _vs(), _p(1,1.0) {}
57
58 // Construct Factor from VarSet
59 TFactor( const VarSet& ns ) : _vs(ns), _p(_vs.states()) {}
60
61 // Construct Factor from VarSet and initial value
62 TFactor( const VarSet& ns, Real p ) : _vs(ns), _p(_vs.states(),p) {}
63
64 // Construct Factor from VarSet and initial array
65 TFactor( const VarSet& ns, const Real* p ) : _vs(ns), _p(_vs.states(),p) {}
66
67 // Construct Factor from VarSet and TProb<T>
68 TFactor( const VarSet& ns, const TProb<T> p ) : _vs(ns), _p(p) {
69 #ifdef DAI_DEBUG
70 assert( _vs.states() == _p.size() );
71 #endif
72 }
73
74 // Construct Factor from Var
75 TFactor( const Var& n ) : _vs(n), _p(n.states()) {}
76
77 // Copy constructor
78 TFactor( const TFactor<T> &x ) : _vs(x._vs), _p(x._p) {}
79
80 // Assignment operator
81 TFactor<T> & operator= (const TFactor<T> &x) {
82 if( this != &x ) {
83 _vs = x._vs;
84 _p = x._p;
85 }
86 return *this;
87 }
88
89 const TProb<T> & p() const { return _p; }
90 TProb<T> & p() { return _p; }
91 const VarSet & vars() const { return _vs; }
92 size_t states() const {
93 #ifdef DAI_DEBUG
94 assert( _vs.states() == _p.size() );
95 #endif
96 return _p.size();
97 }
98
99 T operator[] (size_t i) const { return _p[i]; }
100 T& operator[] (size_t i) { return _p[i]; }
101 TFactor<T> & fill (T p)
102 { _p.fill( p ); return(*this); }
103 TFactor<T> & randomize ()
104 { _p.randomize(); return(*this); }
105 TFactor<T> operator* (T x) const {
106 Factor result = *this;
107 result.p() *= x;
108 return result;
109 }
110 TFactor<T>& operator*= (T x) {
111 _p *= x;
112 return *this;
113 }
114 TFactor<T> operator/ (T x) const {
115 Factor result = *this;
116 result.p() /= x;
117 return result;
118 }
119 TFactor<T>& operator/= (T x) {
120 _p /= x;
121 return *this;
122 }
123 TFactor<T> operator* (const TFactor<T>& Q) const;
124 TFactor<T>& operator*= (const TFactor<T>& Q) { return( *this = (*this * Q) ); }
125 TFactor<T> operator+ (const TFactor<T>& Q) const {
126 #ifdef DAI_DEBUG
127 assert( Q._vs == _vs );
128 #endif
129 TFactor<T> sum(*this);
130 sum._p += Q._p;
131 return sum;
132 }
133 TFactor<T> operator- (const TFactor<T>& Q) const {
134 #ifdef DAI_DEBUG
135 assert( Q._vs == _vs );
136 #endif
137 TFactor<T> sum(*this);
138 sum._p -= Q._p;
139 return sum;
140 }
141 TFactor<T>& operator+= (const TFactor<T>& Q) {
142 #ifdef DAI_DEBUG
143 assert( Q._vs == _vs );
144 #endif
145 _p += Q._p;
146 return *this;
147 }
148 TFactor<T>& operator-= (const TFactor<T>& Q) {
149 #ifdef DAI_DEBUG
150 assert( Q._vs == _vs );
151 #endif
152 _p -= Q._p;
153 return *this;
154 }
155
156 TFactor<T> operator^ (Real a) const { TFactor<T> x; x._vs = _vs; x._p = _p^a; return x; }
157 TFactor<T>& operator^= (Real a) { _p ^= a; return *this; }
158
159 TFactor<T>& makeZero( Real epsilon ) {
160 _p.makeZero( epsilon );
161 return *this;
162 }
163
164 TFactor<T> inverse() const {
165 TFactor<T> inv;
166 inv._vs = _vs;
167 inv._p = _p.inverse(true); // FIXME
168 return inv;
169 }
170
171 TFactor<T> divided_by( const TFactor<T>& denom ) const {
172 #ifdef DAI_DEBUG
173 assert( denom._vs == _vs );
174 #endif
175 TFactor<T> quot(*this);
176 quot._p /= denom._p;
177 return quot;
178 }
179
180 TFactor<T>& divide( const TFactor<T>& denom ) {
181 #ifdef DAI_DEBUG
182 assert( denom._vs == _vs );
183 #endif
184 _p /= denom._p;
185 return *this;
186 }
187
188 TFactor<T> exp() const {
189 TFactor<T> e;
190 e._vs = _vs;
191 e._p = _p.exp();
192 return e;
193 }
194
195 TFactor<T> log() const {
196 TFactor<T> l;
197 l._vs = _vs;
198 l._p = _p.log();
199 return l;
200 }
201
202 TFactor<T> log0() const {
203 TFactor<T> l0;
204 l0._vs = _vs;
205 l0._p = _p.log0();
206 return l0;
207 }
208
209 CFactor clog0() const {
210 CFactor l0;
211 l0._vs = _vs;
212 l0._p = _p.clog0();
213 return l0;
214 }
215
216 T normalize( typename Prob::NormType norm ) { return _p.normalize( norm ); }
217 TFactor<T> normalized( typename Prob::NormType norm ) const {
218 TFactor<T> result;
219 result._vs = _vs;
220 result._p = _p.normalized( norm );
221 return result;
222 }
223
224 // returns slice of this factor where the subset ns is in state ns_state
225 Factor slice( const VarSet & ns, size_t ns_state ) const {
226 assert( ns << _vs );
227 VarSet nsrem = _vs / ns;
228 Factor result( nsrem, 0.0 );
229
230 // OPTIMIZE ME
231 IndexFor i_ns (ns, _vs);
232 IndexFor i_nsrem (nsrem, _vs);
233 for( size_t i = 0; i < states(); i++, ++i_ns, ++i_nsrem )
234 if( (size_t)i_ns == ns_state )
235 result._p[i_nsrem] = _p[i];
236
237 return result;
238 }
239
240 // returns unnormalized marginal
241 TFactor<T> part_sum(const VarSet & ns) const;
242 // returns normalized marginal
243 TFactor<T> marginal(const VarSet & ns) const { return part_sum(ns).normalized( Prob::NORMPROB ); }
244
245 bool hasNaNs() const { return _p.hasNaNs(); }
246 bool hasNegatives() const { return _p.hasNegatives(); }
247 T totalSum() const { return _p.totalSum(); }
248 T maxAbs() const { return _p.maxAbs(); }
249 T maxVal() const { return _p.maxVal(); }
250 Complex entropy() const { return _p.entropy(); }
251 T strength( const Var &i, const Var &j ) const;
252
253 friend Real dist( const TFactor<T> & x, const TFactor<T> & y, Prob::DistType dt ) {
254 if( x._vs.empty() || y._vs.empty() )
255 return -1;
256 else {
257 #ifdef DAI_DEBUG
258 assert( x._vs == y._vs );
259 #endif
260 return dist( x._p, y._p, dt );
261 }
262 }
263 friend Complex KL_dist <> (const TFactor<T> & p, const TFactor<T> & q);
264 template<class U> friend std::ostream& operator<< (std::ostream& os, const TFactor<U>& P);
265 };
266
267
268 template<typename T> TFactor<T> TFactor<T>::part_sum(const VarSet & ns) const {
269 #ifdef DAI_DEBUG
270 assert( ns << _vs );
271 #endif
272
273 TFactor<T> res( ns, 0.0 );
274
275 IndexFor i_res( ns, _vs );
276 for( size_t i = 0; i < _p.size(); i++, ++i_res )
277 res._p[i_res] += _p[i];
278
279 return res;
280 }
281
282
283 template<typename T> std::ostream& operator<< (std::ostream& os, const TFactor<T>& P) {
284 os << "(" << P.vars() << " <";
285 for( size_t i = 0; i < P._p.size(); i++ )
286 os << P._p[i] << " ";
287 os << ">)";
288 return os;
289 }
290
291
292 template<typename T> TFactor<T> TFactor<T>::operator* (const TFactor<T>& Q) const {
293 TFactor<T> prod( _vs | Q._vs, 0.0 );
294
295 IndexFor i1(_vs, prod._vs);
296 IndexFor i2(Q._vs, prod._vs);
297
298 for( size_t i = 0; i < prod._p.size(); i++, ++i1, ++i2 )
299 prod._p[i] += _p[i1] * Q._p[i2];
300
301 return prod;
302 }
303
304
305 template<typename T> Complex KL_dist(const TFactor<T> & P, const TFactor<T> & Q) {
306 if( P._vs.empty() || Q._vs.empty() )
307 return -1;
308 else {
309 #ifdef DAI_DEBUG
310 assert( P._vs == Q._vs );
311 #endif
312 return KL_dist( P._p, Q._p );
313 }
314 }
315
316
317 // calculate N(psi, i, j)
318 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const {
319 #ifdef DAI_DEBUG
320 assert( _vs.contains( i ) );
321 assert( _vs.contains( j ) );
322 assert( i != j );
323 #endif
324 VarSet ij = i | j;
325
326 T max = 0.0;
327 for( size_t alpha1 = 0; alpha1 < i.states(); alpha1++ )
328 for( size_t alpha2 = 0; alpha2 < i.states(); alpha2++ )
329 if( alpha2 != alpha1 )
330 for( size_t beta1 = 0; beta1 < j.states(); beta1++ )
331 for( size_t beta2 = 0; beta2 < j.states(); beta2++ )
332 if( beta2 != beta1 ) {
333 size_t as = 1, bs = 1;
334 if( i < j )
335 bs = i.states();
336 else
337 as = j.states();
338 T f1 = slice( ij, alpha1 * as + beta1 * bs ).p().divide( slice( ij, alpha2 * as + beta1 * bs ).p() ).maxVal();
339 T f2 = slice( ij, alpha2 * as + beta2 * bs ).p().divide( slice( ij, alpha1 * as + beta2 * bs ).p() ).maxVal();
340 T f = f1 * f2;
341 if( f > max )
342 max = f;
343 }
344
345 return std::tanh( 0.25 * std::log( max ) );
346 }
347
348
349 template<typename T> TFactor<T> RemoveFirstOrderInteractions( const TFactor<T> & psi ) {
350 TFactor<T> result = psi;
351
352 VarSet vars = psi.vars();
353 for( size_t iter = 0; iter < 100; iter++ ) {
354 for( VarSet::const_iterator n = vars.begin(); n != vars.end(); n++ )
355 result = result * result.part_sum(*n).inverse();
356 result.normalize( Prob::NORMPROB );
357 }
358
359 return result;
360 }
361
362
363 } // end of namespace dai
364
365
366 #endif