Removed stuff from InfAlg, moved it to individual inference algorithms
[libdai.git] / src / mr.cpp
index 78ba6ab..eb9ab83 100644 (file)
@@ -39,30 +39,34 @@ using namespace std;
 const char *MR::Name = "MR";
 
 
-bool MR::checkProperties() {
-    if( !HasProperty("updates") )
-        return false;
-    if( !HasProperty("inits") )
-        return false;
-    if( !HasProperty("verbose") )
-        return false;
-    if( !HasProperty("tol") )
-        return false;
+void MR::setProperties( const PropertySet &opts ) {
+    assert( opts.hasKey("tol") );
+    assert( opts.hasKey("verbose") );
+    assert( opts.hasKey("updates") );
+    assert( opts.hasKey("inits") );
     
-    ConvertPropertyTo<UpdateType>("updates");
-    ConvertPropertyTo<InitType>("inits");
-    ConvertPropertyTo<size_t>("verbose");
-    ConvertPropertyTo<double>("tol");
+    props.tol = opts.getStringAs<double>("tol");
+    props.verbose = opts.getStringAs<size_t>("verbose");
+    props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+    props.inits = opts.getStringAs<Properties::InitType>("inits");
+}
+
 
-    return true;
+PropertySet MR::getProperties() const {
+    PropertySet opts;
+    opts.Set( "tol", props.tol );
+    opts.Set( "verbose", props.verbose );
+    opts.Set( "updates", props.updates );
+    opts.Set( "inits", props.inits );
+    return opts;
 }
 
 
 // init N, con, nb, tJ, theta
-void MR::init(size_t _N, double *_w, double *_th) {
+void MR::init(size_t Nin, double *_w, double *_th) {
     size_t i,j;
     
-    N = _N;
+    N = Nin;
 
     con.resize(N);
     nb.resize(N);
@@ -199,9 +203,9 @@ double MR::init_cor_resp() {
                         }
                     }
                 }
-            } while((md > Tol())&&(runx<runs)); // Precision condition reached -> BP and RP finished
+            } while((md > props.tol)&&(runx<runs)); // Precision condition reached -> BP and RP finished
             if(runx==runs)
-                if( Verbose() >= 2 )
+                if( props.verbose >= 2 )
                     cout << "init_cor_resp: Convergence not reached (md=" << md << ")..." << endl;
             if(md > maxdev)
                 maxdev = md;
@@ -409,7 +413,7 @@ void MR::solvemcav() {
                 assert( nb[j][_i] == i );
 
                 double newM = 0.0;
-                if( Updates() == UpdateType::FULL ) {
+                if( props.updates == Properties::UpdateType::FULL ) {
                     // find indices in nb[j] that do not correspond with i
                     sub_nb _nbj_min_i(con[j]);
                     _nbj_min_i -= kindex[i][_j];
@@ -429,7 +433,7 @@ void MR::solvemcav() {
                         numer += tJ[i][_k] * cors[i][_j][_k] * (tanh(theta[i]) * sum_even + sum_odd);
                     }
                     newM -= numer / denom;
-                } else if( Updates() == UpdateType::LINEAR ) {
+                } else if( props.updates == Properties::UpdateType::LINEAR ) {
                     newM = T(j,_i);
                     for(size_t _l=0; _l<con[i]; _l++) if( _l != _j )
                         newM -= Omega(i,_j,_l) * tJ[i][_l] * cors[i][_j][_l];
@@ -449,12 +453,13 @@ void MR::solvemcav() {
                 M[i][_j] = newM;
             }
         }
-    } while((maxdev>Tol())&&(run<maxruns));
+    } while((maxdev>props.tol)&&(run<maxruns));
 
-    updateMaxDiff( maxdev );
+    if( maxdev > maxdiff )
+        maxdiff = maxdev;
 
     if(run==maxruns){
-        if( Verbose() >= 1 )
+        if( props.verbose >= 1 )
             cout << "solve_mcav: Convergence not reached (maxdev=" << maxdev << ")..." << endl;
     }
 }
