EM bugfix. Convenience methods in Factor, Permute, Properties, EM.
[libdai.git] / include / dai / factor.h
index b441116..1427002 100644 (file)
@@ -32,6 +32,7 @@
 
 
 #include <iostream>
+#include <functional>
 #include <cmath>
 #include <dai/prob.h>
 #include <dai/varset.h>
 namespace dai {
 
 
+/// Function object similar to std::divides(), but different in that dividing by zero results in zero
+template<typename T> struct divides0 : public std::binary_function<T, T, T> {
+    T operator()(const T& i, const T& j) const {
+        if( j == (T)0 )
+            return (T)0;
+        else
+            return i / j;
+    }
+};
+
+
 /// Represents a (probability) factor.
 /** Mathematically, a \e factor is a function mapping joint states of some
  *  variables to the nonnegative real numbers.
@@ -101,6 +113,12 @@ template <typename T> class TFactor {
             assert( _vs.nrStates() == _p.size() );
 #endif
         }
+        TFactor( const std::vector< Var >& vars, const std::vector< T >& p ) : _vs(vars.begin(), vars.end(), vars.size()), _p(p.size()) {
+            Permute permindex(vars);
+            for (size_t li = 0; li < p.size(); ++li) {
+                _p[permindex.convert_linear_index(li)] = p[li];
+            }
+        }
         
         /// Constructs TFactor depending on the variable n, with uniform distribution
         TFactor( const Var& n ) : _vs(n), _p(n.states()) {}
@@ -222,64 +240,58 @@ template <typename T> class TFactor {
             return *this;
         }
 
+        /// Adds the TFactor f to *this
+        TFactor<T>& operator+= (const TFactor<T>& f) { 
+            if( f._vs == _vs ) // optimize special case
+                _p += f._p;
+            else
+                *this = (*this + f); 
+            return *this;
+        }
+
+        /// Subtracts the TFactor f from *this
+        TFactor<T>& operator-= (const TFactor<T>& f) { 
+            if( f._vs == _vs ) // optimize special case
+                _p -= f._p;
+            else
+                *this = (*this - f); 
+            return *this;
+        }
+
         /// Returns product of *this with the TFactor f
         /** The product of two factors is defined as follows: if 
          *  \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then
          *  \f[fg : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto f(x_L) g(x_M).\f]
          */
-        TFactor<T> operator* (const TFactor<T>& f) const;
+        TFactor<T> operator* (const TFactor<T>& f) const {
+            return pointwiseOp(*this,f,std::multiplies<T>());
+        }
 
         /// Returns quotient of *this by the TFactor f
         /** The quotient of two factors is defined as follows: if 
          *  \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then
          *  \f[\frac{f}{g} : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto \frac{f(x_L)}{g(x_M)}.\f]
          */
-        TFactor<T> operator/ (const TFactor<T>& f) const;
-
-        /// Adds the TFactor f to *this
-        /** \pre this->vars() == f.vars()
-         */
-        TFactor<T>& operator+= (const TFactor<T>& f) { 
-#ifdef DAI_DEBUG
-            assert( f._vs == _vs );
-#endif
-            _p += f._p;
-            return *this;
-        }
-
-        /// Subtracts the TFactor f from *this
-        /** \pre this->vars() == f.vars()
-         */
-        TFactor<T>& operator-= (const TFactor<T>& f) { 
-#ifdef DAI_DEBUG
-            assert( f._vs == _vs );
-#endif
-            _p -= f._p;
-            return *this;
+        TFactor<T> operator/ (const TFactor<T>& f) const {
+            return pointwiseOp(*this,f,divides0<T>());
         }
 
         /// Returns sum of *this and the TFactor f
-        /** \pre this->vars() == f.vars()
+        /** The sum of two factors is defined as follows: if 
+         *  \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then
+         *  \f[f+g : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto f(x_L) + g(x_M).\f]
          */
         TFactor<T> operator+ (const TFactor<T>& f) const {
-#ifdef DAI_DEBUG
-            assert( f._vs == _vs );
-#endif
-            TFactor<T> sum(*this); 
-            sum._p += f._p; 
-            return sum; 
+            return pointwiseOp(*this,f,std::plus<T>());
         }
 
         /// Returns *this minus the TFactor f
-        /** \pre this->vars() == f.vars()
+        /** The difference of two factors is defined as follows: if 
+         *  \f$f : \prod_{l\in L} X_l \to [0,\infty)\f$ and \f$g : \prod_{m\in M} X_m \to [0,\infty)\f$, then
+         *  \f[f-g : \prod_{l\in L\cup M} X_l \to [0,\infty) : x \mapsto f(x_L) - g(x_M).\f]
          */
         TFactor<T> operator- (const TFactor<T>& f) const {
-#ifdef DAI_DEBUG
-            assert( f._vs == _vs );
-#endif
-            TFactor<T> sum(*this); 
-            sum._p -= f._p; 
-            return sum; 
+            return pointwiseOp(*this,f,std::minus<T>());
         }
 
 
@@ -372,6 +384,9 @@ template <typename T> class TFactor {
         /// Returns marginal on ns, obtained by summing out all variables except those in ns, and normalizing the result if normed==true
         TFactor<T> marginal(const VarSet & ns, bool normed=true) const;
 
+        /// Returns max-marginal on ns, obtained by maximizing all variables except those in ns, and normalizing the result if normed==true
+        TFactor<T> maxMarginal(const VarSet & ns, bool normed=true) const;
+
         /// Embeds this factor in a larger VarSet
         /** \pre vars() should be a subset of ns
          *
@@ -383,7 +398,7 @@ template <typename T> class TFactor {
             if( _vs == ns )
                 return *this;
             else
-                return (*this) * TFactor<T>(ns / _vs, 1);
+                return (*this) * TFactor<T>(ns / _vs, (T)1);
         }
 
         /// Returns true if *this has NaN values
@@ -428,40 +443,39 @@ template<typename T> TFactor<T> TFactor<T>::marginal(const VarSet & ns, bool nor
 }
 
 
-template<typename T> TFactor<T> TFactor<T>::operator* (const TFactor<T>& f) const {
-    if( f._vs == _vs ) { // optimizate special case
-        TFactor<T> prod(*this); 
-        prod._p *= f._p; 
-        return prod; 
-    } else {
-        TFactor<T> prod( _vs | f._vs, 0.0 );
+template<typename T> TFactor<T> TFactor<T>::maxMarginal(const VarSet & ns, bool normed) const {
+    VarSet res_ns = ns & _vs;
 
-        IndexFor i1(_vs, prod._vs);
-        IndexFor i2(f._vs, prod._vs);
+    TFactor<T> res( res_ns, 0.0 );
 
-        for( size_t i = 0; i < prod._p.size(); i++, ++i1, ++i2 )
-            prod._p[i] += _p[i1] * f._p[i2];
+    IndexFor i_res( res_ns, _vs );
+    for( size_t i = 0; i < _p.size(); i++, ++i_res )
+        if( _p[i] > res._p[i_res] )
+            res._p[i_res] = _p[i];
 
-        return prod;
-    }
+    if( normed )
+        res.normalize( Prob::NORMPROB );
+
+    return res;
 }
 
 
-template<typename T> TFactor<T> TFactor<T>::operator/ (const TFactor<T>& f) const {
-    if( f._vs == _vs ) { // optimizate special case
-        TFactor<T> quot(*this); 
-        quot._p /= f._p; 
-        return quot; 
+template<typename T, typename binaryOp> TFactor<T> pointwiseOp( const TFactor<T> &f, const TFactor<T> &g, binaryOp op ) {
+    if( f.vars() == g.vars() ) { // optimizate special case
+        TFactor<T> result(f); 
+        for( size_t i = 0; i < result.states(); i++ )
+            result[i] = op( result[i], g[i] );
+        return result; 
     } else {
-        TFactor<T> quot( _vs | f._vs, 0.0 );
+        TFactor<T> result( f.vars() | g.vars(), 0.0 );
 
-        IndexFor i1(_vs, quot._vs);
-        IndexFor i2(f._vs, quot._vs);
+        IndexFor i1(f.vars(), result.vars());
+        IndexFor i2(g.vars(), result.vars());
 
-        for( size_t i = 0; i < quot._p.size(); i++, ++i1, ++i2 )
-            quot._p[i] += _p[i1] / f._p[i2];
+        for( size_t i = 0; i < result.states(); i++, ++i1, ++i2 )
+            result[i] = op( f[i1], g[i2] );
 
-        return quot;
+        return result;
     }
 }
 
@@ -554,7 +568,7 @@ template<typename T> Real MutualInfo(const TFactor<T> &f) {
     VarSet::const_iterator it = f.vars().begin();
     Var i = *it; it++; Var j = *it;
     TFactor<T> projection = f.marginal(i) * f.marginal(j);
-    return real( dist( f.normalized(), projection, Prob::DISTKL ) );
+    return dist( f.normalized(), projection, Prob::DISTKL );
 }