1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
3 This file is part of libDAI.
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 #include <dai/gibbs.h>
25 #include <dai/bipgraph.h>
34 typedef BipartiteGraph::Neighbor Neighbor
;
37 Prob
unnormAdjoint( const Prob
&w
, Real Z_w
, const Prob
&adj_w
) {
38 assert( w
.size() == adj_w
.size() );
39 Prob
adj_w_unnorm( w
.size(), 0.0 );
41 for( size_t i
= 0; i
< w
.size(); i
++ )
43 for( size_t i
= 0; i
< w
.size(); i
++ )
44 adj_w_unnorm
[i
] = (adj_w
[i
] - s
) / Z_w
;
46 // THIS WOULD BE ABOUT 50% SLOWER: return (adj_w - (w * adj_w).sum()) / Z_w;
50 std::vector
<size_t> getGibbsState( const InfAlg
&ia
, size_t iters
) {
51 PropertySet gibbsProps
;
52 gibbsProps
.Set("iters", iters
);
53 gibbsProps
.Set("verbose", size_t(0));
54 Gibbs
gibbs( ia
.fg(), gibbsProps
);
60 size_t getFactorEntryForState( const FactorGraph
&fg
, size_t I
, const vector
<size_t> &state
) {
62 for( int _j
= fg
.nbF(I
).size() - 1; _j
>= 0; _j
-- ) {
63 // note that iterating over nbF(I) yields the same ordering
64 // of variables as iterating over factor(I).vars()
65 size_t j
= fg
.nbF(I
)[_j
];
66 f_entry
*= fg
.var(j
).states();
73 #define LOOP_ij(body) { \
74 size_t i_states = _fg->var(i).states(); \
75 size_t j_states = _fg->var(j).states(); \
76 if(_fg->var(i) > _fg->var(j)) { \
78 for(size_t xi=0; xi<i_states; xi++) { \
79 for(size_t xj=0; xj<j_states; xj++) { \
86 for(size_t xj=0; xj<j_states; xj++) { \
87 for(size_t xi=0; xi<i_states; xi++) { \
96 void BBP::RegenerateInds() {
97 // initialise _indices
98 // typedef std::vector<size_t> _ind_t;
99 // std::vector<std::vector<_ind_t> > _indices;
100 _indices
.resize( _fg
->nrVars() );
101 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
103 _indices
[i
].reserve( _fg
->nbV(i
).size() );
104 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
106 index
.reserve( _fg
->factor(I
).states() );
107 for( IndexFor
k(_fg
->var(i
), _fg
->factor(I
).vars()); k
>= 0; ++k
)
108 index
.push_back( k
);
109 _indices
[i
].push_back( index
);
115 void BBP::RegenerateT() {
117 _T
.resize( _fg
->nrVars() );
118 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
119 _T
[i
].resize( _fg
->nbV(i
).size() );
120 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
121 Prob
prod( _fg
->var(i
).states(), 1.0 );
122 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
123 if( J
.node
!= I
.node
)
124 prod
*= _bp_dual
.msgM( i
, J
.iter
);
125 _T
[i
][I
.iter
] = prod
;
131 void BBP::RegenerateU() {
133 _U
.resize( _fg
->nrFactors() );
134 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ ) {
135 _U
[I
].resize( _fg
->nbF(I
).size() );
136 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
137 Prob
prod( _fg
->factor(I
).states(), 1.0 );
138 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
139 if( i
.node
!= j
.node
) {
140 Prob
n_jI( _bp_dual
.msgN( j
, j
.dual
) );
141 const _ind_t
&ind
= _index( j
, j
.dual
);
142 // multiply prod by n_jI
143 for( size_t x_I
= 0; x_I
< prod
.size(); x_I
++ )
144 prod
[x_I
] *= n_jI
[ind
[x_I
]];
146 _U
[I
][i
.iter
] = prod
;
152 void BBP::RegenerateS() {
154 _S
.resize( _fg
->nrVars() );
155 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
156 _S
[i
].resize( _fg
->nbV(i
).size() );
157 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
158 _S
[i
][I
.iter
].resize( _fg
->nbF(I
).size() );
159 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
161 Factor
prod( _fg
->factor(I
) );
162 foreach( const Neighbor
&k
, _fg
->nbF(I
) ) {
163 if( k
!= i
&& k
.node
!= j
.node
) {
164 const _ind_t
&ind
= _index( k
, k
.dual
);
165 Prob
p( _bp_dual
.msgN( k
, k
.dual
) );
166 for( size_t x_I
= 0; x_I
< prod
.states(); x_I
++ )
167 prod
.p()[x_I
] *= p
[ind
[x_I
]];
170 // "Marginalize" onto i|j (unnormalized)
172 marg
= prod
.marginal( VarSet(_fg
->var(i
), _fg
->var(j
)), false ).p();
173 _S
[i
][I
.iter
][j
.iter
] = marg
;
180 void BBP::RegenerateR() {
182 _R
.resize( _fg
->nrFactors() );
183 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ ) {
184 _R
[I
].resize( _fg
->nbF(I
).size() );
185 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
186 _R
[I
][i
.iter
].resize( _fg
->nbV(i
).size() );
187 foreach( const Neighbor
&J
, _fg
->nbV(i
) ) {
189 Prob
prod( _fg
->var(i
).states(), 1.0 );
190 foreach( const Neighbor
&K
, _fg
->nbV(i
) )
191 if( K
.node
!= I
&& K
.node
!= J
.node
)
192 prod
*= _bp_dual
.msgM( i
, K
.iter
);
193 _R
[I
][i
.iter
][J
.iter
] = prod
;
201 void BBP::RegenerateInputs() {
202 _adj_b_V_unnorm
.clear();
203 _adj_b_V_unnorm
.reserve( _fg
->nrVars() );
204 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
205 _adj_b_V_unnorm
.push_back( unnormAdjoint( _bp_dual
.beliefV(i
).p(), _bp_dual
.beliefVZ(i
), _adj_b_V
[i
] ) );
206 _adj_b_F_unnorm
.clear();
207 _adj_b_F_unnorm
.reserve( _fg
->nrFactors() );
208 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ )
209 _adj_b_F_unnorm
.push_back( unnormAdjoint( _bp_dual
.beliefF(I
).p(), _bp_dual
.beliefFZ(I
), _adj_b_F
[I
] ) );
213 void BBP::RegeneratePsiAdjoints() {
215 _adj_psi_V
.reserve( _fg
->nrVars() );
216 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
217 Prob
p( _adj_b_V_unnorm
[i
] );
218 assert( p
.size() == _fg
->var(i
).states() );
219 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
220 p
*= _bp_dual
.msgM( i
, I
.iter
);
221 p
+= _init_adj_psi_V
[i
];
222 _adj_psi_V
.push_back( p
);
225 _adj_psi_F
.reserve( _fg
->nrFactors() );
226 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ ) {
227 Prob
p( _adj_b_F_unnorm
[I
] );
228 assert( p
.size() == _fg
->factor(I
).states() );
229 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
230 Prob
n_iI( _bp_dual
.msgN( i
, i
.dual
) );
231 const _ind_t
& ind
= _index( i
, i
.dual
);
232 // multiply prod with n_jI
233 for( size_t x_I
= 0; x_I
< p
.size(); x_I
++ )
234 p
[x_I
] *= n_iI
[ind
[x_I
]];
236 p
+= _init_adj_psi_F
[I
];
237 _adj_psi_F
.push_back( p
);
242 void BBP::RegenerateParMessageAdjoints() {
243 size_t nv
= _fg
->nrVars();
246 _adj_n_unnorm
.resize( nv
);
247 _adj_m_unnorm
.resize( nv
);
248 _new_adj_n
.resize( nv
);
249 _new_adj_m
.resize( nv
);
250 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
251 size_t n_i
= _fg
->nbV(i
).size();
252 _adj_n
[i
].resize( n_i
);
253 _adj_m
[i
].resize( n_i
);
254 _adj_n_unnorm
[i
].resize( n_i
);
255 _adj_m_unnorm
[i
].resize( n_i
);
256 _new_adj_n
[i
].resize( n_i
);
257 _new_adj_m
[i
].resize( n_i
);
258 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
260 Prob
prod( _fg
->factor(I
).p() );
261 prod
*= _adj_b_F_unnorm
[I
];
262 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
264 Prob
n_jI( _bp_dual
.msgN( j
, j
.dual
) );
265 const _ind_t
&ind
= _index( j
, j
.dual
);
266 // multiply prod with n_jI
267 for( size_t x_I
= 0; x_I
< prod
.size(); x_I
++ )
268 prod
[x_I
] *= n_jI
[ind
[x_I
]];
270 Prob
marg( _fg
->var(i
).states(), 0.0 );
271 const _ind_t
&ind
= _index( i
, I
.iter
);
272 for( size_t r
= 0; r
< prod
.size(); r
++ )
273 marg
[ind
[r
]] += prod
[r
];
274 _new_adj_n
[i
][I
.iter
] = marg
;
279 Prob
prod( _adj_b_V_unnorm
[i
] );
280 assert( prod
.size() == _fg
->var(i
).states() );
281 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
282 if( J
.node
!= I
.node
)
283 prod
*= _bp_dual
.msgM(i
,J
.iter
);
284 _new_adj_m
[i
][I
.iter
] = prod
;
292 void BBP::RegenerateSeqMessageAdjoints() {
293 size_t nv
= _fg
->nrVars();
295 _adj_m_unnorm
.resize( nv
);
296 _new_adj_m
.resize( nv
);
297 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
298 size_t n_i
= _fg
->nbV(i
).size();
299 _adj_m
[i
].resize( n_i
);
300 _adj_m_unnorm
[i
].resize( n_i
);
301 _new_adj_m
[i
].resize( n_i
);
302 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
304 Prob
prod( _adj_b_V_unnorm
[i
] );
305 assert( prod
.size() == _fg
->var(i
).states() );
306 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
307 if( J
.node
!= I
.node
)
308 prod
*= _bp_dual
.msgM( i
, J
.iter
);
309 _adj_m
[i
][I
.iter
] = prod
;
310 calcUnnormMsgM( i
, I
.iter
);
311 _new_adj_m
[i
][I
.iter
] = Prob( _fg
->var(i
).states(), 0.0 );
314 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
315 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
317 Prob
prod( _fg
->factor(I
).p() );
318 prod
*= _adj_b_F_unnorm
[I
];
319 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
321 Prob
n_jI( _bp_dual
.msgN( j
, j
.dual
) );
322 const _ind_t
& ind
= _index( j
, j
.dual
);
323 // multiply prod with n_jI
324 for( size_t x_I
= 0; x_I
< prod
.size(); x_I
++ )
325 prod
[x_I
] *= n_jI
[ind
[x_I
]];
327 Prob
marg( _fg
->var(i
).states(), 0.0 );
328 const _ind_t
&ind
= _index( i
, I
.iter
);
329 for( size_t r
= 0; r
< prod
.size(); r
++ )
330 marg
[ind
[r
]] += prod
[r
];
331 sendSeqMsgN( i
, I
.iter
,marg
);
337 void BBP::calcNewN( size_t i
, size_t _I
) {
338 _adj_psi_V
[i
] += T(i
,_I
) * _adj_n_unnorm
[i
][_I
];
339 Prob
&new_adj_n_iI
= _new_adj_n
[i
][_I
];
340 new_adj_n_iI
= Prob( _fg
->var(i
).states(), 0.0 );
341 size_t I
= _fg
->nbV(i
)[_I
];
342 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
344 const Prob
&p
= _S
[i
][_I
][j
.iter
];
345 const Prob
&_adj_m_unnorm_jI
= _adj_m_unnorm
[j
][j
.dual
];
347 new_adj_n_iI
[xi
] += p
[xij
] * _adj_m_unnorm_jI
[xj
];
349 /* THE FOLLOWING WOULD BE ABOUT TWICE AS SLOW:
350 Var vi = _fg->var(i);
351 Var vj = _fg->var(j);
352 new_adj_n_iI = (Factor(VarSet(vi, vj), p) * Factor(vj,_adj_m_unnorm_jI)).marginal(vi,false).p();
358 void BBP::calcNewM( size_t i
, size_t _I
) {
359 const Neighbor
&I
= _fg
->nbV(i
)[_I
];
360 Prob
p( U(I
, I
.dual
) );
361 const Prob
&adj
= _adj_m_unnorm
[i
][_I
];
362 const _ind_t
&ind
= _index(i
,_I
);
363 for( size_t x_I
= 0; x_I
< p
.size(); x_I
++ )
364 p
[x_I
] *= adj
[ind
[x_I
]];
366 /* THE FOLLOWING WOULD BE SLIGHTLY SLOWER:
367 _adj_psi_F[I] += (Factor( _fg->factor(I).vars(), U(I, I.dual) ) * Factor( _fg->var(i), _adj_m_unnorm[i][_I] )).p();
370 _new_adj_m
[i
][_I
] = Prob( _fg
->var(i
).states(), 0.0 );
371 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
373 _new_adj_m
[i
][_I
] += _R
[I
][I
.dual
][J
.iter
] * _adj_n_unnorm
[i
][J
.iter
];
377 void BBP::calcUnnormMsgN( size_t i
, size_t _I
) {
378 _adj_n_unnorm
[i
][_I
] = unnormAdjoint( _bp_dual
.msgN(i
,_I
), _bp_dual
.zN(i
,_I
), _adj_n
[i
][_I
] );
382 void BBP::calcUnnormMsgM( size_t i
, size_t _I
) {
383 _adj_m_unnorm
[i
][_I
] = unnormAdjoint( _bp_dual
.msgM(i
,_I
), _bp_dual
.zM(i
,_I
), _adj_m
[i
][_I
] );
387 void BBP::upMsgN( size_t i
, size_t _I
) {
388 _adj_n
[i
][_I
] = _new_adj_n
[i
][_I
];
389 calcUnnormMsgN( i
, _I
);
393 void BBP::upMsgM( size_t i
, size_t _I
) {
394 _adj_m
[i
][_I
] = _new_adj_m
[i
][_I
];
395 calcUnnormMsgM( i
, _I
);
399 void BBP::doParUpdate() {
400 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
401 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
402 calcNewM( i
, I
.iter
);
403 calcNewN( i
, I
.iter
);
405 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
406 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
413 void BBP::incrSeqMsgM( size_t i
, size_t _I
, const Prob
&p
) {
414 if( props
.clean_updates
)
415 _new_adj_m
[i
][_I
] += p
;
418 calcUnnormMsgM(i
, _I
);
428 void BBP::updateSeqMsgM( size_t i
, size_t _I
) {
429 if( props
.clean_updates
) {
431 if(_new_adj_m
[i
][_I
].sumAbs() > pv_thresh
||
432 _adj_m
[i
][_I
].sumAbs() > pv_thresh
) {
434 DAI_DMSG("in updateSeqMsgM");
437 DAI_PV(_adj_m
[i
][_I
]);
438 DAI_PV(_new_adj_m
[i
][_I
]);
441 _adj_m
[i
][_I
] += _new_adj_m
[i
][_I
];
442 calcUnnormMsgM( i
, _I
);
443 _new_adj_m
[i
][_I
].fill( 0.0 );
448 void BBP::setSeqMsgM( size_t i
, size_t _I
, const Prob
&p
) {
450 calcUnnormMsgM( i
, _I
);
454 void BBP::sendSeqMsgN( size_t i
, size_t _I
, const Prob
&f
) {
455 Prob f_unnorm
= unnormAdjoint( _bp_dual
.msgN(i
,_I
), _bp_dual
.zN(i
,_I
), f
);
456 const Neighbor
&I
= _fg
->nbV(i
)[_I
];
457 assert( I
.iter
== _I
);
458 _adj_psi_V
[i
] += f_unnorm
* T( i
, _I
);
460 if(f_unnorm
.sumAbs() > pv_thresh
) {
461 DAI_DMSG("in sendSeqMsgN");
465 DAI_PV(_bp_dual
.msgN(i
,_I
));
466 DAI_PV(_bp_dual
.zN(i
,_I
));
467 DAI_PV(_bp_dual
.msgM(i
,_I
));
468 DAI_PV(_bp_dual
.zM(i
,_I
));
469 DAI_PV(_fg
->factor(I
).p());
472 foreach( const Neighbor
&J
, _fg
->nbV(i
) ) {
473 if( J
.node
!= I
.node
) {
475 if(f_unnorm
.sumAbs() > pv_thresh
) {
476 DAI_DMSG("in sendSeqMsgN loop");
479 DAI_PV(_R
[J
][J
.dual
][_I
]);
480 DAI_PV(f_unnorm
* _R
[J
][J
.dual
][_I
]);
483 incrSeqMsgM( i
, J
.iter
, f_unnorm
* R(J
, J
.dual
, _I
) );
489 void BBP::sendSeqMsgM( size_t j
, size_t _I
) {
490 const Neighbor
&I
= _fg
->nbV(j
)[_I
];
494 // DAI_PV(_adj_m_unnorm_jI);
495 // DAI_PV(_adj_m[j][_I]);
496 // DAI_PV(_bp_dual.zM(j,_I));
499 const Prob
&_adj_m_unnorm_jI
= _adj_m_unnorm
[j
][_I
];
501 const _ind_t
&ind
= _index(j
, _I
);
502 for( size_t x_I
= 0; x_I
< um
.size(); x_I
++ )
503 um
[x_I
] *= _adj_m_unnorm_jI
[ind
[x_I
]];
504 um
*= 1 - props
.damping
;
507 /* THE FOLLOWING WOULD BE SLIGHTLY SLOWER:
508 _adj_psi_F[I] += (Factor( _fg->factor(I).vars(), U(I, _j) ) * Factor( _fg->var(j), _adj_m_unnorm[j][_I] )).p() * (1.0 - props.damping);
511 // DAI_DMSG("in sendSeqMsgM");
515 // DAI_PV(_fg->nbF(I).size());
516 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
518 const Prob
&S
= _S
[i
][i
.dual
][_j
];
519 Prob
msg( _fg
->var(i
).states(), 0.0 );
521 msg
[xi
] += S
[xij
] * _adj_m_unnorm_jI
[xj
];
523 msg
*= 1.0 - props
.damping
;
524 /* THE FOLLOWING WOULD BE ABOUT TWICE AS SLOW:
525 Var vi = _fg->var(i);
526 Var vj = _fg->var(j);
527 msg = (Factor(VarSet(vi,vj), S) * Factor(vj,_adj_m_unnorm_jI)).marginal(vi,false).p() * (1.0 - props.damping);
530 if(msg
.sumAbs() > pv_thresh
) {
531 DAI_DMSG("in sendSeqMsgM loop");
536 DAI_PV(_fg
->nbF(I
).size());
537 DAI_PV(_fg
->factor(I
).p());
538 DAI_PV(_S
[i
][i
.dual
][_j
]);
543 DAI_PV(_fg
->nbV(i
).size());
546 assert( _fg
->nbV(i
)[i
.dual
].node
== I
);
547 sendSeqMsgN( i
, i
.dual
, msg
);
550 setSeqMsgM( j
, _I
, _adj_m
[j
][_I
] * props
.damping
);
554 Real
BBP::getUnMsgMag() {
557 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
558 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
559 s
+= _adj_m_unnorm
[i
][I
.iter
].sumAbs();
560 s
+= _adj_n_unnorm
[i
][I
.iter
].sumAbs();
567 void BBP::getMsgMags( Real
&s
, Real
&new_s
) {
571 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
572 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
573 s
+= _adj_m
[i
][I
.iter
].sumAbs();
574 s
+= _adj_n
[i
][I
.iter
].sumAbs();
575 new_s
+= _new_adj_m
[i
][I
.iter
].sumAbs();
576 new_s
+= _new_adj_n
[i
][I
.iter
].sumAbs();
583 // tuple<size_t,size_t,Real> BBP::getArgMaxPsi1Adj() {
584 // size_t argmax_var=0;
585 // size_t argmax_var_state=0;
587 // for( size_t i = 0; i < _fg->nrVars(); i++ ) {
588 // pair<size_t,Real> argmax_state = adj_psi_V(i).argmax();
589 // if(i==0 || argmax_state.second>max_var) {
591 // max_var = argmax_state.second;
592 // argmax_var_state = argmax_state.first;
595 // assert(/*0 <= argmax_var_state &&*/
596 // argmax_var_state < _fg->var(argmax_var).states());
597 // return tuple<size_t,size_t,Real>(argmax_var,argmax_var_state,max_var);
601 void BBP::getArgmaxMsgM( size_t &out_i
, size_t &out__I
, Real
&mag
) {
603 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
604 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
605 Real thisMag
= _adj_m
[i
][I
.iter
].sumAbs();
606 if( !found
|| mag
< thisMag
) {
617 Real
BBP::getMaxMsgM() {
620 getArgmaxMsgM( dummy
, dummy
, mag
);
625 Real
BBP::getTotalMsgM() {
627 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
628 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
629 mag
+= _adj_m
[i
][I
.iter
].sumAbs();
634 Real
BBP::getTotalNewMsgM() {
636 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
637 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
638 mag
+= _new_adj_m
[i
][I
.iter
].sumAbs();
643 Real
BBP::getTotalMsgN() {
645 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
646 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
647 mag
+= _adj_n
[i
][I
.iter
].sumAbs();
652 void BBP::Regenerate() {
659 RegeneratePsiAdjoints();
660 if( props
.updates
== Properties::UpdateType::PAR
)
661 RegenerateParMessageAdjoints();
663 RegenerateSeqMessageAdjoints();
668 std::vector
<Prob
> BBP::getZeroAdjF( const FactorGraph
&fg
) {
670 adj_2
.reserve( fg
.nrFactors() );
671 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
672 adj_2
.push_back( Prob( fg
.factor(I
).states(), 0.0 ) );
677 std::vector
<Prob
> BBP::getZeroAdjV( const FactorGraph
&fg
) {
679 adj_1
.reserve( fg
.nrVars() );
680 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
681 adj_1
.push_back( Prob( fg
.var(i
).states(), 0.0 ) );
687 typedef BBP::Properties::UpdateType UT
;
688 Real
&tol
= props
.tol
;
689 UT
&updates
= props
.updates
;
692 switch( (size_t)updates
) {
698 getArgmaxMsgM( i
, _I
, mag
);
699 sendSeqMsgM( i
, _I
);
700 } while( mag
> tol
&& _iters
< props
.maxiter
);
702 if( _iters
>= props
.maxiter
)
703 if( props
.verbose
>= 1 )
704 cerr
<< "Warning: BBP didn't converge in " << _iters
<< " iterations (greatest message magnitude = " << mag
<< ")" << endl
;
706 } case UT::SEQ_FIX
: {
710 mag
= getTotalMsgM();
714 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
715 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
716 sendSeqMsgM( i
, I
.iter
);
717 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
718 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
719 updateSeqMsgM( i
, I
.iter
);
720 } while( mag
> tol
&& _iters
< props
.maxiter
);
722 if( _iters
>= props
.maxiter
)
723 if( props
.verbose
>= 1 )
724 cerr
<< "Warning: BBP didn't converge in " << _iters
<< " iterations (greatest message magnitude = " << mag
<< ")" << endl
;
726 } case UT::SEQ_BP_REV
:
727 case UT::SEQ_BP_FWD
: {
728 const BP
*bp
= static_cast<const BP
*>(_ia
);
729 vector
<pair
<size_t, size_t> > sentMessages
= bp
->getSentMessages();
730 size_t totalMessages
= sentMessages
.size();
731 if( totalMessages
== 0 ) {
732 cerr
<< "Asked for updates = " << updates
<< " but no BP messages; did you forget to set recordSentMessages?" << endl
;
733 DAI_THROW(INTERNAL_ERROR
);
735 if( updates
==UT::SEQ_BP_FWD
)
736 reverse( sentMessages
.begin(), sentMessages
.end() );
737 // DAI_PV(sentMessages.size());
739 // DAI_PV(props.maxiter);
740 while( sentMessages
.size() > 0 && _iters
< props
.maxiter
) {
741 // DAI_PV(sentMessages.size());
744 pair
<size_t, size_t> e
= sentMessages
.back();
745 sentMessages
.pop_back();
746 size_t i
= e
.first
, _I
= e
.second
;
747 sendSeqMsgM( i
, _I
);
749 if( _iters
>= props
.maxiter
)
750 if( props
.verbose
>= 1 )
751 cerr
<< "Warning: BBP updates limited to " << props
.maxiter
<< " iterations, but using UpdateType " << updates
<< " with " << totalMessages
<< " messages" << endl
;
757 } while( (_iters
< 2 || getUnMsgMag() > tol
) && _iters
< props
.maxiter
);
758 if( _iters
== props
.maxiter
) {
760 getMsgMags( s
, new_s
);
761 if( props
.verbose
>= 1 )
762 cerr
<< "Warning: BBP didn't converge in " << _iters
<< " iterations (unnorm message magnitude = " << getUnMsgMag() << ", norm message mags = " << s
<< " -> " << new_s
<< ")" << endl
;
767 if( props
.verbose
>= 3 )
768 cerr
<< "BBP::run() took " << toc()-tic
<< " seconds " << doneIters() << " iterations" << endl
;
772 double numericBBPTest( const InfAlg
&bp
, const vector
<size_t> *state
, const PropertySet
&bbp_props
, bbp_cfn_t cfn
, double h
) {
773 // calculate the value of the unperturbed cost function
774 Real cf0
= getCostFn( bp
, cfn
, state
);
776 // run BBP to estimate adjoints
777 BBP
bbp( &bp
, bbp_props
);
778 initBBPCostFnAdj( bbp
, bp
, cfn
, state
);
782 const FactorGraph
& fg
= bp
.fg();
785 // verify bbp.adj_psi_V
787 // for each variable i
788 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
789 vector
<double> adj_est
;
791 for( size_t xi
= 0; xi
< fg
.var(i
).states(); xi
++ ) {
792 // Clone 'bp' (which may be any InfAlg)
793 InfAlg
*bp_prb
= bp
.clone();
796 size_t n
= bp_prb
->fg().var(i
).states();
797 Prob
psi_1_prb( n
, 1.0 );
799 // psi_1_prb.normalize();
800 size_t I
= bp_prb
->fg().nbV(i
)[0]; // use first factor in list of neighbors of i
801 bp_prb
->fg().factor(I
) *= Factor( bp_prb
->fg().var(i
), psi_1_prb
);
803 // call 'init' on the perturbed variables
804 bp_prb
->init( bp_prb
->fg().var(i
) );
806 // run copy to convergence
809 // calculate new value of cost function
810 Real cf_prb
= getCostFn( *bp_prb
, cfn
, state
);
812 // use to estimate adjoint for i
813 adj_est
.push_back( (cf_prb
- cf0
) / h
);
815 // free cloned InfAlg
818 Prob
p_adj_est( adj_est
.begin(), adj_est
.end() );
819 // compare this numerical estimate to the BBP estimate; sum the distances
821 << ", p_adj_est: " << p_adj_est
822 << ", bbp.adj_psi_V(i): " << bbp
.adj_psi_V(i
) << endl
;
823 d
+= dist( p_adj_est
, bbp
.adj_psi_V(i
), Prob::DISTL1
);
827 // verify bbp.adj_n and bbp.adj_m
829 // We actually want to check the responsiveness of objective
830 // function to changes in the final messages. But at the end of a
831 // BBP run, the message adjoints are for the initial messages (and
832 // they should be close to zero, see paper). So this resets the
833 // BBP adjoints to the refer to the desired final messages
834 bbp.RegenerateMessageAdjoints();
836 // for each variable i
837 for(size_t i=0; i<bp_dual.nrVars(); i++) {
838 // for each factor I ~ i
839 foreach(size_t I, bp_dual.nbV(i)) {
840 vector<double> adj_n_est;
842 for(size_t xi=0; xi<bp_dual.var(i).states(); xi++) {
843 BP_dual bp_dual_prb(bp_dual);
844 // make h-sized change to newMsgN
845 bp_dual_prb.newMsgN(i,I)[xi] += h;
846 // recalculate beliefs
847 bp_dual_prb.CalcBeliefs();
848 // get cost function value
849 Real cf_prb = getCostFn(bp_dual_prb, cfn, &state);
850 // add it to list of adjoints
851 adj_n_est.push_back((cf_prb-cf0)/h);
854 vector<double> adj_m_est;
856 for(size_t xi=0; xi<bp_dual.var(i).states(); xi++) {
857 BP_dual bp_dual_prb(bp_dual);
858 // make h-sized change to newMsgM
859 bp_dual_prb.newMsgM(I,i)[xi] += h;
860 // recalculate beliefs
861 bp_dual_prb.CalcBeliefs();
862 // get cost function value
863 Real cf_prb = getCostFn(bp_dual_prb, cfn, &state);
864 // add it to list of adjoints
865 adj_m_est.push_back((cf_prb-cf0)/h);
868 Prob p_adj_n_est(adj_n_est.begin(), adj_n_est.end());
869 // compare this numerical estimate to the BBP estimate; sum the distances
870 cerr << "i: " << i << ", I: " << I
871 << ", adj_n_est: " << p_adj_n_est
872 << ", bbp.adj_n(i,I): " << bbp.adj_n(i,I) << endl;
873 d += dist(p_adj_n_est, bbp.adj_n(i,I), Prob::DISTL1);
875 Prob p_adj_m_est(adj_m_est.begin(), adj_m_est.end());
876 // compare this numerical estimate to the BBP estimate; sum the distances
877 cerr << "i: " << i << ", I: " << I
878 << ", adj_m_est: " << p_adj_m_est
879 << ", bbp.adj_m(I,i): " << bbp.adj_m(I,i) << endl;
880 d += dist(p_adj_m_est, bbp.adj_m(I,i), Prob::DISTL1);
886 // verify bbp.adj_b_V
887 for(size_t i=0; i<bp_dual.nrVars(); i++) {
888 vector<double> adj_b_V_est;
890 for(size_t xi=0; xi<bp_dual.var(i).states(); xi++) {
891 BP_dual bp_dual_prb(bp_dual);
893 // make h-sized change to b_1(i)[x_i]
894 bp_dual_prb._beliefs.b1[i][xi] += h;
896 // get cost function value
897 Real cf_prb = getCostFn(bp_dual_prb, cfn, &state);
899 // add it to list of adjoints
900 adj_b_V_est.push_back((cf_prb-cf0)/h);
902 Prob p_adj_b_V_est(adj_b_V_est.begin(), adj_b_V_est.end());
903 // compare this numerical estimate to the BBP estimate; sum the distances
905 << ", adj_b_V_est: " << p_adj_b_V_est
906 << ", bbp.adj_b_V(i): " << bbp.adj_b_V(i) << endl;
907 d += dist(p_adj_b_V_est, bbp.adj_b_V(i), Prob::DISTL1);
912 // return total of distances
917 bool needGibbsState( bbp_cfn_t cfn
) {
918 switch( (size_t)cfn
) {
919 case bbp_cfn_t::CFN_GIBBS_B
:
920 case bbp_cfn_t::CFN_GIBBS_B2
:
921 case bbp_cfn_t::CFN_GIBBS_EXP
:
922 case bbp_cfn_t::CFN_GIBBS_B_FACTOR
:
923 case bbp_cfn_t::CFN_GIBBS_B2_FACTOR
:
924 case bbp_cfn_t::CFN_GIBBS_EXP_FACTOR
:
932 void initBBPCostFnAdj( BBP
&bbp
, const InfAlg
&ia
, bbp_cfn_t cfn_type
, const vector
<size_t> *stateP
) {
933 const FactorGraph
&fg
= ia
.fg();
935 switch( (size_t)cfn_type
) {
936 case bbp_cfn_t::CFN_BETHE_ENT
: {
939 vector
<Prob
> psi1_adj
;
940 vector
<Prob
> psi2_adj
;
941 b1_adj
.reserve( fg
.nrVars() );
942 psi1_adj
.reserve( fg
.nrVars() );
943 b2_adj
.reserve( fg
.nrFactors() );
944 psi2_adj
.reserve( fg
.nrFactors() );
945 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
946 size_t dim
= fg
.var(i
).states();
947 int c
= fg
.nbV(i
).size();
949 for( size_t xi
= 0; xi
< dim
; xi
++ )
950 p
[xi
] = (1 - c
) * (1 + log( ia
.beliefV(i
)[xi
] ));
951 b1_adj
.push_back( p
);
953 for( size_t xi
= 0; xi
< dim
; xi
++ )
954 p
[xi
] = -ia
.beliefV(i
)[xi
];
955 psi1_adj
.push_back( p
);
957 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
958 size_t dim
= fg
.factor(I
).states();
960 for( size_t xI
= 0; xI
< dim
; xI
++ )
961 p
[xI
] = 1 + log( ia
.beliefF(I
)[xI
] / fg
.factor(I
).p()[xI
] );
962 b2_adj
.push_back( p
);
964 for( size_t xI
= 0; xI
< dim
; xI
++ )
965 p
[xI
] = -ia
.beliefF(I
)[xI
] / fg
.factor(I
).p()[xI
];
966 psi2_adj
.push_back( p
);
968 bbp
.init( b1_adj
, b2_adj
, psi1_adj
, psi2_adj
);
970 } case bbp_cfn_t::CFN_FACTOR_ENT
: {
972 b2_adj
.reserve( fg
.nrFactors() );
973 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
974 size_t dim
= fg
.factor(I
).states();
976 for( size_t xI
= 0; xI
< dim
; xI
++ ) {
977 double bIxI
= ia
.beliefF(I
)[xI
];
981 p
[xI
] = 1 + log( bIxI
);
985 bbp
.init( bbp
.getZeroAdjV(fg
), b2_adj
);
987 } case bbp_cfn_t::CFN_VAR_ENT
: {
989 b1_adj
.reserve( fg
.nrVars() );
990 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
991 size_t dim
= fg
.var(i
).states();
993 for( size_t xi
= 0; xi
< fg
.var(i
).states(); xi
++ ) {
994 double bixi
= ia
.beliefV(i
)[xi
];
998 p
[xi
] = 1 + log( bixi
);
1000 b1_adj
.push_back( p
);
1004 } case bbp_cfn_t::CFN_GIBBS_B
:
1005 case bbp_cfn_t::CFN_GIBBS_B2
:
1006 case bbp_cfn_t::CFN_GIBBS_EXP
: {
1007 // cost functions that use Gibbs sample, summing over variable marginals
1008 vector
<size_t> state
;
1009 if( stateP
== NULL
)
1010 state
= getGibbsState( ia
, 2*ia
.Iterations() );
1013 assert( state
.size() == fg
.nrVars() );
1015 vector
<Prob
> b1_adj
;
1016 b1_adj
.reserve(fg
.nrVars());
1017 for( size_t i
= 0; i
< state
.size(); i
++ ) {
1018 size_t n
= fg
.var(i
).states();
1019 Prob
delta( n
, 0.0 );
1020 assert(/*0<=state[i] &&*/ state
[i
] < n
);
1021 double b
= ia
.beliefV(i
)[state
[i
]];
1022 switch( (size_t)cfn_type
) {
1023 case bbp_cfn_t::CFN_GIBBS_B
:
1024 delta
[state
[i
]] = 1.0;
1026 case bbp_cfn_t::CFN_GIBBS_B2
:
1027 delta
[state
[i
]] = b
;
1029 case bbp_cfn_t::CFN_GIBBS_EXP
:
1030 delta
[state
[i
]] = exp(b
);
1033 DAI_THROW(INTERNAL_ERROR
);
1035 b1_adj
.push_back( delta
);
1039 } case bbp_cfn_t::CFN_GIBBS_B_FACTOR
:
1040 case bbp_cfn_t::CFN_GIBBS_B2_FACTOR
:
1041 case bbp_cfn_t::CFN_GIBBS_EXP_FACTOR
: {
1042 // cost functions that use Gibbs sample, summing over factor marginals
1043 vector
<size_t> state
;
1044 if( stateP
== NULL
)
1045 state
= getGibbsState( ia
, 2*ia
.Iterations() );
1048 assert( state
.size() == fg
.nrVars() );
1050 vector
<Prob
> b2_adj
;
1051 b2_adj
.reserve( fg
.nrVars() );
1052 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
1053 size_t n
= fg
.factor(I
).states();
1054 Prob
delta( n
, 0.0 );
1056 size_t x_I
= getFactorEntryForState( fg
, I
, state
);
1057 assert(/*0<=x_I &&*/ x_I
< n
);
1059 double b
= ia
.beliefF(I
)[x_I
];
1060 switch( (size_t)cfn_type
) {
1061 case bbp_cfn_t::CFN_GIBBS_B_FACTOR
:
1064 case bbp_cfn_t::CFN_GIBBS_B2_FACTOR
:
1067 case bbp_cfn_t::CFN_GIBBS_EXP_FACTOR
:
1068 delta
[x_I
] = exp( b
);
1071 DAI_THROW(INTERNAL_ERROR
);
1073 b2_adj
.push_back( delta
);
1075 bbp
.init( bbp
.getZeroAdjV(fg
), b2_adj
);
1078 DAI_THROW(UNKNOWN_ENUM_VALUE
);
1083 Real
getCostFn( const InfAlg
&ia
, bbp_cfn_t cfn_type
, const vector
<size_t> *stateP
) {
1085 const FactorGraph
&fg
= ia
.fg();
1087 switch( (size_t)cfn_type
) {
1088 case bbp_cfn_t::CFN_BETHE_ENT
: // ignores state
1091 case bbp_cfn_t::CFN_VAR_ENT
: // ignores state
1092 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
1093 cf
+= -ia
.beliefV(i
).entropy();
1095 case bbp_cfn_t::CFN_FACTOR_ENT
: // ignores state
1096 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
1097 cf
+= -ia
.beliefF(I
).entropy();
1099 case bbp_cfn_t::CFN_GIBBS_B
:
1100 case bbp_cfn_t::CFN_GIBBS_B2
:
1101 case bbp_cfn_t::CFN_GIBBS_EXP
: {
1102 assert( stateP
!= NULL
);
1103 vector
<size_t> state
= *stateP
;
1104 assert( state
.size() == fg
.nrVars() );
1105 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
1106 double b
= ia
.beliefV(i
)[state
[i
]];
1107 switch( (size_t)cfn_type
) {
1108 case bbp_cfn_t::CFN_GIBBS_B
:
1111 case bbp_cfn_t::CFN_GIBBS_B2
:
1114 case bbp_cfn_t::CFN_GIBBS_EXP
:
1118 DAI_THROW(INTERNAL_ERROR
);
1122 } case bbp_cfn_t::CFN_GIBBS_B_FACTOR
:
1123 case bbp_cfn_t::CFN_GIBBS_B2_FACTOR
:
1124 case bbp_cfn_t::CFN_GIBBS_EXP_FACTOR
: {
1125 assert( stateP
!= NULL
);
1126 vector
<size_t> state
= *stateP
;
1127 assert( state
.size() == fg
.nrVars() );
1128 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
1129 size_t x_I
= getFactorEntryForState( fg
, I
, state
);
1130 double b
= ia
.beliefF(I
)[x_I
];
1131 switch( (size_t)cfn_type
) {
1132 case bbp_cfn_t::CFN_GIBBS_B_FACTOR
:
1135 case bbp_cfn_t::CFN_GIBBS_B2_FACTOR
:
1138 case bbp_cfn_t::CFN_GIBBS_EXP_FACTOR
:
1142 DAI_THROW(INTERNAL_ERROR
);
1147 cerr
<< "Unknown cost function " << cfn_type
<< endl
;
1148 DAI_THROW(UNKNOWN_ENUM_VALUE
);
1154 } // end of namespace dai
1157 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
1158 ./scripts/regenerate-properties include/dai/bbp.h src/bbp.cpp
1162 void BBP::Properties::set(const PropertySet
&opts
)
1164 const std::set
<PropertyKey
> &keys
= opts
.allKeys();
1165 std::set
<PropertyKey
>::const_iterator i
;
1167 for(i
=keys
.begin(); i
!=keys
.end(); i
++) {
1168 if(*i
== "verbose") continue;
1169 if(*i
== "maxiter") continue;
1170 if(*i
== "tol") continue;
1171 if(*i
== "damping") continue;
1172 if(*i
== "updates") continue;
1173 if(*i
== "clean_updates") continue;
1174 cerr
<< "BBP: Unknown property " << *i
<< endl
;
1178 DAI_THROW(UNKNOWN_PROPERTY_TYPE
);
1180 if(!opts
.hasKey("verbose")) {
1181 cerr
<< "BBP: Missing property \"verbose\" for method \"BBP\"" << endl
;
1184 if(!opts
.hasKey("maxiter")) {
1185 cerr
<< "BBP: Missing property \"maxiter\" for method \"BBP\"" << endl
;
1188 if(!opts
.hasKey("tol")) {
1189 cerr
<< "BBP: Missing property \"tol\" for method \"BBP\"" << endl
;
1192 if(!opts
.hasKey("damping")) {
1193 cerr
<< "BBP: Missing property \"damping\" for method \"BBP\"" << endl
;
1196 if(!opts
.hasKey("updates")) {
1197 cerr
<< "BBP: Missing property \"updates\" for method \"BBP\"" << endl
;
1200 if(!opts
.hasKey("clean_updates")) {
1201 cerr
<< "BBP: Missing property \"clean_updates\" for method \"BBP\"" << endl
;
1205 DAI_THROW(NOT_ALL_PROPERTIES_SPECIFIED
);
1207 verbose
= opts
.getStringAs
<size_t>("verbose");
1208 maxiter
= opts
.getStringAs
<size_t>("maxiter");
1209 tol
= opts
.getStringAs
<double>("tol");
1210 damping
= opts
.getStringAs
<double>("damping");
1211 updates
= opts
.getStringAs
<UpdateType
>("updates");
1212 clean_updates
= opts
.getStringAs
<bool>("clean_updates");
1214 PropertySet
BBP::Properties::get() const {
1216 opts
.Set("verbose", verbose
);
1217 opts
.Set("maxiter", maxiter
);
1218 opts
.Set("tol", tol
);
1219 opts
.Set("damping", damping
);
1220 opts
.Set("updates", updates
);
1221 opts
.Set("clean_updates", clean_updates
);
1224 string
BBP::Properties::toString() const {
1225 stringstream
s(stringstream::out
);
1227 s
<< "verbose=" << verbose
<< ",";
1228 s
<< "maxiter=" << maxiter
<< ",";
1229 s
<< "tol=" << tol
<< ",";
1230 s
<< "damping=" << damping
<< ",";
1231 s
<< "updates=" << updates
<< ",";
1232 s
<< "clean_updates=" << clean_updates
;
1236 } // end of namespace dai
1237 /* }}} END OF GENERATED CODE */