Changed license from GPL v2+ to FreeBSD (aka BSD 2-clause) license
[libdai.git] / include / dai / bp.h
index 3bae74d..8c0a7b2 100644 (file)
@@ -1,16 +1,14 @@
 /*  This file is part of libDAI - http://www.libdai.org/
  *
- *  libDAI is licensed under the terms of the GNU General Public License version
- *  2, or (at your option) any later version. libDAI is distributed without any
- *  warranty. See the file COPYING for more details.
+ *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
  *
- *  Copyright (C) 2006-2009  Joris Mooij  [joris dot mooij at libdai dot org]
- *  Copyright (C) 2006-2007  Radboud University Nijmegen, The Netherlands
+ *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
  */
 
 
 /// \file
 /// \brief Defines class BP, which implements (Loopy) Belief Propagation
+/// \todo Consider using a priority_queue for maximum residual schedule
 
 
 #ifndef __defined_libdai_bp_h
@@ -87,6 +85,12 @@ class BP : public DAIAlgFG {
         size_t _iters;
         /// The history of message updates (only recorded if \a recordSentMessages is \c true)
         std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
+        /// Stores variable beliefs of previous iteration
+        std::vector<Factor> _oldBeliefsV;
+        /// Stores factor beliefs of previous iteration
+        std::vector<Factor> _oldBeliefsF;
+        /// Stores the update schedule
+        std::vector<Edge> _updateSeq;
 
     public:
         /// Parameters for BP
@@ -113,6 +117,9 @@ class BP : public DAIAlgFG {
             /// Maximum number of iterations
             size_t maxiter;
 
+            /// Maximum time (in seconds)
+            double maxtime;
+
             /// Tolerance for convergence test
             Real tol;
 
@@ -129,9 +136,6 @@ class BP : public DAIAlgFG {
             InfType inference;
         } props;
 
-        /// Name of this inference algorithm
-        static const char *Name;
-
         /// Specifies whether the history of message updates should be recorded
         bool recordSentMessages;
 
@@ -139,21 +143,19 @@ class BP : public DAIAlgFG {
     /// \name Constructors/destructors
     //@{
         /// Default constructor
-        BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
+        BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) {}
 
         /// Construct from FactorGraph \a fg and PropertySet \a opts
-        /** \param opts Parameters @see Properties
+        /** \param fg Factor graph.
+         *  \param opts Parameters @see Properties
          */
-        BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
+        BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), _oldBeliefsV(), _oldBeliefsF(), _updateSeq(), props(), recordSentMessages(false) {
             setProperties( opts );
             construct();
         }
 
         /// Copy constructor
-        BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut),
-            _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages),
-            props(x.props), recordSentMessages(x.recordSentMessages)
-        {
+        BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _edge2lut(x._edge2lut), _lut(x._lut), _maxdiff(x._maxdiff), _iters(x._iters), _sentMessages(x._sentMessages), _oldBeliefsV(x._oldBeliefsV), _oldBeliefsF(x._oldBeliefsF), _updateSeq(x._updateSeq), props(x.props), recordSentMessages(x.recordSentMessages) {
             for( LutType::iterator l = _lut.begin(); l != _lut.end(); ++l )
                 _edge2lut[l->second.first][l->second.second] = l;
         }
@@ -169,6 +171,9 @@ class BP : public DAIAlgFG {
                 _maxdiff = x._maxdiff;
                 _iters = x._iters;
                 _sentMessages = x._sentMessages;
+                _oldBeliefsV = x._oldBeliefsV;
+                _oldBeliefsF = x._oldBeliefsF;
+                _updateSeq = x._updateSeq;
                 props = x.props;
                 recordSentMessages = x.recordSentMessages;
             }
@@ -179,18 +184,22 @@ class BP : public DAIAlgFG {
     /// \name General InfAlg interface
     //@{
         virtual BP* clone() const { return new BP(*this); }
-        virtual std::string identify() const;
-        virtual Factor belief( const Var &n ) const;
-        virtual Factor belief( const VarSet &ns ) const;
+        virtual std::string name() const { return "BP"; }
+        virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
+        virtual Factor belief( const VarSet &vs ) const;
         virtual Factor beliefV( size_t i ) const;
         virtual Factor beliefF( size_t I ) const;
         virtual std::vector<Factor> beliefs() const;
         virtual Real logZ() const;
+        /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
+         */
+        std::vector<std::size_t> findMaximum() const { return dai::findMaximum( *this ); }
         virtual void init();
         virtual void init( const VarSet &ns );
         virtual Real run();
         virtual Real maxDiff() const { return _maxdiff; }
         virtual size_t Iterations() const { return _iters; }
+        virtual void setMaxIter( size_t maxiter ) { props.maxiter = maxiter; }
         virtual void setProperties( const PropertySet &opts );
         virtual PropertySet getProperties() const;
         virtual std::string printProperties() const;
@@ -198,11 +207,6 @@ class BP : public DAIAlgFG {
 
     /// \name Additional interface specific for BP
     //@{
-        /// Calculates the joint state of all variables that has maximum probability
-        /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
-         */
-        std::vector<std::size_t> findMaximum() const;
-
         /// Returns history of which messages have been updated
         const std::vector<std::pair<std::size_t, std::size_t> >& getSentMessages() const {
             return _sentMessages;
@@ -230,6 +234,11 @@ class BP : public DAIAlgFG {
         /// Returns reference to residual for the edge between variable \a i and its \a _I 'th neighbor
         Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
 
+        /// Calculate the product of factor \a I and the incoming messages
+        /** If \a without_i == \c true, the message coming from variable \a i is omitted from the product
+         *  \note This function is used by calcNewMessage() and calcBeliefF()
+         */
+        virtual Prob calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const;
         /// Calculate the updated message from the \a _I 'th neighbor of variable \a i to variable \a i
         virtual void calcNewMessage( size_t i, size_t _I );
         /// Replace the "old" message from the \a _I 'th neighbor of variable \a i to variable \a i by the "new" (updated) message
@@ -239,9 +248,11 @@ class BP : public DAIAlgFG {
         /// Finds the edge which has the maximum residual (difference between new and old message)
         void findMaxResidual( size_t &i, size_t &_I );
         /// Calculates unnormalized belief of variable \a i
-        void calcBeliefV( size_t i, Prob &p ) const;
+        virtual void calcBeliefV( size_t i, Prob &p ) const;
         /// Calculates unnormalized belief of factor \a I
-        virtual void calcBeliefF( size_t I, Prob &p ) const;
+        virtual void calcBeliefF( size_t I, Prob &p ) const {
+            p = calcIncomingMessageProduct( I, false, 0 );
+        }
 
         /// Helper function for constructors
         virtual void construct();