Merged duplicate code (in calcBeliefF() and calcNewMessage()) in BP,FBP,TRWBP
[libdai.git] / src / trwbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #include <dai/trwbp.h>
12
13
14 #define DAI_TRWBP_FAST 1
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 const char *TRWBP::Name = "TRWBP";
24
25
26 string TRWBP::identify() const {
27 return string(Name) + printProperties();
28 }
29
30
31 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
32 Real TRWBP::logZ() const {
33 Real sum = 0.0;
34 for( size_t I = 0; I < nrFactors(); I++ ) {
35 sum += (beliefF(I) * factor(I).log(true)).sum(); // TRWBP/FBP
36 sum += Weight(I) * beliefF(I).entropy(); // TRWBP/FBP
37 }
38 for( size_t i = 0; i < nrVars(); ++i ) {
39 Real c_i = 0.0;
40 foreach( const Neighbor &I, nbV(i) )
41 c_i += Weight(I);
42 sum += (1.0 - c_i) * beliefV(i).entropy(); // TRWBP/FBP
43 }
44 return sum;
45 }
46
47
48 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
49 Prob TRWBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
50 Real c_I = Weight(I); // TRWBP: c_I
51
52 Factor Fprod( factor(I) );
53 Prob &prod = Fprod.p();
54 if( props.logdomain ) {
55 prod.takeLog();
56 prod /= c_I; // TRWBP
57 } else
58 prod ^= (1.0 / c_I); // TRWBP
59
60 // Calculate product of incoming messages and factor I
61 foreach( const Neighbor &j, nbF(I) )
62 if( !(without_i && (j == i)) ) {
63 const Var &v_j = var(j);
64 // prod_j will be the product of messages coming into j
65 // TRWBP: corresponds to messages n_jI
66 Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
67 foreach( const Neighbor &J, nbV(j) ) {
68 Real c_J = Weight(J); // TRWBP
69 if( J != I ) { // for all J in nb(j) \ I
70 if( props.logdomain )
71 prod_j += message( j, J.iter ) * c_J;
72 else
73 prod_j *= message( j, J.iter ) ^ c_J;
74 } else { // TRWBP: multiply by m_Ij^(c_I-1)
75 if( props.logdomain )
76 prod_j += message( j, J.iter ) * (c_J - 1.0);
77 else
78 prod_j *= message( j, J.iter ) ^ (c_J - 1.0);
79 }
80 }
81
82 // multiply prod with prod_j
83 if( !DAI_TRWBP_FAST ) {
84 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
85 if( props.logdomain )
86 Fprod += Factor( v_j, prod_j );
87 else
88 Fprod *= Factor( v_j, prod_j );
89 } else {
90 // OPTIMIZED VERSION
91 size_t _I = j.dual;
92 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
93 const ind_t &ind = index(j, _I);
94
95 for( size_t r = 0; r < prod.size(); ++r ) {
96 if( props.logdomain )
97 prod[r] += prod_j[ind[r]];
98 else
99 prod[r] *= prod_j[ind[r]];
100 }
101 }
102 }
103
104 return prod;
105 }
106
107
108 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
109 void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
110 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
111 foreach( const Neighbor &I, nbV(i) ) {
112 Real c_I = Weight(I);
113 if( props.logdomain )
114 p += newMessage( i, I.iter ) * c_I;
115 else
116 p *= newMessage( i, I.iter ) ^ c_I;
117 }
118 }
119
120
121 void TRWBP::construct() {
122 BP::construct();
123 _weight.resize( nrFactors(), 1.0 );
124 }
125
126
127 } // end of namespace dai