1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
13 #include <dai/gibbs.h>
15 #include <dai/bipgraph.h>
24 /// Convenience typedef
25 typedef BipartiteGraph::Neighbor Neighbor
;
28 /// Returns the entry of the I'th factor corresponding to a global state
29 size_t getFactorEntryForState( const FactorGraph
&fg
, size_t I
, const vector
<size_t> &state
) {
31 for( int _j
= fg
.nbF(I
).size() - 1; _j
>= 0; _j
-- ) {
32 // note that iterating over nbF(I) yields the same ordering
33 // of variables as iterating over factor(I).vars()
34 size_t j
= fg
.nbF(I
)[_j
];
35 f_entry
*= fg
.var(j
).states();
42 bool BBPCostFunction::needGibbsState() const {
43 switch( (size_t)(*this) ) {
47 case CFN_GIBBS_B_FACTOR
:
48 case CFN_GIBBS_B2_FACTOR
:
49 case CFN_GIBBS_EXP_FACTOR
:
57 Real
BBPCostFunction::evaluate( const InfAlg
&ia
, const vector
<size_t> *stateP
) const {
59 const FactorGraph
&fg
= ia
.fg();
61 switch( (size_t)(*this) ) {
62 case CFN_BETHE_ENT
: // ignores state
65 case CFN_VAR_ENT
: // ignores state
66 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
67 cf
+= -ia
.beliefV(i
).entropy();
69 case CFN_FACTOR_ENT
: // ignores state
70 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
71 cf
+= -ia
.beliefF(I
).entropy();
76 DAI_ASSERT( stateP
!= NULL
);
77 vector
<size_t> state
= *stateP
;
78 DAI_ASSERT( state
.size() == fg
.nrVars() );
79 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
80 Real b
= ia
.beliefV(i
)[state
[i
]];
81 switch( (size_t)(*this) ) {
92 DAI_THROW(UNKNOWN_ENUM_VALUE
);
96 } case CFN_GIBBS_B_FACTOR
:
97 case CFN_GIBBS_B2_FACTOR
:
98 case CFN_GIBBS_EXP_FACTOR
: {
99 DAI_ASSERT( stateP
!= NULL
);
100 vector
<size_t> state
= *stateP
;
101 DAI_ASSERT( state
.size() == fg
.nrVars() );
102 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
103 size_t x_I
= getFactorEntryForState( fg
, I
, state
);
104 Real b
= ia
.beliefF(I
)[x_I
];
105 switch( (size_t)(*this) ) {
106 case CFN_GIBBS_B_FACTOR
:
109 case CFN_GIBBS_B2_FACTOR
:
112 case CFN_GIBBS_EXP_FACTOR
:
116 DAI_THROW(UNKNOWN_ENUM_VALUE
);
121 DAI_THROWE(UNKNOWN_ENUM_VALUE
, "Unknown cost function " + std::string(*this));
127 #define LOOP_ij(body) { \
128 size_t i_states = _fg->var(i).states(); \
129 size_t j_states = _fg->var(j).states(); \
130 if(_fg->var(i) > _fg->var(j)) { \
132 for(size_t xi=0; xi<i_states; xi++) { \
133 for(size_t xj=0; xj<j_states; xj++) { \
140 for(size_t xj=0; xj<j_states; xj++) { \
141 for(size_t xi=0; xi<i_states; xi++) { \
150 void BBP::RegenerateInds() {
151 // initialise _indices
152 // typedef std::vector<size_t> _ind_t;
153 // std::vector<std::vector<_ind_t> > _indices;
154 _indices
.resize( _fg
->nrVars() );
155 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
157 _indices
[i
].reserve( _fg
->nbV(i
).size() );
158 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
160 index
.reserve( _fg
->factor(I
).states() );
161 for( IndexFor
k(_fg
->var(i
), _fg
->factor(I
).vars()); k
.valid(); ++k
)
162 index
.push_back( k
);
163 _indices
[i
].push_back( index
);
169 void BBP::RegenerateT() {
171 _T
.resize( _fg
->nrVars() );
172 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
173 _T
[i
].resize( _fg
->nbV(i
).size() );
174 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
175 Prob
prod( _fg
->var(i
).states(), 1.0 );
176 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
177 if( J
.node
!= I
.node
)
178 prod
*= _bp_dual
.msgM( i
, J
.iter
);
179 _T
[i
][I
.iter
] = prod
;
185 void BBP::RegenerateU() {
187 _U
.resize( _fg
->nrFactors() );
188 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ ) {
189 _U
[I
].resize( _fg
->nbF(I
).size() );
190 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
191 Prob
prod( _fg
->factor(I
).states(), 1.0 );
192 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
193 if( i
.node
!= j
.node
) {
194 Prob
n_jI( _bp_dual
.msgN( j
, j
.dual
) );
195 const _ind_t
&ind
= _index( j
, j
.dual
);
196 // multiply prod by n_jI
197 for( size_t x_I
= 0; x_I
< prod
.size(); x_I
++ )
198 prod
[x_I
] *= n_jI
[ind
[x_I
]];
200 _U
[I
][i
.iter
] = prod
;
206 void BBP::RegenerateS() {
208 _S
.resize( _fg
->nrVars() );
209 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
210 _S
[i
].resize( _fg
->nbV(i
).size() );
211 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
212 _S
[i
][I
.iter
].resize( _fg
->nbF(I
).size() );
213 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
215 Factor
prod( _fg
->factor(I
) );
216 foreach( const Neighbor
&k
, _fg
->nbF(I
) ) {
217 if( k
!= i
&& k
.node
!= j
.node
) {
218 const _ind_t
&ind
= _index( k
, k
.dual
);
219 Prob
p( _bp_dual
.msgN( k
, k
.dual
) );
220 for( size_t x_I
= 0; x_I
< prod
.states(); x_I
++ )
221 prod
.p()[x_I
] *= p
[ind
[x_I
]];
224 // "Marginalize" onto i|j (unnormalized)
226 marg
= prod
.marginal( VarSet(_fg
->var(i
), _fg
->var(j
)), false ).p();
227 _S
[i
][I
.iter
][j
.iter
] = marg
;
234 void BBP::RegenerateR() {
236 _R
.resize( _fg
->nrFactors() );
237 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ ) {
238 _R
[I
].resize( _fg
->nbF(I
).size() );
239 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
240 _R
[I
][i
.iter
].resize( _fg
->nbV(i
).size() );
241 foreach( const Neighbor
&J
, _fg
->nbV(i
) ) {
243 Prob
prod( _fg
->var(i
).states(), 1.0 );
244 foreach( const Neighbor
&K
, _fg
->nbV(i
) )
245 if( K
.node
!= I
&& K
.node
!= J
.node
)
246 prod
*= _bp_dual
.msgM( i
, K
.iter
);
247 _R
[I
][i
.iter
][J
.iter
] = prod
;
255 void BBP::RegenerateInputs() {
256 _adj_b_V_unnorm
.clear();
257 _adj_b_V_unnorm
.reserve( _fg
->nrVars() );
258 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
259 _adj_b_V_unnorm
.push_back( unnormAdjoint( _bp_dual
.beliefV(i
).p(), _bp_dual
.beliefVZ(i
), _adj_b_V
[i
] ) );
260 _adj_b_F_unnorm
.clear();
261 _adj_b_F_unnorm
.reserve( _fg
->nrFactors() );
262 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ )
263 _adj_b_F_unnorm
.push_back( unnormAdjoint( _bp_dual
.beliefF(I
).p(), _bp_dual
.beliefFZ(I
), _adj_b_F
[I
] ) );
267 void BBP::RegeneratePsiAdjoints() {
269 _adj_psi_V
.reserve( _fg
->nrVars() );
270 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
271 Prob
p( _adj_b_V_unnorm
[i
] );
272 DAI_ASSERT( p
.size() == _fg
->var(i
).states() );
273 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
274 p
*= _bp_dual
.msgM( i
, I
.iter
);
275 p
+= _init_adj_psi_V
[i
];
276 _adj_psi_V
.push_back( p
);
279 _adj_psi_F
.reserve( _fg
->nrFactors() );
280 for( size_t I
= 0; I
< _fg
->nrFactors(); I
++ ) {
281 Prob
p( _adj_b_F_unnorm
[I
] );
282 DAI_ASSERT( p
.size() == _fg
->factor(I
).states() );
283 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
284 Prob
n_iI( _bp_dual
.msgN( i
, i
.dual
) );
285 const _ind_t
& ind
= _index( i
, i
.dual
);
286 // multiply prod with n_jI
287 for( size_t x_I
= 0; x_I
< p
.size(); x_I
++ )
288 p
[x_I
] *= n_iI
[ind
[x_I
]];
290 p
+= _init_adj_psi_F
[I
];
291 _adj_psi_F
.push_back( p
);
296 void BBP::RegenerateParMessageAdjoints() {
297 size_t nv
= _fg
->nrVars();
300 _adj_n_unnorm
.resize( nv
);
301 _adj_m_unnorm
.resize( nv
);
302 _new_adj_n
.resize( nv
);
303 _new_adj_m
.resize( nv
);
304 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
305 size_t n_i
= _fg
->nbV(i
).size();
306 _adj_n
[i
].resize( n_i
);
307 _adj_m
[i
].resize( n_i
);
308 _adj_n_unnorm
[i
].resize( n_i
);
309 _adj_m_unnorm
[i
].resize( n_i
);
310 _new_adj_n
[i
].resize( n_i
);
311 _new_adj_m
[i
].resize( n_i
);
312 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
314 Prob
prod( _fg
->factor(I
).p() );
315 prod
*= _adj_b_F_unnorm
[I
];
316 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
318 Prob
n_jI( _bp_dual
.msgN( j
, j
.dual
) );
319 const _ind_t
&ind
= _index( j
, j
.dual
);
320 // multiply prod with n_jI
321 for( size_t x_I
= 0; x_I
< prod
.size(); x_I
++ )
322 prod
[x_I
] *= n_jI
[ind
[x_I
]];
324 Prob
marg( _fg
->var(i
).states(), 0.0 );
325 const _ind_t
&ind
= _index( i
, I
.iter
);
326 for( size_t r
= 0; r
< prod
.size(); r
++ )
327 marg
[ind
[r
]] += prod
[r
];
328 _new_adj_n
[i
][I
.iter
] = marg
;
333 Prob
prod( _adj_b_V_unnorm
[i
] );
334 DAI_ASSERT( prod
.size() == _fg
->var(i
).states() );
335 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
336 if( J
.node
!= I
.node
)
337 prod
*= _bp_dual
.msgM(i
,J
.iter
);
338 _new_adj_m
[i
][I
.iter
] = prod
;
346 void BBP::RegenerateSeqMessageAdjoints() {
347 size_t nv
= _fg
->nrVars();
349 _adj_m_unnorm
.resize( nv
);
350 _new_adj_m
.resize( nv
);
351 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
352 size_t n_i
= _fg
->nbV(i
).size();
353 _adj_m
[i
].resize( n_i
);
354 _adj_m_unnorm
[i
].resize( n_i
);
355 _new_adj_m
[i
].resize( n_i
);
356 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
358 Prob
prod( _adj_b_V_unnorm
[i
] );
359 DAI_ASSERT( prod
.size() == _fg
->var(i
).states() );
360 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
361 if( J
.node
!= I
.node
)
362 prod
*= _bp_dual
.msgM( i
, J
.iter
);
363 _adj_m
[i
][I
.iter
] = prod
;
364 calcUnnormMsgM( i
, I
.iter
);
365 _new_adj_m
[i
][I
.iter
] = Prob( _fg
->var(i
).states(), 0.0 );
368 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ ) {
369 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
371 Prob
prod( _fg
->factor(I
).p() );
372 prod
*= _adj_b_F_unnorm
[I
];
373 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
375 Prob
n_jI( _bp_dual
.msgN( j
, j
.dual
) );
376 const _ind_t
& ind
= _index( j
, j
.dual
);
377 // multiply prod with n_jI
378 for( size_t x_I
= 0; x_I
< prod
.size(); x_I
++ )
379 prod
[x_I
] *= n_jI
[ind
[x_I
]];
381 Prob
marg( _fg
->var(i
).states(), 0.0 );
382 const _ind_t
&ind
= _index( i
, I
.iter
);
383 for( size_t r
= 0; r
< prod
.size(); r
++ )
384 marg
[ind
[r
]] += prod
[r
];
385 sendSeqMsgN( i
, I
.iter
,marg
);
391 void BBP::Regenerate() {
398 RegeneratePsiAdjoints();
399 if( props
.updates
== Properties::UpdateType::PAR
)
400 RegenerateParMessageAdjoints();
402 RegenerateSeqMessageAdjoints();
407 void BBP::calcNewN( size_t i
, size_t _I
) {
408 _adj_psi_V
[i
] += T(i
,_I
) * _adj_n_unnorm
[i
][_I
];
409 Prob
&new_adj_n_iI
= _new_adj_n
[i
][_I
];
410 new_adj_n_iI
= Prob( _fg
->var(i
).states(), 0.0 );
411 size_t I
= _fg
->nbV(i
)[_I
];
412 foreach( const Neighbor
&j
, _fg
->nbF(I
) )
414 const Prob
&p
= _S
[i
][_I
][j
.iter
];
415 const Prob
&_adj_m_unnorm_jI
= _adj_m_unnorm
[j
][j
.dual
];
417 new_adj_n_iI
[xi
] += p
[xij
] * _adj_m_unnorm_jI
[xj
];
419 /* THE FOLLOWING WOULD BE ABOUT TWICE AS SLOW:
420 Var vi = _fg->var(i);
421 Var vj = _fg->var(j);
422 new_adj_n_iI = (Factor(VarSet(vi, vj), p) * Factor(vj,_adj_m_unnorm_jI)).marginal(vi,false).p();
428 void BBP::calcNewM( size_t i
, size_t _I
) {
429 const Neighbor
&I
= _fg
->nbV(i
)[_I
];
430 Prob
p( U(I
, I
.dual
) );
431 const Prob
&adj
= _adj_m_unnorm
[i
][_I
];
432 const _ind_t
&ind
= _index(i
,_I
);
433 for( size_t x_I
= 0; x_I
< p
.size(); x_I
++ )
434 p
[x_I
] *= adj
[ind
[x_I
]];
436 /* THE FOLLOWING WOULD BE SLIGHTLY SLOWER:
437 _adj_psi_F[I] += (Factor( _fg->factor(I).vars(), U(I, I.dual) ) * Factor( _fg->var(i), _adj_m_unnorm[i][_I] )).p();
440 _new_adj_m
[i
][_I
] = Prob( _fg
->var(i
).states(), 0.0 );
441 foreach( const Neighbor
&J
, _fg
->nbV(i
) )
443 _new_adj_m
[i
][_I
] += _R
[I
][I
.dual
][J
.iter
] * _adj_n_unnorm
[i
][J
.iter
];
447 void BBP::calcUnnormMsgN( size_t i
, size_t _I
) {
448 _adj_n_unnorm
[i
][_I
] = unnormAdjoint( _bp_dual
.msgN(i
,_I
), _bp_dual
.zN(i
,_I
), _adj_n
[i
][_I
] );
452 void BBP::calcUnnormMsgM( size_t i
, size_t _I
) {
453 _adj_m_unnorm
[i
][_I
] = unnormAdjoint( _bp_dual
.msgM(i
,_I
), _bp_dual
.zM(i
,_I
), _adj_m
[i
][_I
] );
457 void BBP::upMsgN( size_t i
, size_t _I
) {
458 _adj_n
[i
][_I
] = _new_adj_n
[i
][_I
];
459 calcUnnormMsgN( i
, _I
);
463 void BBP::upMsgM( size_t i
, size_t _I
) {
464 _adj_m
[i
][_I
] = _new_adj_m
[i
][_I
];
465 calcUnnormMsgM( i
, _I
);
469 void BBP::doParUpdate() {
470 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
471 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
472 calcNewM( i
, I
.iter
);
473 calcNewN( i
, I
.iter
);
475 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
476 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
483 void BBP::incrSeqMsgM( size_t i
, size_t _I
, const Prob
&p
) {
484 /* if( props.clean_updates )
485 _new_adj_m[i][_I] += p;
488 calcUnnormMsgM(i
, _I
);
499 void BBP::updateSeqMsgM( size_t i, size_t _I ) {
500 if( props.clean_updates ) {
502 if(_new_adj_m[i][_I].sumAbs() > pv_thresh ||
503 _adj_m[i][_I].sumAbs() > pv_thresh) {
505 DAI_DMSG("in updateSeqMsgM");
508 DAI_PV(_adj_m[i][_I]);
509 DAI_PV(_new_adj_m[i][_I]);
512 _adj_m[i][_I] += _new_adj_m[i][_I];
513 calcUnnormMsgM( i, _I );
514 _new_adj_m[i][_I].fill( 0.0 );
519 void BBP::setSeqMsgM( size_t i
, size_t _I
, const Prob
&p
) {
521 calcUnnormMsgM( i
, _I
);
525 void BBP::sendSeqMsgN( size_t i
, size_t _I
, const Prob
&f
) {
526 Prob f_unnorm
= unnormAdjoint( _bp_dual
.msgN(i
,_I
), _bp_dual
.zN(i
,_I
), f
);
527 const Neighbor
&I
= _fg
->nbV(i
)[_I
];
528 DAI_ASSERT( I
.iter
== _I
);
529 _adj_psi_V
[i
] += f_unnorm
* T( i
, _I
);
531 if(f_unnorm
.sumAbs() > pv_thresh
) {
532 DAI_DMSG("in sendSeqMsgN");
536 DAI_PV(_bp_dual
.msgN(i
,_I
));
537 DAI_PV(_bp_dual
.zN(i
,_I
));
538 DAI_PV(_bp_dual
.msgM(i
,_I
));
539 DAI_PV(_bp_dual
.zM(i
,_I
));
540 DAI_PV(_fg
->factor(I
).p());
543 foreach( const Neighbor
&J
, _fg
->nbV(i
) ) {
544 if( J
.node
!= I
.node
) {
546 if(f_unnorm
.sumAbs() > pv_thresh
) {
547 DAI_DMSG("in sendSeqMsgN loop");
550 DAI_PV(_R
[J
][J
.dual
][_I
]);
551 DAI_PV(f_unnorm
* _R
[J
][J
.dual
][_I
]);
554 incrSeqMsgM( i
, J
.iter
, f_unnorm
* R(J
, J
.dual
, _I
) );
560 void BBP::sendSeqMsgM( size_t j
, size_t _I
) {
561 const Neighbor
&I
= _fg
->nbV(j
)[_I
];
565 // DAI_PV(_adj_m_unnorm_jI);
566 // DAI_PV(_adj_m[j][_I]);
567 // DAI_PV(_bp_dual.zM(j,_I));
570 const Prob
&_adj_m_unnorm_jI
= _adj_m_unnorm
[j
][_I
];
572 const _ind_t
&ind
= _index(j
, _I
);
573 for( size_t x_I
= 0; x_I
< um
.size(); x_I
++ )
574 um
[x_I
] *= _adj_m_unnorm_jI
[ind
[x_I
]];
575 um
*= 1 - props
.damping
;
578 /* THE FOLLOWING WOULD BE SLIGHTLY SLOWER:
579 _adj_psi_F[I] += (Factor( _fg->factor(I).vars(), U(I, _j) ) * Factor( _fg->var(j), _adj_m_unnorm[j][_I] )).p() * (1.0 - props.damping);
582 // DAI_DMSG("in sendSeqMsgM");
586 // DAI_PV(_fg->nbF(I).size());
587 foreach( const Neighbor
&i
, _fg
->nbF(I
) ) {
589 const Prob
&S
= _S
[i
][i
.dual
][_j
];
590 Prob
msg( _fg
->var(i
).states(), 0.0 );
592 msg
[xi
] += S
[xij
] * _adj_m_unnorm_jI
[xj
];
594 msg
*= 1.0 - props
.damping
;
595 /* THE FOLLOWING WOULD BE ABOUT TWICE AS SLOW:
596 Var vi = _fg->var(i);
597 Var vj = _fg->var(j);
598 msg = (Factor(VarSet(vi,vj), S) * Factor(vj,_adj_m_unnorm_jI)).marginal(vi,false).p() * (1.0 - props.damping);
601 if(msg
.sumAbs() > pv_thresh
) {
602 DAI_DMSG("in sendSeqMsgM loop");
607 DAI_PV(_fg
->nbF(I
).size());
608 DAI_PV(_fg
->factor(I
).p());
609 DAI_PV(_S
[i
][i
.dual
][_j
]);
614 DAI_PV(_fg
->nbV(i
).size());
617 DAI_ASSERT( _fg
->nbV(i
)[i
.dual
].node
== I
);
618 sendSeqMsgN( i
, i
.dual
, msg
);
621 setSeqMsgM( j
, _I
, _adj_m
[j
][_I
] * props
.damping
);
625 Prob
BBP::unnormAdjoint( const Prob
&w
, Real Z_w
, const Prob
&adj_w
) {
626 DAI_ASSERT( w
.size() == adj_w
.size() );
627 Prob
adj_w_unnorm( w
.size(), 0.0 );
629 for( size_t i
= 0; i
< w
.size(); i
++ )
630 s
+= w
[i
] * adj_w
[i
];
631 for( size_t i
= 0; i
< w
.size(); i
++ )
632 adj_w_unnorm
[i
] = (adj_w
[i
] - s
) / Z_w
;
634 // THIS WOULD BE ABOUT 50% SLOWER: return (adj_w - (w * adj_w).sum()) / Z_w;
638 Real
BBP::getUnMsgMag() {
641 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
642 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
643 s
+= _adj_m_unnorm
[i
][I
.iter
].sumAbs();
644 s
+= _adj_n_unnorm
[i
][I
.iter
].sumAbs();
651 void BBP::getMsgMags( Real
&s
, Real
&new_s
) {
655 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
656 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
657 s
+= _adj_m
[i
][I
.iter
].sumAbs();
658 s
+= _adj_n
[i
][I
.iter
].sumAbs();
659 new_s
+= _new_adj_m
[i
][I
.iter
].sumAbs();
660 new_s
+= _new_adj_n
[i
][I
.iter
].sumAbs();
667 // tuple<size_t,size_t,Real> BBP::getArgMaxPsi1Adj() {
668 // size_t argmax_var=0;
669 // size_t argmax_var_state=0;
671 // for( size_t i = 0; i < _fg->nrVars(); i++ ) {
672 // pair<size_t,Real> argmax_state = adj_psi_V(i).argmax();
673 // if(i==0 || argmax_state.second>max_var) {
675 // max_var = argmax_state.second;
676 // argmax_var_state = argmax_state.first;
679 // DAI_ASSERT(/*0 <= argmax_var_state &&*/
680 // argmax_var_state < _fg->var(argmax_var).states());
681 // return tuple<size_t,size_t,Real>(argmax_var,argmax_var_state,max_var);
685 void BBP::getArgmaxMsgM( size_t &out_i
, size_t &out__I
, Real
&mag
) {
687 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
688 foreach( const Neighbor
&I
, _fg
->nbV(i
) ) {
689 Real thisMag
= _adj_m
[i
][I
.iter
].sumAbs();
690 if( !found
|| mag
< thisMag
) {
701 Real
BBP::getMaxMsgM() {
704 getArgmaxMsgM( dummy
, dummy
, mag
);
709 Real
BBP::getTotalMsgM() {
711 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
712 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
713 mag
+= _adj_m
[i
][I
.iter
].sumAbs();
718 Real
BBP::getTotalNewMsgM() {
720 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
721 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
722 mag
+= _new_adj_m
[i
][I
.iter
].sumAbs();
727 Real
BBP::getTotalMsgN() {
729 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
730 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
731 mag
+= _adj_n
[i
][I
.iter
].sumAbs();
736 std::vector
<Prob
> BBP::getZeroAdjF( const FactorGraph
&fg
) {
738 adj_2
.reserve( fg
.nrFactors() );
739 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
740 adj_2
.push_back( Prob( fg
.factor(I
).states(), 0.0 ) );
745 std::vector
<Prob
> BBP::getZeroAdjV( const FactorGraph
&fg
) {
747 adj_1
.reserve( fg
.nrVars() );
748 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
749 adj_1
.push_back( Prob( fg
.var(i
).states(), 0.0 ) );
754 void BBP::initCostFnAdj( const BBPCostFunction
&cfn
, const vector
<size_t> *stateP
) {
755 const FactorGraph
&fg
= _ia
->fg();
757 switch( (size_t)cfn
) {
758 case BBPCostFunction::CFN_BETHE_ENT
: {
761 vector
<Prob
> psi1_adj
;
762 vector
<Prob
> psi2_adj
;
763 b1_adj
.reserve( fg
.nrVars() );
764 psi1_adj
.reserve( fg
.nrVars() );
765 b2_adj
.reserve( fg
.nrFactors() );
766 psi2_adj
.reserve( fg
.nrFactors() );
767 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
768 size_t dim
= fg
.var(i
).states();
769 int c
= fg
.nbV(i
).size();
771 for( size_t xi
= 0; xi
< dim
; xi
++ )
772 p
[xi
] = (1 - c
) * (1 + log( _ia
->beliefV(i
)[xi
] ));
773 b1_adj
.push_back( p
);
775 for( size_t xi
= 0; xi
< dim
; xi
++ )
776 p
[xi
] = -_ia
->beliefV(i
)[xi
];
777 psi1_adj
.push_back( p
);
779 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
780 size_t dim
= fg
.factor(I
).states();
782 for( size_t xI
= 0; xI
< dim
; xI
++ )
783 p
[xI
] = 1 + log( _ia
->beliefF(I
)[xI
] / fg
.factor(I
).p()[xI
] );
784 b2_adj
.push_back( p
);
786 for( size_t xI
= 0; xI
< dim
; xI
++ )
787 p
[xI
] = -_ia
->beliefF(I
)[xI
] / fg
.factor(I
).p()[xI
];
788 psi2_adj
.push_back( p
);
790 init( b1_adj
, b2_adj
, psi1_adj
, psi2_adj
);
792 } case BBPCostFunction::CFN_FACTOR_ENT
: {
794 b2_adj
.reserve( fg
.nrFactors() );
795 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
796 size_t dim
= fg
.factor(I
).states();
798 for( size_t xI
= 0; xI
< dim
; xI
++ ) {
799 Real bIxI
= _ia
->beliefF(I
)[xI
];
803 p
[xI
] = 1 + log( bIxI
);
809 } case BBPCostFunction::CFN_VAR_ENT
: {
811 b1_adj
.reserve( fg
.nrVars() );
812 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
813 size_t dim
= fg
.var(i
).states();
815 for( size_t xi
= 0; xi
< fg
.var(i
).states(); xi
++ ) {
816 Real bixi
= _ia
->beliefV(i
)[xi
];
820 p
[xi
] = 1 + log( bixi
);
822 b1_adj
.push_back( p
);
826 } case BBPCostFunction::CFN_GIBBS_B
:
827 case BBPCostFunction::CFN_GIBBS_B2
:
828 case BBPCostFunction::CFN_GIBBS_EXP
: {
829 // cost functions that use Gibbs sample, summing over variable marginals
830 vector
<size_t> state
;
832 state
= getGibbsState( _ia
->fg(), 2*_ia
->Iterations() );
835 DAI_ASSERT( state
.size() == fg
.nrVars() );
838 b1_adj
.reserve(fg
.nrVars());
839 for( size_t i
= 0; i
< state
.size(); i
++ ) {
840 size_t n
= fg
.var(i
).states();
841 Prob
delta( n
, 0.0 );
842 DAI_ASSERT(/*0<=state[i] &&*/ state
[i
] < n
);
843 Real b
= _ia
->beliefV(i
)[state
[i
]];
844 switch( (size_t)cfn
) {
845 case BBPCostFunction::CFN_GIBBS_B
:
846 delta
[state
[i
]] = 1.0;
848 case BBPCostFunction::CFN_GIBBS_B2
:
851 case BBPCostFunction::CFN_GIBBS_EXP
:
852 delta
[state
[i
]] = exp(b
);
855 DAI_THROW(UNKNOWN_ENUM_VALUE
);
857 b1_adj
.push_back( delta
);
861 } case BBPCostFunction::CFN_GIBBS_B_FACTOR
:
862 case BBPCostFunction::CFN_GIBBS_B2_FACTOR
:
863 case BBPCostFunction::CFN_GIBBS_EXP_FACTOR
: {
864 // cost functions that use Gibbs sample, summing over factor marginals
865 vector
<size_t> state
;
867 state
= getGibbsState( _ia
->fg(), 2*_ia
->Iterations() );
870 DAI_ASSERT( state
.size() == fg
.nrVars() );
873 b2_adj
.reserve( fg
.nrVars() );
874 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
875 size_t n
= fg
.factor(I
).states();
876 Prob
delta( n
, 0.0 );
878 size_t x_I
= getFactorEntryForState( fg
, I
, state
);
879 DAI_ASSERT(/*0<=x_I &&*/ x_I
< n
);
881 Real b
= _ia
->beliefF(I
)[x_I
];
882 switch( (size_t)cfn
) {
883 case BBPCostFunction::CFN_GIBBS_B_FACTOR
:
886 case BBPCostFunction::CFN_GIBBS_B2_FACTOR
:
889 case BBPCostFunction::CFN_GIBBS_EXP_FACTOR
:
890 delta
[x_I
] = exp( b
);
893 DAI_THROW(UNKNOWN_ENUM_VALUE
);
895 b2_adj
.push_back( delta
);
900 DAI_THROW(UNKNOWN_ENUM_VALUE
);
906 typedef BBP::Properties::UpdateType UT
;
907 Real tol
= props
.tol
;
908 UT
&updates
= props
.updates
;
911 switch( (size_t)updates
) {
917 getArgmaxMsgM( i
, _I
, mag
);
918 sendSeqMsgM( i
, _I
);
919 } while( mag
> tol
&& _iters
< props
.maxiter
);
921 if( _iters
>= props
.maxiter
)
922 if( props
.verbose
>= 1 )
923 cerr
<< "Warning: BBP didn't converge in " << _iters
<< " iterations (greatest message magnitude = " << mag
<< ")" << endl
;
925 } case UT::SEQ_FIX
: {
929 mag
= getTotalMsgM();
933 for( size_t i
= 0; i
< _fg
->nrVars(); i
++ )
934 foreach( const Neighbor
&I
, _fg
->nbV(i
) )
935 sendSeqMsgM( i
, I
.iter
);
936 /* for( size_t i = 0; i < _fg->nrVars(); i++ )
937 foreach( const Neighbor &I, _fg->nbV(i) )
938 updateSeqMsgM( i, I.iter );*/
939 } while( mag
> tol
&& _iters
< props
.maxiter
);
941 if( _iters
>= props
.maxiter
)
942 if( props
.verbose
>= 1 )
943 cerr
<< "Warning: BBP didn't converge in " << _iters
<< " iterations (greatest message magnitude = " << mag
<< ")" << endl
;
945 } case UT::SEQ_BP_REV
:
946 case UT::SEQ_BP_FWD
: {
947 const BP
*bp
= static_cast<const BP
*>(_ia
);
948 vector
<pair
<size_t, size_t> > sentMessages
= bp
->getSentMessages();
949 size_t totalMessages
= sentMessages
.size();
950 if( totalMessages
== 0 )
951 DAI_THROWE(INTERNAL_ERROR
, "Asked for updates=" + std::string(updates
) + " but no BP messages; did you forget to set recordSentMessages?");
952 if( updates
==UT::SEQ_BP_FWD
)
953 reverse( sentMessages
.begin(), sentMessages
.end() );
954 // DAI_PV(sentMessages.size());
956 // DAI_PV(props.maxiter);
957 while( sentMessages
.size() > 0 && _iters
< props
.maxiter
) {
958 // DAI_PV(sentMessages.size());
961 pair
<size_t, size_t> e
= sentMessages
.back();
962 sentMessages
.pop_back();
963 size_t i
= e
.first
, _I
= e
.second
;
964 sendSeqMsgM( i
, _I
);
966 if( _iters
>= props
.maxiter
)
967 if( props
.verbose
>= 1 )
968 cerr
<< "Warning: BBP updates limited to " << props
.maxiter
<< " iterations, but using UpdateType " << updates
<< " with " << totalMessages
<< " messages" << endl
;
974 } while( (_iters
< 2 || getUnMsgMag() > tol
) && _iters
< props
.maxiter
);
975 if( _iters
== props
.maxiter
) {
977 getMsgMags( s
, new_s
);
978 if( props
.verbose
>= 1 )
979 cerr
<< "Warning: BBP didn't converge in " << _iters
<< " iterations (unnorm message magnitude = " << getUnMsgMag() << ", norm message mags = " << s
<< " -> " << new_s
<< ")" << endl
;
984 if( props
.verbose
>= 3 )
985 cerr
<< "BBP::run() took " << toc()-tic
<< " seconds " << Iterations() << " iterations" << endl
;
989 Real
numericBBPTest( const InfAlg
&bp
, const std::vector
<size_t> *state
, const PropertySet
&bbp_props
, const BBPCostFunction
&cfn
, Real h
) {
990 BBP
bbp( &bp
, bbp_props
);
991 // calculate the value of the unperturbed cost function
992 Real cf0
= cfn
.evaluate( bp
, state
);
993 // run BBP to estimate adjoints
994 bbp
.initCostFnAdj( cfn
, state
);
998 const FactorGraph
& fg
= bp
.fg();
1001 // verify bbp.adj_psi_V
1003 // for each variable i
1004 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
1005 vector
<Real
> adj_est
;
1006 // for each value xi
1007 for( size_t xi
= 0; xi
< fg
.var(i
).states(); xi
++ ) {
1008 // Clone 'bp' (which may be any InfAlg)
1009 InfAlg
*bp_prb
= bp
.clone();
1012 size_t n
= bp_prb
->fg().var(i
).states();
1013 Prob
psi_1_prb( n
, 1.0 );
1015 // psi_1_prb.normalize();
1016 size_t I
= bp_prb
->fg().nbV(i
)[0]; // use first factor in list of neighbors of i
1017 bp_prb
->fg().factor(I
) *= Factor( bp_prb
->fg().var(i
), psi_1_prb
);
1019 // call 'init' on the perturbed variables
1020 bp_prb
->init( bp_prb
->fg().var(i
) );
1022 // run copy to convergence
1025 // calculate new value of cost function
1026 Real cf_prb
= cfn
.evaluate( *bp_prb
, state
);
1028 // use to estimate adjoint for i
1029 adj_est
.push_back( (cf_prb
- cf0
) / h
);
1031 // free cloned InfAlg
1034 Prob
p_adj_est( adj_est
);
1035 // compare this numerical estimate to the BBP estimate; sum the distances
1037 << ", p_adj_est: " << p_adj_est
1038 << ", bbp.adj_psi_V(i): " << bbp
.adj_psi_V(i
) << endl
;
1039 d
+= dist( p_adj_est
, bbp
.adj_psi_V(i
), Prob::DISTL1
);
1043 // verify bbp.adj_n and bbp.adj_m
1045 // We actually want to check the responsiveness of objective
1046 // function to changes in the final messages. But at the end of a
1047 // BBP run, the message adjoints are for the initial messages (and
1048 // they should be close to zero, see paper). So this resets the
1049 // BBP adjoints to the refer to the desired final messages
1050 bbp.RegenerateMessageAdjoints();
1052 // for each variable i
1053 for(size_t i=0; i<bp_dual.nrVars(); i++) {
1054 // for each factor I ~ i
1055 foreach(size_t I, bp_dual.nbV(i)) {
1056 vector<Real> adj_n_est;
1057 // for each value xi
1058 for(size_t xi=0; xi<bp_dual.var(i).states(); xi++) {
1059 BP_dual bp_dual_prb(bp_dual);
1060 // make h-sized change to newMsgN
1061 bp_dual_prb.newMsgN(i,I)[xi] += h;
1062 // recalculate beliefs
1063 bp_dual_prb.CalcBeliefs();
1064 // get cost function value
1065 Real cf_prb = getCostFn(bp_dual_prb, cfn, &state);
1066 // add it to list of adjoints
1067 adj_n_est.push_back((cf_prb-cf0)/h);
1070 vector<Real> adj_m_est;
1071 // for each value xi
1072 for(size_t xi=0; xi<bp_dual.var(i).states(); xi++) {
1073 BP_dual bp_dual_prb(bp_dual);
1074 // make h-sized change to newMsgM
1075 bp_dual_prb.newMsgM(I,i)[xi] += h;
1076 // recalculate beliefs
1077 bp_dual_prb.CalcBeliefs();
1078 // get cost function value
1079 Real cf_prb = getCostFn(bp_dual_prb, cfn, &state);
1080 // add it to list of adjoints
1081 adj_m_est.push_back((cf_prb-cf0)/h);
1084 Prob p_adj_n_est( adj_n_est );
1085 // compare this numerical estimate to the BBP estimate; sum the distances
1086 cerr << "i: " << i << ", I: " << I
1087 << ", adj_n_est: " << p_adj_n_est
1088 << ", bbp.adj_n(i,I): " << bbp.adj_n(i,I) << endl;
1089 d += dist(p_adj_n_est, bbp.adj_n(i,I), Prob::DISTL1);
1091 Prob p_adj_m_est( adj_m_est );
1092 // compare this numerical estimate to the BBP estimate; sum the distances
1093 cerr << "i: " << i << ", I: " << I
1094 << ", adj_m_est: " << p_adj_m_est
1095 << ", bbp.adj_m(I,i): " << bbp.adj_m(I,i) << endl;
1096 d += dist(p_adj_m_est, bbp.adj_m(I,i), Prob::DISTL1);
1102 // verify bbp.adj_b_V
1103 for(size_t i=0; i<bp_dual.nrVars(); i++) {
1104 vector<Real> adj_b_V_est;
1105 // for each value xi
1106 for(size_t xi=0; xi<bp_dual.var(i).states(); xi++) {
1107 BP_dual bp_dual_prb(bp_dual);
1109 // make h-sized change to b_1(i)[x_i]
1110 bp_dual_prb._beliefs.b1[i][xi] += h;
1112 // get cost function value
1113 Real cf_prb = getCostFn(bp_dual_prb, cfn, &state);
1115 // add it to list of adjoints
1116 adj_b_V_est.push_back((cf_prb-cf0)/h);
1118 Prob p_adj_b_V_est( adj_b_V_est );
1119 // compare this numerical estimate to the BBP estimate; sum the distances
1121 << ", adj_b_V_est: " << p_adj_b_V_est
1122 << ", bbp.adj_b_V(i): " << bbp.adj_b_V(i) << endl;
1123 d += dist(p_adj_b_V_est, bbp.adj_b_V(i), Prob::DISTL1);
1128 // return total of distances
1133 } // end of namespace dai
1136 /* {{{ GENERATED CODE: DO NOT EDIT. Created by
1137 ./scripts/regenerate-properties include/dai/bbp.h src/bbp.cpp
1141 void BBP::Properties::set(const PropertySet
&opts
)
1143 const std::set
<PropertyKey
> &keys
= opts
.keys();
1144 std::set
<PropertyKey
>::const_iterator i
;
1145 for(i
=keys
.begin(); i
!=keys
.end(); i
++) {
1146 if(*i
== "verbose") continue;
1147 if(*i
== "maxiter") continue;
1148 if(*i
== "tol") continue;
1149 if(*i
== "damping") continue;
1150 if(*i
== "updates") continue;
1151 DAI_THROWE(UNKNOWN_PROPERTY_TYPE
, "BBP: Unknown property " + *i
);
1153 if(!opts
.hasKey("verbose"))
1154 DAI_THROWE(NOT_ALL_PROPERTIES_SPECIFIED
,"BBP: Missing property \"verbose\" for method \"BBP\"");
1155 if(!opts
.hasKey("maxiter"))
1156 DAI_THROWE(NOT_ALL_PROPERTIES_SPECIFIED
,"BBP: Missing property \"maxiter\" for method \"BBP\"");
1157 if(!opts
.hasKey("tol"))
1158 DAI_THROWE(NOT_ALL_PROPERTIES_SPECIFIED
,"BBP: Missing property \"tol\" for method \"BBP\"");
1159 if(!opts
.hasKey("damping"))
1160 DAI_THROWE(NOT_ALL_PROPERTIES_SPECIFIED
,"BBP: Missing property \"damping\" for method \"BBP\"");
1161 if(!opts
.hasKey("updates"))
1162 DAI_THROWE(NOT_ALL_PROPERTIES_SPECIFIED
,"BBP: Missing property \"updates\" for method \"BBP\"");
1163 verbose
= opts
.getStringAs
<size_t>("verbose");
1164 maxiter
= opts
.getStringAs
<size_t>("maxiter");
1165 tol
= opts
.getStringAs
<Real
>("tol");
1166 damping
= opts
.getStringAs
<Real
>("damping");
1167 updates
= opts
.getStringAs
<UpdateType
>("updates");
1169 PropertySet
BBP::Properties::get() const {
1171 opts
.Set("verbose", verbose
);
1172 opts
.Set("maxiter", maxiter
);
1173 opts
.Set("tol", tol
);
1174 opts
.Set("damping", damping
);
1175 opts
.Set("updates", updates
);
1178 string
BBP::Properties::toString() const {
1179 stringstream
s(stringstream::out
);
1181 s
<< "verbose=" << verbose
<< ",";
1182 s
<< "maxiter=" << maxiter
<< ",";
1183 s
<< "tol=" << tol
<< ",";
1184 s
<< "damping=" << damping
<< ",";
1185 s
<< "updates=" << updates
;
1189 } // end of namespace dai
1190 /* }}} END OF GENERATED CODE */