8cd51b0a8e6508dcdad06f4f309a92b7491ac975
[libdai.git] / src / fbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Frederik Eaton
8 * Copyright (C) 2009 Joris Mooij [joris dot mooij at libdai dot org]
9 */
10
11
12 #include <dai/fbp.h>
13
14
15 #define DAI_FBP_FAST 1
16
17
18 namespace dai {
19
20
21 using namespace std;
22
23
24 const char *FBP::Name = "FBP";
25
26
27 string FBP::identify() const {
28 return string(Name) + printProperties();
29 }
30
31
32 /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */
33 Real FBP::logZ() const {
34 Real sum = 0.0;
35 for( size_t I = 0; I < nrFactors(); I++ ) {
36 sum += (beliefF(I) * factor(I).log(true)).sum(); // FBP
37 sum += Weight(I) * beliefF(I).entropy(); // FBP
38 }
39 for( size_t i = 0; i < nrVars(); ++i ) {
40 Real c_i = 0.0;
41 foreach( const Neighbor &I, nbV(i) )
42 c_i += Weight(I);
43 sum += (1.0 - c_i) * beliefV(i).entropy(); // FBP
44 }
45 return sum;
46 }
47
48
49 /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */
50 void FBP::calcNewMessage( size_t i, size_t _I ) {
51 // calculate updated message I->i
52 size_t I = nbV(i,_I);
53
54 Real c_I = Weight(I); // FBP: c_I
55
56 Factor Fprod( factor(I) );
57 Prob &prod = Fprod.p();
58 if( props.logdomain ) {
59 prod.takeLog();
60 prod /= c_I; // FBP
61 } else
62 prod ^= (1.0 / c_I); // FBP
63
64 // Calculate product of incoming messages and factor I
65 foreach( const Neighbor &j, nbF(I) )
66 if( j != i ) { // for all j in I \ i
67 // FBP: same as n_jI
68 // prod_j will be the product of messages coming into j
69 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
70 foreach( const Neighbor &J, nbV(j) )
71 if( J != I ) { // for all J in nb(j) \ I
72 if( props.logdomain )
73 prod_j += message( j, J.iter );
74 else
75 prod_j *= message( j, J.iter );
76 } else {
77 // FBP: multiply by m_Ij^(1-1/c_I)
78 if( props.logdomain )
79 prod_j += message( j, J.iter ) * (1.0 - 1.0 / c_I);
80 else
81 prod_j *= message( j, J.iter ) ^ (1.0 - 1.0 / c_I);
82 }
83
84 // multiply prod with prod_j
85 if( !DAI_FBP_FAST ) {
86 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
87 if( props.logdomain )
88 Fprod += Factor( var(j), prod_j );
89 else
90 Fprod *= Factor( var(j), prod_j );
91 } else {
92 /* OPTIMIZED VERSION */
93 size_t _I = j.dual;
94 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
95 const ind_t &ind = index(j, _I);
96 for( size_t r = 0; r < prod.size(); ++r )
97 if( props.logdomain )
98 prod[r] += prod_j[ind[r]];
99 else
100 prod[r] *= prod_j[ind[r]];
101 }
102 }
103
104 if( props.logdomain ) {
105 prod -= prod.max();
106 prod.takeExp();
107 }
108
109 // Marginalize onto i
110 Prob marg;
111 if( !DAI_FBP_FAST ) {
112 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
113 if( props.inference == Properties::InfType::SUMPROD )
114 marg = Fprod.marginal( var(i) ).p();
115 else
116 marg = Fprod.maxMarginal( var(i) ).p();
117 } else {
118 /* OPTIMIZED VERSION */
119 marg = Prob( var(i).states(), 0.0 );
120 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
121 const ind_t ind = index(i,_I);
122 if( props.inference == Properties::InfType::SUMPROD )
123 for( size_t r = 0; r < prod.size(); ++r )
124 marg[ind[r]] += prod[r];
125 else
126 for( size_t r = 0; r < prod.size(); ++r )
127 if( prod[r] > marg[ind[r]] )
128 marg[ind[r]] = prod[r];
129 marg.normalize();
130 }
131
132 // FBP
133 marg ^= c_I;
134
135 // Store result
136 if( props.logdomain )
137 newMessage(i,_I) = marg.log();
138 else
139 newMessage(i,_I) = marg;
140
141 // Update the residual if necessary
142 if( props.updates == Properties::UpdateType::SEQMAX )
143 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
144 }
145
146
147 /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */
148 void FBP::calcBeliefF( size_t I, Prob &p ) const {
149 Real c_I = Weight(I); // FBP: c_I
150
151 Factor Fprod( factor(I) );
152 Prob &prod = Fprod.p();
153
154 if( props.logdomain ) {
155 prod.takeLog();
156 prod /= c_I; // FBP
157 } else
158 prod ^= (1.0 / c_I); // FBP
159
160 foreach( const Neighbor &j, nbF(I) ) {
161 // prod_j will be the product of messages coming into j
162 // FBP: corresponds to messages n_jI
163 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
164 foreach( const Neighbor &J, nbV(j) )
165 if( J != I ) { // for all J in nb(j) \ I
166 if( props.logdomain )
167 prod_j += newMessage( j, J.iter );
168 else
169 prod_j *= newMessage( j, J.iter );
170 } else {
171 // FBP: multiply by m_Ij^(1-1/c_I)
172 if( props.logdomain )
173 prod_j += newMessage( j, J.iter) * (1.0 - 1.0 / c_I);
174 else
175 prod_j *= newMessage( j, J.iter) ^ (1.0 - 1.0 / c_I);
176 }
177
178 // multiply prod with prod_j
179 if( !DAI_FBP_FAST ) {
180 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
181 if( props.logdomain )
182 Fprod += Factor( var(j), prod_j );
183 else
184 Fprod *= Factor( var(j), prod_j );
185 } else {
186 /* OPTIMIZED VERSION */
187 size_t _I = j.dual;
188 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
189 const ind_t & ind = index(j, _I);
190
191 for( size_t r = 0; r < prod.size(); ++r ) {
192 if( props.logdomain )
193 prod[r] += prod_j[ind[r]];
194 else
195 prod[r] *= prod_j[ind[r]];
196 }
197 }
198 }
199
200 p = prod;
201 }
202
203
204 void FBP::construct() {
205 BP::construct();
206 _weight.resize( nrFactors(), 1.0 );
207 }
208
209
210 } // end of namespace dai