1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
3 Radboud University Nijmegen, The Netherlands
5 This file is part of libDAI.
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 #ifndef __defined_libdai_factor_h
24 #define __defined_libdai_factor_h
30 #include <dai/varset.h>
31 #include <dai/index.h>
37 template<typename T
> class TFactor
;
38 typedef TFactor
<Real
> Factor
;
42 template<typename T
> Real
dist( const TFactor
<T
> & x
, const TFactor
<T
> & y
, Prob::DistType dt
);
43 template<typename T
> Real
KL_dist( const TFactor
<T
> & p
, const TFactor
<T
> & q
);
44 template<typename T
> std::ostream
& operator<< (std::ostream
& os
, const TFactor
<T
>& P
);
47 // T should be castable from and to double
48 template <typename T
> class TFactor
{
54 // Default constructor
55 TFactor () : _vs(), _p(1,1.0) {}
57 // Construct Factor from VarSet
58 TFactor( const VarSet
& ns
) : _vs(ns
), _p(_vs
.states()) {}
60 // Construct Factor from VarSet and initial value
61 TFactor( const VarSet
& ns
, Real p
) : _vs(ns
), _p(_vs
.states(),p
) {}
63 // Construct Factor from VarSet and initial array
64 TFactor( const VarSet
& ns
, const Real
* p
) : _vs(ns
), _p(_vs
.states(),p
) {}
66 // Construct Factor from VarSet and TProb<T>
67 TFactor( const VarSet
& ns
, const TProb
<T
> p
) : _vs(ns
), _p(p
) {
69 assert( _vs
.states() == _p
.size() );
73 // Construct Factor from Var
74 TFactor( const Var
& n
) : _vs(n
), _p(n
.states()) {}
77 TFactor( const TFactor
<T
> &x
) : _vs(x
._vs
), _p(x
._p
) {}
79 // Assignment operator
80 TFactor
<T
> & operator= (const TFactor
<T
> &x
) {
88 const TProb
<T
> & p() const { return _p
; }
89 TProb
<T
> & p() { return _p
; }
90 const VarSet
& vars() const { return _vs
; }
91 size_t states() const {
93 assert( _vs
.states() == _p
.size() );
98 T
operator[] (size_t i
) const { return _p
[i
]; }
99 T
& operator[] (size_t i
) { return _p
[i
]; }
100 TFactor
<T
> & fill (T p
)
101 { _p
.fill( p
); return(*this); }
102 TFactor
<T
> & randomize ()
103 { _p
.randomize(); return(*this); }
104 TFactor
<T
> operator* (T x
) const {
105 Factor result
= *this;
109 TFactor
<T
>& operator*= (T x
) {
113 TFactor
<T
> operator/ (T x
) const {
114 Factor result
= *this;
118 TFactor
<T
>& operator/= (T x
) {
122 TFactor
<T
> operator* (const TFactor
<T
>& Q
) const;
123 TFactor
<T
>& operator*= (const TFactor
<T
>& Q
) { return( *this = (*this * Q
) ); }
124 TFactor
<T
> operator+ (const TFactor
<T
>& Q
) const {
126 assert( Q
._vs
== _vs
);
128 TFactor
<T
> sum(*this);
132 TFactor
<T
> operator- (const TFactor
<T
>& Q
) const {
134 assert( Q
._vs
== _vs
);
136 TFactor
<T
> sum(*this);
140 TFactor
<T
>& operator+= (const TFactor
<T
>& Q
) {
142 assert( Q
._vs
== _vs
);
147 TFactor
<T
>& operator-= (const TFactor
<T
>& Q
) {
149 assert( Q
._vs
== _vs
);
155 TFactor
<T
> operator^ (Real a
) const { TFactor
<T
> x
; x
._vs
= _vs
; x
._p
= _p
^a
; return x
; }
156 TFactor
<T
>& operator^= (Real a
) { _p
^= a
; return *this; }
158 TFactor
<T
>& makeZero( Real epsilon
) {
159 _p
.makeZero( epsilon
);
163 TFactor
<T
> inverse() const {
166 inv
._p
= _p
.inverse(true); // FIXME
170 TFactor
<T
> divided_by( const TFactor
<T
>& denom
) const {
172 assert( denom
._vs
== _vs
);
174 TFactor
<T
> quot(*this);
179 TFactor
<T
>& divide( const TFactor
<T
>& denom
) {
181 assert( denom
._vs
== _vs
);
187 TFactor
<T
> exp() const {
194 TFactor
<T
> log() const {
201 TFactor
<T
> log0() const {
208 T
normalize( typename
Prob::NormType norm
) { return _p
.normalize( norm
); }
209 TFactor
<T
> normalized( typename
Prob::NormType norm
) const {
212 result
._p
= _p
.normalized( norm
);
216 // returns slice of this factor where the subset ns is in state ns_state
217 Factor
slice( const VarSet
& ns
, size_t ns_state
) const {
219 VarSet nsrem
= _vs
/ ns
;
220 Factor
result( nsrem
, 0.0 );
223 IndexFor
i_ns (ns
, _vs
);
224 IndexFor
i_nsrem (nsrem
, _vs
);
225 for( size_t i
= 0; i
< states(); i
++, ++i_ns
, ++i_nsrem
)
226 if( (size_t)i_ns
== ns_state
)
227 result
._p
[i_nsrem
] = _p
[i
];
232 // returns unnormalized marginal
233 TFactor
<T
> part_sum(const VarSet
& ns
) const;
234 // returns normalized marginal
235 TFactor
<T
> marginal(const VarSet
& ns
) const { return part_sum(ns
).normalized( Prob::NORMPROB
); }
237 bool hasNaNs() const { return _p
.hasNaNs(); }
238 bool hasNegatives() const { return _p
.hasNegatives(); }
239 T
totalSum() const { return _p
.totalSum(); }
240 T
maxAbs() const { return _p
.maxAbs(); }
241 T
maxVal() const { return _p
.maxVal(); }
242 Real
entropy() const { return _p
.entropy(); }
243 T
strength( const Var
&i
, const Var
&j
) const;
245 friend Real
dist( const TFactor
<T
> & x
, const TFactor
<T
> & y
, Prob::DistType dt
) {
246 if( x
._vs
.empty() || y
._vs
.empty() )
250 assert( x
._vs
== y
._vs
);
252 return dist( x
._p
, y
._p
, dt
);
255 friend Real KL_dist
<> (const TFactor
<T
> & p
, const TFactor
<T
> & q
);
256 template<class U
> friend std::ostream
& operator<< (std::ostream
& os
, const TFactor
<U
>& P
);
260 template<typename T
> TFactor
<T
> TFactor
<T
>::part_sum(const VarSet
& ns
) const {
265 TFactor
<T
> res( ns
, 0.0 );
267 IndexFor
i_res( ns
, _vs
);
268 for( size_t i
= 0; i
< _p
.size(); i
++, ++i_res
)
269 res
._p
[i_res
] += _p
[i
];
275 template<typename T
> std::ostream
& operator<< (std::ostream
& os
, const TFactor
<T
>& P
) {
276 os
<< "(" << P
.vars() << " <";
277 for( size_t i
= 0; i
< P
._p
.size(); i
++ )
278 os
<< P
._p
[i
] << " ";
284 template<typename T
> TFactor
<T
> TFactor
<T
>::operator* (const TFactor
<T
>& Q
) const {
285 TFactor
<T
> prod( _vs
| Q
._vs
, 0.0 );
287 IndexFor
i1(_vs
, prod
._vs
);
288 IndexFor
i2(Q
._vs
, prod
._vs
);
290 for( size_t i
= 0; i
< prod
._p
.size(); i
++, ++i1
, ++i2
)
291 prod
._p
[i
] += _p
[i1
] * Q
._p
[i2
];
297 template<typename T
> Real
KL_dist(const TFactor
<T
> & P
, const TFactor
<T
> & Q
) {
298 if( P
._vs
.empty() || Q
._vs
.empty() )
302 assert( P
._vs
== Q
._vs
);
304 return KL_dist( P
._p
, Q
._p
);
309 // calculate N(psi, i, j)
310 template<typename T
> T TFactor
<T
>::strength( const Var
&i
, const Var
&j
) const {
312 assert( _vs
.contains( i
) );
313 assert( _vs
.contains( j
) );
319 for( size_t alpha1
= 0; alpha1
< i
.states(); alpha1
++ )
320 for( size_t alpha2
= 0; alpha2
< i
.states(); alpha2
++ )
321 if( alpha2
!= alpha1
)
322 for( size_t beta1
= 0; beta1
< j
.states(); beta1
++ )
323 for( size_t beta2
= 0; beta2
< j
.states(); beta2
++ )
324 if( beta2
!= beta1
) {
325 size_t as
= 1, bs
= 1;
330 T f1
= slice( ij
, alpha1
* as
+ beta1
* bs
).p().divide( slice( ij
, alpha2
* as
+ beta1
* bs
).p() ).maxVal();
331 T f2
= slice( ij
, alpha2
* as
+ beta2
* bs
).p().divide( slice( ij
, alpha1
* as
+ beta2
* bs
).p() ).maxVal();
337 return std::tanh( 0.25 * std::log( max
) );
341 template<typename T
> TFactor
<T
> RemoveFirstOrderInteractions( const TFactor
<T
> & psi
) {
342 TFactor
<T
> result
= psi
;
344 VarSet vars
= psi
.vars();
345 for( size_t iter
= 0; iter
< 100; iter
++ ) {
346 for( VarSet::const_iterator n
= vars
.begin(); n
!= vars
.end(); n
++ )
347 result
= result
* result
.part_sum(*n
).inverse();
348 result
.normalize( Prob::NORMPROB
);
355 } // end of namespace dai