023a64a712c7cf0b19e611f3005fe64f041c0ec1
[libdai.git] / daialg.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include "daialg.h"
23
24
25 /// Calculate the marginal of obj on ns by clamping
26 /// all variables in ns and calculating logZ for each joined state
27 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
28 Factor Pns (ns);
29
30 multind mi( ns );
31
32 InfAlg *clamped = obj.clone();
33 if( !reInit )
34 clamped->init();
35
36 Complex logZ0;
37 for( size_t j = 0; j < mi.max(); j++ ) {
38 // save unclamped factors connected to ns
39 clamped->saveProbs( ns );
40
41 // set clamping Factors to delta functions
42 vector<size_t> vi = mi.vi( j );
43 size_t k = 0;
44 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++, k++ )
45 clamped->clamp( *n, vi[k] );
46
47 // run DAIAlg, calc logZ, store in Pns
48 if( clamped->Verbose() >= 2 )
49 cout << j << ": ";
50 if( reInit )
51 clamped->init();
52 clamped->run();
53
54 Complex Z;
55 if( j == 0 ) {
56 logZ0 = clamped->logZ();
57 Z = 1.0;
58 } else {
59 // subtract logZ0 to avoid very large numbers
60 Z = exp(clamped->logZ() - logZ0);
61 if( fabs(imag(Z)) > 1e-5 )
62 cout << "Marginal:: WARNING: complex Z (" << Z << ")" << endl;
63 }
64
65 Pns[j] = real(Z);
66
67 // restore clamped factors
68 clamped->undoProbs( ns );
69 }
70
71 delete clamped;
72
73 return( Pns.normalized(Prob::NORMPROB) );
74 }
75
76
77 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit ) {
78 // convert ns to vector<VarSet>
79 size_t N = ns.size();
80 vector<Var> vns;
81 vns.reserve( N );
82 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
83 vns.push_back( *n );
84
85 vector<Factor> pairbeliefs;
86 pairbeliefs.reserve( N * N );
87 for( size_t j = 0; j < N; j++ )
88 for( size_t k = 0; k < N; k++ )
89 if( j == k )
90 pairbeliefs.push_back(Factor());
91 else
92 pairbeliefs.push_back(Factor(vns[j] | vns[k]));
93
94 InfAlg *clamped = obj.clone();
95 if( !reInit )
96 clamped->init();
97
98 Complex logZ0;
99 for( size_t j = 0; j < N; j++ ) {
100 // clamp Var j to its possible values
101 for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
102 if( obj.Verbose() >= 2 )
103 cout << j << "/" << N-1 << " (" << j_val << "/" << vns[j].states() << "): ";
104
105 // save unclamped factors connected to ns
106 clamped->saveProbs( ns );
107
108 clamped->clamp( vns[j], j_val );
109 if( reInit )
110 clamped->init();
111 clamped->run();
112
113 //if( j == 0 )
114 // logZ0 = obj.logZ();
115 double Z_xj = 1.0;
116 if( j == 0 && j_val == 0 ) {
117 logZ0 = clamped->logZ();
118 } else {
119 // subtract logZ0 to avoid very large numbers
120 Complex Z = exp(clamped->logZ() - logZ0);
121 if( fabs(imag(Z)) > 1e-5 )
122 cout << "calcPairBelief:: Warning: complex Z: " << Z << endl;
123 Z_xj = real(Z);
124 }
125
126 for( size_t k = 0; k < N; k++ )
127 if( k != j ) {
128 Factor b_k = clamped->belief(vns[k]);
129 for( size_t k_val = 0; k_val < vns[k].states(); k_val++ )
130 if( vns[j].label() < vns[k].label() )
131 pairbeliefs[j * N + k][j_val + (k_val * vns[j].states())] = Z_xj * b_k[k_val];
132 else
133 pairbeliefs[j * N + k][k_val + (j_val * vns[k].states())] = Z_xj * b_k[k_val];
134 }
135
136 // restore clamped factors
137 clamped->undoProbs( ns );
138 }
139 }
140
141 delete clamped;
142
143 // Calculate result by taking the geometric average
144 vector<Factor> result;
145 result.reserve( N * (N - 1) / 2 );
146 for( size_t j = 0; j < N; j++ )
147 for( size_t k = j+1; k < N; k++ )
148 result.push_back( (pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5 );
149
150 return result;
151 }
152
153
154 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
155 // returns a a probability distribution whose 1st order interactions
156 // are unspecified, whose 2nd order interactions approximate those of
157 // the marginal on ns, and whose higher order interactions are absent.
158
159 vector<Factor> pairbeliefs = calcPairBeliefs( obj, ns, reInit );
160
161 Factor Pns (ns);
162 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
163 Pns *= pairbeliefs[ij];
164
165 return( Pns.normalized(Prob::NORMPROB) );
166 }
167
168
169 vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit ) {
170 vector<Factor> result;
171 result.reserve( ns.size() * (ns.size() - 1) / 2 );
172
173 InfAlg *clamped = obj.clone();
174 if( !reInit )
175 clamped->init();
176
177 Complex logZ0;
178 VarSet::const_iterator nj = ns.begin();
179 for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
180 size_t k = 0;
181 for( VarSet::const_iterator nk = nj; (++nk) != ns.end(); k++ ) {
182 Factor pairbelief( *nj | *nk );
183
184 // clamp Vars j and k to their possible values
185 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
186 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
187 // save unclamped factors connected to ns
188 clamped->saveProbs( ns );
189
190 clamped->clamp( *nj, j_val );
191 clamped->clamp( *nk, k_val );
192 if( reInit )
193 clamped->init();
194 clamped->run();
195
196 double Z_xj = 1.0;
197 if( j_val == 0 && k_val == 0 ) {
198 logZ0 = clamped->logZ();
199 } else {
200 // subtract logZ0 to avoid very large numbers
201 Complex Z = exp(clamped->logZ() - logZ0);
202 if( fabs(imag(Z)) > 1e-5 )
203 cout << "calcPairBelief:: Warning: complex Z: " << Z << endl;
204 Z_xj = real(Z);
205 }
206
207 // we assume that j.label() < k.label()
208 // i.e. we make an assumption here about the indexing
209 pairbelief[j_val + (k_val * nj->states())] = Z_xj;
210
211 // restore clamped factors
212 clamped->undoProbs( ns );
213 }
214
215 result.push_back( pairbelief );
216 }
217 }
218
219 delete clamped;
220
221 assert( result.size() == (ns.size() * (ns.size() - 1) / 2) );
222
223 return result;
224 }