Miscellaneous improvements in FactorGraph, Permute, HAK
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Sun, 16 May 2010 18:27:02 +0000 (20:27 +0200)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Sun, 16 May 2010 18:27:02 +0000 (20:27 +0200)
* Added FactorGraph::isMaximal(size_t) and FactorGraph::maximalFactor(size_t)
* Added optional reverse argument to Permute::Permute( const std::vector<Var>& )
  constructor
* Added Permute::ranges() accessor
* Added Permute::inverse() method
* Optimized region graph construction for HAK/GBP with clusters=BETHE

ChangeLog
Makefile
include/dai/bp.h
include/dai/factorgraph.h
include/dai/index.h
src/factorgraph.cpp
src/hak.cpp
tests/unit/factorgraph_test.cpp
tests/unit/index_test.cpp
utils/uai2fg.cpp

index 9315858..5d8da46 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,5 +1,11 @@
 git HEAD
 --------
+* Added FactorGraph::isMaximal(size_t) and FactorGraph::maximalFactor(size_t).
+* Added optional reverse argument to Permute::Permute( const std::vector<Var>& )
+  constructor.
+* Added Permute::ranges() accessor.
+* Added Permute::inverse() method.
+* Optimized region graph construction for HAK/GBP with clusters=BETHE.
 * Fixed a problem with isnan() on FreeBSD; it is now in the dai namespace.
 * Workaround for older g++ compilers (e.g. version 4.0.0 on Darwin 9.8.0)
   which have problems when comparing const_reverse_iterator with
index 23d1201..e5402e7 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -19,7 +19,7 @@ DAI_VERSION="git HEAD"
 DAI_DATE="May 12, 2010, or later"
 
 # Directories of libDAI sources
-# Location libDAI headers
+# Location of libDAI headers
 INC=include/dai
 # Location of libDAI source files
 SRC=src
index ec56dec..dc133af 100644 (file)
@@ -11,6 +11,7 @@
 
 /// \file
 /// \brief Defines class BP, which implements (Loopy) Belief Propagation
+/// \todo Consider using a priority_queue for maximum residual schedule
 
 
 #ifndef __defined_libdai_bp_h
