Fixed numerical issues in MF, FBP and TRWBP (discovered in sparse branch)
[libdai.git] / src / fbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
8 * Copyright (C) 2009-2010 Joris Mooij [joris dot mooij at libdai dot org]
9 */
10
11
12 #include <dai/fbp.h>
13
14
15 #define DAI_FBP_FAST 1
16
17
18 namespace dai {
19
20
21 using namespace std;
22
23
24 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
25 Real FBP::logZ() const {
26 Real sum = 0.0;
27 for( size_t I = 0; I < nrFactors(); I++ ) {
28 sum += (beliefF(I) * factor(I).log(true)).sum(); // FBP
29 sum += Weight(I) * beliefF(I).entropy(); // FBP
30 }
31 for( size_t i = 0; i < nrVars(); ++i ) {
32 Real c_i = 0.0;
33 foreach( const Neighbor &I, nbV(i) )
34 c_i += Weight(I);
35 if( c_i != 1.0 )
36 sum += (1.0 - c_i) * beliefV(i).entropy(); // FBP
37 }
38 return sum;
39 }
40
41
42 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
43 Prob FBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
44 Real c_I = Weight(I); // FBP: c_I
45
46 Factor Fprod( factor(I) );
47 Prob &prod = Fprod.p();
48
49 if( props.logdomain ) {
50 prod.takeLog();
51 prod /= c_I; // FBP
52 } else
53 prod ^= (1.0 / c_I); // FBP
54
55 // Calculate product of incoming messages and factor I
56 foreach( const Neighbor &j, nbF(I) )
57 if( !(without_i && (j == i)) ) {
58 // prod_j will be the product of messages coming into j
59 // FBP: corresponds to messages n_jI
60 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
61 foreach( const Neighbor &J, nbV(j) )
62 if( J != I ) { // for all J in nb(j) \ I
63 if( props.logdomain )
64 prod_j += message( j, J.iter );
65 else
66 prod_j *= message( j, J.iter );
67 } else if( c_I != 1.0 ) {
68 // FBP: multiply by m_Ij^(1-1/c_I)
69 if( props.logdomain )
70 prod_j += newMessage( j, J.iter) * (1.0 - 1.0 / c_I);
71 else
72 prod_j *= newMessage( j, J.iter) ^ (1.0 - 1.0 / c_I);
73 }
74
75 // multiply prod with prod_j
76 if( !DAI_FBP_FAST ) {
77 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
78 if( props.logdomain )
79 Fprod += Factor( var(j), prod_j );
80 else
81 Fprod *= Factor( var(j), prod_j );
82 } else {
83 // OPTIMIZED VERSION
84 size_t _I = j.dual;
85 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
86 const ind_t &ind = index(j, _I);
87
88 for( size_t r = 0; r < prod.size(); ++r )
89 if( props.logdomain )
90 prod.set( r, prod[r] + prod_j[ind[r]] );
91 else
92 prod.set( r, prod[r] * prod_j[ind[r]] );
93 }
94 }
95 return prod;
96 }
97
98
99 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
100 void FBP::calcNewMessage( size_t i, size_t _I ) {
101 // calculate updated message I->i
102 size_t I = nbV(i,_I);
103
104 Real c_I = Weight(I); // FBP: c_I
105
106 Factor Fprod( factor(I) );
107 Prob &prod = Fprod.p();
108 prod = calcIncomingMessageProduct( I, true, i );
109
110 if( props.logdomain ) {
111 prod -= prod.max();
112 prod.takeExp();
113 }
114
115 // Marginalize onto i
116 Prob marg;
117 if( !DAI_FBP_FAST ) {
118 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
119 if( props.inference == Properties::InfType::SUMPROD )
120 marg = Fprod.marginal( var(i) ).p();
121 else
122 marg = Fprod.maxMarginal( var(i) ).p();
123 } else {
124 // OPTIMIZED VERSION
125 marg = Prob( var(i).states(), 0.0 );
126 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
127 const ind_t ind = index(i,_I);
128 if( props.inference == Properties::InfType::SUMPROD )
129 for( size_t r = 0; r < prod.size(); ++r )
130 marg.set( ind[r], marg[ind[r]] + prod[r] );
131 else
132 for( size_t r = 0; r < prod.size(); ++r )
133 if( prod[r] > marg[ind[r]] )
134 marg.set( ind[r], prod[r] );
135 marg.normalize();
136 }
137
138 // FBP
139 marg ^= c_I;
140
141 // Store result
142 if( props.logdomain )
143 newMessage(i,_I) = marg.log();
144 else
145 newMessage(i,_I) = marg;
146
147 // Update the residual if necessary
148 if( props.updates == Properties::UpdateType::SEQMAX )
149 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), DISTLINF ) );
150 }
151
152
153 void FBP::construct() {
154 BP::construct();
155 _weight.resize( nrFactors(), 1.0 );
156 }
157
158
159 } // end of namespace dai