[Frederik Eaton] Added Fractional Belief Propagation
[libdai.git] / src / fbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Frederik Eaton
8 */
9
10
11 #include <dai/fbp.h>
12
13
14 #define DAI_FBP_FAST 1
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 const char *FBP::Name = "FBP";
24
25
26 string FBP::identify() const {
27 return string(Name) + printProperties();
28 }
29
30
31 /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */
32 void FBP::calcNewMessage( size_t i, size_t _I ) {
33 // calculate updated message I->i
34 size_t I = nbV(i,_I);
35
36 Real scale = scaleF(I); // FBP: c_I
37
38 Factor Fprod( factor(I) );
39 Prob &prod = Fprod.p();
40 if( props.logdomain ) {
41 prod.takeLog();
42 prod *= (1/scale); // FBP
43 } else
44 prod ^= (1/scale); // FBP
45
46 // Calculate product of incoming messages and factor I
47 foreach( const Neighbor &j, nbF(I) )
48 if( j != i ) { // for all j in I \ i
49 // FBP: same as n_jI
50 // prod_j will be the product of messages coming into j
51 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
52 foreach( const Neighbor &J, nbV(j) )
53 if( J != I ) { // for all J in nb(j) \ I
54 if( props.logdomain )
55 prod_j += message( j, J.iter );
56 else
57 prod_j *= message( j, J.iter );
58 }
59
60
61 size_t _I = j.dual;
62 // FBP: now multiply by m_Ij^(1-1/c_I)
63 if(props.logdomain)
64 prod_j += message( j, _I)*(1-1/scale);
65 else
66 prod_j *= message( j, _I)^(1-1/scale);
67
68 // multiply prod with prod_j
69 if( !DAI_FBP_FAST ) {
70 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
71 if( props.logdomain )
72 Fprod += Factor( var(j), prod_j );
73 else
74 Fprod *= Factor( var(j), prod_j );
75 } else {
76 /* OPTIMIZED VERSION */
77 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
78 const ind_t &ind = index(j, _I);
79 for( size_t r = 0; r < prod.size(); ++r )
80 if( props.logdomain )
81 prod[r] += prod_j[ind[r]];
82 else
83 prod[r] *= prod_j[ind[r]];
84 }
85 }
86
87 if( props.logdomain ) {
88 prod -= prod.max();
89 prod.takeExp();
90 }
91
92 // Marginalize onto i
93 Prob marg;
94 if( !DAI_FBP_FAST ) {
95 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
96 if( props.inference == Properties::InfType::SUMPROD )
97 marg = Fprod.marginal( var(i) ).p();
98 else
99 marg = Fprod.maxMarginal( var(i) ).p();
100 } else {
101 /* OPTIMIZED VERSION */
102 marg = Prob( var(i).states(), 0.0 );
103 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
104 const ind_t ind = index(i,_I);
105 if( props.inference == Properties::InfType::SUMPROD )
106 for( size_t r = 0; r < prod.size(); ++r )
107 marg[ind[r]] += prod[r];
108 else
109 for( size_t r = 0; r < prod.size(); ++r )
110 if( prod[r] > marg[ind[r]] )
111 marg[ind[r]] = prod[r];
112 marg.normalize();
113 }
114
115 // FBP
116 marg ^= scale;
117
118 // Store result
119 if( props.logdomain )
120 newMessage(i,_I) = marg.log();
121 else
122 newMessage(i,_I) = marg;
123
124 // Update the residual if necessary
125 if( props.updates == Properties::UpdateType::SEQMAX )
126 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
127 }
128
129
130 /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */
131 void FBP::calcBeliefF( size_t I, Prob &p ) const {
132 Real scale = scaleF(I); // FBP: c_I
133
134 Factor Fprod( factor(I) );
135 Prob &prod = Fprod.p();
136
137 if( props.logdomain ) {
138 prod.takeLog();
139 prod /= scale; // FBP
140 } else
141 prod ^= (1/scale); // FBP
142
143 foreach( const Neighbor &j, nbF(I) ) {
144 // prod_j will be the product of messages coming into j
145 // FBP: corresponds to messages n_jI
146 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
147 foreach( const Neighbor &J, nbV(j) )
148 if( J != I ) { // for all J in nb(j) \ I
149 if( props.logdomain )
150 prod_j += newMessage( j, J.iter );
151 else
152 prod_j *= newMessage( j, J.iter );
153 }
154
155 size_t _I = j.dual;
156
157 // FBP: now multiply by m_Ij^(1-1/c_I)
158 if( props.logdomain )
159 prod_j += newMessage( j, _I)*(1-1/scale);
160 else
161 prod_j *= newMessage( j, _I)^(1-1/scale);
162
163 // multiply prod with prod_j
164 if( !DAI_FBP_FAST ) {
165 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
166 if( props.logdomain )
167 Fprod += Factor( var(j), prod_j );
168 else
169 Fprod *= Factor( var(j), prod_j );
170 } else {
171 /* OPTIMIZED VERSION */
172 size_t _I = j.dual;
173 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
174 const ind_t & ind = index(j, _I);
175
176 for( size_t r = 0; r < prod.size(); ++r ) {
177 if( props.logdomain )
178 prod[r] += prod_j[ind[r]];
179 else
180 prod[r] *= prod_j[ind[r]];
181 }
182 }
183 }
184
185 p = prod;
186 }
187
188
189 void FBP::construct() {
190 BP::construct();
191 constructScaleParams();
192 }
193
194
195 } // end of namespace dai