Various cleanups
[libdai.git] / src / daialg.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <vector>
24 #include <dai/daialg.h>
25
26
27 namespace dai {
28
29
30 using namespace std;
31
32
33 /// Calculates the marginal of obj on ns by clamping all variables in ns and calculating logZ for each joined state.
34 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
35 */
36 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
37 Factor Pns (ns);
38
39 InfAlg *clamped = obj.clone();
40 if( !reInit )
41 clamped->init();
42
43 Real logZ0 = -INFINITY;
44 for( State s(ns); s.valid(); s++ ) {
45 // save unclamped factors connected to ns
46 clamped->backupFactors( ns );
47
48 // set clamping Factors to delta functions
49 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
50 clamped->clamp( *n, s(*n) );
51
52 // run DAIAlg, calc logZ, store in Pns
53 if( reInit )
54 clamped->init();
55 else
56 clamped->init(ns);
57
58 Real logZ;
59 try {
60 clamped->run();
61 logZ = clamped->logZ();
62 } catch( Exception &e ) {
63 if( e.code() == Exception::NOT_NORMALIZABLE )
64 logZ = -INFINITY;
65 else
66 throw;
67 }
68
69 if( logZ0 == -INFINITY )
70 if( logZ != -INFINITY )
71 logZ0 = logZ;
72
73 if( logZ == -INFINITY )
74 Pns[s] = 0;
75 else
76 Pns[s] = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
77
78 // restore clamped factors
79 clamped->restoreFactors( ns );
80 }
81
82 delete clamped;
83
84 return( Pns.normalized() );
85 }
86
87
88 /// Calculates beliefs of all pairs in ns (by clamping nodes in ns and calculating logZ and the beliefs for each state).
89 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
90 */
91 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit ) {
92 // convert ns to vector<VarSet>
93 size_t N = ns.size();
94 vector<Var> vns;
95 vns.reserve( N );
96 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
97 vns.push_back( *n );
98
99 vector<Factor> pairbeliefs;
100 pairbeliefs.reserve( N * N );
101 for( size_t j = 0; j < N; j++ )
102 for( size_t k = 0; k < N; k++ )
103 if( j == k )
104 pairbeliefs.push_back( Factor() );
105 else
106 pairbeliefs.push_back( Factor( VarSet(vns[j], vns[k]) ) );
107
108 InfAlg *clamped = obj.clone();
109 if( !reInit )
110 clamped->init();
111
112 Real logZ0 = -INFINITY;
113 for( size_t j = 0; j < N; j++ ) {
114 // clamp Var j to its possible values
115 for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
116 clamped->clamp( vns[j], j_val, true );
117 if( reInit )
118 clamped->init();
119 else
120 clamped->init(ns);
121
122 Real logZ;
123 try {
124 clamped->run();
125 logZ = clamped->logZ();
126 } catch( Exception &e ) {
127 if( e.code() == Exception::NOT_NORMALIZABLE )
128 logZ = -INFINITY;
129 else
130 throw;
131 }
132
133 if( logZ0 == -INFINITY )
134 if( logZ != -INFINITY )
135 logZ0 = logZ;
136
137 double Z_xj;
138 if( logZ == -INFINITY )
139 Z_xj = 0;
140 else
141 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
142
143 for( size_t k = 0; k < N; k++ )
144 if( k != j ) {
145 Factor b_k = clamped->belief(vns[k]);
146 for( size_t k_val = 0; k_val < vns[k].states(); k_val++ )
147 if( vns[j].label() < vns[k].label() )
148 pairbeliefs[j * N + k][j_val + (k_val * vns[j].states())] = Z_xj * b_k[k_val];
149 else
150 pairbeliefs[j * N + k][k_val + (j_val * vns[k].states())] = Z_xj * b_k[k_val];
151 }
152
153 // restore clamped factors
154 clamped->restoreFactors( ns );
155 }
156 }
157
158 delete clamped;
159
160 // Calculate result by taking the geometric average
161 vector<Factor> result;
162 result.reserve( N * (N - 1) / 2 );
163 for( size_t j = 0; j < N; j++ )
164 for( size_t k = j+1; k < N; k++ )
165 result.push_back( ((pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5).normalized() );
166
167 return result;
168 }
169
170
171 /// Calculates beliefs of all pairs in ns (by clamping pairs in ns and calculating logZ for each joined state).
172 /* reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
173 */
174 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
175 // returns a a probability distribution whose 1st order interactions
176 // are unspecified, whose 2nd order interactions approximate those of
177 // the marginal on ns, and whose higher order interactions are absent.
178
179 vector<Factor> pairbeliefs = calcPairBeliefs( obj, ns, reInit );
180
181 Factor Pns (ns);
182 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
183 Pns *= pairbeliefs[ij];
184
185 return( Pns.normalized() );
186 }
187
188
189 /// Calculates 2nd order interactions of the marginal of obj on ns.
190 vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit ) {
191 vector<Factor> result;
192 result.reserve( ns.size() * (ns.size() - 1) / 2 );
193
194 InfAlg *clamped = obj.clone();
195 if( !reInit )
196 clamped->init();
197
198 Real logZ0 = 0.0;
199 VarSet::const_iterator nj = ns.begin();
200 for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
201 size_t k = 0;
202 for( VarSet::const_iterator nk = nj; (++nk) != ns.end(); k++ ) {
203 Factor pairbelief( VarSet(*nj, *nk) );
204
205 // clamp Vars j and k to their possible values
206 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
207 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
208 // save unclamped factors connected to ns
209 clamped->backupFactors( ns );
210
211 clamped->clamp( *nj, j_val );
212 clamped->clamp( *nk, k_val );
213 if( reInit )
214 clamped->init();
215 else
216 clamped->init(ns);
217
218 Real logZ;
219 try {
220 clamped->run();
221 logZ = clamped->logZ();
222 } catch( Exception &e ) {
223 if( e.code() == Exception::NOT_NORMALIZABLE )
224 logZ = -INFINITY;
225 else
226 throw;
227 }
228
229 if( logZ0 == -INFINITY )
230 if( logZ != -INFINITY )
231 logZ0 = logZ;
232
233 double Z_xj;
234 if( logZ == -INFINITY )
235 Z_xj = 0;
236 else
237 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
238
239 // we assume that j.label() < k.label()
240 // i.e. we make an assumption here about the indexing
241 pairbelief[j_val + (k_val * nj->states())] = Z_xj;
242
243 // restore clamped factors
244 clamped->restoreFactors( ns );
245 }
246
247 result.push_back( pairbelief.normalized() );
248 }
249 }
250
251 delete clamped;
252
253 assert( result.size() == (ns.size() * (ns.size() - 1) / 2) );
254
255 return result;
256 }
257
258
259 } // end of namespace dai