Replaced ENUM2,ENUM3,ENUM4,ENUM5,ENUM6 by single DAI_ENUM macro.
[libdai.git] / src / factorgraph.cpp
index 0b65e58..4df27c8 100644 (file)
@@ -27,9 +27,9 @@
 #include <string>
 #include <algorithm>
 #include <functional>
-#include <tr1/unordered_map>
 #include <dai/factorgraph.h>
 #include <dai/util.h>
+#include <dai/exceptions.h>
 
 
 namespace dai {
@@ -38,7 +38,7 @@ namespace dai {
 using namespace std;
 
 
-FactorGraph::FactorGraph( const vector<Factor> &P ) : G(), _undoProbs(), _normtype(Prob::NORMPROB) {
+FactorGraph::FactorGraph( const std::vector<Factor> &P ) : G(), _undoProbs() {
     // add factors, obtain variables
     set<Var> _vars;
     factors.reserve( P.size() );
@@ -62,13 +62,12 @@ FactorGraph::FactorGraph( const vector<Factor> &P ) : G(), _undoProbs(), _normty
 /// Part of constructors (creates edges, neighbours and adjacency matrix)
 void FactorGraph::createGraph( size_t nrEdges ) {
     // create a mapping for indices
-    std::tr1::unordered_map<size_t, size_t> hashmap;
+    hash_map<size_t, size_t> hashmap;
     
     for( size_t i = 0; i < vars.size(); i++ )
         hashmap[var(i).label()] = i;
     
     // create edge list
-    typedef pair<unsigned,unsigned> Edge;
     vector<Edge> edges;
     edges.reserve( nrEdges );
     for( size_t i2 = 0; i2 < nrFactors(); i2++ ) {
@@ -123,13 +122,13 @@ istream& operator >> (istream& is, FactorGraph& fg) {
             getline(is,line);
         is >> nr_f;
         if( is.fail() )
-            throw "ReadFromFile: unable to read number of Factors";
+            DAI_THROW(INVALID_FACTORGRAPH_FILE);
         if( verbose >= 2 )
             cout << "Reading " << nr_f << " factors..." << endl;
 
         getline (is,line);
         if( is.fail() )
-            throw "ReadFromFile: empty line expected";
+            DAI_THROW(INVALID_FACTORGRAPH_FILE);
 
         for( size_t I = 0; I < nr_f; I++ ) {
             if( verbose >= 3 )
@@ -172,11 +171,11 @@ istream& operator >> (istream& is, FactorGraph& fg) {
             // add the Factor
             VarSet I_vars;
             for( size_t mi = 0; mi < nr_members; mi++ )
-                I_vars.insert( Var(labels[mi], dims[mi]) );
+                I_vars |= Var(labels[mi], dims[mi]);
             factors.push_back(Factor(I_vars,0.0));
             
             // calculate permutation sigma (internally, members are sorted)
-            vector<long> sigma(nr_members,0);
+            vector<size_t> sigma(nr_members,0);
             VarSet::iterator j = I_vars.begin();
             for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
                 long search_for = j->label();
@@ -190,22 +189,7 @@ istream& operator >> (istream& is, FactorGraph& fg) {
             }
 
             // calculate multindices
-            vector<size_t> sdims(nr_members,0);
-            for( size_t k = 0; k < nr_members; k++ ) {
-                sdims[k] = dims[sigma[k]];
-            }
-            multind mi(dims);
-            multind smi(sdims);
-            if( verbose >= 3 ) {
-                cout << "  mi.max(): " << mi.max() << endl;
-                cout << "       ";
-                for( size_t k=0; k < nr_members; k++ ) 
-                    cout << labels[k] << " ";
-                cout << "   ";
-                for( size_t k=0; k < nr_members; k++ ) 
-                    cout << labels[sigma[k]] << " ";
-                cout << endl;
-            }
+            Permute permindex( dims, sigma );
             
             // read values
             size_t nr_nonzeros;
@@ -224,19 +208,9 @@ istream& operator >> (istream& is, FactorGraph& fg) {
                     getline(is,line);
                 is >> val;
 
-                vector<size_t> vi = mi.vi(li);
-                vector<size_t> svi(vi.size(),0);
-                for( size_t k = 0; k < vi.size(); k++ )
-                    svi[k] = vi[sigma[k]];
-                size_t sli = smi.li(svi);
-                if( verbose >= 3 ) {
-                    cout << "    " << li << ": ";
-                    copy( vi.begin(), vi.end(), ostream_iterator<size_t>(cout," "));
-                    cout << "-> ";
-                    copy( svi.begin(), svi.end(), ostream_iterator<size_t>(cout," "));
-                    cout << ": " << sli << endl;
-                }
-                factors.back()[sli] = val;
+                // store value, but permute indices first according
+                // to internal representation
+                factors.back()[permindex.convert_linear_index( li  )] = val;
             }
         }
 
@@ -420,7 +394,7 @@ void FactorGraph::clamp( const Var & n, size_t i ) {
 
     // For all factors that contain n
     for( size_t I = 0; I < nrFactors(); I++ ) 
-        if( factor(I).vars() && n )
+        if( factor(I).vars().contains( n ) )
             // Multiply it with a delta function
             factor(I) *= delta_n_i;
 
@@ -449,14 +423,14 @@ void FactorGraph::saveProbs( const VarSet &ns ) {
     if( !_undoProbs.empty() )
         cout << "FactorGraph::saveProbs:  WARNING: _undoProbs not empy!" << endl;
     for( size_t I = 0; I < nrFactors(); I++ )
-        if( factor(I).vars() && ns )
+        if( factor(I).vars().intersects( ns ) )
             _undoProbs[I] = factor(I).p();
 }
 
 
 void FactorGraph::undoProbs( const VarSet &ns ) {
     for( map<size_t,Prob>::iterator uI = _undoProbs.begin(); uI != _undoProbs.end(); ) {
-        if( factor((*uI).first).vars() && ns ) {
+        if( factor((*uI).first).vars().intersects( ns ) ) {
 //          cout << "undoing " << factor((*uI).first).vars() << endl;
 //          cout << "from " << factor((*uI).first).p() << " to " << (*uI).second << endl;
             factor((*uI).first).p() = (*uI).second;