Added BipartiteGraph::eraseEdge and some small cleanup
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Mon, 20 Jul 2009 13:23:35 +0000 (15:23 +0200)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Mon, 20 Jul 2009 13:23:35 +0000 (15:23 +0200)
include/dai/bipgraph.h
include/dai/factor.h
tests/testdai.cpp

index f202f99..17b39a7 100644 (file)
@@ -258,6 +258,22 @@ class BipartiteGraph {
         /// Removes node n2 of type 2 and all incident edges.
         void erase2( size_t n2 );
 
+        /// Removes edge between node n1 of type 1 and node n2 of type 2.
+        void eraseEdge( size_t n1, size_t n2 ) {
+            assert( n1 < nr1() );
+            assert( n2 < nr2() );
+            for( Neighbors::iterator i1 = _nb1[n1].begin(); i1 != _nb1[n1].end(); i1++ )
+                if( i1->node == n2 ) {
+                    _nb1[n1].erase( i1 );
+                    break;
+                }
+            for( Neighbors::iterator i2 = _nb2[n2].begin(); i2 != _nb2[n2].end(); i2++ )
+                if( i2->node == n1 ) {
+                    _nb2[n2].erase( i2 );
+                    break;
+                }
+        }
+
         /// Adds an edge between node n1 of type 1 and node n2 of type 2.
         /** If check == true, only adds the edge if it does not exist already.
          */
index d719e7a..42df4ce 100644 (file)
@@ -562,7 +562,7 @@ template<typename T> Real MutualInfo(const TFactor<T> &f) {
     VarSet::const_iterator it = f.vars().begin();
     Var i = *it; it++; Var j = *it;
     TFactor<T> projection = f.marginal(i) * f.marginal(j);
-    return real( dist( f.normalized(), projection, Prob::DISTKL ) );
+    return dist( f.normalized(), projection, Prob::DISTKL );
 }
 
 
index f153b72..d3d340f 100644 (file)
@@ -259,121 +259,126 @@ int main( int argc, char *argv[] ) {
         return 1;
     }
 
-    // Read aliases
-    map<string,string> Aliases;
-    if( !aliases.empty() ) {
-        ifstream infile;
-        infile.open (aliases.c_str());
-        if (infile.is_open()) {
-            while( true ) {
-                string line;
-                getline(infile,line);
-                if( infile.fail() )
-                    break;
-                if( (!line.empty()) && (line[0] != '#') ) {
-                    string::size_type pos = line.find(':',0);
-                    if( pos == string::npos )
-                        throw "Invalid alias";
-                    else {
-                        string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
-                        string key = line.substr(0, posl + 1);
-                        string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
-                        string val = line.substr(pos + 1 + posr, line.length());
-                        Aliases[key] = val;
+    try {
+        // Read aliases
+        map<string,string> Aliases;
+        if( !aliases.empty() ) {
+            ifstream infile;
+            infile.open (aliases.c_str());
+            if (infile.is_open()) {
+                while( true ) {
+                    string line;
+                    getline(infile,line);
+                    if( infile.fail() )
+                        break;
+                    if( (!line.empty()) && (line[0] != '#') ) {
+                        string::size_type pos = line.find(':',0);
+                        if( pos == string::npos )
+                            throw string("Invalid alias");
+                        else {
+                            string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
+                            string key = line.substr(0, posl + 1);
+                            string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
+                            string val = line.substr(pos + 1 + posr, line.length());
+                            Aliases[key] = val;
+                        }
                     }
                 }
-            }
-            infile.close();
-        } else
-            throw "Error opening aliases file";
-    }
-
-    FactorGraph fg;
-    fg.ReadFromFile( filename.c_str() );
-
-    vector<Factor> q0;
-    double logZ0 = 0.0;
-
-    cout.setf( ios_base::scientific );
-    cout.precision( 3 );
-
-    cout << "# " << filename << endl;
-    cout.width( 39 );
-    cout << left << "# METHOD" << "\t";
-    if( report_time )
-        cout << right << "SECONDS  " << "\t";
-    if( report_iters )
-        cout << "ITERS" << "\t";
-    cout << "MAX ERROR" << "\t";
-    cout << "AVG ERROR" << "\t";
-    cout << "LOGZ ERROR" << "\t";
-    cout << "MAXDIFF" << "\t";
-    cout << endl;
-
-    for( size_t m = 0; m < methods.size(); m++ ) {
-        pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
-
-        if( vm.count("tol") )
-            meth.second.Set("tol",tol);
-        if( vm.count("maxiter") )
-            meth.second.Set("maxiter",maxiter);
-        if( vm.count("verbose") )
-            meth.second.Set("verbose",verbose);
-        TestDAI piet(fg, meth.first, meth.second );
-        piet.doDAI();
-        if( m == 0 ) {
-            q0 = piet.q;
-            logZ0 = piet.logZ;
+                infile.close();
+            } else
+                throw string("Error opening aliases file");
         }
-        piet.calcErrs(q0);
 
+        FactorGraph fg;
+        fg.ReadFromFile( filename.c_str() );
+
+        vector<Factor> q0;
+        double logZ0 = 0.0;
+
+        cout.setf( ios_base::scientific );
+        cout.precision( 3 );
+
+        cout << "# " << filename << endl;
         cout.width( 39 );
-        cout << left << methods[m] << "\t";
+        cout << left << "# METHOD" << "\t";
         if( report_time )
-            cout << right << piet.time << "\t";
-        if( report_iters ) {
-            if( piet.has_iters ) {
-                cout << piet.iters << "\t";
-            } else {
-                cout << "N/A  \t";
+            cout << right << "SECONDS  " << "\t";
+        if( report_iters )
+            cout << "ITERS" << "\t";
+        cout << "MAX ERROR" << "\t";
+        cout << "AVG ERROR" << "\t";
+        cout << "LOGZ ERROR" << "\t";
+        cout << "MAXDIFF" << "\t";
+        cout << endl;
+
+        for( size_t m = 0; m < methods.size(); m++ ) {
+            pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
+
+            if( vm.count("tol") )
+                meth.second.Set("tol",tol);
+            if( vm.count("maxiter") )
+                meth.second.Set("maxiter",maxiter);
+            if( vm.count("verbose") )
+                meth.second.Set("verbose",verbose);
+            TestDAI piet(fg, meth.first, meth.second );
+            piet.doDAI();
+            if( m == 0 ) {
+                q0 = piet.q;
+                logZ0 = piet.logZ;
+            }
+            piet.calcErrs(q0);
+
+            cout.width( 39 );
+            cout << left << methods[m] << "\t";
+            if( report_time )
+                cout << right << piet.time << "\t";
+            if( report_iters ) {
+                if( piet.has_iters ) {
+                    cout << piet.iters << "\t";
+                } else {
+                    cout << "N/A  \t";
+                }
             }
-        }
 
-        if( m > 0 ) {
-            cout.setf( ios_base::scientific );
-            cout.precision( 3 );
-            
-            double me = clipdouble( piet.maxErr(), 1e-9 );
-            cout << me << "\t";
-            
-            double ae = clipdouble( piet.avgErr(), 1e-9 );
-            cout << ae << "\t";
-            
-            if( piet.has_logZ ) {
-                cout.setf( ios::showpos );
-                double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
-                cout << le << "\t";
-                cout.unsetf( ios::showpos );
-            } else
-                cout << "N/A       \t";
-
-            if( piet.has_maxdiff ) {
-                double md = clipdouble( piet.maxdiff, 1e-9 );
-                if( isnan( me ) )
-                    md = me;
-                if( isnan( ae ) )
-                    md = ae;
-                cout << md << "\t";
-            } else
-                cout << "N/A    \t";
-        }
-        cout << endl;
+            if( m > 0 ) {
+                cout.setf( ios_base::scientific );
+                cout.precision( 3 );
+                
+                double me = clipdouble( piet.maxErr(), 1e-9 );
+                cout << me << "\t";
+                
+                double ae = clipdouble( piet.avgErr(), 1e-9 );
+                cout << ae << "\t";
+                
+                if( piet.has_logZ ) {
+                    cout.setf( ios::showpos );
+                    double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
+                    cout << le << "\t";
+                    cout.unsetf( ios::showpos );
+                } else
+                    cout << "N/A       \t";
+
+                if( piet.has_maxdiff ) {
+                    double md = clipdouble( piet.maxdiff, 1e-9 );
+                    if( isnan( me ) )
+                        md = me;
+                    if( isnan( ae ) )
+                        md = ae;
+                    cout << md << "\t";
+                } else
+                    cout << "N/A    \t";
+            }
+            cout << endl;
 
-        if( marginals ) {
-            for( size_t i = 0; i < piet.q.size(); i++ )
-                cout << "# " << piet.q[i] << endl;
+            if( marginals ) {
+                for( size_t i = 0; i < piet.q.size(); i++ )
+                    cout << "# " << piet.q[i] << endl;
+            }
         }
-    }
 
-    return 0;
+        return 0;
+    } catch( string &s ) {
+        cerr << "Exception: " << s << endl;
+        return 2;
+    }
 }