Removed stuff from InfAlg, moved it to individual inference algorithms
[libdai.git] / src / daialg.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <vector>
23 #include <dai/daialg.h>
24
25
26 namespace dai {
27
28
29 using namespace std;
30
31
32 /// Calculate the marginal of obj on ns by clamping
33 /// all variables in ns and calculating logZ for each joined state
34 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
35 Factor Pns (ns);
36
37 InfAlg *clamped = obj.clone();
38 if( !reInit )
39 clamped->init();
40
41 Complex logZ0;
42 for( State s(ns); s.valid(); s++ ) {
43 // save unclamped factors connected to ns
44 clamped->saveProbs( ns );
45
46 // set clamping Factors to delta functions
47 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
48 clamped->clamp( *n, s(*n) );
49
50 // run DAIAlg, calc logZ, store in Pns
51 if( reInit )
52 clamped->init();
53 clamped->run();
54
55 Complex Z;
56 if( s == 0 ) {
57 logZ0 = clamped->logZ();
58 Z = 1.0;
59 } else {
60 // subtract logZ0 to avoid very large numbers
61 Z = exp(clamped->logZ() - logZ0);
62 if( fabs(imag(Z)) > 1e-5 )
63 cout << "Marginal:: WARNING: complex Z (" << Z << ")" << endl;
64 }
65
66 Pns[s] = real(Z);
67
68 // restore clamped factors
69 clamped->undoProbs( ns );
70 }
71
72 delete clamped;
73
74 return( Pns.normalized(Prob::NORMPROB) );
75 }
76
77
78 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit ) {
79 // convert ns to vector<VarSet>
80 size_t N = ns.size();
81 vector<Var> vns;
82 vns.reserve( N );
83 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
84 vns.push_back( *n );
85
86 vector<Factor> pairbeliefs;
87 pairbeliefs.reserve( N * N );
88 for( size_t j = 0; j < N; j++ )
89 for( size_t k = 0; k < N; k++ )
90 if( j == k )
91 pairbeliefs.push_back(Factor());
92 else
93 pairbeliefs.push_back(Factor(vns[j] | vns[k]));
94
95 InfAlg *clamped = obj.clone();
96 if( !reInit )
97 clamped->init();
98
99 Complex logZ0;
100 for( size_t j = 0; j < N; j++ ) {
101 // clamp Var j to its possible values
102 for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
103 // save unclamped factors connected to ns
104 clamped->saveProbs( ns );
105
106 clamped->clamp( vns[j], j_val );
107 if( reInit )
108 clamped->init();
109 clamped->run();
110
111 //if( j == 0 )
112 // logZ0 = obj.logZ();
113 double Z_xj = 1.0;
114 if( j == 0 && j_val == 0 ) {
115 logZ0 = clamped->logZ();
116 } else {
117 // subtract logZ0 to avoid very large numbers
118 Complex Z = exp(clamped->logZ() - logZ0);
119 if( fabs(imag(Z)) > 1e-5 )
120 cout << "calcPairBelief:: Warning: complex Z: " << Z << endl;
121 Z_xj = real(Z);
122 }
123
124 for( size_t k = 0; k < N; k++ )
125 if( k != j ) {
126 Factor b_k = clamped->belief(vns[k]);
127 for( size_t k_val = 0; k_val < vns[k].states(); k_val++ )
128 if( vns[j].label() < vns[k].label() )
129 pairbeliefs[j * N + k][j_val + (k_val * vns[j].states())] = Z_xj * b_k[k_val];
130 else
131 pairbeliefs[j * N + k][k_val + (j_val * vns[k].states())] = Z_xj * b_k[k_val];
132 }
133
134 // restore clamped factors
135 clamped->undoProbs( ns );
136 }
137 }
138
139 delete clamped;
140
141 // Calculate result by taking the geometric average
142 vector<Factor> result;
143 result.reserve( N * (N - 1) / 2 );
144 for( size_t j = 0; j < N; j++ )
145 for( size_t k = j+1; k < N; k++ )
146 result.push_back( (pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5 );
147
148 return result;
149 }
150
151
152 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
153 // returns a a probability distribution whose 1st order interactions
154 // are unspecified, whose 2nd order interactions approximate those of
155 // the marginal on ns, and whose higher order interactions are absent.
156
157 vector<Factor> pairbeliefs = calcPairBeliefs( obj, ns, reInit );
158
159 Factor Pns (ns);
160 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
161 Pns *= pairbeliefs[ij];
162
163 return( Pns.normalized(Prob::NORMPROB) );
164 }
165
166
167 vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit ) {
168 vector<Factor> result;
169 result.reserve( ns.size() * (ns.size() - 1) / 2 );
170
171 InfAlg *clamped = obj.clone();
172 if( !reInit )
173 clamped->init();
174
175 Complex logZ0;
176 VarSet::const_iterator nj = ns.begin();
177 for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
178 size_t k = 0;
179 for( VarSet::const_iterator nk = nj; (++nk) != ns.end(); k++ ) {
180 Factor pairbelief( *nj | *nk );
181
182 // clamp Vars j and k to their possible values
183 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
184 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
185 // save unclamped factors connected to ns
186 clamped->saveProbs( ns );
187
188 clamped->clamp( *nj, j_val );
189 clamped->clamp( *nk, k_val );
190 if( reInit )
191 clamped->init();
192 clamped->run();
193
194 double Z_xj = 1.0;
195 if( j_val == 0 && k_val == 0 ) {
196 logZ0 = clamped->logZ();
197 } else {
198 // subtract logZ0 to avoid very large numbers
199 Complex Z = exp(clamped->logZ() - logZ0);
200 if( fabs(imag(Z)) > 1e-5 )
201 cout << "calcPairBelief:: Warning: complex Z: " << Z << endl;
202 Z_xj = real(Z);
203 }
204
205 // we assume that j.label() < k.label()
206 // i.e. we make an assumption here about the indexing
207 pairbelief[j_val + (k_val * nj->states())] = Z_xj;
208
209 // restore clamped factors
210 clamped->undoProbs( ns );
211 }
212
213 result.push_back( pairbelief );
214 }
215 }
216
217 delete clamped;
218
219 assert( result.size() == (ns.size() * (ns.size() - 1) / 2) );
220
221 return result;
222 }
223
224
225 } // end of namespace dai