1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
11 #include <dai/trwbp.h>
14 #define DAI_TRWBP_FAST 1
23 const char *TRWBP::Name
= "TRWBP";
26 string
TRWBP::identify() const {
27 return string(Name
) + printProperties();
31 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
32 Real
TRWBP::logZ() const {
34 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
35 sum
+= (beliefF(I
) * factor(I
).log(true)).sum(); // TRWBP/FBP
36 sum
+= Weight(I
) * beliefF(I
).entropy(); // TRWBP/FBP
38 for( size_t i
= 0; i
< nrVars(); ++i
) {
40 foreach( const Neighbor
&I
, nbV(i
) )
42 sum
+= (1.0 - c_i
) * beliefV(i
).entropy(); // TRWBP/FBP
48 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
49 void TRWBP::calcNewMessage( size_t i
, size_t _I
) {
50 // calculate updated message I->i
52 const Var
&v_i
= var(i
);
53 Real c_I
= Weight(I
); // TRWBP: c_I (\mu_I in the paper)
56 if( factor(I
).vars().size() == 1 ) { // optimization
59 Factor
Fprod( factor(I
) );
60 Prob
&prod
= Fprod
.p();
61 if( props
.logdomain
) {
65 prod
^= (1.0 / c_I
); // TRWBP
67 // Calculate product of incoming messages and factor I
68 foreach( const Neighbor
&j
, nbF(I
) )
69 if( j
!= i
) { // for all j in I \ i
70 const Var
&v_j
= var(j
);
72 // TRWBP: corresponds to messages n_jI
73 // prod_j will be the product of messages coming into j
74 Prob
prod_j( v_j
.states(), props
.logdomain
? 0.0 : 1.0 );
75 foreach( const Neighbor
&J
, nbV(j
) ) {
77 if( J
!= I
) { // for all J in nb(j) \ I
79 prod_j
+= message( j
, J
.iter
) * c_J
;
81 prod_j
*= message( j
, J
.iter
) ^ c_J
;
82 } else { // TRWBP: multiply by m_Ij^(c_I-1)
84 prod_j
+= message( j
, J
.iter
) * (c_J
- 1.0);
86 prod_j
*= message( j
, J
.iter
) ^ (c_J
- 1.0);
90 // multiply prod with prod_j
91 if( !DAI_TRWBP_FAST
) {
92 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
94 Fprod
+= Factor( v_j
, prod_j
);
96 Fprod
*= Factor( v_j
, prod_j
);
98 /* OPTIMIZED VERSION */
100 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
101 const ind_t
&ind
= index(j
, _I
);
102 for( size_t r
= 0; r
< prod
.size(); ++r
)
103 if( props
.logdomain
)
104 prod
[r
] += prod_j
[ind
[r
]];
106 prod
[r
] *= prod_j
[ind
[r
]];
110 if( props
.logdomain
) {
115 // Marginalize onto i
116 if( !DAI_TRWBP_FAST
) {
117 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
118 if( props
.inference
== Properties::InfType::SUMPROD
)
119 marg
= Fprod
.marginal( v_i
).p();
121 marg
= Fprod
.maxMarginal( v_i
).p();
123 /* OPTIMIZED VERSION */
124 marg
= Prob( v_i
.states(), 0.0 );
125 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
126 const ind_t ind
= index(i
,_I
);
127 if( props
.inference
== Properties::InfType::SUMPROD
)
128 for( size_t r
= 0; r
< prod
.size(); ++r
)
129 marg
[ind
[r
]] += prod
[r
];
131 for( size_t r
= 0; r
< prod
.size(); ++r
)
132 if( prod
[r
] > marg
[ind
[r
]] )
133 marg
[ind
[r
]] = prod
[r
];
139 if( props
.logdomain
)
140 newMessage(i
,_I
) = marg
.log();
142 newMessage(i
,_I
) = marg
;
144 // Update the residual if necessary
145 if( props
.updates
== Properties::UpdateType::SEQMAX
)
146 updateResidual( i
, _I
, dist( newMessage( i
, _I
), message( i
, _I
), Prob::DISTLINF
) );
150 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
151 void TRWBP::calcBeliefV( size_t i
, Prob
&p
) const {
152 p
= Prob( var(i
).states(), props
.logdomain
? 0.0 : 1.0 );
153 foreach( const Neighbor
&I
, nbV(i
) ) {
154 Real c_I
= Weight(I
);
155 if( props
.logdomain
)
156 p
+= newMessage( i
, I
.iter
) * c_I
;
158 p
*= newMessage( i
, I
.iter
) ^ c_I
;
163 /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
164 void TRWBP::calcBeliefF( size_t I
, Prob
&p
) const {
165 Real c_I
= Weight(I
); // TRWBP: c_I
167 Factor
Fprod( factor(I
) );
168 Prob
&prod
= Fprod
.p();
170 if( props
.logdomain
) {
172 prod
/= c_I
; // TRWBP
174 prod
^= (1.0 / c_I
); // TRWBP
176 // Calculate product of incoming messages and factor I
177 foreach( const Neighbor
&j
, nbF(I
) ) {
178 const Var
&v_j
= var(j
);
180 // TRWBP: corresponds to messages n_jI
181 // prod_j will be the product of messages coming into j
182 Prob
prod_j( v_j
.states(), props
.logdomain
? 0.0 : 1.0 );
183 foreach( const Neighbor
&J
, nbV(j
) ) {
184 Real c_J
= Weight(J
);
185 if( J
!= I
) { // for all J in nb(j) \ I
186 if( props
.logdomain
)
187 prod_j
+= newMessage( j
, J
.iter
) * c_J
;
189 prod_j
*= newMessage( j
, J
.iter
) ^ c_J
;
190 } else { // TRWBP: multiply by m_Ij^(c_I-1)
191 if( props
.logdomain
)
192 prod_j
+= newMessage( j
, J
.iter
) * (c_J
- 1.0);
194 prod_j
*= newMessage( j
, J
.iter
) ^ (c_J
- 1.0);
198 // multiply prod with prod_j
199 if( !DAI_TRWBP_FAST
) {
200 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
201 if( props
.logdomain
)
202 Fprod
+= Factor( v_j
, prod_j
);
204 Fprod
*= Factor( v_j
, prod_j
);
206 /* OPTIMIZED VERSION */
208 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
209 const ind_t
&ind
= index(j
, _I
);
211 for( size_t r
= 0; r
< prod
.size(); ++r
) {
212 if( props
.logdomain
)
213 prod
[r
] += prod_j
[ind
[r
]];
215 prod
[r
] *= prod_j
[ind
[r
]];
224 void TRWBP::construct() {
226 _weight
.resize( nrFactors(), 1.0 );
230 } // end of namespace dai