Merged SVN head ...
[libdai.git] / include / dai / treeep.h
index 9120085..fd039a6 100644 (file)
@@ -1,6 +1,6 @@
 /*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
     Radboud University Nijmegen, The Netherlands
-    
+
     This file is part of libDAI.
 
     libDAI is free software; you can redistribute it and/or modify
 namespace dai {
 
 
-class TreeEPSubTree {
-    protected:
-        std::vector<Factor>  _Qa;
-        std::vector<Factor>  _Qb;
-        DEdgeVec             _RTree;
-        std::vector<size_t>  _a;        // _Qa[alpha]  <->  superTree._Qa[_a[alpha]]
-        std::vector<size_t>  _b;        // _Qb[beta]   <->  superTree._Qb[_b[beta]]
-                                        // _Qb[beta]   <->  _RTree[beta]    
-        const Factor *       _I;
-        VarSet               _ns;
-        VarSet               _nsrem;
-        double               _logZ;
-        
-        
-    public:
-        TreeEPSubTree() : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(NULL), _ns(), _nsrem(), _logZ(0.0) {}
-        TreeEPSubTree( const TreeEPSubTree &x) : _Qa(x._Qa), _Qb(x._Qb), _RTree(x._RTree), _a(x._a), _b(x._b), _I(x._I), _ns(x._ns), _nsrem(x._nsrem), _logZ(x._logZ) {}
-        TreeEPSubTree & operator=( const TreeEPSubTree& x ) {
-            if( this != &x ) {
-                _Qa         = x._Qa;
-                _Qb         = x._Qb;
-                _RTree      = x._RTree;
-                _a          = x._a;
-                _b          = x._b;
-                _I          = x._I;
-                _ns         = x._ns;
-                _nsrem      = x._nsrem;
-                _logZ       = x._logZ;
-            }
-            return *this;
-        }
-
-        TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I );
-        void init();
-        void InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb );
-        void HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb );
-        double logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const;
-        const Factor *& I() { return _I; }
-};
-
-
 class TreeEP : public JTree {
     protected:
-        std::map<size_t, TreeEPSubTree>  _Q;
+        /// Maximum difference encountered so far
+        double                  _maxdiff;
+        /// Number of iterations needed
+        size_t                  _iters;
 
     public:
         struct Properties {
@@ -91,52 +53,116 @@ class TreeEP : public JTree {
             double tol;
             DAI_ENUM(TypeType,ORG,ALT)
             TypeType type;
-        } props;
-        double maxdiff;
+        } props; // FIXME: should be props2 because of conflict with JTree::props?
+        /// Name of this inference method
+        static const char *Name;
+
+    protected:
+        class TreeEPSubTree {
+            protected:
+                std::vector<Factor>  _Qa;
+                std::vector<Factor>  _Qb;
+                DEdgeVec             _RTree;
+                std::vector<size_t>  _a;        // _Qa[alpha]  <->  superTree._Qa[_a[alpha]]
+                std::vector<size_t>  _b;        // _Qb[beta]   <->  superTree._Qb[_b[beta]]
+                                                // _Qb[beta]   <->  _RTree[beta]    
+                const Factor *       _I;
+                VarSet               _ns;
+                VarSet               _nsrem;
+                double               _logZ;
+                
+                
+            public:
+                TreeEPSubTree() : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(NULL), _ns(), _nsrem(), _logZ(0.0) {}
+                TreeEPSubTree( const TreeEPSubTree &x) : _Qa(x._Qa), _Qb(x._Qb), _RTree(x._RTree), _a(x._a), _b(x._b), _I(x._I), _ns(x._ns), _nsrem(x._nsrem), _logZ(x._logZ) {}
+                TreeEPSubTree & operator=( const TreeEPSubTree& x ) {
+                    if( this != &x ) {
+                        _Qa         = x._Qa;
+                        _Qb         = x._Qb;
+                        _RTree      = x._RTree;
+                        _a          = x._a;
+                        _b          = x._b;
+                        _I          = x._I;
+                        _ns         = x._ns;
+                        _nsrem      = x._nsrem;
+                        _logZ       = x._logZ;
+                    }
+                    return *this;
+                }
+
+                TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I );
+                void init();
+                void InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb );
+                void HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb );
+                double logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const;
+                const Factor *& I() { return _I; }
+        };
+
+        std::map<size_t, TreeEPSubTree>  _Q;
 
     public:
         /// Default constructor
-        TreeEP() : JTree(), _Q(), props(), maxdiff() {};
+        TreeEP() : JTree(), _maxdiff(0.0), _iters(0), props(), _Q() {}
+
+        /// Construct from FactorGraph fg and PropertySet opts
+        TreeEP( const FactorGraph &fg, const PropertySet &opts );
+
         /// Copy constructor
-        TreeEP( const TreeEP& x ) : JTree(x), _Q(x._Q), props(x.props), maxdiff(x.maxdiff) {
+        TreeEP( const TreeEP &x ) : JTree(x), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props), _Q(x._Q) {
             for( size_t I = 0; I < nrFactors(); I++ )
                 if( offtree( I ) )
                     _Q[I].I() = &factor(I);
         }
-        TreeEP* clone() const { return new TreeEP(*this); }
-        /// Create (virtual constructor)
+
+        /// Clone *this (virtual copy constructor)
+        virtual TreeEP* clone() const { return new TreeEP(*this); }
+
+        /// Create (virtual default constructor)
         virtual TreeEP* create() const { return new TreeEP(); }
-        TreeEP & operator=( const TreeEP& x ) {
+
+        /// Assignment operator
+        TreeEP& operator=( const TreeEP &x ) {
             if( this != &x ) {
-                JTree::operator=(x);
-                _Q   = x._Q;
+                JTree::operator=( x );
+                _maxdiff = x._maxdiff;
+                _iters   = x._iters;
+                props    = x.props;
+                _Q       = x._Q;
                 for( size_t I = 0; I < nrFactors(); I++ )
                     if( offtree( I ) )
                         _Q[I].I() = &factor(I);
-                props = x.props;
-                maxdiff = x.maxdiff;
             }
             return *this;
         }
-        TreeEP( const FactorGraph &fg, const PropertySet &opts );
-        void ConstructRG( const DEdgeVec &tree );
 
-        static const char *Name;
-        std::string identify() const;
-        void init();
+        /// Identifies itself for logging purposes
+        virtual std::string identify() const;
+
+        /// Get log partition sum
+        virtual Real logZ() const;
+
+        /// Clear messages and beliefs
+        virtual void init();
+
         /// Clear messages and beliefs corresponding to the nodes in ns
         virtual void init( const VarSet &/*ns*/ ) { init(); }
-        double run();
-        Real logZ() const;
 
-        bool offtree( size_t I ) const { return (fac2OR[I] == -1U); }
+        /// The actual approximate inference algorithm
+        virtual double run();
+
+        /// Return maximum difference between single node beliefs in the last pass
+        virtual double maxDiff() const { return _maxdiff; }
 
-        void restoreFactors( const VarSet &ns ) { RegionGraph::restoreFactors( ns ); init( ns ); }
+        /// Return number of passes over the factorgraph
+        virtual size_t Iterations() const { return _iters; }
+
+
+        void ConstructRG( const DEdgeVec &tree );
+        bool offtree( size_t I ) const { return (fac2OR[I] == -1U); }
 
         void setProperties( const PropertySet &opts );
         PropertySet getProperties() const;
         std::string printProperties() const;
-        double maxDiff() const { return maxdiff; }
 };