7ddb49501e8a4680f6cb03b8de5d505debdb3cd3
[libdai.git] / src / daialg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <vector>
10 #include <stack>
11 #include <dai/daialg.h>
12
13
14 namespace dai {
15
16
17 using namespace std;
18
19
20 Factor calcMarginal( const InfAlg &obj, const VarSet &vs, bool reInit ) {
21 Factor Pvs (vs);
22
23 InfAlg *clamped = obj.clone();
24 if( !reInit )
25 clamped->init();
26
27 map<Var,size_t> varindices;
28 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ )
29 varindices[*n] = obj.fg().findVar( *n );
30
31 Real logZ0 = -INFINITY;
32 for( State s(vs); s.valid(); s++ ) {
33 // save unclamped factors connected to vs
34 clamped->backupFactors( vs );
35
36 // set clamping Factors to delta functions
37 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ )
38 clamped->clamp( varindices[*n], s(*n) );
39
40 // run DAIAlg, calc logZ, store in Pvs
41 if( reInit )
42 clamped->init();
43 else
44 clamped->init(vs);
45
46 Real logZ;
47 try {
48 clamped->run();
49 logZ = clamped->logZ();
50 } catch( Exception &e ) {
51 if( e.getCode() == Exception::NOT_NORMALIZABLE )
52 logZ = -INFINITY;
53 else
54 throw;
55 }
56
57 if( logZ0 == -INFINITY )
58 if( logZ != -INFINITY )
59 logZ0 = logZ;
60
61 if( logZ == -INFINITY )
62 Pvs.set( s, 0 );
63 else
64 Pvs.set( s, exp(logZ - logZ0) ); // subtract logZ0 to avoid very large numbers
65
66 // restore clamped factors
67 clamped->restoreFactors( vs );
68 }
69
70 delete clamped;
71
72 return( Pvs.normalized() );
73 }
74
75
76 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& vs, bool reInit, bool accurate ) {
77 vector<Factor> result;
78 size_t N = vs.size();
79 result.reserve( N * (N - 1) / 2 );
80
81 InfAlg *clamped = obj.clone();
82 if( !reInit )
83 clamped->init();
84
85 map<Var,size_t> varindices;
86 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
87 varindices[*v] = obj.fg().findVar( *v );
88
89 if( accurate ) {
90 Real logZ0 = 0.0;
91 VarSet::const_iterator nj = vs.begin();
92 for( long j = 0; j < (long)N - 1; j++, nj++ ) {
93 size_t k = 0;
94 for( VarSet::const_iterator nk = nj; (++nk) != vs.end(); k++ ) {
95 Factor pairbelief( VarSet(*nj, *nk) );
96
97 // clamp Vars j and k to their possible values
98 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
99 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
100 // save unclamped factors connected to vs
101 clamped->backupFactors( vs );
102
103 clamped->clamp( varindices[*nj], j_val );
104 clamped->clamp( varindices[*nk], k_val );
105 if( reInit )
106 clamped->init();
107 else
108 clamped->init(vs);
109
110 Real logZ;
111 try {
112 clamped->run();
113 logZ = clamped->logZ();
114 } catch( Exception &e ) {
115 if( e.getCode() == Exception::NOT_NORMALIZABLE )
116 logZ = -INFINITY;
117 else
118 throw;
119 }
120
121 if( logZ0 == -INFINITY )
122 if( logZ != -INFINITY )
123 logZ0 = logZ;
124
125 Real Z_xj;
126 if( logZ == -INFINITY )
127 Z_xj = 0;
128 else
129 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
130
131 // we assume that j.label() < k.label()
132 // i.e. we make an assumption here about the indexing
133 pairbelief.set( j_val + (k_val * nj->states()), Z_xj );
134
135 // restore clamped factors
136 clamped->restoreFactors( vs );
137 }
138
139 result.push_back( pairbelief.normalized() );
140 }
141 }
142 } else {
143 // convert vs to vector<VarSet>
144 vector<Var> vvs( vs.begin(), vs.end() );
145
146 vector<Factor> pairbeliefs;
147 pairbeliefs.reserve( N * N );
148 for( size_t j = 0; j < N; j++ )
149 for( size_t k = 0; k < N; k++ )
150 if( j == k )
151 pairbeliefs.push_back( Factor() );
152 else
153 pairbeliefs.push_back( Factor( VarSet(vvs[j], vvs[k]) ) );
154
155 Real logZ0 = -INFINITY;
156 for( size_t j = 0; j < N; j++ ) {
157 // clamp Var j to its possible values
158 for( size_t j_val = 0; j_val < vvs[j].states(); j_val++ ) {
159 clamped->clamp( varindices[vvs[j]], j_val, true );
160 if( reInit )
161 clamped->init();
162 else
163 clamped->init(vs);
164
165 Real logZ;
166 try {
167 clamped->run();
168 logZ = clamped->logZ();
169 } catch( Exception &e ) {
170 if( e.getCode() == Exception::NOT_NORMALIZABLE )
171 logZ = -INFINITY;
172 else
173 throw;
174 }
175
176 if( logZ0 == -INFINITY )
177 if( logZ != -INFINITY )
178 logZ0 = logZ;
179
180 Real Z_xj;
181 if( logZ == -INFINITY )
182 Z_xj = 0;
183 else
184 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
185
186 for( size_t k = 0; k < N; k++ )
187 if( k != j ) {
188 Factor b_k = clamped->belief(vvs[k]);
189 for( size_t k_val = 0; k_val < vvs[k].states(); k_val++ )
190 if( vvs[j].label() < vvs[k].label() )
191 pairbeliefs[j * N + k].set( j_val + (k_val * vvs[j].states()), Z_xj * b_k[k_val] );
192 else
193 pairbeliefs[j * N + k].set( k_val + (j_val * vvs[k].states()), Z_xj * b_k[k_val] );
194 }
195
196 // restore clamped factors
197 clamped->restoreFactors( vs );
198 }
199 }
200
201 // Calculate result by taking the geometric average
202 for( size_t j = 0; j < N; j++ )
203 for( size_t k = j+1; k < N; k++ )
204 result.push_back( ((pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5).normalized() );
205 }
206 delete clamped;
207 return result;
208 }
209
210
211 std::vector<size_t> findMaximum( const InfAlg& obj ) {
212 vector<size_t> maximum( obj.fg().nrVars() );
213 vector<bool> visitedVars( obj.fg().nrVars(), false );
214 vector<bool> visitedFactors( obj.fg().nrFactors(), false );
215 stack<size_t> scheduledFactors;
216 scheduledFactors.push( 0 );
217 while( !scheduledFactors.empty() ) {
218 size_t I = scheduledFactors.top();
219 scheduledFactors.pop();
220 if( visitedFactors[I] )
221 continue;
222 visitedFactors[I] = true;
223
224 // Get marginal of factor I
225 Prob probF = obj.beliefF(I).p();
226
227 // The allowed configuration is restrained according to the variables assigned so far:
228 // pick the argmax amongst the allowed states
229 Real maxProb = -numeric_limits<Real>::max();
230 State maxState( obj.fg().factor(I).vars() );
231 size_t maxcount = 0;
232 for( State s( obj.fg().factor(I).vars() ); s.valid(); ++s ) {
233 // First, calculate whether this state is consistent with variables that
234 // have been assigned already
235 bool allowedState = true;
236 bforeach( const Neighbor &j, obj.fg().nbF(I) )
237 if( visitedVars[j.node] && maximum[j.node] != s(obj.fg().var(j.node)) ) {
238 allowedState = false;
239 break;
240 }
241 // If it is consistent, check if its probability is larger than what we have seen so far
242 if( allowedState ) {
243 if( probF[s] > maxProb ) {
244 maxState = s;
245 maxProb = probF[s];
246 maxcount = 1;
247 } else
248 maxcount++;
249 }
250 }
251 if( maxProb == 0.0 )
252 DAI_THROWE(RUNTIME_ERROR,"Failed to decode the MAP state (should try harder using a SAT solver, but that's not implemented yet)");
253 DAI_ASSERT( obj.fg().factor(I).p()[maxState] != 0.0 );
254
255 // Decode the argmax
256 bforeach( const Neighbor &j, obj.fg().nbF(I) ) {
257 if( visitedVars[j.node] ) {
258 // We have already visited j earlier - hopefully our state is consistent
259 if( maximum[j.node] != maxState( obj.fg().var(j.node) ) )
260 DAI_THROWE(RUNTIME_ERROR,"Detected inconsistency while decoding MAP state (should try harder using a SAT solver, but that's not implemented yet)");
261 } else {
262 // We found a consistent state for variable j
263 visitedVars[j.node] = true;
264 maximum[j.node] = maxState( obj.fg().var(j.node) );
265 bforeach( const Neighbor &J, obj.fg().nbV(j) )
266 if( !visitedFactors[J] )
267 scheduledFactors.push(J);
268 }
269 }
270 }
271 return maximum;
272 }
273
274
275 } // end of namespace dai