Cleaned up variable elimination code in ClusterGraph
[libdai.git] / include / dai / bp.h
index 536164f..c7bd160 100644 (file)
@@ -37,9 +37,9 @@ namespace dai {
  *
  *  The messages \f$m_{I\to i}(x_i)\f$ are passed from factors \f$I\f$ to variables \f$i\f$. 
  *  In case of the sum-product algorith, the update equation is: 
  *
  *  The messages \f$m_{I\to i}(x_i)\f$ are passed from factors \f$I\f$ to variables \f$i\f$. 
  *  In case of the sum-product algorith, the update equation is: 
- *    \f[ m_{I\to i}(x_i) \propto \sum_{x_{I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
+ *    \f[ m_{I\to i}(x_i) \propto \sum_{x_{N_I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
  *  and in case of the max-product algorithm:
  *  and in case of the max-product algorithm:
- *    \f[ m_{I\to i}(x_i) \propto \max_{x_{I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
+ *    \f[ m_{I\to i}(x_i) \propto \max_{x_{N_I\setminus\{i\}}} f_I(x_I) \prod_{j\in N_I\setminus\{i\}} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}\f]
  *  In order to improve convergence, the updates can be damped. For improved numerical stability,
  *  the updates can be done in the log-domain alternatively.
  *
  *  In order to improve convergence, the updates can be damped. For improved numerical stability,
  *  the updates can be done in the log-domain alternatively.
  *
@@ -50,12 +50,6 @@ namespace dai {
  *  The logarithm of the partition sum is calculated by:
  *    \f[ \log Z = \sum_i (1 - |N_i|) \sum_{x_i} b_i(x_i) \log b_i(x_i) - \sum_I \sum_{x_I} b_I(x_I) \log \frac{b_I(x_I)}{f_I(x_I)} \f]
  *
  *  The logarithm of the partition sum is calculated by:
  *    \f[ \log Z = \sum_i (1 - |N_i|) \sum_{x_i} b_i(x_i) \log b_i(x_i) - \sum_I \sum_{x_I} b_I(x_I) \log \frac{b_I(x_I)}{f_I(x_I)} \f]
  *
- *  There are several predefined update schedules:
- *    - PARALL parallel updates
- *    - SEQFIX sequential updates using a fixed sequence
- *    - SEQRND sequential updates using a random sequence
- *    - SEQMAX maximum-residual updates [\ref EMK06]
- *
  *  For the max-product algorithm, a heuristic way of finding the MAP state (the 
  *  joint configuration of all variables which has maximum probability) is provided
  *  by the findMaximum() method, which can be called after convergence.
  *  For the max-product algorithm, a heuristic way of finding the MAP state (the 
  *  joint configuration of all variables which has maximum probability) is provided
  *  by the findMaximum() method, which can be called after convergence.
@@ -65,7 +59,7 @@ namespace dai {
  *  enabled by defining DAI_BP_FAST as false in the source file.
  */
 class BP : public DAIAlgFG {
  *  enabled by defining DAI_BP_FAST as false in the source file.
  */
 class BP : public DAIAlgFG {
-    private:
+    protected:
         /// Type used for index cache
         typedef std::vector<size_t> ind_t;
         /// Type used for storing edge properties
         /// Type used for index cache
         typedef std::vector<size_t> ind_t;
         /// Type used for storing edge properties
@@ -95,15 +89,25 @@ class BP : public DAIAlgFG {
         std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
 
     public:
         std::vector<std::pair<std::size_t, std::size_t> > _sentMessages;
 
     public:
-        /// Parameters of this inference algorithm
+        /// Parameters for BP
         struct Properties {
             /// Enumeration of possible update schedules
         struct Properties {
             /// Enumeration of possible update schedules
+            /** The following update schedules have been defined:
+             *  - PARALL parallel updates
+             *  - SEQFIX sequential updates using a fixed sequence
+             *  - SEQRND sequential updates using a random sequence
+             *  - SEQMAX maximum-residual updates [\ref EMK06]
+             */
             DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL);
 
             /// Enumeration of inference variants
             DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL);
 
             /// Enumeration of inference variants
+            /** There are two inference variants:
+             *  - SUMPROD Sum-Product
+             *  - MAXPROD Max-Product (equivalent to Min-Sum)
+             */
             DAI_ENUM(InfType,SUMPROD,MAXPROD);
 
             DAI_ENUM(InfType,SUMPROD,MAXPROD);
 
-            /// Verbosity
+            /// Verbosity (amount of output sent to stderr)
             size_t verbose;
 
             /// Maximum number of iterations
             size_t verbose;
 
             /// Maximum number of iterations
@@ -121,7 +125,7 @@ class BP : public DAIAlgFG {
             /// Message update schedule
             UpdateType updates;
 
             /// Message update schedule
             UpdateType updates;
 
-            /// Type of inference: sum-product or max-product?
+            /// Inference variant
             InfType inference;
         } props;
 
             InfType inference;
         } props;
 
@@ -138,6 +142,8 @@ class BP : public DAIAlgFG {
         BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
 
         /// Construct from FactorGraph \a fg and PropertySet \a opts
         BP() : DAIAlgFG(), _edges(), _edge2lut(), _lut(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {}
 
         /// Construct from FactorGraph \a fg and PropertySet \a opts
+        /** \param opts Parameters @see Properties
+         */
         BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
             setProperties( opts );
             construct();
         BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), _sentMessages(), props(), recordSentMessages(false) {
             setProperties( opts );
             construct();
@@ -174,8 +180,8 @@ class BP : public DAIAlgFG {
     //@{
         virtual BP* clone() const { return new BP(*this); }
         virtual std::string identify() const;
     //@{
         virtual BP* clone() const { return new BP(*this); }
         virtual std::string identify() const;
-        virtual Factor belief( const Var &n ) const;
-        virtual Factor belief( const VarSet &ns ) const;
+        virtual Factor belief( const Var &v ) const { return beliefV( findVar( v ) ); }
+        virtual Factor belief( const VarSet &vs ) const;
         virtual Factor beliefV( size_t i ) const;
         virtual Factor beliefF( size_t I ) const;
         virtual std::vector<Factor> beliefs() const;
         virtual Factor beliefV( size_t i ) const;
         virtual Factor beliefF( size_t I ) const;
         virtual std::vector<Factor> beliefs() const;
@@ -206,7 +212,7 @@ class BP : public DAIAlgFG {
         void clearSentMessages() { _sentMessages.clear(); }
     //@}
 
         void clearSentMessages() { _sentMessages.clear(); }
     //@}
 
-    private:
+    protected:
         /// Returns constant reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
         const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
         /// Returns reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
         /// Returns constant reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
         const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
         /// Returns reference to message from the \a _I 'th neighbor of variable \a i to variable \a i
@@ -225,7 +231,7 @@ class BP : public DAIAlgFG {
         Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
 
         /// Calculate the updated message from the \a _I 'th neighbor of variable \a i to variable \a i
         Real & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
 
         /// Calculate the updated message from the \a _I 'th neighbor of variable \a i to variable \a i
-        void calcNewMessage( size_t i, size_t _I );
+        virtual void calcNewMessage( size_t i, size_t _I );
         /// Replace the "old" message from the \a _I 'th neighbor of variable \a i to variable \a i by the "new" (updated) message
         void updateMessage( size_t i, size_t _I );
         /// Set the residual (difference between new and old message) for the edge between variable \a i and its \a _I 'th neighbor to \a r
         /// Replace the "old" message from the \a _I 'th neighbor of variable \a i to variable \a i by the "new" (updated) message
         void updateMessage( size_t i, size_t _I );
         /// Set the residual (difference between new and old message) for the edge between variable \a i and its \a _I 'th neighbor to \a r
@@ -235,10 +241,10 @@ class BP : public DAIAlgFG {
         /// Calculates unnormalized belief of variable \a i
         void calcBeliefV( size_t i, Prob &p ) const;
         /// Calculates unnormalized belief of factor \a I
         /// Calculates unnormalized belief of variable \a i
         void calcBeliefV( size_t i, Prob &p ) const;
         /// Calculates unnormalized belief of factor \a I
-        void calcBeliefF( size_t I, Prob &p ) const;
+        virtual void calcBeliefF( size_t I, Prob &p ) const;
 
         /// Helper function for constructors
 
         /// Helper function for constructors
-        void construct();
+        virtual void construct();
 };
 
 
 };