index bc3336d..66d1da1 100644 (file)
@@ -259,6 +259,18 @@ class FactorGraph {
          */
         GraphAL MarkovGraph() const;
 
+        /// Returns whether the \a I 'th factor is maximal
+        /** \note A factor (domain) is \a maximal if and only if it is not a
+         *  strict subset of another factor domain.
+         */
+        bool isMaximal( size_t I ) const;
+
+        /// Returns the index of a maximal factor that contains the \a I 'th factor
+        /** \note A factor (domain) is \a maximal if and only if it is not a
+         *  strict subset of another factor domain.
+         */
+        size_t maximalFactor( size_t I ) const;
+
         /// Returns the maximal factor domains in this factorgraph
         /** \note A factor domain is \a maximal if and only if it is not a
          *  strict subset of another factor domain.
index 12250fd..e301b24 100644 (file)
@@ -157,14 +157,32 @@ class Permute {
         /// Construct from vector of variables.
         /** The implied permutation maps the index of each variable in \a vars according to the canonical ordering 
          *  (i.e., sorted ascendingly according to their label) to the index it has in \a vars.
+         *  If \a reverse == \c true, reverses the indexing in \a vars first.
          */
-        Permute( const std::vector<Var> &vars ) : _ranges(vars.size()), _sigma(vars.size()) {
-            for( size_t i = 0; i < vars.size(); ++i )
-                _ranges[i] = vars[i].states();
-            VarSet vs( vars.begin(), vars.end(), vars.size() );
-            VarSet::const_iterator vs_i = vs.begin();
-            for( size_t i = 0; i < vs.size(); ++i, ++vs_i )
-                _sigma[i] = find( vars.begin(), vars.end(), *vs_i ) - vars.begin();
+        Permute( const std::vector<Var> &vars, bool reverse=false ) : _ranges(), _sigma() {
+            size_t N = vars.size();
+
+            // construct ranges
+            _ranges.reserve( N );
+            for( size_t i = 0; i < N; ++i )
+                if( reverse )
+                    _ranges.push_back( vars[N - 1 - i].states() );
+                else
+                    _ranges.push_back( vars[i].states() );
+
+            // construct VarSet out of vars
+            VarSet vs( vars.begin(), vars.end(), N );
+            DAI_ASSERT( vs.size() == N );
+            
+            // construct sigma
+            _sigma.reserve( N );
+            for( VarSet::const_iterator vs_i = vs.begin(); vs_i != vs.end(); ++vs_i ) {
+                size_t ind = find( vars.begin(), vars.end(), *vs_i ) - vars.begin();
+                if( reverse )
+                    _sigma.push_back( N - 1 - ind );
+                else
+                    _sigma.push_back( ind );
+            }
         }
 
         /// Calculates a permuted linear index.
@@ -195,12 +213,15 @@ class Permute {
             return sigma_li;
         }
 
-        /// Returns const reference to the permutation
+        /// Returns constant reference to the permutation
         const std::vector<size_t>& sigma() const { return _sigma; }
 
         /// Returns reference to the permutation
         std::vector<size_t>& sigma() { return _sigma; }
 
+        /// Returns constant reference to the dimensionality vector
+        const std::vector<size_t>& ranges() { return _ranges; }
+
         /// Returns the result of applying the permutation on \a i
         size_t operator[]( size_t i ) const {
 #ifdef DAI_DEBUG
@@ -209,6 +230,18 @@ class Permute {
             return _sigma[i];
 #endif
         }
+
+        /// Returns the inverse permutation
+        Permute inverse() const {
+            size_t N = _ranges.size();
+            std::vector<size_t> invRanges( N, 0 );
+            std::vector<size_t> invSigma( N, 0 );
+            for( size_t i = 0; i < N; i++ ) {
+                invSigma[_sigma[i]] = i;
+                invRanges[i] = _ranges[_sigma[i]];
+            }
+            return Permute( invRanges, invSigma );
+        }
 };
 
 
index 4b961f9..905bb34 100644 (file)
@@ -278,26 +278,58 @@ GraphAL FactorGraph::MarkovGraph() const {
 }
 
 
-vector<VarSet> FactorGraph::maximalFactorDomains() const {
-    vector<VarSet> result;
+bool FactorGraph::isMaximal( size_t I ) const {
+    const VarSet& I_vars = factor(I).vars();
+    size_t I_size = I_vars.size();
+
+    if( I_size == 0 ) {
+        for( size_t J = 0; J < nrFactors(); J++ ) 
+            if( J != I )
+                if( factor(J).vars().size() > 0 )
+                    return false;
+        return true;
+    } else {
+        foreach( const Neighbor& i, nbF(I) ) {
+            foreach( const Neighbor& J, nbV(i) ) {
+                if( J != I )
+                    if( (factor(J).vars() >> I_vars) && (factor(J).vars().size() != I_size) )
+                        return false;
+            }
+        }
+        return true;
+    }
+}
 
-    for( size_t I = 0; I < nrFactors(); I++ ) {
-        bool maximal = true;
-        const VarSet& I_vars = factor(I).vars();
-        size_t I_size = I_vars.size();
 
-        if( I_size == 0 )
-            maximal = false;
+size_t FactorGraph::maximalFactor( size_t I ) const {
+    const VarSet& I_vars = factor(I).vars();
+    size_t I_size = I_vars.size();
+
+    if( I_size == 0 ) {
+        for( size_t J = 0; J < nrFactors(); J++ )
+            if( J != I )
+                if( factor(J).vars().size() > 0 )
+                    return maximalFactor( J );
+        return I;
+    } else {
         foreach( const Neighbor& i, nbF(I) ) {
             foreach( const Neighbor& J, nbV(i) ) {
                 if( J != I )
                     if( (factor(J).vars() >> I_vars) && (factor(J).vars().size() != I_size) )
-                        maximal = false;
+                        return maximalFactor( J );
             }
         }
-        if( maximal )
-            result.push_back( factor(I).vars() );
+        return I;
     }
+}
+
+
+vector<VarSet> FactorGraph::maximalFactorDomains() const {
+    vector<VarSet> result;
+
+    for( size_t I = 0; I < nrFactors(); I++ )
+        if( isMaximal( I ) )
+            result.push_back( factor(I).vars() );
 
     if( result.size() == 0 )
         result.push_back( VarSet() );
index ebf6758..8f8c5fe 100644 (file)
@@ -104,18 +104,24 @@ string HAK::printProperties() const {
 
 void HAK::construct() {
     // Create outer beliefs
+    if( props.verbose >= 3 )
+        cerr << "Constructing outer beliefs" << endl;
     _Qa.clear();
     _Qa.reserve(nrORs());
     for( size_t alpha = 0; alpha < nrORs(); alpha++ )
         _Qa.push_back( Factor( OR(alpha) ) );
 
     // Create inner beliefs
+    if( props.verbose >= 3 )
+        cerr << "Constructing inner beliefs" << endl;
     _Qb.clear();
     _Qb.reserve(nrIRs());
     for( size_t beta = 0; beta < nrIRs(); beta++ )
         _Qb.push_back( Factor( IR(beta) ) );
 
     // Create messages
+    if( props.verbose >= 3 )
+        cerr << "Constructing messages" << endl;
     _muab.clear();
     _muab.reserve( nrORs() );
     _muba.clear();
@@ -154,6 +160,9 @@ void HAK::findLoopClusters( const FactorGraph & fg, std::set<VarSet> &allcl, Var
 HAK::HAK(const FactorGraph & fg, const PropertySet &opts) : DAIAlgRG(), _Qa(), _Qb(), _muab(), _muba(), _maxdiff(0.0), _iters(0U), props() {
     setProperties( opts );
 
+    if( props.verbose >= 3 )
+        cerr << "Constructing clusters" << endl;
+
     vector<VarSet> cl;
     if( props.clusters == Properties::ClustersType::MIN ) {
         cl = fg.maximalFactorDomains();
@@ -180,80 +189,64 @@ HAK::HAK(const FactorGraph & fg, const PropertySet &opts) : DAIAlgRG(), _Qa(), _
         }
         constructCVM( fg, cl );
     } else if( props.clusters == Properties::ClustersType::BETHE ) {
-/*      if( props.verbose >= 3 )
-            cerr << "Copy factor graph" << endl;
         // Copy factor graph structure
+        if( props.verbose >= 3 )
+            cerr << "Copying factor graph" << endl;
         FactorGraph::operator=( fg );
 
+        // Construct inner regions (single variables)
         if( props.verbose >= 3 )
-            cerr << "Copy region graph from factor graph" << endl;
-        // Copy bipartite graph
-        _G = fg.bipGraph();
+            cerr << "Constructing inner regions" << endl;
+        _IRs.reserve( fg.nrVars() );
+        for( size_t i = 0; i < fg.nrVars(); i++ )
+            _IRs.push_back( Region( fg.var(i), 1.0 ) );
 
-        // Throw away non-maximal regions
+        // Construct graph
         if( props.verbose >= 3 )
-            cerr << "Throw away non-maximal regions" << endl;
-        for( size_t OR = 0; OR < _G.nrNodes1(); ) {
-            // check if it is maximal
-            bool maximal = true;
-            size_t OR_size = _G.nb1(OR).size();
-            if( OR_size == 0 )
-                maximal = false;
-            size_t OR2;
-            foreach( OR2, _G.delta1(OR, false) )
-                if( (_G.nb1(OR2).size() > OR_size) && (_G.nb1Set(OR2) >> _G.nb1Set(OR1)) ) {
-                    maximal = false;
-                    break;
-                }
-            if( !maximal ) {
-                // if not maximal, throw away and assign factor to OR2
-                _G.eraseNode1( OR );
+            cerr << "Constructing graph" << endl;
+        _G = BipartiteGraph( 0, nrIRs() );
+
+        // Construct outer regions:
+        // maximal factors become new outer regions
+        // non-maximal factors are assigned an outer region that contains them
+        if( props.verbose >= 3 )
+            cerr << "Construct outer regions" << endl;
+        _fac2OR.reserve( nrFactors() );
+        queue<pair<size_t, size_t> > todo;
+        for( size_t I = 0; I < fg.nrFactors(); I++ ) {
+            size_t J = fg.maximalFactor( I );
+            if( J == I ) {
+                // I is maximal; add it to the outer regions
+                _fac2OR.push_back( nrORs() );
+                // Construct outer region (with counting number 1.0)
+                _ORs.push_back( FRegion( fg.factor(I), 1.0 ) );
+                // Add node and edges to graph
+                SmallSet<size_t> irs = fg.bipGraph().nb2Set( I );
+                _G.addNode1( irs.begin(), irs.end(), irs.size() );
+            } else if( J < I ) {
+                // J is larger and has already been assigned to an outer region
+                // so I should belong to the same outer region as J
+                _fac2OR.push_back( _fac2OR[J] );
+                _ORs[_fac2OR[J]] *= fg.factor(I);
             } else {
-                OR++;
+                // J is larger but has not yet been assigned to an outer region
+                // we handle this case later
+                _fac2OR.push_back( -1 );
+                todo.push( make_pair( I, J ) );
             }
         }
-
-        // For each factor, find an outer region that subsumes that factor.
-        // Then, multiply the outer region with that factor.
-        _fac2OR.clear();
-        _fac2OR.reserve( nrFactors() );
-        for( size_t I = 0; I < nrFactors(); I++ ) {
-            size_t alpha;
-            for( alpha = 0; alpha < nrORs(); alpha++ )
-                if( OR(alpha).vars() >> factor(I).vars() ) {
-                    _fac2OR.push_back( alpha );
-                    break;
-                }
-            DAI_ASSERT( alpha != nrORs() );
+        // finish the construction
+        while( !todo.empty() ) {
+            size_t I = todo.front().first;
+            size_t J = todo.front().second;
+            todo.pop();
+            _fac2OR[I] = _fac2OR[J];
+            _ORs[_fac2OR[J]] *= fg.factor(I);
         }
 
-        if( props.verbose >= 3 )
-            cerr << "Build outer regions" << endl;*/
-        // build outer regions (the maximal factor domains)
-        cl = fg.maximalFactorDomains();
-        size_t nrEdges = 0;
-        for( size_t c = 0; c < cl.size(); c++ )
-            nrEdges += cl[c].size();
-
-        // build inner regions (single variables)
-        vector<Region> irs;
-        irs.reserve( fg.nrVars() );
-        for( size_t i = 0; i < fg.nrVars(); i++ )
-            irs.push_back( Region( fg.var(i), 1.0 ) );
-
-        // build edges (an outer and inner region are connected if the outer region contains the inner one)
-        // and calculate counting number for inner regions
-        vector<std::pair<size_t, size_t> > edges;
-        edges.reserve( nrEdges );
-        for( size_t c = 0; c < cl.size(); c++ )
-            for( size_t i = 0; i < irs.size(); i++ )
-                if( cl[c].contains( fg.var(i) ) ) {
-                    edges.push_back( make_pair( c, i ) );
-                    irs[i].c() -= 1.0;
-                }
-
-        // build region graph
-        RegionGraph::construct( fg, cl, irs, edges );
+        // Calculate inner regions' counting numbers
+        for( size_t beta = 0; beta < nrIRs(); beta++ )
+            _IRs[beta].c() = 1.0 - _G.nb2(beta).size();
     } else
         DAI_THROW(UNKNOWN_ENUM_VALUE);
 
index 3268792..5ad84a7 100644 (file)
@@ -198,6 +198,12 @@ BOOST_AUTO_TEST_CASE( QueriesTest ) {
     BOOST_CHECK( G1.isPairwise() );
     BOOST_CHECK( G1.MarkovGraph() == H );
     BOOST_CHECK( G1.bipGraph() == K );
+    BOOST_CHECK(  G1.isMaximal( 0 ) );
+    BOOST_CHECK(  G1.isMaximal( 1 ) );
+    BOOST_CHECK( !G1.isMaximal( 2 ) );
+    BOOST_CHECK_EQUAL( G1.maximalFactor( 0 ), 0 );
+    BOOST_CHECK_EQUAL( G1.maximalFactor( 1 ), 1 );
+    BOOST_CHECK_EQUAL( G1.maximalFactor( 2 ), 0 );
     BOOST_CHECK_EQUAL( G1.maximalFactorDomains().size(), 2 );
     BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[0], v01 );
     BOOST_CHECK_EQUAL( G1.maximalFactorDomains()[1], v12 );
@@ -233,6 +239,14 @@ BOOST_AUTO_TEST_CASE( QueriesTest ) {
     BOOST_CHECK( G2.isPairwise() );
     BOOST_CHECK( G2.MarkovGraph() == H );
     BOOST_CHECK( G2.bipGraph() == K );
+    BOOST_CHECK(  G2.isMaximal( 0 ) );
+    BOOST_CHECK(  G2.isMaximal( 1 ) );
+    BOOST_CHECK( !G2.isMaximal( 2 ) );
+    BOOST_CHECK(  G2.isMaximal( 3 ) );
+    BOOST_CHECK_EQUAL( G2.maximalFactor( 0 ), 0 );
+    BOOST_CHECK_EQUAL( G2.maximalFactor( 1 ), 1 );
+    BOOST_CHECK_EQUAL( G2.maximalFactor( 2 ), 0 );
+    BOOST_CHECK_EQUAL( G2.maximalFactor( 3 ), 3 );
     BOOST_CHECK_EQUAL( G2.maximalFactorDomains().size(), 3 );
     BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[0], v01 );
     BOOST_CHECK_EQUAL( G2.maximalFactorDomains()[1], v12 );
@@ -283,6 +297,16 @@ BOOST_AUTO_TEST_CASE( QueriesTest ) {
     BOOST_CHECK( G3.isPairwise() );
     BOOST_CHECK( G3.MarkovGraph() == H );
     BOOST_CHECK( G3.bipGraph() == K );
+    BOOST_CHECK(  G3.isMaximal( 0 ) );
+    BOOST_CHECK(  G3.isMaximal( 1 ) );
+    BOOST_CHECK( !G3.isMaximal( 2 ) );
+    BOOST_CHECK(  G3.isMaximal( 3 ) );
+    BOOST_CHECK(  G3.isMaximal( 4 ) );
+    BOOST_CHECK_EQUAL( G3.maximalFactor( 0 ), 0 );
+    BOOST_CHECK_EQUAL( G3.maximalFactor( 1 ), 1 );
+    BOOST_CHECK_EQUAL( G3.maximalFactor( 2 ), 0 );
+    BOOST_CHECK_EQUAL( G3.maximalFactor( 3 ), 3 );
+    BOOST_CHECK_EQUAL( G3.maximalFactor( 4 ), 4 );
     BOOST_CHECK_EQUAL( G3.maximalFactorDomains().size(), 4 );
     BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[0], v01 );
     BOOST_CHECK_EQUAL( G3.maximalFactorDomains()[1], v12 );
@@ -328,6 +352,18 @@ BOOST_AUTO_TEST_CASE( QueriesTest ) {
     BOOST_CHECK( !G4.isPairwise() );
     BOOST_CHECK( G4.MarkovGraph() == H );
     BOOST_CHECK( G4.bipGraph() == K );
+    BOOST_CHECK(  G4.isMaximal( 0 ) );
+    BOOST_CHECK( !G4.isMaximal( 1 ) );
+    BOOST_CHECK( !G4.isMaximal( 2 ) );
+    BOOST_CHECK(  G4.isMaximal( 3 ) );
+    BOOST_CHECK( !G4.isMaximal( 4 ) );
+    BOOST_CHECK(  G4.isMaximal( 5 ) );
+    BOOST_CHECK_EQUAL( G4.maximalFactor( 0 ), 0 );
+    BOOST_CHECK_EQUAL( G4.maximalFactor( 1 ), 5 );
+    BOOST_CHECK_EQUAL( G4.maximalFactor( 2 ), 0 );
+    BOOST_CHECK_EQUAL( G4.maximalFactor( 3 ), 3 );
+    BOOST_CHECK_EQUAL( G4.maximalFactor( 4 ), 5 );
+    BOOST_CHECK_EQUAL( G4.maximalFactor( 5 ), 5 );
     BOOST_CHECK_EQUAL( G4.maximalFactorDomains().size(), 3 );
     BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[0], v01 );
     BOOST_CHECK_EQUAL( G4.maximalFactorDomains()[1], v02 );
index a60d308..bef7f04 100644 (file)
@@ -68,12 +68,17 @@ BOOST_AUTO_TEST_CASE( PermuteTest ) {
     V.push_back( x0 );
     VarSet X( V.begin(), V.end() );
     Permute sigma(V);
+    BOOST_CHECK_EQUAL( sigma.sigma().size(), 3 );
     BOOST_CHECK_EQUAL( sigma.sigma()[0], 2 );
     BOOST_CHECK_EQUAL( sigma.sigma()[1], 0 );
     BOOST_CHECK_EQUAL( sigma.sigma()[2], 1 );
     BOOST_CHECK_EQUAL( sigma[0], 2 );
     BOOST_CHECK_EQUAL( sigma[1], 0 );
     BOOST_CHECK_EQUAL( sigma[2], 1 );
+    BOOST_CHECK_EQUAL( sigma.ranges().size(), 3 );
+    BOOST_CHECK_EQUAL( sigma.ranges()[0], 3 );
+    BOOST_CHECK_EQUAL( sigma.ranges()[1], 2 );
+    BOOST_CHECK_EQUAL( sigma.ranges()[2], 2 );
     BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 0 ), 0 );
     BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 1 ), 2 );
     BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 2 ), 4 );
@@ -87,15 +92,41 @@ BOOST_AUTO_TEST_CASE( PermuteTest ) {
     BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 10 ), 9 );
     BOOST_CHECK_EQUAL( sigma.convertLinearIndex( 11 ), 11 );
 
+    Permute sigmar(V, true);
+    BOOST_CHECK_EQUAL( sigmar.sigma().size(), 3 );
+    BOOST_CHECK_EQUAL( sigmar.sigma()[0], 0 );
+    BOOST_CHECK_EQUAL( sigmar.sigma()[1], 2 );
+    BOOST_CHECK_EQUAL( sigmar.sigma()[2], 1 );
+    BOOST_CHECK_EQUAL( sigmar[0], 0 );
+    BOOST_CHECK_EQUAL( sigmar[1], 2 );
+    BOOST_CHECK_EQUAL( sigmar[2], 1 );
+    BOOST_CHECK_EQUAL( sigmar.ranges().size(), 3 );
+    BOOST_CHECK_EQUAL( sigmar.ranges()[0], 2 );
+    BOOST_CHECK_EQUAL( sigmar.ranges()[1], 2 );
+    BOOST_CHECK_EQUAL( sigmar.ranges()[2], 3 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 0 ), 0 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 1 ), 1 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 2 ), 6 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 3 ), 7 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 4 ), 2 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 5 ), 3 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 6 ), 8 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 7 ), 9 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 8 ), 4 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 9 ), 5 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 10 ), 10 );
+    BOOST_CHECK_EQUAL( sigmar.convertLinearIndex( 11 ), 11 );
+
     std::vector<size_t> rs, sig;
