1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
13 #include <dai/gibbs.h>
15 #include <dai/bipgraph.h>
24 /// Convenience typedef
25 typedef BipartiteGraph::Neighbor Neighbor
;
28 Prob
unnormAdjoint( const Prob
&w
, Real Z_w
, const Prob
&adj_w
) {
29 assert( w
.size() == adj_w
.size() );
30 Prob
adj_w_unnorm( w
.size(), 0.0 );
32 for( size_t i
= 0; i
< w
.size(); i
++ )
34 for( size_t i
= 0; i
< w
.size(); i
++ )
35 adj_w_unnorm
[i
] = (adj_w
[i
] - s
) / Z_w
;
37 // THIS WOULD BE ABOUT 50% SLOWER: return (adj_w - (w * adj_w).sum()) / Z_w;
41 std::vector
<size_t> getGibbsState( const InfAlg
&ia
, size_t iters
) {
42 PropertySet gibbsProps
;
43 gibbsProps
.Set("iters", iters
);
44 gibbsProps
.Set("verbose", size_t(0));
45 Gibbs
gibbs( ia
.fg(), gibbsProps
);
51 /// Returns the entry of the I'th factor corresponding to a global state
52 size_t getFactorEntryForState( const FactorGraph
&fg
, size_t I
, const vector
<size_t> &state
) {
54 for( int _j
= fg
.nbF(I
).size() - 1; _j
>= 0; _j
-- ) {
55 // note that iterating over nbF(I) yields the same ordering
56 // of variables as iterating over factor(I).vars()
57 size_t j
= fg
.nbF(I
)[_j
];
58 f_entry
*= fg
.var(j
).states();
65 #define LOOP_ij(body) { \
66 size_t i_states = _fg->var(i).states(); \
67 size_t j_states = _fg->var(j).states(); \
68 if(_fg->var(i) > _fg->var(j)) { \
70 for(size_t xi=0; xi<i_states; xi++) { \
71 for(size_t xj=0; xj<j_states; xj++) { \
78 for(size_t xj=0; xj<j_states; xj++) { \
79 for(size_t xi=0; xi<i_states; xi++) { \
88 void BBP::RegenerateInds() {
89 // initialise _indices
90 // typedef std::vector<size_t> _ind_t;
91 // std::vector<std::vector<_ind_t> > _indices;
92 _indices
.resize( _fg
->nrVars() );
93 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
95 _indices
[i
].reserve( _fg
->nbV(i
).size() );
96 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
98 index
.reserve( _fg
->factor(I
).states() );
99 for( IndexFor
k(_fg
->var(i
), _fg
->factor(I
).vars()); k
>= 0; ++k
)
100 index
.push_back( k
);
101 _indices
[i
].push_back( index
);
107 void BBP::RegenerateT() {
109 _T
.resize( _fg
->nrVars() );
110 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
111 _T
[i
].resize( _fg
->nbV(i
).size() );
112 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
113 Prob
prod( _fg
->var(i
).states(), 1.0 );
114 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
115 if( J
.node
!= I
.node
)
116 prod
*= _bp_dual
.msgM( i
, J
.iter
);
117 _T
[i
][I
.iter
] = prod
;
123 void BBP::RegenerateU() {
125 _U
.resize( _fg
->nrFactors() );
126 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ ) {
127 _U
[I
].resize( _fg
->nbF(I
).size() );
128 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
129 Prob
prod( _fg
->factor(I
).states(), 1.0 );
130 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
131 if( i
.node
!= j
.node
) {
132 Prob
n_jI( _bp_dual
.msgN( j
, j
.dual
) );
133 const _ind_t
&ind
= _index( j
, j
.dual
);
134 // multiply prod by n_jI
135 for( size_t x_I
= 0; x_I
< prod
.size(); x_I
++ )
136 prod
[x_I
] *= n_jI
[ind
[x_I
]];
138 _U
[I
][i
.iter
] = prod
;
144 void BBP::RegenerateS() {
146 _S
.resize( _fg
->nrVars() );
147 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
148 _S
[i
].resize( _fg
->nbV(i
).size() );
149 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
150 _S
[i
][I
.iter
].resize( _fg
->nbF(I
).size() );
151 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
153 Factor
prod( _fg
->factor(I
) );
154 foreach( const Neighbor
&k
, _fg
->nbF(I
) ) {
155 if( k
!= i
&& k
.node
!= j
.node
) {
156 const _ind_t
&ind
= _index( k
, k
.dual
);
157 Prob
p( _bp_dual
.msgN( k
, k
.dual
) );
158 for( size_t x_I
= 0; x_I
< prod
.states(); x_I
++ )
159 prod
.p()[x_I
] *= p
[ind
[x_I
]];
162 // "Marginalize" onto i|j (unnormalized)
164 marg
= prod
.marginal( VarSet(_fg
->var(i
), _fg
->var(j
)), false ).p();
165 _S
[i
][I
.iter
][j
.iter
] = marg
;
172 void BBP::RegenerateR() {
174 _R
.resize( _fg
->nrFactors() );
175 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ ) {
176 _R
[I
].resize( _fg
->nbF(I
).size() );
177 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
178 _R
[I
][i
.iter
].resize( _fg
->nbV(i
).size() );
179 foreach( const Neighbor
&J
, _fg
->nbV(i
) ) {
181 Prob
prod( _fg
->var(i
).states(), 1.0 );
182 foreach( const Neighbor
&K
, _fg
->nbV(i
) )
183 if( K
.node
!= I
&& K
.node
!= J
.node
)
184 prod
*= _bp_dual
.msgM( i
, K
.iter
);
185 _R
[I
][i
.iter
][J
.iter
] = prod
;
193 void BBP::RegenerateInputs() {
194 _adj_b_V_unnorm
.clear();
195 _adj_b_V_unnorm
.reserve( _fg
->nrVars() );
196 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
197 _adj_b_V_unnorm
.push_back( unnormAdjoint( _bp_dual
.beliefV(i
).p(), _bp_dual
.beliefVZ(i
), _adj_b_V
[i
] ) );
198 _adj_b_F_unnorm
.clear();
199 _adj_b_F_unnorm
.reserve( _fg
->nrFactors() );
200 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ )
201 _adj_b_F_unnorm
.push_back( unnormAdjoint( _bp_dual
.beliefF(I
).p(), _bp_dual
.beliefFZ(I
), _adj_b_F
[I
] ) );
205 void BBP::RegeneratePsiAdjoints() {
207 _adj_psi_V
.reserve( _fg
->nrVars() );
208 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
209 Prob
p( _adj_b_V_unnorm
[i
] );
210 assert( p
.size() == _fg
->var(i
).states() );
211 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
212 p
*= _bp_dual
.msgM( i
, I
.iter
);
213 p
+= _init_adj_psi_V
[i
];
214 _adj_psi_V
.push_back( p
);
217 _adj_psi_F
.reserve( _fg
->nrFactors() );
218 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ ) {
219 Prob
p( _adj_b_F_unnorm
[I
] );
220 assert( p
.size() == _fg
->factor(I
).states() );
221 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
222 Prob
n_iI( _bp_dual
.msgN( i
, i
.dual
) );
223 const _ind_t
& ind
= _index( i
, i
.dual
);
224 // multiply prod with n_jI
225 for( size_t x_I
= 0; x_I
< p
.size(); x_I
++ )
226 p
[x_I
] *= n_iI
[ind
[x_I
]];
228 p
+= _init_adj_psi_F
[I
];
229 _adj_psi_F
.push_back( p
);
234 void BBP::RegenerateParMessageAdjoints() {
235 size_t nv
= _fg
->nrVars();
238 _adj_n_unnorm
.resize( nv
);
239 _adj_m_unnorm
.resize( nv
);
240 _new_adj_n
.resize( nv
);
241 _new_adj_m
.resize( nv
);
242 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
243 size_t n_i
= _fg
->nbV(i
).size();
244 _adj_n
[i
].resize( n_i
);
245 _adj_m
[i
].resize( n_i
);
246 _adj_n_unnorm
[i
].resize( n_i
);
247 _adj_m_unnorm
[i
].resize( n_i
);
248 _new_adj_n
[i
].resize( n_i
);
249 _new_adj_m
[i
].resize( n_i
);
250 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
252 Prob
prod( _fg
->factor(I
).p() );
253 prod
*= _adj_b_F_unnorm
[I
];
254 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
256 Prob
n_jI( _bp_dual
.msgN( j
, j
.dual
) );
257 const _ind_t
&ind
= _index( j
, j
.dual
);
258 // multiply prod with n_jI
259 for( size_t x_I
= 0; x_I
< prod
.size(); x_I
++ )
260 prod
[x_I
] *= n_jI
[ind
[x_I
]];
262 Prob
marg( _fg
->var(i
).states(), 0.0 );
263 const _ind_t
&ind
= _index( i
, I
.iter
);
264 for( size_t r
= 0; r
< prod
.size(); r
++ )
265 marg
[ind
[r
]] += prod
[r
];
266 _new_adj_n
[i
][I
.iter
] = marg
;
271 Prob
prod( _adj_b_V_unnorm
[i
] );
272 assert( prod
.size() == _fg
->var(i
).states() );
273 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
274 if( J
.node
!= I
.node
)
275 prod
*= _bp_dual
.msgM(i
,J
.iter
);
276 _new_adj_m
[i
][I
.iter
] = prod
;
284 void BBP::RegenerateSeqMessageAdjoints() {
285 size_t nv
= _fg
->nrVars();
287 _adj_m_unnorm
.resize( nv
);
288 _new_adj_m
.resize( nv
);
289 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
290 size_t n_i
= _fg
->nbV(i
).size();
291 _adj_m
[i
].resize( n_i
);
292 _adj_m_unnorm
[i
].resize( n_i
);
293 _new_adj_m
[i
].resize( n_i
);
294 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
296 Prob
prod( _adj_b_V_unnorm
[i
] );
297 assert( prod
.size() == _fg
->var(i
).states() );
298 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
299 if( J
.node
!= I
.node
)
300 prod
*= _bp_dual
.msgM( i
, J
.iter
);
301 _adj_m
[i
][I
.iter
] = prod
;
302 calcUnnormMsgM( i
, I
.iter
);
303 _new_adj_m
[i
][I
.iter
] = Prob( _fg
->var(i
).states(), 0.0 );
306 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
307 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
309 Prob
prod( _fg
->factor(I
).p() );
310 prod
*= _adj_b_F_unnorm
[I
];
311 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
313 Prob
n_jI( _bp_dual
.msgN( j
, j
.dual
) );
314 const _ind_t
& ind
= _index( j
, j
.dual
);
315 // multiply prod with n_jI
316 for( size_t x_I
= 0; x_I
< prod
.size(); x_I
++ )
317 prod
[x_I
] *= n_jI
[ind
[x_I
]];
319 Prob
marg( _fg
->var(i
).states(), 0.0 );
320 const _ind_t
&ind
= _index( i
, I
.iter
);
321 for( size_t r
= 0; r
< prod
.size(); r
++ )
322 marg
[ind
[r
]] += prod
[r
];
323 sendSeqMsgN( i
, I
.iter
,marg
);
329 void BBP::calcNewN( size_t i
, size_t _I
) {
330 _adj_psi_V
[i
] += T(i
,_I
) * _adj_n_unnorm
[i
][_I
];
331 Prob
&new_adj_n_iI
= _new_adj_n
[i
][_I
];
332 new_adj_n_iI
= Prob( _fg
->var(i
).states(), 0.0 );
333 size_t I
= _fg
->nbV(i
)[_I
];
334 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
336 const Prob
&p
= _S
[i
][_I
][j
.iter
];
337 const Prob
&_adj_m_unnorm_jI
= _adj_m_unnorm
[j
][j
.dual
];
339 new_adj_n_iI
[xi
] += p
[xij
] * _adj_m_unnorm_jI
[xj
];
341 /* THE FOLLOWING WOULD BE ABOUT TWICE AS SLOW:
342 Var vi = _fg->var(i);
343 Var vj = _fg->var(j);
344 new_adj_n_iI = (Factor(VarSet(vi, vj), p) * Factor(vj,_adj_m_unnorm_jI)).marginal(vi,false).p();
350 void BBP::calcNewM( size_t i
, size_t _I
) {
351 const Neighbor
&I
= _fg
->nbV(i
)[_I
];
352 Prob
p( U(I
, I
.dual
) );
353 const Prob
&adj
= _adj_m_unnorm
[i
][_I
];
354 const _ind_t
&ind
= _index(i
,_I
);
355 for( size_t x_I
= 0; x_I
< p
.size(); x_I
++ )
356 p
[x_I
] *= adj
[ind
[x_I
]];
358 /* THE FOLLOWING WOULD BE SLIGHTLY SLOWER:
359 _adj_psi_F[I] += (Factor( _fg->factor(I).vars(), U(I, I.dual) ) * Factor( _fg->var(i), _adj_m_unnorm[i][_I] )).p();
362 _new_adj_m
[i
][_I
] = Prob( _fg
->var(i
).states(), 0.0 );
363 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
365 _new_adj_m
[i
][_I
] += _R
[I
][I
.dual
][J
.iter
] * _adj_n_unnorm
[i
][J
.iter
];
369 void BBP::calcUnnormMsgN( size_t i
, size_t _I
) {
370 _adj_n_unnorm
[i
][_I
] = unnormAdjoint( _bp_dual
.msgN(i
,_I
), _bp_dual
.zN(i
,_I
), _adj_n
[i
][_I
] );
374 void BBP::calcUnnormMsgM( size_t i
, size_t _I
) {
375 _adj_m_unnorm
[i
][_I
] = unnormAdjoint( _bp_dual
.msgM(i
,_I
), _bp_dual
.zM(i
,_I
), _adj_m
[i
][_I
] );
379 void BBP::upMsgN( size_t i
, size_t _I
) {
380 _adj_n
[i
][_I
] = _new_adj_n
[i
][_I
];
381 calcUnnormMsgN( i
, _I
);
385 void BBP::upMsgM( size_t i
, size_t _I
) {
386 _adj_m
[i
][_I
] = _new_adj_m
[i
][_I
];
387 calcUnnormMsgM( i
, _I
);
391 void BBP::doParUpdate() {
392 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
393 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
394 calcNewM( i
, I
.iter
);
395 calcNewN( i
, I
.iter
);
397 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
398 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
405 void BBP::incrSeqMsgM( size_t i
, size_t _I
, const Prob
&p
) {
406 /* if( props.clean_updates )
407 _new_adj_m[i][_I] += p;
410 calcUnnormMsgM(i
, _I
);
421 void BBP::updateSeqMsgM( size_t i, size_t _I ) {
422 if( props.clean_updates ) {
424 if(_new_adj_m[i][_I].sumAbs() > pv_thresh ||
425 _adj_m[i][_I].sumAbs() > pv_thresh) {
427 DAI_DMSG("in updateSeqMsgM");
430 DAI_PV(_adj_m[i][_I]);
431 DAI_PV(_new_adj_m[i][_I]);
434 _adj_m[i][_I] += _new_adj_m[i][_I];
435 calcUnnormMsgM( i, _I );
436 _new_adj_m[i][_I].fill( 0.0 );
441 void BBP::setSeqMsgM( size_t i
, size_t _I
, const Prob
&p
) {
443 calcUnnormMsgM( i
, _I
);
447 void BBP::sendSeqMsgN( size_t i
, size_t _I
, const Prob
&f
) {
448 Prob f_unnorm
= unnormAdjoint( _bp_dual
.msgN(i
,_I
), _bp_dual
.zN(i
,_I
), f
);
449 const Neighbor
&I
= _fg
->nbV(i
)[_I
];
450 assert( I
.iter
== _I
);
451 _adj_psi_V
[i
] += f_unnorm
* T( i
, _I
);
453 if(f_unnorm
.sumAbs() > pv_thresh
) {
454 DAI_DMSG("in sendSeqMsgN");
458 DAI_PV(_bp_dual
.msgN(i
,_I
));
459 DAI_PV(_bp_dual
.zN(i
,_I
));
460 DAI_PV(_bp_dual
.msgM(i
,_I
));
461 DAI_PV(_bp_dual
.zM(i
,_I
));
462 DAI_PV(_fg
->factor(I
).p());
465 foreach( const Neighbor
&J
, _fg
->nbV(i
) ) {
466 if( J
.node
!= I
.node
) {
468 if(f_unnorm
.sumAbs() > pv_thresh
) {
469 DAI_DMSG("in sendSeqMsgN loop");
472 DAI_PV(_R
[J
][J
.dual
][_I
]);
473 DAI_PV(f_unnorm
* _R
[J
][J
.dual
][_I
]);
476 incrSeqMsgM( i
, J
.iter
, f_unnorm
* R(J
, J
.dual
, _I
) );
482 void BBP::sendSeqMsgM( size_t j
, size_t _I
) {
483 const Neighbor
&I
= _fg
->nbV(j
)[_I
];
487 // DAI_PV(_adj_m_unnorm_jI);
488 // DAI_PV(_adj_m[j][_I]);
489 // DAI_PV(_bp_dual.zM(j,_I));
492 const Prob
&_adj_m_unnorm_jI
= _adj_m_unnorm
[j
][_I
];
494 const _ind_t
&ind
= _index(j
, _I
);
495 for( size_t x_I
= 0; x_I
< um
.size(); x_I
++ )
496 um
[x_I
] *= _adj_m_unnorm_jI
[ind
[x_I
]];
497 um
*= 1 - props
.damping
;
500 /* THE FOLLOWING WOULD BE SLIGHTLY SLOWER:
501 _adj_psi_F[I] += (Factor( _fg->factor(I).vars(), U(I, _j) ) * Factor( _fg->var(j), _adj_m_unnorm[j][_I] )).p() * (1.0 - props.damping);
504 // DAI_DMSG("in sendSeqMsgM");
508 // DAI_PV(_fg->nbF(I).size());
509 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
511 const Prob
&S
= _S
[i
][i
.dual
][_j
];
512 Prob
msg( _fg
->var(i
).states(), 0.0 );
514 msg
[xi
] += S
[xij
] * _adj_m_unnorm_jI
[xj
];
516 msg
*= 1.0 - props
.damping
;
517 /* THE FOLLOWING WOULD BE ABOUT TWICE AS SLOW:
518 Var vi = _fg->var(i);
519 Var vj = _fg->var(j);
520 msg = (Factor(VarSet(vi,vj), S) * Factor(vj,_adj_m_unnorm_jI)).marginal(vi,false).p() * (1.0 - props.damping);
523 if(msg
.sumAbs() > pv_thresh
) {
524 DAI_DMSG("in sendSeqMsgM loop");
529 DAI_PV(_fg
->nbF(I
).size());
530 DAI_PV(_fg
->factor(I
).p());
531 DAI_PV(_S
[i
][i
.dual
][_j
]);
536 DAI_PV(_fg
->nbV(i
).size());
539 assert( _fg
->nbV(i
)[i
.dual
].node
== I
);
540 sendSeqMsgN( i
, i
.dual
, msg
);
543 setSeqMsgM( j
, _I
, _adj_m
[j
][_I
] * props
.damping
);
547 Real
BBP::getUnMsgMag() {
550 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
551 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
552 s
+= _adj_m_unnorm
[i
][I
.iter
].sumAbs();
553 s
+= _adj_n_unnorm
[i
][I
.iter
].sumAbs();
560 void BBP::getMsgMags( Real
&s
, Real
&new_s
) {
564 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
565 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
566 s
+= _adj_m
[i
][I
.iter
].sumAbs();
567 s
+= _adj_n
[i
][I
.iter
].sumAbs();
568 new_s
+= _new_adj_m
[i
][I
.iter
].sumAbs();
569 new_s
+= _new_adj_n
[i
][I
.iter
].sumAbs();
576 // tuple<size_t,size_t,Real> BBP::getArgMaxPsi1Adj() {
577 // size_t argmax_var=0;
578 // size_t argmax_var_state=0;
580 // for( size_t i = 0; i < _fg->nrVars(); i++ ) {
581 // pair<size_t,Real> argmax_state = adj_psi_V(i).argmax();
582 // if(i==0 || argmax_state.second>max_var) {
584 // max_var = argmax_state.second;
585 // argmax_var_state = argmax_state.first;
588 // assert(/*0 <= argmax_var_state &&*/
589 // argmax_var_state < _fg->var(argmax_var).states());
590 // return tuple<size_t,size_t,Real>(argmax_var,argmax_var_state,max_var);
594 void BBP::getArgmaxMsgM( size_t &out_i
, size_t &out__I
, Real
&mag
) {
596 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
597 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
598 Real thisMag
= _adj_m
[i
][I
.iter
].sumAbs();
599 if( !found
|| mag
< thisMag
) {
610 Real
BBP::getMaxMsgM() {
613 getArgmaxMsgM( dummy
, dummy
, mag
);
618 Real
BBP::getTotalMsgM() {
620 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
621 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
622 mag
+= _adj_m
[i
][I
.iter
].sumAbs();
627 Real
BBP::getTotalNewMsgM() {
629 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
630 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
631 mag
+= _new_adj_m
[i
][I
.iter
].sumAbs();
636 Real
BBP::getTotalMsgN() {
638 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
639 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
640 mag
+= _adj_n
[i
][I
.iter
].sumAbs();
645 void BBP::Regenerate() {
652 RegeneratePsiAdjoints();
653 if( props
.updates
== Properties::UpdateType::PAR
)
654 RegenerateParMessageAdjoints();
656 RegenerateSeqMessageAdjoints();
661 std::vector
<Prob
> BBP::getZeroAdjF( const FactorGraph
&fg
) {
663 adj_2
.reserve( fg
.nrFactors() );
664 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
665 adj_2
.push_back( Prob( fg
.factor(I
).states(), 0.0 ) );
670 std::vector
<Prob
> BBP::getZeroAdjV( const FactorGraph
&fg
) {
672 adj_1
.reserve( fg
.nrVars() );
673 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
674 adj_1
.push_back( Prob( fg
.var(i
).states(), 0.0 ) );
680 typedef BBP::Properties::UpdateType UT
;
681 Real
&tol
= props
.tol
;
682 UT
&updates
= props
.updates
;
685 switch( (size_t)updates
) {
691 getArgmaxMsgM( i
, _I
, mag
);
692 sendSeqMsgM( i
, _I
);
693 } while( mag
> tol
&& _iters
< props
.maxiter
);
695 if( _iters
>= props
.maxiter
)
696 if( props
.verbose
>= 1 )
697 cerr
<< "Warning: BBP didn't converge in " << _iters
<< " iterations (greatest message magnitude = " << mag
<< ")" << endl
;
699 } case UT::SEQ_FIX
: {
703 mag
= getTotalMsgM();
707 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
708 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
709 sendSeqMsgM( i
, I
.iter
);
710 /* for( size_t i = 0; i < _fg->nrVars(); i++ )
711 foreach( const Neighbor &I, _fg->nbV(i) )
712 updateSeqMsgM( i, I.iter );*/
713 } while( mag
> tol
&& _iters
< props
.maxiter
);
715 if( _iters
>= props
.maxiter
)
716 if( props
.verbose
>= 1 )
717 cerr
<< "Warning: BBP didn't converge in " << _iters
<< " iterations (greatest message magnitude = " << mag
<< ")" << endl
;
719 } case UT::SEQ_BP_REV
:
720 case UT::SEQ_BP_FWD
: {
721 const BP
*bp
= static_cast<const BP
*>(_ia
);
722 vector
<pair
<size_t, size_t> > sentMessages
= bp
->getSentMessages();
723 size_t totalMessages
= sentMessages
.size();
724 if( totalMessages
== 0 )
725 DAI_THROWE(INTERNAL_ERROR
, "Asked for updates=" + std::string(updates
) + " but no BP messages; did you forget to set recordSentMessages?");
726 if( updates
==UT::SEQ_BP_FWD
)
727 reverse( sentMessages
.begin(), sentMessages
.end() );
728 // DAI_PV(sentMessages.size());
730 // DAI_PV(props.maxiter);
731 while( sentMessages
.size() > 0 && _iters
< props
.maxiter
) {
732 // DAI_PV(sentMessages.size());
735 pair
<size_t, size_t> e
= sentMessages
.back();
736 sentMessages
.pop_back();
737 size_t i
= e
.first
, _I
= e
.second
;
738 sendSeqMsgM( i
, _I
);
740 if( _iters
>= props
.maxiter
)
741 if( props
.verbose
>= 1 )
742 cerr
<< "Warning: BBP updates limited to " << props
.maxiter
<< " iterations, but using UpdateType " << updates
<< " with " << totalMessages
<< " messages" << endl
;
748 } while( (_iters
< 2 || getUnMsgMag() > tol
) && _iters
< props
.maxiter
);
749 if( _iters
== props
.maxiter
) {
751 getMsgMags( s
, new_s
);
752 if( props
.verbose
>= 1 )
753 cerr
<< "Warning: BBP didn't converge in " << _iters
<< " iterations (unnorm message magnitude = " << getUnMsgMag() << ", norm message mags = " << s
<< " -> " << new_s
<< ")" << endl
;
758 if( props
.verbose
>= 3 )
759 cerr
<< "BBP::run() took " << toc()-tic
<< " seconds " << doneIters() << " iterations" << endl
;
763 double numericBBPTest( const InfAlg
&bp
, const vector
<size_t> *state
, const PropertySet
&bbp_props
, bbp_cfn_t cfn
, double h
) {
764 // calculate the value of the unperturbed cost function
765 Real cf0
= getCostFn( bp
, cfn
, state
);
767 // run BBP to estimate adjoints
768 BBP
bbp( &bp
, bbp_props
);
769 initBBPCostFnAdj( bbp
, bp
, cfn
, state
);
773 const FactorGraph
& fg
= bp
.fg();
776 // verify bbp.adj_psi_V
778 // for each variable i
779 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
780 vector
<double> adj_est
;
782 for( size_t xi
= 0; xi
< fg
.var(i
).states(); xi
++ ) {
783 // Clone 'bp' (which may be any InfAlg)
784 InfAlg
*bp_prb
= bp
.clone();
787 size_t n
= bp_prb
->fg().var(i
).states();
788 Prob
psi_1_prb( n
, 1.0 );
790 // psi_1_prb.normalize();
791 size_t I
= bp_prb
->fg().nbV(i
)[0]; // use first factor in list of neighbors of i
792 bp_prb
->fg().factor(I
) *= Factor( bp_prb
->fg().var(i
), psi_1_prb
);
794 // call 'init' on the perturbed variables
795 bp_prb
->init( bp_prb
->fg().var(i
) );
797 // run copy to convergence
800 // calculate new value of cost function
801 Real cf_prb
= getCostFn( *bp_prb
, cfn
, state
);
803 // use to estimate adjoint for i
804 adj_est
.push_back( (cf_prb
- cf0
) / h
);
806 // free cloned InfAlg
809 Prob
p_adj_est( adj_est
.begin(), adj_est
.end() );
810 // compare this numerical estimate to the BBP estimate; sum the distances
812 << ", p_adj_est: " << p_adj_est
813 << ", bbp.adj_psi_V(i): " << bbp
.adj_psi_V(i
) << endl
;
814 d
+= dist( p_adj_est
, bbp
.adj_psi_V(i
), Prob::DISTL1
);
818 // verify bbp.adj_n and bbp.adj_m
820 // We actually want to check the responsiveness of objective
821 // function to changes in the final messages. But at the end of a
822 // BBP run, the message adjoints are for the initial messages (and
823 // they should be close to zero, see paper). So this resets the
824 // BBP adjoints to the refer to the desired final messages
825 bbp.RegenerateMessageAdjoints();
827 // for each variable i
828 for(size_t i=0; i<bp_dual.nrVars(); i++) {
829 // for each factor I ~ i
830 foreach(size_t I, bp_dual.nbV(i)) {
831 vector<double> adj_n_est;
833 for(size_t xi=0; xi<bp_dual.var(i).states(); xi++) {
834 BP_dual bp_dual_prb(bp_dual);
835 // make h-sized change to newMsgN
836 bp_dual_prb.newMsgN(i,I)[xi] += h;
837 // recalculate beliefs
838 bp_dual_prb.CalcBeliefs();
839 // get cost function value
840 Real cf_prb = getCostFn(bp_dual_prb, cfn, &state);
841 // add it to list of adjoints
842 adj_n_est.push_back((cf_prb-cf0)/h);
845 vector<double> adj_m_est;
847 for(size_t xi=0; xi<bp_dual.var(i).states(); xi++) {
848 BP_dual bp_dual_prb(bp_dual);
849 // make h-sized change to newMsgM
850 bp_dual_prb.newMsgM(I,i)[xi] += h;
851 // recalculate beliefs
852 bp_dual_prb.CalcBeliefs();
853 // get cost function value
854 Real cf_prb = getCostFn(bp_dual_prb, cfn, &state);
855 // add it to list of adjoints
856 adj_m_est.push_back((cf_prb-cf0)/h);
859 Prob p_adj_n_est(adj_n_est.begin(), adj_n_est.end());
860 // compare this numerical estimate to the BBP estimate; sum the distances
861 cerr << "i: " << i << ", I: " << I
862 << ", adj_n_est: " << p_adj_n_est
863 << ", bbp.adj_n(i,I): " << bbp.adj_n(i,I) << endl;
864 d += dist(p_adj_n_est, bbp.adj_n(i,I), Prob::DISTL1);
866 Prob p_adj_m_est(adj_m_est.begin(), adj_m_est.end());
867 // compare this numerical estimate to the BBP estimate; sum the distances
868 cerr << "i: " << i << ", I: " << I
869 << ", adj_m_est: " << p_adj_m_est
870 << ", bbp.adj_m(I,i): " << bbp.adj_m(I,i) << endl;
871 d += dist(p_adj_m_est, bbp.adj_m(I,i), Prob::DISTL1);
877 // verify bbp.adj_b_V
878 for(size_t i=0; i<bp_dual.nrVars(); i++) {
879 vector<double> adj_b_V_est;
881 for(size_t xi=0; xi<bp_dual.var(i).states(); xi++) {
882 BP_dual bp_dual_prb(bp_dual);
884 // make h-sized change to b_1(i)[x_i]
885 bp_dual_prb._beliefs.b1[i][xi] += h;
887 // get cost function value
888 Real cf_prb = getCostFn(bp_dual_prb, cfn, &state);
890 // add it to list of adjoints
891 adj_b_V_est.push_back((cf_prb-cf0)/h);
893 Prob p_adj_b_V_est(adj_b_V_est.begin(), adj_b_V_est.end());
894 // compare this numerical estimate to the BBP estimate; sum the distances
896 << ", adj_b_V_est: " << p_adj_b_V_est
897 << ", bbp.adj_b_V(i): " << bbp.adj_b_V(i) << endl;
898 d += dist(p_adj_b_V_est, bbp.adj_b_V(i), Prob::DISTL1);
903 // return total of distances
908 bool needGibbsState( bbp_cfn_t cfn
) {
909 switch( (size_t)cfn
) {
910 case bbp_cfn_t::CFN_GIBBS_B
:
911 case bbp_cfn_t::CFN_GIBBS_B2
:
912 case bbp_cfn_t::CFN_GIBBS_EXP
:
913 case bbp_cfn_t::CFN_GIBBS_B_FACTOR
:
914 case bbp_cfn_t::CFN_GIBBS_B2_FACTOR
:
915 case bbp_cfn_t::CFN_GIBBS_EXP_FACTOR
:
923 void initBBPCostFnAdj( BBP
&bbp
, const InfAlg
&ia
, bbp_cfn_t cfn_type
, const vector
<size_t> *stateP
) {
924 const FactorGraph
&fg
= ia
.fg();
926 switch( (size_t)cfn_type
) {
927 case bbp_cfn_t::CFN_BETHE_ENT
: {
930 vector
<Prob
> psi1_adj
;
931 vector
<Prob
> psi2_adj
;
932 b1_adj
.reserve( fg
.nrVars() );
933 psi1_adj
.reserve( fg
.nrVars() );
934 b2_adj
.reserve( fg
.nrFactors() );
935 psi2_adj
.reserve( fg
.nrFactors() );
936 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
937 size_t dim
= fg
.var(i
).states();
938 int c
= fg
.nbV(i
).size();
940 for( size_t xi
= 0; xi
< dim
; xi
++ )
941 p
[xi
] = (1 - c
) * (1 + log( ia
.beliefV(i
)[xi
] ));
942 b1_adj
.push_back( p
);
944 for( size_t xi
= 0; xi
< dim
; xi
++ )
945 p
[xi
] = -ia
.beliefV(i
)[xi
];
946 psi1_adj
.push_back( p
);
948 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
949 size_t dim
= fg
.factor(I
).states();
951 for( size_t xI
= 0; xI
< dim
; xI
++ )
952 p
[xI
] = 1 + log( ia
.beliefF(I
)[xI
] / fg
.factor(I
).p()[xI
] );
953 b2_adj
.push_back( p
);
955 for( size_t xI
= 0; xI
< dim
; xI
++ )
956 p
[xI
] = -ia
.beliefF(I
)[xI
] / fg
.factor(I
).p()[xI
];
957 psi2_adj
.push_back( p
);
959 bbp
.init( b1_adj
, b2_adj
, psi1_adj
, psi2_adj
);
961 } case bbp_cfn_t::CFN_FACTOR_ENT
: {
963 b2_adj
.reserve( fg
.nrFactors() );
964 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
965 size_t dim
= fg
.factor(I
).states();
967 for( size_t xI
= 0; xI
< dim
; xI
++ ) {
968 double bIxI
= ia
.beliefF(I
)[xI
];
972 p
[xI
] = 1 + log( bIxI
);
976 bbp
.init( bbp
.getZeroAdjV(fg
), b2_adj
);
978 } case bbp_cfn_t::CFN_VAR_ENT
: {
980 b1_adj
.reserve( fg
.nrVars() );
981 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
982 size_t dim
= fg
.var(i
).states();
984 for( size_t xi
= 0; xi
< fg
.var(i
).states(); xi
++ ) {
985 double bixi
= ia
.beliefV(i
)[xi
];
989 p
[xi
] = 1 + log( bixi
);
991 b1_adj
.push_back( p
);
995 } case bbp_cfn_t::CFN_GIBBS_B
:
996 case bbp_cfn_t::CFN_GIBBS_B2
:
997 case bbp_cfn_t::CFN_GIBBS_EXP
: {
998 // cost functions that use Gibbs sample, summing over variable marginals
999 vector
<size_t> state
;
1000 if( stateP
== NULL
)
1001 state
= getGibbsState( ia
, 2*ia
.Iterations() );
1004 assert( state
.size() == fg
.nrVars() );
1006 vector
<Prob
> b1_adj
;
1007 b1_adj
.reserve(fg
.nrVars());
1008 for( size_t i
= 0; i
< state
.size(); i
++ ) {
1009 size_t n
= fg
.var(i
).states();
1010 Prob
delta( n
, 0.0 );
1011 assert(/*0<=state[i] &&*/ state
[i
] < n
);
1012 double b
= ia
.beliefV(i
)[state
[i
]];
1013 switch( (size_t)cfn_type
) {
1014 case bbp_cfn_t::CFN_GIBBS_B
:
1015 delta
[state
[i
]] = 1.0;
1017 case bbp_cfn_t::CFN_GIBBS_B2
:
1018 delta
[state
[i
]] = b
;
1020 case bbp_cfn_t::CFN_GIBBS_EXP
:
1021 delta
[state
[i
]] = exp(b
);
1024 DAI_THROW(UNKNOWN_ENUM_VALUE
);
1026 b1_adj
.push_back( delta
);
1030 } case bbp_cfn_t::CFN_GIBBS_B_FACTOR
:
1031 case bbp_cfn_t::CFN_GIBBS_B2_FACTOR
:
1032 case bbp_cfn_t::CFN_GIBBS_EXP_FACTOR
: {
1033 // cost functions that use Gibbs sample, summing over factor marginals
1034 vector
<size_t> state
;
1035 if( stateP
== NULL
)
1036 state
= getGibbsState( ia
, 2*ia
.Iterations() );
1039 assert( state
.size() == fg
.nrVars() );
1041 vector
<Prob
> b2_adj
;
1042 b2_adj
.reserve( fg
.nrVars() );
1043 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
1044 size_t n
= fg
.factor(I
).states();
1045 Prob
delta( n
, 0.0 );
1047 size_t x_I
= getFactorEntryForState( fg
, I
, state
);
1048 assert(/*0<=x_I &&*/ x_I
< n
);
1050 double b
= ia
.beliefF(I
)[x_I
];
1051 switch( (size_t)cfn_type
) {
1052 case bbp_cfn_t::CFN_GIBBS_B_FACTOR
:
1055 case bbp_cfn_t::CFN_GIBBS_B2_FACTOR
:
1058 case bbp_cfn_t::CFN_GIBBS_EXP_FACTOR
:
1059 delta
[x_I
] = exp( b
);
1062 DAI_THROW(UNKNOWN_ENUM_VALUE
);
1064 b2_adj
.push_back( delta
);
1066 bbp
.init( bbp
.getZeroAdjV(fg
), b2_adj
);
1069 DAI_THROW(UNKNOWN_ENUM_VALUE
);
1074 Real
getCostFn( const InfAlg
&ia
, bbp_cfn_t cfn_type
, const vector
<size_t> *stateP
) {
1076 const FactorGraph
&fg
= ia
.fg();
1078 switch( (size_t)cfn_type
) {
1079 case bbp_cfn_t::CFN_BETHE_ENT
: // ignores state
1082 case bbp_cfn_t::CFN_VAR_ENT
: // ignores state
1083 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
1084 cf
+= -ia
.beliefV(i
).entropy();
1086 case bbp_cfn_t::CFN_FACTOR_ENT
: // ignores state
1087 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
1088 cf
+= -ia
.beliefF(I
).entropy();
1090 case bbp_cfn_t::CFN_GIBBS_B
:
1091 case bbp_cfn_t::CFN_GIBBS_B2
:
1092 case bbp_cfn_t::CFN_GIBBS_EXP
: {
1093 assert( stateP
!= NULL
);
1094 vector
<size_t> state
= *stateP
;
1095 assert( state
.size() == fg
.nrVars() );
1096 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
1097 double b
= ia
.beliefV(i
)[state
[i
]];
1098 switch( (size_t)cfn_type
) {
1099 case bbp_cfn_t::CFN_GIBBS_B
:
1102 case bbp_cfn_t::CFN_GIBBS_B2
:
1105 case bbp_cfn_t::CFN_GIBBS_EXP
:
1109 DAI_THROW(UNKNOWN_ENUM_VALUE
);
1113 } case bbp_cfn_t::CFN_GIBBS_B_FACTOR
:
1114 case bbp_cfn_t::CFN_GIBBS_B2_FACTOR
:
1115 case bbp_cfn_t::CFN_GIBBS_EXP_FACTOR
: {
1116 assert( stateP
!= NULL
);
1117 vector
<size_t> state
= *stateP
;
1118 assert( state
.size() == fg
.nrVars() );
1119 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
1120 size_t x_I
= getFactorEntryForState( fg
, I
, state
);
1121 double b
= ia
.beliefF(I
)[x_I
];
1122 switch( (size_t)cfn_type
) {
1123 case bbp_cfn_t::CFN_GIBBS_B_FACTOR
:
1126 case bbp_cfn_t::CFN_GIBBS_B2_FACTOR
:
1129 case bbp_cfn_t::CFN_GIBBS_EXP_FACTOR
:
1133 DAI_THROW(UNKNOWN_ENUM_VALUE
);
1138 DAI_THROWE(UNKNOWN_ENUM_VALUE
, "Unknown cost function " + std::string(cfn_type
));
1144 } // end of namespace dai
1147 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
1148 ./scripts/regenerate-properties include/dai/bbp.h src/bbp.cpp
1152 void BBP::Properties::set(const PropertySet
&opts
)
1154 const std::set
<PropertyKey
> &keys
= opts
.allKeys();
1155 std::set
<PropertyKey
>::const_iterator i
;
1156 for(i
=keys
.begin(); i
!=keys
.end(); i
++) {
1157 if(*i
== "verbose") continue;
1158 if(*i
== "maxiter") continue;
1159 if(*i
== "tol") continue;
1160 if(*i
== "damping") continue;
1161 if(*i
== "updates") continue;
1162 DAI_THROWE(UNKNOWN_PROPERTY_TYPE
, "BBP: Unknown property " + *i
);
1164 if(!opts
.hasKey("verbose"))
1165 DAI_THROWE(NOT_ALL_PROPERTIES_SPECIFIED
,"BBP: Missing property \"verbose\" for method \"BBP\"");
1166 if(!opts
.hasKey("maxiter"))
1167 DAI_THROWE(NOT_ALL_PROPERTIES_SPECIFIED
,"BBP: Missing property \"maxiter\" for method \"BBP\"");
1168 if(!opts
.hasKey("tol"))
1169 DAI_THROWE(NOT_ALL_PROPERTIES_SPECIFIED
,"BBP: Missing property \"tol\" for method \"BBP\"");
1170 if(!opts
.hasKey("damping"))
1171 DAI_THROWE(NOT_ALL_PROPERTIES_SPECIFIED
,"BBP: Missing property \"damping\" for method \"BBP\"");
1172 if(!opts
.hasKey("updates"))
1173 DAI_THROWE(NOT_ALL_PROPERTIES_SPECIFIED
,"BBP: Missing property \"updates\" for method \"BBP\"");
1174 verbose
= opts
.getStringAs
<size_t>("verbose");
1175 maxiter
= opts
.getStringAs
<size_t>("maxiter");
1176 tol
= opts
.getStringAs
<double>("tol");
1177 damping
= opts
.getStringAs
<double>("damping");
1178 updates
= opts
.getStringAs
<UpdateType
>("updates");
1180 PropertySet
BBP::Properties::get() const {
1182 opts
.Set("verbose", verbose
);
1183 opts
.Set("maxiter", maxiter
);
1184 opts
.Set("tol", tol
);
1185 opts
.Set("damping", damping
);
1186 opts
.Set("updates", updates
);
1189 string
BBP::Properties::toString() const {
1190 stringstream
s(stringstream::out
);
1192 s
<< "verbose=" << verbose
<< ",";
1193 s
<< "maxiter=" << maxiter
<< ",";
1194 s
<< "tol=" << tol
<< ",";
1195 s
<< "damping=" << damping
<< ",";
1196 s
<< "updates=" << updates
;
1200 } // end of namespace dai
1201 /* }}} END OF GENERATED CODE */