Improved mf.h/cpp by adding "init" and "updates" options
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Mon, 19 Apr 2010 15:59:27 +0000 (17:59 +0200)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Mon, 19 Apr 2010 15:59:27 +0000 (17:59 +0200)
ChangeLog
include/dai/mf.h
src/mf.cpp
tests/aliases.conf
tests/testall
tests/testall.bat
tests/testfast.out

index 2825f55..7d9fac3 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -6,6 +6,10 @@ git master HEAD
 * [Stefano Pellegrini] Fixed bug in JTree::findMaximum()
 * Fixed some bugs in the MatLab interface build system
 * Fixed a bug in utils/fginfo.cpp
 * [Stefano Pellegrini] Fixed bug in JTree::findMaximum()
 * Fixed some bugs in the MatLab interface build system
 * Fixed a bug in utils/fginfo.cpp
+* Improved mf.h/cpp:
+  - Added an "init" option that allows to initialize randomly
+  - Added an "updates" option that allows to choose between standard
+    mean field and "hard-spin" mean field
 * Improved treeep.h/cpp:
   - changed TreeEP::construct( const RootedTree& ) into
     TreeEP::construct( const FactorGraph&, const RootedTree& )
 * Improved treeep.h/cpp:
   - changed TreeEP::construct( const RootedTree& ) into
     TreeEP::construct( const FactorGraph&, const RootedTree& )
index 596cfd6..0f16815 100644 (file)
@@ -18,6 +18,7 @@
 
 
 #include <string>
 
 
 #include <string>
+#include <dai/enum.h>
 #include <dai/daialg.h>
 #include <dai/factorgraph.h>
 #include <dai/properties.h>
 #include <dai/daialg.h>
 #include <dai/factorgraph.h>
 #include <dai/properties.h>
