assert( opts.hasKey("maxiter") );
assert( opts.hasKey("logdomain") );
assert( opts.hasKey("updates") );
-
+
props.tol = opts.getStringAs<double>("tol");
props.maxiter = opts.getStringAs<size_t>("maxiter");
props.logdomain = opts.getStringAs<bool>("logdomain");
_edge2lut.reserve( nrVars() );
for( size_t i = 0; i < nrVars(); ++i ) {
_edges.push_back( vector<EdgeProp>() );
- _edges[i].reserve( nbV(i).size() );
+ _edges[i].reserve( nbV(i).size() );
if( props.updates == Properties::UpdateType::SEQMAX ) {
_edge2lut.push_back( vector<LutType::iterator>() );
_edge2lut[i].reserve( nbV(i).size() );
message( i, I.iter ).fill( c );
newMessage( i, I.iter ).fill( c );
if( props.updates == Properties::UpdateType::SEQMAX )
- updateResidual( i, I.iter, 0.0 );
+ updateResidual( i, I.iter, 0.0 );
}
}
}
Factor Fprod( factor(I) );
Prob &prod = Fprod.p();
- if( props.logdomain )
+ if( props.logdomain )
prod.takeLog();
// Calculate product of incoming messages and factor I
foreach( const Neighbor &j, nbF(I) )
if( j != i ) { // for all j in I \ i
// prod_j will be the product of messages coming into j
- Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
+ Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
foreach( const Neighbor &J, nbV(j) )
- if( J != I ) { // for all J in nb(j) \ I
+ if( J != I ) { // for all J in nb(j) \ I
if( props.logdomain )
prod_j += message( j, J.iter );
else
Prob marg;
if( !DAI_BP_FAST ) {
/* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
- if( props.inference == Properties::InfType::SUMPROD )
+ if( props.inference == Properties::InfType::SUMPROD )
marg = Fprod.marginal( var(i) ).p();
else
marg = Fprod.maxMarginal( var(i) ).p();
marg = Prob( var(i).states(), 0.0 );
// ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
const ind_t ind = index(i,_I);
- if( props.inference == Properties::InfType::SUMPROD )
+ if( props.inference == Properties::InfType::SUMPROD )
for( size_t r = 0; r < prod.size(); ++r )
marg[ind[r]] += prod[r];
else
for( size_t r = 0; r < prod.size(); ++r )
- if( prod[r] > marg[ind[r]] )
+ if( prod[r] > marg[ind[r]] )
marg[ind[r]] = prod[r];
marg.normalize();
}
double tic = toc();
Diffs diffs(nrVars(), 1.0);
-
+
vector<Edge> update_seq;
vector<Factor> old_beliefs;
}
}
} else if( props.updates == Properties::UpdateType::PARALL ) {
- // Parallel updates
+ // Parallel updates
for( size_t i = 0; i < nrVars(); ++i )
foreach( const Neighbor &I, nbV(i) )
calcNewMessage( i, I.iter );
// Sequential updates
if( props.updates == Properties::UpdateType::SEQRND )
random_shuffle( update_seq.begin(), update_seq.end() );
-
+
foreach( const Edge &e, update_seq ) {
calcNewMessage( e.first, e.second );
updateMessage( e.first, e.second );
void BP::calcBeliefV( size_t i, Prob &p ) const {
- p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
+ p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
foreach( const Neighbor &I, nbV(i) )
if( props.logdomain )
p += newMessage( i, I.iter );
Factor Fprod( factor( I ) );
Prob &prod = Fprod.p();
- if( props.logdomain )
+ if( props.logdomain )
prod.takeLog();
foreach( const Neighbor &j, nbF(I) ) {
// prod_j will be the product of messages coming into j
Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
foreach( const Neighbor &J, nbV(j) )
- if( J != I ) { // for all J in nb(j) \ I
+ if( J != I ) { // for all J in nb(j) \ I
if( props.logdomain )
prod_j += newMessage( j, J.iter );
else
}
-string BP::identify() const {
+string BP::identify() const {
return string(Name) + printProperties();
}
void BP::updateResidual( size_t i, size_t _I, double r ) {
- EdgeProp* pEdge = &_edges[i][_I];
- pEdge->residual = r;
-
- // rearrange look-up table (delete and reinsert new key)
- _lut.erase( _edge2lut[i][_I] );
- _edge2lut[i][_I] = _lut.insert( make_pair( r, make_pair(i, _I) ) );
+ EdgeProp* pEdge = &_edges[i][_I];
+ pEdge->residual = r;
+
+ // rearrange look-up table (delete and reinsert new key)
+ _lut.erase( _edge2lut[i][_I] );
+ _edge2lut[i][_I] = _lut.insert( make_pair( r, make_pair(i, _I) ) );
}
if( visitedVars[i] )
continue;
visitedVars[i] = true;
-
+
// Maximise with respect to variable i
Prob prod;
calcBeliefV( i, prod );
maximum[i] = max_element( prod.begin(), prod.end() ) - prod.begin();
-
+
foreach( const Neighbor &I, nbV(i) )
- if( !visitedFactors[I] )
+ if( !visitedFactors[I] )
scheduledFactors.push(I);
while( !scheduledFactors.empty() ){
if( visitedFactors[I] )
continue;
visitedFactors[I] = true;
-
+
// Evaluate if some neighboring variables still need to be fixed; if not, we're done
bool allDetermined = true;
- foreach( const Neighbor &j, nbF(I) )
+ foreach( const Neighbor &j, nbF(I) )
if( !visitedVars[j.node] ) {
allDetermined = false;
break;
}
if( allDetermined )
continue;
-
+
// Calculate product of incoming messages on factor I
Prob prod2;
calcBeliefF( I, prod2 );
maxProb = prod2[s];
}
}
-
+
// Decode the argmax
foreach( const Neighbor &j, nbF(I) ) {
if( visitedVars[j.node] ) {
visitedVars[j.node] = true;
maximum[j.node] = maxState( var(j.node) );
foreach( const Neighbor &J, nbV(j) )
- if( !visitedFactors[J] )
+ if( !visitedFactors[J] )
scheduledFactors.push(J);
}
}