Even more cleanup of BBP code
[libdai.git] / tests / testdai.cpp
index 57d8342..8e47385 100644 (file)
@@ -76,7 +76,7 @@ class TestDAI {
                 delete obj;
         }
 
-        string identify() 
+        string identify() const {
             if( obj != NULL )
                 return obj->identify(); 
             else
@@ -96,24 +96,37 @@ class TestDAI {
                 obj->init();
                 obj->run();
                 time += toc() - tic;
+                
                 try {
                     logZ = obj->logZ();
                     has_logZ = true;
                 } catch( Exception &e ) {
-                    has_logZ = false;
+                    if( e.code() == Exception::NOT_IMPLEMENTED )
+                        has_logZ = false;
+                    else
+                        throw;
                 }
+
                 try {
                     maxdiff = obj->maxDiff();
                     has_maxdiff = true;
                 } catch( Exception &e ) {
-                    has_maxdiff = false;
+                    if( e.code() == Exception::NOT_IMPLEMENTED )
+                        has_maxdiff = false;
+                    else
+                        throw;
                 }
+                
                 try {
                     iters = obj->Iterations();
                     has_iters = true;
                 } catch( Exception &e ) {
-                    has_iters = false;
+                    if( e.code() == Exception::NOT_IMPLEMENTED )
+                        has_iters = false;
+                    else
+                        throw;
                 }
+                
                 q = allBeliefs();
             };
         }
@@ -142,47 +155,51 @@ class TestDAI {
 };
 
 
