a303ec6a2538b7759a3268ec7f27b85351ac4ef0
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
3 This file is part of libDAI.
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
25 #include <dai/bp_dual.h>
27 #include <dai/bipgraph.h>
36 typedef BipartiteGraph::Neighbor Neighbor
;
39 void BP_dual::RegenerateMessages() {
40 size_t nv
= fg().nrVars();
45 for(size_t i
=0; i
<nv
; i
++) {
46 size_t nvf
= fg().nbV(i
).size();
47 _msgs
.Zn
[i
].resize(nvf
, 1.0);
48 _msgs
.Zm
[i
].resize(nvf
, 1.0);
49 size_t states
= fg().var(i
).states();
50 _msgs
.n
[i
].resize(nvf
, Prob(states
));
51 _msgs
.m
[i
].resize(nvf
, Prob(states
));
56 void BP_dual::RegenerateBeliefs() {
58 _beliefs
.b1
.reserve(fg().nrVars());
59 _beliefs
.Zb1
.resize(fg().nrVars(), 1.0);
61 _beliefs
.b2
.reserve(fg().nrFactors());
62 _beliefs
.Zb2
.resize(fg().nrFactors(), 1.0);
64 for(size_t i
=0; i
<fg().nrVars(); i
++) {
65 _beliefs
.b1
.push_back( Prob( fg().var(i
).states() ) );
67 for(size_t I
=0; I
<fg().nrFactors(); I
++) {
68 _beliefs
.b2
.push_back( Prob( fg().factor(I
).states() ) );
73 void BP_dual::Init() {
81 void BP_dual::CalcMessages() {
82 // calculate 'n' messages from "factor marginal / factor"
84 size_t nf
= fg().nrFactors();
85 for( size_t I
= 0; I
< nf
; I
++ )
86 bs
.push_back(_ia
->beliefF(I
));
87 assert(nf
== bs
.size());
88 for( size_t I
= 0; I
< nf
; I
++ ) {
91 foreach(const Neighbor
&i
, fg().nbF(I
)) {
92 msgN(i
, i
.dual
) = f
.marginal(fg().var(i
)).p();
95 // calculate 'm' messages and normalizers from 'n' messages
96 for( size_t i
= 0; i
< fg().nrVars(); i
++ )
97 foreach(const Neighbor
&I
, fg().nbV(i
))
99 // recalculate 'n' messages and normalizers from 'm' messages
100 for( size_t i
= 0; i
< fg().nrVars(); i
++ ) {
101 foreach(const Neighbor
&I
, fg().nbV(i
)) {
102 Prob oldN
= msgN(i
,I
.iter
);
104 Prob newN
= msgN(i
,I
.iter
);
106 // check that new 'n' messages match old ones
107 if((oldN
-newN
).maxAbs() > 1.0e-5) {
108 cerr
<< "New 'n' messages don't match old: " <<
109 "(i,I) = (" << i
<< ", " << I
<<
110 ") old = " << oldN
<< ", new = " << newN
<< endl
;
111 DAI_THROW(INTERNAL_ERROR
);
119 void BP_dual::CalcBeliefV(size_t i
) {
120 Prob
prod( fg().var(i
).states(), 1.0 );
121 foreach(const Neighbor
&I
, fg().nbV(i
)) {
122 prod
*= msgM(i
,I
.iter
);
124 _beliefs
.Zb1
[i
] = prod
.normalize();
125 _beliefs
.b1
[i
] = prod
;
129 void BP_dual::CalcBeliefF(size_t I
) {
130 Prob
prod( fg().factor(I
).p() );
131 foreach(const Neighbor
&j
, fg().nbF(I
)) {
132 IndexFor
ind (fg().var(j
), fg().factor(I
).vars() );
133 Prob
n(msgN(j
,j
.dual
));
134 for(size_t x
=0; ind
>= 0; x
++, ++ind
) {
138 _beliefs
.Zb2
[I
] = prod
.normalize();
139 _beliefs
.b2
[I
] = prod
;
143 // called after run()
144 void BP_dual::CalcBeliefs() {
145 for( size_t i
= 0; i
< fg().nrVars(); i
++ )
146 CalcBeliefV(i
); // calculate b_i
147 for( size_t I
= 0; I
< fg().nrFactors(); I
++ )
148 CalcBeliefF(I
); // calculate b_I
152 void BP_dual::calcNewM(size_t i
, size_t _I
) {
153 // calculate updated message I->i
154 const Neighbor
&I
= fg().nbV(i
)[_I
];
155 Prob
prod( fg().factor(I
).p() );
156 foreach(const Neighbor
&j
, fg().nbF(I
)) {
157 if( j
!= i
) { // for all j in I \ i
158 Prob
n(msgN(j
,j
.dual
));
159 IndexFor
ind(fg().var(j
), fg().factor(I
).vars());
160 for(size_t x
=0; ind
>= 0; x
++, ++ind
)
164 // Marginalize onto i
165 Prob
marg( fg().var(i
).states(), 0.0 );
166 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
167 IndexFor
ind(fg().var(i
), fg().factor(I
).vars());
168 for(size_t x
=0; ind
>= 0; x
++, ++ind
)
169 marg
[ind
] += prod
[x
];
171 _msgs
.Zm
[i
][_I
] = marg
.normalize();
172 _msgs
.m
[i
][_I
] = marg
;
176 void BP_dual::calcNewN(size_t i
, size_t _I
) {
177 // calculate updated message i->I
178 const Neighbor
&I
= fg().nbV(i
)[_I
];
179 Prob
prod(fg().var(i
).states(), 1.0);
180 foreach(const Neighbor
&J
, fg().nbV(i
)) {
181 if(J
.node
!= I
.node
) { // for all J in i \ I
182 prod
*= msgM(i
,J
.iter
);
185 _msgs
.Zn
[i
][_I
] = prod
.normalize();
186 _msgs
.n
[i
][_I
] = prod
;
190 } // end of namespace dai