Implementing TRWBP
[libdai.git] / src / trwbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #include <dai/trwbp.h>
12
13
14 #define DAI_TRWBP_FAST 1
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 const char *TRWBP::Name = "TRWBP";
24
25
26 string TRWBP::identify() const {
27 return string(Name) + printProperties();
28 }
29
30
31 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
32 Real TRWBP::logZ() const {
33 Real sum = 0.0;
34 for( size_t I = 0; I < nrFactors(); I++ ) {
35 sum += (beliefF(I) * factor(I).log(true)).sum(); // TRWBP
36 if( factor(I).vars().size() == 2 )
37 sum -= edgeWeight(I) * MutualInfo( beliefF(I) ); // TRWBP
38 }
39 for( size_t i = 0; i < nrVars(); ++i )
40 sum += beliefV(i).entropy(); // TRWBP
41 return sum;
42 }
43
44
45 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
46 void TRWBP::calcNewMessage( size_t i, size_t _I ) {
47 // calculate updated message I->i
48 size_t I = nbV(i,_I);
49 const Var &v_i = var(i);
50 const VarSet &v_I = factor(I).vars();
51 Real c_I = edgeWeight(I); // TRWBP: c_I (\mu_I in the paper)
52
53 Prob marg;
54 if( v_I.size() == 1 ) { // optimization
55 marg = factor(I).p();
56 } else
57 Factor Fprod( factor(I) );
58 Prob &prod = Fprod.p();
59 if( props.logdomain ) {
60 prod.takeLog();
61 prod /= c_I; // TRWBP
62 } else
63 prod ^= (1.0 / c_I); // TRWBP
64
65 // Calculate product of incoming messages and factor I
66 foreach( const Neighbor &j, nbF(I) )
67 if( j != i ) { // for all j in I \ i
68 const Var &v_j = var(j);
69
70 // TRWBP: corresponds to messages n_jI
71 // prod_j will be the product of messages coming into j
72 Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
73 foreach( const Neighbor &J, nbV(j) ) {
74 Real c_J = edgeWeight(J);
75 if( J != I ) { // for all J in nb(j) \ I
76 if( props.logdomain )
77 prod_j += message( j, J.iter ) * c_J;
78 else
79 prod_j *= message( j, J.iter ) ^ c_J;
80 } else { // TRWBP: multiply by m_Ij^(c_I-1)
81 if( props.logdomain )
82 prod_j += message( j, J.iter ) * (c_J - 1.0);
83 else
84 prod_j *= message( j, J.iter ) ^ (c_J - 1.0);
85 }
86 }
87
88 // multiply prod with prod_j
89 if( !DAI_TRWBP_FAST ) {
90 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
91 if( props.logdomain )
92 Fprod += Factor( v_j, prod_j );
93 else
94 Fprod *= Factor( v_j, prod_j );
95 } else {
96 /* OPTIMIZED VERSION */
97 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
98 const ind_t &ind = index(j, _I);
99 for( size_t r = 0; r < prod.size(); ++r )
100 if( props.logdomain )
101 prod[r] += prod_j[ind[r]];
102 else
103 prod[r] *= prod_j[ind[r]];
104 }
105 }
106
107 if( props.logdomain ) {
108 prod -= prod.max();
109 prod.takeExp();
110 }
111
112 // Marginalize onto i
113 if( !DAI_TRWBP_FAST ) {
114 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
115 if( props.inference == Properties::InfType::SUMPROD )
116 marg = Fprod.marginal( v_i ).p();
117 else
118 marg = Fprod.maxMarginal( v_i ).p();
119 } else {
120 /* OPTIMIZED VERSION */
121 marg = Prob( v_i.states(), 0.0 );
122 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
123 const ind_t ind = index(i,_I);
124 if( props.inference == Properties::InfType::SUMPROD )
125 for( size_t r = 0; r < prod.size(); ++r )
126 marg[ind[r]] += prod[r];
127 else
128 for( size_t r = 0; r < prod.size(); ++r )
129 if( prod[r] > marg[ind[r]] )
130 marg[ind[r]] = prod[r];
131 marg.normalize();
132 }
133
134 // Store result
135 if( props.logdomain )
136 newMessage(i,_I) = marg.log();
137 else
138 newMessage(i,_I) = marg;
139
140 // Update the residual if necessary
141 if( props.updates == Properties::UpdateType::SEQMAX )
142 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
143 }
144 }
145
146
147 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
148 void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
149 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
150 foreach( const Neighbor &I, nbV(i) ) {
151 Real c_I = edgeWeight(I);
152 if( props.logdomain )
153 p += newMessage( i, I.iter ) * c_I;
154 else
155 p *= newMessage( i, I.iter ) ^ c_I;
156 }
157 }
158
159
160 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
161 void TRWBP::calcBeliefF( size_t I, Prob &p ) const {
162 Real c_I = edgeWeight(I); // TRWBP: c_I
163 const VarSet &v_I = factor(I).vars();
164
165 Factor Fprod( factor(I) );
166 Prob &prod = Fprod.p();
167
168 if( props.logdomain ) {
169 prod.takeLog();
170 prod /= c_I; // TRWBP
171 } else
172 prod ^= (1.0 / c_I); // TRWBP
173
174 // Calculate product of incoming messages and factor I
175 foreach( const Neighbor &j, nbF(I) ) {
176 const Var &v_j = var(j);
177
178 // TRWBP: corresponds to messages n_jI
179 // prod_j will be the product of messages coming into j
180 Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
181 foreach( const Neighbor &J, nbV(j) ) {
182 Real c_J = edgeWeight(J);
183 if( J != I ) { // for all J in nb(j) \ I
184 if( props.logdomain )
185 prod_j += newMessage( j, J.iter ) * c_J;
186 else
187 prod_j *= newMessage( j, J.iter ) ^ c_J;
188 } else { // TRWBP: multiply by m_Ij^(c_I-1)
189 if( props.logdomain )
190 prod_j += newMessage( j, J.iter ) * (c_J - 1.0);
191 else
192 prod_j *= newMessage( j, J.iter ) ^ (c_J - 1.0);
193 }
194 }
195
196 // multiply prod with prod_j
197 if( !DAI_TRWBP_FAST ) {
198 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
199 if( props.logdomain )
200 Fprod += Factor( v_j, prod_j );
201 else
202 Fprod *= Factor( v_j, prod_j );
203 } else {
204 /* OPTIMIZED VERSION */
205 size_t _I = j.dual;
206 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
207 const ind_t &ind = index(j, _I);
208
209 for( size_t r = 0; r < prod.size(); ++r ) {
210 if( props.logdomain )
211 prod[r] += prod_j[ind[r]];
212 else
213 prod[r] *= prod_j[ind[r]];
214 }
215 }
216 }
217
218 p = prod;
219 }
220
221
222 void TRWBP::construct() {
223 BP::construct();
224 _edge_weight.resize( nrFactors(), 1.0 );
225 }
226
227
228 } // end of namespace dai