Multiple changes: changes in build system, one workaround and one bug fix
[libdai.git] / src / fbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/dai_config.h>
10 #ifdef DAI_WITH_FBP
11
12
13 #include <dai/fbp.h>
14
15
16 #define DAI_FBP_FAST 1
17
18
19 namespace dai {
20
21
22 using namespace std;
23
24
25 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
26 Real FBP::logZ() const {
27 Real sum = 0.0;
28 for( size_t I = 0; I < nrFactors(); I++ ) {
29 sum += (beliefF(I) * factor(I).log(true)).sum(); // FBP
30 sum += Weight(I) * beliefF(I).entropy(); // FBP
31 }
32 for( size_t i = 0; i < nrVars(); ++i ) {
33 Real c_i = 0.0;
34 bforeach( const Neighbor &I, nbV(i) )
35 c_i += Weight(I);
36 if( c_i != 1.0 )
37 sum += (1.0 - c_i) * beliefV(i).entropy(); // FBP
38 }
39 return sum;
40 }
41
42
43 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
44 Prob FBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
45 Real c_I = Weight(I); // FBP: c_I
46
47 Factor Fprod( factor(I) );
48 Prob &prod = Fprod.p();
49
50 if( props.logdomain ) {
51 prod.takeLog();
52 prod /= c_I; // FBP
53 } else
54 prod ^= (1.0 / c_I); // FBP
55
56 // Calculate product of incoming messages and factor I
57 bforeach( const Neighbor &j, nbF(I) )
58 if( !(without_i && (j == i)) ) {
59 // prod_j will be the product of messages coming into j
60 // FBP: corresponds to messages n_jI
61 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
62 bforeach( const Neighbor &J, nbV(j) )
63 if( J != I ) { // for all J in nb(j) \ I
64 if( props.logdomain )
65 prod_j += message( j, J.iter );
66 else
67 prod_j *= message( j, J.iter );
68 } else if( c_I != 1.0 ) {
69 // FBP: multiply by m_Ij^(1-1/c_I)
70 if( props.logdomain )
71 prod_j += newMessage( j, J.iter) * (1.0 - 1.0 / c_I);
72 else
73 prod_j *= newMessage( j, J.iter) ^ (1.0 - 1.0 / c_I);
74 }
75
76 // multiply prod with prod_j
77 if( !DAI_FBP_FAST ) {
78 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
79 if( props.logdomain )
80 Fprod += Factor( var(j), prod_j );
81 else
82 Fprod *= Factor( var(j), prod_j );
83 } else {
84 // OPTIMIZED VERSION
85 size_t _I = j.dual;
86 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
87 const ind_t &ind = index(j, _I);
88
89 for( size_t r = 0; r < prod.size(); ++r )
90 if( props.logdomain )
91 prod.set( r, prod[r] + prod_j[ind[r]] );
92 else
93 prod.set( r, prod[r] * prod_j[ind[r]] );
94 }
95 }
96 return prod;
97 }
98
99
100 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
101 void FBP::calcNewMessage( size_t i, size_t _I ) {
102 // calculate updated message I->i
103 size_t I = nbV(i,_I);
104
105 Real c_I = Weight(I); // FBP: c_I
106
107 Factor Fprod( factor(I) );
108 Prob &prod = Fprod.p();
109 prod = calcIncomingMessageProduct( I, true, i );
110
111 if( props.logdomain ) {
112 prod -= prod.max();
113 prod.takeExp();
114 }
115
116 // Marginalize onto i
117 Prob marg;
118 if( !DAI_FBP_FAST ) {
119 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
120 if( props.inference == Properties::InfType::SUMPROD )
121 marg = Fprod.marginal( var(i) ).p();
122 else
123 marg = Fprod.maxMarginal( var(i) ).p();
124 } else {
125 // OPTIMIZED VERSION
126 marg = Prob( var(i).states(), 0.0 );
127 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
128 const ind_t ind = index(i,_I);
129 if( props.inference == Properties::InfType::SUMPROD )
130 for( size_t r = 0; r < prod.size(); ++r )
131 marg.set( ind[r], marg[ind[r]] + prod[r] );
132 else
133 for( size_t r = 0; r < prod.size(); ++r )
134 if( prod[r] > marg[ind[r]] )
135 marg.set( ind[r], prod[r] );
136 marg.normalize();
137 }
138
139 // FBP
140 marg ^= c_I;
141
142 // Store result
143 if( props.logdomain )
144 newMessage(i,_I) = marg.log();
145 else
146 newMessage(i,_I) = marg;
147
148 // Update the residual if necessary
149 if( props.updates == Properties::UpdateType::SEQMAX )
150 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), DISTLINF ) );
151 }
152
153
154 void FBP::construct() {
155 BP::construct();
156 _weight.resize( nrFactors(), 1.0 );
157 }
158
159
160 } // end of namespace dai
161
162
163 #endif