d87f8aa44d719eddde8c658b1b0c97a3fbfc9184
[libdai.git] / src / fbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/fbp.h>
10
11
12 #define DAI_FBP_FAST 1
13
14
15 namespace dai {
16
17
18 using namespace std;
19
20
21 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
22 Real FBP::logZ() const {
23 Real sum = 0.0;
24 for( size_t I = 0; I < nrFactors(); I++ ) {
25 sum += (beliefF(I) * factor(I).log(true)).sum(); // FBP
26 sum += Weight(I) * beliefF(I).entropy(); // FBP
27 }
28 for( size_t i = 0; i < nrVars(); ++i ) {
29 Real c_i = 0.0;
30 bforeach( const Neighbor &I, nbV(i) )
31 c_i += Weight(I);
32 if( c_i != 1.0 )
33 sum += (1.0 - c_i) * beliefV(i).entropy(); // FBP
34 }
35 return sum;
36 }
37
38
39 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
40 Prob FBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
41 Real c_I = Weight(I); // FBP: c_I
42
43 Factor Fprod( factor(I) );
44 Prob &prod = Fprod.p();
45
46 if( props.logdomain ) {
47 prod.takeLog();
48 prod /= c_I; // FBP
49 } else
50 prod ^= (1.0 / c_I); // FBP
51
52 // Calculate product of incoming messages and factor I
53 bforeach( const Neighbor &j, nbF(I) )
54 if( !(without_i && (j == i)) ) {
55 // prod_j will be the product of messages coming into j
56 // FBP: corresponds to messages n_jI
57 Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
58 bforeach( const Neighbor &J, nbV(j) )
59 if( J != I ) { // for all J in nb(j) \ I
60 if( props.logdomain )
61 prod_j += message( j, J.iter );
62 else
63 prod_j *= message( j, J.iter );
64 } else if( c_I != 1.0 ) {
65 // FBP: multiply by m_Ij^(1-1/c_I)
66 if( props.logdomain )
67 prod_j += newMessage( j, J.iter) * (1.0 - 1.0 / c_I);
68 else
69 prod_j *= newMessage( j, J.iter) ^ (1.0 - 1.0 / c_I);
70 }
71
72 // multiply prod with prod_j
73 if( !DAI_FBP_FAST ) {
74 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
75 if( props.logdomain )
76 Fprod += Factor( var(j), prod_j );
77 else
78 Fprod *= Factor( var(j), prod_j );
79 } else {
80 // OPTIMIZED VERSION
81 size_t _I = j.dual;
82 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
83 const ind_t &ind = index(j, _I);
84
85 for( size_t r = 0; r < prod.size(); ++r )
86 if( props.logdomain )
87 prod.set( r, prod[r] + prod_j[ind[r]] );
88 else
89 prod.set( r, prod[r] * prod_j[ind[r]] );
90 }
91 }
92 return prod;
93 }
94
95
96 // This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour
97 void FBP::calcNewMessage( size_t i, size_t _I ) {
98 // calculate updated message I->i
99 size_t I = nbV(i,_I);
100
101 Real c_I = Weight(I); // FBP: c_I
102
103 Factor Fprod( factor(I) );
104 Prob &prod = Fprod.p();
105 prod = calcIncomingMessageProduct( I, true, i );
106
107 if( props.logdomain ) {
108 prod -= prod.max();
109 prod.takeExp();
110 }
111
112 // Marginalize onto i
113 Prob marg;
114 if( !DAI_FBP_FAST ) {
115 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
116 if( props.inference == Properties::InfType::SUMPROD )
117 marg = Fprod.marginal( var(i) ).p();
118 else
119 marg = Fprod.maxMarginal( var(i) ).p();
120 } else {
121 // OPTIMIZED VERSION
122 marg = Prob( var(i).states(), 0.0 );
123 // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
124 const ind_t ind = index(i,_I);
125 if( props.inference == Properties::InfType::SUMPROD )
126 for( size_t r = 0; r < prod.size(); ++r )
127 marg.set( ind[r], marg[ind[r]] + prod[r] );
128 else
129 for( size_t r = 0; r < prod.size(); ++r )
130 if( prod[r] > marg[ind[r]] )
131 marg.set( ind[r], prod[r] );
132 marg.normalize();
133 }
134
135 // FBP
136 marg ^= c_I;
137
138 // Store result
139 if( props.logdomain )
140 newMessage(i,_I) = marg.log();
141 else
142 newMessage(i,_I) = marg;
143
144 // Update the residual if necessary
145 if( props.updates == Properties::UpdateType::SEQMAX )
146 updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), DISTLINF ) );
147 }
148
149
150 void FBP::construct() {
151 BP::construct();
152 _weight.resize( nrFactors(), 1.0 );
153 }
154
155
156 } // end of namespace dai