efb1588dc52552384271cd7985f453fa8b5a73d1
[libdai.git] / src / daialg.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <vector>
24 #include <dai/daialg.h>
25
26
27 namespace dai {
28
29
30 using namespace std;
31
32
33 /// Calculates the marginal of obj on ns by clamping all variables in ns and calculating logZ for each joined state.
34 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
35 */
36 Factor calcMarginal( const InfAlg &obj, const VarSet &ns, bool reInit ) {
37 Factor Pns (ns);
38
39 InfAlg *clamped = obj.clone();
40 if( !reInit )
41 clamped->init();
42
43 map<Var,size_t> varindices;
44 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
45 varindices[*n] = obj.fg().findVar( *n );
46
47 Real logZ0 = -INFINITY;
48 for( State s(ns); s.valid(); s++ ) {
49 // save unclamped factors connected to ns
50 clamped->backupFactors( ns );
51
52 // set clamping Factors to delta functions
53 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
54 clamped->clamp( varindices[*n], s(*n) );
55
56 // run DAIAlg, calc logZ, store in Pns
57 if( reInit )
58 clamped->init();
59 else
60 clamped->init(ns);
61
62 Real logZ;
63 try {
64 clamped->run();
65 logZ = clamped->logZ();
66 } catch( Exception &e ) {
67 if( e.code() == Exception::NOT_NORMALIZABLE )
68 logZ = -INFINITY;
69 else
70 throw;
71 }
72
73 if( logZ0 == -INFINITY )
74 if( logZ != -INFINITY )
75 logZ0 = logZ;
76
77 if( logZ == -INFINITY )
78 Pns[s] = 0;
79 else
80 Pns[s] = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
81
82 // restore clamped factors
83 clamped->restoreFactors( ns );
84 }
85
86 delete clamped;
87
88 return( Pns.normalized() );
89 }
90
91
92 /// Calculates beliefs of all pairs in ns (by clamping nodes in ns and calculating logZ and the beliefs for each state).
93 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
94 */
95 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit ) {
96 // convert ns to vector<VarSet>
97 size_t N = ns.size();
98 vector<Var> vns;
99 vns.reserve( N );
100 map<Var,size_t> varindices;
101 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ ) {
102 vns.push_back( *n );
103 varindices[*n] = obj.fg().findVar( *n );
104 }
105
106 vector<Factor> pairbeliefs;
107 pairbeliefs.reserve( N * N );
108 for( size_t j = 0; j < N; j++ )
109 for( size_t k = 0; k < N; k++ )
110 if( j == k )
111 pairbeliefs.push_back( Factor() );
112 else
113 pairbeliefs.push_back( Factor( VarSet(vns[j], vns[k]) ) );
114
115 InfAlg *clamped = obj.clone();
116 if( !reInit )
117 clamped->init();
118
119 Real logZ0 = -INFINITY;
120 for( size_t j = 0; j < N; j++ ) {
121 // clamp Var j to its possible values
122 for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
123 clamped->clamp( varindices[vns[j]], j_val, true );
124 if( reInit )
125 clamped->init();
126 else
127 clamped->init(ns);
128
129 Real logZ;
130 try {
131 clamped->run();
132 logZ = clamped->logZ();
133 } catch( Exception &e ) {
134 if( e.code() == Exception::NOT_NORMALIZABLE )
135 logZ = -INFINITY;
136 else
137 throw;
138 }
139
140 if( logZ0 == -INFINITY )
141 if( logZ != -INFINITY )
142 logZ0 = logZ;
143
144 double Z_xj;
145 if( logZ == -INFINITY )
146 Z_xj = 0;
147 else
148 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
149
150 for( size_t k = 0; k < N; k++ )
151 if( k != j ) {
152 Factor b_k = clamped->belief(vns[k]);
153 for( size_t k_val = 0; k_val < vns[k].states(); k_val++ )
154 if( vns[j].label() < vns[k].label() )
155 pairbeliefs[j * N + k][j_val + (k_val * vns[j].states())] = Z_xj * b_k[k_val];
156 else
157 pairbeliefs[j * N + k][k_val + (j_val * vns[k].states())] = Z_xj * b_k[k_val];
158 }
159
160 // restore clamped factors
161 clamped->restoreFactors( ns );
162 }
163 }
164
165 delete clamped;
166
167 // Calculate result by taking the geometric average
168 vector<Factor> result;
169 result.reserve( N * (N - 1) / 2 );
170 for( size_t j = 0; j < N; j++ )
171 for( size_t k = j+1; k < N; k++ )
172 result.push_back( ((pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5).normalized() );
173
174 return result;
175 }
176
177
178 /// Calculates beliefs of all pairs in ns (by clamping pairs in ns and calculating logZ for each joined state).
179 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
180 */
181 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
182 // returns a a probability distribution whose 1st order interactions
183 // are unspecified, whose 2nd order interactions approximate those of
184 // the marginal on ns, and whose higher order interactions are absent.
185
186 vector<Factor> pairbeliefs = calcPairBeliefs( obj, ns, reInit );
187
188 Factor Pns (ns);
189 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
190 Pns *= pairbeliefs[ij];
191
192 return( Pns.normalized() );
193 }
194
195
196 /// Calculates 2nd order interactions of the marginal of obj on ns.
197 vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit ) {
198 vector<Factor> result;
199 result.reserve( ns.size() * (ns.size() - 1) / 2 );
200
201 InfAlg *clamped = obj.clone();
202 if( !reInit )
203 clamped->init();
204
205 map<Var,size_t> varindices;
206 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
207 varindices[*n] = obj.fg().findVar( *n );
208
209 Real logZ0 = 0.0;
210 VarSet::const_iterator nj = ns.begin();
211 for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
212 size_t k = 0;
213 for( VarSet::const_iterator nk = nj; (++nk) != ns.end(); k++ ) {
214 Factor pairbelief( VarSet(*nj, *nk) );
215
216 // clamp Vars j and k to their possible values
217 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
218 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
219 // save unclamped factors connected to ns
220 clamped->backupFactors( ns );
221
222 clamped->clamp( varindices[*nj], j_val );
223 clamped->clamp( varindices[*nk], k_val );
224 if( reInit )
225 clamped->init();
226 else
227 clamped->init(ns);
228
229 Real logZ;
230 try {
231 clamped->run();
232 logZ = clamped->logZ();
233 } catch( Exception &e ) {
234 if( e.code() == Exception::NOT_NORMALIZABLE )
235 logZ = -INFINITY;
236 else
237 throw;
238 }
239
240 if( logZ0 == -INFINITY )
241 if( logZ != -INFINITY )
242 logZ0 = logZ;
243
244 double Z_xj;
245 if( logZ == -INFINITY )
246 Z_xj = 0;
247 else
248 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
249
250 // we assume that j.label() < k.label()
251 // i.e. we make an assumption here about the indexing
252 pairbelief[j_val + (k_val * nj->states())] = Z_xj;
253
254 // restore clamped factors
255 clamped->restoreFactors( ns );
256 }
257
258 result.push_back( pairbelief.normalized() );
259 }
260 }
261
262 delete clamped;
263
264 assert( result.size() == (ns.size() * (ns.size() - 1) / 2) );
265
266 return result;
267 }
268
269
270 } // end of namespace dai