Merged regiongraph.* and daialg.* from SVN head,
[libdai.git] / include / dai / hak.h
index fac40ba..112c881 100644 (file)
@@ -27,6 +27,7 @@
 #include <dai/daialg.h>
 #include <dai/regiongraph.h>
 #include <dai/enum.h>
 #include <dai/daialg.h>
 #include <dai/regiongraph.h>
 #include <dai/enum.h>
+#include <dai/properties.h>
 
 
 namespace dai {
 
 
 namespace dai {
@@ -35,65 +36,80 @@ namespace dai {
 /// HAK provides an implementation of the single and double-loop algorithms by Heskes, Albers and Kappen
 class HAK : public DAIAlgRG {
     protected:
 /// HAK provides an implementation of the single and double-loop algorithms by Heskes, Albers and Kappen
 class HAK : public DAIAlgRG {
     protected:
-        std::vector<Factor>          _Qa;
-        std::vector<Factor>          _Qb;
-        std::vector<Factor>          _muab;
-        std::vector<Factor>          _muba;
+        std::vector<Factor>                _Qa;
+        std::vector<Factor>                _Qb;
+        std::vector<std::vector<Factor> >  _muab;
+        std::vector<std::vector<Factor> >  _muba;
+
+    public:
+        struct Properties {
+            size_t verbose;
+            size_t maxiter;
+            double tol;
+            DAI_ENUM(ClustersType,MIN,DELTA,LOOP)
+            ClustersType clusters;
+            bool doubleloop;
+            size_t loopdepth;
+        } props;
+        double maxdiff;
         
     public:
         /// Default constructor
         
     public:
         /// Default constructor
-        HAK() : DAIAlgRG() {};
+        HAK() : DAIAlgRG(), _Qa(), _Qb(), _muab(), _muba(), props(), maxdiff() {}
 
         /// Copy constructor
 
         /// Copy constructor
-        HAK(const HAK & x) : DAIAlgRG(x), _Qa(x._Qa), _Qb(x._Qb), _muab(x._muab), _muba(x._muba) {};
+        HAK(const HAK & x) : DAIAlgRG(x), _Qa(x._Qa), _Qb(x._Qb), _muab(x._muab), _muba(x._muba), props(x.props), maxdiff(x.maxdiff) {}
 
         /// Clone function
         HAK* clone() const { return new HAK(*this); }
         
 
         /// Clone function
         HAK* clone() const { return new HAK(*this); }
         
+        /// Create (virtual constructor)
+        virtual HAK* create() const { return new HAK(); }
+
         /// Construct from RegionGraph
         /// Construct from RegionGraph
-        HAK(const RegionGraph & rg, const Properties &opts);
+        HAK(const RegionGraph & rg, const PropertySet &opts);
 
         /// Construct from RactorGraph using "clusters" option
 
         /// Construct from RactorGraph using "clusters" option
-        HAK(const FactorGraph & fg, const Properties &opts);
+        HAK(const FactorGraph & fg, const PropertySet &opts);
 
         /// Assignment operator
         HAK & operator=(const HAK & x) {
             if( this != &x ) {
                 DAIAlgRG::operator=(x);
 
         /// Assignment operator
         HAK & operator=(const HAK & x) {
             if( this != &x ) {
                 DAIAlgRG::operator=(x);
-                _Qa         = x._Qa;
-                _Qb         = x._Qb;
-                _muab       = x._muab;
-                _muba       = x._muba;
+                _Qa    = x._Qa;
+                _Qb    = x._Qb;
+                _muab  = x._muab;
+                _muba  = x._muba;
+                props  = x.props;
+                maxdiff = x.maxdiff;
             }
             return *this;
         }
         
         static const char *Name;
 
             }
             return *this;
         }
         
         static const char *Name;
 
-        ENUM3(ClustersType,MIN,DELTA,LOOP)
-        ClustersType Clusters() const { return GetPropertyAs<ClustersType>("clusters"); }
-        bool DoubleLoop() { return GetPropertyAs<bool>("doubleloop"); }
-        size_t LoopDepth() { return GetPropertyAs<size_t>("loopdepth"); }
-
-        Factor & muab( size_t alpha, size_t beta ) { return _muab[ORIR2E(alpha,beta)]; }
-        Factor & muba( size_t beta, size_t alpha ) { return _muba[ORIR2E(alpha,beta)]; }
+        Factor & muab( size_t alpha, size_t _beta ) { return _muab[alpha][_beta]; }
+        Factor & muba( size_t alpha, size_t _beta ) { return _muba[alpha][_beta]; }
         const Factor& Qa( size_t alpha ) const { return _Qa[alpha]; };
         const Factor& Qb( size_t beta ) const { return _Qb[beta]; };
 
         const Factor& Qa( size_t alpha ) const { return _Qa[alpha]; };
         const Factor& Qb( size_t beta ) const { return _Qb[beta]; };
 
-//      void Regenerate();
         double doGBP();
         double doDoubleLoop();
         double run();
         void init();
         double doGBP();
         double doDoubleLoop();
         double run();
         void init();
+        /// Clear messages and beliefs corresponding to the nodes in ns
+        virtual void init( const VarSet &ns );
         std::string identify() const;
         Factor belief( const Var &n ) const;
         Factor belief( const VarSet &ns ) const;
         std::vector<Factor> beliefs() const;
         std::string identify() const;
         Factor belief( const Var &n ) const;
         Factor belief( const VarSet &ns ) const;
         std::vector<Factor> beliefs() const;
-        Complex logZ () const;
+        Real logZ () const;
 
 
-        void init( const VarSet &ns );
-        void undoProbs( const VarSet &ns ) { RegionGraph::undoProbs( ns ); init( ns ); }
-        bool checkProperties();
+        void restoreFactors( const VarSet &ns ) { RegionGraph::restoreFactors( ns ); init( ns ); }
+        void setProperties( const PropertySet &opts );
+        PropertySet getProperties() const;
+        std::string printProperties() const;
+        double maxDiff() const { return maxdiff; }
 
     private:
         void constructMessages();
 
     private:
         void constructMessages();