Improved HAK (added 'maxtime' property)
[libdai.git] / src / fbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
8 * Copyright (C) 2009-2010 Joris Mooij [joris dot mooij at libdai dot org]
9 */
10
11
12 #include <dai/fbp.h>
13
14
15 #define DAI_FBP_FAST 1
16
17
18 namespace dai {
19
20
21 using namespace std;
22
23
24 const char *FBP::Name = "FBP";
25
26
27 string FBP::identify() const {
28 return string(Name) + printProperties();
29 }
30
31
32 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
33 Real FBP::logZ() const {
34 Real sum = 0.0;
35 for( size_t I = 0; I < nrFactors(); I++ ) {
36 sum += (beliefF(I) * factor(I).log(true)).sum(); // FBP
37 sum += Weight(I) * beliefF(I).entropy(); // FBP
38 }
39 for( size_t i = 0; i < nrVars(); ++i ) {
40 Real c_i = 0.0;
41 foreach( const Neighbor &I, nbV(i) )
42 c_i += Weight(I);
43 sum += (1.0 - c_i) * beliefV(i).entropy(); // FBP
44 }
45 return sum;
46 }
47
48
49 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
50 Prob FBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
51 Real c_I = Weight(I); // FBP: c_I
52
53 Factor Fprod( factor(I) );
54 Prob &prod = Fprod.p();
55
56 if( props.logdomain ) {
57 prod.takeLog();
58 prod /= c_I; // FBP
59 } else
60 prod ^= (1.0 / c_I); // FBP
61
62 // Calculate product of incoming messages and factor I
63 foreach( const Neighbor &j, nbF(I) )
64 if( !(without_i && (j == i)) ) {
65 // prod_j will be the product of messages coming into j
66 // FBP: corresponds to messages n_jI
67 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
68 foreach( const Neighbor &J, nbV(j) )
69 if( J != I ) { // for all J in nb(j) \ I
70 if( props.logdomain )
71 prod_j += message( j, J.iter );
72 else
73 prod_j *= message( j, J.iter );
74 } else {
75 // FBP: multiply by m_Ij^(1-1/c_I)
76 if( props.logdomain )
77 prod_j += newMessage( j, J.iter) * (1.0 - 1.0 / c_I);
78 else
79 prod_j *= newMessage( j, J.iter) ^ (1.0 - 1.0 / c_I);
80 }
81
82 // multiply prod with prod_j
83 if( !DAI_FBP_FAST ) {
84 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
85 if( props.logdomain )
86 Fprod += Factor( var(j), prod_j );
87 else
88 Fprod *= Factor( var(j), prod_j );
89 } else {
90 // OPTIMIZED VERSION
91 size_t _I = j.dual;
92 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
93 const ind_t &ind = index(j, _I);
94
95 for( size_t r = 0; r < prod.size(); ++r )
96 if( props.logdomain )
97 prod.set( r, prod[r] + prod_j[ind[r]] );
98 else
99 prod.set( r, prod[r] * prod_j[ind[r]] );
100 }
101 }
102 return prod;
103 }
104
105
106 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
107 void FBP::calcNewMessage( size_t i, size_t _I ) {
108 // calculate updated message I->i
109 size_t I = nbV(i,_I);
110
111 Real c_I = Weight(I); // FBP: c_I
112
113 Factor Fprod( factor(I) );
114 Prob &prod = Fprod.p();
115 prod = calcIncomingMessageProduct( I, true, i );
116
117 if( props.logdomain ) {
118 prod -= prod.max();
119 prod.takeExp();
120 }
121
122 // Marginalize onto i
123 Prob marg;
124 if( !DAI_FBP_FAST ) {
125 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
126 if( props.inference == Properties::InfType::SUMPROD )
127 marg = Fprod.marginal( var(i) ).p();
128 else
129 marg = Fprod.maxMarginal( var(i) ).p();
130 } else {
131 // OPTIMIZED VERSION
132 marg = Prob( var(i).states(), 0.0 );
133 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
134 const ind_t ind = index(i,_I);
135 if( props.inference == Properties::InfType::SUMPROD )
136 for( size_t r = 0; r < prod.size(); ++r )
137 marg.set( ind[r], marg[ind[r]] + prod[r] );
138 else
139 for( size_t r = 0; r < prod.size(); ++r )
140 if( prod[r] > marg[ind[r]] )
141 marg.set( ind[r], prod[r] );
142 marg.normalize();
143 }
144
145 // FBP
146 marg ^= c_I;
147
148 // Store result
149 if( props.logdomain )
150 newMessage(i,_I) = marg.log();
151 else
152 newMessage(i,_I) = marg;
153
154 // Update the residual if necessary
155 if( props.updates == Properties::UpdateType::SEQMAX )
156 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), DISTLINF ) );
157 }
158
159
160 void FBP::construct() {
161 BP::construct();
162 _weight.resize( nrFactors(), 1.0 );
163 }
164
165
166 } // end of namespace dai