Cleaned up variable elimination code in ClusterGraph
[libdai.git] / src / fbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Frederik Eaton
8 * Copyright (C) 2009 Joris Mooij [joris dot mooij at libdai dot org]
9 */
10
11
12 #include <dai/fbp.h>
13
14
15 #define DAI_FBP_FAST 1
16
17
18 namespace dai {
19
20
21 using namespace std;
22
23
24 const char *FBP::Name = "FBP";
25
26
27 string FBP::identify() const {
28 return string(Name) + printProperties();
29 }
30
31
32 /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */
33 Real FBP::logZ() const {
34 Real sum = 0.0;
35 for( size_t I = 0; I < nrFactors(); I++ ) {
36 sum += (beliefF(I) * factor(I).log(true)).sum(); // FBP
37 sum += scaleF(I) * beliefF(I).entropy(); // FBP
38 }
39 for( size_t i = 0; i < nrVars(); ++i ) {
40 Real c_i = 0.0;
41 foreach( const Neighbor &I, nbV(i) )
42 c_i += scaleF(I);
43 sum += (1.0 - c_i) * beliefV(i).entropy(); // FBP
44 }
45 return sum;
46 }
47
48
49 /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */
50 void FBP::calcNewMessage( size_t i, size_t _I ) {
51 // calculate updated message I->i
52 size_t I = nbV(i,_I);
53
54 Real scale = scaleF(I); // FBP: c_I
55
56 Factor Fprod( factor(I) );
57 Prob &prod = Fprod.p();
58 if( props.logdomain ) {
59 prod.takeLog();
60 prod *= (1/scale); // FBP
61 } else
62 prod ^= (1/scale); // FBP
63
64 // Calculate product of incoming messages and factor I
65 foreach( const Neighbor &j, nbF(I) )
66 if( j != i ) { // for all j in I \ i
67 // FBP: same as n_jI
68 // prod_j will be the product of messages coming into j
69 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
70 foreach( const Neighbor &J, nbV(j) )
71 if( J != I ) { // for all J in nb(j) \ I
72 if( props.logdomain )
73 prod_j += message( j, J.iter );
74 else
75 prod_j *= message( j, J.iter );
76 }
77
78
79 size_t _I = j.dual;
80 // FBP: now multiply by m_Ij^(1-1/c_I)
81 if(props.logdomain)
82 prod_j += message( j, _I)*(1-1/scale);
83 else
84 prod_j *= message( j, _I)^(1-1/scale);
85
86 // multiply prod with prod_j
87 if( !DAI_FBP_FAST ) {
88 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
89 if( props.logdomain )
90 Fprod += Factor( var(j), prod_j );
91 else
92 Fprod *= Factor( var(j), prod_j );
93 } else {
94 /* OPTIMIZED VERSION */
95 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
96 const ind_t &ind = index(j, _I);
97 for( size_t r = 0; r < prod.size(); ++r )
98 if( props.logdomain )
99 prod[r] += prod_j[ind[r]];
100 else
101 prod[r] *= prod_j[ind[r]];
102 }
103 }
104
105 if( props.logdomain ) {
106 prod -= prod.max();
107 prod.takeExp();
108 }
109
110 // Marginalize onto i
111 Prob marg;
112 if( !DAI_FBP_FAST ) {
113 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
114 if( props.inference == Properties::InfType::SUMPROD )
115 marg = Fprod.marginal( var(i) ).p();
116 else
117 marg = Fprod.maxMarginal( var(i) ).p();
118 } else {
119 /* OPTIMIZED VERSION */
120 marg = Prob( var(i).states(), 0.0 );
121 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
122 const ind_t ind = index(i,_I);
123 if( props.inference == Properties::InfType::SUMPROD )
124 for( size_t r = 0; r < prod.size(); ++r )
125 marg[ind[r]] += prod[r];
126 else
127 for( size_t r = 0; r < prod.size(); ++r )
128 if( prod[r] > marg[ind[r]] )
129 marg[ind[r]] = prod[r];
130 marg.normalize();
131 }
132
133 // FBP
134 marg ^= scale;
135
136 // Store result
137 if( props.logdomain )
138 newMessage(i,_I) = marg.log();
139 else
140 newMessage(i,_I) = marg;
141
142 // Update the residual if necessary
143 if( props.updates == Properties::UpdateType::SEQMAX )
144 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
145 }
146
147
148 /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */
149 void FBP::calcBeliefF( size_t I, Prob &p ) const {
150 Real scale = scaleF(I); // FBP: c_I
151
152 Factor Fprod( factor(I) );
153 Prob &prod = Fprod.p();
154
155 if( props.logdomain ) {
156 prod.takeLog();
157 prod /= scale; // FBP
158 } else
159 prod ^= (1/scale); // FBP
160
161 foreach( const Neighbor &j, nbF(I) ) {
162 // prod_j will be the product of messages coming into j
163 // FBP: corresponds to messages n_jI
164 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
165 foreach( const Neighbor &J, nbV(j) )
166 if( J != I ) { // for all J in nb(j) \ I
167 if( props.logdomain )
168 prod_j += newMessage( j, J.iter );
169 else
170 prod_j *= newMessage( j, J.iter );
171 }
172
173 size_t _I = j.dual;
174
175 // FBP: now multiply by m_Ij^(1-1/c_I)
176 if( props.logdomain )
177 prod_j += newMessage( j, _I)*(1-1/scale);
178 else
179 prod_j *= newMessage( j, _I)^(1-1/scale);
180
181 // multiply prod with prod_j
182 if( !DAI_FBP_FAST ) {
183 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
184 if( props.logdomain )
185 Fprod += Factor( var(j), prod_j );
186 else
187 Fprod *= Factor( var(j), prod_j );
188 } else {
189 /* OPTIMIZED VERSION */
190 size_t _I = j.dual;
191 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
192 const ind_t & ind = index(j, _I);
193
194 for( size_t r = 0; r < prod.size(); ++r ) {
195 if( props.logdomain )
196 prod[r] += prod_j[ind[r]];
197 else
198 prod[r] *= prod_j[ind[r]];
199 }
200 }
201 }
202
203 p = prod;
204 }
205
206
207 void FBP::construct() {
208 BP::construct();
209 _scale_factor.resize( nrFactors(), 1.0 );
210 }
211
212
213 } // end of namespace dai