1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
5 This file is part of libDAI.
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
24 #include <dai/daialg.h>
33 /// Calculates the marginal of obj on ns by clamping all variables in ns and calculating logZ for each joined state.
34 Factor
calcMarginal( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
37 InfAlg
*clamped
= obj
.clone();
42 for( State
s(ns
); s
.valid(); s
++ ) {
43 // save unclamped factors connected to ns
44 clamped
->backupFactors( ns
);
46 // set clamping Factors to delta functions
47 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ )
48 clamped
->clamp( *n
, s(*n
) );
50 // run DAIAlg, calc logZ, store in Pns
59 logZ0
= clamped
->logZ();
62 // subtract logZ0 to avoid very large numbers
63 Z
= exp(clamped
->logZ() - logZ0
);
68 // restore clamped factors
69 clamped
->restoreFactors( ns
);
74 return( Pns
.normalized() );
78 /// Calculates beliefs of all pairs in ns (by clamping nodes in ns and calculating logZ and the beliefs for each state).
79 vector
<Factor
> calcPairBeliefs( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
80 // convert ns to vector<VarSet>
84 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ )
87 vector
<Factor
> pairbeliefs
;
88 pairbeliefs
.reserve( N
* N
);
89 for( size_t j
= 0; j
< N
; j
++ )
90 for( size_t k
= 0; k
< N
; k
++ )
92 pairbeliefs
.push_back( Factor() );
94 pairbeliefs
.push_back( Factor( VarSet(vns
[j
], vns
[k
]) ) );
96 InfAlg
*clamped
= obj
.clone();
101 for( size_t j
= 0; j
< N
; j
++ ) {
102 // clamp Var j to its possible values
103 for( size_t j_val
= 0; j_val
< vns
[j
].states(); j_val
++ ) {
104 clamped
->clamp( vns
[j
], j_val
, true );
112 // logZ0 = obj.logZ();
114 if( j
== 0 && j_val
== 0 ) {
115 logZ0
= clamped
->logZ();
117 // subtract logZ0 to avoid very large numbers
118 Z_xj
= exp(clamped
->logZ() - logZ0
);
121 for( size_t k
= 0; k
< N
; k
++ )
123 Factor b_k
= clamped
->belief(vns
[k
]);
124 for( size_t k_val
= 0; k_val
< vns
[k
].states(); k_val
++ )
125 if( vns
[j
].label() < vns
[k
].label() )
126 pairbeliefs
[j
* N
+ k
][j_val
+ (k_val
* vns
[j
].states())] = Z_xj
* b_k
[k_val
];
128 pairbeliefs
[j
* N
+ k
][k_val
+ (j_val
* vns
[k
].states())] = Z_xj
* b_k
[k_val
];
131 // restore clamped factors
132 clamped
->restoreFactors( ns
);
138 // Calculate result by taking the geometric average
139 vector
<Factor
> result
;
140 result
.reserve( N
* (N
- 1) / 2 );
141 for( size_t j
= 0; j
< N
; j
++ )
142 for( size_t k
= j
+1; k
< N
; k
++ )
143 result
.push_back( (pairbeliefs
[j
* N
+ k
] * pairbeliefs
[k
* N
+ j
]) ^ 0.5 );
149 /// Calculates beliefs of all pairs in ns (by clamping pairs in ns and calculating logZ for each joined state).
150 Factor
calcMarginal2ndO( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
151 // returns a a probability distribution whose 1st order interactions
152 // are unspecified, whose 2nd order interactions approximate those of
153 // the marginal on ns, and whose higher order interactions are absent.
155 vector
<Factor
> pairbeliefs
= calcPairBeliefs( obj
, ns
, reInit
);
158 for( size_t ij
= 0; ij
< pairbeliefs
.size(); ij
++ )
159 Pns
*= pairbeliefs
[ij
];
161 return( Pns
.normalized() );
165 /// Calculates 2nd order interactions of the marginal of obj on ns.
166 vector
<Factor
> calcPairBeliefsNew( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
167 vector
<Factor
> result
;
168 result
.reserve( ns
.size() * (ns
.size() - 1) / 2 );
170 InfAlg
*clamped
= obj
.clone();
175 VarSet::const_iterator nj
= ns
.begin();
176 for( long j
= 0; j
< (long)ns
.size() - 1; j
++, nj
++ ) {
178 for( VarSet::const_iterator nk
= nj
; (++nk
) != ns
.end(); k
++ ) {
179 Factor
pairbelief( VarSet(*nj
, *nk
) );
181 // clamp Vars j and k to their possible values
182 for( size_t j_val
= 0; j_val
< nj
->states(); j_val
++ )
183 for( size_t k_val
= 0; k_val
< nk
->states(); k_val
++ ) {
184 // save unclamped factors connected to ns
185 clamped
->backupFactors( ns
);
187 clamped
->clamp( *nj
, j_val
);
188 clamped
->clamp( *nk
, k_val
);
196 if( j_val
== 0 && k_val
== 0 ) {
197 logZ0
= clamped
->logZ();
199 // subtract logZ0 to avoid very large numbers
200 Z_xj
= exp(clamped
->logZ() - logZ0
);
203 // we assume that j.label() < k.label()
204 // i.e. we make an assumption here about the indexing
205 pairbelief
[j_val
+ (k_val
* nj
->states())] = Z_xj
;
207 // restore clamped factors
208 clamped
->restoreFactors( ns
);
211 result
.push_back( pairbelief
);
217 assert( result
.size() == (ns
.size() * (ns
.size() - 1) / 2) );
223 } // end of namespace dai