Added init parameter to HAK/GBP to allow for random initialization
[libdai.git] / src / hak.cpp
index 43a368d..7232615 100644 (file)
@@ -45,6 +45,10 @@ void HAK::setProperties( const PropertySet &opts ) {
         props.damping = opts.getStringAs<double>("damping");
     else
         props.damping = 0.0;
+    if( opts.hasKey("init") )
+        props.init = opts.getStringAs<Properties::InitType>("init");
+    else
+        props.init = Properties::InitType::UNIFORM;
 }
 
 
@@ -55,6 +59,7 @@ PropertySet HAK::getProperties() const {
     opts.Set( "verbose", props.verbose );
     opts.Set( "doubleloop", props.doubleloop );
     opts.Set( "clusters", props.clusters );
+    opts.Set( "init", props.init );
     opts.Set( "loopdepth", props.loopdepth );
     opts.Set( "damping", props.damping );
     return opts;
@@ -69,6 +74,7 @@ string HAK::printProperties() const {
     s << "verbose=" << props.verbose << ",";
     s << "doubleloop=" << props.doubleloop << ",";
     s << "clusters=" << props.clusters << ",";
+    s << "init=" << props.init << ",";
     s << "loopdepth=" << props.loopdepth << ",";
     s << "damping=" << props.damping << "]";
     return s.str();
@@ -168,16 +174,28 @@ string HAK::identify() const {
 
 void HAK::init( const VarSet &ns ) {
     for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
-        if( alpha->vars().intersects( ns ) )
-            alpha->fill( 1.0 / alpha->states() );
+        if( alpha->vars().intersects( ns ) ) {
+            if( props.init == Properties::InitType::UNIFORM )
+                alpha->fill( 1.0 / alpha->states() );
+            else
+                alpha->randomize();
+        }
 
     for( size_t beta = 0; beta < nrIRs(); beta++ )
         if( IR(beta).intersects( ns ) ) {
-            _Qb[beta].fill( 1.0 );
+            if( props.init == Properties::InitType::UNIFORM )
+                _Qb[beta].fill( 1.0 );
+            else
+                _Qb[beta].randomize();
             foreach( const Neighbor &alpha, nbIR(beta) ) {
                 size_t _beta = alpha.dual;
-                muab( alpha, _beta ).fill( 1.0 );
-                muba( alpha, _beta ).fill( 1.0 );
+                if( props.init == Properties::InitType::UNIFORM ) {
+                    muab( alpha, _beta ).fill( 1.0 );
+                    muba( alpha, _beta ).fill( 1.0 );
+                } else {
+                    muab( alpha, _beta ).randomize();
+                    muba( alpha, _beta ).randomize();
+                }
             }
         }
 }
@@ -185,16 +203,27 @@ void HAK::init( const VarSet &ns ) {
 
 void HAK::init() {
     for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
-        alpha->fill( 1.0 / alpha->states() );
+        if( props.init == Properties::InitType::UNIFORM )
+            alpha->fill( 1.0 / alpha->states() );
+        else
+            alpha->randomize();
 
     for( vector<Factor>::iterator beta = _Qb.begin(); beta != _Qb.end(); beta++ )
-        beta->fill( 1.0 / beta->states() );
+        if( props.init == Properties::InitType::UNIFORM )
+            beta->fill( 1.0 / beta->states() );
+        else
+            beta->randomize();
 
     for( size_t alpha = 0; alpha < nrORs(); alpha++ )
         foreach( const Neighbor &beta, nbOR(alpha) ) {
             size_t _beta = beta.iter;
-            muab( alpha, _beta ).fill( 1.0 / muab( alpha, _beta ).states() );
-            muba( alpha, _beta ).fill( 1.0 / muab( alpha, _beta ).states() );
+            if( props.init == Properties::InitType::UNIFORM ) {
+                muab( alpha, _beta ).fill( 1.0 / muab( alpha, _beta ).states() );
+                muba( alpha, _beta ).fill( 1.0 / muab( alpha, _beta ).states() );
+            } else {
+                muab( alpha, _beta ).randomize();
+                muba( alpha, _beta ).randomize();
+            }
         }
 }