Various cleanups
[libdai.git] / src / bp_dual.cpp
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 #include <iostream>
22 #include <sstream>
23 #include <algorithm>
24
25 #include <dai/bp_dual.h>
26 #include <dai/util.h>
27 #include <dai/bipgraph.h>
28
29
30 namespace dai {
31
32
33 using namespace std;
34
35
36 typedef BipartiteGraph::Neighbor Neighbor;
37
38
39 void BP_dual::regenerateMessages() {
40 size_t nv = fg().nrVars();
41 _msgs.Zn.resize(nv);
42 _msgs.Zm.resize(nv);
43 _msgs.m.resize(nv);
44 _msgs.n.resize(nv);
45 for( size_t i = 0; i < nv; i++ ) {
46 size_t nvf = fg().nbV(i).size();
47 _msgs.Zn[i].resize(nvf, 1.0);
48 _msgs.Zm[i].resize(nvf, 1.0);
49 size_t states = fg().var(i).states();
50 _msgs.n[i].resize(nvf, Prob(states));
51 _msgs.m[i].resize(nvf, Prob(states));
52 }
53 }
54
55
56 void BP_dual::regenerateBeliefs() {
57 _beliefs.b1.clear();
58 _beliefs.b1.reserve(fg().nrVars());
59 _beliefs.Zb1.resize(fg().nrVars(), 1.0);
60 _beliefs.b2.clear();
61 _beliefs.b2.reserve(fg().nrFactors());
62 _beliefs.Zb2.resize(fg().nrFactors(), 1.0);
63
64 for( size_t i = 0; i < fg().nrVars(); i++ )
65 _beliefs.b1.push_back( Prob( fg().var(i).states() ) );
66 for( size_t I = 0; I < fg().nrFactors(); I++ )
67 _beliefs.b2.push_back( Prob( fg().factor(I).states() ) );
68 }
69
70
71 void BP_dual::init() {
72 regenerateMessages();
73 regenerateBeliefs();
74 calcMessages();
75 calcBeliefs();
76 }
77
78
79 void BP_dual::calcMessages() {
80 // calculate 'n' messages from "factor marginal / factor"
81 vector<Factor> bs;
82 size_t nf = fg().nrFactors();
83 for( size_t I = 0; I < nf; I++ )
84 bs.push_back(_ia->beliefF(I));
85 assert(nf == bs.size());
86 for( size_t I = 0; I < nf; I++ ) {
87 Factor f = bs[I];
88 f /= fg().factor(I);
89 foreach(const Neighbor &i, fg().nbF(I))
90 msgN(i, i.dual) = f.marginal(fg().var(i)).p();
91 }
92 // calculate 'm' messages and normalizers from 'n' messages
93 for( size_t i = 0; i < fg().nrVars(); i++ )
94 foreach(const Neighbor &I, fg().nbV(i))
95 calcNewM(i, I.iter);
96 // recalculate 'n' messages and normalizers from 'm' messages
97 for( size_t i = 0; i < fg().nrVars(); i++ ) {
98 foreach(const Neighbor &I, fg().nbV(i)) {
99 Prob oldN = msgN(i,I.iter);
100 calcNewN(i, I.iter);
101 Prob newN = msgN(i,I.iter);
102 #if 0
103 // check that new 'n' messages match old ones
104 if((oldN-newN).maxAbs() > 1.0e-5) {
105 cerr << "New 'n' messages don't match old: " <<
106 "(i,I) = (" << i << ", " << I <<
107 ") old = " << oldN << ", new = " << newN << endl;
108 DAI_THROW(INTERNAL_ERROR);
109 }
110 #endif
111 }
112 }
113 }
114
115
116 void BP_dual::calcBeliefV(size_t i) {
117 Prob prod( fg().var(i).states(), 1.0 );
118 foreach(const Neighbor &I, fg().nbV(i))
119 prod *= msgM(i,I.iter);
120 _beliefs.Zb1[i] = prod.normalize();
121 _beliefs.b1[i] = prod;
122 }
123
124
125 void BP_dual::calcBeliefF(size_t I) {
126 Prob prod( fg().factor(I).p() );
127 foreach(const Neighbor &j, fg().nbF(I)) {
128 IndexFor ind (fg().var(j), fg().factor(I).vars() );
129 Prob n(msgN(j,j.dual));
130 for(size_t x=0; ind >= 0; x++, ++ind)
131 prod[x] *= n[ind];
132 }
133 _beliefs.Zb2[I] = prod.normalize();
134 _beliefs.b2[I] = prod;
135 }
136
137
138 // called after run()
139 void BP_dual::calcBeliefs() {
140 for( size_t i = 0; i < fg().nrVars(); i++ )
141 calcBeliefV(i); // calculate b_i
142 for( size_t I = 0; I < fg().nrFactors(); I++ )
143 calcBeliefF(I); // calculate b_I
144 }
145
146
147 void BP_dual::calcNewM(size_t i, size_t _I) {
148 // calculate updated message I->i
149 const Neighbor &I = fg().nbV(i)[_I];
150 Prob prod( fg().factor(I).p() );
151 foreach(const Neighbor &j, fg().nbF(I)) {
152 if( j != i ) { // for all j in I \ i
153 Prob n(msgN(j,j.dual));
154 IndexFor ind(fg().var(j), fg().factor(I).vars());
155 for(size_t x=0; ind >= 0; x++, ++ind)
156 prod[x] *= n[ind];
157 }
158 }
159 // Marginalize onto i
160 Prob marg( fg().var(i).states(), 0.0 );
161 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
162 IndexFor ind(fg().var(i), fg().factor(I).vars());
163 for(size_t x=0; ind >= 0; x++, ++ind)
164 marg[ind] += prod[x];
165
166 _msgs.Zm[i][_I] = marg.normalize();
167 _msgs.m[i][_I] = marg;
168 }
169
170
171 void BP_dual::calcNewN(size_t i, size_t _I) {
172 // calculate updated message i->I
173 const Neighbor &I = fg().nbV(i)[_I];
174 Prob prod(fg().var(i).states(), 1.0);
175 foreach(const Neighbor &J, fg().nbV(i)) {
176 if(J.node != I.node) // for all J in i \ I
177 prod *= msgM(i,J.iter);
178 }
179 _msgs.Zn[i][_I] = prod.normalize();
180 _msgs.n[i][_I] = prod;
181 }
182
183
184 } // end of namespace dai