Merged SVN head ...
[libdai.git] / include / dai / hak.h
index b782fd9..159c2f7 100644 (file)
@@ -1,6 +1,6 @@
 /*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
     Radboud University Nijmegen, The Netherlands
-    
+
     This file is part of libDAI.
 
     libDAI is free software; you can redistribute it and/or modify
@@ -40,50 +40,89 @@ class HAK : public DAIAlgRG {
         std::vector<Factor>                _Qb;
         std::vector<std::vector<Factor> >  _muab;
         std::vector<std::vector<Factor> >  _muba;
+        /// Maximum difference encountered so far
+        double _maxdiff;
+        /// Number of iterations needed
+        size_t _iters;
 
     public:
         struct Properties {
             size_t verbose;
             size_t maxiter;
             double tol;
+            double damping;
             DAI_ENUM(ClustersType,MIN,DELTA,LOOP)
             ClustersType clusters;
             bool doubleloop;
             size_t loopdepth;
         } props;
-        double maxdiff;
-        
+        /// Name of this inference method
+        static const char *Name;
+
     public:
         /// Default constructor
-        HAK() : DAIAlgRG(), _Qa(), _Qb(), _muab(), _muba(), props(), maxdiff() {}
+        HAK() : DAIAlgRG(), _Qa(), _Qb(), _muab(), _muba(), _maxdiff(0.0), _iters(0U), props() {}
+
+        /// Construct from FactorGraph fg and PropertySet opts
+        HAK( const FactorGraph &fg, const PropertySet &opts );
+
+        /// Construct from RegionGraph rg and PropertySet opts
+        HAK( const RegionGraph &rg, const PropertySet &opts );
 
         /// Copy constructor
-        HAK(const HAK & x) : DAIAlgRG(x), _Qa(x._Qa), _Qb(x._Qb), _muab(x._muab), _muba(x._muba), props(x.props), maxdiff(x.maxdiff) {}
+        HAK( const HAK &x ) : DAIAlgRG(x), _Qa(x._Qa), _Qb(x._Qb), _muab(x._muab), _muba(x._muba), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props) {}
 
-        /// Clone function
-        HAK* clone() const { return new HAK(*this); }
-        
-        /// Construct from RegionGraph
-        HAK(const RegionGraph & rg, const PropertySet &opts);
+        /// Clone *this (virtual copy constructor)
+        virtual HAK* clone() const { return new HAK(*this); }
 
-        /// Construct from RactorGraph using "clusters" option
-        HAK(const FactorGraph & fg, const PropertySet &opts);
+        /// Create (virtual default constructor)
+        virtual HAK* create() const { return new HAK(); }
 
         /// Assignment operator
-        HAK & operator=(const HAK & x) {
+        HAK& operator=( const HAK &x ) {
             if( this != &x ) {
-                DAIAlgRG::operator=(x);
-                _Qa    = x._Qa;
-                _Qb    = x._Qb;
-                _muab  = x._muab;
-                _muba  = x._muba;
-                props  = x.props;
-                maxdiff = x.maxdiff;
+                DAIAlgRG::operator=( x );
+                _Qa      = x._Qa;
+                _Qb      = x._Qb;
+                _muab    = x._muab;
+                _muba    = x._muba;
+                _maxdiff = x._maxdiff;
+                _iters   = x._iters;
+                props    = x.props;
             }
             return *this;
         }
-        
-        static const char *Name;
+
+        /// Identifies itself for logging purposes
+        virtual std::string identify() const;
+
+        /// Get single node belief
+        virtual Factor belief( const Var &n ) const;
+
+        /// Get general belief
+        virtual Factor belief( const VarSet &ns ) const;
+
+        /// Get all beliefs
+        virtual std::vector<Factor> beliefs() const;
+
+        /// Get log partition sum
+        virtual Real logZ() const;
+
+        /// Clear messages and beliefs
+        virtual void init();
+
+        /// Clear messages and beliefs corresponding to the nodes in ns
+        virtual void init( const VarSet &ns );
+
+        /// The actual approximate inference algorithm
+        virtual double run();
+
+        /// Return maximum difference between single node beliefs in the last pass
+        virtual double maxDiff() const { return _maxdiff; }
+
+        /// Return number of passes over the factorgraph
+        virtual size_t Iterations() const { return _iters; }
+
 
         Factor & muab( size_t alpha, size_t _beta ) { return _muab[alpha][_beta]; }
         Factor & muba( size_t alpha, size_t _beta ) { return _muba[alpha][_beta]; }
@@ -92,20 +131,10 @@ class HAK : public DAIAlgRG {
 
         double doGBP();
         double doDoubleLoop();
-        double run();
-        void init();
-        std::string identify() const;
-        Factor belief( const Var &n ) const;
-        Factor belief( const VarSet &ns ) const;
-        std::vector<Factor> beliefs() const;
-        Real logZ () const;
-
-        void init( const VarSet &ns );
-        void undoProbs( const VarSet &ns ) { RegionGraph::undoProbs( ns ); init( ns ); }
+
         void setProperties( const PropertySet &opts );
         PropertySet getProperties() const;
         std::string printProperties() const;
-        double maxDiff() const { return maxdiff; }
 
     private:
         void constructMessages();