[Ofer Meshi] Added a script to convert from FastInf fileformat to libDAI fileformat
[libdai.git] / tests / testdai.cpp
index f47ef4d..93416e0 100644 (file)
@@ -26,6 +26,16 @@ using namespace dai;
 namespace po = boost::program_options;
 
 
+std::vector<Real> calcDists( const vector<Factor> &x, const vector<Factor> &y ) {
+    vector<Real> errs;
+    errs.reserve( x.size() );
+    DAI_ASSERT( x.size() == y.size() );
+    for( size_t i = 0; i < x.size(); i++ )
+        errs.push_back( dist( x[i], y[i], DISTTV ) );
+    return errs;
+}
+
+
 /// Wrapper class for DAI approximate inference algorithms
 class TestDAI {
     protected:
@@ -34,11 +44,15 @@ class TestDAI {
         /// Stores the name of the InfAlg algorithm
         string          name;
         /// Stores the total variation distances of the variable marginals
-        vector<Real>    err;
+        vector<Real>    varErr;
+        /// Stores the total variation distances of the factor marginals
+        vector<Real>    facErr;
 
     public:
         /// Stores the variable marginals
         vector<Factor>  varMarginals;
+        /// Stores the factor marginals
+        vector<Factor>  facMarginals;
         /// Stores all marginals
         vector<Factor>  allMarginals;
         /// Stores the logarithm of the partition sum
@@ -57,7 +71,7 @@ class TestDAI {
         bool            has_iters;
 
         /// Construct from factor graph \a fg, name \a _name, and set of properties \a opts
-        TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), varMarginals(), allMarginals(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
+        TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), varErr(), facErr(), varMarginals(), facMarginals(), allMarginals(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
             double tic = toc();
 
             if( name == "LDPC" ) {
@@ -143,35 +157,47 @@ class TestDAI {
                 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
                     varMarginals.push_back( obj->beliefV( i ) );
 
+                // Store factor marginals
+                facMarginals.clear();
+                for( size_t I = 0; I < obj->fg().nrFactors(); I++ )
+                    try {
+                        facMarginals.push_back( obj->beliefF( I ) );
+                    } catch( Exception &e ) {
+                        if( e.code() == Exception::BELIEF_NOT_AVAILABLE )
+                            facMarginals.push_back( Factor( obj->fg().factor(I).vars(), INFINITY ) );
+                        else
+                            throw;
+                    }
+
                 // Store all marginals calculated by the method
                 allMarginals = obj->beliefs();
             };
         }
 
-        /// Calculate total variation distance of variable marginals with respect to those in \a x
-        void calcErrs( const TestDAI &x ) {
-            err.clear();
-            err.reserve( varMarginals.size() );
-            for( size_t i = 0; i < varMarginals.size(); i++ )
-                err.push_back( dist( varMarginals[i], x.varMarginals[i], Prob::DISTTV ) );
+        /// Calculate total variation distance of variable and factor marginals with respect to those in \a varMargs and \a facMargs
+        void calcErrors( const vector<Factor>& varMargs, const vector<Factor>& facMargs ) {
+            varErr = calcDists( varMarginals, varMargs );
+            facErr = calcDists( facMarginals, facMargs );
+        }
+
+        /// Return maximum variable error
+        Real maxVarErr() {
+            return( *max_element( varErr.begin(), varErr.end() ) );
         }
 
-        /// Calculate total variation distance of variable marginals with respect to those in \a x
-        void calcErrs( const vector<Factor> &x ) {
-            err.clear();
-            err.reserve( varMarginals.size() );
-            for( size_t i = 0; i < varMarginals.size(); i++ )
-                err.push_back( dist( varMarginals[i], x[i], Prob::DISTTV ) );
+        /// Return average variable error
+        Real avgVarErr() {
+            return( accumulate( varErr.begin(), varErr.end(), 0.0 ) / varErr.size() );
         }
 
-        /// Return maximum error
-        Real maxErr() {
-            return( *max_element( err.begin(), err.end() ) );
+        /// Return maximum factor error
+        Real maxFacErr() {
+            return( *max_element( facErr.begin(), facErr.end() ) );
         }
 
-        /// Return average error
-        Real avgErr() {
-            return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
+        /// Return average factor error
+        Real avgFacErr() {
+            return( accumulate( facErr.begin(), facErr.end(), 0.0 ) / facErr.size() );
         }
 };
 
@@ -185,8 +211,8 @@ Real clipReal( Real x, Real minabs ) {
 }
 
 
-/// Whether to output no marginals, only variable marginals, or all calculated marginals
-DAI_ENUM(MarginalsOutputType,NONE,VAR,ALL);
+/// Which marginals to outpu (none, only variable, only factor, variable and factor, all)
+DAI_ENUM(MarginalsOutputType,NONE,VAR,FAC,VARFAC,ALL);
 
 
 /// Main function
@@ -217,7 +243,7 @@ int main( int argc, char *argv[] ) {
     opts_optional.add_options()
         ("help", "Produce help message")
         ("aliases", po::value< string >(&aliases), "Filename for aliases")
-        ("marginals", po::value< MarginalsOutputType >(&marginals), "Output marginals? (NONE/VAR/ALL, default=NONE)")
+        ("marginals", po::value< MarginalsOutputType >(&marginals), "Output marginals? (NONE/VAR/FAC/VARFAC/ALL, default=NONE)")
         ("report-time", po::value< bool >(&report_time), "Output calculation time (default==1)?")
         ("report-iters", po::value< bool >(&report_iters), "Output iterations needed (default==1)?")
     ;
@@ -241,11 +267,14 @@ int main( int argc, char *argv[] ) {
         cout << "  o the number of iterations needed (if report-iters == 1);" << endl;
         cout << "  o the maximum (over all variables) total variation error in the variable marginals;" << endl;
         cout << "  o the average (over all variables) total variation error in the variable marginals;" << endl;
+        cout << "  o the maximum (over all factors) total variation error in the factor marginals;" << endl;
+        cout << "  o the average (over all factors) total variation error in the factor marginals;" << endl;
         cout << "  o the error (difference) of the logarithm of the partition sums;" << endl << endl;
         cout << "All errors are calculated by comparing the results of the current method with" << endl; 
         cout << "the results of the first method (the base method). If marginals==VAR, additional" << endl;
-        cout << "output consists of the variable marginals, and if marginals==ALL, all marginals" << endl;
-        cout << "calculated by the method are reported." << endl << endl;
+        cout << "output consists of the variable marginals, if marginals==FAC, the factor marginals" << endl;
+        cout << "if marginals==VARFAC, both variable and factor marginals, and if marginals==ALL, all" << endl;
+        cout << "marginals calculated by the method are reported." << endl << endl;
         cout << "<method*> should be a list of one or more methods, seperated by spaces, in the format:" << endl << endl;
         cout << "    name[key1=val1,key2=val2,key3=val3,...,keyn=valn]" << endl << endl;
         cout << "where name should be the name of an algorithm in libDAI (or an alias, if an alias" << endl;
@@ -273,8 +302,9 @@ int main( int argc, char *argv[] ) {
         FactorGraph fg;
         fg.ReadFromFile( filename.c_str() );
 
-        // Declare variables used for storing variable marginals and log partition sum of base method
+        // Declare variables used for storing variable factor marginals and log partition sum of base method
         vector<Factor> varMarginals0;
+        vector<Factor> facMarginals0;
         Real logZ0 = 0.0;
 
         // Output header
@@ -287,8 +317,10 @@ int main( int argc, char *argv[] ) {
             cout << right << "SECONDS  " << "\t";
         if( report_iters )
             cout << "ITERS" << "\t";
-        cout << "MAX ERROR" << "\t";
-        cout << "AVG ERROR" << "\t";
+        cout << "MAX VAR ERR" << "\t";
+        cout << "AVG VAR ERR" << "\t";
+        cout << "MAX FAC ERR" << "\t";
+        cout << "AVG FAC ERR" << "\t";
         cout << "LOGZ ERROR" << "\t";
         cout << "MAXDIFF" << "\t";
         cout << endl;
@@ -315,11 +347,12 @@ int main( int argc, char *argv[] ) {
             // For the base method, store its variable marginals and logarithm of the partition sum
             if( m == 0 ) {
                 varMarginals0 = testdai.varMarginals;
+                facMarginals0 = testdai.facMarginals;
                 logZ0 = testdai.logZ;
             }
 
             // Calculate errors relative to base method
-            testdai.calcErrs( varMarginals0 );
+            testdai.calcErrors( varMarginals0, facMarginals0 );
 
             // Output method name
             cout.width( 39 );
@@ -342,12 +375,26 @@ int main( int argc, char *argv[] ) {
                 cout.precision( 3 );
 
                 // Output maximum error in variable marginals
-                Real me = clipReal( testdai.maxErr(), 1e-9 );
-                cout << me << "\t";
+                Real mev = clipReal( testdai.maxVarErr(), 1e-9 );
+                cout << mev << "\t";
 
                 // Output average error in variable marginals
-                Real ae = clipReal( testdai.avgErr(), 1e-9 );
-                cout << ae << "\t";
+                Real aev = clipReal( testdai.avgVarErr(), 1e-9 );
+                cout << aev << "\t";
+
+                // Output maximum error in factor marginals
+                Real mef = clipReal( testdai.maxFacErr(), 1e-9 );
+                if( mef == INFINITY )
+                    cout << "N/A       \t";
+                else
+                    cout << mef << "\t";
+
+                // Output average error in factor marginals
+                Real aef = clipReal( testdai.avgFacErr(), 1e-9 );
+                if( aef == INFINITY )
+                    cout << "N/A       \t";
+                else
+                    cout << aef << "\t";
 
                 // Output error in log partition sum
                 if( testdai.has_logZ ) {
@@ -361,10 +408,10 @@ int main( int argc, char *argv[] ) {
                 // Output maximum difference in last iteration
                 if( testdai.has_maxdiff ) {
                     Real md = clipReal( testdai.maxdiff, 1e-9 );
-                    if( isnan( me ) )
-                        md = me;
-                    if( isnan( ae ) )
-                        md = ae;
+                    if( isnan( mev ) )
+                        md = mev;
+                    if( isnan( aev ) )
+                        md = aev;
                     if( md == INFINITY )
                         md = 1.0;
                     cout << md << "\t";
@@ -374,13 +421,15 @@ int main( int argc, char *argv[] ) {
             cout << endl;
 
             // Output marginals, if requested
-            if( marginals == MarginalsOutputType::VAR ) {
+            if( marginals == MarginalsOutputType::VAR || marginals == MarginalsOutputType::VARFAC )
                 for( size_t i = 0; i < testdai.varMarginals.size(); i++ )
                     cout << "# " << testdai.varMarginals[i] << endl;
-            } else if( marginals == MarginalsOutputType::ALL ) {
+            if( marginals == MarginalsOutputType::FAC || marginals == MarginalsOutputType::VARFAC )
+                for( size_t I = 0; I < testdai.facMarginals.size(); I++ )
+                    cout << "# " << testdai.facMarginals[I] << endl;
+            if( marginals == MarginalsOutputType::ALL )
                 for( size_t I = 0; I < testdai.allMarginals.size(); I++ )
                     cout << "# " << testdai.allMarginals[I] << endl;
-            }
         }
 
         return 0;