Fixed bug in findMaximum(): inconsistent max-beliefs are now detected,
[libdai.git] / src / daialg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 * Copyright (C) 2008-2009 Giuseppe Passino
10 */
11
12
13 #include <vector>
14 #include <stack>
15 #include <dai/daialg.h>
16
17
18 namespace dai {
19
20
21 using namespace std;
22
23
24 Factor calcMarginal( const InfAlg &obj, const VarSet &vs, bool reInit ) {
25 Factor Pvs (vs);
26
27 InfAlg *clamped = obj.clone();
28 if( !reInit )
29 clamped->init();
30
31 map<Var,size_t> varindices;
32 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ )
33 varindices[*n] = obj.fg().findVar( *n );
34
35 Real logZ0 = -INFINITY;
36 for( State s(vs); s.valid(); s++ ) {
37 // save unclamped factors connected to vs
38 clamped->backupFactors( vs );
39
40 // set clamping Factors to delta functions
41 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ )
42 clamped->clamp( varindices[*n], s(*n) );
43
44 // run DAIAlg, calc logZ, store in Pvs
45 if( reInit )
46 clamped->init();
47 else
48 clamped->init(vs);
49
50 Real logZ;
51 try {
52 clamped->run();
53 logZ = clamped->logZ();
54 } catch( Exception &e ) {
55 if( e.code() == Exception::NOT_NORMALIZABLE )
56 logZ = -INFINITY;
57 else
58 throw;
59 }
60
61 if( logZ0 == -INFINITY )
62 if( logZ != -INFINITY )
63 logZ0 = logZ;
64
65 if( logZ == -INFINITY )
66 Pvs.set( s, 0 );
67 else
68 Pvs.set( s, exp(logZ - logZ0) ); // subtract logZ0 to avoid very large numbers
69
70 // restore clamped factors
71 clamped->restoreFactors( vs );
72 }
73
74 delete clamped;
75
76 return( Pvs.normalized() );
77 }
78
79
80 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& vs, bool reInit, bool accurate ) {
81 vector<Factor> result;
82 size_t N = vs.size();
83 result.reserve( N * (N - 1) / 2 );
84
85 InfAlg *clamped = obj.clone();
86 if( !reInit )
87 clamped->init();
88
89 map<Var,size_t> varindices;
90 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
91 varindices[*v] = obj.fg().findVar( *v );
92
93 if( accurate ) {
94 Real logZ0 = 0.0;
95 VarSet::const_iterator nj = vs.begin();
96 for( long j = 0; j < (long)N - 1; j++, nj++ ) {
97 size_t k = 0;
98 for( VarSet::const_iterator nk = nj; (++nk) != vs.end(); k++ ) {
99 Factor pairbelief( VarSet(*nj, *nk) );
100
101 // clamp Vars j and k to their possible values
102 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
103 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
104 // save unclamped factors connected to vs
105 clamped->backupFactors( vs );
106
107 clamped->clamp( varindices[*nj], j_val );
108 clamped->clamp( varindices[*nk], k_val );
109 if( reInit )
110 clamped->init();
111 else
112 clamped->init(vs);
113
114 Real logZ;
115 try {
116 clamped->run();
117 logZ = clamped->logZ();
118 } catch( Exception &e ) {
119 if( e.code() == Exception::NOT_NORMALIZABLE )
120 logZ = -INFINITY;
121 else
122 throw;
123 }
124
125 if( logZ0 == -INFINITY )
126 if( logZ != -INFINITY )
127 logZ0 = logZ;
128
129 Real Z_xj;
130 if( logZ == -INFINITY )
131 Z_xj = 0;
132 else
133 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
134
135 // we assume that j.label() < k.label()
136 // i.e. we make an assumption here about the indexing
137 pairbelief.set( j_val + (k_val * nj->states()), Z_xj );
138
139 // restore clamped factors
140 clamped->restoreFactors( vs );
141 }
142
143 result.push_back( pairbelief.normalized() );
144 }
145 }
146 } else {
147 // convert vs to vector<VarSet>
148 vector<Var> vvs( vs.begin(), vs.end() );
149
150 vector<Factor> pairbeliefs;
151 pairbeliefs.reserve( N * N );
152 for( size_t j = 0; j < N; j++ )
153 for( size_t k = 0; k < N; k++ )
154 if( j == k )
155 pairbeliefs.push_back( Factor() );
156 else
157 pairbeliefs.push_back( Factor( VarSet(vvs[j], vvs[k]) ) );
158
159 Real logZ0 = -INFINITY;
160 for( size_t j = 0; j < N; j++ ) {
161 // clamp Var j to its possible values
162 for( size_t j_val = 0; j_val < vvs[j].states(); j_val++ ) {
163 clamped->clamp( varindices[vvs[j]], j_val, true );
164 if( reInit )
165 clamped->init();
166 else
167 clamped->init(vs);
168
169 Real logZ;
170 try {
171 clamped->run();
172 logZ = clamped->logZ();
173 } catch( Exception &e ) {
174 if( e.code() == Exception::NOT_NORMALIZABLE )
175 logZ = -INFINITY;
176 else
177 throw;
178 }
179
180 if( logZ0 == -INFINITY )
181 if( logZ != -INFINITY )
182 logZ0 = logZ;
183
184 Real Z_xj;
185 if( logZ == -INFINITY )
186 Z_xj = 0;
187 else
188 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
189
190 for( size_t k = 0; k < N; k++ )
191 if( k != j ) {
192 Factor b_k = clamped->belief(vvs[k]);
193 for( size_t k_val = 0; k_val < vvs[k].states(); k_val++ )
194 if( vvs[j].label() < vvs[k].label() )
195 pairbeliefs[j * N + k].set( j_val + (k_val * vvs[j].states()), Z_xj * b_k[k_val] );
196 else
197 pairbeliefs[j * N + k].set( k_val + (j_val * vvs[k].states()), Z_xj * b_k[k_val] );
198 }
199
200 // restore clamped factors
201 clamped->restoreFactors( vs );
202 }
203 }
204
205 // Calculate result by taking the geometric average
206 for( size_t j = 0; j < N; j++ )
207 for( size_t k = j+1; k < N; k++ )
208 result.push_back( ((pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5).normalized() );
209 }
210 delete clamped;
211 return result;
212 }
213
214
215 std::vector<size_t> findMaximum( const InfAlg& obj ) {
216 vector<size_t> maximum( obj.fg().nrVars() );
217 vector<bool> visitedVars( obj.fg().nrVars(), false );
218 vector<bool> visitedFactors( obj.fg().nrFactors(), false );
219 stack<size_t> scheduledFactors;
220 scheduledFactors.push( 0 );
221 while( !scheduledFactors.empty() ) {
222 size_t I = scheduledFactors.top();
223 scheduledFactors.pop();
224 if( visitedFactors[I] )
225 continue;
226 visitedFactors[I] = true;
227
228 // Get marginal of factor I
229 Prob probF = obj.beliefF(I).p();
230
231 // The allowed configuration is restrained according to the variables assigned so far:
232 // pick the argmax amongst the allowed states
233 Real maxProb = -numeric_limits<Real>::max();
234 State maxState( obj.fg().factor(I).vars() );
235 size_t maxcount = 0;
236 for( State s( obj.fg().factor(I).vars() ); s.valid(); ++s ) {
237 // First, calculate whether this state is consistent with variables that
238 // have been assigned already
239 bool allowedState = true;
240 foreach( const Neighbor &j, obj.fg().nbF(I) )
241 if( visitedVars[j.node] && maximum[j.node] != s(obj.fg().var(j.node)) ) {
242 allowedState = false;
243 break;
244 }
245 // If it is consistent, check if its probability is larger than what we have seen so far
246 if( allowedState ) {
247 if( probF[s] > maxProb ) {
248 maxState = s;
249 maxProb = probF[s];
250 maxcount = 1;
251 } else
252 maxcount++;
253 }
254 }
255 DAI_ASSERT( maxProb != 0.0 );
256 DAI_ASSERT( obj.fg().factor(I).p()[maxState] != 0.0 );
257
258 // Decode the argmax
259 foreach( const Neighbor &j, obj.fg().nbF(I) ) {
260 if( visitedVars[j.node] ) {
261 // We have already visited j earlier - hopefully our state is consistent
262 if( maximum[j.node] != maxState( obj.fg().var(j.node) ) )
263 DAI_THROWE(RUNTIME_ERROR,"MAP state inconsistent due to loops");
264 } else {
265 // We found a consistent state for variable j
266 visitedVars[j.node] = true;
267 maximum[j.node] = maxState( obj.fg().var(j.node) );
268 foreach( const Neighbor &J, obj.fg().nbV(j) )
269 if( !visitedFactors[J] )
270 scheduledFactors.push(J);
271 }
272 }
273 }
274 return maximum;
275 }
276
277
278 } // end of namespace dai