Improved documentation of factor.h, ...
[libdai.git] / src / daialg.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <vector>
24 #include <dai/daialg.h>
25
26
27 namespace dai {
28
29
30 using namespace std;
31
32
33 /// Calculates the marginal of obj on ns by clamping all variables in ns and calculating logZ for each joined state.
34 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
35 Factor Pns (ns);
36
37 InfAlg *clamped = obj.clone();
38 if( !reInit )
39 clamped->init();
40
41 Real logZ0 = 0.0;
42 for( State s(ns); s.valid(); s++ ) {
43 // save unclamped factors connected to ns
44 clamped->backupFactors( ns );
45
46 // set clamping Factors to delta functions
47 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
48 clamped->clamp( *n, s(*n) );
49
50 // run DAIAlg, calc logZ, store in Pns
51 if( reInit )
52 clamped->init();
53 else
54 clamped->init(ns);
55 clamped->run();
56
57 Real Z;
58 if( s == 0 ) {
59 logZ0 = clamped->logZ();
60 Z = 1.0;
61 } else {
62 // subtract logZ0 to avoid very large numbers
63 Z = exp(clamped->logZ() - logZ0);
64 }
65
66 Pns[s] = Z;
67
68 // restore clamped factors
69 clamped->restoreFactors( ns );
70 }
71
72 delete clamped;
73
74 return( Pns.normalized() );
75 }
76
77
78 /// Calculates beliefs of all pairs in ns (by clamping nodes in ns and calculating logZ and the beliefs for each state).
79 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit ) {
80 // convert ns to vector<VarSet>
81 size_t N = ns.size();
82 vector<Var> vns;
83 vns.reserve( N );
84 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
85 vns.push_back( *n );
86
87 vector<Factor> pairbeliefs;
88 pairbeliefs.reserve( N * N );
89 for( size_t j = 0; j < N; j++ )
90 for( size_t k = 0; k < N; k++ )
91 if( j == k )
92 pairbeliefs.push_back( Factor() );
93 else
94 pairbeliefs.push_back( Factor( VarSet(vns[j], vns[k]) ) );
95
96 InfAlg *clamped = obj.clone();
97 if( !reInit )
98 clamped->init();
99
100 Real logZ0 = 0.0;
101 for( size_t j = 0; j < N; j++ ) {
102 // clamp Var j to its possible values
103 for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
104 clamped->clamp( vns[j], j_val, true );
105 if( reInit )
106 clamped->init();
107 else
108 clamped->init(ns);
109 clamped->run();
110
111 //if( j == 0 )
112 // logZ0 = obj.logZ();
113 double Z_xj = 1.0;
114 if( j == 0 && j_val == 0 ) {
115 logZ0 = clamped->logZ();
116 } else {
117 // subtract logZ0 to avoid very large numbers
118 Z_xj = exp(clamped->logZ() - logZ0);
119 }
120
121 for( size_t k = 0; k < N; k++ )
122 if( k != j ) {
123 Factor b_k = clamped->belief(vns[k]);
124 for( size_t k_val = 0; k_val < vns[k].states(); k_val++ )
125 if( vns[j].label() < vns[k].label() )
126 pairbeliefs[j * N + k][j_val + (k_val * vns[j].states())] = Z_xj * b_k[k_val];
127 else
128 pairbeliefs[j * N + k][k_val + (j_val * vns[k].states())] = Z_xj * b_k[k_val];
129 }
130
131 // restore clamped factors
132 clamped->restoreFactors( ns );
133 }
134 }
135
136 delete clamped;
137
138 // Calculate result by taking the geometric average
139 vector<Factor> result;
140 result.reserve( N * (N - 1) / 2 );
141 for( size_t j = 0; j < N; j++ )
142 for( size_t k = j+1; k < N; k++ )
143 result.push_back( (pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5 );
144
145 return result;
146 }
147
148
149 /// Calculates beliefs of all pairs in ns (by clamping pairs in ns and calculating logZ for each joined state).
150 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
151 // returns a a probability distribution whose 1st order interactions
152 // are unspecified, whose 2nd order interactions approximate those of
153 // the marginal on ns, and whose higher order interactions are absent.
154
155 vector<Factor> pairbeliefs = calcPairBeliefs( obj, ns, reInit );
156
157 Factor Pns (ns);
158 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
159 Pns *= pairbeliefs[ij];
160
161 return( Pns.normalized() );
162 }
163
164
165 /// Calculates 2nd order interactions of the marginal of obj on ns.
166 vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit ) {
167 vector<Factor> result;
168 result.reserve( ns.size() * (ns.size() - 1) / 2 );
169
170 InfAlg *clamped = obj.clone();
171 if( !reInit )
172 clamped->init();
173
174 Real logZ0 = 0.0;
175 VarSet::const_iterator nj = ns.begin();
176 for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
177 size_t k = 0;
178 for( VarSet::const_iterator nk = nj; (++nk) != ns.end(); k++ ) {
179 Factor pairbelief( VarSet(*nj, *nk) );
180
181 // clamp Vars j and k to their possible values
182 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
183 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
184 // save unclamped factors connected to ns
185 clamped->backupFactors( ns );
186
187 clamped->clamp( *nj, j_val );
188 clamped->clamp( *nk, k_val );
189 if( reInit )
190 clamped->init();
191 else
192 clamped->init(ns);
193 clamped->run();
194
195 double Z_xj = 1.0;
196 if( j_val == 0 && k_val == 0 ) {
197 logZ0 = clamped->logZ();
198 } else {
199 // subtract logZ0 to avoid very large numbers
200 Z_xj = exp(clamped->logZ() - logZ0);
201 }
202
203 // we assume that j.label() < k.label()
204 // i.e. we make an assumption here about the indexing
205 pairbelief[j_val + (k_val * nj->states())] = Z_xj;
206
207 // restore clamped factors
208 clamped->restoreFactors( ns );
209 }
210
211 result.push_back( pairbelief );
212 }
213 }
214
215 delete clamped;
216
217 assert( result.size() == (ns.size() * (ns.size() - 1) / 2) );
218
219 return result;
220 }
221
222
223 } // end of namespace dai