Merged var.h and varset.h from SVN head
authorJoris Mooij <jorism@osun.tuebingen.mpg.de>
Sat, 13 Sep 2008 16:49:18 +0000 (18:49 +0200)
committerJoris Mooij <jorism@osun.tuebingen.mpg.de>
Sat, 13 Sep 2008 16:49:18 +0000 (18:49 +0200)
- Merged var.h from SVN head
- Merged varset.h from SVN head, which uses vector<Var> as implementation
  for a VarSet instead of a set<Var>, which yields a 30% speed improvement
  for testregression

ChangeLog
include/dai/factor.h
include/dai/var.h
include/dai/varset.h
src/factorgraph.cpp
src/hak.cpp
src/jtree.cpp
src/mf.cpp
src/regiongraph.cpp
utils/createfg.cpp

index 7795ce4..b1e6f67 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -2,12 +2,12 @@ libDAI-0.2.2 (2008-??-??)
 -------------------------
 
 * Pervasive change of BipartiteGraph implementation (based on an idea by
 -------------------------
 
 * Pervasive change of BipartiteGraph implementation (based on an idea by
-  Giuseppe Passino). BipartiteGraph no longer stores the node properties (former
-  _V1 and _V2), nor does it store a dense adjacency matrix anymore, nor an edge
-  list. Instead, it stores the graph structure as lists of neighboring nodes.
-  This yields a significant memory/speed improvement for large factor graphs, and
-  is more elegant as well. Iterating over neighbors is made easy by using
-  boost::foreach.
+  Giuseppe Passino). BipartiteGraph no longer stores the node properties 
+  (former _V1 and _V2), nor does it store a dense adjacency matrix anymore, 
+  nor an edge list. Instead, it stores the graph structure as lists of 
+  neighboring nodes. This yields a significant memory/speed improvement for 
+  large factor graphs, and is more elegant as well. Iterating over neighbors is
+  made easy by using boost::foreach.
 * Improved index.h (merged from SVN head):
   - Renamed Index -> IndexFor
   - Added some .reserve()'s to IndexFor methods which yields a 
 * Improved index.h (merged from SVN head):
   - Renamed Index -> IndexFor
   - Added some .reserve()'s to IndexFor methods which yields a 
@@ -42,9 +42,17 @@ libDAI-0.2.2 (2008-??-??)
 * FactorGraph constructors no longer check for short loops (huge speed
   increase for large factor graphs), nor for negative entries. Also, the 
   normtype is now Prob::NORMPROB by default.
 * FactorGraph constructors no longer check for short loops (huge speed
   increase for large factor graphs), nor for negative entries. Also, the 
   normtype is now Prob::NORMPROB by default.
+* VarSet is now implemented using a std::vector<Var> instead of a
+  std::set<Var>, which yields a significant speed improvement.
 * Small optimization in Diffs
 * Interface changes:
 * Small optimization in Diffs
 * Interface changes:
-  - VarSet::stateSpace() -> VarSet::states()
+  - VarSet::
+      stateSpace() -> states()
+      VarSet( const std::set<Var> ) -> VarSet( begin, end, sizeHint=0 )
+      VarSet( const std::vector<Var> ) -> VarSet( begin, end, sizeHint=0 )
+      removed bool operator||
+      operator&&(const VarSet&) -> intersects(const VarSet&)
+      operator&&(const Var&) -> contains(const Var&)
   - FactorGraph::
       delta(const Var &) -> delta(size_t)
       Delta(const Var &) -> Delta(size_t)
   - FactorGraph::
       delta(const Var &) -> delta(size_t)
       Delta(const Var &) -> Delta(size_t)
@@ -68,6 +76,7 @@ libDAI-0.2.2 (2008-??-??)
   - Prob::max() -> Prob::maxVal()
   - Factor::max() -> Factor::maxVal()
   - toc() in util.h now returns seconds as a double
   - Prob::max() -> Prob::maxVal()
   - Factor::max() -> Factor::maxVal()
   - toc() in util.h now returns seconds as a double
