Replaced sub_nb class in mr.h by boost::dynamic_bitset
[libdai.git] / src / daialg.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <vector>
23 #include <dai/daialg.h>
24
25
26 namespace dai {
27
28
29 using namespace std;
30
31
32 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
33 Factor Pns (ns);
34
35 InfAlg *clamped = obj.clone();
36 if( !reInit )
37 clamped->init();
38
39 Real logZ0 = 0.0;
40 for( State s(ns); s.valid(); s++ ) {
41 // save unclamped factors connected to ns
42 clamped->backupFactors( ns );
43
44 // set clamping Factors to delta functions
45 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
46 clamped->clamp( *n, s(*n) );
47
48 // run DAIAlg, calc logZ, store in Pns
49 if( reInit )
50 clamped->init();
51 else
52 clamped->init(ns);
53 clamped->run();
54
55 Real Z;
56 if( s == 0 ) {
57 logZ0 = clamped->logZ();
58 Z = 1.0;
59 } else {
60 // subtract logZ0 to avoid very large numbers
61 Z = exp(clamped->logZ() - logZ0);
62 }
63
64 Pns[s] = Z;
65
66 // restore clamped factors
67 clamped->restoreFactors( ns );
68 }
69
70 delete clamped;
71
72 return( Pns.normalized() );
73 }
74
75
76 vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit ) {
77 // convert ns to vector<VarSet>
78 size_t N = ns.size();
79 vector<Var> vns;
80 vns.reserve( N );
81 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
82 vns.push_back( *n );
83
84 vector<Factor> pairbeliefs;
85 pairbeliefs.reserve( N * N );
86 for( size_t j = 0; j < N; j++ )
87 for( size_t k = 0; k < N; k++ )
88 if( j == k )
89 pairbeliefs.push_back( Factor() );
90 else
91 pairbeliefs.push_back( Factor( vns[j] | vns[k] ) );
92
93 InfAlg *clamped = obj.clone();
94 if( !reInit )
95 clamped->init();
96
97 Real logZ0 = 0.0;
98 for( size_t j = 0; j < N; j++ ) {
99 // clamp Var j to its possible values
100 for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
101 clamped->clamp( vns[j], j_val, true );
102 if( reInit )
103 clamped->init();
104 else
105 clamped->init(ns);
106 clamped->run();
107
108 //if( j == 0 )
109 // logZ0 = obj.logZ();
110 double Z_xj = 1.0;
111 if( j == 0 && j_val == 0 ) {
112 logZ0 = clamped->logZ();
113 } else {
114 // subtract logZ0 to avoid very large numbers
115 Z_xj = exp(clamped->logZ() - logZ0);
116 }
117
118 for( size_t k = 0; k < N; k++ )
119 if( k != j ) {
120 Factor b_k = clamped->belief(vns[k]);
121 for( size_t k_val = 0; k_val < vns[k].states(); k_val++ )
122 if( vns[j].label() < vns[k].label() )
123 pairbeliefs[j * N + k][j_val + (k_val * vns[j].states())] = Z_xj * b_k[k_val];
124 else
125 pairbeliefs[j * N + k][k_val + (j_val * vns[k].states())] = Z_xj * b_k[k_val];
126 }
127
128 // restore clamped factors
129 clamped->restoreFactors( ns );
130 }
131 }
132
133 delete clamped;
134
135 // Calculate result by taking the geometric average
136 vector<Factor> result;
137 result.reserve( N * (N - 1) / 2 );
138 for( size_t j = 0; j < N; j++ )
139 for( size_t k = j+1; k < N; k++ )
140 result.push_back( (pairbeliefs[j * N + k] * pairbeliefs[k * N + j]) ^ 0.5 );
141
142 return result;
143 }
144
145
146 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
147 // returns a a probability distribution whose 1st order interactions
148 // are unspecified, whose 2nd order interactions approximate those of
149 // the marginal on ns, and whose higher order interactions are absent.
150
151 vector<Factor> pairbeliefs = calcPairBeliefs( obj, ns, reInit );
152
153 Factor Pns (ns);
154 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
155 Pns *= pairbeliefs[ij];
156
157 return( Pns.normalized() );
158 }
159
160
161 vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit ) {
162 vector<Factor> result;
163 result.reserve( ns.size() * (ns.size() - 1) / 2 );
164
165 InfAlg *clamped = obj.clone();
166 if( !reInit )
167 clamped->init();
168
169 Real logZ0 = 0.0;
170 VarSet::const_iterator nj = ns.begin();
171 for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
172 size_t k = 0;
173 for( VarSet::const_iterator nk = nj; (++nk) != ns.end(); k++ ) {
174 Factor pairbelief( *nj | *nk );
175
176 // clamp Vars j and k to their possible values
177 for( size_t j_val = 0; j_val < nj->states(); j_val++ )
178 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
179 // save unclamped factors connected to ns
180 clamped->backupFactors( ns );
181
182 clamped->clamp( *nj, j_val );
183 clamped->clamp( *nk, k_val );
184 if( reInit )
185 clamped->init();
186 else
187 clamped->init(ns);
188 clamped->run();
189
190 double Z_xj = 1.0;
191 if( j_val == 0 && k_val == 0 ) {
192 logZ0 = clamped->logZ();
193 } else {
194 // subtract logZ0 to avoid very large numbers
195 Z_xj = exp(clamped->logZ() - logZ0);
196 }
197
198 // we assume that j.label() < k.label()
199 // i.e. we make an assumption here about the indexing
200 pairbelief[j_val + (k_val * nj->states())] = Z_xj;
201
202 // restore clamped factors
203 clamped->restoreFactors( ns );
204 }
205
206 result.push_back( pairbelief );
207 }
208 }
209
210 delete clamped;
211
212 assert( result.size() == (ns.size() * (ns.size() - 1) / 2) );
213
214 return result;
215 }
216
217
218 } // end of namespace dai