Merge branch 'master' of git@git.tuebingen.mpg.de:libdai
[libdai.git] / tests / test.cpp
index be355ad..d33880f 100644 (file)
@@ -25,6 +25,7 @@
 #include <numeric>
 #include <cmath>
 #include <cstdlib>
+#include <cstring>
 #include <boost/program_options.hpp>
 #include <dai/util.h>
 #include <dai/alldai.h>
@@ -45,10 +46,11 @@ class TestAI {
         vector<Factor>  q;
         double          logZ;
         double          maxdiff;
-        clock_t         time;
+        double          time;
+        bool            has_logZ;
 
-        TestAI( const FactorGraph &fg, const string &_name, const Properties &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0) {
-            clock_t tic = toc();
+        TestAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0), has_logZ(true) {
+            double tic = toc();
             obj = newInfAlg( name, fg, opts );
             time += toc() - tic;
 /*
@@ -90,13 +92,13 @@ class TestAI {
 
         vector<Factor> allBeliefs() {
             vector<Factor> result;
-            for( size_t i = 0; i < obj->nrVars(); i++ )
-                result.push_back( obj->belief( obj->var(i) ) );
+            for( size_t i = 0; i < obj->fg().nrVars(); i++ )
+                result.push_back( obj->belief( obj->fg().var(i) ) );
             return result;
         }
 
         void doAI() {
-            clock_t tic = toc();
+            double tic = toc();
 //            if( name == "EXACT" ) {
 //                // calculation has already been done
 //            } 
@@ -104,8 +106,13 @@ class TestAI {
                 obj->init();
                 obj->run();
                 time += toc() - tic;
-                logZ = real(obj->logZ());
-                maxdiff = obj->MaxDiff();
+                try {
+                    logZ = obj->logZ();
+                    has_logZ = true;
+                } catch( Exception &e ) {
+                    has_logZ = false;
+                }
+                maxdiff = obj->maxDiff();
                 q = allBeliefs();
             };
         }
@@ -134,24 +141,24 @@ class TestAI {
 };
 
 
-pair<string, Properties> parseMethod( const string &_s, const map<string,string> & aliases ) {
+pair<string, PropertySet> parseMethod( const string &_s, const map<string,string> & aliases ) {
     string s = _s;
     if( aliases.find(_s) != aliases.end() )
         s = aliases.find(_s)->second;
 
-    pair<string, Properties> result;
+    pair<string, PropertySet> result;
     string & name = result.first;
-    Properties & opts = result.second;
+    PropertySet & opts = result.second;
 
     string::size_type pos = s.find_first_of('[');
     name = s.substr( 0, pos );
     if( pos == string::npos )
         throw "Malformed method";
     size_t n = 0;
-    for( ; n < sizeof(DAINames) / sizeof(string); n++ )
+    for( ; strlen( DAINames[n] ) != 0; n++ )
         if( name == DAINames[n] )
             break;
-    if( n == sizeof(DAINames) / sizeof(string) )
+    if( strlen( DAINames[n] ) == 0 )
         throw "Unknown inference algorithm";
 
     stringstream ss;
@@ -178,6 +185,7 @@ int main( int argc, char *argv[] ) {
         double tol;
         size_t maxiter;
         size_t verbose;
+        bool report_time = true;
 
         po::options_description opts_required("Required options");
         opts_required.add_options()
@@ -192,6 +200,7 @@ int main( int argc, char *argv[] ) {
             ("tol", po::value< double >(&tol), "Override tolerance")
             ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
             ("verbose", po::value< size_t >(&verbose), "Override verbosity")
+            ("report-time", po::value< bool >(&report_time), "Report calculation time")
         ;
 
         po::options_description cmdline_options;
@@ -203,7 +212,7 @@ int main( int argc, char *argv[] ) {
 
         if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
             cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
-            cout << "inference algorithms <method*>, reporting clocks, max and average" << endl;
+            cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
             cout << "error and relative logZ error (comparing with the results of" << endl;
             cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
             cout << opts_required << opts_optional << endl;
@@ -250,8 +259,10 @@ int main( int argc, char *argv[] ) {
             cout << "# " << filename << endl;
             cout.width( 40 );
             cout << left << "# METHOD" << "  ";
-            cout.width( 10 );
-            cout << right << "CLOCKS" << "    ";
+            if( report_time ) {
+                cout.width( 10 );
+                cout << right << "SECONDS" << "   ";
+            }
             cout.width( 10 );
             cout << "MAX ERROR" << "  ";
             cout.width( 10 );
@@ -262,7 +273,7 @@ int main( int argc, char *argv[] ) {
             cout << "MAXDIFF" << endl;
 
             for( size_t m = 0; m < methods.size(); m++ ) {
-                pair<string, Properties> meth = parseMethod( methods[m], Aliases );
+                pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
 
                 if( vm.count("tol") )
                     meth.second.Set("tol",tol);
@@ -281,8 +292,10 @@ int main( int argc, char *argv[] ) {
                 cout.width( 40 );
 //                cout << left << piet.identify() << "  ";
                 cout << left << methods[m] << "  ";
-                cout.width( 10 );
-                cout << right << piet.time << "    ";
+                if( report_time ) {
+                    cout.width( 10 );
+                    cout << right << piet.time << "    ";
+                }
 
                 if( m > 0 ) {
                     cout.setf( ios_base::scientific );
@@ -294,8 +307,12 @@ int main( int argc, char *argv[] ) {
                     double ae = clipdouble( piet.avgErr(), 1e-9 );
                     cout << ae << "  ";
                     cout.width( 10 );
-                    double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
-                    cout << le << "  ";
+                    if( piet.has_logZ ) {
+                        double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
+                        cout << le << "  ";
+                    } else {
+                        cout << "N/A         ";
+                    }
                     cout.width( 10 );
                     double md = clipdouble( piet.maxdiff, 1e-9 );
                     if( isnan( me ) )