Replaced Name members by name() virtual functions (fixing a bug in matlab/dai.cpp)
[libdai.git] / src / fbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
8 * Copyright (C) 2009-2010 Joris Mooij [joris dot mooij at libdai dot org]
9 */
10
11
12 #include <dai/fbp.h>
13
14
15 #define DAI_FBP_FAST 1
16
17
18 namespace dai {
19
20
21 using namespace std;
22
23
24 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
25 Real FBP::logZ() const {
26 Real sum = 0.0;
27 for( size_t I = 0; I < nrFactors(); I++ ) {
28 sum += (beliefF(I) * factor(I).log(true)).sum(); // FBP
29 sum += Weight(I) * beliefF(I).entropy(); // FBP
30 }
31 for( size_t i = 0; i < nrVars(); ++i ) {
32 Real c_i = 0.0;
33 foreach( const Neighbor &I, nbV(i) )
34 c_i += Weight(I);
35 sum += (1.0 - c_i) * beliefV(i).entropy(); // FBP
36 }
37 return sum;
38 }
39
40
41 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
42 Prob FBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
43 Real c_I = Weight(I); // FBP: c_I
44
45 Factor Fprod( factor(I) );
46 Prob &prod = Fprod.p();
47
48 if( props.logdomain ) {
49 prod.takeLog();
50 prod /= c_I; // FBP
51 } else
52 prod ^= (1.0 / c_I); // FBP
53
54 // Calculate product of incoming messages and factor I
55 foreach( const Neighbor &j, nbF(I) )
56 if( !(without_i && (j == i)) ) {
57 // prod_j will be the product of messages coming into j
58 // FBP: corresponds to messages n_jI
59 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
60 foreach( const Neighbor &J, nbV(j) )
61 if( J != I ) { // for all J in nb(j) \ I
62 if( props.logdomain )
63 prod_j += message( j, J.iter );
64 else
65 prod_j *= message( j, J.iter );
66 } else {
67 // FBP: multiply by m_Ij^(1-1/c_I)
68 if( props.logdomain )
69 prod_j += newMessage( j, J.iter) * (1.0 - 1.0 / c_I);
70 else
71 prod_j *= newMessage( j, J.iter) ^ (1.0 - 1.0 / c_I);
72 }
73
74 // multiply prod with prod_j
75 if( !DAI_FBP_FAST ) {
76 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
77 if( props.logdomain )
78 Fprod += Factor( var(j), prod_j );
79 else
80 Fprod *= Factor( var(j), prod_j );
81 } else {
82 // OPTIMIZED VERSION
83 size_t _I = j.dual;
84 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
85 const ind_t &ind = index(j, _I);
86
87 for( size_t r = 0; r < prod.size(); ++r )
88 if( props.logdomain )
89 prod.set( r, prod[r] + prod_j[ind[r]] );
90 else
91 prod.set( r, prod[r] * prod_j[ind[r]] );
92 }
93 }
94 return prod;
95 }
96
97
98 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
99 void FBP::calcNewMessage( size_t i, size_t _I ) {
100 // calculate updated message I->i
101 size_t I = nbV(i,_I);
102
103 Real c_I = Weight(I); // FBP: c_I
104
105 Factor Fprod( factor(I) );
106 Prob &prod = Fprod.p();
107 prod = calcIncomingMessageProduct( I, true, i );
108
109 if( props.logdomain ) {
110 prod -= prod.max();
111 prod.takeExp();
112 }
113
114 // Marginalize onto i
115 Prob marg;
116 if( !DAI_FBP_FAST ) {
117 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
118 if( props.inference == Properties::InfType::SUMPROD )
119 marg = Fprod.marginal( var(i) ).p();
120 else
121 marg = Fprod.maxMarginal( var(i) ).p();
122 } else {
123 // OPTIMIZED VERSION
124 marg = Prob( var(i).states(), 0.0 );
125 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
126 const ind_t ind = index(i,_I);
127 if( props.inference == Properties::InfType::SUMPROD )
128 for( size_t r = 0; r < prod.size(); ++r )
129 marg.set( ind[r], marg[ind[r]] + prod[r] );
130 else
131 for( size_t r = 0; r < prod.size(); ++r )
132 if( prod[r] > marg[ind[r]] )
133 marg.set( ind[r], prod[r] );
134 marg.normalize();
135 }
136
137 // FBP
138 marg ^= c_I;
139
140 // Store result
141 if( props.logdomain )
142 newMessage(i,_I) = marg.log();
143 else
144 newMessage(i,_I) = marg;
145
146 // Update the residual if necessary
147 if( props.updates == Properties::UpdateType::SEQMAX )
148 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), DISTLINF ) );
149 }
150
151
152 void FBP::construct() {
153 BP::construct();
154 _weight.resize( nrFactors(), 1.0 );
155 }
156
157
158 } // end of namespace dai