Merge branch 'no-edges2'
[libdai.git] / src / daialg.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <vector>
23 #include <dai/daialg.h>
24
25
26 namespace dai {
27
28
29 using namespace std;
30
31
32 /// Calculate the marginal of obj on ns by clamping
33 /// all variables in ns and calculating logZ for each joined state
34 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
35 Factor Pns (ns);
36
37 multind mi( ns );
38
39 InfAlg *clamped = obj.clone();
40 if( !reInit )
41 clamped->init();
42
43 Complex logZ0;
44 for( size_t j = 0; j < mi.max(); j++ ) {
45 // save unclamped factors connected to ns
46 clamped->saveProbs( ns );
47
48 // set clamping Factors to delta functions
49 vector<size_t> vi = mi.vi( j );
50 size_t k = 0;
51 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++, k++ )
52 clamped->clamp( *n, vi[k] );
53
54 // run DAIAlg, calc logZ, store in Pns
55 if( clamped->Verbose() >= 2 )
56 cout << j << ": ";
57 if( reInit )
58 clamped->init();
59 clamped->run();
60
61 Complex Z;
62 if( j == 0 ) {
63 logZ0 = clamped->logZ();
64 Z = 1.0;
65 } else {
66 // subtract logZ0 to avoid very large numbers
67 Z = exp(clamped->logZ() - logZ0);
68 if( fabs(imag(Z)) > 1e-5 )
69 cout << "Marginal:: WARNING: complex Z (" << Z << ")" << endl;
70 }
71
72 Pns[j] = real(Z);
73
74 // restore clamped factors
75 clamped->undoProbs( ns );
76 }
77
78 delete clamped;
79
80 return( Pns.normalized(Prob::NORMPROB) );
81 }
82
83
84 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit ) {
85 // convert ns to vector<VarSet>
86 size_t N = ns.size();
87 vector<Var> vns;
88 vns.reserve( N );
89 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
90 vns.push_back( *n );
91
92 vector<Factor> pairbeliefs;
93 pairbeliefs.reserve( N * N );
94 for( size_t j = 0; j < N; j++ )
95 for( size_t k = 0; k < N; k++ )
96 if( j == k )
97 pairbeliefs.push_back(Factor());
98 else
99 pairbeliefs.push_back(Factor(vns[j] | vns[k]));
100
101 InfAlg *clamped = obj.clone();
102 if( !reInit )
103 clamped->init();
104
105 Complex logZ0;
106 for( size_t j = 0; j < N; j++ ) {
107 // clamp Var j to its possible values
108 for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
109 if( obj.Verbose() >= 2 )
110 cout << j << "/" << N-1 << " (" << j_val << "/" << vns[j].states() << "): ";
111
112 // save unclamped factors connected to ns
113 clamped->saveProbs( ns );
114
115 clamped->clamp( vns[j], j_val );
116 if( reInit )
117 clamped->init();
118 clamped->run();
119
120 //if( j == 0 )
121 // logZ0 = obj.logZ();
122 double Z_xj = 1.0;
123 if( j == 0 && j_val == 0 ) {
124 logZ0 = clamped->logZ();
125 } else {
126 // subtract logZ0 to avoid very large numbers
127 Complex Z = exp(clamped->logZ() - logZ0);
128 if( fabs(imag(Z)) > 1e-5 )
129 cout << "calcPairBelief:: Warning: complex Z: " << Z << endl;
130 Z_xj = real(Z);
131 }
132
133 for( size_t k = 0; k < N; k++ )
134 if( k != j ) {
135 Factor b_k = clamped->belief(vns[k]);
136 for( size_t k_val = 0; k_val < vns[k].states(); k_val++ )
137 if( vns[j].label() < vns[k].label() )
138 pairbeliefs[j * N + k][j_val + (k_val * vns[j].states())] = Z_xj * b_k[k_val];
139 else
140 pairbeliefs[j * N + k][k_val + (j_val * vns[k].states())] = Z_xj * b_k[k_val];
141 }
142
143 // restore clamped factors
144 clamped->undoProbs( ns );
145 }
146 }
147
148 delete clamped;
149
150 // Calculate result by taking the geometric average
151 vector<Factor> result;
152 result.reserve( N * (N - 1) / 2 );
153 for( size_t j = 0; j < N; j++ )
154 for( size_t k = j+1; k < N; k++ )
155 result.push_back( (pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5 );
156
157 return result;
158 }
159
160
161 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
162 // returns a a probability distribution whose 1st order interactions
163 // are unspecified, whose 2nd order interactions approximate those of
164 // the marginal on ns, and whose higher order interactions are absent.
165
166 vector<Factor> pairbeliefs = calcPairBeliefs( obj, ns, reInit );
167
168 Factor Pns (ns);
169 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
170 Pns *= pairbeliefs[ij];
171
172 return( Pns.normalized(Prob::NORMPROB) );
173 }
174
175
176 vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit ) {
177 vector<Factor> result;
178 result.reserve( ns.size() * (ns.size() - 1) / 2 );
179
180 InfAlg *clamped = obj.clone();
181 if( !reInit )
182 clamped->init();
183
184 Complex logZ0;
185 VarSet::const_iterator nj = ns.begin();
186 for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
187 size_t k = 0;
188 for( VarSet::const_iterator nk = nj; (++nk) != ns.end(); k++ ) {
189 Factor pairbelief( *nj | *nk );
190
191 // clamp Vars j and k to their possible values
192 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
193 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
194 // save unclamped factors connected to ns
195 clamped->saveProbs( ns );
196
197 clamped->clamp( *nj, j_val );
198 clamped->clamp( *nk, k_val );
199 if( reInit )
200 clamped->init();
201 clamped->run();
202
203 double Z_xj = 1.0;
204 if( j_val == 0 && k_val == 0 ) {
205 logZ0 = clamped->logZ();
206 } else {
207 // subtract logZ0 to avoid very large numbers
208 Complex Z = exp(clamped->logZ() - logZ0);
209 if( fabs(imag(Z)) > 1e-5 )
210 cout << "calcPairBelief:: Warning: complex Z: " << Z << endl;
211 Z_xj = real(Z);
212 }
213
214 // we assume that j.label() < k.label()
215 // i.e. we make an assumption here about the indexing
216 pairbelief[j_val + (k_val * nj->states())] = Z_xj;
217
218 // restore clamped factors
219 clamped->undoProbs( ns );
220 }
221
222 result.push_back( pairbelief );
223 }
224 }
225
226 delete clamped;
227
228 assert( result.size() == (ns.size() * (ns.size() - 1) / 2) );
229
230 return result;
231 }
232
233
234 } // end of namespace dai