Implementing TRWBP
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Fri, 22 Jan 2010 12:41:14 +0000 (13:41 +0100)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Fri, 22 Jan 2010 12:41:14 +0000 (13:41 +0100)
16 files changed:
Makefile
Makefile.CYGWIN
Makefile.LINUX
Makefile.MACOSX
Makefile.WINDOWS
doxygen.conf
include/dai/alldai.h
include/dai/bp.h
include/dai/doc.h
include/dai/factor.h
include/dai/smallset.h
include/dai/trwbp.h [new file with mode: 0644]
src/alldai.cpp
src/fbp.cpp
src/trwbp.cpp [new file with mode: 0644]
tests/aliases.conf

index d2064f6..40ebcbe 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -60,6 +60,10 @@ ifdef WITH_FBP
   CCFLAGS:=$(CCFLAGS) -DDAI_WITH_FBP
   OBJECTS:=$(OBJECTS) fbp$(OE)
 endif
+ifdef WITH_TRWBP
+  CCFLAGS:=$(CCFLAGS) -DDAI_WITH_TRWBP
+  OBJECTS:=$(OBJECTS) trwbp$(OE)
+endif
 ifdef WITH_MF
   CCFLAGS:=$(CCFLAGS) -DDAI_WITH_MF
   OBJECTS:=$(OBJECTS) mf$(OE)
@@ -143,6 +147,9 @@ bp$(OE) : $(SRC)/bp.cpp $(INC)/bp.h $(HEADERS)
 fbp$(OE) : $(SRC)/fbp.cpp $(INC)/fbp.h $(HEADERS)
        $(CC) -c $(SRC)/fbp.cpp
 
+trwbp$(OE) : $(SRC)/trwbp.cpp $(INC)/trwbp.h $(HEADERS)
+       $(CC) -c $(SRC)/trwbp.cpp
+
 bp_dual$(OE) : $(SRC)/bp_dual.cpp $(INC)/bp_dual.h $(HEADERS)
        $(CC) -c $(SRC)/bp_dual.cpp
 
index 0fc6d3a..296cc73 100644 (file)
@@ -19,6 +19,7 @@
 # Enable/disable various approximate inference methods
 WITH_BP=true
 WITH_FBP=true
+WITH_TRWBP=true
 WITH_MF=true
 WITH_HAK=true
 WITH_LC=true
index 0f47e48..657b39f 100644 (file)
@@ -19,6 +19,7 @@
 # Enable/disable various approximate inference methods
 WITH_BP=true
 WITH_FBP=true
+WITH_TRWBP=true
 WITH_MF=true
 WITH_HAK=true
 WITH_LC=true
index fc82a2b..d0066ae 100644 (file)
@@ -19,6 +19,7 @@
 # Enable/disable various approximate inference methods
 WITH_BP=true
 WITH_FBP=true
+WITH_TRWBP=true
 WITH_MF=true
 WITH_HAK=true
 WITH_LC=true
index 8bea63b..24fa9c9 100644 (file)
@@ -20,6 +20,7 @@
 # Enable/disable various approximate inference methods
 WITH_BP=true
 WITH_FBP=true
+WITH_TRWBP=true
 WITH_MF=true
 WITH_HAK=true
 WITH_LC=true
index 0eaeccf..707f3c1 100644 (file)
@@ -1254,6 +1254,7 @@ INCLUDE_FILE_PATTERNS  =
 
 PREDEFINED             = DAI_WITH_BP \
                          DAI_WITH_FBP \
+                         DAI_WITH_TRWBP \
                          DAI_WITH_MF \
                          DAI_WITH_HAK \
                          DAI_WITH_LC \
index ed04ce2..11ccd98 100644 (file)
@@ -30,6 +30,9 @@
 #ifdef DAI_WITH_FBP
     #include <dai/fbp.h>
 #endif
+#ifdef DAI_WITH_TRWBP
+    #include <dai/trwbp.h>
+#endif
 #ifdef DAI_WITH_MF
     #include <dai/mf.h>
 #endif