@@ -31,6 +32,9 @@ namespace dai {
  *  single variable marginals (beliefs). The update equation for 
  *  a single belief \f$b_i\f$ is given by:
  *    \f[ b_i^{\mathrm{new}}(x_i) \propto \prod_{I\in N_i} \exp \left( \sum_{x_{N_I \setminus \{i\}}} \log f_I(x_I) \prod_{j \in N_I \setminus \{i\}} b_j(x_j) \right) \f]
  *  single variable marginals (beliefs). The update equation for 
  *  a single belief \f$b_i\f$ is given by:
  *    \f[ b_i^{\mathrm{new}}(x_i) \propto \prod_{I\in N_i} \exp \left( \sum_{x_{N_I \setminus \{i\}}} \log f_I(x_I) \prod_{j \in N_I \setminus \{i\}} b_j(x_j) \right) \f]
+ *  for naive mean field and by
+ *    \f[ b_i^{\mathrm{new}}(x_i) \propto \prod_{I\in N_i} \left( \sum_{x_{N_I \setminus \{i\}}} f_I(x_I) \prod_{j \in N_I \setminus \{i\}} b_j(x_j) \right) \f]
+ *  for hard-spin mean field.
  *  These update equations are performed for all variables until convergence.
  */
 class MF : public DAIAlgFG {
  *  These update equations are performed for all variables until convergence.
  */
 class MF : public DAIAlgFG {
@@ -45,6 +49,12 @@ class MF : public DAIAlgFG {
     public:
         /// Parameters for MF
         struct Properties {
     public:
         /// Parameters for MF
         struct Properties {
+            /// Enumeration of possible message initializations
+            DAI_ENUM(InitType,UNIFORM,RANDOM);
+
+            /// Enumeration of possible update types
+            DAI_ENUM(UpdateType,NAIVE,HARDSPIN);
+
             /// Verbosity (amount of output sent to stderr)
             size_t verbose;
 
             /// Verbosity (amount of output sent to stderr)
             size_t verbose;
 
@@ -56,6 +66,12 @@ class MF : public DAIAlgFG {
 
             /// Damping constant (0.0 means no damping, 1.0 is maximum damping)
             Real damping;
 
             /// Damping constant (0.0 means no damping, 1.0 is maximum damping)
             Real damping;
+            
+            /// How to initialize the messages/beliefs
+            InitType init;
+
+            /// How to update the messages/beliefs
+            UpdateType updates;
         } props;
 
         /// Name of this inference algorithm
         } props;
 
         /// Name of this inference algorithm
index 535ceed..d2be915 100644 (file)
@@ -40,6 +40,14 @@ void MF::setProperties( const PropertySet &opts ) {
         props.damping = opts.getStringAs<Real>("damping");
     else
         props.damping = 0.0;
         props.damping = opts.getStringAs<Real>("damping");
     else
         props.damping = 0.0;
+    if( opts.hasKey("init") )
+        props.init = opts.getStringAs<Properties::InitType>("init");
+    else
+        props.init = Properties::InitType::UNIFORM;
+    if( opts.hasKey("updates") )
+        props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+    else
+        props.updates = Properties::UpdateType::NAIVE;
 }
 
 
 }
 
 
@@ -49,6 +57,8 @@ PropertySet MF::getProperties() const {
     opts.set( "maxiter", props.maxiter );
     opts.set( "verbose", props.verbose );
     opts.set( "damping", props.damping );
     opts.set( "maxiter", props.maxiter );
     opts.set( "verbose", props.verbose );
     opts.set( "damping", props.damping );
+    opts.set( "init", props.init );
+    opts.set( "updates", props.updates );
     return opts;
 }
 
     return opts;
 }
 
@@ -59,6 +69,8 @@ string MF::printProperties() const {
     s << "tol=" << props.tol << ",";
     s << "maxiter=" << props.maxiter << ",";
     s << "verbose=" << props.verbose << ",";
     s << "tol=" << props.tol << ",";
     s << "maxiter=" << props.maxiter << ",";
     s << "verbose=" << props.verbose << ",";
+    s << "init=" << props.init << ",";
+    s << "updates=" << props.updates << ",";
     s << "damping=" << props.damping << "]";
     return s.str();
 }
     s << "damping=" << props.damping << "]";
     return s.str();
 }
@@ -79,23 +91,30 @@ string MF::identify() const {
 
 
 void MF::init() {
 
 
 void MF::init() {
-    for( vector<Factor>::iterator qi = _beliefs.begin(); qi != _beliefs.end(); qi++ )
-        qi->fill(1.0);
+    if( props.init == Properties::InitType::UNIFORM )
+        for( size_t i = 0; i < nrVars(); i++ )
+            _beliefs[i].fill( 1.0 );
+    else
+        for( size_t i = 0; i < nrVars(); i++ )
+            _beliefs[i].randomize();
 }
 
 
 Factor MF::calcNewBelief( size_t i ) {
     Factor result;
     foreach( const Neighbor &I, nbV(i) ) {
 }
 
 
 Factor MF::calcNewBelief( size_t i ) {
     Factor result;
     foreach( const Neighbor &I, nbV(i) ) {
-        Factor henk;
+        Factor belief_I_minus_i;
         foreach( const Neighbor &j, nbF(I) ) // for all j in I \ i
             if( j != i )
         foreach( const Neighbor &j, nbF(I) ) // for all j in I \ i
             if( j != i )
-                henk *= _beliefs[j];
-        Factor piet = factor(I).log(true);
-        piet *= henk;
-        piet = piet.marginal(var(i), false);
-        piet = piet.exp();
-        result *= piet;
+                belief_I_minus_i *= _beliefs[j];
+        Factor f_I = factor(I);
+        if( props.updates == Properties::UpdateType::NAIVE )
+            f_I.takeLog();
+        Factor msg_I_i = (belief_I_minus_i * f_I).marginal( var(i), false );
+        if( props.updates == Properties::UpdateType::NAIVE )
+            result *= msg_I_i.exp();
+        else
+            result *= msg_I_i;
     }
     result.normalize();
     return result;
     }
     result.normalize();
     return result;
@@ -204,10 +223,13 @@ Real MF::logZ() const {
 
 
 void MF::init( const VarSet &ns ) {
 
 
 void MF::init( const VarSet &ns ) {
-    for( size_t i = 0; i < nrVars(); i++ ) {
-        if( ns.contains(var(i) ) )
-            _beliefs[i].fill( 1.0 );
-    }
+    for( size_t i = 0; i < nrVars(); i++ )
+        if( ns.contains(var(i) ) ) {
+            if( props.init == Properties::InitType::UNIFORM )
+                _beliefs[i].fill( 1.0 );
+            else
+                _beliefs[i].randomize();
+        }
 }
 
 
 }
 
 
index d692553..bb30e97 100644 (file)
@@ -98,7 +98,12 @@ JTREE_MINNEIGHBORS_SHSH_MAP:    JTREE[inference=MAXPROD,heuristic=MINNEIGHBORS,u
 
 # --- MF ----------------------
 
 
 # --- MF ----------------------
 
-MF_SEQRND:                      MF[tol=1e-9,maxiter=10000,damping=0.0]
+MF:                             MF[tol=1e-9,maxiter=10000,damping=0.0,init=RANDOM,updates=NAIVE]
+
+MF_NAIVE_UNI:                   MF[tol=1e-9,maxiter=10000,damping=0.0,init=UNIFORM,updates=NAIVE]
+MF_NAIVE_RND:                   MF[tol=1e-9,maxiter=10000,damping=0.0,init=RANDOM,updates=NAIVE]
+MF_HARDSPIN_UNI:                MF[tol=1e-9,maxiter=10000,damping=0.0,init=UNIFORM,updates=HARDSPIN]
+MF_HARDSPIN_RND:                MF[tol=1e-9,maxiter=10000,damping=0.0,init=RANDOM,updates=HARDSPIN]
 
 # --- TREEEP ------------------
 
 
 # --- TREEEP ------------------
 
index 1a63d64..5e9c25e 100755 (executable)
@@ -1,2 +1,2 @@
 #!/bin/bash
 #!/bin/bash
-./testdai --report-iters false --report-time false --marginals VAR --aliases aliases.conf --filename $1 --methods EXACT JTREE_HUGIN JTREE_SHSH BP_SEQFIX BP_SEQRND BP_SEQMAX BP_PARALL BP_SEQFIX_LOG BP_SEQRND_LOG BP_SEQMAX_LOG BP_PARALL_LOG FBP TRWBP MF_SEQRND TREEEP TREEEPWC GBP_MIN GBP_BETHE GBP_DELTA GBP_LOOP3 GBP_LOOP4 GBP_LOOP5 GBP_LOOP6 GBP_LOOP7 HAK_MIN HAK_BETHE HAK_DELTA HAK_LOOP3 HAK_LOOP4 HAK_LOOP5 MR_RESPPROP_FULL MR_CLAMPING_FULL MR_EXACT_FULL MR_RESPPROP_LINEAR MR_CLAMPING_LINEAR MR_EXACT_LINEAR LCBP_FULLCAV_SEQFIX LCBP_FULLCAVin_SEQFIX LCBP_FULLCAV_SEQRND LCBP_FULLCAVin_SEQRND LCBP_FULLCAV_NONE LCBP_FULLCAVin_NONE LCBP_PAIRCAV_SEQFIX LCBP_PAIRCAVin_SEQFIX LCBP_PAIRCAV_SEQRND LCBP_PAIRCAVin_SEQRND LCBP_PAIRCAV_NONE LCBP_PAIRCAVin_NONE LCBP_PAIR2CAV_SEQFIX LCBP_PAIR2CAVin_SEQFIX LCBP_PAIR2CAV_SEQRND LCBP_PAIR2CAVin_SEQRND LCBP_PAIR2CAV_NONE LCBP_PAIR2CAVin_NONE LCBP_UNICAV_SEQFIX LCBP_UNICAV_SEQRND BBP CBP
+./testdai --report-iters false --report-time false --marginals VAR --aliases aliases.conf --filename $1 --methods EXACT JTREE_HUGIN JTREE_SHSH BP_SEQFIX BP_SEQRND BP_SEQMAX BP_PARALL BP_SEQFIX_LOG BP_SEQRND_LOG BP_SEQMAX_LOG BP_PARALL_LOG FBP TRWBP MF TREEEP TREEEPWC GBP_MIN GBP_BETHE GBP_DELTA GBP_LOOP3 GBP_LOOP4 GBP_LOOP5 GBP_LOOP6 GBP_LOOP7 HAK_MIN HAK_BETHE HAK_DELTA HAK_LOOP3 HAK_LOOP4 HAK_LOOP5 MR_RESPPROP_FULL MR_CLAMPING_FULL MR_EXACT_FULL MR_RESPPROP_LINEAR MR_CLAMPING_LINEAR MR_EXACT_LINEAR LCBP_FULLCAV_SEQFIX LCBP_FULLCAVin_SEQFIX LCBP_FULLCAV_SEQRND LCBP_FULLCAVin_SEQRND LCBP_FULLCAV_NONE LCBP_FULLCAVin_NONE LCBP_PAIRCAV_SEQFIX LCBP_PAIRCAVin_SEQFIX LCBP_PAIRCAV_SEQRND LCBP_PAIRCAVin_SEQRND LCBP_PAIRCAV_NONE LCBP_PAIRCAVin_NONE LCBP_PAIR2CAV_SEQFIX LCBP_PAIR2CAVin_SEQFIX LCBP_PAIR2CAV_SEQRND LCBP_PAIR2CAVin_SEQRND LCBP_PAIR2CAV_NONE LCBP_PAIR2CAVin_NONE LCBP_UNICAV_SEQFIX LCBP_UNICAV_SEQRND BBP CBP
index 12eda6f..3717c25 100755 (executable)
@@ -1 +1 @@
-@testdai --report-iters false --report-time false --marginals VAR --aliases aliases.conf --filename %1 --methods EXACT JTREE_HUGIN JTREE_SHSH BP_SEQFIX BP_SEQRND BP_SEQMAX BP_PARALL BP_SEQFIX_LOG BP_SEQRND_LOG BP_SEQMAX_LOG BP_PARALL_LOG FBP TRWBP MF_SEQRND TREEEP TREEEPWC GBP_MIN GBP_BETHE GBP_DELTA GBP_LOOP3 GBP_LOOP4 GBP_LOOP5 GBP_LOOP6 GBP_LOOP7 HAK_MIN HAK_BETHE HAK_DELTA HAK_LOOP3 HAK_LOOP4 HAK_LOOP5 MR_RESPPROP_FULL MR_CLAMPING_FULL MR_EXACT_FULL MR_RESPPROP_LINEAR MR_CLAMPING_LINEAR MR_EXACT_LINEAR LCBP_FULLCAV_SEQFIX LCBP_FULLCAVin_SEQFIX LCBP_FULLCAV_SEQRND LCBP_FULLCAVin_SEQRND LCBP_FULLCAV_NONE LCBP_FULLCAVin_NONE LCBP_PAIRCAV_SEQFIX LCBP_PAIRCAVin_SEQFIX LCBP_PAIRCAV_SEQRND LCBP_PAIRCAVin_SEQRND LCBP_PAIRCAV_NONE LCBP_PAIRCAVin_NONE LCBP_PAIR2CAV_SEQFIX LCBP_PAIR2CAVin_SEQFIX LCBP_PAIR2CAV_SEQRND LCBP_PAIR2CAVin_SEQRND LCBP_PAIR2CAV_NONE LCBP_PAIR2CAVin_NONE LCBP_UNICAV_SEQFIX LCBP_UNICAV_SEQRND BBP CBP
+@testdai --report-iters false --report-time false --marginals VAR --aliases aliases.conf --filename %1 --methods EXACT JTREE_HUGIN JTREE_SHSH BP_SEQFIX BP_SEQRND BP_SEQMAX BP_PARALL BP_SEQFIX_LOG BP_SEQRND_LOG BP_SEQMAX_LOG BP_PARALL_LOG FBP TRWBP MF TREEEP TREEEPWC GBP_MIN GBP_BETHE GBP_DELTA GBP_LOOP3 GBP_LOOP4 GBP_LOOP5 GBP_LOOP6 GBP_LOOP7 HAK_MIN HAK_BETHE HAK_DELTA HAK_LOOP3 HAK_LOOP4 HAK_LOOP5 MR_RESPPROP_FULL MR_CLAMPING_FULL MR_EXACT_FULL MR_RESPPROP_LINEAR MR_CLAMPING_LINEAR MR_EXACT_LINEAR LCBP_FULLCAV_SEQFIX LCBP_FULLCAVin_SEQFIX LCBP_FULLCAV_SEQRND LCBP_FULLCAVin_SEQRND LCBP_FULLCAV_NONE LCBP_FULLCAVin_NONE LCBP_PAIRCAV_SEQFIX LCBP_PAIRCAVin_SEQFIX LCBP_PAIRCAV_SEQRND LCBP_PAIRCAVin_SEQRND LCBP_PAIRCAV_NONE LCBP_PAIRCAVin_NONE LCBP_PAIR2CAV_SEQFIX LCBP_PAIR2CAVin_SEQFIX LCBP_PAIR2CAV_SEQRND LCBP_PAIR2CAVin_SEQRND LCBP_PAIR2CAV_NONE LCBP_PAIR2CAVin_NONE LCBP_UNICAV_SEQFIX LCBP_UNICAV_SEQRND BBP CBP
index 4707b5f..000e249 100644 (file)
@@ -221,7 +221,7 @@ TRWBP                                       9.483e-02       3.078e-02       +2.969e-01      1.000e-09
 # ({x13}, (5.266e-01, 4.734e-01))
 # ({x14}, (6.033e-01, 3.967e-01))
 # ({x15}, (1.558e-01, 8.442e-01))
 # ({x13}, (5.266e-01, 4.734e-01))
 # ({x14}, (6.033e-01, 3.967e-01))
 # ({x15}, (1.558e-01, 8.442e-01))
-MF_SEQRND                                      3.607e-01       1.904e-01       -1.608e+00      1.000e-09       
+MF                                             3.607e-01       1.904e-01       -1.608e+00      1.000e-09       
 # ({x0}, (2.053e-01, 7.947e-01))
 # ({x1}, (9.163e-01, 8.373e-02))
 # ({x2}, (1.579e-01, 8.421e-01))
 # ({x0}, (2.053e-01, 7.947e-01))
 # ({x1}, (9.163e-01, 8.373e-02))
 # ({x2}, (1.579e-01, 8.421e-01))