@@ -462,7 +467,7 @@ void MR::solvemcav() {
  
 void MR::solveM() { 
     for(size_t i=0; i<N; i++) {
-        if( Updates() == UpdateType::FULL ) {
+        if( props.updates == Properties::UpdateType::FULL ) {
             // find indices in nb[i]
             sub_nb _nbi(con[i]);
 
@@ -472,7 +477,7 @@ void MR::solveM() {
 
             Mag[i] = (tanh(theta[i]) * sum_even + sum_odd) / (sum_even + tanh(theta[i]) * sum_odd);
 
-        } else if( Updates() == UpdateType::LINEAR ) {
+        } else if( props.updates == Properties::UpdateType::LINEAR ) {
             sub_nb empty(con[i]);
             empty.clear();
             Mag[i] = T(i,empty);
@@ -490,14 +495,14 @@ void MR::solveM() {
 void MR::init_cor() {
     for( size_t i = 0; i < nrVars(); i++ ) {
         vector<Factor> pairq;
-        if( Inits() == InitType::CLAMPING ) {
-            BP bpcav(*this, Properties()("updates",string("SEQMAX"))("tol", string("1e-9"))("maxiter", string("1000UL"))("verbose", string("0UL")));
-            bpcav.makeCavity( var(i) );
-            pairq = calcPairBeliefs( bpcav, delta(var(i)), false );
-        } else if( Inits() == InitType::EXACT ) {
-            JTree jtcav(*this, Properties()("updates",string("HUGIN"))("verbose", string("0UL")) );
-            jtcav.makeCavity( var(i) );
-            pairq = calcPairBeliefs( jtcav, delta(var(i)), false );
+        if( props.inits == Properties::InitType::CLAMPING ) {
+            BP bpcav(*this, PropertySet()("updates",string("SEQMAX"))("tol", string("1e-9"))("maxiter", string("1000UL"))("verbose", string("0UL"))("logdomain", string("0")));
+            bpcav.makeCavity( i );
+            pairq = calcPairBeliefs( bpcav, delta(i), false );
+        } else if( props.inits == Properties::InitType::EXACT ) {
+            JTree jtcav(*this, PropertySet()("updates",string("HUGIN"))("verbose", 0UL) );
+            jtcav.makeCavity( i );
+            pairq = calcPairBeliefs( jtcav, delta(i), false );
         }
         for( size_t jk = 0; jk < pairq.size(); jk++ ) {
             VarSet::const_iterator kit = pairq[jk].vars().begin();
@@ -517,17 +522,17 @@ void MR::init_cor() {
 
 string MR::identify() const { 
     stringstream result (stringstream::out);
-    result << Name << GetProperties();
+    result << Name << getProperties();
     return result.str();
 }
 
 
 double MR::run() {
     if( supported ) {
-        if( Verbose() >= 1 )
+        if( props.verbose >= 1 )
             cout << "Starting " << identify() << "...";
 
-        clock_t tic = toc();
+        double tic = toc();
 //        Diffs diffs(nrVars(), 1.0);
 
         M.resize(N);
@@ -545,11 +550,13 @@ double MR::run() {
         for(size_t i=0; i<N; i++)
           kindex[i].resize(kmax);
 
-        if( Inits() == InitType::RESPPROP )
-            updateMaxDiff( init_cor_resp() );
-        else if( Inits() == InitType::EXACT )
+        if( props.inits == Properties::InitType::RESPPROP ) {
+            double md = init_cor_resp();
+            if( md > maxdiff )
+                maxdiff = md;
+        } else if( props.inits == Properties::InitType::EXACT )
             init_cor(); // FIXME no MaxDiff() calculation
-        else if( Inits() == InitType::CLAMPING )
+        else if( props.inits == Properties::InitType::CLAMPING )
             init_cor(); // FIXME no MaxDiff() calculation
 
         solvemcav();
@@ -557,7 +564,7 @@ double MR::run() {
         Mag.resize(N);
         solveM();
 
-        if( Verbose() >= 1 )
+        if( props.verbose >= 1 )
             cout << "MR needed " << toc() - tic << " clocks." << endl;
 
         return 0.0;
@@ -601,11 +608,13 @@ vector<Factor> MR::beliefs() const {
 
 
 
-MR::MR( const FactorGraph &fg, const Properties &opts ) : DAIAlgFG(fg, opts), supported(true) {
+MR::MR( const FactorGraph &fg, const PropertySet &opts ) : DAIAlgFG(fg), supported(true), maxdiff(0.0) {
+    setProperties( opts );
+
     // check whether all vars in fg are binary
     // check whether connectivity is <= kmax
     for( size_t i = 0; i < fg.nrVars(); i++ )
-        if( (fg.var(i).states() > 2) || (fg.delta(fg.var(i)).size() > kmax) ) {
+        if( (fg.var(i).states() > 2) || (fg.delta(i).size() > kmax) ) {
             supported = false;
             break;
         }
@@ -624,15 +633,15 @@ MR::MR( const FactorGraph &fg, const Properties &opts ) : DAIAlgFG(fg, opts), su
         return;
 
     // create w and th
-    size_t _N = fg.nrVars();
+    size_t Nin = fg.nrVars();
 
-    double *w = new double[_N*_N];
-    double *th = new double[_N];
+    double *w = new double[Nin*Nin];
+    double *th = new double[Nin];
     
-    for( size_t i = 0; i < _N; i++ ) {
+    for( size_t i = 0; i < Nin; i++ ) {
         th[i] = 0.0;
-        for( size_t j = 0; j < _N; j++ )
-            w[i*_N+j] = 0.0;
+        for( size_t j = 0; j < Nin; j++ )
+            w[i*Nin+j] = 0.0;
     }
 
     for( size_t I = 0; I < fg.nrFactors(); I++ ) {
@@ -645,15 +654,15 @@ MR::MR( const FactorGraph &fg, const Properties &opts ) : DAIAlgFG(fg, opts), su
             VarSet::const_iterator jit = psi.vars().begin();
             size_t j = fg.findVar( *(++jit) );
 
-            w[i*_N+j] += 0.25 * log(psi[3] * psi[0] / (psi[2] * psi[1])); 
-            w[j*_N+i] += 0.25 * log(psi[3] * psi[0] / (psi[2] * psi[1])); 
+            w[i*Nin+j] += 0.25 * log(psi[3] * psi[0] / (psi[2] * psi[1])); 
+            w[j*Nin+i] += 0.25 * log(psi[3] * psi[0] / (psi[2] * psi[1])); 
 
             th[i] += 0.25 * log(psi[3] / psi[2] * psi[1] / psi[0]);
             th[j] += 0.25 * log(psi[3] / psi[1] * psi[2] / psi[0]);
         }
     }
     
-    init(_N, w, th);
+    init(Nin, w, th);
 
     delete th;
     delete w;