[Frederik Eaton] Improved bp_dual
[libdai.git] / src / bp_dual.cpp
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 #include <iostream>
22 #include <sstream>
23 #include <algorithm>
24
25 #include <dai/bp_dual.h>
26 #include <dai/util.h>
27 #include <dai/bipgraph.h>
28
29
30 namespace dai {
31
32
33 using namespace std;
34
35
36 typedef BipartiteGraph::Neighbor Neighbor;
37
38
39 void BP_dual::RegenerateMessages() {
40 size_t nv = fg().nrVars();
41 _msgs.Zn.resize(nv);
42 _msgs.Zm.resize(nv);
43 _msgs.m.resize(nv);
44 _msgs.n.resize(nv);
45 for(size_t i=0; i<nv; i++) {
46 size_t nvf = fg().nbV(i).size();
47 _msgs.Zn[i].resize(nvf, 1.0);
48 _msgs.Zm[i].resize(nvf, 1.0);
49 size_t states = fg().var(i).states();
50 _msgs.n[i].resize(nvf, Prob(states));
51 _msgs.m[i].resize(nvf, Prob(states));
52 }
53 }
54
55
56 void BP_dual::RegenerateBeliefs() {
57 _beliefs.b1.clear();
58 _beliefs.b1.reserve(fg().nrVars());
59 _beliefs.Zb1.resize(fg().nrVars(), 1.0);
60 _beliefs.b2.clear();
61 _beliefs.b2.reserve(fg().nrFactors());
62 _beliefs.Zb2.resize(fg().nrFactors(), 1.0);
63
64 for(size_t i=0; i<fg().nrVars(); i++) {
65 _beliefs.b1.push_back( Prob( fg().var(i).states() ) );
66 }
67 for(size_t I=0; I<fg().nrFactors(); I++) {
68 _beliefs.b2.push_back( Prob( fg().factor(I).states() ) );
69 }
70 }
71
72
73 void BP_dual::Init() {
74 RegenerateMessages();
75 RegenerateBeliefs();
76 CalcMessages();
77 CalcBeliefs();
78 }
79
80
81 void BP_dual::CalcMessages() {
82 // calculate 'n' messages from "factor marginal / factor"
83 vector<Factor> bs;
84 size_t nf = fg().nrFactors();
85 for( size_t I = 0; I < nf; I++ )
86 bs.push_back(_ia->beliefF(I));
87 assert(nf == bs.size());
88 for( size_t I = 0; I < nf; I++ ) {
89 Factor f = bs[I];
90 f /= fg().factor(I);
91 foreach(const Neighbor &i, fg().nbF(I)) {
92 msgN(i, i.dual) = f.marginal(fg().var(i)).p();
93 }
94 }
95 // calculate 'm' messages and normalizers from 'n' messages
96 for( size_t i = 0; i < fg().nrVars(); i++ )
97 foreach(const Neighbor &I, fg().nbV(i))
98 calcNewM(i, I.iter);
99 // recalculate 'n' messages and normalizers from 'm' messages
100 for( size_t i = 0; i < fg().nrVars(); i++ ) {
101 foreach(const Neighbor &I, fg().nbV(i)) {
102 Prob oldN = msgN(i,I.iter);
103 calcNewN(i, I.iter);
104 Prob newN = msgN(i,I.iter);
105 #if 0
106 // check that new 'n' messages match old ones
107 if((oldN-newN).maxAbs() > 1.0e-5) {
108 cerr << "New 'n' messages don't match old: " <<
109 "(i,I) = (" << i << ", " << I <<
110 ") old = " << oldN << ", new = " << newN << endl;
111 DAI_THROW(INTERNAL_ERROR);
112 }
113 #endif
114 }
115 }
116 }
117
118
119 void BP_dual::CalcBeliefV(size_t i) {
120 Prob prod( fg().var(i).states(), 1.0 );
121 foreach(const Neighbor &I, fg().nbV(i)) {
122 prod *= msgM(i,I.iter);
123 }
124 _beliefs.Zb1[i] = prod.normalize();
125 _beliefs.b1[i] = prod;
126 }
127
128
129 void BP_dual::CalcBeliefF(size_t I) {
130 Prob prod( fg().factor(I).p() );
131 foreach(const Neighbor &j, fg().nbF(I)) {
132 IndexFor ind (fg().var(j), fg().factor(I).vars() );
133 Prob n(msgN(j,j.dual));
134 for(size_t x=0; ind >= 0; x++, ++ind) {
135 prod[x] *= n[ind];
136 }
137 }
138 _beliefs.Zb2[I] = prod.normalize();
139 _beliefs.b2[I] = prod;
140 }
141
142
143 // called after run()
144 void BP_dual::CalcBeliefs() {
145 for( size_t i = 0; i < fg().nrVars(); i++ )
146 CalcBeliefV(i); // calculate b_i
147 for( size_t I = 0; I < fg().nrFactors(); I++ )
148 CalcBeliefF(I); // calculate b_I
149 }
150
151
152 void BP_dual::calcNewM(size_t i, size_t _I) {
153 // calculate updated message I->i
154 const Neighbor &I = fg().nbV(i)[_I];
155 Prob prod( fg().factor(I).p() );
156 foreach(const Neighbor &j, fg().nbF(I)) {
157 if( j != i ) { // for all j in I \ i
158 Prob n(msgN(j,j.dual));
159 IndexFor ind(fg().var(j), fg().factor(I).vars());
160 for(size_t x=0; ind >= 0; x++, ++ind)
161 prod[x] *= n[ind];
162 }
163 }
164 // Marginalize onto i
165 Prob marg( fg().var(i).states(), 0.0 );
166 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
167 IndexFor ind(fg().var(i), fg().factor(I).vars());
168 for(size_t x=0; ind >= 0; x++, ++ind)
169 marg[ind] += prod[x];
170
171 _msgs.Zm[i][_I] = marg.normalize();
172 _msgs.m[i][_I] = marg;
173 }
174
175
176 void BP_dual::calcNewN(size_t i, size_t _I) {
177 // calculate updated message i->I
178 const Neighbor &I = fg().nbV(i)[_I];
179 Prob prod(fg().var(i).states(), 1.0);
180 foreach(const Neighbor &J, fg().nbV(i)) {
181 if(J.node != I.node) { // for all J in i \ I
182 prod *= msgM(i,J.iter);
183 }
184 }
185 _msgs.Zn[i][_I] = prod.normalize();
186 _msgs.n[i][_I] = prod;
187 }
188
189
190 } // end of namespace dai