Merged SVN head ...
[libdai.git] / include / dai / bp.h
index 7b66aef..f5b4886 100644 (file)
@@ -1,6 +1,6 @@
 /*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
     Radboud University Nijmegen, The Netherlands
-    
+
     This file is part of libDAI.
 
     libDAI is free software; you can redistribute it and/or modify
@@ -26,6 +26,7 @@
 #include <string>
 #include <dai/daialg.h>
 #include <dai/factorgraph.h>
+#include <dai/properties.h>
 #include <dai/enum.h>
 
 
@@ -33,61 +34,124 @@ namespace dai {
 
 
 class BP : public DAIAlgFG {
-    protected:
-        typedef std::vector<size_t>  _ind_t;
-
-        std::vector<_ind_t>          _indices;
-        std::vector<Prob>            _messages, _newmessages;
+    private:
+        typedef std::vector<size_t> ind_t;
+        struct EdgeProp {
+            ind_t  index;
+            Prob   message;
+            Prob   newMessage;
+            double residual;
+        };
+        std::vector<std::vector<EdgeProp> > _edges;
+        /// Maximum difference encountered so far
+        double _maxdiff;
+        /// Number of iterations needed
+        size_t _iters;
+    
+    public:
+        struct Properties {
+            size_t verbose;
+            size_t maxiter;
+            double tol;
+            bool logdomain;
+            double damping;
+            DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL)
+            UpdateType updates;
+        } props;
+        static const char *Name;
 
     public:
-        ENUM4(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL)
-        UpdateType Updates() const { return GetPropertyAs<UpdateType>("updates"); }
-
-        // default constructor
-        BP() : DAIAlgFG() {};
-        // copy constructor
-        BP(const BP & x) : DAIAlgFG(x), _indices(x._indices), _messages(x._messages), _newmessages(x._newmessages) {};
-        BP* clone() const { return new BP(*this); }
-        // construct BP object from FactorGraph
-        BP(const FactorGraph & fg, const Properties &opts) : DAIAlgFG(fg, opts) {
-            assert( checkProperties() );
-            Regenerate();
+        /// Default constructor
+        BP() : DAIAlgFG(), _edges(), _maxdiff(0.0), _iters(0U), props() {}
+
+        /// Construct from FactorGraph fg and PropertySet opts
+        BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), props() {
+            setProperties( opts );
+            construct();
         }
-        // assignment operator
-        BP & operator=(const BP & x) {
-            if(this!=&x) {
-                DAIAlgFG::operator=(x);
-                _messages = x._messages;
-                _newmessages = x._newmessages;
-                _indices = x._indices;
+
+        /// Copy constructor
+        BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props) {}
+
+        /// Clone *this (virtual copy constructor)
+        virtual BP* clone() const { return new BP(*this); }
+
+        /// Create (virtual default constructor)
+        virtual BP* create() const { return new BP(); }
+
+        /// Assignment operator
+        BP& operator=( const BP &x ) {
+            if( this != &x ) {
+                DAIAlgFG::operator=( x );
+                _edges = x._edges;
+                _maxdiff = x._maxdiff;
+                _iters = x._iters;
+                props = x.props;
             }
             return *this;
         }
 
-        static const char *Name;
+        /// Identifies itself for logging purposes
+        virtual std::string identify() const;
+
+        /// Get single node belief
+        virtual Factor belief( const Var &n ) const;
+
+        /// Get general belief
+        virtual Factor belief( const VarSet &ns ) const;
+
+        /// Get all beliefs
+        virtual std::vector<Factor> beliefs() const;
+
+        /// Get log partition sum
+        virtual Real logZ() const;
+
+        /// Clear messages and beliefs
+        virtual void init();
+
+        /// Clear messages and beliefs corresponding to the nodes in ns
+        virtual void init( const VarSet &ns );
+
+        /// The actual approximate inference algorithm
+        virtual double run();
+
+        /// Return maximum difference between single node beliefs in the last pass
+        virtual double maxDiff() const { return _maxdiff; }
+
+        /// Return number of passes over the factorgraph
+        virtual size_t Iterations() const { return _iters; }
+
+
+        Factor beliefV( size_t i ) const;
+        Factor beliefF( size_t I ) const;
+
+    private:
+        const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
+        Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
+        Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; }
+        const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; }
+        ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; }
+        const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; }
+        double & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
+        const double & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; }
+
+        void calcNewMessage( size_t i, size_t _I );
+        void updateMessage( size_t i, size_t _I ) {
+            if( props.damping == 0.0 ) {
+                message(i,_I) = newMessage(i,_I);
+                residual(i,_I) = 0.0;
+            } else {
+                message(i,_I) = (message(i,_I) ^ props.damping) * (newMessage(i,_I) ^ (1.0 - props.damping));
+                residual(i,_I) = dist( newMessage(i,_I), message(i,_I), Prob::DISTLINF );
+            }
+        }
+        void findMaxResidual( size_t &i, size_t &_I );
 
-        Prob & message(size_t i1, size_t i2) { return( _messages[VV2E(i1,i2)] ); }  
-        const Prob & message(size_t i1, size_t i2) const { return( _messages[VV2E(i1,i2)] ); }  
-        Prob & newMessage(size_t i1, size_t i2) { return( _newmessages[VV2E(i1,i2)] ); }    
-        const Prob & newMessage(size_t i1, size_t i2) const { return( _newmessages[VV2E(i1,i2)] ); }    
-        _ind_t & index(size_t i1, size_t i2) { return( _indices[VV2E(i1,i2)] ); }
-        const _ind_t & index(size_t i1, size_t i2) const { return( _indices[VV2E(i1,i2)] ); }
-
-        std::string identify() const;
-        void Regenerate();
-        void init();
-        void calcNewMessage(size_t iI);
-        double run();
-        Factor belief1 (size_t i) const;
-        Factor belief2 (size_t I) const;
-        Factor belief (const Var &n) const;
-        Factor belief (const VarSet &n) const;
-        std::vector<Factor> beliefs() const;
-        Complex logZ() const;
-
-        void init( const VarSet &ns );
-        void undoProbs( const VarSet &ns ) { FactorGraph::undoProbs(ns); init(ns); }
-        bool checkProperties();
+        void construct();
+        /// Set Props according to the PropertySet opts, where the values can be stored as std::strings or as the type of the corresponding Props member
+        void setProperties( const PropertySet &opts );
+        PropertySet getProperties() const;
+        std::string printProperties() const;
 };