+  - VarSet::operator&&
 * Added possibility to build for Windows in Makefile
 
 
 * Added possibility to build for Windows in Makefile
 
 
index f18f039..41b5eaf 100644 (file)
@@ -317,8 +317,8 @@ template<typename T> Complex KL_dist(const TFactor<T> & P, const TFactor<T> & Q)
 // calculate N(psi, i, j)
 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const {
 #ifdef DAI_DEBUG
 // calculate N(psi, i, j)
 template<typename T> T TFactor<T>::strength( const Var &i, const Var &j ) const {
 #ifdef DAI_DEBUG
-    assert( _vs && i );
-    assert( _vs && j );
+    assert( _vs.contains( i ) );
+    assert( _vs.contains( j ) );
     assert( i != j );
 #endif
     VarSet ij = i | j;
     assert( i != j );
 #endif
     VarSet ij = i | j;
index d0a9950..e9deda1 100644 (file)
 namespace dai {
 
 
 namespace dai {
 
 
-/// Represents a discrete variable
+/// Represents a discrete variable.
+/*  It contains the label of the variable, an integer-valued
+ *  unique ID for that variable, and the number of possible 
+ *  values ("states") of the variable.
+ */
 class Var {
     private:
         /// Internal label of the variable
 class Var {
     private:
         /// Internal label of the variable
@@ -41,37 +45,37 @@ class Var {
         
     public:
         /// Default constructor
         
     public:
         /// Default constructor
-        Var() : _label(-1), _states(0) {};
+        Var() : _label(-1), _states(0) {}
         /// Constructor
         /// Constructor
-        Var(long label, size_t states) : _label(label), _states(states) {};
+        Var( long label, size_t states ) : _label(label), _states(states) {}
 
 
-        /// Read access to label
-        long label() const { return _label; };
-        /// Access to label
-        long & label() { return _label; };
+        /// Gets the label
+        long label() const { return _label; }
+        /// Returns reference to label
+        long & label() { return _label; }
 
 
-        /// Read access to states
-        size_t states () const { return _states; };
-        /// Access to states
-        size_t& states () { return _states; };
+        /// Gets the number of states
+        size_t states () const { return _states; }
+        /// Returns reference to number of states
+        size_t& states () { return _states; }
 
         /// Smaller-than operator (compares labels)
 
         /// Smaller-than operator (compares labels)
-        bool operator <  (const Var& n) const { return( _label <  n._label ); };
+        bool operator < ( const Var& n ) const { return( _label <  n._label ); }
         /// Larger-than operator (compares labels)
         /// Larger-than operator (compares labels)
-        bool operator >  (const Var& n) const { return( _label >  n._label ); };
+        bool operator > ( const Var& n ) const { return( _label >  n._label ); }
         /// Smaller-than-or-equal-to operator (compares labels)
         /// Smaller-than-or-equal-to operator (compares labels)
-        bool operator <= (const Var& n) const { return( _label <= n._label ); };
+        bool operator <= ( const Var& n ) const { return( _label <= n._label ); }
         /// Larger-than-or-equal-to operator (compares labels)
         /// Larger-than-or-equal-to operator (compares labels)
-        bool operator >= (const Var& n) const { return( _label >= n._label ); };
+        bool operator >= ( const Var& n ) const { return( _label >= n._label ); }
         /// Not-equal-to operator (compares labels)
         /// Not-equal-to operator (compares labels)
-        bool operator != (const Var& n) const { return( _label != n._label ); };
+        bool operator != ( const Var& n ) const { return( _label != n._label ); }
         /// Equal-to operator (compares labels)
         /// Equal-to operator (compares labels)
-        bool operator == (const Var& n) const { return( _label == n._label ); };
+        bool operator == ( const Var& n ) const { return( _label == n._label ); }
 
 
-        /// Stream output operator
-        friend std::ostream& operator << (std::ostream& os, const Var& n) {
+        /// Write this Var to stream
+        friend std::ostream& operator << ( std::ostream& os, const Var& n ) {
             return( os << "[" << n.label() << "]" );
             return( os << "[" << n.label() << "]" );
-        };
+        }
 };
 
 
 };
 
 
index 606c3f8..94ae3eb 100644 (file)
 #define __defined_libdai_varset_h
 
 
 #define __defined_libdai_varset_h
 
 
-#include <set>
+#include <vector>
+#include <map>
 #include <algorithm>
 #include <iostream>
 #include <cassert>
 #include <dai/var.h>
 #include <algorithm>
 #include <iostream>
 #include <cassert>
 #include <dai/var.h>
+#include <dai/util.h>
 
 
 namespace dai {
 
 
 
 
 namespace dai {
 
 
-/// VarSet represents a set of variables and is a descendant of set<Var>. 
-/// In addition, it provides an easy interface for set-theoretic operations
-/// by operator overloading.
-class VarSet : private std::set<Var> {
+/// Represents a set of variables.
+/**
+ *  It is implemented as an ordered vector<Var> for efficiency reasons
+ *  (this is more efficient than a set<Var>). In addition, it provides 
+ *  an interface for common set-theoretic operations.
+ */
+class VarSet {
     protected:
     protected:
-        /// Product of number of states of all contained variables
-        size_t _statespace;
-
-        /// Check whether ns is a subset
-        bool includes( const VarSet& ns ) const {
-            return std::includes( begin(), end(), ns.begin(), ns.end() );
-        }
-
-        /// Calculate statespace
-        size_t calcStateSpace() {
-            _statespace = 1;
-            for( const_iterator i = begin(); i != end(); ++i )
-                _statespace *= i->states();
-            return _statespace;
-        }
+        /// The variables in this set
+        std::vector<Var> _vars;
 
 
+        /// Product of number of states of all contained variables
+        size_t _states;
 
     public:
         /// Default constructor
 
     public:
         /// Default constructor
-        VarSet() : _statespace(0) {};
+        VarSet() : _vars(), _states(1) {};
 
 
-        /// Construct a VarSet with one variable
-        VarSet( const Var &n ) : _statespace( n.states() ) { 
-            insert( n ); 
+        /// Construct a VarSet from one variable
+        VarSet( const Var &n ) : _vars(), _states( n.states() ) { 
+            _vars.push_back( n );
         }
 
         }
 
-        /// Construct a VarSet with two variables
+        /// Construct a VarSet from two variables
         VarSet( const Var &n1, const Var &n2 ) { 
         VarSet( const Var &n1, const Var &n2 ) { 
-            insert( n1 ); 
-            insert( n2 ); 
-            calcStateSpace();
+            if( n1 < n2 ) {
+                _vars.push_back( n1 );
+                _vars.push_back( n2 );
+            } else if( n1 > n2 ) {
+                _vars.push_back( n2 );
+                _vars.push_back( n1 );
+            } else
+                _vars.push_back( n1 );
+            calcStates();
         }
 
         }
 
-        /// Construct from a set<Var>
-        VarSet( const std::set<Var> &ns ) {
-            std::set<Var>::operator=( ns );
-            calcStateSpace();
-        }
-
-        /// Construct from a vector<Var>
-        VarSet( const std::vector<Var> &ns ) {
-            for( std::vector<Var>::const_iterator n = ns.begin(); n != ns.end(); n++ )
-                insert( *n );
-            calcStateSpace();
+        /// Construct from a range of iterators
+        /*  The value_type of the VarIterator should be Var.
+         *  For efficiency, the number of variables can be
+         *  speficied by sizeHint.
+         */
+        template <typename VarIterator>
+        VarSet( VarIterator begin, VarIterator end, size_t sizeHint=0 ) {
+            _vars.reserve( sizeHint );
+            _vars.insert( _vars.begin(), begin, end );
+            std::sort( _vars.begin(), _vars.end() );
+            std::vector<Var>::iterator new_end = std::unique( _vars.begin(), _vars.end() );
+            _vars.erase( new_end, _vars.end() );
+            calcStates();
         }
 
         /// Copy constructor
         }
 
         /// Copy constructor
-        VarSet( const VarSet &x ) : std::set<Var>( x ), _statespace( x._statespace ) {}
+        VarSet( const VarSet &x ) : _vars( x._vars ), _states( x._states ) {}
 
         /// Assignment operator
         VarSet & operator=( const VarSet &x ) {
             if( this != &x ) {
 
         /// Assignment operator
         VarSet & operator=( const VarSet &x ) {
             if( this != &x ) {
-                std::set<Var>::operator=( x );
-                _statespace = x._statespace;
+                _vars = x._vars;
+                _states = x._states;
             }
             return *this;
         }
         
 
             }
             return *this;
         }
         
 
-        /// Return statespace, i.e. the product of the number of states of each variable
+        /// Return the product of the number of states of each variable in this set
         size_t states() const { 
         size_t states() const { 
-#ifdef DAI_DEBUG
-            size_t x = 1;
-            for( const_iterator i = begin(); i != end(); ++i )
-                x *= i->states();
-            assert( x == _statespace );
-#endif
-            return _statespace; 
+            return _states; 
         }
         
 
         }
         
 
-        /// Erase one variable
-        VarSet& operator/= (const Var& n) { 
-            erase( n ); 
-            calcStateSpace();
-            return *this; 
-        }
-
-        /// Add one variable
-        VarSet& operator|= (const Var& n) {
-            insert( n ); 
-            calcStateSpace();
-            return *this;
-        }
-
         /// Setminus operator (result contains all variables except those in ns)
         /// Setminus operator (result contains all variables except those in ns)
-        VarSet operator/ (const VarSet& ns) const {
+        VarSet operator/ ( const VarSet& ns ) const {
             VarSet res;
             VarSet res;
-            std::set_difference( begin(), end(), ns.begin(), ns.end(), inserter( res, res.begin() ) );
-            res.calcStateSpace();
+            std::set_difference( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
+            res.calcStates();
             return res;
         }
 
         /// Set-union operator (result contains all variables plus those in ns)
             return res;
         }
 
         /// Set-union operator (result contains all variables plus those in ns)
-        VarSet operator| (const VarSet& ns) const {
+        VarSet operator| ( const VarSet& ns ) const {
             VarSet res;
             VarSet res;
-            std::set_union( begin(), end(), ns.begin(), ns.end(), inserter( res, res.begin() ) );
-            res.calcStateSpace();
+            std::set_union( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
+            res.calcStates();
             return res;
         }
 
         /// Set-intersection operator (result contains all variables that are also contained in ns)
             return res;
         }
 
         /// Set-intersection operator (result contains all variables that are also contained in ns)
-        VarSet operator& (const VarSet& ns) const {
+        VarSet operator& ( const VarSet& ns ) const {
             VarSet res;
             VarSet res;
-            std::set_intersection( begin(), end(), ns.begin(), ns.end(), inserter( res, res.begin() ) );
-            res.calcStateSpace();
+            std::set_intersection( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
+            res.calcStates();
             return res;
         }
         
         /// Erases from *this all variables in ns
             return res;
         }
         
         /// Erases from *this all variables in ns
-        VarSet& operator/= (const VarSet& ns) {
+        VarSet& operator/= ( const VarSet& ns ) {
             return (*this = (*this / ns));
         }
 
             return (*this = (*this / ns));
         }
 
-        /// Adds to *this all variables in ns
-        VarSet& operator|= (const VarSet& ns) {
-            return (*this = (*this | ns));
+        /// Erase one variable
+        VarSet& operator/= ( const Var& n ) { 
+            std::vector<Var>::iterator pos = lower_bound( _vars.begin(), _vars.end(), n );
+            if( pos != _vars.end() )
+                if( *pos == n ) { // found variable, delete it
+                    _vars.erase( pos ); 
+                    _states /= n.states();
+                }
+            return *this; 
         }
 
         }
 
-        /// Erases from *this all variables not in ns
-        VarSet& operator&= (const VarSet& ns) { 
-            return (*this = (*this & ns)); 
+        /// Adds to *this all variables in ns
+        VarSet& operator|= ( const VarSet& ns ) {
+            return( *this = (*this | ns) );
         }
         }
-        
 
 
-        /// Returns false if both *this and ns are empty
-        bool operator|| (const VarSet& ns) const { 
-            return !( this->empty() && ns.empty() );
+        /// Add one variable
+        VarSet& operator|= ( const Var& n ) {
+            std::vector<Var>::iterator pos = lower_bound( _vars.begin(), _vars.end(), n );
+            if( pos == _vars.end() || *pos != n ) { // insert it
+                _vars.insert( pos, n );
+                _states *= n.states();
+            }
+            return *this;
         }
 
         }
 
-        /// Returns true if *this and ns contain common variables
-        bool operator&& (const VarSet& ns) const { 
-            return !( (*this & ns).empty() ); 
+
+        /// Erases from *this all variables not in ns
+        VarSet& operator&= ( const VarSet& ns ) { 
+            return (*this = (*this & ns));
         }
 
         }
 
+
         /// Returns true if *this is a subset of ns
         /// Returns true if *this is a subset of ns
-        bool operator<< (const VarSet& ns) const { 
-            return ns.includes( *this ); 
+        bool operator<< ( const VarSet& ns ) const { 
+            return std::includes( ns._vars.begin(), ns._vars.end(), _vars.begin(), _vars.end() );
         }
 
         /// Returns true if ns is a subset of *this
         }
 
         /// Returns true if ns is a subset of *this
-        bool operator>> (const VarSet& ns) const { 
-            return includes( ns ); 
+        bool operator>> ( const VarSet& ns ) const { 
+            return std::includes( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end() );
+        }
+
+        /// Returns true if *this and ns contain common variables
+        bool intersects( const VarSet& ns ) const { 
+            return( (*this & ns).size() > 0 ); 
         }
 
         /// Returns true if *this contains the variable n
         }
 
         /// Returns true if *this contains the variable n
-        bool operator&& (const Var& n) const { 
-            return( find( n ) == end() ? false : true ); 
+        bool contains( const Var& n ) const { 
+            return std::binary_search( _vars.begin(), _vars.end(), n );
         }
 
         }
 
-        
         /// Sends a VarSet to an output stream
         friend std::ostream& operator<< (std::ostream & os, const VarSet& ns) {
         /// Sends a VarSet to an output stream
         friend std::ostream& operator<< (std::ostream & os, const VarSet& ns) {
-            for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++)
-                os << *n;
+            foreach( const Var &n, ns._vars )
+                os << n;
             return( os );
         }
 
             return( os );
         }
 
+        // Constant iterator
+        typedef std::vector<Var>::const_iterator const_iterator;
+        /// Iterator
+        typedef std::vector<Var>::iterator iterator;
+        // Constant reverse iterator
+        typedef std::vector<Var>::const_reverse_iterator const_reverse_iterator;
+        /// Reverse Iterator
+        typedef std::vector<Var>::reverse_iterator reverse_iterator;
         
         
-/*      The following makes part of the public interface of set<Var> available.
- *      It is important to note that insert functions have to be overloaded,
- *      because they have to recalculate the statespace. A different approach
- *      would be to publicly inherit from set<Var> and only overload the insert
- *      methods.
- */
-        
-        // Additional interface from set<Var> that has to be provided
-        using std::set<Var>::const_iterator;
-        using std::set<Var>::iterator;
-        using std::set<Var>::const_reference;
-        using std::set<Var>::begin;
-        using std::set<Var>::end;
-        using std::set<Var>::size;
-        using std::set<Var>::empty;
-
-        /// Copy of set<Var>::insert which additionally calculates the new statespace
-        std::pair<iterator, bool> insert( const Var& x ) {
-            std::pair<iterator, bool> result = std::set<Var>::insert( x );
-            calcStateSpace();
-            return result;
-        }
+        /// Returns iterator that points to the first variable
+        iterator begin() { return _vars.begin(); }
+        /// Returns constant iterator that points to the first variable
+        const_iterator begin() const { return _vars.begin(); }
 
 
-        /// Copy of set<Var>::insert which additionally calculates the new statespace
-        iterator insert( iterator pos, const value_type& x ) {
-            iterator result = std::set<Var>::insert( pos, x );
-            calcStateSpace();
-            return result;
-        }
+        /// Returns iterator that points beyond the last variable
+        iterator end() { return _vars.end(); }
+        /// Returns constant iterator that points beyond the last variable
+        const_iterator end() const { return _vars.end(); }
 
 
-        /// Test for equality (ignore _statespace member)
+        /// Returns reverse iterator that points to the last variable
+        reverse_iterator rbegin() { return _vars.rbegin(); }
+        /// Returns constant reverse iterator that points to the last variable
+        const_reverse_iterator rbegin() const { return _vars.rbegin(); }
+
+        /// Returns reverse iterator that points beyond the first variable
+        reverse_iterator rend() { return _vars.rend(); }
+        /// Returns constant reverse iterator that points beyond the first variable
+        const_reverse_iterator rend() const { return _vars.rend(); }
+
+
+        /// Returns number of variables
+        std::vector<Var>::size_type size() const { return _vars.size(); }
+
+
+        /// Returns whether the set is empty
+        bool empty() const { return _vars.size() == 0; }
+
+
+        /// Test for equality of variable labels
         friend bool operator==( const VarSet &a, const VarSet &b ) {
         friend bool operator==( const VarSet &a, const VarSet &b ) {
-            return operator==( (std::set<Var>)a, (std::set<Var>)b );
+            return (a._vars == b._vars);
         }
 
         }
 
-        /// Test for inequality (ignore _statespace member)
+        /// Test for inequality of variable labels
         friend bool operator!=( const VarSet &a, const VarSet &b ) {
         friend bool operator!=( const VarSet &a, const VarSet &b ) {
-            return operator!=( (std::set<Var>)a, (std::set<Var>)b );
+            return !(a._vars == b._vars);
         }
 
         }
 
+        /// Lexicographical comparison of variable labels
         friend bool operator<( const VarSet &a, const VarSet &b ) {
         friend bool operator<( const VarSet &a, const VarSet &b ) {
-            return operator<( (std::set<Var>)a, (std::set<Var>)b );
+            return a._vars < b._vars;
+        }
+
+        /// calcState calculates the linear index of this VarSet that corresponds
+        /// to the states of the variables given in states, implicitly assuming
+        /// states[m] = 0 for all m in this VarSet which are not in states.
+        size_t calcState( const std::map<Var, size_t> &states ) const {
+            size_t prod = 1;
+            size_t state = 0;
+            foreach( const Var &n, *this ) {
+                std::map<Var, size_t>::const_iterator m = states.find( n );
+                if( m != states.end() )
+                    state += prod * m->second;
+                prod *= n.states();
+            }
+            return state;
+        }
+
+    protected:
+        /// Calculate number of states
+        size_t calcStates() {
+            _states = 1;
+            foreach( Var &i, _vars )
+                _states *= i.states();
+            return _states;
         }
 };
 
         }
 };
 
index 5d650e8..b887316 100644 (file)
@@ -171,7 +171,7 @@ istream& operator >> (istream& is, FactorGraph& fg) {
             // add the Factor
             VarSet I_vars;
             for( size_t mi = 0; mi < nr_members; mi++ )
             // add the Factor
             VarSet I_vars;
             for( size_t mi = 0; mi < nr_members; mi++ )
-                I_vars.insert( Var(labels[mi], dims[mi]) );
+                I_vars |= Var(labels[mi], dims[mi]);
             factors.push_back(Factor(I_vars,0.0));
             
             // calculate permutation sigma (internally, members are sorted)
             factors.push_back(Factor(I_vars,0.0));
             
             // calculate permutation sigma (internally, members are sorted)
@@ -394,7 +394,7 @@ void FactorGraph::clamp( const Var & n, size_t i ) {
 
     // For all factors that contain n
     for( size_t I = 0; I < nrFactors(); I++ ) 
 
     // For all factors that contain n
     for( size_t I = 0; I < nrFactors(); I++ ) 
-        if( factor(I).vars() && n )
+        if( factor(I).vars().contains( n ) )
             // Multiply it with a delta function
             factor(I) *= delta_n_i;
 
             // Multiply it with a delta function
             factor(I) *= delta_n_i;
 
@@ -423,14 +423,14 @@ void FactorGraph::saveProbs( const VarSet &ns ) {
     if( !_undoProbs.empty() )
         cout << "FactorGraph::saveProbs:  WARNING: _undoProbs not empy!" << endl;
     for( size_t I = 0; I < nrFactors(); I++ )
     if( !_undoProbs.empty() )
         cout << "FactorGraph::saveProbs:  WARNING: _undoProbs not empy!" << endl;
     for( size_t I = 0; I < nrFactors(); I++ )
-        if( factor(I).vars() && ns )
+        if( factor(I).vars().intersects( ns ) )
             _undoProbs[I] = factor(I).p();
 }
 
 
 void FactorGraph::undoProbs( const VarSet &ns ) {
     for( map<size_t,Prob>::iterator uI = _undoProbs.begin(); uI != _undoProbs.end(); ) {
             _undoProbs[I] = factor(I).p();
 }
 
 
 void FactorGraph::undoProbs( const VarSet &ns ) {
     for( map<size_t,Prob>::iterator uI = _undoProbs.begin(); uI != _undoProbs.end(); ) {
-        if( factor((*uI).first).vars() && ns ) {
+        if( factor((*uI).first).vars().intersects( ns ) ) {
 //          cout << "undoing " << factor((*uI).first).vars() << endl;
 //          cout << "from " << factor((*uI).first).p() << " to " << (*uI).second << endl;
             factor((*uI).first).p() = (*uI).second;
 //          cout << "undoing " << factor((*uI).first).vars() << endl;
 //          cout << "from " << factor((*uI).first).p() << " to " << (*uI).second << endl;
             factor((*uI).first).p() = (*uI).second;
index f7087a5..094d3fa 100644 (file)
@@ -156,11 +156,11 @@ string HAK::identify() const {
 
 void HAK::init( const VarSet &ns ) {
     for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
 
 void HAK::init( const VarSet &ns ) {
     for( vector<Factor>::iterator alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
-        if( alpha->vars() && ns )
+        if( alpha->vars().intersects( ns ) )
             alpha->fill( 1.0 / alpha->states() );
 
     for( size_t beta = 0; beta < nrIRs(); beta++ )
             alpha->fill( 1.0 / alpha->states() );
 
     for( size_t beta = 0; beta < nrIRs(); beta++ )
-        if( IR(beta) && ns ) {
+        if( IR(beta).intersects( ns ) ) {
             _Qb[beta].fill( 1.0 );
             foreach( const Neighbor &alpha, nbIR(beta) ) {
                 size_t _beta = alpha.dual;
             _Qb[beta].fill( 1.0 );
             foreach( const Neighbor &alpha, nbIR(beta) ) {
                 size_t _beta = alpha.dual;
index b846d35..c7652e1 100644 (file)
@@ -338,7 +338,7 @@ size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t Previo
         // find first occurence of *n in the tree, which is closest to the root
         size_t e = 0;
         for( ; e != newTree.size(); e++ ) {
         // find first occurence of *n in the tree, which is closest to the root
         size_t e = 0;
         for( ; e != newTree.size(); e++ ) {
-            if( OR(newTree[e].n2).vars() && *n )
+            if( OR(newTree[e].n2).vars().contains( *n ) )
                 break;
         }
         assert( e != newTree.size() );
                 break;
         }
         assert( e != newTree.size() );
index 1dd0cfd..1ffc360 100644 (file)
@@ -190,7 +190,7 @@ Complex MF::logZ() const {
 
 void MF::init( const VarSet &ns ) {
     for( size_t i = 0; i < nrVars(); i++ ) {
 
 void MF::init( const VarSet &ns ) {
     for( size_t i = 0; i < nrVars(); i++ ) {
-        if( ns && var(i) )
+        if( ns.contains(var(i) ) )
             _beliefs[i].fill( 1.0 );
     }
 }
             _beliefs[i].fill( 1.0 );
     }
 }
index 72b3eb9..79fb5f9 100644 (file)
@@ -180,10 +180,10 @@ bool RegionGraph::Check_Counting_Numbers() {
     for( vector<Var>::const_iterator n = vars.begin(); n != vars.end(); n++ ) {
         double c_n = 0.0;
         for( size_t alpha = 0; alpha < nrORs(); alpha++ )
     for( vector<Var>::const_iterator n = vars.begin(); n != vars.end(); n++ ) {
         double c_n = 0.0;
         for( size_t alpha = 0; alpha < nrORs(); alpha++ )
-            if( OR(alpha).vars() && *n )
+            if( OR(alpha).vars().contains( *n ) )
                 c_n += OR(alpha).c();
         for( size_t beta = 0; beta < nrIRs(); beta++ )
                 c_n += OR(alpha).c();
         for( size_t beta = 0; beta < nrIRs(); beta++ )
-            if( IR(beta) && *n )
+            if( IR(beta).contains( *n ) )
                 c_n += IR(beta).c();
         if( fabs(c_n - 1.0) > 1e-15 ) {
             all_valid = false;
                 c_n += IR(beta).c();
         if( fabs(c_n - 1.0) > 1e-15 ) {
             all_valid = false;
@@ -206,11 +206,11 @@ void RegionGraph::RecomputeORs() {
 
 void RegionGraph::RecomputeORs( const VarSet &ns ) {
     for( size_t alpha = 0; alpha < nrORs(); alpha++ )
 
 void RegionGraph::RecomputeORs( const VarSet &ns ) {
     for( size_t alpha = 0; alpha < nrORs(); alpha++ )
-        if( OR(alpha).vars() && ns )
+        if( OR(alpha).vars().intersects( ns ) )
             OR(alpha).fill( 1.0 );
     for( size_t I = 0; I < nrFactors(); I++ )
         if( fac2OR[I] != -1U )
             OR(alpha).fill( 1.0 );
     for( size_t I = 0; I < nrFactors(); I++ )
         if( fac2OR[I] != -1U )
-            if( OR( fac2OR[I] ).vars() && ns )
+            if( OR( fac2OR[I] ).vars().intersects( ns ) )
                 OR( fac2OR[I] ) *= factor( I );
 }
 
                 OR( fac2OR[I] ) *= factor( I );
 }
 
index e5f78f4..cc5fe80 100644 (file)
@@ -45,7 +45,7 @@ void MakeHOIFG( size_t N, size_t M, size_t k, double sigma, FactorGraph &fg ) {
                        do {
                                size_t newind = (size_t)(N * rnd_uniform());
                                Var newvar = Var(newind, 2);
                        do {
                                size_t newind = (size_t)(N * rnd_uniform());
                                Var newvar = Var(newind, 2);
-                               if( !(vars && newvar) ) {
+                               if( !(vars.contains( newvar )) ) {
                                        vars |= newvar;
                                        break;
                                }
                                        vars |= newvar;
                                        break;
                                }