From 33b02362c3dd18a6c52950221838506bf27ad3c7 Mon Sep 17 00:00:00 2001 From: Joris Mooij Date: Fri, 22 Jan 2010 16:20:39 +0100 Subject: [PATCH] Fixed regression FBP and bugs in TRWBP --- ChangeLog | 1 + include/dai/bp.h | 2 + include/dai/doc.h | 5 ++ include/dai/emalg.h | 1 + include/dai/fbp.h | 43 +++++++++------- include/dai/trwbp.h | 40 +++++++++------ src/bp.cpp | 118 +++++++++++++++++++++++--------------------- src/fbp.cpp | 29 +++++------ src/trwbp.cpp | 50 ++++++++++--------- tests/testall | 2 +- tests/testall.bat | 2 +- tests/testfast.out | 34 +++++++++++++ 12 files changed, 197 insertions(+), 130 deletions(-) diff --git a/ChangeLog b/ChangeLog index 12a59d0..b1ba06c 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,4 @@ +* Implemented Tree-Reweighted BP * Implemented various heuristics for choosing a variable elimination sequence in JTree * Added BETHE method for GBP/HAK cluster choice diff --git a/include/dai/bp.h b/include/dai/bp.h index e178a32..aa0441d 100644 --- a/include/dai/bp.h +++ b/include/dai/bp.h @@ -57,6 +57,8 @@ namespace dai { * \note There are two implementations, an optimized one (the default) which caches IndexFor objects, * and a slower, less complicated one which is easier to maintain/understand. The slower one can be * enabled by defining DAI_BP_FAST as false in the source file. + * + * \todo Merge duplicate code in calcNewMessage() and calcBeliefF() */ class BP : public DAIAlgFG { protected: diff --git a/include/dai/doc.h b/include/dai/doc.h index 127359f..2c85683 100644 --- a/include/dai/doc.h +++ b/include/dai/doc.h @@ -22,6 +22,8 @@ * * \todo Implement routines for UAI probabilistic inference evaluation data * + * \todo Improve SWIG interfaces + * * \idea Adapt (part of the) guidelines in http://www.boost.org/development/requirements.html#Design_and_Programming * * \idea Use "gcc -MM" to generate dependencies for targets: http://make.paulandlesley.org/autodep.html @@ -86,6 +88,7 @@ * - Mean Field; * - Loopy Belief Propagation [\ref KFL01]; * - Fractional Belief Propagation [\ref WiH03]; + * - Tree-Reweighted Belief Propagation [\ref WJW03]; * - Tree Expectation Propagation [\ref MiQ04]; * - Generalized Belief Propagation [\ref YFW05]; * - Double-loop GBP [\ref HAK03]; @@ -437,6 +440,8 @@ * Approximate inference: * - Mean Field: dai::MF * - (Loopy) Belief Propagation: dai::BP [\ref KFL01] + * - Fractional Belief Propagation: dai::FBP [\ref WiH03] + * - Tree-Reweighted Belief Propagation: dai::TRWBP [\ref WJW03] * - Tree Expectation Propagation: dai::TreeEP [\ref MiQ04] * - Generalized Belief Propagation: dai::HAK [\ref YFW05] * - Double-loop GBP: dai::HAK [\ref HAK03] diff --git a/include/dai/emalg.h b/include/dai/emalg.h index d9379da..2d3f2f6 100644 --- a/include/dai/emalg.h +++ b/include/dai/emalg.h @@ -25,6 +25,7 @@ /// \file /// \brief Defines classes related to Expectation Maximization (EMAlg, ParameterEstimation, CondProbEstimation and SharedParameters) +/// \todo Implement parameter estimation for undirected models / factor graphs. namespace dai { diff --git a/include/dai/fbp.h b/include/dai/fbp.h index 0fdc7a1..bcfa83a 100644 --- a/include/dai/fbp.h +++ b/include/dai/fbp.h @@ -29,13 +29,13 @@ namespace dai { /// Approximate inference algorithm "Fractional Belief Propagation" [\ref WiH03] /** The Fractional Belief Propagation algorithm is like Belief - * Propagation, but associates each factor with a scale parameter + * Propagation, but associates each factor with a weight (scale parameter) * which controls the divergence measure being minimized. Standard - * Belief Propagation corresponds to the case of FBP where each scale - * parameter is 1. When cast as an EP algorithm, BP (and EP) minimize + * Belief Propagation corresponds to the case of FBP where each weight + * is 1. When cast as an EP algorithm, BP (and EP) minimize * the inclusive KL-divergence, i.e. \f$\min_q KL(p||q)\f$ (note that the * Bethe free energy is typically derived from \f$ KL(q||p) \f$). If each - * factor \a I has scale parameter \f$ c_I \f$, then FBP minimizes the + * factor \a I has weight \f$ c_I \f$, then FBP minimizes the * alpha-divergence with \f$ \alpha=1/c_I \f$ for that factor, which also * corresponds to Power EP [\ref Min05]. * @@ -46,14 +46,19 @@ namespace dai { * \f[ b_i(x_i) \propto \prod_{I\in N_i} m_{I\to i} \f] * and the factor beliefs are calculated by: * \f[ b_I(x_I) \propto f_I(x_I)^{1/c_I} \prod_{j \in N_I} m_{I\to j}^{1-1/c_I} \prod_{J\in N_j\setminus\{I\}} m_{J\to j} \f] + * The logarithm of the partition sum is approximated by: + * \f[ \log Z = \sum_{I} \sum_{x_I} b_I(x_I) \big( \log f_I(x_I) - c_I \log b_I(x_I) \big) + \sum_{i} (c_i - 1) \sum_{x_i} b_i(x_i) \log b_i(x_i) \f] + * where the variable weights are defined as + * \f[ c_i := \sum_{I \in N_i} c_I \f] * - * \todo Add nice way to set scale parameters + * \todo Add nice way to set weights + * \todo Merge duplicate code in calcNewMessage() and calcBeliefF() * \author Frederik Eaton */ class FBP : public BP { protected: - /// Factor scale parameters (indexed by factor ID) - std::vector _scale_factor; + /// Factor weights (indexed by factor ID) + std::vector _weight; public: /// Name of this inference algorithm @@ -63,12 +68,12 @@ class FBP : public BP { /// \name Constructors/destructors //@{ /// Default constructor - FBP() : BP(), _scale_factor() {} + FBP() : BP(), _weight() {} /// Construct from FactorGraph \a fg and PropertySet \a opts /** \param opts Parameters @see BP::Properties */ - FBP( const FactorGraph &fg, const PropertySet &opts ) : BP(fg, opts), _scale_factor() { + FBP( const FactorGraph &fg, const PropertySet &opts ) : BP(fg, opts), _weight() { setProperties( opts ); construct(); } @@ -81,21 +86,21 @@ class FBP : public BP { virtual Real logZ() const; //@} - /// \name FBP accessors/mutators for scale parameters + /// \name FBP accessors/mutators for weights //@{ - /// Returns scale parameter of the \a I 'th factor - Real scaleF( size_t I ) const { return _scale_factor[I]; } + /// Returns weight of the \a I 'th factor + Real Weight( size_t I ) const { return _weight[I]; } - /// Returns constant reference to vector of all factor scale parameters - const std::vector& scaleFs() const { return _scale_factor; } + /// Returns constant reference to vector of all factor weights + const std::vector& Weights() const { return _weight; } - /// Sets the scale parameter of the \a I 'th factor to \a c - void setScaleF( size_t I, Real c ) { _scale_factor[I] = c; } + /// Sets the weight of the \a I 'th factor to \a c + void setWeight( size_t I, Real c ) { _weight[I] = c; } - /// Sets the scale parameters of all factors simultaenously - /** \note Faster than calling setScaleF(size_t,Real) for each factor + /// Sets the weights of all factors simultaenously + /** \note Faster than calling setWeight(size_t,Real) for each factor */ - void setScaleFs( const std::vector &c ) { _scale_factor = c; } + void setWeights( const std::vector &c ) { _weight = c; } protected: // Calculate the updated message from the \a _I 'th neighbor of variable \a i to variable \a i diff --git a/include/dai/trwbp.h b/include/dai/trwbp.h index d37c788..83f39d5 100644 --- a/include/dai/trwbp.h +++ b/include/dai/trwbp.h @@ -39,13 +39,25 @@ namespace dai { * \f[ b_i(x_i) \propto \prod_{I\in N_i} m_{I\to i}^{c_I} \f] * and the factor beliefs are calculated by: * \f[ b_I(x_I) \propto f_I(x_I)^{1/c_I} \prod_{j \in N_I} m_{I\to j}^{c_I-1} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}^{c_J} \f] + * The logarithm of the partition sum is approximated by: + * \f[ \log Z = \sum_{I} \sum_{x_I} b_I(x_I) \big( \log f_I(x_I) - c_I \log b_I(x_I) \big) + \sum_{i} (c_i - 1) \sum_{x_i} b_i(x_i) \log b_i(x_i) \f] + * where the variable weights are defined as + * \f[ c_i := \sum_{I \in N_i} c_I \f] * - * \todo Fix documentation + * \note TRWBP is actually equivalent to FBP + * \todo Merge duplicate code in calcNewMessage() and calcBeliefF() + * \todo Add nice way to set weights + * \todo Merge code of FBP and TRWBP */ class TRWBP : public BP { protected: - /// Factor scale parameters (indexed by factor ID) - std::vector _edge_weight; + /// "Edge weights" (indexed by factor ID) + /** In [\ref WJW03], only unary or pairwise factors are considered. + * Here we are more general by having a weight for each factor in the + * factor graph. If unary factors have weight 1, and higher-order factors + * are absent, then we have the special case considered in [\ref WJW03]. + */ + std::vector _weight; public: /// Name of this inference algorithm @@ -55,12 +67,12 @@ class TRWBP : public BP { /// \name Constructors/destructors //@{ /// Default constructor - TRWBP() : BP(), _edge_weight() {} + TRWBP() : BP(), _weight() {} /// Construct from FactorGraph \a fg and PropertySet \a opts /** \param opts Parameters @see BP::Properties */ - TRWBP( const FactorGraph &fg, const PropertySet &opts ) : BP(fg, opts), _edge_weight() { + TRWBP( const FactorGraph &fg, const PropertySet &opts ) : BP(fg, opts), _weight() { setProperties( opts ); construct(); } @@ -75,19 +87,19 @@ class TRWBP : public BP { /// \name TRWBP accessors/mutators for scale parameters //@{ - /// Returns scale parameter of edge corresponding to the \a I 'th factor - Real edgeWeight( size_t I ) const { return _edge_weight[I]; } + /// Returns weight corresponding to the \a I 'th factor + Real Weight( size_t I ) const { return _weight[I]; } - /// Returns constant reference to vector of all factor scale parameters - const std::vector& edgeWeights() const { return _edge_weight; } + /// Returns constant reference to vector of all weights + const std::vector& Weights() const { return _weight; } - /// Sets the scale parameter of the \a I 'th factor to \a c - void setEdgeWeight( size_t I, Real c ) { _edge_weight[I] = c; } + /// Sets the weight of the \a I 'th factor to \a c + void setWeight( size_t I, Real c ) { _weight[I] = c; } - /// Sets the scale parameters of all factors simultaenously - /** \note Faster than calling setScaleF(size_t,Real) for each factor + /// Sets the weights of all factors simultaenously + /** \note Faster than calling setWeight(size_t,Real) for each factor */ - void setEdgeWeights( const std::vector &c ) { _edge_weight = c; } + void setWeights( const std::vector &c ) { _weight = c; } protected: // Calculate the updated message from the \a _I 'th neighbor of variable \a i to variable \a i diff --git a/src/bp.cpp b/src/bp.cpp index 5d4a79a..dc7dd91 100644 --- a/src/bp.cpp +++ b/src/bp.cpp @@ -145,70 +145,74 @@ void BP::calcNewMessage( size_t i, size_t _I ) { // calculate updated message I->i size_t I = nbV(i,_I); - Factor Fprod( factor(I) ); - Prob &prod = Fprod.p(); - if( props.logdomain ) - prod.takeLog(); + Prob marg; + if( factor(I).vars().size() == 1 ) // optimization + marg = factor(I).p(); + else { + Factor Fprod( factor(I) ); + Prob &prod = Fprod.p(); + if( props.logdomain ) + prod.takeLog(); + + // Calculate product of incoming messages and factor I + foreach( const Neighbor &j, nbF(I) ) + if( j != i ) { // for all j in I \ i + // prod_j will be the product of messages coming into j + Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); + foreach( const Neighbor &J, nbV(j) ) + if( J != I ) { // for all J in nb(j) \ I + if( props.logdomain ) + prod_j += message( j, J.iter ); + else + prod_j *= message( j, J.iter ); + } - // Calculate product of incoming messages and factor I - foreach( const Neighbor &j, nbF(I) ) - if( j != i ) { // for all j in I \ i - // prod_j will be the product of messages coming into j - Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); - foreach( const Neighbor &J, nbV(j) ) - if( J != I ) { // for all J in nb(j) \ I + // multiply prod with prod_j + if( !DAI_BP_FAST ) { + /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */ if( props.logdomain ) - prod_j += message( j, J.iter ); + Fprod += Factor( var(j), prod_j ); else - prod_j *= message( j, J.iter ); + Fprod *= Factor( var(j), prod_j ); + } else { + /* OPTIMIZED VERSION */ + size_t _I = j.dual; + // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k] + const ind_t &ind = index(j, _I); + for( size_t r = 0; r < prod.size(); ++r ) + if( props.logdomain ) + prod[r] += prod_j[ind[r]]; + else + prod[r] *= prod_j[ind[r]]; } - - // multiply prod with prod_j - if( !DAI_BP_FAST ) { - /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */ - if( props.logdomain ) - Fprod += Factor( var(j), prod_j ); - else - Fprod *= Factor( var(j), prod_j ); - } else { - /* OPTIMIZED VERSION */ - size_t _I = j.dual; - // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k] - const ind_t &ind = index(j, _I); - for( size_t r = 0; r < prod.size(); ++r ) - if( props.logdomain ) - prod[r] += prod_j[ind[r]]; - else - prod[r] *= prod_j[ind[r]]; } - } - if( props.logdomain ) { - prod -= prod.max(); - prod.takeExp(); - } + if( props.logdomain ) { + prod -= prod.max(); + prod.takeExp(); + } - // Marginalize onto i - Prob marg; - if( !DAI_BP_FAST ) { - /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */ - if( props.inference == Properties::InfType::SUMPROD ) - marg = Fprod.marginal( var(i) ).p(); - else - marg = Fprod.maxMarginal( var(i) ).p(); - } else { - /* OPTIMIZED VERSION */ - marg = Prob( var(i).states(), 0.0 ); - // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k] - const ind_t ind = index(i,_I); - if( props.inference == Properties::InfType::SUMPROD ) - for( size_t r = 0; r < prod.size(); ++r ) - marg[ind[r]] += prod[r]; - else - for( size_t r = 0; r < prod.size(); ++r ) - if( prod[r] > marg[ind[r]] ) - marg[ind[r]] = prod[r]; - marg.normalize(); + // Marginalize onto i + if( !DAI_BP_FAST ) { + /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */ + if( props.inference == Properties::InfType::SUMPROD ) + marg = Fprod.marginal( var(i) ).p(); + else + marg = Fprod.maxMarginal( var(i) ).p(); + } else { + /* OPTIMIZED VERSION */ + marg = Prob( var(i).states(), 0.0 ); + // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k] + const ind_t ind = index(i,_I); + if( props.inference == Properties::InfType::SUMPROD ) + for( size_t r = 0; r < prod.size(); ++r ) + marg[ind[r]] += prod[r]; + else + for( size_t r = 0; r < prod.size(); ++r ) + if( prod[r] > marg[ind[r]] ) + marg[ind[r]] = prod[r]; + marg.normalize(); + } } // Store result diff --git a/src/fbp.cpp b/src/fbp.cpp index 893388e..8cd51b0 100644 --- a/src/fbp.cpp +++ b/src/fbp.cpp @@ -34,12 +34,12 @@ Real FBP::logZ() const { Real sum = 0.0; for( size_t I = 0; I < nrFactors(); I++ ) { sum += (beliefF(I) * factor(I).log(true)).sum(); // FBP - sum += scaleF(I) * beliefF(I).entropy(); // FBP + sum += Weight(I) * beliefF(I).entropy(); // FBP } for( size_t i = 0; i < nrVars(); ++i ) { Real c_i = 0.0; foreach( const Neighbor &I, nbV(i) ) - c_i += scaleF(I); + c_i += Weight(I); sum += (1.0 - c_i) * beliefV(i).entropy(); // FBP } return sum; @@ -51,15 +51,15 @@ void FBP::calcNewMessage( size_t i, size_t _I ) { // calculate updated message I->i size_t I = nbV(i,_I); - Real scale = scaleF(I); // FBP: c_I + Real c_I = Weight(I); // FBP: c_I Factor Fprod( factor(I) ); Prob &prod = Fprod.p(); if( props.logdomain ) { prod.takeLog(); - prod *= (1/scale); // FBP + prod /= c_I; // FBP } else - prod ^= (1/scale); // FBP + prod ^= (1.0 / c_I); // FBP // Calculate product of incoming messages and factor I foreach( const Neighbor &j, nbF(I) ) @@ -76,9 +76,9 @@ void FBP::calcNewMessage( size_t i, size_t _I ) { } else { // FBP: multiply by m_Ij^(1-1/c_I) if( props.logdomain ) - prod_j += message( j, J.iter )*(1-1/scale); + prod_j += message( j, J.iter ) * (1.0 - 1.0 / c_I); else - prod_j *= message( j, J.iter )^(1-1/scale); + prod_j *= message( j, J.iter ) ^ (1.0 - 1.0 / c_I); } // multiply prod with prod_j @@ -90,6 +90,7 @@ void FBP::calcNewMessage( size_t i, size_t _I ) { Fprod *= Factor( var(j), prod_j ); } else { /* OPTIMIZED VERSION */ + size_t _I = j.dual; // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k] const ind_t &ind = index(j, _I); for( size_t r = 0; r < prod.size(); ++r ) @@ -129,7 +130,7 @@ void FBP::calcNewMessage( size_t i, size_t _I ) { } // FBP - marg ^= scale; + marg ^= c_I; // Store result if( props.logdomain ) @@ -145,16 +146,16 @@ void FBP::calcNewMessage( size_t i, size_t _I ) { /* This code has been copied from bp.cpp, except where comments indicate FBP-specific behaviour */ void FBP::calcBeliefF( size_t I, Prob &p ) const { - Real scale = scaleF(I); // FBP: c_I + Real c_I = Weight(I); // FBP: c_I Factor Fprod( factor(I) ); Prob &prod = Fprod.p(); if( props.logdomain ) { prod.takeLog(); - prod /= scale; // FBP + prod /= c_I; // FBP } else - prod ^= (1/scale); // FBP + prod ^= (1.0 / c_I); // FBP foreach( const Neighbor &j, nbF(I) ) { // prod_j will be the product of messages coming into j @@ -169,9 +170,9 @@ void FBP::calcBeliefF( size_t I, Prob &p ) const { } else { // FBP: multiply by m_Ij^(1-1/c_I) if( props.logdomain ) - prod_j += newMessage( j, J.iter)*(1-1/scale); + prod_j += newMessage( j, J.iter) * (1.0 - 1.0 / c_I); else - prod_j *= newMessage( j, J.iter)^(1-1/scale); + prod_j *= newMessage( j, J.iter) ^ (1.0 - 1.0 / c_I); } // multiply prod with prod_j @@ -202,7 +203,7 @@ void FBP::calcBeliefF( size_t I, Prob &p ) const { void FBP::construct() { BP::construct(); - _scale_factor.resize( nrFactors(), 1.0 ); + _weight.resize( nrFactors(), 1.0 ); } diff --git a/src/trwbp.cpp b/src/trwbp.cpp index 407d847..bda5bc7 100644 --- a/src/trwbp.cpp +++ b/src/trwbp.cpp @@ -32,12 +32,15 @@ string TRWBP::identify() const { Real TRWBP::logZ() const { Real sum = 0.0; for( size_t I = 0; I < nrFactors(); I++ ) { - sum += (beliefF(I) * factor(I).log(true)).sum(); // TRWBP - if( factor(I).vars().size() == 2 ) - sum -= edgeWeight(I) * MutualInfo( beliefF(I) ); // TRWBP + sum += (beliefF(I) * factor(I).log(true)).sum(); // TRWBP/FBP + sum += Weight(I) * beliefF(I).entropy(); // TRWBP/FBP + } + for( size_t i = 0; i < nrVars(); ++i ) { + Real c_i = 0.0; + foreach( const Neighbor &I, nbV(i) ) + c_i += Weight(I); + sum += (1.0 - c_i) * beliefV(i).entropy(); // TRWBP/FBP } - for( size_t i = 0; i < nrVars(); ++i ) - sum += beliefV(i).entropy(); // TRWBP return sum; } @@ -47,13 +50,12 @@ void TRWBP::calcNewMessage( size_t i, size_t _I ) { // calculate updated message I->i size_t I = nbV(i,_I); const Var &v_i = var(i); - const VarSet &v_I = factor(I).vars(); - Real c_I = edgeWeight(I); // TRWBP: c_I (\mu_I in the paper) + Real c_I = Weight(I); // TRWBP: c_I (\mu_I in the paper) Prob marg; - if( v_I.size() == 1 ) { // optimization + if( factor(I).vars().size() == 1 ) { // optimization marg = factor(I).p(); - } else + } else { Factor Fprod( factor(I) ); Prob &prod = Fprod.p(); if( props.logdomain ) { @@ -71,7 +73,7 @@ void TRWBP::calcNewMessage( size_t i, size_t _I ) { // prod_j will be the product of messages coming into j Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 ); foreach( const Neighbor &J, nbV(j) ) { - Real c_J = edgeWeight(J); + Real c_J = Weight(J); if( J != I ) { // for all J in nb(j) \ I if( props.logdomain ) prod_j += message( j, J.iter ) * c_J; @@ -94,6 +96,7 @@ void TRWBP::calcNewMessage( size_t i, size_t _I ) { Fprod *= Factor( v_j, prod_j ); } else { /* OPTIMIZED VERSION */ + size_t _I = j.dual; // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k] const ind_t &ind = index(j, _I); for( size_t r = 0; r < prod.size(); ++r ) @@ -130,17 +133,17 @@ void TRWBP::calcNewMessage( size_t i, size_t _I ) { marg[ind[r]] = prod[r]; marg.normalize(); } + } - // Store result - if( props.logdomain ) - newMessage(i,_I) = marg.log(); - else - newMessage(i,_I) = marg; + // Store result + if( props.logdomain ) + newMessage(i,_I) = marg.log(); + else + newMessage(i,_I) = marg; - // Update the residual if necessary - if( props.updates == Properties::UpdateType::SEQMAX ) - updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) ); - } + // Update the residual if necessary + if( props.updates == Properties::UpdateType::SEQMAX ) + updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) ); } @@ -148,7 +151,7 @@ void TRWBP::calcNewMessage( size_t i, size_t _I ) { void TRWBP::calcBeliefV( size_t i, Prob &p ) const { p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 ); foreach( const Neighbor &I, nbV(i) ) { - Real c_I = edgeWeight(I); + Real c_I = Weight(I); if( props.logdomain ) p += newMessage( i, I.iter ) * c_I; else @@ -159,8 +162,7 @@ void TRWBP::calcBeliefV( size_t i, Prob &p ) const { /* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */ void TRWBP::calcBeliefF( size_t I, Prob &p ) const { - Real c_I = edgeWeight(I); // TRWBP: c_I - const VarSet &v_I = factor(I).vars(); + Real c_I = Weight(I); // TRWBP: c_I Factor Fprod( factor(I) ); Prob &prod = Fprod.p(); @@ -179,7 +181,7 @@ void TRWBP::calcBeliefF( size_t I, Prob &p ) const { // prod_j will be the product of messages coming into j Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 ); foreach( const Neighbor &J, nbV(j) ) { - Real c_J = edgeWeight(J); + Real c_J = Weight(J); if( J != I ) { // for all J in nb(j) \ I if( props.logdomain ) prod_j += newMessage( j, J.iter ) * c_J; @@ -221,7 +223,7 @@ void TRWBP::calcBeliefF( size_t I, Prob &p ) const { void TRWBP::construct() { BP::construct(); - _edge_weight.resize( nrFactors(), 1.0 ); + _weight.resize( nrFactors(), 1.0 ); } diff --git a/tests/testall b/tests/testall index 5aac2d2..1a63d64 100755 --- a/tests/testall +++ b/tests/testall @@ -1,2 +1,2 @@ #!/bin/bash -./testdai --report-iters false --report-time false --marginals VAR --aliases aliases.conf --filename $1 --methods EXACT JTREE_HUGIN JTREE_SHSH BP_SEQFIX BP_SEQRND BP_SEQMAX BP_PARALL BP_SEQFIX_LOG BP_SEQRND_LOG BP_SEQMAX_LOG BP_PARALL_LOG MF_SEQRND TREEEP TREEEPWC GBP_MIN GBP_BETHE GBP_DELTA GBP_LOOP3 GBP_LOOP4 GBP_LOOP5 GBP_LOOP6 GBP_LOOP7 HAK_MIN HAK_BETHE HAK_DELTA HAK_LOOP3 HAK_LOOP4 HAK_LOOP5 MR_RESPPROP_FULL MR_CLAMPING_FULL MR_EXACT_FULL MR_RESPPROP_LINEAR MR_CLAMPING_LINEAR MR_EXACT_LINEAR LCBP_FULLCAV_SEQFIX LCBP_FULLCAVin_SEQFIX LCBP_FULLCAV_SEQRND LCBP_FULLCAVin_SEQRND LCBP_FULLCAV_NONE LCBP_FULLCAVin_NONE LCBP_PAIRCAV_SEQFIX LCBP_PAIRCAVin_SEQFIX LCBP_PAIRCAV_SEQRND LCBP_PAIRCAVin_SEQRND LCBP_PAIRCAV_NONE LCBP_PAIRCAVin_NONE LCBP_PAIR2CAV_SEQFIX LCBP_PAIR2CAVin_SEQFIX LCBP_PAIR2CAV_SEQRND LCBP_PAIR2CAVin_SEQRND LCBP_PAIR2CAV_NONE LCBP_PAIR2CAVin_NONE LCBP_UNICAV_SEQFIX LCBP_UNICAV_SEQRND BBP CBP +./testdai --report-iters false --report-time false --marginals VAR --aliases aliases.conf --filename $1 --methods EXACT JTREE_HUGIN JTREE_SHSH BP_SEQFIX BP_SEQRND BP_SEQMAX BP_PARALL BP_SEQFIX_LOG BP_SEQRND_LOG BP_SEQMAX_LOG BP_PARALL_LOG FBP TRWBP MF_SEQRND TREEEP TREEEPWC GBP_MIN GBP_BETHE GBP_DELTA GBP_LOOP3 GBP_LOOP4 GBP_LOOP5 GBP_LOOP6 GBP_LOOP7 HAK_MIN HAK_BETHE HAK_DELTA HAK_LOOP3 HAK_LOOP4 HAK_LOOP5 MR_RESPPROP_FULL MR_CLAMPING_FULL MR_EXACT_FULL MR_RESPPROP_LINEAR MR_CLAMPING_LINEAR MR_EXACT_LINEAR LCBP_FULLCAV_SEQFIX LCBP_FULLCAVin_SEQFIX LCBP_FULLCAV_SEQRND LCBP_FULLCAVin_SEQRND LCBP_FULLCAV_NONE LCBP_FULLCAVin_NONE LCBP_PAIRCAV_SEQFIX LCBP_PAIRCAVin_SEQFIX LCBP_PAIRCAV_SEQRND LCBP_PAIRCAVin_SEQRND LCBP_PAIRCAV_NONE LCBP_PAIRCAVin_NONE LCBP_PAIR2CAV_SEQFIX LCBP_PAIR2CAVin_SEQFIX LCBP_PAIR2CAV_SEQRND LCBP_PAIR2CAVin_SEQRND LCBP_PAIR2CAV_NONE LCBP_PAIR2CAVin_NONE LCBP_UNICAV_SEQFIX LCBP_UNICAV_SEQRND BBP CBP diff --git a/tests/testall.bat b/tests/testall.bat index a7ce8b8..12eda6f 100755 --- a/tests/testall.bat +++ b/tests/testall.bat @@ -1 +1 @@ -@testdai --report-iters false --report-time false --marginals VAR --aliases aliases.conf --filename %1 --methods EXACT JTREE_HUGIN JTREE_SHSH BP_SEQFIX BP_SEQRND BP_SEQMAX BP_PARALL BP_SEQFIX_LOG BP_SEQRND_LOG BP_SEQMAX_LOG BP_PARALL_LOG MF_SEQRND TREEEP TREEEPWC GBP_MIN GBP_BETHE GBP_DELTA GBP_LOOP3 GBP_LOOP4 GBP_LOOP5 GBP_LOOP6 GBP_LOOP7 HAK_MIN HAK_BETHE HAK_DELTA HAK_LOOP3 HAK_LOOP4 HAK_LOOP5 MR_RESPPROP_FULL MR_CLAMPING_FULL MR_EXACT_FULL MR_RESPPROP_LINEAR MR_CLAMPING_LINEAR MR_EXACT_LINEAR LCBP_FULLCAV_SEQFIX LCBP_FULLCAVin_SEQFIX LCBP_FULLCAV_SEQRND LCBP_FULLCAVin_SEQRND LCBP_FULLCAV_NONE LCBP_FULLCAVin_NONE LCBP_PAIRCAV_SEQFIX LCBP_PAIRCAVin_SEQFIX LCBP_PAIRCAV_SEQRND LCBP_PAIRCAVin_SEQRND LCBP_PAIRCAV_NONE LCBP_PAIRCAVin_NONE LCBP_PAIR2CAV_SEQFIX LCBP_PAIR2CAVin_SEQFIX LCBP_PAIR2CAV_SEQRND LCBP_PAIR2CAVin_SEQRND LCBP_PAIR2CAV_NONE LCBP_PAIR2CAVin_NONE LCBP_UNICAV_SEQFIX LCBP_UNICAV_SEQRND BBP CBP +@testdai --report-iters false --report-time false --marginals VAR --aliases aliases.conf --filename %1 --methods EXACT JTREE_HUGIN JTREE_SHSH BP_SEQFIX BP_SEQRND BP_SEQMAX BP_PARALL BP_SEQFIX_LOG BP_SEQRND_LOG BP_SEQMAX_LOG BP_PARALL_LOG FBP TRWBP MF_SEQRND TREEEP TREEEPWC GBP_MIN GBP_BETHE GBP_DELTA GBP_LOOP3 GBP_LOOP4 GBP_LOOP5 GBP_LOOP6 GBP_LOOP7 HAK_MIN HAK_BETHE HAK_DELTA HAK_LOOP3 HAK_LOOP4 HAK_LOOP5 MR_RESPPROP_FULL MR_CLAMPING_FULL MR_EXACT_FULL MR_RESPPROP_LINEAR MR_CLAMPING_LINEAR MR_EXACT_LINEAR LCBP_FULLCAV_SEQFIX LCBP_FULLCAVin_SEQFIX LCBP_FULLCAV_SEQRND LCBP_FULLCAVin_SEQRND LCBP_FULLCAV_NONE LCBP_FULLCAVin_NONE LCBP_PAIRCAV_SEQFIX LCBP_PAIRCAVin_SEQFIX LCBP_PAIRCAV_SEQRND LCBP_PAIRCAVin_SEQRND LCBP_PAIRCAV_NONE LCBP_PAIRCAVin_NONE LCBP_PAIR2CAV_SEQFIX LCBP_PAIR2CAVin_SEQFIX LCBP_PAIR2CAV_SEQRND LCBP_PAIR2CAVin_SEQRND LCBP_PAIR2CAV_NONE LCBP_PAIR2CAVin_NONE LCBP_UNICAV_SEQFIX LCBP_UNICAV_SEQRND BBP CBP diff --git a/tests/testfast.out b/tests/testfast.out index 46a5013..30c1492 100644 --- a/tests/testfast.out +++ b/tests/testfast.out @@ -187,6 +187,40 @@ BP_PARALL_LOG 9.483e-02 3.078e-02 +1.737e-02 1.000e-09 # ({x13}, (5.266e-01, 4.734e-01)) # ({x14}, (6.033e-01, 3.967e-01)) # ({x15}, (1.558e-01, 8.442e-01)) +FBP 9.483e-02 3.078e-02 +1.737e-02 1.000e-09 +# ({x0}, (4.233e-01, 5.767e-01)) +# ({x1}, (5.422e-01, 4.578e-01)) +# ({x2}, (4.662e-01, 5.338e-01)) +# ({x3}, (5.424e-01, 4.576e-01)) +# ({x4}, (6.042e-01, 3.958e-01)) +# ({x5}, (1.845e-01, 8.155e-01)) +# ({x6}, (8.203e-01, 1.797e-01)) +# ({x7}, (2.292e-01, 7.708e-01)) +# ({x8}, (3.119e-01, 6.881e-01)) +# ({x9}, (2.975e-01, 7.025e-01)) +# ({x10}, (7.268e-01, 2.732e-01)) +# ({x11}, (1.485e-01, 8.515e-01)) +# ({x12}, (4.512e-01, 5.488e-01)) +# ({x13}, (5.266e-01, 4.734e-01)) +# ({x14}, (6.033e-01, 3.967e-01)) +# ({x15}, (1.558e-01, 8.442e-01)) +TRWBP 9.483e-02 3.078e-02 +1.737e-02 1.000e-09 +# ({x0}, (4.233e-01, 5.767e-01)) +# ({x1}, (5.422e-01, 4.578e-01)) +# ({x2}, (4.662e-01, 5.338e-01)) +# ({x3}, (5.424e-01, 4.576e-01)) +# ({x4}, (6.042e-01, 3.958e-01)) +# ({x5}, (1.845e-01, 8.155e-01)) +# ({x6}, (8.203e-01, 1.797e-01)) +# ({x7}, (2.292e-01, 7.708e-01)) +# ({x8}, (3.119e-01, 6.881e-01)) +# ({x9}, (2.975e-01, 7.025e-01)) +# ({x10}, (7.268e-01, 2.732e-01)) +# ({x11}, (1.485e-01, 8.515e-01)) +# ({x12}, (4.512e-01, 5.488e-01)) +# ({x13}, (5.266e-01, 4.734e-01)) +# ({x14}, (6.033e-01, 3.967e-01)) +# ({x15}, (1.558e-01, 8.442e-01)) MF_SEQRND 3.607e-01 1.904e-01 -9.409e-02 1.000e-09 # ({x0}, (2.053e-01, 7.947e-01)) # ({x1}, (9.163e-01, 8.373e-02)) -- 2.20.1