-pair<string, PropertySet> parseMethod( const string &_s, const map<string,string> & aliases ) {
-    // s = first part of _s, until '['
-    string::size_type pos = _s.find_first_of('[');
-    string s;
-    if( pos == string::npos )
-        s = _s;
-    else
-        s = _s.substr(0,pos);
+pair<string, PropertySet> parseMethodRaw( const string &s ) {
+    string::size_type pos = s.find_first_of('[');
+    string name;
+    PropertySet opts;
+    if( pos == string::npos ) {
+        name = s;
+    } else {
+        name = s.substr(0,pos);
 
-    // if the first part is an alias, substitute
-    if( aliases.find(s) != aliases.end() )
-        s = aliases.find(s)->second;
-
-    // attach second part, merging properties if necessary
-    if( pos != string::npos ) {
-        if( s.at(s.length()-1) == ']' ) {
-            s = s.erase(s.length()-1,1) + ',' + _s.substr(pos+1);
-        } else
-            s = s + _s.substr(pos);
+        stringstream ss;
+        ss << s.substr(pos,s.length());
+        ss >> opts;
     }
+    return make_pair(name,opts);
+}
 
-    pair<string, PropertySet> result;
-    string & name = result.first;
-    PropertySet & opts = result.second;
 
-    pos = s.find_first_of('[');
-    if( pos == string::npos )
-        throw "Malformed method";
-    name = s.substr( 0, pos );
+pair<string, PropertySet> parseMethod( const string &_s, const map<string,string> & aliases ) {
+    // break string into method[properties]
+    pair<string,PropertySet> ps = parseMethodRaw(_s);
+    bool looped = false;
+
+    // as long as 'method' is an alias, update:
+    while( aliases.find(ps.first) != aliases.end() && !looped ) {
+        string astr = aliases.find(ps.first)->second;
+        pair<string,PropertySet> aps = parseMethodRaw(astr);
+        if( aps.first == ps.first )
+            looped = true;
+        // override aps properties by ps properties
+        aps.second.Set( ps.second );
+        // replace ps by aps
+        ps = aps;
+        // repeat until method name == alias name ('looped'), or
+        // there is no longer an alias 'method'
+    }
+    
+    // check whether name is valid
     size_t n = 0;
     for( ; strlen( DAINames[n] ) != 0; n++ )
-        if( name == DAINames[n] )
+        if( ps.first == DAINames[n] )
             break;
-    if( strlen( DAINames[n] ) == 0 && (name != "LDPC") )
-        DAI_THROW(UNKNOWN_DAI_ALGORITHM);
+    if( strlen( DAINames[n] ) == 0 && (ps.first != "LDPC") )
+        throw std::runtime_error(string("Unknown DAI algorithm \"") + ps.first + string("\" in \"") + _s + string("\""));
 
-    stringstream ss;
-    ss << s.substr(pos,s.length());
-    ss >> opts;
-    
-    return result;
+    return ps;
 }
 
 
@@ -195,175 +212,164 @@ double clipdouble( double x, double minabs ) {
 
 
 int main( int argc, char *argv[] ) {
-    try {
-        string filename;
-        string aliases;
-        vector<string> methods;
-        double tol;
-        size_t maxiter;
-        size_t verbose;
-        bool marginals = false;
-        bool report_iters = true;
-        bool report_time = true;
-
-        po::options_description opts_required("Required options");
-        opts_required.add_options()
-            ("filename", po::value< string >(&filename), "Filename of FactorGraph")
-            ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to test")
-        ;
-
-        po::options_description opts_optional("Allowed options");
-        opts_optional.add_options()
-            ("help", "produce help message")
-            ("aliases", po::value< string >(&aliases), "Filename for aliases")
-            ("tol", po::value< double >(&tol), "Override tolerance")
-            ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
-            ("verbose", po::value< size_t >(&verbose), "Override verbosity")
-            ("marginals", po::value< bool >(&marginals), "Output single node marginals?")
-            ("report-time", po::value< bool >(&report_time), "Report calculation time")
-            ("report-iters", po::value< bool >(&report_iters), "Report iterations needed")
-        ;
-
-        po::options_description cmdline_options;
-        cmdline_options.add(opts_required).add(opts_optional);
-
-        po::variables_map vm;
-        po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
-        po::notify(vm);
-
-        if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
-            cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
-            cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
-            cout << "error and relative logZ error (comparing with the results of" << endl;
-            cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
-            cout << opts_required << opts_optional << endl;
-            return 1;
-        }
+    string filename;
+    string aliases;
+    vector<string> methods;
+    double tol;
+    size_t maxiter;
+    size_t verbose;
+    bool marginals = false;
+    bool report_iters = true;
+    bool report_time = true;
+
+    po::options_description opts_required("Required options");
+    opts_required.add_options()
+        ("filename", po::value< string >(&filename), "Filename of FactorGraph")
+        ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to test")
+    ;
+
+    po::options_description opts_optional("Allowed options");
+    opts_optional.add_options()
+        ("help", "produce help message")
+        ("aliases", po::value< string >(&aliases), "Filename for aliases")
+        ("tol", po::value< double >(&tol), "Override tolerance")
+        ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
+        ("verbose", po::value< size_t >(&verbose), "Override verbosity")
+        ("marginals", po::value< bool >(&marginals), "Output single node marginals?")
+        ("report-time", po::value< bool >(&report_time), "Report calculation time")
+        ("report-iters", po::value< bool >(&report_iters), "Report iterations needed")
+    ;
+
+    po::options_description cmdline_options;
+    cmdline_options.add(opts_required).add(opts_optional);
+
+    po::variables_map vm;
+    po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
+    po::notify(vm);
+
+    if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
+        cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
+        cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
+        cout << "error and relative logZ error (comparing with the results of" << endl;
+        cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
+        cout << opts_required << opts_optional << endl;
+        return 1;
+    }
 
-        // Read aliases
-        map<string,string> Aliases;
-        if( !aliases.empty() ) {
-            ifstream infile;
-            infile.open (aliases.c_str());
-            if (infile.is_open()) {
-                while( true ) {
-                    string line;
-                    getline(infile,line);
-                    if( infile.fail() )
-                        break;
-                    if( (!line.empty()) && (line[0] != '#') ) {
-                        string::size_type pos = line.find(':',0);
-                        if( pos == string::npos )
-                            throw "Invalid alias";
-                        else {
-                            string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
-                            string key = line.substr(0, posl + 1);
-                            string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
-                            string val = line.substr(pos + 1 + posr, line.length());
-                            Aliases[key] = val;
-                        }
+    // Read aliases
+    map<string,string> Aliases;
+    if( !aliases.empty() ) {
+        ifstream infile;
+        infile.open (aliases.c_str());
+        if (infile.is_open()) {
+            while( true ) {
+                string line;
+                getline(infile,line);
+                if( infile.fail() )
+                    break;
+                if( (!line.empty()) && (line[0] != '#') ) {
+                    string::size_type pos = line.find(':',0);
+                    if( pos == string::npos )
+                        throw "Invalid alias";
+                    else {
+                        string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
+                        string key = line.substr(0, posl + 1);
+                        string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
+                        string val = line.substr(pos + 1 + posr, line.length());
+                        Aliases[key] = val;
                     }
                 }
-                infile.close();
-            } else
-                throw "Error opening aliases file";
-        }
-
-        FactorGraph fg;
-        fg.ReadFromFile( filename.c_str() );
-
-        vector<Factor> q0;
-        double logZ0 = 0.0;
+            }
+            infile.close();
+        } else
+            throw "Error opening aliases file";
+    }
 
-        cout.setf( ios_base::scientific );
-        cout.precision( 3 );
+    FactorGraph fg;
+    fg.ReadFromFile( filename.c_str() );
+
+    vector<Factor> q0;
+    double logZ0 = 0.0;
+
+    cout.setf( ios_base::scientific );
+    cout.precision( 3 );
+
+    cout << "# " << filename << endl;
+    cout.width( 39 );
+    cout << left << "# METHOD" << "\t";
+    if( report_time )
+        cout << right << "SECONDS  " << "\t";
+    if( report_iters )
+        cout << "ITERS" << "\t";
+    cout << "MAX ERROR" << "\t";
+    cout << "AVG ERROR" << "\t";
+    cout << "LOGZ ERROR" << "\t";
+    cout << "MAXDIFF" << "\t";
+    cout << endl;
+
+    for( size_t m = 0; m < methods.size(); m++ ) {
+        pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
+
+        if( vm.count("tol") )
+            meth.second.Set("tol",tol);
+        if( vm.count("maxiter") )
+            meth.second.Set("maxiter",maxiter);
+        if( vm.count("verbose") )
+            meth.second.Set("verbose",verbose);
+        TestDAI piet(fg, meth.first, meth.second );
+        piet.doDAI();
+        if( m == 0 ) {
+            q0 = piet.q;
+            logZ0 = piet.logZ;
+        }
+        piet.calcErrs(q0);
 
-        cout << "# " << filename << endl;
         cout.width( 39 );
-        cout << left << "# METHOD" << "\t";
+        cout << left << methods[m] << "\t";
         if( report_time )
-            cout << right << "SECONDS  " << "\t";
-        if( report_iters )
-            cout << "ITERS" << "\t";
-        cout << "MAX ERROR" << "\t";
-        cout << "AVG ERROR" << "\t";
-        cout << "LOGZ ERROR" << "\t";
-        cout << "MAXDIFF" << "\t";
-        cout << endl;
-
-        for( size_t m = 0; m < methods.size(); m++ ) {
-            pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
-
-            if( vm.count("tol") )
-                meth.second.Set("tol",tol);
-            if( vm.count("maxiter") )
-                meth.second.Set("maxiter",maxiter);
-            if( vm.count("verbose") )
-                meth.second.Set("verbose",verbose);
-            TestDAI piet(fg, meth.first, meth.second );
-            piet.doDAI();
-            if( m == 0 ) {
-                q0 = piet.q;
-                logZ0 = piet.logZ;
-            }
-            piet.calcErrs(q0);
-
-            cout.width( 39 );
-            cout << left << methods[m] << "\t";
-            if( report_time )
-                cout << right << piet.time << "\t";
-            if( report_iters ) {
-                if( piet.has_iters ) {
-                    cout << piet.iters << "\t";
-                } else {
-                    cout << "N/A  \t";
-                }
+            cout << right << piet.time << "\t";
+        if( report_iters ) {
+            if( piet.has_iters ) {
+                cout << piet.iters << "\t";
+            } else {
+                cout << "N/A  \t";
             }
+        }
 
-            if( m > 0 ) {
-                cout.setf( ios_base::scientific );
-                cout.precision( 3 );
-                
-                double me = clipdouble( piet.maxErr(), 1e-9 );
-                cout << me << "\t";
-                
-                double ae = clipdouble( piet.avgErr(), 1e-9 );
-                cout << ae << "\t";
-                
-                if( piet.has_logZ ) {
-                    cout.setf( ios::showpos );
-                    double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
-                    cout << le << "\t";
-                    cout.unsetf( ios::showpos );
-                } else
-                    cout << "N/A       \t";
-
-                if( piet.has_maxdiff ) {
-                    double md = clipdouble( piet.maxdiff, 1e-9 );
-                    if( isnan( me ) )
-                        md = me;
-                    if( isnan( ae ) )
-                        md = ae;
-                    cout << md << "\t";
-                } else
-                    cout << "N/A    \t";
-            }
-            cout << endl;
+        if( m > 0 ) {
+            cout.setf( ios_base::scientific );
+            cout.precision( 3 );
+            
+            double me = clipdouble( piet.maxErr(), 1e-9 );
+            cout << me << "\t";
+            
+            double ae = clipdouble( piet.avgErr(), 1e-9 );
+            cout << ae << "\t";
+            
+            if( piet.has_logZ ) {
+                cout.setf( ios::showpos );
+                double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
+                cout << le << "\t";
+                cout.unsetf( ios::showpos );
+            } else
+                cout << "N/A       \t";
+
+            if( piet.has_maxdiff ) {
+                double md = clipdouble( piet.maxdiff, 1e-9 );
+                if( isnan( me ) )
+                    md = me;
+                if( isnan( ae ) )
+                    md = ae;
+                cout << md << "\t";
+            } else
+                cout << "N/A    \t";
+        }
+        cout << endl;
 
-            if( marginals ) {
-                for( size_t i = 0; i < piet.q.size(); i++ )
-                    cout << "# " << piet.q[i] << endl;
-            }
+        if( marginals ) {
+            for( size_t i = 0; i < piet.q.size(); i++ )
+                cout << "# " << piet.q[i] << endl;
         }
-    } catch(const char *e) {
-        cerr << "Exception: " << e << endl;
-        return 1;
-    } catch(exception& e) {
-        cerr << "Exception: " << e.what() << endl;
-        return 1;
-    }
-    catch(...) {
-        cerr << "Exception of unknown type!" << endl;
     }
 
     return 0;