bda5bc768b01237d88e55c294eb1397ef9c2bce8
[libdai.git] / src / trwbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #include <dai/trwbp.h>
12
13
14 #define DAI_TRWBP_FAST 1
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 const char *TRWBP::Name = "TRWBP";
24
25
26 string TRWBP::identify() const {
27 return string(Name) + printProperties();
28 }
29
30
31 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
32 Real TRWBP::logZ() const {
33 Real sum = 0.0;
34 for( size_t I = 0; I < nrFactors(); I++ ) {
35 sum += (beliefF(I) * factor(I).log(true)).sum(); // TRWBP/FBP
36 sum += Weight(I) * beliefF(I).entropy(); // TRWBP/FBP
37 }
38 for( size_t i = 0; i < nrVars(); ++i ) {
39 Real c_i = 0.0;
40 foreach( const Neighbor &I, nbV(i) )
41 c_i += Weight(I);
42 sum += (1.0 - c_i) * beliefV(i).entropy(); // TRWBP/FBP
43 }
44 return sum;
45 }
46
47
48 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
49 void TRWBP::calcNewMessage( size_t i, size_t _I ) {
50 // calculate updated message I->i
51 size_t I = nbV(i,_I);
52 const Var &v_i = var(i);
53 Real c_I = Weight(I); // TRWBP: c_I (\mu_I in the paper)
54
55 Prob marg;
56 if( factor(I).vars().size() == 1 ) { // optimization
57 marg = factor(I).p();
58 } else {
59 Factor Fprod( factor(I) );
60 Prob &prod = Fprod.p();
61 if( props.logdomain ) {
62 prod.takeLog();
63 prod /= c_I; // TRWBP
64 } else
65 prod ^= (1.0 / c_I); // TRWBP
66
67 // Calculate product of incoming messages and factor I
68 foreach( const Neighbor &j, nbF(I) )
69 if( j != i ) { // for all j in I \ i
70 const Var &v_j = var(j);
71
72 // TRWBP: corresponds to messages n_jI
73 // prod_j will be the product of messages coming into j
74 Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
75 foreach( const Neighbor &J, nbV(j) ) {
76 Real c_J = Weight(J);
77 if( J != I ) { // for all J in nb(j) \ I
78 if( props.logdomain )
79 prod_j += message( j, J.iter ) * c_J;
80 else
81 prod_j *= message( j, J.iter ) ^ c_J;
82 } else { // TRWBP: multiply by m_Ij^(c_I-1)
83 if( props.logdomain )
84 prod_j += message( j, J.iter ) * (c_J - 1.0);
85 else
86 prod_j *= message( j, J.iter ) ^ (c_J - 1.0);
87 }
88 }
89
90 // multiply prod with prod_j
91 if( !DAI_TRWBP_FAST ) {
92 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
93 if( props.logdomain )
94 Fprod += Factor( v_j, prod_j );
95 else
96 Fprod *= Factor( v_j, prod_j );
97 } else {
98 /* OPTIMIZED VERSION */
99 size_t _I = j.dual;
100 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
101 const ind_t &ind = index(j, _I);
102 for( size_t r = 0; r < prod.size(); ++r )
103 if( props.logdomain )
104 prod[r] += prod_j[ind[r]];
105 else
106 prod[r] *= prod_j[ind[r]];
107 }
108 }
109
110 if( props.logdomain ) {
111 prod -= prod.max();
112 prod.takeExp();
113 }
114
115 // Marginalize onto i
116 if( !DAI_TRWBP_FAST ) {
117 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
118 if( props.inference == Properties::InfType::SUMPROD )
119 marg = Fprod.marginal( v_i ).p();
120 else
121 marg = Fprod.maxMarginal( v_i ).p();
122 } else {
123 /* OPTIMIZED VERSION */
124 marg = Prob( v_i.states(), 0.0 );
125 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
126 const ind_t ind = index(i,_I);
127 if( props.inference == Properties::InfType::SUMPROD )
128 for( size_t r = 0; r < prod.size(); ++r )
129 marg[ind[r]] += prod[r];
130 else
131 for( size_t r = 0; r < prod.size(); ++r )
132 if( prod[r] > marg[ind[r]] )
133 marg[ind[r]] = prod[r];
134 marg.normalize();
135 }
136 }
137
138 // Store result
139 if( props.logdomain )
140 newMessage(i,_I) = marg.log();
141 else
142 newMessage(i,_I) = marg;
143
144 // Update the residual if necessary
145 if( props.updates == Properties::UpdateType::SEQMAX )
146 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
147 }
148
149
150 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
151 void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
152 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
153 foreach( const Neighbor &I, nbV(i) ) {
154 Real c_I = Weight(I);
155 if( props.logdomain )
156 p += newMessage( i, I.iter ) * c_I;
157 else
158 p *= newMessage( i, I.iter ) ^ c_I;
159 }
160 }
161
162
163 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
164 void TRWBP::calcBeliefF( size_t I, Prob &p ) const {
165 Real c_I = Weight(I); // TRWBP: c_I
166
167 Factor Fprod( factor(I) );
168 Prob &prod = Fprod.p();
169
170 if( props.logdomain ) {
171 prod.takeLog();
172 prod /= c_I; // TRWBP
173 } else
174 prod ^= (1.0 / c_I); // TRWBP
175
176 // Calculate product of incoming messages and factor I
177 foreach( const Neighbor &j, nbF(I) ) {
178 const Var &v_j = var(j);
179
180 // TRWBP: corresponds to messages n_jI
181 // prod_j will be the product of messages coming into j
182 Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
183 foreach( const Neighbor &J, nbV(j) ) {
184 Real c_J = Weight(J);
185 if( J != I ) { // for all J in nb(j) \ I
186 if( props.logdomain )
187 prod_j += newMessage( j, J.iter ) * c_J;
188 else
189 prod_j *= newMessage( j, J.iter ) ^ c_J;
190 } else { // TRWBP: multiply by m_Ij^(c_I-1)
191 if( props.logdomain )
192 prod_j += newMessage( j, J.iter ) * (c_J - 1.0);
193 else
194 prod_j *= newMessage( j, J.iter ) ^ (c_J - 1.0);
195 }
196 }
197
198 // multiply prod with prod_j
199 if( !DAI_TRWBP_FAST ) {
200 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
201 if( props.logdomain )
202 Fprod += Factor( v_j, prod_j );
203 else
204 Fprod *= Factor( v_j, prod_j );
205 } else {
206 /* OPTIMIZED VERSION */
207 size_t _I = j.dual;
208 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
209 const ind_t &ind = index(j, _I);
210
211 for( size_t r = 0; r < prod.size(); ++r ) {
212 if( props.logdomain )
213 prod[r] += prod_j[ind[r]];
214 else
215 prod[r] *= prod_j[ind[r]];
216 }
217 }
218 }
219
220 p = prod;
221 }
222
223
224 void TRWBP::construct() {
225 BP::construct();
226 _weight.resize( nrFactors(), 1.0 );
227 }
228
229
230 } // end of namespace dai