c4c3d9255b46d08117e6749bf68dc104f49eb5de
[libdai.git] / daialg.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include "daialg.h"
23
24
25 namespace dai {
26
27
28 /// Calculate the marginal of obj on ns by clamping
29 /// all variables in ns and calculating logZ for each joined state
30 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
31 Factor Pns (ns);
32
33 multind mi( ns );
34
35 InfAlg *clamped = obj.clone();
36 if( !reInit )
37 clamped->init();
38
39 Complex logZ0;
40 for( size_t j = 0; j < mi.max(); j++ ) {
41 // save unclamped factors connected to ns
42 clamped->saveProbs( ns );
43
44 // set clamping Factors to delta functions
45 vector<size_t> vi = mi.vi( j );
46 size_t k = 0;
47 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++, k++ )
48 clamped->clamp( *n, vi[k] );
49
50 // run DAIAlg, calc logZ, store in Pns
51 if( clamped->Verbose() >= 2 )
52 cout << j << ": ";
53 if( reInit )
54 clamped->init();
55 clamped->run();
56
57 Complex Z;
58 if( j == 0 ) {
59 logZ0 = clamped->logZ();
60 Z = 1.0;
61 } else {
62 // subtract logZ0 to avoid very large numbers
63 Z = exp(clamped->logZ() - logZ0);
64 if( fabs(imag(Z)) > 1e-5 )
65 cout << "Marginal:: WARNING: complex Z (" << Z << ")" << endl;
66 }
67
68 Pns[j] = real(Z);
69
70 // restore clamped factors
71 clamped->undoProbs( ns );
72 }
73
74 delete clamped;
75
76 return( Pns.normalized(Prob::NORMPROB) );
77 }
78
79
80 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit ) {
81 // convert ns to vector<VarSet>
82 size_t N = ns.size();
83 vector<Var> vns;
84 vns.reserve( N );
85 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
86 vns.push_back( *n );
87
88 vector<Factor> pairbeliefs;
89 pairbeliefs.reserve( N * N );
90 for( size_t j = 0; j < N; j++ )
91 for( size_t k = 0; k < N; k++ )
92 if( j == k )
93 pairbeliefs.push_back(Factor());
94 else
95 pairbeliefs.push_back(Factor(vns[j] | vns[k]));
96
97 InfAlg *clamped = obj.clone();
98 if( !reInit )
99 clamped->init();
100
101 Complex logZ0;
102 for( size_t j = 0; j < N; j++ ) {
103 // clamp Var j to its possible values
104 for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
105 if( obj.Verbose() >= 2 )
106 cout << j << "/" << N-1 << " (" << j_val << "/" << vns[j].states() << "): ";
107
108 // save unclamped factors connected to ns
109 clamped->saveProbs( ns );
110
111 clamped->clamp( vns[j], j_val );
112 if( reInit )
113 clamped->init();
114 clamped->run();
115
116 //if( j == 0 )
117 // logZ0 = obj.logZ();
118 double Z_xj = 1.0;
119 if( j == 0 && j_val == 0 ) {
120 logZ0 = clamped->logZ();
121 } else {
122 // subtract logZ0 to avoid very large numbers
123 Complex Z = exp(clamped->logZ() - logZ0);
124 if( fabs(imag(Z)) > 1e-5 )
125 cout << "calcPairBelief:: Warning: complex Z: " << Z << endl;
126 Z_xj = real(Z);
127 }
128
129 for( size_t k = 0; k < N; k++ )
130 if( k != j ) {
131 Factor b_k = clamped->belief(vns[k]);
132 for( size_t k_val = 0; k_val < vns[k].states(); k_val++ )
133 if( vns[j].label() < vns[k].label() )
134 pairbeliefs[j * N + k][j_val + (k_val * vns[j].states())] = Z_xj * b_k[k_val];
135 else
136 pairbeliefs[j * N + k][k_val + (j_val * vns[k].states())] = Z_xj * b_k[k_val];
137 }
138
139 // restore clamped factors
140 clamped->undoProbs( ns );
141 }
142 }
143
144 delete clamped;
145
146 // Calculate result by taking the geometric average
147 vector<Factor> result;
148 result.reserve( N * (N - 1) / 2 );
149 for( size_t j = 0; j < N; j++ )
150 for( size_t k = j+1; k < N; k++ )
151 result.push_back( (pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5 );
152
153 return result;
154 }
155
156
157 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
158 // returns a a probability distribution whose 1st order interactions
159 // are unspecified, whose 2nd order interactions approximate those of
160 // the marginal on ns, and whose higher order interactions are absent.
161
162 vector<Factor> pairbeliefs = calcPairBeliefs( obj, ns, reInit );
163
164 Factor Pns (ns);
165 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
166 Pns *= pairbeliefs[ij];
167
168 return( Pns.normalized(Prob::NORMPROB) );
169 }
170
171
172 vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit ) {
173 vector<Factor> result;
174 result.reserve( ns.size() * (ns.size() - 1) / 2 );
175
176 InfAlg *clamped = obj.clone();
177 if( !reInit )
178 clamped->init();
179
180 Complex logZ0;
181 VarSet::const_iterator nj = ns.begin();
182 for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
183 size_t k = 0;
184 for( VarSet::const_iterator nk = nj; (++nk) != ns.end(); k++ ) {
185 Factor pairbelief( *nj | *nk );
186
187 // clamp Vars j and k to their possible values
188 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
189 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
190 // save unclamped factors connected to ns
191 clamped->saveProbs( ns );
192
193 clamped->clamp( *nj, j_val );
194 clamped->clamp( *nk, k_val );
195 if( reInit )
196 clamped->init();
197 clamped->run();
198
199 double Z_xj = 1.0;
200 if( j_val == 0 && k_val == 0 ) {
201 logZ0 = clamped->logZ();
202 } else {
203 // subtract logZ0 to avoid very large numbers
204 Complex Z = exp(clamped->logZ() - logZ0);
205 if( fabs(imag(Z)) > 1e-5 )
206 cout << "calcPairBelief:: Warning: complex Z: " << Z << endl;
207 Z_xj = real(Z);
208 }
209
210 // we assume that j.label() < k.label()
211 // i.e. we make an assumption here about the indexing
212 pairbelief[j_val + (k_val * nj->states())] = Z_xj;
213
214 // restore clamped factors
215 clamped->undoProbs( ns );
216 }
217
218 result.push_back( pairbelief );
219 }
220 }
221
222 delete clamped;
223
224 assert( result.size() == (ns.size() * (ns.size() - 1) / 2) );
225
226 return result;
227 }
228
229
230 }