Improvements of TFactor<T>
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Fri, 17 Jul 2009 16:55:48 +0000 (18:55 +0200)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Fri, 17 Jul 2009 16:55:48 +0000 (18:55 +0200)
- Extended functionality of TFactor<T>::operators +,-,+=,-= to handle different VarSets
- Added TFactor<T>::maxMarginal (similar to marginal but with max instead of sum)

include/dai/factor.h

index 24d136f..9027b31 100644 (file)
@@ -32,6 +32,7 @@
 
 
 #include <iostream>
+#include <functional>
 #include <cmath>
 #include <dai/prob.h>
 #include <dai/varset.h>
 namespace dai {
 
 
+// Function object similar to std::divides(), but different in that dividing by zero results in zero
+template<typename T> struct divides0 : public std::binary_function<T, T, T> {
+    T operator()(const T& i, const T& j) const {
+        if( j == (T)0 )
+            return (T)0;
+        else
+            return i / j;
+    }
+};
+
+
 /// Represents a (probability) factor.
 /** Mathematically, a \e factor is a function mapping joint states of some
  *  variables to the nonnegative real numbers.
@@ -222,64 +234,58 @@ template <typename T> class TFactor {
             return *this;
         }
 
+        /// Adds the TFactor f to *this
+        TFactor<T>& operator+= (const TFactor<T>& f) { 
+            if( f._vs == _vs ) // optimize special case
+                _p += f._p;
+            else
+                *this = (*this + f); 
+            return *this;
+        }
+
+        /// Subtracts the TFactor f from *this
+        TFactor<T>& operator-= (const TFactor<T>& f) { 
+            if( f._vs == _vs ) // optimize special case
+                _p -= f._p;
+            else
+                *this = (*this - f); 
+            return *this;
+        }
+
         /// Returns product of *this with the TFactor f
         /** The product of two factors is defined as follows: if 
          *  \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then
          *  \f[fg : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto f(x_L) g(x_M).\f]
          */
-        TFactor<T> operator* (const TFactor<T>& f) const;
+        TFactor<T> operator* (const TFactor<T>& f) const {
+            return pointwiseOp(*this,f,std::multiplies<T>());
+        }
 
         /// Returns quotient of *this by the TFactor f
         /** The quotient of two factors is defined as follows: if 
          *  \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then
          *  \f[\frac{f}{g} : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto \frac{f(x_L)}{g(x_M)}.\f]
          */
-        TFactor<T> operator/ (const TFactor<T>& f) const;
-
-        /// Adds the TFactor f to *this
-        /** \pre this->vars() == f.vars()
-         */
-        TFactor<T>& operator+= (const TFactor<T>& f) { 
-#ifdef DAI_DEBUG
-            assert( f._vs == _vs );
-#endif
-            _p += f._p;
-            return *this;
-        }
-
-        /// Subtracts the TFactor f from *this
-        /** \pre this->vars() == f.vars()
-         */
-        TFactor<T>& operator-= (const TFactor<T>& f) { 
-#ifdef DAI_DEBUG
-            assert( f._vs == _vs );
-#endif
-            _p -= f._p;
-            return *this;
+        TFactor<T> operator/ (const TFactor<T>& f) const {
+            return pointwiseOp(*this,f,divides0<T>());
         }
 
         /// Returns sum of *this and the TFactor f
-        /** \pre this->vars() == f.vars()
+        /** The sum of two factors is defined as follows: if 
+         *  \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then
+         *  \f[f+g : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto f(x_L) + g(x_M).\f]
          */
         TFactor<T> operator+ (const TFactor<T>& f) const {
-#ifdef DAI_DEBUG
-            assert( f._vs == _vs );
-#endif
-            TFactor<T> sum(*this); 
-            sum._p += f._p; 
-            return sum; 
+            return pointwiseOp(*this,f,std::plus<T>());
         }
 
         /// Returns *this minus the TFactor f
-        /** \pre this->vars() == f.vars()
+        /** The difference of two factors is defined as follows: if 
+         *  \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then
+         *  \f[f-g : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto f(x_L) - g(x_M).\f]
          */
         TFactor<T> operator- (const TFactor<T>& f) const {
-#ifdef DAI_DEBUG
-            assert( f._vs == _vs );
-#endif
-            TFactor<T> sum(*this); 
-            sum._p -= f._p; 
-            return sum; 
+            return pointwiseOp(*this,f,std::minus<T>());
         }
 
 
@@ -372,6 +378,9 @@ template <typename T> class TFactor {
         /// Returns marginal on ns, obtained by summing out all variables except those in ns, and normalizing the result if normed==true
         TFactor<T> marginal(const VarSet & ns, bool normed=true) const;
 
+        /// Returns max-marginal on ns, obtained by maximizing all variables except those in ns, and normalizing the result if normed==true
+        TFactor<T> maxMarginal(const VarSet & ns, bool normed=true) const;
+
         /// Embeds this factor in a larger VarSet
         /** \pre vars() should be a subset of ns
          *
@@ -428,40 +437,39 @@ template<typename T> TFactor<T> TFactor<T>::marginal(const VarSet & ns, bool nor
 }
 
 
-template<typename T> TFactor<T> TFactor<T>::operator* (const TFactor<T>& f) const {
-    if( f._vs == _vs ) { // optimizate special case
-        TFactor<T> prod(*this); 
-        prod._p *= f._p; 
-        return prod; 
-    } else {
-        TFactor<T> prod( _vs | f._vs, 0.0 );
+template<typename T> TFactor<T> TFactor<T>::maxMarginal(const VarSet & ns, bool normed) const {
+    VarSet res_ns = ns & _vs;
+
+    TFactor<T> res( res_ns, 0.0 );
 
-        IndexFor i1(_vs, prod._vs);
-        IndexFor i2(f._vs, prod._vs);
+    IndexFor i_res( res_ns, _vs );
+    for( size_t i = 0; i < _p.size(); i++, ++i_res )
+        if( _p[i] > res._p[i_res] )
+            res._p[i_res] = _p[i];
 
-        for( size_t i = 0; i < prod._p.size(); i++, ++i1, ++i2 )
-            prod._p[i] += _p[i1] * f._p[i2];
+    if( normed )
+        res.normalize( Prob::NORMPROB );
 
-        return prod;
-    }
+    return res;
 }
 
 
-template<typename T> TFactor<T> TFactor<T>::operator/ (const TFactor<T>& f) const {
-    if( f._vs == _vs ) { // optimizate special case
-        TFactor<T> quot(*this); 
-        quot._p /= f._p; 
-        return quot; 
+template<typename T, typename binaryOp> TFactor<T> pointwiseOp( const TFactor<T> &f, const TFactor<T> &g, binaryOp op ) {
+    if( f.vars() == g.vars() ) { // optimizate special case
+        TFactor<T> result(f); 
+        for( size_t i = 0; i < result.states(); i++ )
+            result[i] = op( result[i], g[i] );
+        return result; 
     } else {
-        TFactor<T> quot( _vs | f._vs, 0.0 );
+        TFactor<T> result( f.vars() | g.vars(), 0.0 );
 
-        IndexFor i1(_vs, quot._vs);
-        IndexFor i2(f._vs, quot._vs);
+        IndexFor i1(f.vars(), result.vars());
+        IndexFor i2(g.vars(), result.vars());
 
-        for( size_t i = 0; i < quot._p.size(); i++, ++i1, ++i2 )
-            quot._p[i] += _p[i1] / f._p[i2];
+        for( size_t i = 0; i < result.states(); i++, ++i1, ++i2 )
+            result[i] = op( f[i1], g[i2] );
 
-        return quot;
+        return result;
     }
 }