-    rs.push_back(2);
     rs.push_back(3);
     rs.push_back(2);
+    rs.push_back(2);
     sig.push_back(2);
     sig.push_back(0);
     sig.push_back(1);
     Permute tau( rs, sig );
     BOOST_CHECK( tau.sigma() == sig );
+    BOOST_CHECK( tau.ranges() == rs );
     BOOST_CHECK_EQUAL( tau[0], 2 );
     BOOST_CHECK_EQUAL( tau[1], 0 );
     BOOST_CHECK_EQUAL( tau[2], 1 );
@@ -111,6 +142,20 @@ BOOST_AUTO_TEST_CASE( PermuteTest ) {
     BOOST_CHECK_EQUAL( tau.convertLinearIndex( 9 ), 7 );
     BOOST_CHECK_EQUAL( tau.convertLinearIndex( 10 ), 9 );
     BOOST_CHECK_EQUAL( tau.convertLinearIndex( 11 ), 11 );
+
+    Permute tauinv = tau.inverse();
+    BOOST_CHECK_EQUAL( tauinv.sigma().size(), 3 );
+    BOOST_CHECK_EQUAL( tauinv.ranges().size(), 3 );
+    BOOST_CHECK_EQUAL( tauinv[0], 1 );
+    BOOST_CHECK_EQUAL( tauinv[1], 2 );
+    BOOST_CHECK_EQUAL( tauinv[2], 0 );
+    BOOST_CHECK_EQUAL( tauinv.ranges()[0], 2 );
+    BOOST_CHECK_EQUAL( tauinv.ranges()[1], 3 );
+    BOOST_CHECK_EQUAL( tauinv.ranges()[2], 2 );
+    for( size_t i = 0; i < 12; i++ ) {
+        BOOST_CHECK_EQUAL( tau.convertLinearIndex( tauinv.convertLinearIndex( i ) ), i );
+        BOOST_CHECK_EQUAL( tauinv.convertLinearIndex( tau.convertLinearIndex( i ) ), i );
+    }
 }
 
 
