47f61eb38a70bcb75f3402693a37e12b99f7fcbe
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
5 This file is part of libDAI.
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
24 #include <dai/daialg.h>
33 /// Calculates the marginal of obj on ns by clamping all variables in ns and calculating logZ for each joined state.
34 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
36 Factor
calcMarginal( const InfAlg
&obj
, const VarSet
&ns
, bool reInit
) {
39 InfAlg
*clamped
= obj
.clone();
43 map
<Var
,size_t> varindices
;
44 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ )
45 varindices
[*n
] = obj
.fg().findVar( *n
);
47 Real logZ0
= -INFINITY
;
48 for( State
s(ns
); s
.valid(); s
++ ) {
49 // save unclamped factors connected to ns
50 clamped
->backupFactors( ns
);
52 // set clamping Factors to delta functions
53 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ )
54 clamped
->clamp( varindices
[*n
], s(*n
) );
56 // run DAIAlg, calc logZ, store in Pns
65 logZ
= clamped
->logZ();
66 } catch( Exception
&e
) {
67 if( e
.code() == Exception::NOT_NORMALIZABLE
)
73 if( logZ0
== -INFINITY
)
74 if( logZ
!= -INFINITY
)
77 if( logZ
== -INFINITY
)
80 Pns
[s
] = exp(logZ
- logZ0
); // subtract logZ0 to avoid very large numbers
82 // restore clamped factors
83 clamped
->restoreFactors( ns
);
88 return( Pns
.normalized() );
92 /// Calculates beliefs of all pairs in ns (by clamping nodes in ns and calculating logZ and the beliefs for each state).
93 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
95 vector
<Factor
> calcPairBeliefs( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
96 // convert ns to vector<VarSet>
100 map
<Var
,size_t> varindices
;
101 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ ) {
103 varindices
[*n
] = obj
.fg().findVar( *n
);
106 vector
<Factor
> pairbeliefs
;
107 pairbeliefs
.reserve( N
* N
);
108 for( size_t j
= 0; j
< N
; j
++ )
109 for( size_t k
= 0; k
< N
; k
++ )
111 pairbeliefs
.push_back( Factor() );
113 pairbeliefs
.push_back( Factor( VarSet(vns
[j
], vns
[k
]) ) );
115 InfAlg
*clamped
= obj
.clone();
119 Real logZ0
= -INFINITY
;
120 for( size_t j
= 0; j
< N
; j
++ ) {
121 // clamp Var j to its possible values
122 for( size_t j_val
= 0; j_val
< vns
[j
].states(); j_val
++ ) {
123 clamped
->clamp( varindices
[vns
[j
]], j_val
, true );
132 logZ
= clamped
->logZ();
133 } catch( Exception
&e
) {
134 if( e
.code() == Exception::NOT_NORMALIZABLE
)
140 if( logZ0
== -INFINITY
)
141 if( logZ
!= -INFINITY
)
145 if( logZ
== -INFINITY
)
148 Z_xj
= exp(logZ
- logZ0
); // subtract logZ0 to avoid very large numbers
150 for( size_t k
= 0; k
< N
; k
++ )
152 Factor b_k
= clamped
->belief(vns
[k
]);
153 for( size_t k_val
= 0; k_val
< vns
[k
].states(); k_val
++ )
154 if( vns
[j
].label() < vns
[k
].label() )
155 pairbeliefs
[j
* N
+ k
][j_val
+ (k_val
* vns
[j
].states())] = Z_xj
* b_k
[k_val
];
157 pairbeliefs
[j
* N
+ k
][k_val
+ (j_val
* vns
[k
].states())] = Z_xj
* b_k
[k_val
];
160 // restore clamped factors
161 clamped
->restoreFactors( ns
);
167 // Calculate result by taking the geometric average
168 vector
<Factor
> result
;
169 result
.reserve( N
* (N
- 1) / 2 );
170 for( size_t j
= 0; j
< N
; j
++ )
171 for( size_t k
= j
+1; k
< N
; k
++ )
172 result
.push_back( ((pairbeliefs
[j
* N
+ k
] * pairbeliefs
[k
* N
+ j
]) ^ 0.5).normalized() );
178 /// Calculates beliefs of all pairs in ns (by clamping pairs in ns and calculating logZ for each joined state).
179 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
181 Factor
calcMarginal2ndO( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
182 // returns a a probability distribution whose 1st order interactions
183 // are unspecified, whose 2nd order interactions approximate those of
184 // the marginal on ns, and whose higher order interactions are absent.
186 vector
<Factor
> pairbeliefs
= calcPairBeliefs( obj
, ns
, reInit
);
189 for( size_t ij
= 0; ij
< pairbeliefs
.size(); ij
++ )
190 Pns
*= pairbeliefs
[ij
];
192 return( Pns
.normalized() );
196 /// Calculates 2nd order interactions of the marginal of obj on ns.
197 vector
<Factor
> calcPairBeliefsNew( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
198 vector
<Factor
> result
;
199 result
.reserve( ns
.size() * (ns
.size() - 1) / 2 );
201 InfAlg
*clamped
= obj
.clone();
205 map
<Var
,size_t> varindices
;
206 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ )
207 varindices
[*n
] = obj
.fg().findVar( *n
);
210 VarSet::const_iterator nj
= ns
.begin();
211 for( long j
= 0; j
< (long)ns
.size() - 1; j
++, nj
++ ) {
213 for( VarSet::const_iterator nk
= nj
; (++nk
) != ns
.end(); k
++ ) {
214 Factor
pairbelief( VarSet(*nj
, *nk
) );
216 // clamp Vars j and k to their possible values
217 for( size_t j_val
= 0; j_val
< nj
->states(); j_val
++ )
218 for( size_t k_val
= 0; k_val
< nk
->states(); k_val
++ ) {
219 // save unclamped factors connected to ns
220 clamped
->backupFactors( ns
);
222 clamped
->clamp( varindices
[*nj
], j_val
);
223 clamped
->clamp( varindices
[*nk
], k_val
);
232 logZ
= clamped
->logZ();
233 } catch( Exception
&e
) {
234 if( e
.code() == Exception::NOT_NORMALIZABLE
)
240 if( logZ0
== -INFINITY
)
241 if( logZ
!= -INFINITY
)
245 if( logZ
== -INFINITY
)
248 Z_xj
= exp(logZ
- logZ0
); // subtract logZ0 to avoid very large numbers
250 // we assume that j.label() < k.label()
251 // i.e. we make an assumption here about the indexing
252 pairbelief
[j_val
+ (k_val
* nj
->states())] = Z_xj
;
254 // restore clamped factors
255 clamped
->restoreFactors( ns
);
258 result
.push_back( pairbelief
.normalized() );
264 assert( result
.size() == (ns
.size() * (ns
.size() - 1) / 2) );
270 } // end of namespace dai