Implementing TRWBP
[libdai.git] / src / fbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Frederik Eaton
8 * Copyright (C) 2009 Joris Mooij [joris dot mooij at libdai dot org]
9 */
10
11
12 #include <dai/fbp.h>
13
14
15 #define DAI_FBP_FAST 1
16
17
18 namespace dai {
19
20
21 using namespace std;
22
23
24 const char *FBP::Name = "FBP";
25
26
27 string FBP::identify() const {
28 return string(Name) + printProperties();
29 }
30
31
32 /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */
33 Real FBP::logZ() const {
34 Real sum = 0.0;
35 for( size_t I = 0; I < nrFactors(); I++ ) {
36 sum += (beliefF(I) * factor(I).log(true)).sum(); // FBP
37 sum += scaleF(I) * beliefF(I).entropy(); // FBP
38 }
39 for( size_t i = 0; i < nrVars(); ++i ) {
40 Real c_i = 0.0;
41 foreach( const Neighbor &I, nbV(i) )
42 c_i += scaleF(I);
43 sum += (1.0 - c_i) * beliefV(i).entropy(); // FBP
44 }
45 return sum;
46 }
47
48
49 /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */
50 void FBP::calcNewMessage( size_t i, size_t _I ) {
51 // calculate updated message I->i
52 size_t I = nbV(i,_I);
53
54 Real scale = scaleF(I); // FBP: c_I
55
56 Factor Fprod( factor(I) );
57 Prob &prod = Fprod.p();
58 if( props.logdomain ) {
59 prod.takeLog();
60 prod *= (1/scale); // FBP
61 } else
62 prod ^= (1/scale); // FBP
63
64 // Calculate product of incoming messages and factor I
65 foreach( const Neighbor &j, nbF(I) )
66 if( j != i ) { // for all j in I \ i
67 // FBP: same as n_jI
68 // prod_j will be the product of messages coming into j
69 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
70 foreach( const Neighbor &J, nbV(j) )
71 if( J != I ) { // for all J in nb(j) \ I
72 if( props.logdomain )
73 prod_j += message( j, J.iter );
74 else
75 prod_j *= message( j, J.iter );
76 } else {
77 // FBP: multiply by m_Ij^(1-1/c_I)
78 if( props.logdomain )
79 prod_j += message( j, J.iter )*(1-1/scale);
80 else
81 prod_j *= message( j, J.iter )^(1-1/scale);
82 }
83
84 // multiply prod with prod_j
85 if( !DAI_FBP_FAST ) {
86 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
87 if( props.logdomain )
88 Fprod += Factor( var(j), prod_j );
89 else
90 Fprod *= Factor( var(j), prod_j );
91 } else {
92 /* OPTIMIZED VERSION */
93 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
94 const ind_t &ind = index(j, _I);
95 for( size_t r = 0; r < prod.size(); ++r )
96 if( props.logdomain )
97 prod[r] += prod_j[ind[r]];
98 else
99 prod[r] *= prod_j[ind[r]];
100 }
101 }
102
103 if( props.logdomain ) {
104 prod -= prod.max();
105 prod.takeExp();
106 }
107
108 // Marginalize onto i
109 Prob marg;
110 if( !DAI_FBP_FAST ) {
111 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
112 if( props.inference == Properties::InfType::SUMPROD )
113 marg = Fprod.marginal( var(i) ).p();
114 else
115 marg = Fprod.maxMarginal( var(i) ).p();
116 } else {
117 /* OPTIMIZED VERSION */
118 marg = Prob( var(i).states(), 0.0 );
119 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
120 const ind_t ind = index(i,_I);
121 if( props.inference == Properties::InfType::SUMPROD )
122 for( size_t r = 0; r < prod.size(); ++r )
123 marg[ind[r]] += prod[r];
124 else
125 for( size_t r = 0; r < prod.size(); ++r )
126 if( prod[r] > marg[ind[r]] )
127 marg[ind[r]] = prod[r];
128 marg.normalize();
129 }
130
131 // FBP
132 marg ^= scale;
133
134 // Store result
135 if( props.logdomain )
136 newMessage(i,_I) = marg.log();
137 else
138 newMessage(i,_I) = marg;
139
140 // Update the residual if necessary
141 if( props.updates == Properties::UpdateType::SEQMAX )
142 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
143 }
144
145
146 /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */
147 void FBP::calcBeliefF( size_t I, Prob &p ) const {
148 Real scale = scaleF(I); // FBP: c_I
149
150 Factor Fprod( factor(I) );
151 Prob &prod = Fprod.p();
152
153 if( props.logdomain ) {
154 prod.takeLog();
155 prod /= scale; // FBP
156 } else
157 prod ^= (1/scale); // FBP
158
159 foreach( const Neighbor &j, nbF(I) ) {
160 // prod_j will be the product of messages coming into j
161 // FBP: corresponds to messages n_jI
162 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
163 foreach( const Neighbor &J, nbV(j) )
164 if( J != I ) { // for all J in nb(j) \ I
165 if( props.logdomain )
166 prod_j += newMessage( j, J.iter );
167 else
168 prod_j *= newMessage( j, J.iter );
169 } else {
170 // FBP: multiply by m_Ij^(1-1/c_I)
171 if( props.logdomain )
172 prod_j += newMessage( j, J.iter)*(1-1/scale);
173 else
174 prod_j *= newMessage( j, J.iter)^(1-1/scale);
175 }
176
177 // multiply prod with prod_j
178 if( !DAI_FBP_FAST ) {
179 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
180 if( props.logdomain )
181 Fprod += Factor( var(j), prod_j );
182 else
183 Fprod *= Factor( var(j), prod_j );
184 } else {
185 /* OPTIMIZED VERSION */
186 size_t _I = j.dual;
187 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
188 const ind_t & ind = index(j, _I);
189
190 for( size_t r = 0; r < prod.size(); ++r ) {
191 if( props.logdomain )
192 prod[r] += prod_j[ind[r]];
193 else
194 prod[r] *= prod_j[ind[r]];
195 }
196 }
197 }
198
199 p = prod;
200 }
201
202
203 void FBP::construct() {
204 BP::construct();
205 _scale_factor.resize( nrFactors(), 1.0 );
206 }
207
208
209 } // end of namespace dai