1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
13 #include <dai/daialg.h>
22 /// Calculates the marginal of obj on ns by clamping all variables in ns and calculating logZ for each joined state.
23 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
25 Factor
calcMarginal( const InfAlg
&obj
, const VarSet
&ns
, bool reInit
) {
28 InfAlg
*clamped
= obj
.clone();
32 map
<Var
,size_t> varindices
;
33 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ )
34 varindices
[*n
] = obj
.fg().findVar( *n
);
36 Real logZ0
= -INFINITY
;
37 for( State
s(ns
); s
.valid(); s
++ ) {
38 // save unclamped factors connected to ns
39 clamped
->backupFactors( ns
);
41 // set clamping Factors to delta functions
42 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ )
43 clamped
->clamp( varindices
[*n
], s(*n
) );
45 // run DAIAlg, calc logZ, store in Pns
54 logZ
= clamped
->logZ();
55 } catch( Exception
&e
) {
56 if( e
.code() == Exception::NOT_NORMALIZABLE
)
62 if( logZ0
== -INFINITY
)
63 if( logZ
!= -INFINITY
)
66 if( logZ
== -INFINITY
)
69 Pns
[s
] = exp(logZ
- logZ0
); // subtract logZ0 to avoid very large numbers
71 // restore clamped factors
72 clamped
->restoreFactors( ns
);
77 return( Pns
.normalized() );
81 /// Calculates beliefs of all pairs in ns (by clamping nodes in ns and calculating logZ and the beliefs for each state).
82 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
84 vector
<Factor
> calcPairBeliefs( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
85 // convert ns to vector<VarSet>
89 map
<Var
,size_t> varindices
;
90 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ ) {
92 varindices
[*n
] = obj
.fg().findVar( *n
);
95 vector
<Factor
> pairbeliefs
;
96 pairbeliefs
.reserve( N
* N
);
97 for( size_t j
= 0; j
< N
; j
++ )
98 for( size_t k
= 0; k
< N
; k
++ )
100 pairbeliefs
.push_back( Factor() );
102 pairbeliefs
.push_back( Factor( VarSet(vns
[j
], vns
[k
]) ) );
104 InfAlg
*clamped
= obj
.clone();
108 Real logZ0
= -INFINITY
;
109 for( size_t j
= 0; j
< N
; j
++ ) {
110 // clamp Var j to its possible values
111 for( size_t j_val
= 0; j_val
< vns
[j
].states(); j_val
++ ) {
112 clamped
->clamp( varindices
[vns
[j
]], j_val
, true );
121 logZ
= clamped
->logZ();
122 } catch( Exception
&e
) {
123 if( e
.code() == Exception::NOT_NORMALIZABLE
)
129 if( logZ0
== -INFINITY
)
130 if( logZ
!= -INFINITY
)
134 if( logZ
== -INFINITY
)
137 Z_xj
= exp(logZ
- logZ0
); // subtract logZ0 to avoid very large numbers
139 for( size_t k
= 0; k
< N
; k
++ )
141 Factor b_k
= clamped
->belief(vns
[k
]);
142 for( size_t k_val
= 0; k_val
< vns
[k
].states(); k_val
++ )
143 if( vns
[j
].label() < vns
[k
].label() )
144 pairbeliefs
[j
* N
+ k
][j_val
+ (k_val
* vns
[j
].states())] = Z_xj
* b_k
[k_val
];
146 pairbeliefs
[j
* N
+ k
][k_val
+ (j_val
* vns
[k
].states())] = Z_xj
* b_k
[k_val
];
149 // restore clamped factors
150 clamped
->restoreFactors( ns
);
156 // Calculate result by taking the geometric average
157 vector
<Factor
> result
;
158 result
.reserve( N
* (N
- 1) / 2 );
159 for( size_t j
= 0; j
< N
; j
++ )
160 for( size_t k
= j
+1; k
< N
; k
++ )
161 result
.push_back( ((pairbeliefs
[j
* N
+ k
] * pairbeliefs
[k
* N
+ j
]) ^ 0.5).normalized() );
167 /// Calculates beliefs of all pairs in ns (by clamping pairs in ns and calculating logZ for each joined state).
168 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
170 Factor
calcMarginal2ndO( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
171 // returns a a probability distribution whose 1st order interactions
172 // are unspecified, whose 2nd order interactions approximate those of
173 // the marginal on ns, and whose higher order interactions are absent.
175 vector
<Factor
> pairbeliefs
= calcPairBeliefs( obj
, ns
, reInit
);
178 for( size_t ij
= 0; ij
< pairbeliefs
.size(); ij
++ )
179 Pns
*= pairbeliefs
[ij
];
181 return( Pns
.normalized() );
185 /// Calculates 2nd order interactions of the marginal of obj on ns.
186 vector
<Factor
> calcPairBeliefsNew( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
187 vector
<Factor
> result
;
188 result
.reserve( ns
.size() * (ns
.size() - 1) / 2 );
190 InfAlg
*clamped
= obj
.clone();
194 map
<Var
,size_t> varindices
;
195 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ )
196 varindices
[*n
] = obj
.fg().findVar( *n
);
199 VarSet::const_iterator nj
= ns
.begin();
200 for( long j
= 0; j
< (long)ns
.size() - 1; j
++, nj
++ ) {
202 for( VarSet::const_iterator nk
= nj
; (++nk
) != ns
.end(); k
++ ) {
203 Factor
pairbelief( VarSet(*nj
, *nk
) );
205 // clamp Vars j and k to their possible values
206 for( size_t j_val
= 0; j_val
< nj
->states(); j_val
++ )
207 for( size_t k_val
= 0; k_val
< nk
->states(); k_val
++ ) {
208 // save unclamped factors connected to ns
209 clamped
->backupFactors( ns
);
211 clamped
->clamp( varindices
[*nj
], j_val
);
212 clamped
->clamp( varindices
[*nk
], k_val
);
221 logZ
= clamped
->logZ();
222 } catch( Exception
&e
) {
223 if( e
.code() == Exception::NOT_NORMALIZABLE
)
229 if( logZ0
== -INFINITY
)
230 if( logZ
!= -INFINITY
)
234 if( logZ
== -INFINITY
)
237 Z_xj
= exp(logZ
- logZ0
); // subtract logZ0 to avoid very large numbers
239 // we assume that j.label() < k.label()
240 // i.e. we make an assumption here about the indexing
241 pairbelief
[j_val
+ (k_val
* nj
->states())] = Z_xj
;
243 // restore clamped factors
244 clamped
->restoreFactors( ns
);
247 result
.push_back( pairbelief
.normalized() );
253 assert( result
.size() == (ns
.size() * (ns
.size() - 1) / 2) );
259 } // end of namespace dai