Small changes
[libdai.git] / include / dai / bp.h
index 7b66aef..cfb7a15 100644 (file)
@@ -26,6 +26,7 @@
 #include <string>
 #include <dai/daialg.h>
 #include <dai/factorgraph.h>
+#include <dai/properties.h>
 #include <dai/enum.h>
 
 
@@ -34,60 +35,80 @@ namespace dai {
 
 class BP : public DAIAlgFG {
     protected:
-        typedef std::vector<size_t>  _ind_t;
-
-        std::vector<_ind_t>          _indices;
-        std::vector<Prob>            _messages, _newmessages;
-
+        typedef std::vector<size_t> ind_t;
+        struct EdgeProp {
+            ind_t  index;
+            Prob   message;
+            Prob   newMessage;
+            double residual;
+        };
+        std::vector<std::vector<EdgeProp> > edges;
+    
     public:
-        ENUM4(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL)
-        UpdateType Updates() const { return GetPropertyAs<UpdateType>("updates"); }
+        struct Properties {
+            size_t verbose;
+            size_t maxiter;
+            double tol;
+            bool logdomain;
+            DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL)
+            UpdateType updates;
+        } props;
+        double maxdiff;
 
-        // default constructor
-        BP() : DAIAlgFG() {};
-        // copy constructor
-        BP(const BP & x) : DAIAlgFG(x), _indices(x._indices), _messages(x._messages), _newmessages(x._newmessages) {};
+    public:
+        /// Default constructor
+        BP() : DAIAlgFG(), edges(), props(), maxdiff(0.0) {};
+        /// Copy constructor
+        BP( const BP & x ) : DAIAlgFG(x), edges(x.edges), props(x.props), maxdiff(x.maxdiff) {};
+        /// Clone *this
         BP* clone() const { return new BP(*this); }
-        // construct BP object from FactorGraph
-        BP(const FactorGraph & fg, const Properties &opts) : DAIAlgFG(fg, opts) {
-            assert( checkProperties() );
-            Regenerate();
+        /// Construct from FactorGraph fg and PropertySet opts
+        BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), edges(), props(), maxdiff(0.0) {
+            setProperties( opts );
+            create();
         }
-        // assignment operator
-        BP & operator=(const BP & x) {
-            if(this!=&x) {
-                DAIAlgFG::operator=(x);
-                _messages = x._messages;
-                _newmessages = x._newmessages;
-                _indices = x._indices;
+        /// Assignment operator
+        BP& operator=( const BP & x ) {
+            if( this != &x ) {
+                DAIAlgFG::operator=( x );
+                edges = x.edges;
+                props = x.props;
+                maxdiff = x.maxdiff;
             }
             return *this;
         }
 
         static const char *Name;
 
-        Prob & message(size_t i1, size_t i2) { return( _messages[VV2E(i1,i2)] ); }  
-        const Prob & message(size_t i1, size_t i2) const { return( _messages[VV2E(i1,i2)] ); }  
-        Prob & newMessage(size_t i1, size_t i2) { return( _newmessages[VV2E(i1,i2)] ); }    
-        const Prob & newMessage(size_t i1, size_t i2) const { return( _newmessages[VV2E(i1,i2)] ); }    
-        _ind_t & index(size_t i1, size_t i2) { return( _indices[VV2E(i1,i2)] ); }
-        const _ind_t & index(size_t i1, size_t i2) const { return( _indices[VV2E(i1,i2)] ); }
+        Prob & message(size_t i, size_t _I) { return edges[i][_I].message; }
+        const Prob & message(size_t i, size_t _I) const { return edges[i][_I].message; }
+        Prob & newMessage(size_t i, size_t _I) { return edges[i][_I].newMessage; }
+        const Prob & newMessage(size_t i, size_t _I) const { return edges[i][_I].newMessage; }
+        ind_t & index(size_t i, size_t _I) { return edges[i][_I].index; }
+        const ind_t & index(size_t i, size_t _I) const { return edges[i][_I].index; }
+        double & residual(size_t i, size_t _I) { return edges[i][_I].residual; }
+        const double & residual(size_t i, size_t _I) const { return edges[i][_I].residual; }
 
         std::string identify() const;
-        void Regenerate();
+        void create();
         void init();
-        void calcNewMessage(size_t iI);
         double run();
-        Factor belief1 (size_t i) const;
-        Factor belief2 (size_t I) const;
+
+        void findMaxResidual( size_t &i, size_t &_I );
+        void calcNewMessage( size_t i, size_t _I );
+        Factor beliefV (size_t i) const;
+        Factor beliefF (size_t I) const;
         Factor belief (const Var &n) const;
         Factor belief (const VarSet &n) const;
         std::vector<Factor> beliefs() const;
-        Complex logZ() const;
+        Real logZ() const;
 
         void init( const VarSet &ns );
         void undoProbs( const VarSet &ns ) { FactorGraph::undoProbs(ns); init(ns); }
-        bool checkProperties();
+
+        void setProperties( const PropertySet &opts );
+        PropertySet getProperties() const;
+        double maxDiff() const { return maxdiff; }
 };