44014fe939674069d599987c5c62a6d8e15f5e7b
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
3 This file is part of libDAI.
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
25 #include <dai/bp_dual.h>
27 #include <dai/bipgraph.h>
36 typedef BipartiteGraph::Neighbor Neighbor
;
39 void BP_dual::init() {
47 void BP_dual::regenerateMessages() {
48 size_t nv
= fg().nrVars();
53 for( size_t i
= 0; i
< nv
; i
++ ) {
54 size_t nvf
= fg().nbV(i
).size();
55 _msgs
.Zn
[i
].resize(nvf
, 1.0);
56 _msgs
.Zm
[i
].resize(nvf
, 1.0);
57 size_t states
= fg().var(i
).states();
58 _msgs
.n
[i
].resize(nvf
, Prob(states
));
59 _msgs
.m
[i
].resize(nvf
, Prob(states
));
64 void BP_dual::regenerateBeliefs() {
66 _beliefs
.b1
.reserve(fg().nrVars());
67 _beliefs
.Zb1
.resize(fg().nrVars(), 1.0);
69 _beliefs
.b2
.reserve(fg().nrFactors());
70 _beliefs
.Zb2
.resize(fg().nrFactors(), 1.0);
72 for( size_t i
= 0; i
< fg().nrVars(); i
++ )
73 _beliefs
.b1
.push_back( Prob( fg().var(i
).states() ) );
74 for( size_t I
= 0; I
< fg().nrFactors(); I
++ )
75 _beliefs
.b2
.push_back( Prob( fg().factor(I
).states() ) );
79 void BP_dual::calcMessages() {
80 // calculate 'n' messages from "factor marginal / factor"
81 for( size_t I
= 0; I
< fg().nrFactors(); I
++ ) {
82 Factor f
= _ia
->beliefF(I
) / fg().factor(I
);
83 foreach( const Neighbor
&i
, fg().nbF(I
) )
84 msgN(i
, i
.dual
) = f
.marginal( fg().var(i
) ).p();
86 // calculate 'm' messages and normalizers from 'n' messages
87 for( size_t i
= 0; i
< fg().nrVars(); i
++ )
88 foreach( const Neighbor
&I
, fg().nbV(i
) )
89 calcNewM( i
, I
.iter
);
90 // recalculate 'n' messages and normalizers from 'm' messages
91 for( size_t i
= 0; i
< fg().nrVars(); i
++ )
92 foreach( const Neighbor
&I
, fg().nbV(i
) )
97 void BP_dual::calcNewM( size_t i
, size_t _I
) {
98 // calculate updated message I->i
99 const Neighbor
&I
= fg().nbV(i
)[_I
];
100 Prob
prod( fg().factor(I
).p() );
101 foreach( const Neighbor
&j
, fg().nbF(I
) )
102 if( j
!= i
) { // for all j in I \ i
103 Prob
&n
= msgN(j
,j
.dual
);
104 IndexFor
ind( fg().var(j
), fg().factor(I
).vars() );
105 for( size_t x
= 0; ind
>= 0; x
++, ++ind
)
108 // Marginalize onto i
109 Prob
marg( fg().var(i
).states(), 0.0 );
110 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
111 IndexFor
ind( fg().var(i
), fg().factor(I
).vars() );
112 for( size_t x
= 0; ind
>= 0; x
++, ++ind
)
113 marg
[ind
] += prod
[x
];
115 _msgs
.Zm
[i
][_I
] = marg
.normalize();
116 _msgs
.m
[i
][_I
] = marg
;
120 void BP_dual::calcNewN( size_t i
, size_t _I
) {
121 // calculate updated message i->I
122 const Neighbor
&I
= fg().nbV(i
)[_I
];
123 Prob
prod( fg().var(i
).states(), 1.0 );
124 foreach( const Neighbor
&J
, fg().nbV(i
) )
125 if( J
.node
!= I
.node
) // for all J in i \ I
126 prod
*= msgM(i
,J
.iter
);
127 _msgs
.Zn
[i
][_I
] = prod
.normalize();
128 _msgs
.n
[i
][_I
] = prod
;
132 void BP_dual::calcBeliefs() {
133 for( size_t i
= 0; i
< fg().nrVars(); i
++ )
134 calcBeliefV(i
); // calculate b_i
135 for( size_t I
= 0; I
< fg().nrFactors(); I
++ )
136 calcBeliefF(I
); // calculate b_I
140 void BP_dual::calcBeliefV( size_t i
) {
141 Prob
prod( fg().var(i
).states(), 1.0 );
142 foreach( const Neighbor
&I
, fg().nbV(i
) )
143 prod
*= msgM(i
,I
.iter
);
144 _beliefs
.Zb1
[i
] = prod
.normalize();
145 _beliefs
.b1
[i
] = prod
;
149 void BP_dual::calcBeliefF( size_t I
) {
150 Prob
prod( fg().factor(I
).p() );
151 foreach( const Neighbor
&j
, fg().nbF(I
) ) {
152 IndexFor
ind( fg().var(j
), fg().factor(I
).vars() );
153 Prob
n( msgN(j
,j
.dual
) );
154 for( size_t x
= 0; ind
>= 0; x
++, ++ind
)
157 _beliefs
.Zb2
[I
] = prod
.normalize();
158 _beliefs
.b2
[I
] = prod
;
162 } // end of namespace dai