980154690e41cd048620b03260d5dd3c720b2cd2
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
5 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
6 Radboud University Nijmegen, The Netherlands
8 This file is part of libDAI.
10 libDAI is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
15 libDAI is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
20 You should have received a copy of the GNU General Public License
21 along with libDAI; if not, write to the Free Software
22 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
26 #ifndef __defined_libdai_factor_h
27 #define __defined_libdai_factor_h
33 #include <dai/varset.h>
34 #include <dai/index.h>
40 template<typename T
> class TFactor
;
41 typedef TFactor
<Real
> Factor
;
45 template<typename T
> Real
dist( const TFactor
<T
> & x
, const TFactor
<T
> & y
, Prob::DistType dt
);
46 template<typename T
> Real
KL_dist( const TFactor
<T
> & p
, const TFactor
<T
> & q
);
47 template<typename T
> Real
MutualInfo( const TFactor
<T
> & p
);
48 template<typename T
> TFactor
<T
> max( const TFactor
<T
> & P
, const TFactor
<T
> & Q
);
49 template<typename T
> TFactor
<T
> min( const TFactor
<T
> & P
, const TFactor
<T
> & Q
);
50 template<typename T
> std::ostream
& operator<< (std::ostream
& os
, const TFactor
<T
>& P
);
53 // T should be castable from and to double
54 template <typename T
> class TFactor
{
60 // Construct Factor with empty VarSet but nonempty _p
61 TFactor ( Real p
= 1.0 ) : _vs(), _p(1,p
) {}
63 // Construct Factor from VarSet
64 TFactor( const VarSet
& ns
) : _vs(ns
), _p(nrStates(_vs
)) {}
66 // Construct Factor from VarSet and initial value
67 TFactor( const VarSet
& ns
, Real p
) : _vs(ns
), _p(nrStates(_vs
),p
) {}
69 // Construct Factor from VarSet and initial array
70 TFactor( const VarSet
& ns
, const Real
* p
) : _vs(ns
), _p(nrStates(_vs
),p
) {}
72 // Construct Factor from VarSet and TProb<T>
73 TFactor( const VarSet
& ns
, const TProb
<T
>& p
) : _vs(ns
), _p(p
) {
75 assert( nrStates(_vs
) == _p
.size() );
79 // Construct Factor from Var
80 TFactor( const Var
& n
) : _vs(n
), _p(n
.states()) {}
83 TFactor( const TFactor
<T
> &x
) : _vs(x
._vs
), _p(x
._p
) {}
85 // Assignment operator
86 TFactor
<T
> & operator= (const TFactor
<T
> &x
) {
94 const TProb
<T
> & p() const { return _p
; }
95 TProb
<T
> & p() { return _p
; }
96 const VarSet
& vars() const { return _vs
; }
97 size_t states() const { return _p
.size(); }
99 T
operator[] (size_t i
) const { return _p
[i
]; }
100 T
& operator[] (size_t i
) { return _p
[i
]; }
101 TFactor
<T
> & fill (T p
)
102 { _p
.fill( p
); return(*this); }
103 TFactor
<T
> & randomize ()
104 { _p
.randomize(); return(*this); }
105 TFactor
<T
> operator* (T x
) const {
106 Factor result
= *this;
110 TFactor
<T
>& operator*= (T x
) {
114 TFactor
<T
> operator/ (T x
) const {
115 Factor result
= *this;
119 TFactor
<T
>& operator/= (T x
) {
123 TFactor
<T
> operator* (const TFactor
<T
>& Q
) const;
124 TFactor
<T
> operator/ (const TFactor
<T
>& Q
) const;
125 TFactor
<T
>& operator*= (const TFactor
<T
>& Q
) { return( *this = (*this * Q
) ); }
126 TFactor
<T
>& operator/= (const TFactor
<T
>& Q
) { return( *this = (*this / Q
) ); }
127 TFactor
<T
> operator+ (const TFactor
<T
>& Q
) const {
129 assert( Q
._vs
== _vs
);
131 TFactor
<T
> sum(*this);
135 TFactor
<T
> operator- (const TFactor
<T
>& Q
) const {
137 assert( Q
._vs
== _vs
);
139 TFactor
<T
> sum(*this);
143 TFactor
<T
>& operator+= (const TFactor
<T
>& Q
) {
145 assert( Q
._vs
== _vs
);
150 TFactor
<T
>& operator-= (const TFactor
<T
>& Q
) {
152 assert( Q
._vs
== _vs
);
157 TFactor
<T
>& operator+= (T q
) {
161 TFactor
<T
>& operator-= (T q
) {
165 TFactor
<T
> operator+ (T q
) const {
166 TFactor
<T
> result(*this);
170 TFactor
<T
> operator- (T q
) const {
171 TFactor
<T
> result(*this);
176 TFactor
<T
> operator^ (Real a
) const { TFactor
<T
> x
; x
._vs
= _vs
; x
._p
= _p
^a
; return x
; }
177 TFactor
<T
>& operator^= (Real a
) { _p
^= a
; return *this; }
179 TFactor
<T
>& makeZero( Real epsilon
) {
180 _p
.makeZero( epsilon
);
184 TFactor
<T
>& makePositive( Real epsilon
) {
185 _p
.makePositive( epsilon
);
189 TFactor
<T
> inverse() const {
192 inv
._p
= _p
.inverse(true); // FIXME
196 TFactor
<T
> divided_by( const TFactor
<T
>& denom
) const {
198 assert( denom
._vs
== _vs
);
200 TFactor
<T
> quot(*this);
205 TFactor
<T
>& divide( const TFactor
<T
>& denom
) {
207 assert( denom
._vs
== _vs
);
213 TFactor
<T
> exp() const {
220 TFactor
<T
> abs() const {
227 TFactor
<T
> log() const {
234 TFactor
<T
> log0() const {
241 T
normalize( typename
Prob::NormType norm
= Prob::NORMPROB
) { return _p
.normalize( norm
); }
242 TFactor
<T
> normalized( typename
Prob::NormType norm
= Prob::NORMPROB
) const {
245 result
._p
= _p
.normalized( norm
);
249 // returns slice of this factor where the subset ns is in state ns_state
250 Factor
slice( const VarSet
& ns
, size_t ns_state
) const {
252 VarSet nsrem
= _vs
/ ns
;
253 Factor
result( nsrem
, 0.0 );
256 IndexFor
i_ns (ns
, _vs
);
257 IndexFor
i_nsrem (nsrem
, _vs
);
258 for( size_t i
= 0; i
< states(); i
++, ++i_ns
, ++i_nsrem
)
259 if( (size_t)i_ns
== ns_state
)
260 result
._p
[i_nsrem
] = _p
[i
];
265 // returns unnormalized marginal; ns should be a subset of vars()
266 TFactor
<T
> partSum(const VarSet
& ns
) const;
267 // returns (normalized by default) marginal; ns should be a subset of vars()
268 TFactor
<T
> marginal(const VarSet
& ns
, bool normed
= true) const { if(normed
) return partSum(ns
).normalized(); else return partSum(ns
); }
269 // sums out all variables except those in ns
270 TFactor
<T
> notSum(const VarSet
& ns
) const { return partSum(vars() ^ ns
); }
272 // embeds this factor in larger varset ns
273 TFactor
<T
> embed(const VarSet
& ns
) const {
279 return (*this) * Factor(ns
/ vs
, 1.0);
282 bool hasNaNs() const { return _p
.hasNaNs(); }
283 bool hasNegatives() const { return _p
.hasNegatives(); }
284 T
totalSum() const { return _p
.totalSum(); }
285 T
maxAbs() const { return _p
.maxAbs(); }
286 T
maxVal() const { return _p
.maxVal(); }
287 T
minVal() const { return _p
.minVal(); }
288 Real
entropy() const { return _p
.entropy(); }
289 T
strength( const Var
&i
, const Var
&j
) const;
291 friend Real
dist( const TFactor
<T
> & x
, const TFactor
<T
> & y
, Prob::DistType dt
) {
292 if( x
._vs
.empty() || y
._vs
.empty() )
296 assert( x
._vs
== y
._vs
);
298 return dist( x
._p
, y
._p
, dt
);
301 friend Real KL_dist
<> (const TFactor
<T
> & p
, const TFactor
<T
> & q
);
302 friend Real MutualInfo
<> ( const TFactor
<T
> & P
);
303 template<class U
> friend std::ostream
& operator<< (std::ostream
& os
, const TFactor
<U
>& P
);
307 template<typename T
> TFactor
<T
> TFactor
<T
>::partSum(const VarSet
& ns
) const {
312 TFactor
<T
> res( ns
, 0.0 );
314 IndexFor
i_res( ns
, _vs
);
315 for( size_t i
= 0; i
< _p
.size(); i
++, ++i_res
)
316 res
._p
[i_res
] += _p
[i
];
322 template<typename T
> std::ostream
& operator<< (std::ostream
& os
, const TFactor
<T
>& P
) {
323 os
<< "(" << P
.vars() << " <";
324 for( size_t i
= 0; i
< P
._p
.size(); i
++ )
325 os
<< P
._p
[i
] << " ";
331 template<typename T
> TFactor
<T
> TFactor
<T
>::operator* (const TFactor
<T
>& Q
) const {
332 TFactor
<T
> prod( _vs
| Q
._vs
, 0.0 );
334 IndexFor
i1(_vs
, prod
._vs
);
335 IndexFor
i2(Q
._vs
, prod
._vs
);
337 for( size_t i
= 0; i
< prod
._p
.size(); i
++, ++i1
, ++i2
)
338 prod
._p
[i
] += _p
[i1
] * Q
._p
[i2
];
344 template<typename T
> TFactor
<T
> TFactor
<T
>::operator/ (const TFactor
<T
>& Q
) const {
345 TFactor
<T
> quot( _vs
+ Q
._vs
, 0.0 );
347 IndexFor
i1(_vs
, quot
._vs
);
348 IndexFor
i2(Q
._vs
, quot
._vs
);
350 for( size_t i
= 0; i
< quot
._p
.size(); i
++, ++i1
, ++i2
)
351 quot
._p
[i
] += _p
[i1
] / Q
._p
[i2
];
357 template<typename T
> Real
KL_dist(const TFactor
<T
> & P
, const TFactor
<T
> & Q
) {
358 if( P
._vs
.empty() || Q
._vs
.empty() )
362 assert( P
._vs
== Q
._vs
);
364 return KL_dist( P
._p
, Q
._p
);
369 // calculate mutual information of x_i and x_j where P.vars() = \{x_i,x_j\}
370 template<typename T
> Real
MutualInfo(const TFactor
<T
> & P
) {
371 assert( P
._vs
.size() == 2 );
372 VarSet::const_iterator it
= P
._vs
.begin();
373 Var i
= *it
; it
++; Var j
= *it
;
374 TFactor
<T
> projection
= P
.marginal(i
) * P
.marginal(j
);
375 return real( KL_dist( P
.normalized(), projection
) );
379 template<typename T
> TFactor
<T
> max( const TFactor
<T
> & P
, const TFactor
<T
> & Q
) {
380 assert( P
._vs
== Q
._vs
);
381 return TFactor
<T
>( P
._vs
, min( P
.p(), Q
.p() ) );
384 template<typename T
> TFactor
<T
> min( const TFactor
<T
> & P
, const TFactor
<T
> & Q
) {
385 assert( P
._vs
== Q
._vs
);
386 return TFactor
<T
>( P
._vs
, max( P
.p(), Q
.p() ) );
389 // calculate N(psi, i, j)
390 template<typename T
> T TFactor
<T
>::strength( const Var
&i
, const Var
&j
) const {
392 assert( _vs
.contains( i
) );
393 assert( _vs
.contains( j
) );
399 for( size_t alpha1
= 0; alpha1
< i
.states(); alpha1
++ )
400 for( size_t alpha2
= 0; alpha2
< i
.states(); alpha2
++ )
401 if( alpha2
!= alpha1
)
402 for( size_t beta1
= 0; beta1
< j
.states(); beta1
++ )
403 for( size_t beta2
= 0; beta2
< j
.states(); beta2
++ )
404 if( beta2
!= beta1
) {
405 size_t as
= 1, bs
= 1;
410 T f1
= slice( ij
, alpha1
* as
+ beta1
* bs
).p().divide( slice( ij
, alpha2
* as
+ beta1
* bs
).p() ).maxVal();
411 T f2
= slice( ij
, alpha2
* as
+ beta2
* bs
).p().divide( slice( ij
, alpha1
* as
+ beta2
* bs
).p() ).maxVal();
417 return std::tanh( 0.25 * std::log( max
) );
421 template<typename T
> TFactor
<T
> RemoveFirstOrderInteractions( const TFactor
<T
> & psi
) {
422 TFactor
<T
> result
= psi
;
424 VarSet vars
= psi
.vars();
425 for( size_t iter
= 0; iter
< 100; iter
++ ) {
426 for( VarSet::const_iterator n
= vars
.begin(); n
!= vars
.end(); n
++ )
427 result
= result
* result
.partSum(*n
).inverse();
435 } // end of namespace dai