Updated copyrights
[libdai.git] / src / daialg.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <vector>
24 #include <dai/daialg.h>
25
26
27 namespace dai {
28
29
30 using namespace std;
31
32
33 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
34 Factor Pns (ns);
35
36 InfAlg *clamped = obj.clone();
37 if( !reInit )
38 clamped->init();
39
40 Real logZ0 = 0.0;
41 for( State s(ns); s.valid(); s++ ) {
42 // save unclamped factors connected to ns
43 clamped->backupFactors( ns );
44
45 // set clamping Factors to delta functions
46 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
47 clamped->clamp( *n, s(*n) );
48
49 // run DAIAlg, calc logZ, store in Pns
50 if( reInit )
51 clamped->init();
52 else
53 clamped->init(ns);
54 clamped->run();
55
56 Real Z;
57 if( s == 0 ) {
58 logZ0 = clamped->logZ();
59 Z = 1.0;
60 } else {
61 // subtract logZ0 to avoid very large numbers
62 Z = exp(clamped->logZ() - logZ0);
63 }
64
65 Pns[s] = Z;
66
67 // restore clamped factors
68 clamped->restoreFactors( ns );
69 }
70
71 delete clamped;
72
73 return( Pns.normalized() );
74 }
75
76
77 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit ) {
78 // convert ns to vector<VarSet>
79 size_t N = ns.size();
80 vector<Var> vns;
81 vns.reserve( N );
82 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
83 vns.push_back( *n );
84
85 vector<Factor> pairbeliefs;
86 pairbeliefs.reserve( N * N );
87 for( size_t j = 0; j < N; j++ )
88 for( size_t k = 0; k < N; k++ )
89 if( j == k )
90 pairbeliefs.push_back( Factor() );
91 else
92 pairbeliefs.push_back( Factor( vns[j] | vns[k] ) );
93
94 InfAlg *clamped = obj.clone();
95 if( !reInit )
96 clamped->init();
97
98 Real logZ0 = 0.0;
99 for( size_t j = 0; j < N; j++ ) {
100 // clamp Var j to its possible values
101 for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
102 clamped->clamp( vns[j], j_val, true );
103 if( reInit )
104 clamped->init();
105 else
106 clamped->init(ns);
107 clamped->run();
108
109 //if( j == 0 )
110 // logZ0 = obj.logZ();
111 double Z_xj = 1.0;
112 if( j == 0 && j_val == 0 ) {
113 logZ0 = clamped->logZ();
114 } else {
115 // subtract logZ0 to avoid very large numbers
116 Z_xj = exp(clamped->logZ() - logZ0);
117 }
118
119 for( size_t k = 0; k < N; k++ )
120 if( k != j ) {
121 Factor b_k = clamped->belief(vns[k]);
122 for( size_t k_val = 0; k_val < vns[k].states(); k_val++ )
123 if( vns[j].label() < vns[k].label() )
124 pairbeliefs[j * N + k][j_val + (k_val * vns[j].states())] = Z_xj * b_k[k_val];
125 else
126 pairbeliefs[j * N + k][k_val + (j_val * vns[k].states())] = Z_xj * b_k[k_val];
127 }
128
129 // restore clamped factors
130 clamped->restoreFactors( ns );
131 }
132 }
133
134 delete clamped;
135
136 // Calculate result by taking the geometric average
137 vector<Factor> result;
138 result.reserve( N * (N - 1) / 2 );
139 for( size_t j = 0; j < N; j++ )
140 for( size_t k = j+1; k < N; k++ )
141 result.push_back( (pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5 );
142
143 return result;
144 }
145
146
147 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
148 // returns a a probability distribution whose 1st order interactions
149 // are unspecified, whose 2nd order interactions approximate those of
150 // the marginal on ns, and whose higher order interactions are absent.
151
152 vector<Factor> pairbeliefs = calcPairBeliefs( obj, ns, reInit );
153
154 Factor Pns (ns);
155 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
156 Pns *= pairbeliefs[ij];
157
158 return( Pns.normalized() );
159 }
160
161
162 vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit ) {
163 vector<Factor> result;
164 result.reserve( ns.size() * (ns.size() - 1) / 2 );
165
166 InfAlg *clamped = obj.clone();
167 if( !reInit )
168 clamped->init();
169
170 Real logZ0 = 0.0;
171 VarSet::const_iterator nj = ns.begin();
172 for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
173 size_t k = 0;
174 for( VarSet::const_iterator nk = nj; (++nk) != ns.end(); k++ ) {
175 Factor pairbelief( *nj | *nk );
176
177 // clamp Vars j and k to their possible values
178 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
179 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
180 // save unclamped factors connected to ns
181 clamped->backupFactors( ns );
182
183 clamped->clamp( *nj, j_val );
184 clamped->clamp( *nk, k_val );
185 if( reInit )
186 clamped->init();
187 else
188 clamped->init(ns);
189 clamped->run();
190
191 double Z_xj = 1.0;
192 if( j_val == 0 && k_val == 0 ) {
193 logZ0 = clamped->logZ();
194 } else {
195 // subtract logZ0 to avoid very large numbers
196 Z_xj = exp(clamped->logZ() - logZ0);
197 }
198
199 // we assume that j.label() < k.label()
200 // i.e. we make an assumption here about the indexing
201 pairbelief[j_val + (k_val * nj->states())] = Z_xj;
202
203 // restore clamped factors
204 clamped->restoreFactors( ns );
205 }
206
207 result.push_back( pairbelief );
208 }
209 }
210
211 delete clamped;
212
213 assert( result.size() == (ns.size() * (ns.size() - 1) / 2) );
214
215 return result;
216 }
217
218
219 } // end of namespace dai