01ec3d9a8928966a2337aa83ece5661a78176595
[libdai.git] / src / daialg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <vector>
13 #include <stack>
14 #include <dai/daialg.h>
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 Factor calcMarginal( const InfAlg &obj, const VarSet &vs, bool reInit ) {
24 Factor Pvs (vs);
25
26 InfAlg *clamped = obj.clone();
27 if( !reInit )
28 clamped->init();
29
30 map<Var,size_t> varindices;
31 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ )
32 varindices[*n] = obj.fg().findVar( *n );
33
34 Real logZ0 = -INFINITY;
35 for( State s(vs); s.valid(); s++ ) {
36 // save unclamped factors connected to vs
37 clamped->backupFactors( vs );
38
39 // set clamping Factors to delta functions
40 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ )
41 clamped->clamp( varindices[*n], s(*n) );
42
43 // run DAIAlg, calc logZ, store in Pvs
44 if( reInit )
45 clamped->init();
46 else
47 clamped->init(vs);
48
49 Real logZ;
50 try {
51 clamped->run();
52 logZ = clamped->logZ();
53 } catch( Exception &e ) {
54 if( e.code() == Exception::NOT_NORMALIZABLE )
55 logZ = -INFINITY;
56 else
57 throw;
58 }
59
60 if( logZ0 == -INFINITY )
61 if( logZ != -INFINITY )
62 logZ0 = logZ;
63
64 if( logZ == -INFINITY )
65 Pvs.set( s, 0 );
66 else
67 Pvs.set( s, exp(logZ - logZ0) ); // subtract logZ0 to avoid very large numbers
68
69 // restore clamped factors
70 clamped->restoreFactors( vs );
71 }
72
73 delete clamped;
74
75 return( Pvs.normalized() );
76 }
77
78
79 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& vs, bool reInit, bool accurate ) {
80 vector<Factor> result;
81 size_t N = vs.size();
82 result.reserve( N * (N - 1) / 2 );
83
84 InfAlg *clamped = obj.clone();
85 if( !reInit )
86 clamped->init();
87
88 map<Var,size_t> varindices;
89 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
90 varindices[*v] = obj.fg().findVar( *v );
91
92 if( accurate ) {
93 Real logZ0 = 0.0;
94 VarSet::const_iterator nj = vs.begin();
95 for( long j = 0; j < (long)N - 1; j++, nj++ ) {
96 size_t k = 0;
97 for( VarSet::const_iterator nk = nj; (++nk) != vs.end(); k++ ) {
98 Factor pairbelief( VarSet(*nj, *nk) );
99
100 // clamp Vars j and k to their possible values
101 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
102 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
103 // save unclamped factors connected to vs
104 clamped->backupFactors( vs );
105
106 clamped->clamp( varindices[*nj], j_val );
107 clamped->clamp( varindices[*nk], k_val );
108 if( reInit )
109 clamped->init();
110 else
111 clamped->init(vs);
112
113 Real logZ;
114 try {
115 clamped->run();
116 logZ = clamped->logZ();
117 } catch( Exception &e ) {
118 if( e.code() == Exception::NOT_NORMALIZABLE )
119 logZ = -INFINITY;
120 else
121 throw;
122 }
123
124 if( logZ0 == -INFINITY )
125 if( logZ != -INFINITY )
126 logZ0 = logZ;
127
128 Real Z_xj;
129 if( logZ == -INFINITY )
130 Z_xj = 0;
131 else
132 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
133
134 // we assume that j.label() < k.label()
135 // i.e. we make an assumption here about the indexing
136 pairbelief.set( j_val + (k_val * nj->states()), Z_xj );
137
138 // restore clamped factors
139 clamped->restoreFactors( vs );
140 }
141
142 result.push_back( pairbelief.normalized() );
143 }
144 }
145 } else {
146 // convert vs to vector<VarSet>
147 vector<Var> vvs( vs.begin(), vs.end() );
148
149 vector<Factor> pairbeliefs;
150 pairbeliefs.reserve( N * N );
151 for( size_t j = 0; j < N; j++ )
152 for( size_t k = 0; k < N; k++ )
153 if( j == k )
154 pairbeliefs.push_back( Factor() );
155 else
156 pairbeliefs.push_back( Factor( VarSet(vvs[j], vvs[k]) ) );
157
158 Real logZ0 = -INFINITY;
159 for( size_t j = 0; j < N; j++ ) {
160 // clamp Var j to its possible values
161 for( size_t j_val = 0; j_val < vvs[j].states(); j_val++ ) {
162 clamped->clamp( varindices[vvs[j]], j_val, true );
163 if( reInit )
164 clamped->init();
165 else
166 clamped->init(vs);
167
168 Real logZ;
169 try {
170 clamped->run();
171 logZ = clamped->logZ();
172 } catch( Exception &e ) {
173 if( e.code() == Exception::NOT_NORMALIZABLE )
174 logZ = -INFINITY;
175 else
176 throw;
177 }
178
179 if( logZ0 == -INFINITY )
180 if( logZ != -INFINITY )
181 logZ0 = logZ;
182
183 Real Z_xj;
184 if( logZ == -INFINITY )
185 Z_xj = 0;
186 else
187 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
188
189 for( size_t k = 0; k < N; k++ )
190 if( k != j ) {
191 Factor b_k = clamped->belief(vvs[k]);
192 for( size_t k_val = 0; k_val < vvs[k].states(); k_val++ )
193 if( vvs[j].label() < vvs[k].label() )
194 pairbeliefs[j * N + k].set( j_val + (k_val * vvs[j].states()), Z_xj * b_k[k_val] );
195 else
196 pairbeliefs[j * N + k].set( k_val + (j_val * vvs[k].states()), Z_xj * b_k[k_val] );
197 }
198
199 // restore clamped factors
200 clamped->restoreFactors( vs );
201 }
202 }
203
204 // Calculate result by taking the geometric average
205 for( size_t j = 0; j < N; j++ )
206 for( size_t k = j+1; k < N; k++ )
207 result.push_back( ((pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5).normalized() );
208 }
209 delete clamped;
210 return result;
211 }
212
213
214 std::vector<size_t> findMaximum( const InfAlg& obj ) {
215 vector<size_t> maximum( obj.fg().nrVars() );
216 vector<bool> visitedVars( obj.fg().nrVars(), false );
217 vector<bool> visitedFactors( obj.fg().nrFactors(), false );
218 stack<size_t> scheduledFactors;
219 for( size_t i = 0; i < obj.fg().nrVars(); ++i ) {
220 if( visitedVars[i] )
221 continue;
222 visitedVars[i] = true;
223
224 // Maximise with respect to variable i
225 Prob prod = obj.beliefV(i).p();
226 maximum[i] = prod.argmax().first;
227
228 foreach( const Neighbor &I, obj.fg().nbV(i) )
229 if( !visitedFactors[I] )
230 scheduledFactors.push(I);
231
232 while( !scheduledFactors.empty() ){
233 size_t I = scheduledFactors.top();
234 scheduledFactors.pop();
235 if( visitedFactors[I] )
236 continue;
237 visitedFactors[I] = true;
238
239 // Evaluate if some neighboring variables still need to be fixed; if not, we're done
240 bool allDetermined = true;
241 foreach( const Neighbor &j, obj.fg().nbF(I) )
242 if( !visitedVars[j.node] ) {
243 allDetermined = false;
244 break;
245 }
246 if( allDetermined )
247 continue;
248
249 // Get marginal of factor I
250 Prob probF = obj.beliefF(I).p();
251
252 // The allowed configuration is restrained according to the variables assigned so far:
253 // pick the argmax amongst the allowed states
254 Real maxProb = -numeric_limits<Real>::max();
255 State maxState( obj.fg().factor(I).vars() );
256 for( State s( obj.fg().factor(I).vars() ); s.valid(); ++s ) {
257 // First, calculate whether this state is consistent with variables that
258 // have been assigned already
259 bool allowedState = true;
260 foreach( const Neighbor &j, obj.fg().nbF(I) )
261 if( visitedVars[j.node] && maximum[j.node] != s(obj.fg().var(j.node)) ) {
262 allowedState = false;
263 break;
264 }
265 // If it is consistent, check if its probability is larger than what we have seen so far
266 if( allowedState && probF[s] > maxProb ) {
267 maxState = s;
268 maxProb = probF[s];
269 }
270 }
271
272 // Decode the argmax
273 foreach( const Neighbor &j, obj.fg().nbF(I) ) {
274 if( visitedVars[j.node] ) {
275 // We have already visited j earlier - hopefully our state is consistent
276 if( maximum[j.node] != maxState( obj.fg().var(j.node) ) )
277 DAI_THROWE(RUNTIME_ERROR,"MAP state inconsistent due to loops");
278 } else {
279 // We found a consistent state for variable j
280 visitedVars[j.node] = true;
281 maximum[j.node] = maxState( obj.fg().var(j.node) );
282 foreach( const Neighbor &J, obj.fg().nbV(j) )
283 if( !visitedFactors[J] )
284 scheduledFactors.push(J);
285 }
286 }
287 }
288 }
289 return maximum;
290 }
291
292
293 } // end of namespace dai