index 47bc025..f447366 100644 (file)
@@ -56,10 +56,10 @@ map<size_t, size_t> ReadUAIEvidenceFile( char* filename ) {
 
 
 /// Reads factor graph (as a pair of a variable vector and factor vector) from a UAI factor graph file
-pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t verbose ) {
-    pair<vector<Var>, vector<Factor> > result;
-    vector<Var>& vars = result.first;
-    vector<Factor>& factors = result.second;
+void ReadUAIFGFile( const char *filename, size_t verbose, vector<Var>& vars, vector<Factor>& factors, vector<Permute>& permutations ) {
+    vars.clear();
+    factors.clear();
+    permutations.clear();
 
     // open file
     ifstream is;
@@ -100,9 +100,9 @@ pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t v
             cout << "Reading " << nrFacs << " factors..." << endl;
 
         // for each factor, read the variables on which it depends
-        vector<vector<long> > labels;
+        vector<vector<Var> > factorVars;
         factors.reserve( nrFacs );
-        labels.reserve( nrFacs );
+        factorVars.reserve( nrFacs );
         for( size_t I = 0; I < nrFacs; I++ ) {
             if( verbose >= 3 )
                 cout << "Reading factor " << I << "..." << endl;
@@ -115,12 +115,12 @@ pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t v
             if( verbose >= 3 )
                 cout << "  which depends on " << I_nrVars << " variables" << endl;
 
-            // for each of the variables, read its label and number of states
+            // read the variable labels
             vector<long> I_labels;
             vector<size_t> I_dims;
-            VarSet I_vars;
             I_labels.reserve( I_nrVars );
             I_dims.reserve( I_nrVars );
+            factorVars[I].reserve( I_nrVars );
             for( size_t _i = 0; _i < I_nrVars; _i++ ) {
                 long label;
                 is >> label;
@@ -128,45 +128,25 @@ pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t v
                     DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read variable labels for " + toString(I) + "'th factor");
                 I_labels.push_back( label );
                 I_dims.push_back( vars[label].states() );
-                I_vars |= vars[label];
+                factorVars[I].push_back( vars[label] );
             }
             if( verbose >= 3 )
                 cout << "  labels: " << I_labels << ", dimensions " << I_dims << endl;
 
             // add the factor and the labels
-            factors.push_back( Factor(I_vars,0.0) );
-            labels.push_back( I_labels );
+            factors.push_back( Factor( VarSet( factorVars[I].begin(), factorVars[I].end(), factorVars[I].size() ), (Real)0 ) );
         }
 
         // for each factor, read its values
+        permutations.reserve( nrFacs );
         for( size_t I = 0; I < nrFacs; I++ ) {
             if( verbose >= 3 )
                 cout << "Reading factor " << I << "..." << endl;
 
-            // last label is least significant, so we reverse the label vector
-            reverse( labels[I].begin(), labels[I].end() );
+            // calculate permutation object, reversing the indexing in factorVars[I] first
+            Permute permindex( factorVars[I], true );
+            permutations.push_back( permindex );
 
-            // prepare a vector containing the dimensionalities of the variables for this factor
-            size_t I_nrVars = factors[I].vars().size();
-            vector<size_t> I_dims;
-            I_dims.reserve( I_nrVars );
-            for( size_t _i = 0; _i < I_nrVars; _i++ )
-                I_dims.push_back( vars[labels[I][_i]].states() );
-            if( verbose >= 3 )
-                cout << "  labels: " << labels[I] << ", dimensions " << I_dims << endl;
-
-            // calculate permutation sigma (internally, members are sorted canonically, 
-            // which may be different from the way they are sorted in the file)
-            vector<size_t> sigma( I_nrVars, 0 );
-            VarSet::const_iterator j = factors[I].vars().begin();
-            for( size_t mi = 0; mi < I_nrVars; mi++, j++ )
-                sigma[mi] = distance( labels[I].begin(), find( labels[I].begin(), labels[I].end(), j->label() ) );
-            if( verbose >= 3 )
-                cout << "  permutation: " << sigma << endl;
-
-            // construct permutation object
-            Permute permindex( I_dims, sigma );
-            
             // read factor values
             size_t nrNonZeros;
             is >> nrNonZeros;
@@ -181,9 +161,13 @@ pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t v
                 if( is.fail() )
                     DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read factor values of " + toString(I) + "'th factor");
                 // assign value after calculating its linear index corresponding to the permutation
+                if( verbose >= 4 )
+                    cout << "  " << li << "'th value " << val << " corresponds with index " << permindex.convertLinearIndex(li) << endl;
                 factors[I][permindex.convertLinearIndex( li )] = val;
             }
         }
+        if( verbose >= 3 )
+            cout << "variables:" << vars << endl;
         if( verbose >= 3 )
             cout << "factors:" << factors << endl;
 
@@ -191,8 +175,6 @@ pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t v
         is.close();
     } else
         DAI_THROWE(CANNOT_READ_FILE,"Cannot read from file " + std::string(filename));
-
-    return result;
 }
 
 
@@ -212,11 +194,14 @@ int main( int argc, char *argv[] ) {
         long type = atoi( argv[4] );
         bool run_jtree = atoi( argv[5] );
 
-        // read factor graph and evidence
-        pair<vector<Var>, vector<Factor> > varfacs = ReadUAIFGFile( argv[1], verbose );
+        // read factor graph
+        vector<Var> vars;
+        vector<Factor> facs;
+        vector<Permute> permutations;
+        ReadUAIFGFile( argv[1], verbose, vars, facs, permutations );
+
+        // read evidence
         map<size_t,size_t> evid = ReadUAIEvidenceFile( argv[2] );
-        vector<Var>& vars = varfacs.first;
-        vector<Factor>& facs = varfacs.second;
 
         // construct unclamped factor graph
         FactorGraph fg0( facs.begin(), facs.end(), vars.begin(), vars.end(), facs.size(), vars.size() );