Merge branch 'eaton'
[libdai.git] / src / factorgraph.cpp
index c28e864..d8b9fea 100644 (file)
@@ -1,6 +1,7 @@
-/*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
-    Radboud University Nijmegen, The Netherlands
-    
+/*  Copyright (C) 2006-2008  Joris Mooij  [joris dot mooij at tuebingen dot mpg dot de]
+    Radboud University Nijmegen, The Netherlands /
+    Max Planck Institute for Biological Cybernetics, Germany
+
     This file is part of libDAI.
 
     libDAI is free software; you can redistribute it and/or modify
@@ -20,6 +21,7 @@
 
 
 #include <iostream>
+#include <iomanip>
 #include <iterator>
 #include <map>
 #include <set>
@@ -40,31 +42,30 @@ using namespace std;
 
 FactorGraph::FactorGraph( const std::vector<Factor> &P ) : G(), _backup() {
     // add factors, obtain variables
-    set<Var> _vars;
+    set<Var> varset;
     _factors.reserve( P.size() );
     size_t nrEdges = 0;
     for( vector<Factor>::const_iterator p2 = P.begin(); p2 != P.end(); p2++ ) {
         _factors.push_back( *p2 );
-        copy( p2->vars().begin(), p2->vars().end(), inserter( _vars, _vars.begin() ) );
+        copy( p2->vars().begin(), p2->vars().end(), inserter( varset, varset.begin() ) );
         nrEdges += p2->vars().size();
     }
 
-    // add _vars
-    vars.reserve( _vars.size() );
-    for( set<Var>::const_iterator p1 = _vars.begin(); p1 != _vars.end(); p1++ )
-        vars.push_back( *p1 );
+    // add vars
+    _vars.reserve( varset.size() );
+    for( set<Var>::const_iterator p1 = varset.begin(); p1 != varset.end(); p1++ )
+        _vars.push_back( *p1 );
 
     // create graph structure
     constructGraph( nrEdges );
 }
 
 
-/// Part of constructors (creates edges, neighbours and adjacency matrix)
 void FactorGraph::constructGraph( size_t nrEdges ) {
     // create a mapping for indices
     hash_map<size_t, size_t> hashmap;
     
-    for( size_t i = 0; i < vars.size(); i++ )
+    for( size_t i = 0; i < vars().size(); i++ )
         hashmap[var(i).label()] = i;
     
     // create edge list
@@ -81,6 +82,7 @@ void FactorGraph::constructGraph( size_t nrEdges ) {
 }
 
 
+/// Writes a FactorGraph to an output stream
 ostream& operator << (ostream& os, const FactorGraph& fg) {
     os << fg.nrFactors() << endl;
 
@@ -99,129 +101,123 @@ ostream& operator << (ostream& os, const FactorGraph& fg) {
                 nr_nonzeros++;
         os << nr_nonzeros << endl;
         for( size_t k = 0; k < fg.factor(I).states(); k++ )
-            if( fg.factor(I)[k] != 0.0 ) {
-                char buf[20];
-                sprintf(buf,"%18.14g", fg.factor(I)[k]);
-                os << k << " " << buf << endl;
-            }
+            if( fg.factor(I)[k] != 0.0 )
+                os << k << " " << setw(os.precision()+4) << fg.factor(I)[k] << endl;
     }
 
     return(os);
 }
 
 
+/// Reads a FactorGraph from an input stream
 istream& operator >> (istream& is, FactorGraph& fg) {
     long verbose = 0;
 
-    try {
-        vector<Factor> facs;
-        size_t nr_Factors;
-        string line;
-        
+    vector<Factor> facs;
+    size_t nr_Factors;
+    string line;
+    
+    while( (is.peek()) == '#' )
+        getline(is,line);
+    is >> nr_Factors;
+    if( is.fail() )
+        DAI_THROW(INVALID_FACTORGRAPH_FILE);
+    if( verbose >= 2 )
+        cerr << "Reading " << nr_Factors << " factors..." << endl;
+
+    getline (is,line);
+    if( is.fail() )
+        DAI_THROW(INVALID_FACTORGRAPH_FILE);
+
+    map<long,size_t> vardims;
+    for( size_t I = 0; I < nr_Factors; I++ ) {
+        if( verbose >= 3 )
+            cerr << "Reading factor " << I << "..." << endl;
+        size_t nr_members;
         while( (is.peek()) == '#' )
             getline(is,line);
-        is >> nr_Factors;
-        if( is.fail() )
-            DAI_THROW(INVALID_FACTORGRAPH_FILE);
-        if( verbose >= 2 )
-            cout << "Reading " << nr_Factors << " factors..." << endl;
-
-        getline (is,line);
-        if( is.fail() )
-            DAI_THROW(INVALID_FACTORGRAPH_FILE);
-
-        map<long,size_t> vardims;
-        for( size_t I = 0; I < nr_Factors; I++ ) {
-            if( verbose >= 3 )
-                cout << "Reading factor " << I << "..." << endl;
-            size_t nr_members;
+        is >> nr_members;
+        if( verbose >= 3 )
+            cerr << "  nr_members: " << nr_members << endl;
+
+        vector<long> labels;
+        for( size_t mi = 0; mi < nr_members; mi++ ) {
+            long mi_label;
             while( (is.peek()) == '#' )
                 getline(is,line);
-            is >> nr_members;
-            if( verbose >= 3 )
-                cout << "  nr_members: " << nr_members << endl;
-
-            vector<long> labels;
-            for( size_t mi = 0; mi < nr_members; mi++ ) {
-                long mi_label;
-                while( (is.peek()) == '#' )
-                    getline(is,line);
-                is >> mi_label;
-                labels.push_back(mi_label);
-            }
-            if( verbose >= 3 )
-                cout << "  labels: " << labels << endl;
-
-            vector<size_t> dims;
-            for( size_t mi = 0; mi < nr_members; mi++ ) {
-                size_t mi_dim;
-                while( (is.peek()) == '#' )
-                    getline(is,line);
-                is >> mi_dim;
-                dims.push_back(mi_dim);
-            }
-            if( verbose >= 3 )
-                cout << "  dimensions: " << dims << endl;
-
-            // add the Factor
-            VarSet I_vars;
-            for( size_t mi = 0; mi < nr_members; mi++ ) {
-                map<long,size_t>::iterator vdi = vardims.find( labels[mi] );
-                if( vdi != vardims.end() ) {
-                    // check whether dimensions are consistent
-                    if( vdi->second != dims[mi] )
-                        DAI_THROW(INVALID_FACTORGRAPH_FILE);
-                } else
-                    vardims[labels[mi]] = dims[mi];
-                I_vars |= Var(labels[mi], dims[mi]);
-            }
-            facs.push_back( Factor( I_vars, 0.0 ) );
-            
-            // calculate permutation sigma (internally, members are sorted)
-            vector<size_t> sigma(nr_members,0);
-            VarSet::iterator j = I_vars.begin();
-            for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
-                long search_for = j->label();
-                vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
-                sigma[mi] = j_loc - labels.begin();
-            }
-            if( verbose >= 3 )
-                cout << "  sigma: " << sigma << endl;
-
-            // calculate multindices
-            Permute permindex( dims, sigma );
-            
-            // read values
-            size_t nr_nonzeros;
+            is >> mi_label;
+            labels.push_back(mi_label);
+        }
+        if( verbose >= 3 )
+            cerr << "  labels: " << labels << endl;
+
+        vector<size_t> dims;
+        for( size_t mi = 0; mi < nr_members; mi++ ) {
+            size_t mi_dim;
             while( (is.peek()) == '#' )
                 getline(is,line);
-            is >> nr_nonzeros;
-            if( verbose >= 3 ) 
-                cout << "  nonzeroes: " << nr_nonzeros << endl;
-            for( size_t k = 0; k < nr_nonzeros; k++ ) {
-                size_t li;
-                double val;
-                while( (is.peek()) == '#' )
-                    getline(is,line);
-                is >> li;
-                while( (is.peek()) == '#' )
-                    getline(is,line);
-                is >> val;
-
-                // store value, but permute indices first according
-                // to internal representation
-                facs.back()[permindex.convert_linear_index( li  )] = val;
-            }
+            is >> mi_dim;
+            dims.push_back(mi_dim);
+        }
+        if( verbose >= 3 )
+            cerr << "  dimensions: " << dims << endl;
+
+        // add the Factor
+        VarSet I_vars;
+        for( size_t mi = 0; mi < nr_members; mi++ ) {
+            map<long,size_t>::iterator vdi = vardims.find( labels[mi] );
+            if( vdi != vardims.end() ) {
+                // check whether dimensions are consistent
+                if( vdi->second != dims[mi] )
+                    DAI_THROW(INVALID_FACTORGRAPH_FILE);
+            } else
+                vardims[labels[mi]] = dims[mi];
+            I_vars |= Var(labels[mi], dims[mi]);
+        }
+        facs.push_back( Factor( I_vars, 0.0 ) );
+        
+        // calculate permutation sigma (internally, members are sorted)
+        vector<size_t> sigma(nr_members,0);
+        VarSet::iterator j = I_vars.begin();
+        for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
+            long search_for = j->label();
+            vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
+            sigma[mi] = j_loc - labels.begin();
         }
-
         if( verbose >= 3 )
-            cout << "factors:" << facs << endl;
+            cerr << "  sigma: " << sigma << endl;
+
+        // calculate multindices
+        Permute permindex( dims, sigma );
+        
+        // read values
+        size_t nr_nonzeros;
+        while( (is.peek()) == '#' )
+            getline(is,line);
+        is >> nr_nonzeros;
+        if( verbose >= 3 ) 
+            cerr << "  nonzeroes: " << nr_nonzeros << endl;
+        for( size_t k = 0; k < nr_nonzeros; k++ ) {
+            size_t li;
+            double val;
+            while( (is.peek()) == '#' )
+                getline(is,line);
+            is >> li;
+            while( (is.peek()) == '#' )
+                getline(is,line);
+            is >> val;
 
-        fg = FactorGraph(facs);
-    } catch (char *e) {
-        cout << e << endl;
+            // store value, but permute indices first according
+            // to internal representation
+            facs.back()[permindex.convert_linear_index( li  )] = val;
+        }
     }
 
+    if( verbose >= 3 )
+        cerr << "factors:" << facs << endl;
+
+    fg = FactorGraph(facs);
+
     return is;
 }
 
@@ -270,10 +266,11 @@ void FactorGraph::ReadFromFile( const char *filename ) {
 }
 
 
-void FactorGraph::WriteToFile( const char *filename ) const {
+void FactorGraph::WriteToFile( const char *filename, size_t precision ) const {
     ofstream outfile;
     outfile.open( filename );
     if( outfile.is_open() ) {
+        outfile.precision( precision );
         outfile << *this;
         outfile.close();
     } else
@@ -333,6 +330,37 @@ void FactorGraph::clamp( const Var & n, size_t i, bool backup ) {
 }
 
 
+void FactorGraph::clampVar( size_t i, const vector<size_t> &is, bool backup ) {
+    Var n = var(i);
+    Factor mask_n( n, 0.0 );
+
+    foreach( size_t i, is ) {
+        assert( i <= n.states() );
+        mask_n[i] = 1.0;
+    }
+
+    map<size_t, Factor> newFacs;
+    for( size_t I = 0; I < nrFactors(); I++ ) 
+        if( factor(I).vars().contains( n ) ) {
+            newFacs[I] = factor(I) * mask_n;
+        }
+    setFactors( newFacs, backup );
+}
+
+
+void FactorGraph::clampFactor( size_t I, const vector<size_t> &is, bool backup ) {
+    size_t st = factor(I).states();
+    Factor newF( factor(I).vars(), 0.0 );
+
+    foreach( size_t i, is ) { 
+        assert( i <= st ); 
+        newF[i] = factor(I)[i];
+    }
+
+    setFactor( I, newF, backup );
+}
+
+
 void FactorGraph::backupFactor( size_t I ) {
     map<size_t,Factor>::iterator it = _backup.find( I );
     if( it != _backup.end() )
@@ -375,6 +403,7 @@ void FactorGraph::restoreFactors() {
     _backup.clear();
 }
 
+
 void FactorGraph::backupFactors( const std::set<size_t> & facs ) {
     for( std::set<size_t>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ )
         backupFactor( *fac );
@@ -451,7 +480,7 @@ FactorGraph FactorGraph::maximalFactors() const {
     for( size_t I = 0; I < nrFactors(); I++ )
         facs[newindex[maxfac[I]]] *= factor(I);
 
-    return FactorGraph( facs.begin(), facs.end(), vars.begin(), vars.end(), facs.size(), nrVars() );
+    return FactorGraph( facs.begin(), facs.end(), vars().begin(), vars().end(), facs.size(), nrVars() );
 }