New git HEAD version
[libdai.git] / src / daialg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <vector>
10 #include <stack>
11 #include <dai/daialg.h>
12
13
14 namespace dai {
15
16
17 using namespace std;
18
19
20 Factor calcMarginal( const InfAlg &obj, const VarSet &vs, bool reInit ) {
21 Factor Pvs (vs);
22
23 InfAlg *clamped = obj.clone();
24 if( !reInit )
25 clamped->init();
26
27 map<Var,size_t> varindices;
28 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ )
29 varindices[*n] = obj.fg().findVar( *n );
30
31 Real logZ0 = -INFINITY;
32 for( State s(vs); s.valid(); s++ ) {
33 // save unclamped factors connected to vs
34 clamped->backupFactors( vs );
35
36 // set clamping Factors to delta functions
37 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ )
38 clamped->clamp( varindices[*n], s(*n) );
39
40 // run DAIAlg, calc logZ, store in Pvs
41 if( reInit )
42 clamped->init();
43 else
44 clamped->init(vs);
45
46 Real logZ;
47 try {
48 clamped->run();
49 logZ = clamped->logZ();
50 } catch( Exception &e ) {
51 if( e.getCode() == Exception::NOT_NORMALIZABLE )
52 logZ = -INFINITY;
53 else
54 throw;
55 }
56
57 if( logZ0 == -INFINITY )
58 if( logZ != -INFINITY )
59 logZ0 = logZ;
60
61 if( logZ == -INFINITY )
62 Pvs.set( s, 0 );
63 else
64 Pvs.set( s, exp(logZ - logZ0) ); // subtract logZ0 to avoid very large numbers
65
66 // restore clamped factors
67 clamped->restoreFactors( vs );
68 }
69
70 delete clamped;
71
72 return( Pvs.normalized() );
73 }
74
75
76 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& vs, bool reInit, bool accurate ) {
77 vector<Factor> result;
78 size_t N = vs.size();
79 result.reserve( N * (N - 1) / 2 );
80
81 InfAlg *clamped = obj.clone();
82 if( !reInit )
83 clamped->init();
84
85 map<Var,size_t> varindices;
86 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
87 varindices[*v] = obj.fg().findVar( *v );
88
89 if( accurate ) {
90 Real logZ0 = 0.0;
91 VarSet::const_iterator nj = vs.begin();
92 for( long j = 0; j < (long)N - 1; j++, nj++ ) {
93 size_t k = 0;
94 for( VarSet::const_iterator nk = nj; (++nk) != vs.end(); k++ ) {
95 Factor pairbelief( VarSet(*nj, *nk) );
96
97 // clamp Vars j and k to their possible values
98 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
99 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
100 // save unclamped factors connected to vs
101 clamped->backupFactors( vs );
102
103 clamped->clamp( varindices[*nj], j_val );
104 clamped->clamp( varindices[*nk], k_val );
105 if( reInit )
106 clamped->init();
107 else
108 clamped->init(vs);
109
110 Real logZ;
111 try {
112 clamped->run();
113 logZ = clamped->logZ();
114 } catch( Exception &e ) {
115 if( e.getCode() == Exception::NOT_NORMALIZABLE )
116 logZ = -INFINITY;
117 else
118 throw;
119 }
120
121 if( logZ0 == -INFINITY )
122 if( logZ != -INFINITY )
123 logZ0 = logZ;
124
125 Real Z_xj;
126 if( logZ == -INFINITY )
127 Z_xj = 0;
128 else
129 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
130
131 // we assume that j.label() < k.label()
132 // i.e. we make an assumption here about the indexing
133 pairbelief.set( j_val + (k_val * nj->states()), Z_xj );
134
135 // restore clamped factors
136 clamped->restoreFactors( vs );
137 }
138
139 result.push_back( pairbelief.normalized() );
140 }
141 }
142 } else {
143 // convert vs to vector<VarSet>
144 vector<Var> vvs( vs.begin(), vs.end() );
145
146 vector<Factor> pairbeliefs;
147 pairbeliefs.reserve( N * N );
148 for( size_t j = 0; j < N; j++ )
149 for( size_t k = 0; k < N; k++ )
150 if( j == k )
151 pairbeliefs.push_back( Factor() );
152 else
153 pairbeliefs.push_back( Factor( VarSet(vvs[j], vvs[k]) ) );
154
155 Real logZ0 = -INFINITY;
156 for( size_t j = 0; j < N; j++ ) {
157 // clamp Var j to its possible values
158 for( size_t j_val = 0; j_val < vvs[j].states(); j_val++ ) {
159 clamped->clamp( varindices[vvs[j]], j_val, true );
160 if( reInit )
161 clamped->init();
162 else
163 clamped->init(vs);
164
165 Real logZ;
166 try {
167 clamped->run();
168 logZ = clamped->logZ();
169 } catch( Exception &e ) {
170 if( e.getCode() == Exception::NOT_NORMALIZABLE )
171 logZ = -INFINITY;
172 else
173 throw;
174 }
175
176 if( logZ0 == -INFINITY )
177 if( logZ != -INFINITY )
178 logZ0 = logZ;
179
180 Real Z_xj;
181 if( logZ == -INFINITY )
182 Z_xj = 0;
183 else
184 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
185
186 for( size_t k = 0; k < N; k++ )
187 if( k != j ) {
188 Factor b_k = clamped->belief(vvs[k]);
189 for( size_t k_val = 0; k_val < vvs[k].states(); k_val++ )
190 if( vvs[j].label() < vvs[k].label() )
191 pairbeliefs[j * N + k].set( j_val + (k_val * vvs[j].states()), Z_xj * b_k[k_val] );
192 else
193 pairbeliefs[j * N + k].set( k_val + (j_val * vvs[k].states()), Z_xj * b_k[k_val] );
194 }
195
196 // restore clamped factors
197 clamped->restoreFactors( vs );
198 }
199 }
200
201 // Calculate result by taking the geometric average
202 for( size_t j = 0; j < N; j++ )
203 for( size_t k = j+1; k < N; k++ )
204 result.push_back( ((pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5).normalized() );
205 }
206 delete clamped;
207 return result;
208 }
209
210
211 std::vector<size_t> findMaximum( const InfAlg& obj ) {
212 vector<size_t> maximum( obj.fg().nrVars() );
213 vector<bool> visitedVars( obj.fg().nrVars(), false );
214 vector<bool> visitedFactors( obj.fg().nrFactors(), false );
215 stack<size_t> scheduledFactors;
216 size_t nrVisitedFactors = 0;
217 size_t firstUnvisitedFactor = 0;
218 while( nrVisitedFactors < obj.fg().nrFactors() ) {
219 if( scheduledFactors.size() == 0 ) {
220 while( visitedFactors[firstUnvisitedFactor] ) {
221 firstUnvisitedFactor++;
222 if( firstUnvisitedFactor >= obj.fg().nrFactors() )
223 DAI_THROWE(RUNTIME_ERROR,"Internal error in findMaximum()");
224 }
225 scheduledFactors.push( firstUnvisitedFactor );
226 }
227
228 size_t I = scheduledFactors.top();
229 scheduledFactors.pop();
230 if( visitedFactors[I] )
231 continue;
232 visitedFactors[I] = true;
233 nrVisitedFactors++;
234
235 // Get marginal of factor I
236 Prob probF = obj.beliefF(I).p();
237
238 // The allowed configuration is restrained according to the variables assigned so far:
239 // pick the argmax amongst the allowed states
240 Real maxProb = -numeric_limits<Real>::max();
241 State maxState( obj.fg().factor(I).vars() );
242 size_t maxcount = 0;
243 for( State s( obj.fg().factor(I).vars() ); s.valid(); ++s ) {
244 // First, calculate whether this state is consistent with variables that
245 // have been assigned already
246 bool allowedState = true;
247 bforeach( const Neighbor &j, obj.fg().nbF(I) )
248 if( visitedVars[j.node] && maximum[j.node] != s(obj.fg().var(j.node)) ) {
249 allowedState = false;
250 break;
251 }
252 // If it is consistent, check if its probability is larger than what we have seen so far
253 if( allowedState ) {
254 if( probF[s] > maxProb ) {
255 maxState = s;
256 maxProb = probF[s];
257 maxcount = 1;
258 } else
259 maxcount++;
260 }
261 }
262 if( maxProb == 0.0 )
263 DAI_THROWE(RUNTIME_ERROR,"Failed to decode the MAP state (should try harder using a SAT solver, but that's not implemented yet)");
264 DAI_ASSERT( obj.fg().factor(I).p()[maxState] != 0.0 );
265
266 // Decode the argmax
267 bforeach( const Neighbor &j, obj.fg().nbF(I) ) {
268 if( visitedVars[j.node] ) {
269 // We have already visited j earlier - hopefully our state is consistent
270 if( maximum[j.node] != maxState( obj.fg().var(j.node) ) )
271 DAI_THROWE(RUNTIME_ERROR,"Detected inconsistency while decoding MAP state (should try harder using a SAT solver, but that's not implemented yet)");
272 } else {
273 // We found a consistent state for variable j
274 visitedVars[j.node] = true;
275 maximum[j.node] = maxState( obj.fg().var(j.node) );
276 bforeach( const Neighbor &J, obj.fg().nbV(j) )
277 if( !visitedFactors[J] )
278 scheduledFactors.push(J);
279 }
280 }
281 }
282 return maximum;
283 }
284
285
286 } // end of namespace dai