Changed license from GPL v2+ to FreeBSD (aka BSD 2-clause) license
[libdai.git] / include / dai / bp.h
index ec56dec..8c0a7b2 100644 (file)
@@ -1,16 +1,14 @@
 /*  This file is part of libDAI - http://www.libdai.org/
  *
 /*  This file is part of libDAI - http://www.libdai.org/
  *
- *  libDAI is licensed under the terms of the GNU General Public License version
- *  2, or (at your option) any later version. libDAI is distributed without any
- *  warranty. See the file COPYING for more details.
+ *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
  *
  *
- *  Copyright (C) 2006-2010  Joris Mooij  [joris dot mooij at libdai dot org]
- *  Copyright (C) 2006-2007  Radboud University Nijmegen, The Netherlands
+ *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
  */
 
 
 /// \file
 /// \brief Defines class BP, which implements (Loopy) Belief Propagation
  */
 
 
 /// \file
 /// \brief Defines class BP, which implements (Loopy) Belief Propagation
+/// \todo Consider using a priority_queue for maximum residual schedule
 
 
 #ifndef __defined_libdai_bp_h
 
 
 #ifndef __defined_libdai_bp_h
@@ -87,6 +85,12 @@ class BP : public DAIAlgFG {
         size_t _iters;
         /// The history of message updates (only recorded if \a recordSentMessages is \c true)
         std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
         size_t _iters;
         /// The history of message updates (only recorded if \a recordSentMessages is \c true)
         std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
+        /// Stores variable beliefs of previous iteration
+        std::vector<Factor> _oldBeliefsV;
+        /// Stores factor beliefs of previous iteration
+        std::vector<Factor> _oldBeliefsF;
+        /// Stores the update schedule
+        std::vector<Edge> _updateSeq;
 
     public:
         /// Parameters for BP
 
     public:
         /// Parameters for BP
@@ -113,6 +117,9 @@ class BP : public DAIAlgFG {
             /// Maximum number of iterations
             size_t maxiter;
 
             /// Maximum number of iterations
             size_t maxiter;
 
+            /// Maximum time (in seconds)
+            double maxtime;
+
             /// Tolerance for convergence test
             Real tol;
 
             /// Tolerance for convergence test
             Real tol;
 
@@ -129,9 +136,6 @@ class BP : public DAIAlgFG {
             InfType inference;
         } props;
 
             InfType inference;
         } props;
 
-        /// Name of this inference algorithm
-        static const char *Name;
-
         /// Specifies whether the history of message updates should be recorded
         bool recordSentMessages;
 
         /// Specifies whether the history of message updates should be recorded
         bool recordSentMessages;
 
@@ -139,22 +143,19 @@ class BP : public DAIAlgFG {
     /// \name Constructors/destructors
     //@{
         /// Default constructor
     /// \name Constructors/destructors
     //@{
         /// Default constructor
-        BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
+        BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) {}
 
         /// Construct from FactorGraph \a fg and PropertySet \a opts
         /** \param fg Factor graph.
          *  \param opts Parameters @see Properties
          */
 
         /// Construct from FactorGraph \a fg and PropertySet \a opts
         /** \param fg Factor graph.
          *  \param opts Parameters @see Properties
          */
-        BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
+        BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) {
             setProperties( opts );
             construct();
         }
 
         /// Copy constructor
             setProperties( opts );
             construct();
         }
 
         /// Copy constructor
-        BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut),
-            _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages),
-            props(x.props), recordSentMessages(x.recordSentMessages)
-        {
+        BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut), _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages), _oldBeliefsV(x._oldBeliefsV), _oldBeliefsF(x._oldBeliefsF), _updateSeq(x._updateSeq), props(x.props), recordSentMessages(x.recordSentMessages) {
             for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
                 _edge2lut[l->second.first][l->second.second] = l;
         }
             for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
                 _edge2lut[l->second.first][l->second.second] = l;
         }
@@ -170,6 +171,9 @@ class BP : public DAIAlgFG {
                 _maxdiff = x._maxdiff;
                 _iters = x._iters;
                 _sentMessages = x._sentMessages;
                 _maxdiff = x._maxdiff;
                 _iters = x._iters;
                 _sentMessages = x._sentMessages;
+                _oldBeliefsV = x._oldBeliefsV;
+                _oldBeliefsF = x._oldBeliefsF;
+                _updateSeq = x._updateSeq;
                 props = x.props;
                 recordSentMessages = x.recordSentMessages;
             }
                 props = x.props;
                 recordSentMessages = x.recordSentMessages;
             }
@@ -180,18 +184,22 @@ class BP : public DAIAlgFG {
     /// \name General InfAlg interface
     //@{
         virtual BP* clone() const { return new BP(*this); }
     /// \name General InfAlg interface
     //@{
         virtual BP* clone() const { return new BP(*this); }
-        virtual std::string identify() const;
+        virtual std::string name() const { return "BP"; }
         virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
         virtual Factor belief( const VarSet &vs ) const;
         virtual Factor beliefV( size_t i ) const;
         virtual Factor beliefF( size_t I ) const;
         virtual std::vector<Factor> beliefs() const;
         virtual Real logZ() const;
         virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
         virtual Factor belief( const VarSet &vs ) const;
         virtual Factor beliefV( size_t i ) const;
         virtual Factor beliefF( size_t I ) const;
         virtual std::vector<Factor> beliefs() const;
         virtual Real logZ() const;
+        /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
+         */
+        std::vector<std::size_t> findMaximum() const { return dai::findMaximum( *this ); }
         virtual void init();
         virtual void init( const VarSet &ns );
         virtual Real run();
         virtual Real maxDiff() const { return _maxdiff; }
         virtual size_t Iterations() const { return _iters; }
         virtual void init();
         virtual void init( const VarSet &ns );
         virtual Real run();
         virtual Real maxDiff() const { return _maxdiff; }
         virtual size_t Iterations() const { return _iters; }
+        virtual void setMaxIter( size_t maxiter ) { props.maxiter = maxiter; }
         virtual void setProperties( const PropertySet &opts );
         virtual PropertySet getProperties() const;
         virtual std::string printProperties() const;
         virtual void setProperties( const PropertySet &opts );
         virtual PropertySet getProperties() const;
         virtual std::string printProperties() const;
@@ -199,11 +207,6 @@ class BP : public DAIAlgFG {
 
     /// \name Additional interface specific for BP
     //@{
 
     /// \name Additional interface specific for BP
     //@{
-        /// Calculates the joint state of all variables that has maximum probability
-        /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
-         */
-        std::vector<std::size_t> findMaximum() const;
-
         /// Returns history of which messages have been updated
         const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
             return _sentMessages;
         /// Returns history of which messages have been updated
         const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
             return _sentMessages;