Merge branch 'eaton'
[libdai.git] / src / factorgraph.cpp
index 4f585eb..d8b9fea 100644 (file)
@@ -21,6 +21,7 @@
 
 
 #include <iostream>
+#include <iomanip>
 #include <iterator>
 #include <map>
 #include <set>
@@ -100,11 +101,8 @@ ostream& operator << (ostream& os, const FactorGraph& fg) {
                 nr_nonzeros++;
         os << nr_nonzeros << endl;
         for( size_t k = 0; k < fg.factor(I).states(); k++ )
-            if( fg.factor(I)[k] != 0.0 ) {
-                char buf[20];
-                sprintf(buf,"%18.14g", fg.factor(I)[k]);
-                os << k << " " << buf << endl;
-            }
+            if( fg.factor(I)[k] != 0.0 )
+                os << k << " " << setw(os.precision()+4) << fg.factor(I)[k] << endl;
     }
 
     return(os);
@@ -125,7 +123,7 @@ istream& operator >> (istream& is, FactorGraph& fg) {
     if( is.fail() )
         DAI_THROW(INVALID_FACTORGRAPH_FILE);
     if( verbose >= 2 )
-        cout << "Reading " << nr_Factors << " factors..." << endl;
+        cerr << "Reading " << nr_Factors << " factors..." << endl;
 
     getline (is,line);
     if( is.fail() )
@@ -134,13 +132,13 @@ istream& operator >> (istream& is, FactorGraph& fg) {
     map<long,size_t> vardims;
     for( size_t I = 0; I < nr_Factors; I++ ) {
         if( verbose >= 3 )
-            cout << "Reading factor " << I << "..." << endl;
+            cerr << "Reading factor " << I << "..." << endl;
         size_t nr_members;
         while( (is.peek()) == '#' )
             getline(is,line);
         is >> nr_members;
         if( verbose >= 3 )
-            cout << "  nr_members: " << nr_members << endl;
+            cerr << "  nr_members: " << nr_members << endl;
 
         vector<long> labels;
         for( size_t mi = 0; mi < nr_members; mi++ ) {
@@ -151,7 +149,7 @@ istream& operator >> (istream& is, FactorGraph& fg) {
             labels.push_back(mi_label);
         }
         if( verbose >= 3 )
-            cout << "  labels: " << labels << endl;
+            cerr << "  labels: " << labels << endl;
 
         vector<size_t> dims;
         for( size_t mi = 0; mi < nr_members; mi++ ) {
@@ -162,7 +160,7 @@ istream& operator >> (istream& is, FactorGraph& fg) {
             dims.push_back(mi_dim);
         }
         if( verbose >= 3 )
-            cout << "  dimensions: " << dims << endl;
+            cerr << "  dimensions: " << dims << endl;
 
         // add the Factor
         VarSet I_vars;
@@ -187,7 +185,7 @@ istream& operator >> (istream& is, FactorGraph& fg) {
             sigma[mi] = j_loc - labels.begin();
         }
         if( verbose >= 3 )
-            cout << "  sigma: " << sigma << endl;
+            cerr << "  sigma: " << sigma << endl;
 
         // calculate multindices
         Permute permindex( dims, sigma );
@@ -198,7 +196,7 @@ istream& operator >> (istream& is, FactorGraph& fg) {
             getline(is,line);
         is >> nr_nonzeros;
         if( verbose >= 3 ) 
-            cout << "  nonzeroes: " << nr_nonzeros << endl;
+            cerr << "  nonzeroes: " << nr_nonzeros << endl;
         for( size_t k = 0; k < nr_nonzeros; k++ ) {
             size_t li;
             double val;
@@ -216,7 +214,7 @@ istream& operator >> (istream& is, FactorGraph& fg) {
     }
 
     if( verbose >= 3 )
-        cout << "factors:" << facs << endl;
+        cerr << "factors:" << facs << endl;
 
     fg = FactorGraph(facs);
 
@@ -268,10 +266,11 @@ void FactorGraph::ReadFromFile( const char *filename ) {
 }
 
 
-void FactorGraph::WriteToFile( const char *filename ) const {
+void FactorGraph::WriteToFile( const char *filename, size_t precision ) const {
     ofstream outfile;
     outfile.open( filename );
     if( outfile.is_open() ) {
+        outfile.precision( precision );
         outfile << *this;
         outfile.close();
     } else
@@ -331,6 +330,37 @@ void FactorGraph::clamp( const Var & n, size_t i, bool backup ) {
 }
 
 
+void FactorGraph::clampVar( size_t i, const vector<size_t> &is, bool backup ) {
+    Var n = var(i);
+    Factor mask_n( n, 0.0 );
+
+    foreach( size_t i, is ) {
+        assert( i <= n.states() );
+        mask_n[i] = 1.0;
+    }
+
+    map<size_t, Factor> newFacs;
+    for( size_t I = 0; I < nrFactors(); I++ ) 
+        if( factor(I).vars().contains( n ) ) {
+            newFacs[I] = factor(I) * mask_n;
+        }
+    setFactors( newFacs, backup );
+}
+
+
+void FactorGraph::clampFactor( size_t I, const vector<size_t> &is, bool backup ) {
+    size_t st = factor(I).states();
+    Factor newF( factor(I).vars(), 0.0 );
+
+    foreach( size_t i, is ) { 
+        assert( i <= st ); 
+        newF[i] = factor(I)[i];
+    }
+
+    setFactor( I, newF, backup );
+}
+
+
 void FactorGraph::backupFactor( size_t I ) {
     map<size_t,Factor>::iterator it = _backup.find( I );
     if( it != _backup.end() )
@@ -373,6 +403,7 @@ void FactorGraph::restoreFactors() {
     _backup.clear();
 }
 
+
 void FactorGraph::backupFactors( const std::set<size_t> & facs ) {
     for( std::set<size_t>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ )
         backupFactor( *fac );