Merged SVN head ...
[libdai.git] / include / dai / mr.h
index 707d4f6..34ec957 100644 (file)
@@ -1,6 +1,6 @@
 /*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
     Radboud University Nijmegen, The Netherlands
-    
+
     This file is part of libDAI.
 
     libDAI is free software; you can redistribute it and/or modify
@@ -28,6 +28,8 @@
 #include <dai/factorgraph.h>
 #include <dai/daialg.h>
 #include <dai/enum.h>
+#include <dai/properties.h>
+#include <dai/exceptions.h>
 
 
 namespace dai {
@@ -54,29 +56,104 @@ class MR : public DAIAlgFG {
 
         std::vector<double> Mag;
 
+        double _maxdiff;
+        size_t _iters;
+
     public:
-        ENUM2(UpdateType,FULL,LINEAR)
-        ENUM3(InitType,RESPPROP,CLAMPING,EXACT)
+        struct Properties {
+            size_t verbose;
+            double tol;
+            DAI_ENUM(UpdateType,FULL,LINEAR)
+            DAI_ENUM(InitType,RESPPROP,CLAMPING,EXACT)
+            UpdateType updates;
+            InitType inits;
+        } props;
+        static const char *Name;
+
+    public:
+        /// Default constructor
+        MR() : DAIAlgFG(), supported(), con(), nb(), tJ(), theta(), M(), kindex(), cors(), N(), Mag(), _maxdiff(), _iters(), props() {}
+
+        /// Construct from FactorGraph fg and PropertySet opts
+        MR( const FactorGraph &fg, const PropertySet &opts );
+
+        /// Copy constructor
+        MR( const MR &x ) : DAIAlgFG(x), supported(x.supported), con(x.con), nb(x.nb), tJ(x.tJ), theta(x.theta), M(x.M), kindex(x.kindex), cors(x.cors), N(x.N), Mag(x.Mag), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props) {}
+
+        /// Clone *this (virtual copy constructor)
+        virtual MR* clone() const { return new MR(*this); }
 
-        UpdateType Updates() const { return GetPropertyAs<UpdateType>("updates"); }
-        InitType Inits() const { return GetPropertyAs<InitType>("inits"); }
+        /// Create (virtual default constructor)
+        virtual MR* create() const { return new MR(); }
 
-        MR( const FactorGraph & fg, const Properties &opts );
-        void init(size_t _N, double *_w, double *_th);
+        /// Assignment operator
+        MR& operator=( const MR &x ) {
+            if( this != &x ) {
+                DAIAlgFG::operator=(x);
+                supported = x.supported;
+                con       = x.con; 
+                nb        = x.nb;
+                tJ        = x.tJ;
+                theta     = x.theta;
+                M         = x.M;
+                kindex    = x.kindex;
+                cors      = x.cors;
+                N         = x.N;
+                Mag       = x.Mag;
+                _maxdiff  = x._maxdiff;
+                _iters    = x._iters;
+                props     = x.props;
+            }
+            return *this;
+        }
+
+        /// Identifies itself for logging purposes
+        virtual std::string identify() const;
+
+        /// Get single node belief
+        virtual Factor belief( const Var &n ) const;
+
+        /// Get general belief
+        virtual Factor belief( const VarSet &/*ns*/ ) const { 
+            DAI_THROW(NOT_IMPLEMENTED);
+            return Factor(); 
+        }
+
+        /// Get all beliefs
+        virtual std::vector<Factor> beliefs() const;
+
+        /// Get log partition sum
+        virtual Real logZ() const { 
+            DAI_THROW(NOT_IMPLEMENTED);
+            return 0.0; 
+        }
+
+        /// Clear messages and beliefs
+        virtual void init() {}
+
+        /// Clear messages and beliefs corresponding to the nodes in ns
+        virtual void init( const VarSet &/*ns*/ ) {
+            DAI_THROW(NOT_IMPLEMENTED);
+        }
+
+        /// The actual approximate inference algorithm
+        virtual double run();
+
+        /// Return maximum difference between single node beliefs in the last pass
+        virtual double maxDiff() const { return _maxdiff; }
+
+        /// Return number of passes over the factorgraph
+        virtual size_t Iterations() const { return _iters; }
+
+
+        void init(size_t Nin, double *_w, double *_th);
         void makekindex();
         void read_files();
         void init_cor();
         double init_cor_resp();
         void solvemcav();
         void solveM();
-        double run();
-        Factor belief( const Var &n ) const;
-        Factor belief( const VarSet &/*ns*/ ) const { assert( 0 == 1 ); }
-        std::vector<Factor> beliefs() const;
-        Complex logZ() const { return NAN; }
-        void init() { assert( checkProperties() ); }
-        static const char *Name;
-        std::string identify() const;
+
         double _tJ(size_t i, sub_nb A);
 
         double Omega(size_t i, size_t _j, size_t _l);
@@ -89,10 +166,10 @@ class MR : public DAIAlgFG {
         void sum_subs(size_t j, sub_nb A, double *sum_even, double *sum_odd);
 
         double sign(double a) { return (a >= 0) ? 1.0 : -1.0; }
-        MR* clone() const { assert( 0 == 1 ); }
-
-        bool checkProperties();
-
+        
+        void setProperties( const PropertySet &opts );
+        PropertySet getProperties() const;
+        std::string printProperties() const;
 }; 
 
 
@@ -106,7 +183,7 @@ class sub_nb {
     public:
         // construct full subset containing nr_elmt elements
         sub_nb(size_t nr_elmt) {
-#ifdef DEBUG
+#ifdef DAI_DEBUG
             assert( nr_elmt < sizeof(size_t) / sizeof(char) * 8 );
 #endif
             bits = nr_elmt;
@@ -152,7 +229,7 @@ class sub_nb {
                     else
                         i--;
                 }
-#ifdef DEBUG
+#ifdef DAI_DEBUG
             assert( bit < bits );
 #endif
             return bit;