1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
28 #include <dai/diffs.h>
30 #include <dai/properties.h>
39 const char *BP::Name
= "BP";
42 bool BP::checkProperties() {
43 if( !HasProperty("updates") )
45 if( !HasProperty("tol") )
47 if (!HasProperty("maxiter") )
49 if (!HasProperty("verbose") )
52 ConvertPropertyTo
<double>("tol");
53 ConvertPropertyTo
<size_t>("maxiter");
54 ConvertPropertyTo
<size_t>("verbose");
55 ConvertPropertyTo
<UpdateType
>("updates");
61 void BP::Regenerate() {
62 DAIAlgFG::Regenerate();
66 _messages
.reserve(nr_edges());
70 _indices
.reserve(nr_edges());
72 // create messages and indices
73 for( vector
<_edge_t
>::const_iterator iI
=edges().begin(); iI
!=edges().end(); ++iI
) {
74 _messages
.push_back( Prob( var(iI
->first
).states() ) );
76 vector
<size_t> ind( factor(iI
->second
).stateSpace(), 0 );
77 Index
i (var(iI
->first
), factor(iI
->second
).vars() );
78 for( size_t j
= 0; i
>= 0; ++i
,++j
)
80 _indices
.push_back( ind
);
83 // create new_messages
84 _newmessages
= _messages
;
89 assert( checkProperties() );
90 for( vector
<Prob
>::iterator mij
= _messages
.begin(); mij
!= _messages
.end(); ++mij
)
91 mij
->fill(1.0 / mij
->size());
92 _newmessages
= _messages
;
96 void BP::calcNewMessage (size_t iI
) {
97 // calculate updated message I->i
98 size_t i
= edge(iI
).first
;
99 size_t I
= edge(iI
).second
;
101 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
103 Factor prod( factor( I ) );
104 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ )
105 if( *j != i ) { // for all j in I \ i
106 for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); J++ )
107 if( *J != I ) { // for all J in nb(j) \ I
108 prod *= Factor( *j, message(*j,*J) );
109 Factor marg = prod.marginal(var(i));
112 Prob
prod( factor(I
).p() );
114 // Calculate product of incoming messages and factor I
115 for( _nb_cit j
= nb2(I
).begin(); j
!= nb2(I
).end(); ++j
)
116 if( *j
!= i
) { // for all j in I \ i
117 // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
118 _ind_t
* ind
= &(index(*j
,I
));
120 // prod_j will be the product of messages coming into j
121 Prob
prod_j( var(*j
).states() );
122 for( _nb_cit J
= nb1(*j
).begin(); J
!= nb1(*j
).end(); ++J
)
123 if( *J
!= I
) // for all J in nb(j) \ I
124 prod_j
*= message(*j
,*J
);
126 // multiply prod with prod_j
127 for( size_t r
= 0; r
< prod
.size(); ++r
)
128 prod
[r
] *= prod_j
[(*ind
)[r
]];
131 // Marginalize onto i
132 Prob
marg( var(i
).states(), 0.0 );
133 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
134 _ind_t
* ind
= &(index(i
,I
));
135 for( size_t r
= 0; r
< prod
.size(); ++r
)
136 marg
[(*ind
)[r
]] += prod
[r
];
137 marg
.normalize( _normtype
);
140 _newmessages
[iI
] = marg
;
144 // BP::run does not check for NANs for performance reasons
145 // Somehow NaNs do not often occur in BP...
148 cout
<< "Starting " << identify() << "...";
153 Diffs
diffs(nrVars(), 1.0);
155 vector
<size_t> edge_seq
;
156 vector
<double> residuals
;
158 vector
<Factor
> old_beliefs
;
159 old_beliefs
.reserve( nrVars() );
160 for( size_t i
= 0; i
< nrVars(); ++i
)
161 old_beliefs
.push_back(belief1(i
));
165 if( Updates() == UpdateType::SEQMAX
) {
167 for(size_t iI
= 0; iI
< nr_edges(); ++iI
)
170 // calculate initial residuals
171 residuals
.reserve(nr_edges());
172 for( size_t iI
= 0; iI
< nr_edges(); ++iI
)
173 residuals
.push_back( dist( _newmessages
[iI
], _messages
[iI
], Prob::DISTLINF
) );
175 edge_seq
.reserve( nr_edges() );
176 for( size_t i
= 0; i
< nr_edges(); ++i
)
177 edge_seq
.push_back( i
);
180 // do several passes over the network until maximum number of iterations has
181 // been reached or until the maximum belief difference is smaller than tolerance
182 for( iter
=0; iter
< MaxIter() && diffs
.max() > Tol(); ++iter
) {
183 if( Updates() == UpdateType::SEQMAX
) {
184 // Residuals-BP by Koller et al.
185 for( size_t t
= 0; t
< nr_edges(); ++t
) {
186 // update the message with the largest residual
187 size_t iI
= max_element(residuals
.begin(), residuals
.end()) - residuals
.begin();
188 _messages
[iI
] = _newmessages
[iI
];
191 // I->i has been updated, which means that residuals for all
192 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
193 size_t i
= edge(iI
).first
;
194 size_t I
= edge(iI
).second
;
195 for( _nb_cit J
= nb1(i
).begin(); J
!= nb1(i
).end(); ++J
)
197 for( _nb_cit j
= nb2(*J
).begin(); j
!= nb2(*J
).end(); ++j
)
199 size_t jJ
= VV2E(*j
,*J
);
201 residuals
[jJ
] = dist( _newmessages
[jJ
], _messages
[jJ
], Prob::DISTLINF
);
204 } else if( Updates() == UpdateType::PARALL
) {
206 for( size_t t
= 0; t
< nr_edges(); ++t
)
209 for( size_t t
= 0; t
< nr_edges(); ++t
)
210 _messages
[t
] = _newmessages
[t
];
212 // Sequential updates
213 if( Updates() == UpdateType::SEQRND
)
214 random_shuffle( edge_seq
.begin(), edge_seq
.end() );
216 for( size_t t
= 0; t
< nr_edges(); ++t
) {
217 size_t k
= edge_seq
[t
];
219 _messages
[k
] = _newmessages
[k
];
223 // calculate new beliefs and compare with old ones
224 for( size_t i
= 0; i
< nrVars(); ++i
) {
225 Factor
nb( belief1(i
) );
226 diffs
.push( dist( nb
, old_beliefs
[i
], Prob::DISTLINF
) );
231 cout
<< "BP::run: maxdiff " << diffs
.max() << " after " << iter
+1 << " passes" << endl
;
234 updateMaxDiff( diffs
.max() );
236 if( Verbose() >= 1 ) {
237 if( diffs
.max() > Tol() ) {
240 cout
<< "BP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic
<< " clocks)...final maxdiff:" << diffs
.max() << endl
;
244 cout
<< "converged in " << iter
<< " passes (" << toc() - tic
<< " clocks)." << endl
;
252 Factor
BP::belief1( size_t i
) const {
253 Prob
prod( var(i
).states() );
254 for( _nb_cit I
= nb1(i
).begin(); I
!= nb1(i
).end(); ++I
)
255 prod
*= newMessage(i
,*I
);
257 prod
.normalize( Prob::NORMPROB
);
258 return( Factor( var(i
), prod
) );
262 Factor
BP::belief (const Var
&n
) const {
263 return( belief1( findVar( n
) ) );
267 vector
<Factor
> BP::beliefs() const {
268 vector
<Factor
> result
;
269 for( size_t i
= 0; i
< nrVars(); ++i
)
270 result
.push_back( belief1(i
) );
271 for( size_t I
= 0; I
< nrFactors(); ++I
)
272 result
.push_back( belief2(I
) );
277 Factor
BP::belief( const VarSet
&ns
) const {
279 return belief( *(ns
.begin()) );
282 for( I
= 0; I
< nrFactors(); I
++ )
283 if( factor(I
).vars() >> ns
)
285 assert( I
!= nrFactors() );
286 return belief2(I
).marginal(ns
);
291 Factor
BP::belief2 (size_t I
) const {
292 Prob
prod( factor(I
).p() );
294 for( _nb_cit j
= nb2(I
).begin(); j
!= nb2(I
).end(); ++j
) {
295 // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
296 const _ind_t
*ind
= &(index(*j
, I
));
298 // prod_j will be the product of messages coming into j
299 Prob
prod_j( var(*j
).states() );
300 for( _nb_cit J
= nb1(*j
).begin(); J
!= nb1(*j
).end(); ++J
)
301 if( *J
!= I
) // for all J in nb(j) \ I
302 prod_j
*= newMessage(*j
,*J
);
304 // multiply prod with prod_j
305 for( size_t r
= 0; r
< prod
.size(); ++r
)
306 prod
[r
] *= prod_j
[(*ind
)[r
]];
309 Factor
result( factor(I
).vars(), prod
);
310 result
.normalize( Prob::NORMPROB
);
314 /* UNOPTIMIZED VERSION
316 Factor prod( factor(I) );
317 for( _nb_cit i = nb2(I).begin(); i != nb2(I).end(); i++ ) {
318 for( _nb_cit J = nb1(*i).begin(); J != nb1(*i).end(); J++ )
320 prod *= Factor( var(*i), newMessage(*i,*J)) );
322 return prod.normalize( Prob::NORMPROB );*/
326 Complex
BP::logZ() const {
328 for(size_t i
= 0; i
< nrVars(); ++i
)
329 sum
+= Complex(1.0 - nb1(i
).size()) * belief1(i
).entropy();
330 for( size_t I
= 0; I
< nrFactors(); ++I
)
331 sum
-= KL_dist( belief2(I
), factor(I
) );
336 string
BP::identify() const {
337 stringstream
result (stringstream::out
);
338 result
<< Name
<< GetProperties();
343 void BP::init( const VarSet
&ns
) {
344 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); ++n
) {
345 size_t ni
= findVar( *n
);
346 for( _nb_cit I
= nb1(ni
).begin(); I
!= nb1(ni
).end(); ++I
)
347 message(ni
,*I
).fill( 1.0 );
352 } // end of namespace dai