Miscellaneous improvements in FactorGraph, Permute, HAK
[libdai.git] / utils / uai2fg.cpp
index 47bc025..f447366 100644 (file)
@@ -56,10 +56,10 @@ map<size_t, size_t> ReadUAIEvidenceFile( char* filename ) {
 
 
 /// Reads factor graph (as a pair of a variable vector and factor vector) from a UAI factor graph file
-pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t verbose ) {
-    pair<vector<Var>, vector<Factor> > result;
-    vector<Var>& vars = result.first;
-    vector<Factor>& factors = result.second;
+void ReadUAIFGFile( const char *filename, size_t verbose, vector<Var>& vars, vector<Factor>& factors, vector<Permute>& permutations ) {
+    vars.clear();
+    factors.clear();
+    permutations.clear();
 
     // open file
     ifstream is;
@@ -100,9 +100,9 @@ pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t v
             cout << "Reading " << nrFacs << " factors..." << endl;
 
         // for each factor, read the variables on which it depends
-        vector<vector<long> > labels;
+        vector<vector<Var> > factorVars;
         factors.reserve( nrFacs );
-        labels.reserve( nrFacs );
+        factorVars.reserve( nrFacs );
         for( size_t I = 0; I < nrFacs; I++ ) {
             if( verbose >= 3 )
                 cout << "Reading factor " << I << "..." << endl;
@@ -115,12 +115,12 @@ pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t v
             if( verbose >= 3 )
                 cout << "  which depends on " << I_nrVars << " variables" << endl;
 
-            // for each of the variables, read its label and number of states
+            // read the variable labels
             vector<long> I_labels;
             vector<size_t> I_dims;
-            VarSet I_vars;
             I_labels.reserve( I_nrVars );
             I_dims.reserve( I_nrVars );
+            factorVars[I].reserve( I_nrVars );
             for( size_t _i = 0; _i < I_nrVars; _i++ ) {
                 long label;
                 is >> label;
@@ -128,45 +128,25 @@ pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t v
                     DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read variable labels for " + toString(I) + "'th factor");
                 I_labels.push_back( label );
                 I_dims.push_back( vars[label].states() );
-                I_vars |= vars[label];
+                factorVars[I].push_back( vars[label] );
             }
             if( verbose >= 3 )
                 cout << "  labels: " << I_labels << ", dimensions " << I_dims << endl;
 
             // add the factor and the labels
-            factors.push_back( Factor(I_vars,0.0) );
-            labels.push_back( I_labels );
+            factors.push_back( Factor( VarSet( factorVars[I].begin(), factorVars[I].end(), factorVars[I].size() ), (Real)0 ) );
         }
 
         // for each factor, read its values
+        permutations.reserve( nrFacs );
         for( size_t I = 0; I < nrFacs; I++ ) {
             if( verbose >= 3 )
                 cout << "Reading factor " << I << "..." << endl;
 
-            // last label is least significant, so we reverse the label vector
-            reverse( labels[I].begin(), labels[I].end() );
+            // calculate permutation object, reversing the indexing in factorVars[I] first
+            Permute permindex( factorVars[I], true );
+            permutations.push_back( permindex );
 
-            // prepare a vector containing the dimensionalities of the variables for this factor
-            size_t I_nrVars = factors[I].vars().size();
-            vector<size_t> I_dims;
-            I_dims.reserve( I_nrVars );
-            for( size_t _i = 0; _i < I_nrVars; _i++ )
-                I_dims.push_back( vars[labels[I][_i]].states() );
-            if( verbose >= 3 )
-                cout << "  labels: " << labels[I] << ", dimensions " << I_dims << endl;
-
-            // calculate permutation sigma (internally, members are sorted canonically, 
-            // which may be different from the way they are sorted in the file)
-            vector<size_t> sigma( I_nrVars, 0 );
-            VarSet::const_iterator j = factors[I].vars().begin();
-            for( size_t mi = 0; mi < I_nrVars; mi++, j++ )
-                sigma[mi] = distance( labels[I].begin(), find( labels[I].begin(), labels[I].end(), j->label() ) );
-            if( verbose >= 3 )
-                cout << "  permutation: " << sigma << endl;
-
-            // construct permutation object
-            Permute permindex( I_dims, sigma );
-            
             // read factor values
             size_t nrNonZeros;
             is >> nrNonZeros;
@@ -181,9 +161,13 @@ pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t v
                 if( is.fail() )
                     DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read factor values of " + toString(I) + "'th factor");
                 // assign value after calculating its linear index corresponding to the permutation
+                if( verbose >= 4 )
+                    cout << "  " << li << "'th value " << val << " corresponds with index " << permindex.convertLinearIndex(li) << endl;
                 factors[I][permindex.convertLinearIndex( li )] = val;
             }
         }
+        if( verbose >= 3 )
+            cout << "variables:" << vars << endl;
         if( verbose >= 3 )
             cout << "factors:" << factors << endl;
 
@@ -191,8 +175,6 @@ pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t v
         is.close();
     } else
         DAI_THROWE(CANNOT_READ_FILE,"Cannot read from file " + std::string(filename));
-
-    return result;
 }
 
 
@@ -212,11 +194,14 @@ int main( int argc, char *argv[] ) {
         long type = atoi( argv[4] );
         bool run_jtree = atoi( argv[5] );
 
-        // read factor graph and evidence
-        pair<vector<Var>, vector<Factor> > varfacs = ReadUAIFGFile( argv[1], verbose );
+        // read factor graph
+        vector<Var> vars;
+        vector<Factor> facs;
+        vector<Permute> permutations;
+        ReadUAIFGFile( argv[1], verbose, vars, facs, permutations );
+
+        // read evidence
         map<size_t,size_t> evid = ReadUAIEvidenceFile( argv[2] );
-        vector<Var>& vars = varfacs.first;
-        vector<Factor>& facs = varfacs.second;
 
         // construct unclamped factor graph
         FactorGraph fg0( facs.begin(), facs.end(), vars.begin(), vars.end(), facs.size(), vars.size() );