From bebd3a3b241a326cf60d017026ccc91e75fa90e4 Mon Sep 17 00:00:00 2001 From: Joris Mooij Date: Fri, 17 Jul 2009 18:55:48 +0200 Subject: [PATCH] Improvements of TFactor - Extended functionality of TFactor::operators +,-,+=,-= to handle different VarSets - Added TFactor::maxMarginal (similar to marginal but with max instead of sum) --- include/dai/factor.h | 130 +++++++++++++++++++++++-------------------- 1 file changed, 69 insertions(+), 61 deletions(-) diff --git a/include/dai/factor.h b/include/dai/factor.h index 24d136f..9027b31 100644 --- a/include/dai/factor.h +++ b/include/dai/factor.h @@ -32,6 +32,7 @@ #include +#include #include #include #include @@ -41,6 +42,17 @@ namespace dai { +// Function object similar to std::divides(), but different in that dividing by zero results in zero +template struct divides0 : public std::binary_function { + T operator()(const T& i, const T& j) const { + if( j == (T)0 ) + return (T)0; + else + return i / j; + } +}; + + /// Represents a (probability) factor. /** Mathematically, a \e factor is a function mapping joint states of some * variables to the nonnegative real numbers. @@ -222,64 +234,58 @@ template class TFactor { return *this; } + /// Adds the TFactor f to *this + TFactor& operator+= (const TFactor& f) { + if( f._vs == _vs ) // optimize special case + _p += f._p; + else + *this = (*this + f); + return *this; + } + + /// Subtracts the TFactor f from *this + TFactor& operator-= (const TFactor& f) { + if( f._vs == _vs ) // optimize special case + _p -= f._p; + else + *this = (*this - f); + return *this; + } + /// Returns product of *this with the TFactor f /** The product of two factors is defined as follows: if * \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then * \f[fg : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto f(x_L) g(x_M).\f] */ - TFactor operator* (const TFactor& f) const; + TFactor operator* (const TFactor& f) const { + return pointwiseOp(*this,f,std::multiplies()); + } /// Returns quotient of *this by the TFactor f /** The quotient of two factors is defined as follows: if * \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then * \f[\frac{f}{g} : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto \frac{f(x_L)}{g(x_M)}.\f] */ - TFactor operator/ (const TFactor& f) const; - - /// Adds the TFactor f to *this - /** \pre this->vars() == f.vars() - */ - TFactor& operator+= (const TFactor& f) { -#ifdef DAI_DEBUG - assert( f._vs == _vs ); -#endif - _p += f._p; - return *this; - } - - /// Subtracts the TFactor f from *this - /** \pre this->vars() == f.vars() - */ - TFactor& operator-= (const TFactor& f) { -#ifdef DAI_DEBUG - assert( f._vs == _vs ); -#endif - _p -= f._p; - return *this; + TFactor operator/ (const TFactor& f) const { + return pointwiseOp(*this,f,divides0()); } /// Returns sum of *this and the TFactor f - /** \pre this->vars() == f.vars() + /** The sum of two factors is defined as follows: if + * \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then + * \f[f+g : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto f(x_L) + g(x_M).\f] */ TFactor operator+ (const TFactor& f) const { -#ifdef DAI_DEBUG - assert( f._vs == _vs ); -#endif - TFactor sum(*this); - sum._p += f._p; - return sum; + return pointwiseOp(*this,f,std::plus()); } /// Returns *this minus the TFactor f - /** \pre this->vars() == f.vars() + /** The difference of two factors is defined as follows: if + * \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then + * \f[f-g : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto f(x_L) - g(x_M).\f] */ TFactor operator- (const TFactor& f) const { -#ifdef DAI_DEBUG - assert( f._vs == _vs ); -#endif - TFactor sum(*this); - sum._p -= f._p; - return sum; + return pointwiseOp(*this,f,std::minus()); } @@ -372,6 +378,9 @@ template class TFactor { /// Returns marginal on ns, obtained by summing out all variables except those in ns, and normalizing the result if normed==true TFactor marginal(const VarSet & ns, bool normed=true) const; + /// Returns max-marginal on ns, obtained by maximizing all variables except those in ns, and normalizing the result if normed==true + TFactor maxMarginal(const VarSet & ns, bool normed=true) const; + /// Embeds this factor in a larger VarSet /** \pre vars() should be a subset of ns * @@ -428,40 +437,39 @@ template TFactor TFactor::marginal(const VarSet & ns, bool nor } -template TFactor TFactor::operator* (const TFactor& f) const { - if( f._vs == _vs ) { // optimizate special case - TFactor prod(*this); - prod._p *= f._p; - return prod; - } else { - TFactor prod( _vs | f._vs, 0.0 ); +template TFactor TFactor::maxMarginal(const VarSet & ns, bool normed) const { + VarSet res_ns = ns & _vs; + + TFactor res( res_ns, 0.0 ); - IndexFor i1(_vs, prod._vs); - IndexFor i2(f._vs, prod._vs); + IndexFor i_res( res_ns, _vs ); + for( size_t i = 0; i < _p.size(); i++, ++i_res ) + if( _p[i] > res._p[i_res] ) + res._p[i_res] = _p[i]; - for( size_t i = 0; i < prod._p.size(); i++, ++i1, ++i2 ) - prod._p[i] += _p[i1] * f._p[i2]; + if( normed ) + res.normalize( Prob::NORMPROB ); - return prod; - } + return res; } -template TFactor TFactor::operator/ (const TFactor& f) const { - if( f._vs == _vs ) { // optimizate special case - TFactor quot(*this); - quot._p /= f._p; - return quot; +template TFactor pointwiseOp( const TFactor &f, const TFactor &g, binaryOp op ) { + if( f.vars() == g.vars() ) { // optimizate special case + TFactor result(f); + for( size_t i = 0; i < result.states(); i++ ) + result[i] = op( result[i], g[i] ); + return result; } else { - TFactor quot( _vs | f._vs, 0.0 ); + TFactor result( f.vars() | g.vars(), 0.0 ); - IndexFor i1(_vs, quot._vs); - IndexFor i2(f._vs, quot._vs); + IndexFor i1(f.vars(), result.vars()); + IndexFor i2(g.vars(), result.vars()); - for( size_t i = 0; i < quot._p.size(); i++, ++i1, ++i2 ) - quot._p[i] += _p[i1] / f._p[i2]; + for( size_t i = 0; i < result.states(); i++, ++i1, ++i2 ) + result[i] = op( f[i1], g[i2] ); - return quot; + return result; } } -- 2.20.1