Merged SVN head ...
[libdai.git] / include / dai / hak.h
index 9420298..159c2f7 100644 (file)
@@ -1,6 +1,6 @@
 /*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
     Radboud University Nijmegen, The Netherlands
-    
+
     This file is part of libDAI.
 
     libDAI is free software; you can redistribute it and/or modify
@@ -27,6 +27,7 @@
 #include <dai/daialg.h>
 #include <dai/regiongraph.h>
 #include <dai/enum.h>
+#include <dai/properties.h>
 
 
 namespace dai {
@@ -39,41 +40,89 @@ class HAK : public DAIAlgRG {
         std::vector<Factor>                _Qb;
         std::vector<std::vector<Factor> >  _muab;
         std::vector<std::vector<Factor> >  _muba;
-        
+        /// Maximum difference encountered so far
+        double _maxdiff;
+        /// Number of iterations needed
+        size_t _iters;
+
+    public:
+        struct Properties {
+            size_t verbose;
+            size_t maxiter;
+            double tol;
+            double damping;
+            DAI_ENUM(ClustersType,MIN,DELTA,LOOP)
+            ClustersType clusters;
+            bool doubleloop;
+            size_t loopdepth;
+        } props;
+        /// Name of this inference method
+        static const char *Name;
+
     public:
         /// Default constructor
-        HAK() : DAIAlgRG() {};
+        HAK() : DAIAlgRG(), _Qa(), _Qb(), _muab(), _muba(), _maxdiff(0.0), _iters(0U), props() {}
+
+        /// Construct from FactorGraph fg and PropertySet opts
+        HAK( const FactorGraph &fg, const PropertySet &opts );
+
+        /// Construct from RegionGraph rg and PropertySet opts
+        HAK( const RegionGraph &rg, const PropertySet &opts );
 
         /// Copy constructor
-        HAK(const HAK & x) : DAIAlgRG(x), _Qa(x._Qa), _Qb(x._Qb), _muab(x._muab), _muba(x._muba) {};
+        HAK( const HAK &x ) : DAIAlgRG(x), _Qa(x._Qa), _Qb(x._Qb), _muab(x._muab), _muba(x._muba), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props) {}
 
-        /// Clone function
-        HAK* clone() const { return new HAK(*this); }
-        
-        /// Construct from RegionGraph
-        HAK(const RegionGraph & rg, const Properties &opts);
+        /// Clone *this (virtual copy constructor)
+        virtual HAK* clone() const { return new HAK(*this); }
 
-        /// Construct from RactorGraph using "clusters" option
-        HAK(const FactorGraph & fg, const Properties &opts);
+        /// Create (virtual default constructor)
+        virtual HAK* create() const { return new HAK(); }
 
         /// Assignment operator
-        HAK & operator=(const HAK & x) {
+        HAK& operator=( const HAK &x ) {
             if( this != &x ) {
-                DAIAlgRG::operator=(x);
-                _Qa         = x._Qa;
-                _Qb         = x._Qb;
-                _muab       = x._muab;
-                _muba       = x._muba;
+                DAIAlgRG::operator=( x );
+                _Qa      = x._Qa;
+                _Qb      = x._Qb;
+                _muab    = x._muab;
+                _muba    = x._muba;
+                _maxdiff = x._maxdiff;
+                _iters   = x._iters;
+                props    = x.props;
             }
             return *this;
         }
-        
-        static const char *Name;
 
-        ENUM3(ClustersType,MIN,DELTA,LOOP)
-        ClustersType Clusters() const { return GetPropertyAs<ClustersType>("clusters"); }
-        bool DoubleLoop() { return GetPropertyAs<bool>("doubleloop"); }
-        size_t LoopDepth() { return GetPropertyAs<size_t>("loopdepth"); }
+        /// Identifies itself for logging purposes
+        virtual std::string identify() const;
+
+        /// Get single node belief
+        virtual Factor belief( const Var &n ) const;
+
+        /// Get general belief
+        virtual Factor belief( const VarSet &ns ) const;
+
+        /// Get all beliefs
+        virtual std::vector<Factor> beliefs() const;
+
+        /// Get log partition sum
+        virtual Real logZ() const;
+
+        /// Clear messages and beliefs
+        virtual void init();
+
+        /// Clear messages and beliefs corresponding to the nodes in ns
+        virtual void init( const VarSet &ns );
+
+        /// The actual approximate inference algorithm
+        virtual double run();
+
+        /// Return maximum difference between single node beliefs in the last pass
+        virtual double maxDiff() const { return _maxdiff; }
+
+        /// Return number of passes over the factorgraph
+        virtual size_t Iterations() const { return _iters; }
+
 
         Factor & muab( size_t alpha, size_t _beta ) { return _muab[alpha][_beta]; }
         Factor & muba( size_t alpha, size_t _beta ) { return _muba[alpha][_beta]; }
@@ -82,17 +131,10 @@ class HAK : public DAIAlgRG {
 
         double doGBP();
         double doDoubleLoop();
-        double run();
-        void init();
-        std::string identify() const;
-        Factor belief( const Var &n ) const;
-        Factor belief( const VarSet &ns ) const;
-        std::vector<Factor> beliefs() const;
-        Real logZ () const;
-
-        void init( const VarSet &ns );
-        void undoProbs( const VarSet &ns ) { RegionGraph::undoProbs( ns ); init( ns ); }
-        bool checkProperties();
+
+        void setProperties( const PropertySet &opts );
+        PropertySet getProperties() const;
+        std::string printProperties() const;
 
     private:
         void constructMessages();