@@ -105,6 +108,9 @@ static const char* DAINames[] = {
 #ifdef DAI_WITH_FBP
     FBP::Name,
 #endif
+#ifdef DAI_WITH_TRWBP
+    TRWBP::Name,
+#endif
 #ifdef DAI_WITH_MF
     MF::Name,
 #endif
index c7bd160..e178a32 100644 (file)
@@ -239,7 +239,7 @@ class BP : public DAIAlgFG {
         /// Finds the edge which has the maximum residual (difference between new and old message)
         void findMaxResidual( size_t &i, size_t &_I );
         /// Calculates unnormalized belief of variable \a i
-        void calcBeliefV( size_t i, Prob &p ) const;
+        virtual void calcBeliefV( size_t i, Prob &p ) const;
         /// Calculates unnormalized belief of factor \a I
         virtual void calcBeliefF( size_t I, Prob &p ) const;
 
index dbc7904..44dcf90 100644 (file)
  */
 
 /** \page bibliography Bibliography
+ *  \anchor EaG09 \ref EaG09
+ *  F. Eaton and Z. Ghahramani (2009):
+ *  "Choosing a Variable to Clamp",
+ *  <em>Proceedings of the Twelfth International Conference on Artificial Intelligence and Statistics (AISTATS 2009)</em> 5:145-152,
+ *  http://jmlr.csail.mit.edu/proceedings/papers/v5/eaton09a/eaton09a.pdf
+ *
+ *  \anchor EMK06 \ref EMK06
+ *  G. Elidan and I. McGraw and D. Koller (2006):
+ *  "Residual Belief Propagation: Informed Scheduling for Asynchronous Message Passing",
+ *  <em>Proceedings of the 22nd Annual Conference on Uncertainty in Artificial Intelligence (UAI-06)</em>,
+ *  http://uai.sis.pitt.edu/papers/06/UAI2006_0091.pdf
+ *
+ *  \anchor HAK03 \ref HAK03
+ *  T. Heskes and C. A. Albers and H. J. Kappen (2003):
+ *  "Approximate Inference and Constrained Optimization",
+ *  <em>Proceedings of the 19th Annual Conference on Uncertainty in Artificial Intelligence (UAI-03)</em> pp. 313-320,
+ *  http://www.snn.ru.nl/reports/Heskes.uai2003.ps.gz
+ *
  *  \anchor KFL01 \ref KFL01
  *  F. R. Kschischang and B. J. Frey and H.-A. Loeliger (2001):
  *  "Factor Graphs and the Sum-Product Algorithm",
- *  <em>IEEE Transactions on Information Theory</em> 47(2):498-519.
+ *  <em>IEEE Transactions on Information Theory</em> 47(2):498-519,
  *  http://ieeexplore.ieee.org/xpl/freeabs_all.jsp?arnumber=910572
  *
+ *  \anchor Min05 \ref Min05
+ *  T. Minka (2005):
+ *  "Divergence measures and message passing",
+ *  <em>MicroSoft Research Technical Report</em> MSR-TR-2005-173,
+ *  http://research.microsoft.com/en-us/um/people/minka/papers/message-passing/minka-divergence.pdf
+ *
  *  \anchor MiQ04 \ref MiQ04
  *  T. Minka and Y. Qi (2004):
  *  "Tree-structured Approximations by Expectation Propagation",
- *  <em>Advances in Neural Information Processing Systems</em> (NIPS) 16.
+ *  <em>Advances in Neural Information Processing Systems</em> (NIPS) 16,
  *  http://books.nips.cc/papers/files/nips16/NIPS2003_AA25.pdf
  *
- *  \anchor MoR05 \ref MoR05
- *  A. Montanari and T. Rizzo (2005):
- *  "How to Compute Loop Corrections to the Bethe Approximation",
- *  <em>Journal of Statistical Mechanics: Theory and Experiment</em>
- *  2005(10)-P10011.
- *  http://stacks.iop.org/1742-5468/2005/P10011
- *
- *  \anchor YFW05 \ref YFW05
- *  J. S. Yedidia and W. T. Freeman and Y. Weiss (2005):
- *  "Constructing Free-Energy Approximations and Generalized Belief Propagation Algorithms",
- *  <em>IEEE Transactions on Information Theory</em>
- *  51(7):2282-2312.
- *  http://ieeexplore.ieee.org/xpl/freeabs_all.jsp?arnumber=1459044
- *
- *  \anchor HAK03 \ref HAK03
- *  T. Heskes and C. A. Albers and H. J. Kappen (2003):
- *  "Approximate Inference and Constrained Optimization",
- *  <em>Proceedings of the 19th Annual Conference on Uncertainty in Artificial Intelligence (UAI-03)</em> pp. 313-320.
- *  http://www.snn.ru.nl/reports/Heskes.uai2003.ps.gz
- *
  *  \anchor MoK07 \ref MoK07
  *  J. M. Mooij and H. J. Kappen (2007):
  *  "Loop Corrections for Approximate Inference on Factor Graphs",
- *  <em>Journal of Machine Learning Research</em> 8:1113-1143.
+ *  <em>Journal of Machine Learning Research</em> 8:1113-1143,
  *  http://www.jmlr.org/papers/volume8/mooij07a/mooij07a.pdf
  *
  *  \anchor MoK07b \ref MoK07b
  *  J. M. Mooij and H. J. Kappen (2007):
  *  "Sufficient Conditions for Convergence of the Sum-Product Algorithm",
- *  <em>IEEE Transactions on Information Theory</em> 53(12):4422-4437.
+ *  <em>IEEE Transactions on Information Theory</em> 53(12):4422-4437,
  *  http://ieeexplore.ieee.org/xpl/freeabs_all.jsp?arnumber=4385778
  *
- *  \anchor EaG09 \ref EaG09
- *  F. Eaton and Z. Ghahramani (2009):
- *  "Choosing a Variable to Clamp",
- *  <em>Proceedings of the Twelfth International Conference on Artificial Intelligence and Statistics (AISTATS 2009)</em> 5:145-152
- *  http://jmlr.csail.mit.edu/proceedings/papers/v5/eaton09a/eaton09a.pdf
+ *  \anchor MoR05 \ref MoR05
+ *  A. Montanari and T. Rizzo (2005):
+ *  "How to Compute Loop Corrections to the Bethe Approximation",
+ *  <em>Journal of Statistical Mechanics: Theory and Experiment</em> 2005(10)-P10011,
+ *  http://stacks.iop.org/1742-5468/2005/P10011
  *
  *  \anchor StW99 \ref StW99
  *  A. Steger and N. C. Wormald (1999):
  *  "Generating Random Regular Graphs Quickly",
- *  <em>Combinatorics, Probability and Computing</em> Vol 8, Issue 4, pp. 377-396
+ *  <em>Combinatorics, Probability and Computing</em> Vol 8, Issue 4, pp. 377-396,
  *  http://www.math.uwaterloo.ca/~nwormald/papers/randgen.pdf
  *
- *  \anchor EMK06 \ref EMK06
- *  G. Elidan and I. McGraw and D. Koller (2006):
- *  "Residual Belief Propagation: Informed Scheduling for Asynchronous Message Passing",
- *  <em>Proceedings of the 22nd Annual Conference on Uncertainty in Artificial Intelligence (UAI-06)</em>
- *  http://uai.sis.pitt.edu/papers/06/UAI2006_0091.pdf
- *
  *  \anchor WiH03 \ref WiH03
  *  W. Wiegerinck and T. Heskes (2003):
  *  "Fractional Belief Propagation",
- *  <em>Advances in Neural Information Processing Systems</em> (NIPS) 15, pp. 438-445.
+ *  <em>Advances in Neural Information Processing Systems</em> (NIPS) 15, pp. 438-445,
  *  http://books.nips.cc/papers/files/nips15/LT16.pdf
  *
- *  \anchor Min05 \ref Min05
- *  T. Minka (2005):
- *  "Divergence measures and message passing",
- *  <em>MicroSoft Research Technical Report</em> MSR-TR-2005-173,
- *  http://research.microsoft.com/en-us/um/people/minka/papers/message-passing/minka-divergence.pdf
+ *  \anchor WJW03 \ref WJW03
+ *  M. J. Wainwright, T. S. Jaakkola and A. S. Willsky (2003):
+ *  "Tree-reweighted belief propagation algorithms and approximate ML estimation by pseudo-moment matching",
+ *  <em>9th Workshop on Artificial Intelligence and Statistics</em>,
+ *  http://www.eecs.berkeley.edu/~wainwrig/Papers/WJW_AIStat03.pdf
+ *
+ *  \anchor YFW05 \ref YFW05
+ *  J. S. Yedidia and W. T. Freeman and Y. Weiss (2005):
+ *  "Constructing Free-Energy Approximations and Generalized Belief Propagation Algorithms",
+ *  <em>IEEE Transactions on Information Theory</em> 51(7):2282-2312,
+ *  http://ieeexplore.ieee.org/xpl/freeabs_all.jsp?arnumber=1459044
  */
 
 
index d8132cc..4602317 100644 (file)
@@ -198,6 +198,8 @@ template <typename T> class TFactor {
         }
 
         /// Returns normalized copy of \c *this, using the specified norm
+        /** \throw NOT_NORMALIZABLE if the norm is zero
+         */
         TFactor<T> normalized( typename TProb<T>::NormType norm=TProb<T>::NORMPROB ) const {
             TFactor<T> x;
             x._vs = _vs;
@@ -215,6 +217,8 @@ template <typename T> class TFactor {
         TFactor<T>& setUniform () { _p.setUniform(); return *this; }
 
         /// Normalizes factor using the specified norm
+        /** \throw NOT_NORMALIZABLE if the norm is zero
+         */
         T normalize( typename TProb<T>::NormType norm=TProb<T>::NORMPROB ) { return _p.normalize( norm ); }
     //@}
 
index ecb2878..dd2c4c4 100644 (file)
@@ -196,6 +196,11 @@ class SmallSet {
         reverse_iterator rend() { return _elements.rend(); }
         /// Returns constant reverse iterator that points beyond the first element
         const_reverse_iterator rend() const { return _elements.rend(); }
+
+        /// Returns reference to first element
+        T& front() { return _elements.at(0); }
+        /// Returns constant reference to first element
+        const T& front() const { return _elements.at(0); }
     //@}
 
     /// \name Comparison operators
diff --git a/include/dai/trwbp.h b/include/dai/trwbp.h
new file mode 100644 (file)
index 0000000..d37c788
--- /dev/null
@@ -0,0 +1,110 @@
+/*  This file is part of libDAI - http://www.libdai.org/
+ *
+ *  libDAI is licensed under the terms of the GNU General Public License version
+ *  2, or (at your option) any later version. libDAI is distributed without any
+ *  warranty. See the file COPYING for more details.
+ *
+ *  Copyright (C) 2010 Joris Mooij
+ */
+
+
+/// \file
+/// \brief Defines class TRWBP, which implements Tree-Reweighted Belief Propagation
+
+
+#ifndef __defined_libdai_trwbp_h
+#define __defined_libdai_trwbp_h
+
+
+#include <string>
+#include <dai/daialg.h>
+#include <dai/factorgraph.h>
+#include <dai/properties.h>
+#include <dai/enum.h>
+#include <dai/bp.h>
+
+
+namespace dai {
+
+
+/// Approximate inference algorithm "Tree-Reweighted Belief Propagation" [\ref WJW03]
+/** The Tree-Reweighted Belief Propagation algorithm is like Belief
+ *  Propagation, but associates each factor with a scale parameter.
+ *  which controls the divergence measure being minimized.
+ *
+ *  The messages \f$m_{I\to i}(x_i)\f$ are passed from factors \f$I\f$ to variables \f$i\f$. 
+ *  The update equation is given by:
+ *    \f[ m_{I\to i}(x_i) \propto \sum_{x_{N_I\setminus\{i\}}} f_I(x_I)^{1/c_I} \prod_{j\in N_I\setminus\{i\}} m_{I\to j}^{c_I-1} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}^{c_J} \f]
+ *  After convergence, the variable beliefs are calculated by:
+ *    \f[ b_i(x_i) \propto \prod_{I\in N_i} m_{I\to i}^{c_I} \f]
+ *  and the factor beliefs are calculated by:
+ *    \f[ b_I(x_I) \propto f_I(x_I)^{1/c_I} \prod_{j \in N_I} m_{I\to j}^{c_I-1} \prod_{J\in N_j\setminus\{I\}} m_{J\to j}^{c_J} \f]
+ *
+ *  \todo Fix documentation
+ */
+class TRWBP : public BP {
+    protected:
+        /// Factor scale parameters (indexed by factor ID)
+        std::vector<Real> _edge_weight;
+
+    public:
+        /// Name of this inference algorithm
+        static const char *Name;
+
+    public:
+    /// \name Constructors/destructors
+    //@{
+        /// Default constructor
+        TRWBP() : BP(), _edge_weight() {}
+
+        /// Construct from FactorGraph \a fg and PropertySet \a opts
+        /** \param opts Parameters @see BP::Properties
+         */
+        TRWBP( const FactorGraph &fg, const PropertySet &opts ) : BP(fg, opts), _edge_weight() {
+            setProperties( opts );
+            construct();
+        }
+    //@}
+
+    /// \name General InfAlg interface
+    //@{
+        virtual TRWBP* clone() const { return new TRWBP(*this); }
+        virtual std::string identify() const;
+        virtual Real logZ() const;
+    //@}
+
+    /// \name TRWBP accessors/mutators for scale parameters
+    //@{
+        /// Returns scale parameter of edge corresponding to the \a I 'th factor
+        Real edgeWeight( size_t I ) const { return _edge_weight[I]; }
+
+        /// Returns constant reference to vector of all factor scale parameters
+        const std::vector<Real>& edgeWeights() const { return _edge_weight; }
+
+        /// Sets the scale parameter of the \a I 'th factor to \a c
+        void setEdgeWeight( size_t I, Real c ) { _edge_weight[I] = c; }
+
+        /// Sets the scale parameters of all factors simultaenously
+        /** \note Faster than calling setScaleF(size_t,Real) for each factor
+         */
+        void setEdgeWeights( const std::vector<Real> &c ) { _edge_weight = c; }
+
+    protected:
+        // Calculate the updated message from the \a _I 'th neighbor of variable \a i to variable \a i
+        virtual void calcNewMessage( size_t i, size_t _I );
+
+        /// Calculates unnormalized belief of variable \a i
+        virtual void calcBeliefV( size_t i, Prob &p ) const;
+
+        // Calculates unnormalized belief of factor \a I
+        virtual void calcBeliefF( size_t I, Prob &p ) const;
+
+        // Helper function for constructors
+        virtual void construct();
+};
+
+
+} // end of namespace dai
+
+
+#endif
index 1468bb6..af520e6 100644 (file)
@@ -32,6 +32,10 @@ InfAlg *newInfAlg( const std::string &name, const FactorGraph &fg, const Propert
     if( name == FBP::Name )
         return new FBP (fg, opts);
 #endif
+#ifdef DAI_WITH_TRWBP
+    if( name == TRWBP::Name )
+        return new TRWBP (fg, opts);
+#endif
 #ifdef DAI_WITH_MF
     if( name == MF::Name )
         return new MF (fg, opts);
index 0e67cc5..893388e 100644 (file)
@@ -73,16 +73,14 @@ void FBP::calcNewMessage( size_t i, size_t _I ) {
                         prod_j += message( j, J.iter );
                     else
                         prod_j *= message( j, J.iter );
+                } else {
+                    // FBP: multiply by m_Ij^(1-1/c_I)
+                    if( props.logdomain )
+                        prod_j += message( j, J.iter )*(1-1/scale);
+                    else
+                        prod_j *= message( j, J.iter )^(1-1/scale);
                 }
 
-
-            size_t _I = j.dual;
-            // FBP: now multiply by m_Ij^(1-1/c_I)
-            if(props.logdomain)
-                prod_j += message( j, _I)*(1-1/scale);
-            else
-                prod_j *= message( j, _I)^(1-1/scale);
-
             // multiply prod with prod_j
             if( !DAI_FBP_FAST ) {
                 /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
@@ -168,16 +166,14 @@ void FBP::calcBeliefF( size_t I, Prob &p ) const {
                     prod_j += newMessage( j, J.iter );
                 else
                     prod_j *= newMessage( j, J.iter );
+            } else {
+                // FBP: multiply by m_Ij^(1-1/c_I)
+                if( props.logdomain )
+                    prod_j += newMessage( j, J.iter)*(1-1/scale);
+                else
+                    prod_j *= newMessage( j, J.iter)^(1-1/scale);
             }
 
-        size_t _I = j.dual;
-
-        // FBP: now multiply by m_Ij^(1-1/c_I)
-        if( props.logdomain )
-            prod_j += newMessage( j, _I)*(1-1/scale);
-        else
-            prod_j *= newMessage( j, _I)^(1-1/scale);
-
         // multiply prod with prod_j
         if( !DAI_FBP_FAST ) {
             /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
diff --git a/src/trwbp.cpp b/src/trwbp.cpp
new file mode 100644 (file)
index 0000000..407d847
--- /dev/null
@@ -0,0 +1,228 @@
+/*  This file is part of libDAI - http://www.libdai.org/
+ *
+ *  libDAI is licensed under the terms of the GNU General Public License version
+ *  2, or (at your option) any later version. libDAI is distributed without any
+ *  warranty. See the file COPYING for more details.
+ *
+ *  Copyright (C) 2010  Joris Mooij  [joris dot mooij at libdai dot org]
+ */
+
+
+#include <dai/trwbp.h>
+
+
+#define DAI_TRWBP_FAST 1
+
+
+namespace dai {
+
+
+using namespace std;
+
+
+const char *TRWBP::Name = "TRWBP";
+
+
+string TRWBP::identify() const {
+    return string(Name) + printProperties();
+}
+
+
+/* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
+Real TRWBP::logZ() const {
+    Real sum = 0.0;
+    for( size_t I = 0; I < nrFactors(); I++ ) {
+        sum += (beliefF(I) * factor(I).log(true)).sum();  // TRWBP
+        if( factor(I).vars().size() == 2 )
+            sum -= edgeWeight(I) * MutualInfo( beliefF(I) );  // TRWBP
+    }
+    for( size_t i = 0; i < nrVars(); ++i )
+        sum += beliefV(i).entropy();  // TRWBP
+    return sum;
+}
+
+
+/* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
+void TRWBP::calcNewMessage( size_t i, size_t _I ) {
+    // calculate updated message I->i
+    size_t I = nbV(i,_I);
+    const Var &v_i = var(i);
+    const VarSet &v_I = factor(I).vars();
+    Real c_I = edgeWeight(I); // TRWBP: c_I (\mu_I in the paper)
+
+    Prob marg;
+    if( v_I.size() == 1 ) { // optimization
+        marg = factor(I).p();
+    } else
+        Factor Fprod( factor(I) );
+        Prob &prod = Fprod.p();
+        if( props.logdomain ) {
+            prod.takeLog();
+            prod /= c_I;         // TRWBP
+        } else
+            prod ^= (1.0 / c_I); // TRWBP
+    
+        // Calculate product of incoming messages and factor I
+        foreach( const Neighbor &j, nbF(I) )
+            if( j != i ) { // for all j in I \ i
+                const Var &v_j = var(j);
+
+                // TRWBP: corresponds to messages n_jI
+                // prod_j will be the product of messages coming into j
+                Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
+                foreach( const Neighbor &J, nbV(j) ) {
+                    Real c_J = edgeWeight(J);
+                    if( J != I ) { // for all J in nb(j) \ I
+                        if( props.logdomain )
+                            prod_j += message( j, J.iter ) * c_J;
+                        else
+                            prod_j *= message( j, J.iter ) ^ c_J;
+                    } else { // TRWBP: multiply by m_Ij^(c_I-1)
+                        if( props.logdomain )
+                            prod_j += message( j, J.iter ) * (c_J - 1.0);
+                        else
+                            prod_j *= message( j, J.iter ) ^ (c_J - 1.0);
+                    }
+                }
+
+                // multiply prod with prod_j
+                if( !DAI_TRWBP_FAST ) {
+                    /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+                    if( props.logdomain )
+                        Fprod += Factor( v_j, prod_j );
+                    else
+                        Fprod *= Factor( v_j, prod_j );
+                } else {
+                    /* OPTIMIZED VERSION */
+                    // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+                    const ind_t &ind = index(j, _I);
+                    for( size_t r = 0; r < prod.size(); ++r )
+                        if( props.logdomain )
+                            prod[r] += prod_j[ind[r]];
+                        else
+                            prod[r] *= prod_j[ind[r]];
+                }
+            }
+
+        if( props.logdomain ) {
+            prod -= prod.max();
+            prod.takeExp();
+        }
+
+        // Marginalize onto i
+        if( !DAI_TRWBP_FAST ) {
+            /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+            if( props.inference == Properties::InfType::SUMPROD )
+                marg = Fprod.marginal( v_i ).p();
+            else
+                marg = Fprod.maxMarginal( v_i ).p();
+        } else {
+            /* OPTIMIZED VERSION */
+            marg = Prob( v_i.states(), 0.0 );
+            // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
+            const ind_t ind = index(i,_I);
+            if( props.inference == Properties::InfType::SUMPROD )
+                for( size_t r = 0; r < prod.size(); ++r )
+                    marg[ind[r]] += prod[r];
+            else
+                for( size_t r = 0; r < prod.size(); ++r )
+                    if( prod[r] > marg[ind[r]] )
+                        marg[ind[r]] = prod[r];
+            marg.normalize();
+        }
+
+        // Store result
+        if( props.logdomain )
+            newMessage(i,_I) = marg.log();
+        else
+            newMessage(i,_I) = marg;
+
+        // Update the residual if necessary
+        if( props.updates == Properties::UpdateType::SEQMAX )
+            updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
+    }
+}
+
+
+/* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
+void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
+    p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
+    foreach( const Neighbor &I, nbV(i) ) {
+        Real c_I = edgeWeight(I);
+        if( props.logdomain )
+            p += newMessage( i, I.iter ) * c_I;
+        else
+            p *= newMessage( i, I.iter ) ^ c_I;
+    }
+}
+
+
+/* This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour */
+void TRWBP::calcBeliefF( size_t I, Prob &p ) const {
+    Real c_I = edgeWeight(I); // TRWBP: c_I
+    const VarSet &v_I = factor(I).vars();
+
+    Factor Fprod( factor(I) );
+    Prob &prod = Fprod.p();
+
+    if( props.logdomain ) {
+        prod.takeLog();
+        prod /= c_I; // TRWBP
+    } else
+        prod ^= (1.0 / c_I); // TRWBP
+
+    // Calculate product of incoming messages and factor I
+    foreach( const Neighbor &j, nbF(I) ) {
+        const Var &v_j = var(j);
+
+        // TRWBP: corresponds to messages n_jI
+        // prod_j will be the product of messages coming into j
+        Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
+        foreach( const Neighbor &J, nbV(j) ) {
+            Real c_J = edgeWeight(J);
+            if( J != I ) { // for all J in nb(j) \ I
+                if( props.logdomain )
+                    prod_j += newMessage( j, J.iter ) * c_J;
+                else
+                    prod_j *= newMessage( j, J.iter ) ^ c_J;
+            } else { // TRWBP: multiply by m_Ij^(c_I-1)
+                if( props.logdomain )
+                    prod_j += newMessage( j, J.iter ) * (c_J - 1.0);
+                else
+                    prod_j *= newMessage( j, J.iter ) ^ (c_J - 1.0);
+            }
+        }
+
+        // multiply prod with prod_j
+        if( !DAI_TRWBP_FAST ) {
+            /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+            if( props.logdomain )
+                Fprod += Factor( v_j, prod_j );
+            else
+                Fprod *= Factor( v_j, prod_j );
+        } else {
+            /* OPTIMIZED VERSION */
+            size_t _I = j.dual;
+            // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+            const ind_t &ind = index(j, _I);
+
+            for( size_t r = 0; r < prod.size(); ++r ) {
+                if( props.logdomain )
+                    prod[r] += prod_j[ind[r]];
+                else
+                    prod[r] *= prod_j[ind[r]];
+            }
+        }
+    }
+    
+    p = prod;
+}
+
+
+void TRWBP::construct() {
+    BP::construct();
+    _edge_weight.resize( nrFactors(), 1.0 );
+}
+
+
+} // end of namespace dai
index ff7c8b1..4608cc6 100644 (file)
@@ -26,6 +26,10 @@ MP_PARALL_LOG:                  BP[updates=PARALL,tol=1e-9,maxiter=10000,logdoma
 
 FBP:                            FBP[updates=SEQFIX,tol=1e-9,maxiter=10000,logdomain=0]
 
+# --- TRWBP -------------------
+
+TRWBP:                          TRWBP[updates=SEQFIX,tol=1e-9,maxiter=10000,logdomain=0]
+
 # --- JTREE -------------------
 
 JTREE_HUGIN:                    JTREE[updates=HUGIN,verbose=0]