Cleanup of BP_dual code
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Fri, 31 Jul 2009 15:25:21 +0000 (17:25 +0200)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Fri, 31 Jul 2009 15:25:21 +0000 (17:25 +0200)
include/dai/bp_dual.h
src/alldai.cpp
src/bp_dual.cpp

index ed592bb..0164635 100644 (file)
@@ -20,8 +20,7 @@
 
 /// \file
 /// \brief Defines class BP_dual
-/// \todo Improve documentation
-/// \todo Clean up
+/// \todo This replicates a large part of the functionality of BP; would it not be shorter to adapt BP instead?
 
 
 #ifndef __defined_libdai_bp_dual_h
@@ -36,7 +35,9 @@
 namespace dai {
 
 
-/** Class to estimate "dual" versions of BP messages, and normalizers, given an InfAlg. 
+/// Calculates both types of BP messages and their normalizers from an InfAlg.
+/** BP_dual calculates "dual" versions of BP messages (both messages from factors
+ *  to variables and messages from variables to factors), and normalizers, given an InfAlg. 
  *  These are computed from the variable and factor beliefs of the InfAlg.
  *  This class is used primarily by BBP.
  */
@@ -48,40 +49,57 @@ class BP_dual {
         struct _edges_t : public std::vector<std::vector<T> > {};
 
         /// Groups together the data structures for storing the two types of messages and their normalizers
-        struct messages {
+        struct messages {            
+            /// Unnormalized variable->factor messages
             _edges_t<Prob> n;
+            /// Normalizers of variable->factor messages
             _edges_t<Real> Zn;
+            /// Unnormalized Factor->variable messages
             _edges_t<Prob> m;
+            /// Normalizers of factor->variable messages
             _edges_t<Real> Zm;
         };
+        /// Stores all messages
         messages _msgs;
 
         /// Groups together the data structures for storing the two types of beliefs and their normalizers
         struct beliefs {
-            // indexed by node
+            /// Unnormalized variable beliefs
             std::vector<Prob> b1;
+            /// Normalizers of variable beliefs
             std::vector<Real> Zb1;
-            // indexed by factor
+            /// Unnormalized factor beliefs
             std::vector<Prob> b2;
+            /// Normalizers of factor beliefs
             std::vector<Real> Zb2;
         };
+        /// Stores all beliefs
         beliefs _beliefs;
 
+        /// Pointer to the InfAlg object
         const InfAlg *_ia;
-            
+        
+        /// Does all necessary preprocessing
         void init();
-
+        /// Allocates space for _msgs
         void regenerateMessages();
+        /// Allocates space for _beliefs
         void regenerateBeliefs();
 
+        /// Calculate all messages from InfAlg beliefs
         void calcMessages();
-        void calcBeliefV(size_t i);
-        void calcBeliefF(size_t I);
-        void calcBeliefs();
-
+        /// Update factor->variable message (i->I)
         void calcNewM(size_t i, size_t _I);
+        /// Update variable->factor message (I->i)
         void calcNewN(size_t i, size_t _I);
 
+        /// Calculate all variable and factor beliefs from messages
+        void calcBeliefs();
+        /// Calculate variable belief
+        void calcBeliefV(size_t i);
+        /// Calculate factor belief
+        void calcBeliefF(size_t I);
+
     public:
         /// Construct BP_dual object from (converged) InfAlg object's beliefs and factors. 
         /*  A pointer to the the InfAlg object is stored, 
@@ -89,27 +107,27 @@ class BP_dual {
          */
         BP_dual( const InfAlg *ia ) : _ia(ia) { init(); }
 
+        /// Returns the underlying FactorGraph
         const FactorGraph& fg() const { return _ia->fg(); }
 
-        /// factor -> var message
-        DAI_ACCMUT(Prob & msgM(size_t i, size_t _I), { return _msgs.m[i][_I]; });
-        /// var -> factor message
-        DAI_ACCMUT(Prob & msgN(size_t i, size_t _I), { return _msgs.n[i][_I]; });
-        /// Normalizer for msgM
-        DAI_ACCMUT(Real & zM(size_t i, size_t _I), { return _msgs.Zm[i][_I]; });
-        /// Normalizer for msgN
-        DAI_ACCMUT(Real & zN(size_t i, size_t _I), { return _msgs.Zn[i][_I]; });
-
-        /// Variable belief
-        Factor beliefV(size_t i) const { return Factor(_ia->fg().var(i), _beliefs.b1[i]); }
-        /// Factor belief
-        Factor beliefF(size_t I) const { return Factor(_ia->fg().factor(I).vars(), _beliefs.b2[I]); }
-
-        /// Normalizer for variable belief
-        Real beliefVZ(size_t i) const { return _beliefs.Zb1[i]; }
-        /// Normalizer for factor belief
-        Real beliefFZ(size_t I) const { return _beliefs.Zb2[I]; }
-
+        /// Returns factor -> var message (I->i)
+        DAI_ACCMUT(Prob & msgM( size_t i, size_t _I ), { return _msgs.m[i][_I]; });
+        /// Returns var -> factor message (i->I)
+        DAI_ACCMUT(Prob & msgN( size_t i, size_t _I ), { return _msgs.n[i][_I]; });
+        /// Returns normalizer for msgM
+        DAI_ACCMUT(Real & zM( size_t i, size_t _I ), { return _msgs.Zm[i][_I]; });
+        /// Returns normalizer for msgN
+        DAI_ACCMUT(Real & zN( size_t i, size_t _I ), { return _msgs.Zn[i][_I]; });
+
+        /// Returns variable belief
+        Factor beliefV( size_t i ) const { return Factor( _ia->fg().var(i), _beliefs.b1[i] ); }
+        /// Returns factor belief
+        Factor beliefF( size_t I ) const { return Factor( _ia->fg().factor(I).vars(), _beliefs.b2[I] ); }
+
+        /// Returns normalizer for variable belief
+        Real beliefVZ( size_t i ) const { return _beliefs.Zb1[i]; }
+        /// Returns normalizer for factor belief
+        Real beliefFZ( size_t I ) const { return _beliefs.Zb2[I]; }
 };
 
 
index a65110a..3c5fbba 100644 (file)
@@ -76,17 +76,17 @@ InfAlg *newInfAlg( const std::string &name, const FactorGraph &fg, const Propert
 
 
 /// \todo Make alias file non-testdai-specific, and use it in newInfAlgFromString
-InfAlg *newInfAlgFromString( const std::string &s, const FactorGraph &fg ) {
-    string::size_type pos = s.find_first_of('[');
+InfAlg *newInfAlgFromString( const std::string &nameOpts, const FactorGraph &fg ) {
+    string::size_type pos = nameOpts.find_first_of('[');
     string name;
     PropertySet opts;
     if( pos == string::npos ) {
-        name = s;
+        name = nameOpts;
     } else {
-        name = s.substr(0,pos);
+        name = nameOpts.substr(0,pos);
 
         stringstream ss;
-        ss << s.substr(pos,s.length());
+        ss << nameOpts.substr(pos,nameOpts.length());
         ss >> opts;
     }
     return newInfAlg(name,fg,opts);
index 008813c..44014fe 100644 (file)
@@ -36,6 +36,14 @@ using namespace std;
 typedef BipartiteGraph::Neighbor Neighbor;
 
 
+void BP_dual::init() {
+    regenerateMessages();
+    regenerateBeliefs();
+    calcMessages();
+    calcBeliefs();
+}
+
+
 void BP_dual::regenerateMessages() {
     size_t nv = fg().nrVars();
     _msgs.Zn.resize(nv);
@@ -68,99 +76,40 @@ void BP_dual::regenerateBeliefs() {
 }
 
 
-void BP_dual::init() {
-    regenerateMessages();
-    regenerateBeliefs();
-    calcMessages();
-    calcBeliefs();
-}
-
-
 void BP_dual::calcMessages() {
     // calculate 'n' messages from "factor marginal / factor"
-    vector<Factor> bs;
-    size_t nf = fg().nrFactors();
-    for( size_t I = 0; I < nf; I++ )
-        bs.push_back(_ia->beliefF(I));
-    assert(nf == bs.size());
-    for( size_t I = 0; I < nf; I++ ) {
-        Factor f = bs[I];
-        f /= fg().factor(I);
-        foreach(const Neighbor &i, fg().nbF(I))
-            msgN(i, i.dual) = f.marginal(fg().var(i)).p();
+    for( size_t I = 0; I < fg().nrFactors(); I++ ) {
+        Factor f = _ia->beliefF(I) / fg().factor(I);
+        foreach( const Neighbor &i, fg().nbF(I) )
+            msgN(i, i.dual) = f.marginal( fg().var(i) ).p();
     }
     // calculate 'm' messages and normalizers from 'n' messages
     for( size_t i = 0; i < fg().nrVars(); i++ )
-        foreach(const Neighbor &I, fg().nbV(i))
-            calcNewM(i, I.iter);
+        foreach( const Neighbor &I, fg().nbV(i) )
+            calcNewM( i, I.iter );
     // recalculate 'n' messages and normalizers from 'm' messages
-    for( size_t i = 0; i < fg().nrVars(); i++ ) {
-        foreach(const Neighbor &I, fg().nbV(i)) {
-            Prob oldN = msgN(i,I.iter);
-            calcNewN(i, I.iter);
-            Prob newN = msgN(i,I.iter);
-#if 0
-            // check that new 'n' messages match old ones
-            if((oldN-newN).maxAbs() > 1.0e-5) {
-                cerr << "New 'n' messages don't match old: " <<
-                    "(i,I) = (" << i << ", " << I << 
-                    ") old = " << oldN << ", new = " << newN << endl;
-                DAI_THROW(INTERNAL_ERROR);
-            }
-#endif
-        }
-    }
-}
-
-
-void BP_dual::calcBeliefV(size_t i) {
-    Prob prod( fg().var(i).states(), 1.0 );
-    foreach(const Neighbor &I, fg().nbV(i))
-        prod *= msgM(i,I.iter);
-    _beliefs.Zb1[i] = prod.normalize();
-    _beliefs.b1[i] = prod;
-}
-
-
-void BP_dual::calcBeliefF(size_t I) {
-    Prob prod( fg().factor(I).p() );
-    foreach(const Neighbor &j, fg().nbF(I)) {
-        IndexFor ind (fg().var(j), fg().factor(I).vars() );
-        Prob n(msgN(j,j.dual));
-        for(size_t x=0; ind >= 0; x++, ++ind)
-            prod[x] *= n[ind];
-    }
-    _beliefs.Zb2[I] = prod.normalize();
-    _beliefs.b2[I] = prod;
-}
-
-
-// called after run()
-void BP_dual::calcBeliefs() {
     for( size_t i = 0; i < fg().nrVars(); i++ )
-        calcBeliefV(i);  // calculate b_i
-    for( size_t I = 0; I < fg().nrFactors(); I++ )
-        calcBeliefF(I);  // calculate b_I
+        foreach( const Neighbor &I, fg().nbV(i) )
+            calcNewN(i, I.iter);
 }
 
 
-void BP_dual::calcNewM(size_t i, size_t _I) {
+void BP_dual::calcNewM( size_t i, size_t _I ) {
     // calculate updated message I->i
     const Neighbor &I = fg().nbV(i)[_I];
     Prob prod( fg().factor(I).p() );
-    foreach(const Neighbor &j, fg().nbF(I)) {
-        if( j != i ) {     // for all j in I \ i
-            Prob n(msgN(j,j.dual));
-            IndexFor ind(fg().var(j), fg().factor(I).vars());
-            for(size_t x=0; ind >= 0; x++, ++ind)
+    foreach( const Neighbor &j, fg().nbF(I) )
+        if( j != i ) { // for all j in I \ i
+            Prob &n = msgN(j,j.dual);
+            IndexFor ind( fg().var(j), fg().factor(I).vars() );
+            for( size_t x = 0; ind >= 0; x++, ++ind )
                 prod[x] *= n[ind];
         }
-    }
     // Marginalize onto i
     Prob marg( fg().var(i).states(), 0.0 );
     // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
-    IndexFor ind(fg().var(i), fg().factor(I).vars());
-    for(size_t x=0; ind >= 0; x++, ++ind)
+    IndexFor ind( fg().var(i), fg().factor(I).vars() );
+    for( size_t x = 0; ind >= 0; x++, ++ind )
         marg[ind] += prod[x];
     
     _msgs.Zm[i][_I] = marg.normalize();
@@ -168,17 +117,46 @@ void BP_dual::calcNewM(size_t i, size_t _I) {
 }
 
 
-void BP_dual::calcNewN(size_t i, size_t _I) {
+void BP_dual::calcNewN( size_t i, size_t _I ) {
     // calculate updated message i->I
     const Neighbor &I = fg().nbV(i)[_I];
-    Prob prod(fg().var(i).states(), 1.0);
-    foreach(const Neighbor &J, fg().nbV(i)) {
-        if(J.node != I.node) // for all J in i \ I
+    Prob prod( fg().var(i).states(), 1.0 );
+    foreach( const Neighbor &J, fg().nbV(i) )
+        if( J.node != I.node ) // for all J in i \ I
             prod *= msgM(i,J.iter);
-    }
     _msgs.Zn[i][_I] = prod.normalize();
     _msgs.n[i][_I] = prod;
 }
 
 
+void BP_dual::calcBeliefs() {
+    for( size_t i = 0; i < fg().nrVars(); i++ )
+        calcBeliefV(i);  // calculate b_i
+    for( size_t I = 0; I < fg().nrFactors(); I++ )
+        calcBeliefF(I);  // calculate b_I
+}
+
+
+void BP_dual::calcBeliefV( size_t i ) {
+    Prob prod( fg().var(i).states(), 1.0 );
+    foreach( const Neighbor &I, fg().nbV(i) )
+        prod *= msgM(i,I.iter);
+    _beliefs.Zb1[i] = prod.normalize();
+    _beliefs.b1[i] = prod;
+}
+
+
+void BP_dual::calcBeliefF( size_t I ) {
+    Prob prod( fg().factor(I).p() );
+    foreach( const Neighbor &j, fg().nbF(I) ) {
+        IndexFor ind( fg().var(j), fg().factor(I).vars() );
+        Prob n( msgN(j,j.dual) );
+        for( size_t x = 0; ind >= 0; x++, ++ind )
+            prod[x] *= n[ind];
+    }
+    _beliefs.Zb2[I] = prod.normalize();
+    _beliefs.b2[I] = prod;
+}
+
+
 } // end of namespace dai