Merged SVN head ...
[libdai.git] / tests / test.cpp
index be355ad..5f24776 100644 (file)
@@ -25,6 +25,7 @@
 #include <numeric>
 #include <cmath>
 #include <cstdlib>
+#include <cstring>
 #include <boost/program_options.hpp>
 #include <dai/util.h>
 #include <dai/alldai.h>
@@ -35,7 +36,7 @@ using namespace dai;
 namespace po = boost::program_options;
 
 
-class TestAI {
+class TestDAI {
     protected:
         InfAlg          *obj;
         string          name;
@@ -45,38 +46,31 @@ class TestAI {
         vector<Factor>  q;
         double          logZ;
         double          maxdiff;
-        clock_t         time;
-
-        TestAI( const FactorGraph &fg, const string &_name, const Properties &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0) {
-            clock_t tic = toc();
-            obj = newInfAlg( name, fg, opts );
+        double          time;
+        size_t          iters;
+        bool            has_logZ;
+        bool            has_maxdiff;
+        bool            has_iters;
+
+        TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
+            double tic = toc();
+            if( name == "LDPC" ) {
+                double zero[2] = {1.0, 0.0};
+                q.clear();
+                for( size_t i = 0; i < fg.nrVars(); i++ )
+                    q.push_back( Factor(Var(i,2), zero) );
+                logZ = 0.0;
+                maxdiff = 0.0;
+                iters = 1;
+                has_logZ = false;
+                has_maxdiff = false;
+                has_iters = false;
+            } else
+                obj = newInfAlg( name, fg, opts );
             time += toc() - tic;
-/*
-            } else if( method.substr(0,5) == "EXACT" ) { // EXACT
-                // Look if the network is small enough to do brute-force exact method
-                bool toolarge = false;
-                size_t total_statespace = 1;
-                for( size_t i = 0; i < fg.nrVars(); i++ ) {
-                    total_statespace *= fg.var(i).states();
-                    if( total_statespace > (1UL << 16) )
-                        toolarge = true;
-                }
-
-                if( !toolarge ) {
-                    Factor piet;
-                    for( size_t I = 0; I < fg.nrFactors(); I++ )
-                        piet *= fg.factor( I );
-                    for( size_t i = 0; i < fg.nrVars(); i++ )
-                        q.push_back(piet.marginal(fg.var(i)));
-                    time += toc() - tic;
-                    logZ = fg.ExactlogZ();
-                } else
-                    throw "Network too large for EXACT method";
-            }
-*/
         }
 
-        ~TestAI() { 
+        ~TestDAI() { 
             if( obj != NULL )
                 delete obj;
         }
@@ -90,27 +84,40 @@ class TestAI {
 
         vector<Factor> allBeliefs() {
             vector<Factor> result;
-            for( size_t i = 0; i < obj->nrVars(); i++ )
-                result.push_back( obj->belief( obj->var(i) ) );
+            for( size_t i = 0; i < obj->fg().nrVars(); i++ )
+                result.push_back( obj->belief( obj->fg().var(i) ) );
             return result;
         }
 
-        void doAI() {
-            clock_t tic = toc();
-//            if( name == "EXACT" ) {
-//                // calculation has already been done
-//            } 
+        void doDAI() {
+            double tic = toc();
             if( obj != NULL ) {
                 obj->init();
                 obj->run();
                 time += toc() - tic;
-                logZ = real(obj->logZ());
-                maxdiff = obj->MaxDiff();
+                try {
+                    logZ = obj->logZ();
+                    has_logZ = true;
+                } catch( Exception &e ) {
+                    has_logZ = false;
+                }
+                try {
+                    maxdiff = obj->maxDiff();
+                    has_maxdiff = true;
+                } catch( Exception &e ) {
+                    has_maxdiff = false;
+                }
+                try {
+                    iters = obj->Iterations();
+                    has_iters = true;
+                } catch( Exception &e ) {
+                    has_iters = false;
+                }
                 q = allBeliefs();
             };
         }
 
-        void calcErrs( const TestAI &x ) {
+        void calcErrs( const TestDAI &x ) {
             err.clear();
             err.reserve( q.size() );
             for( size_t i = 0; i < q.size(); i++ )
@@ -134,25 +141,41 @@ class TestAI {
 };
 
 
-pair<string, Properties> parseMethod( const string &_s, const map<string,string> & aliases ) {
-    string s = _s;
-    if( aliases.find(_s) != aliases.end() )
-        s = aliases.find(_s)->second;
+pair<string, PropertySet> parseMethod( const string &_s, const map<string,string> & aliases ) {
+    // s = first part of _s, until '['
+    string::size_type pos = _s.find_first_of('[');
+    string s;
+    if( pos == string::npos )
+        s = _s;
+    else
+        s = _s.substr(0,pos);
+
+    // if the first part is an alias, substitute
+    if( aliases.find(s) != aliases.end() )
+        s = aliases.find(s)->second;
+
+    // attach second part, merging properties if necessary
+    if( pos != string::npos ) {
+        if( s.at(s.length()-1) == ']' ) {
+            s = s.erase(s.length()-1,1) + ',' + _s.substr(pos+1);
+        } else
+            s = s + _s.substr(pos);
+    }
 
-    pair<string, Properties> result;
+    pair<string, PropertySet> result;
     string & name = result.first;
-    Properties & opts = result.second;
+    PropertySet & opts = result.second;
 
-    string::size_type pos = s.find_first_of('[');
-    name = s.substr( 0, pos );
+    pos = s.find_first_of('[');
     if( pos == string::npos )
         throw "Malformed method";
+    name = s.substr( 0, pos );
     size_t n = 0;
-    for( ; n < sizeof(DAINames) / sizeof(string); n++ )
+    for( ; strlen( DAINames[n] ) != 0; n++ )
         if( name == DAINames[n] )
             break;
-    if( n == sizeof(DAINames) / sizeof(string) )
-        throw "Unknown inference algorithm";
+    if( strlen( DAINames[n] ) == 0 && (name != "LDPC") )
+        DAI_THROW(UNKNOWN_DAI_ALGORITHM);
 
     stringstream ss;
     ss << s.substr(pos,s.length());
@@ -178,11 +201,13 @@ int main( int argc, char *argv[] ) {
         double tol;
         size_t maxiter;
         size_t verbose;
+        bool marginals = false;
+        bool report_time = true;
 
         po::options_description opts_required("Required options");
         opts_required.add_options()
             ("filename", po::value< string >(&filename), "Filename of FactorGraph")
-            ("methods", po::value< vector<string> >(&methods)->multitoken(), "AI methods to test")
+            ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to test")
         ;
 
         po::options_description opts_optional("Allowed options");
@@ -192,6 +217,8 @@ int main( int argc, char *argv[] ) {
             ("tol", po::value< double >(&tol), "Override tolerance")
             ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
             ("verbose", po::value< size_t >(&verbose), "Override verbosity")
+            ("marginals", po::value< bool >(&marginals), "Output single node marginals?")
+            ("report-time", po::value< bool >(&report_time), "Report calculation time")
         ;
 
         po::options_description cmdline_options;
@@ -203,7 +230,7 @@ int main( int argc, char *argv[] ) {
 
         if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
             cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
-            cout << "inference algorithms <method*>, reporting clocks, max and average" << endl;
+            cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
             cout << "error and relative logZ error (comparing with the results of" << endl;
             cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
             cout << opts_required << opts_optional << endl;
@@ -240,71 +267,93 @@ int main( int argc, char *argv[] ) {
         }
 
         FactorGraph fg;
-        if( fg.ReadFromFile(filename.c_str()) ) {
-            cout << "Error reading " << filename << endl;
-            return 2;
-        } else {
-            vector<Factor> q0;
-            double logZ0 = 0.0;
-
-            cout << "# " << filename << endl;
-            cout.width( 40 );
-            cout << left << "# METHOD" << "  ";
-            cout.width( 10 );
-            cout << right << "CLOCKS" << "    ";
-            cout.width( 10 );
-            cout << "MAX ERROR" << "  ";
-            cout.width( 10 );
-            cout << "AVG ERROR" << "  ";
-            cout.width( 10 );
-            cout << "LOGZ ERROR" << "  ";
+        fg.ReadFromFile( filename.c_str() );
+
+        vector<Factor> q0;
+        double logZ0 = 0.0;
+
+        cout << "# " << filename << endl;
+        cout.width( 40 );
+        cout << left << "# METHOD" << "  ";
+        if( report_time ) {
             cout.width( 10 );
-            cout << "MAXDIFF" << endl;
-
-            for( size_t m = 0; m < methods.size(); m++ ) {
-                pair<string, Properties> meth = parseMethod( methods[m], Aliases );
-
-                if( vm.count("tol") )
-                    meth.second.Set("tol",tol);
-                if( vm.count("maxiter") )
-                    meth.second.Set("maxiter",maxiter);
-                if( vm.count("verbose") )
-                    meth.second.Set("verbose",verbose);
-                TestAI piet(fg, meth.first, meth.second );
-                piet.doAI();
-                if( m == 0 ) {
-                    q0 = piet.q;
-                    logZ0 = piet.logZ;
-                }
-                piet.calcErrs(q0);
+            cout << right << "SECONDS" << "   ";
+        }
+        cout.width( 10 );
+        cout << "MAX ERROR" << "  ";
+        cout.width( 10 );
+        cout << "AVG ERROR" << "  ";
+        cout.width( 10 );
+        cout << "LOGZ ERROR" << "  ";
+        cout.width( 10 );
+        cout << "MAXDIFF" << "  ";
+        cout.width( 10 );
+        cout << "ITERS" << endl;
+
+        for( size_t m = 0; m < methods.size(); m++ ) {
+            pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
+
+            if( vm.count("tol") )
+                meth.second.Set("tol",tol);
+            if( vm.count("maxiter") )
+                meth.second.Set("maxiter",maxiter);
+            if( vm.count("verbose") )
+                meth.second.Set("verbose",verbose);
+            TestDAI piet(fg, meth.first, meth.second );
+            piet.doDAI();
+            if( m == 0 ) {
+                q0 = piet.q;
+                logZ0 = piet.logZ;
+            }
+            piet.calcErrs(q0);
 
-                cout.width( 40 );
+            cout.width( 40 );
 //                cout << left << piet.identify() << "  ";
-                cout << left << methods[m] << "  ";
+            cout << left << methods[m] << "  ";
+            if( report_time ) {
                 cout.width( 10 );
                 cout << right << piet.time << "    ";
+            }
 
-                if( m > 0 ) {
-                    cout.setf( ios_base::scientific );
-                    cout.precision( 3 );
-                    cout.width( 10 ); 
-                    double me = clipdouble( piet.maxErr(), 1e-9 );
-                    cout << me << "  ";
-                    cout.width( 10 );
-                    double ae = clipdouble( piet.avgErr(), 1e-9 );
-                    cout << ae << "  ";
-                    cout.width( 10 );
+            if( m > 0 ) {
+                cout.setf( ios_base::scientific );
+                cout.precision( 3 );
+                cout.width( 10 ); 
+                double me = clipdouble( piet.maxErr(), 1e-9 );
+                cout << me << "  ";
+                cout.width( 10 );
+                double ae = clipdouble( piet.avgErr(), 1e-9 );
+                cout << ae << "  ";
+                cout.width( 10 );
+                if( piet.has_logZ ) {
                     double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
                     cout << le << "  ";
-                    cout.width( 10 );
+                } else {
+                    cout << "N/A         ";
+                }
+                cout.width( 10 );
+                if( piet.has_maxdiff ) {
                     double md = clipdouble( piet.maxdiff, 1e-9 );
                     if( isnan( me ) )
                         md = me;
                     if( isnan( ae ) )
                         md = ae;
-                    cout << md << endl;
-                } else
-                    cout << endl;
+                    cout << md << "  ";
+                } else {
+                    cout << "N/A         ";
+                }
+                cout.width( 10 );
+                if( piet.has_iters ) {
+                    cout << piet.iters << "  ";
+                } else {
+                    cout << "N/A         ";
+                }
+            }
+            cout << endl;
+
+            if( marginals ) {
+                for( size_t i = 0; i < piet.q.size(); i++ )
+                    cout << "# " << piet.q[i] << endl;
             }
         }
     } catch(const char *e) {