Cleaned up variable elimination code in ClusterGraph
[libdai.git] / src / daialg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <vector>
13 #include <dai/daialg.h>
14
15
16 namespace dai {
17
18
19 using namespace std;
20
21
22 Factor calcMarginal( const InfAlg &obj, const VarSet &vs, bool reInit ) {
23 Factor Pvs (vs);
24
25 InfAlg *clamped = obj.clone();
26 if( !reInit )
27 clamped->init();
28
29 map<Var,size_t> varindices;
30 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ )
31 varindices[*n] = obj.fg().findVar( *n );
32
33 Real logZ0 = -INFINITY;
34 for( State s(vs); s.valid(); s++ ) {
35 // save unclamped factors connected to vs
36 clamped->backupFactors( vs );
37
38 // set clamping Factors to delta functions
39 for( VarSet::const_iterator n = vs.begin(); n != vs.end(); n++ )
40 clamped->clamp( varindices[*n], s(*n) );
41
42 // run DAIAlg, calc logZ, store in Pvs
43 if( reInit )
44 clamped->init();
45 else
46 clamped->init(vs);
47
48 Real logZ;
49 try {
50 clamped->run();
51 logZ = clamped->logZ();
52 } catch( Exception &e ) {
53 if( e.code() == Exception::NOT_NORMALIZABLE )
54 logZ = -INFINITY;
55 else
56 throw;
57 }
58
59 if( logZ0 == -INFINITY )
60 if( logZ != -INFINITY )
61 logZ0 = logZ;
62
63 if( logZ == -INFINITY )
64 Pvs[s] = 0;
65 else
66 Pvs[s] = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
67
68 // restore clamped factors
69 clamped->restoreFactors( vs );
70 }
71
72 delete clamped;
73
74 return( Pvs.normalized() );
75 }
76
77
78 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& vs, bool reInit, bool accurate ) {
79 vector<Factor> result;
80 size_t N = vs.size();
81 result.reserve( N * (N - 1) / 2 );
82
83 InfAlg *clamped = obj.clone();
84 if( !reInit )
85 clamped->init();
86
87 map<Var,size_t> varindices;
88 for( VarSet::const_iterator v = vs.begin(); v != vs.end(); v++ )
89 varindices[*v] = obj.fg().findVar( *v );
90
91 if( accurate ) {
92 Real logZ0 = 0.0;
93 VarSet::const_iterator nj = vs.begin();
94 for( long j = 0; j < (long)N - 1; j++, nj++ ) {
95 size_t k = 0;
96 for( VarSet::const_iterator nk = nj; (++nk) != vs.end(); k++ ) {
97 Factor pairbelief( VarSet(*nj, *nk) );
98
99 // clamp Vars j and k to their possible values
100 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
101 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
102 // save unclamped factors connected to vs
103 clamped->backupFactors( vs );
104
105 clamped->clamp( varindices[*nj], j_val );
106 clamped->clamp( varindices[*nk], k_val );
107 if( reInit )
108 clamped->init();
109 else
110 clamped->init(vs);
111
112 Real logZ;
113 try {
114 clamped->run();
115 logZ = clamped->logZ();
116 } catch( Exception &e ) {
117 if( e.code() == Exception::NOT_NORMALIZABLE )
118 logZ = -INFINITY;
119 else
120 throw;
121 }
122
123 if( logZ0 == -INFINITY )
124 if( logZ != -INFINITY )
125 logZ0 = logZ;
126
127 Real Z_xj;
128 if( logZ == -INFINITY )
129 Z_xj = 0;
130 else
131 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
132
133 // we assume that j.label() < k.label()
134 // i.e. we make an assumption here about the indexing
135 pairbelief[j_val + (k_val * nj->states())] = Z_xj;
136
137 // restore clamped factors
138 clamped->restoreFactors( vs );
139 }
140
141 result.push_back( pairbelief.normalized() );
142 }
143 }
144 } else {
145 // convert vs to vector<VarSet>
146 vector<Var> vvs( vs.begin(), vs.end() );
147
148 vector<Factor> pairbeliefs;
149 pairbeliefs.reserve( N * N );
150 for( size_t j = 0; j < N; j++ )
151 for( size_t k = 0; k < N; k++ )
152 if( j == k )
153 pairbeliefs.push_back( Factor() );
154 else
155 pairbeliefs.push_back( Factor( VarSet(vvs[j], vvs[k]) ) );
156
157 Real logZ0 = -INFINITY;
158 for( size_t j = 0; j < N; j++ ) {
159 // clamp Var j to its possible values
160 for( size_t j_val = 0; j_val < vvs[j].states(); j_val++ ) {
161 clamped->clamp( varindices[vvs[j]], j_val, true );
162 if( reInit )
163 clamped->init();
164 else
165 clamped->init(vs);
166
167 Real logZ;
168 try {
169 clamped->run();
170 logZ = clamped->logZ();
171 } catch( Exception &e ) {
172 if( e.code() == Exception::NOT_NORMALIZABLE )
173 logZ = -INFINITY;
174 else
175 throw;
176 }
177
178 if( logZ0 == -INFINITY )
179 if( logZ != -INFINITY )
180 logZ0 = logZ;
181
182 Real Z_xj;
183 if( logZ == -INFINITY )
184 Z_xj = 0;
185 else
186 Z_xj = exp(logZ - logZ0); // subtract logZ0 to avoid very large numbers
187
188 for( size_t k = 0; k < N; k++ )
189 if( k != j ) {
190 Factor b_k = clamped->belief(vvs[k]);
191 for( size_t k_val = 0; k_val < vvs[k].states(); k_val++ )
192 if( vvs[j].label() < vvs[k].label() )
193 pairbeliefs[j * N + k][j_val + (k_val * vvs[j].states())] = Z_xj * b_k[k_val];
194 else
195 pairbeliefs[j * N + k][k_val + (j_val * vvs[k].states())] = Z_xj * b_k[k_val];
196 }
197
198 // restore clamped factors
199 clamped->restoreFactors( vs );
200 }
201 }
202
203 // Calculate result by taking the geometric average
204 for( size_t j = 0; j < N; j++ )
205 for( size_t k = j+1; k < N; k++ )
206 result.push_back( ((pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5).normalized() );
207 }
208 delete clamped;
209 return result;
210 }
211
212
213 std::vector<Factor> calcPairBeliefsNew( const InfAlg& obj, const VarSet& vs, bool reInit ) {
214 return calcPairBeliefs( obj, vs, reInit, true );
215 }
216
217
218 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& vs, bool reInit ) {
219 vector<Factor> pairbeliefs = calcPairBeliefs( obj, vs, reInit );
220
221 Factor Pvs (vs);
222 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
223 Pvs *= pairbeliefs[ij];
224
225 return( Pvs.normalized() );
226 }
227
228
229 } // end of namespace dai