Merge branch 'eaton'
[libdai.git] / src / factorgraph.cpp
index 57614c2..d8b9fea 100644 (file)
@@ -21,6 +21,7 @@
 
 
 #include <iostream>
+#include <iomanip>
 #include <iterator>
 #include <map>
 #include <set>
@@ -100,11 +101,8 @@ ostream& operator << (ostream& os, const FactorGraph& fg) {
                 nr_nonzeros++;
         os << nr_nonzeros << endl;
         for( size_t k = 0; k < fg.factor(I).states(); k++ )
-            if( fg.factor(I)[k] != 0.0 ) {
-                char buf[20];
-                sprintf(buf,"%18.14g", fg.factor(I)[k]);
-                os << k << " " << buf << endl;
-            }
+            if( fg.factor(I)[k] != 0.0 )
+                os << k << " " << setw(os.precision()+4) << fg.factor(I)[k] << endl;
     }
 
     return(os);
@@ -268,10 +266,11 @@ void FactorGraph::ReadFromFile( const char *filename ) {
 }
 
 
-void FactorGraph::WriteToFile( const char *filename ) const {
+void FactorGraph::WriteToFile( const char *filename, size_t precision ) const {
     ofstream outfile;
     outfile.open( filename );
     if( outfile.is_open() ) {
+        outfile.precision( precision );
         outfile << *this;
         outfile.close();
     } else
@@ -331,6 +330,37 @@ void FactorGraph::clamp( const Var & n, size_t i, bool backup ) {
 }
 
 
+void FactorGraph::clampVar( size_t i, const vector<size_t> &is, bool backup ) {
+    Var n = var(i);
+    Factor mask_n( n, 0.0 );
+
+    foreach( size_t i, is ) {
+        assert( i <= n.states() );
+        mask_n[i] = 1.0;
+    }
+
+    map<size_t, Factor> newFacs;
+    for( size_t I = 0; I < nrFactors(); I++ ) 
+        if( factor(I).vars().contains( n ) ) {
+            newFacs[I] = factor(I) * mask_n;
+        }
+    setFactors( newFacs, backup );
+}
+
+
+void FactorGraph::clampFactor( size_t I, const vector<size_t> &is, bool backup ) {
+    size_t st = factor(I).states();
+    Factor newF( factor(I).vars(), 0.0 );
+
+    foreach( size_t i, is ) { 
+        assert( i <= st ); 
+        newF[i] = factor(I)[i];
+    }
+
+    setFactor( I, newF, backup );
+}
+
+
 void FactorGraph::backupFactor( size_t I ) {
     map<size_t,Factor>::iterator it = _backup.find( I );
     if( it != _backup.end() )
@@ -373,6 +403,7 @@ void FactorGraph::restoreFactors() {
     _backup.clear();
 }
 
+
 void FactorGraph::backupFactors( const std::set<size_t> & facs ) {
     for( std::set<size_t>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ )
         backupFactor( *fac );