Merge branch 'pletscher'
[libdai.git] / src / daialg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <vector>
13 #include <dai/daialg.h>
14
15
16 namespace dai {
17
18
19 using namespace std;
20
21
22 /// Calculates the marginal of obj on ns by clamping all variables in ns and calculating logZ for each joined state.
23 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
24 */
25 Factor calcMarginal( const InfAlg &obj, const VarSet &ns, bool reInit ) {
26 Factor Pns (ns);
27
28 InfAlg *clamped = obj.clone();
29 if( !reInit )
30 clamped->init();
31
32 map<Var,size_t> varindices;
33 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
34 varindices[*n] = obj.fg().findVar( *n );
35
36 Real logZ0 = -INFINITY;
37 for( State s(ns); s.valid(); s++ ) {
38 // save unclamped factors connected to ns
39 clamped->backupFactors( ns );
40
41 // set clamping Factors to delta functions
42 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
43 clamped->clamp( varindices[*n], s(*n) );
44
45 // run DAIAlg, calc logZ, store in Pns
46 if( reInit )
47 clamped->init();
48 else
49 clamped->init(ns);
50
51 Real logZ;
52 try {
53 clamped->run();
54 logZ = clamped->logZ();
55 } catch( Exception &e ) {
56 if( e.code() == Exception::NOT_NORMALIZABLE )
57 logZ = -INFINITY;
58 else
59 throw;
60 }
61
62 if( logZ0 == -INFINITY )
63 if( logZ != -INFINITY )
64 logZ0 = logZ;
65
66 if( logZ == -INFINITY )
67 Pns[s] = 0;
68 else
69 Pns[s] = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
70
71 // restore clamped factors
72 clamped->restoreFactors( ns );
73 }
74
75 delete clamped;
76
77 return( Pns.normalized() );
78 }
79
80
81 /// Calculates beliefs of all pairs in ns (by clamping nodes in ns and calculating logZ and the beliefs for each state).
82 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
83 */
84 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit ) {
85 // convert ns to vector<VarSet>
86 size_t N = ns.size();
87 vector<Var> vns;
88 vns.reserve( N );
89 map<Var,size_t> varindices;
90 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ ) {
91 vns.push_back( *n );
92 varindices[*n] = obj.fg().findVar( *n );
93 }
94
95 vector<Factor> pairbeliefs;
96 pairbeliefs.reserve( N * N );
97 for( size_t j = 0; j < N; j++ )
98 for( size_t k = 0; k < N; k++ )
99 if( j == k )
100 pairbeliefs.push_back( Factor() );
101 else
102 pairbeliefs.push_back( Factor( VarSet(vns[j], vns[k]) ) );
103
104 InfAlg *clamped = obj.clone();
105 if( !reInit )
106 clamped->init();
107
108 Real logZ0 = -INFINITY;
109 for( size_t j = 0; j < N; j++ ) {
110 // clamp Var j to its possible values
111 for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
112 clamped->clamp( varindices[vns[j]], j_val, true );
113 if( reInit )
114 clamped->init();
115 else
116 clamped->init(ns);
117
118 Real logZ;
119 try {
120 clamped->run();
121 logZ = clamped->logZ();
122 } catch( Exception &e ) {
123 if( e.code() == Exception::NOT_NORMALIZABLE )
124 logZ = -INFINITY;
125 else
126 throw;
127 }
128
129 if( logZ0 == -INFINITY )
130 if( logZ != -INFINITY )
131 logZ0 = logZ;
132
133 double Z_xj;
134 if( logZ == -INFINITY )
135 Z_xj = 0;
136 else
137 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
138
139 for( size_t k = 0; k < N; k++ )
140 if( k != j ) {
141 Factor b_k = clamped->belief(vns[k]);
142 for( size_t k_val = 0; k_val < vns[k].states(); k_val++ )
143 if( vns[j].label() < vns[k].label() )
144 pairbeliefs[j * N + k][j_val + (k_val * vns[j].states())] = Z_xj * b_k[k_val];
145 else
146 pairbeliefs[j * N + k][k_val + (j_val * vns[k].states())] = Z_xj * b_k[k_val];
147 }
148
149 // restore clamped factors
150 clamped->restoreFactors( ns );
151 }
152 }
153
154 delete clamped;
155
156 // Calculate result by taking the geometric average
157 vector<Factor> result;
158 result.reserve( N * (N - 1) / 2 );
159 for( size_t j = 0; j < N; j++ )
160 for( size_t k = j+1; k < N; k++ )
161 result.push_back( ((pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5).normalized() );
162
163 return result;
164 }
165
166
167 /// Calculates beliefs of all pairs in ns (by clamping pairs in ns and calculating logZ for each joined state).
168 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
169 */
170 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
171 // returns a a probability distribution whose 1st order interactions
172 // are unspecified, whose 2nd order interactions approximate those of
173 // the marginal on ns, and whose higher order interactions are absent.
174
175 vector<Factor> pairbeliefs = calcPairBeliefs( obj, ns, reInit );
176
177 Factor Pns (ns);
178 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
179 Pns *= pairbeliefs[ij];
180
181 return( Pns.normalized() );
182 }
183
184
185 /// Calculates 2nd order interactions of the marginal of obj on ns.
186 vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit ) {
187 vector<Factor> result;
188 result.reserve( ns.size() * (ns.size() - 1) / 2 );
189
190 InfAlg *clamped = obj.clone();
191 if( !reInit )
192 clamped->init();
193
194 map<Var,size_t> varindices;
195 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
196 varindices[*n] = obj.fg().findVar( *n );
197
198 Real logZ0 = 0.0;
199 VarSet::const_iterator nj = ns.begin();
200 for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
201 size_t k = 0;
202 for( VarSet::const_iterator nk = nj; (++nk) != ns.end(); k++ ) {
203 Factor pairbelief( VarSet(*nj, *nk) );
204
205 // clamp Vars j and k to their possible values
206 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
207 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
208 // save unclamped factors connected to ns
209 clamped->backupFactors( ns );
210
211 clamped->clamp( varindices[*nj], j_val );
212 clamped->clamp( varindices[*nk], k_val );
213 if( reInit )
214 clamped->init();
215 else
216 clamped->init(ns);
217
218 Real logZ;
219 try {
220 clamped->run();
221 logZ = clamped->logZ();
222 } catch( Exception &e ) {
223 if( e.code() == Exception::NOT_NORMALIZABLE )
224 logZ = -INFINITY;
225 else
226 throw;
227 }
228
229 if( logZ0 == -INFINITY )
230 if( logZ != -INFINITY )
231 logZ0 = logZ;
232
233 double Z_xj;
234 if( logZ == -INFINITY )
235 Z_xj = 0;
236 else
237 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
238
239 // we assume that j.label() < k.label()
240 // i.e. we make an assumption here about the indexing
241 pairbelief[j_val + (k_val * nj->states())] = Z_xj;
242
243 // restore clamped factors
244 clamped->restoreFactors( ns );
245 }
246
247 result.push_back( pairbelief.normalized() );
248 }
249 }
250
251 delete clamped;
252
253 DAI_ASSERT( result.size() == (ns.size() * (ns.size() - 1) / 2) );
254
255 return result;
256 }
257
258
259 } // end of namespace dai