void BP::findMaxResidual( size_t &i, size_t &_I ) {
-/*
- i = 0;
- _I = 0;
- double maxres = residual( i, _I );
- for( size_t j = 0; j < nrVars(); ++j )
- foreach( const Neighbor &I, nbV(j) )
- if( residual( j, I.iter ) > maxres ) {
- i = j;
- _I = I.iter;
- maxres = residual( i, _I );
- }
-*/
assert( !_lut.empty() );
LutType::const_iterator largestEl = _lut.end();
--largestEl;
// calculate updated message I->i
size_t I = nbV(i,_I);
- if( !DAI_BP_FAST ) {
- /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
- Factor prod( factor( I ) );
- foreach( const Neighbor &j, nbF(I) )
- if( j != i ) { // for all j in I \ i
- foreach( const Neighbor &J, nbV(j) )
- if( J != I ) { // for all J in nb(j) \ I
- prod *= Factor( var(j), message(j, J.iter) );
- }
- }
- newMessage(i,_I) = prod.marginal( var(i) ).p();
- } else {
- /* OPTIMIZED VERSION */
- Prob prod( factor(I).p() );
- if( props.logdomain )
- prod.takeLog();
+ Factor Fprod( factor(I) );
+ Prob &prod = Fprod.p();
+ if( props.logdomain )
+ prod.takeLog();
+
+ // Calculate product of incoming messages and factor I
+ foreach( const Neighbor &j, nbF(I) )
+ if( j != i ) { // for all j in I \ i
+ // prod_j will be the product of messages coming into j
+ Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
+ foreach( const Neighbor &J, nbV(j) )
+ if( J != I ) { // for all J in nb(j) \ I
+ if( props.logdomain )
+ prod_j += message( j, J.iter );
+ else
+ prod_j *= message( j, J.iter );
+ }
- // Calculate product of incoming messages and factor I
- foreach( const Neighbor &j, nbF(I) ) {
- if( j != i ) { // for all j in I \ i
+ // multiply prod with prod_j
+ if( !DAI_BP_FAST ) {
+ /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+ if( props.logdomain )
+ Fprod += Factor( var(j), prod_j );
+ else
+ Fprod *= Factor( var(j), prod_j );
+ } else {
+ /* OPTIMIZED VERSION */
size_t _I = j.dual;
// ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
const ind_t &ind = index(j, _I);
-
- // prod_j will be the product of messages coming into j
- Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
- foreach( const Neighbor &J, nbV(j) )
- if( J != I ) { // for all J in nb(j) \ I
- if( props.logdomain )
- prod_j += message( j, J.iter );
- else
- prod_j *= message( j, J.iter );
- }
-
- // multiply prod with prod_j
for( size_t r = 0; r < prod.size(); ++r )
if( props.logdomain )
prod[r] += prod_j[ind[r]];
prod[r] *= prod_j[ind[r]];
}
}
- if( props.logdomain ) {
- prod -= prod.max();
- prod.takeExp();
- }
- // Marginalize onto i
- Prob marg( var(i).states(), 0.0 );
+ if( props.logdomain ) {
+ prod -= prod.max();
+ prod.takeExp();
+ }
+
+ // Marginalize onto i
+ Prob marg;
+ if( !DAI_BP_FAST ) {
+ /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+ if( props.inference == Properties::InfType::SUMPROD )
+ marg = Fprod.marginal( var(i) ).p();
+ else
+ marg = Fprod.maxMarginal( var(i) ).p();
+ } else {
+ /* OPTIMIZED VERSION */
+ marg = Prob( var(i).states(), 0.0 );
// ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
const ind_t ind = index(i,_I);
if( props.inference == Properties::InfType::SUMPROD )
if( prod[r] > marg[ind[r]] )
marg[ind[r]] = prod[r];
marg.normalize();
-
- // Store result
- if( props.logdomain )
- newMessage(i,_I) = marg.log();
- else
- newMessage(i,_I) = marg;
}
+ // Store result
+ if( props.logdomain )
+ newMessage(i,_I) = marg.log();
+ else
+ newMessage(i,_I) = marg;
+
// Update the residual if necessary
if( props.updates == Properties::UpdateType::SEQMAX )
updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
void BP::calcBeliefF( size_t I, Prob &p ) const {
- p = factor(I).p();
- if( props.logdomain )
- p.takeLog();
+ Factor Fprod( factor( I ) );
+ Prob &prod = Fprod.p();
- foreach( const Neighbor &j, nbF(I) ) {
- size_t _I = j.dual;
- // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
- const ind_t & ind = index(j, _I);
+ if( props.logdomain )
+ prod.takeLog();
+ foreach( const Neighbor &j, nbF(I) ) {
// prod_j will be the product of messages coming into j
- Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
- foreach( const Neighbor &J, nbV(j) ) {
+ Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
+ foreach( const Neighbor &J, nbV(j) )
if( J != I ) { // for all J in nb(j) \ I
if( props.logdomain )
prod_j += newMessage( j, J.iter );
else
prod_j *= newMessage( j, J.iter );
}
- }
- // multiply p with prod_j
- for( size_t r = 0; r < p.size(); ++r ) {
+ // multiply prod with prod_j
+ if( !DAI_BP_FAST ) {
+ /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
if( props.logdomain )
- p[r] += prod_j[ind[r]];
+ Fprod += Factor( var(j), prod_j );
else
- p[r] *= prod_j[ind[r]];
+ Fprod *= Factor( var(j), prod_j );
+ } else {
+ /* OPTIMIZED VERSION */
+ size_t _I = j.dual;
+ // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+ const ind_t & ind = index(j, _I);
+
+ for( size_t r = 0; r < prod.size(); ++r ) {
+ if( props.logdomain )
+ prod[r] += prod_j[ind[r]];
+ else
+ prod[r] *= prod_j[ind[r]];
+ }
}
}
+
+ p = prod;
}
p -= p.max();
p.takeExp();
}
-
p.normalize();
+
return( Factor( var(i), p ) );
}
+Factor BP::beliefF( size_t I ) const {
+ Prob p;
+ calcBeliefF( I, p );
+
+ if( props.logdomain ) {
+ p -= p.max();
+ p.takeExp();
+ }
+ p.normalize();
+
+ return( Factor( factor(I).vars(), p ) );
+}
+
+
Factor BP::belief( const Var &n ) const {
return( beliefV( findVar( n ) ) );
}
}
-Factor BP::beliefF( size_t I ) const {
- if( !DAI_BP_FAST ) {
- /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
-
- Factor prod( factor(I) );
- foreach( const Neighbor &j, nbF(I) ) {
- foreach( const Neighbor &J, nbV(j) ) {
- if( J != I ) // for all J in nb(j) \ I
- prod *= Factor( var(j), newMessage(j, J.iter) );
- }
- }
- return prod.normalized();
- } else {
- /* OPTIMIZED VERSION */
-
- Prob prod;
- calcBeliefF( I, prod );
-
- if( props.logdomain ) {
- prod -= prod.max();
- prod.takeExp();
- }
- prod.normalize();
-
- Factor result( factor(I).vars(), prod );
-
- return( result );
- }
-}
-
-
Real BP::logZ() const {
Real sum = 0.0;
for(size_t i = 0; i < nrVars(); ++i )