1 /* This file is part of libDAI - http://www.libdai.org/
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
11 #include <dai/daialg.h>
20 Factor
calcMarginal( const InfAlg
&obj
, const VarSet
&vs
, bool reInit
) {
23 InfAlg
*clamped
= obj
.clone();
27 map
<Var
,size_t> varindices
;
28 for( VarSet::const_iterator n
= vs
.begin(); n
!= vs
.end(); n
++ )
29 varindices
[*n
] = obj
.fg().findVar( *n
);
31 Real logZ0
= -INFINITY
;
32 for( State
s(vs
); s
.valid(); s
++ ) {
33 // save unclamped factors connected to vs
34 clamped
->backupFactors( vs
);
36 // set clamping Factors to delta functions
37 for( VarSet::const_iterator n
= vs
.begin(); n
!= vs
.end(); n
++ )
38 clamped
->clamp( varindices
[*n
], s(*n
) );
40 // run DAIAlg, calc logZ, store in Pvs
49 logZ
= clamped
->logZ();
50 } catch( Exception
&e
) {
51 if( e
.getCode() == Exception::NOT_NORMALIZABLE
)
57 if( logZ0
== -INFINITY
)
58 if( logZ
!= -INFINITY
)
61 if( logZ
== -INFINITY
)
64 Pvs
.set( s
, exp(logZ
- logZ0
) ); // subtract logZ0 to avoid very large numbers
66 // restore clamped factors
67 clamped
->restoreFactors( vs
);
72 return( Pvs
.normalized() );
76 vector
<Factor
> calcPairBeliefs( const InfAlg
& obj
, const VarSet
& vs
, bool reInit
, bool accurate
) {
77 vector
<Factor
> result
;
79 result
.reserve( N
* (N
- 1) / 2 );
81 InfAlg
*clamped
= obj
.clone();
85 map
<Var
,size_t> varindices
;
86 for( VarSet::const_iterator v
= vs
.begin(); v
!= vs
.end(); v
++ )
87 varindices
[*v
] = obj
.fg().findVar( *v
);
91 VarSet::const_iterator nj
= vs
.begin();
92 for( long j
= 0; j
< (long)N
- 1; j
++, nj
++ ) {
94 for( VarSet::const_iterator nk
= nj
; (++nk
) != vs
.end(); k
++ ) {
95 Factor
pairbelief( VarSet(*nj
, *nk
) );
97 // clamp Vars j and k to their possible values
98 for( size_t j_val
= 0; j_val
< nj
->states(); j_val
++ )
99 for( size_t k_val
= 0; k_val
< nk
->states(); k_val
++ ) {
100 // save unclamped factors connected to vs
101 clamped
->backupFactors( vs
);
103 clamped
->clamp( varindices
[*nj
], j_val
);
104 clamped
->clamp( varindices
[*nk
], k_val
);
113 logZ
= clamped
->logZ();
114 } catch( Exception
&e
) {
115 if( e
.getCode() == Exception::NOT_NORMALIZABLE
)
121 if( logZ0
== -INFINITY
)
122 if( logZ
!= -INFINITY
)
126 if( logZ
== -INFINITY
)
129 Z_xj
= exp(logZ
- logZ0
); // subtract logZ0 to avoid very large numbers
131 // we assume that j.label() < k.label()
132 // i.e. we make an assumption here about the indexing
133 pairbelief
.set( j_val
+ (k_val
* nj
->states()), Z_xj
);
135 // restore clamped factors
136 clamped
->restoreFactors( vs
);
139 result
.push_back( pairbelief
.normalized() );
143 // convert vs to vector<VarSet>
144 vector
<Var
> vvs( vs
.begin(), vs
.end() );
146 vector
<Factor
> pairbeliefs
;
147 pairbeliefs
.reserve( N
* N
);
148 for( size_t j
= 0; j
< N
; j
++ )
149 for( size_t k
= 0; k
< N
; k
++ )
151 pairbeliefs
.push_back( Factor() );
153 pairbeliefs
.push_back( Factor( VarSet(vvs
[j
], vvs
[k
]) ) );
155 Real logZ0
= -INFINITY
;
156 for( size_t j
= 0; j
< N
; j
++ ) {
157 // clamp Var j to its possible values
158 for( size_t j_val
= 0; j_val
< vvs
[j
].states(); j_val
++ ) {
159 clamped
->clamp( varindices
[vvs
[j
]], j_val
, true );
168 logZ
= clamped
->logZ();
169 } catch( Exception
&e
) {
170 if( e
.getCode() == Exception::NOT_NORMALIZABLE
)
176 if( logZ0
== -INFINITY
)
177 if( logZ
!= -INFINITY
)
181 if( logZ
== -INFINITY
)
184 Z_xj
= exp(logZ
- logZ0
); // subtract logZ0 to avoid very large numbers
186 for( size_t k
= 0; k
< N
; k
++ )
188 Factor b_k
= clamped
->belief(vvs
[k
]);
189 for( size_t k_val
= 0; k_val
< vvs
[k
].states(); k_val
++ )
190 if( vvs
[j
].label() < vvs
[k
].label() )
191 pairbeliefs
[j
* N
+ k
].set( j_val
+ (k_val
* vvs
[j
].states()), Z_xj
* b_k
[k_val
] );
193 pairbeliefs
[j
* N
+ k
].set( k_val
+ (j_val
* vvs
[k
].states()), Z_xj
* b_k
[k_val
] );
196 // restore clamped factors
197 clamped
->restoreFactors( vs
);
201 // Calculate result by taking the geometric average
202 for( size_t j
= 0; j
< N
; j
++ )
203 for( size_t k
= j
+1; k
< N
; k
++ )
204 result
.push_back( ((pairbeliefs
[j
* N
+ k
] * pairbeliefs
[k
* N
+ j
]) ^ 0.5).normalized() );
211 std::vector
<size_t> findMaximum( const InfAlg
& obj
) {
212 vector
<size_t> maximum( obj
.fg().nrVars() );
213 vector
<bool> visitedVars( obj
.fg().nrVars(), false );
214 vector
<bool> visitedFactors( obj
.fg().nrFactors(), false );
215 stack
<size_t> scheduledFactors
;
216 size_t nrVisitedFactors
= 0;
217 size_t firstUnvisitedFactor
= 0;
218 while( nrVisitedFactors
< obj
.fg().nrFactors() ) {
219 if( scheduledFactors
.size() == 0 ) {
220 while( visitedFactors
[firstUnvisitedFactor
] ) {
221 firstUnvisitedFactor
++;
222 if( firstUnvisitedFactor
>= obj
.fg().nrFactors() )
223 DAI_THROWE(RUNTIME_ERROR
,"Internal error in findMaximum()");
225 scheduledFactors
.push( firstUnvisitedFactor
);
228 size_t I
= scheduledFactors
.top();
229 scheduledFactors
.pop();
230 if( visitedFactors
[I
] )
232 visitedFactors
[I
] = true;
235 // Get marginal of factor I
236 Prob probF
= obj
.beliefF(I
).p();
238 // The allowed configuration is restrained according to the variables assigned so far:
239 // pick the argmax amongst the allowed states
240 Real maxProb
= -numeric_limits
<Real
>::max();
241 State
maxState( obj
.fg().factor(I
).vars() );
243 for( State
s( obj
.fg().factor(I
).vars() ); s
.valid(); ++s
) {
244 // First, calculate whether this state is consistent with variables that
245 // have been assigned already
246 bool allowedState
= true;
247 bforeach( const Neighbor
&j
, obj
.fg().nbF(I
) )
248 if( visitedVars
[j
.node
] && maximum
[j
.node
] != s(obj
.fg().var(j
.node
)) ) {
249 allowedState
= false;
252 // If it is consistent, check if its probability is larger than what we have seen so far
254 if( probF
[s
] > maxProb
) {
263 DAI_THROWE(RUNTIME_ERROR
,"Failed to decode the MAP state (should try harder using a SAT solver, but that's not implemented yet)");
264 DAI_ASSERT( obj
.fg().factor(I
).p()[maxState
] != 0.0 );
267 bforeach( const Neighbor
&j
, obj
.fg().nbF(I
) ) {
268 if( visitedVars
[j
.node
] ) {
269 // We have already visited j earlier - hopefully our state is consistent
270 if( maximum
[j
.node
] != maxState( obj
.fg().var(j
.node
) ) )
271 DAI_THROWE(RUNTIME_ERROR
,"Detected inconsistency while decoding MAP state (should try harder using a SAT solver, but that's not implemented yet)");
273 // We found a consistent state for variable j
274 visitedVars
[j
.node
] = true;
275 maximum
[j
.node
] = maxState( obj
.fg().var(j
.node
) );
276 bforeach( const Neighbor
&J
, obj
.fg().nbV(j
) )
277 if( !visitedFactors
[J
] )
278 scheduledFactors
.push(J
);
286 } // end of namespace dai