Small changes
[libdai.git] / include / dai / varset.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Copyright (C) 2002 Martijn Leisink [martijn@mbfys.kun.nl]
3 Radboud University Nijmegen, The Netherlands
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #ifndef __defined_libdai_varset_h
24 #define __defined_libdai_varset_h
25
26
27 #include <vector>
28 #include <map>
29 #include <algorithm>
30 #include <iostream>
31 #include <cassert>
32 #include <dai/var.h>
33 #include <dai/util.h>
34
35
36 namespace dai {
37
38
39 /// A VarSet represents a set of variables.
40 /**
41 * It is implemented as an ordered std::vector<Var> for efficiency reasons
42 * (indeed, it was found that a std::set<Var> usually has more overhead).
43 * In addition, it provides an interface for common set-theoretic operations.
44 */
45 class VarSet {
46 private:
47 /// The variables in this set
48 std::vector<Var> _vars;
49
50 /// Product of number of states of all contained variables
51 size_t _states;
52
53 public:
54 /// Default constructor
55 VarSet() : _vars(), _states(1) {};
56
57 /// Construct a VarSet from one variable
58 VarSet( const Var &n ) : _vars(), _states( n.states() ) {
59 _vars.push_back( n );
60 }
61
62 /// Construct a VarSet from two variables
63 VarSet( const Var &n1, const Var &n2 ) {
64 if( n1 < n2 ) {
65 _vars.push_back( n1 );
66 _vars.push_back( n2 );
67 } else if( n1 > n2 ) {
68 _vars.push_back( n2 );
69 _vars.push_back( n1 );
70 } else
71 _vars.push_back( n1 );
72 calcStates();
73 }
74
75 /// Construct from a range of iterators
76 /** The value_type of the VarIterator should be Var.
77 * For efficiency, the number of variables can be
78 * speficied by sizeHint.
79 */
80 template <typename VarIterator>
81 VarSet( VarIterator begin, VarIterator end, size_t sizeHint=0 ) {
82 _vars.reserve( sizeHint );
83 _vars.insert( _vars.begin(), begin, end );
84 std::sort( _vars.begin(), _vars.end() );
85 std::vector<Var>::iterator new_end = std::unique( _vars.begin(), _vars.end() );
86 _vars.erase( new_end, _vars.end() );
87 calcStates();
88 }
89
90 /// Copy constructor
91 VarSet( const VarSet &x ) : _vars( x._vars ), _states( x._states ) {}
92
93 /// Assignment operator
94 VarSet & operator=( const VarSet &x ) {
95 if( this != &x ) {
96 _vars = x._vars;
97 _states = x._states;
98 }
99 return *this;
100 }
101
102
103 /// Returns the product of the number of states of each variable in this set
104 size_t states() const {
105 return _states;
106 }
107
108
109 /// Setminus operator (result contains all variables in *this, except those in ns)
110 VarSet operator/ ( const VarSet& ns ) const {
111 VarSet res;
112 std::set_difference( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
113 res.calcStates();
114 return res;
115 }
116
117 /// Set-union operator (result contains all variables in *this, plus those in ns)
118 VarSet operator| ( const VarSet& ns ) const {
119 VarSet res;
120 std::set_union( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
121 res.calcStates();
122 return res;
123 }
124
125 /// Set-intersection operator (result contains all variables in *this that are also contained in ns)
126 VarSet operator& ( const VarSet& ns ) const {
127 VarSet res;
128 std::set_intersection( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end(), inserter( res._vars, res._vars.begin() ) );
129 res.calcStates();
130 return res;
131 }
132
133 /// Erases from *this all variables in ns
134 VarSet& operator/= ( const VarSet& ns ) {
135 return (*this = (*this / ns));
136 }
137
138 /// Erase one variable
139 VarSet& operator/= ( const Var& n ) {
140 std::vector<Var>::iterator pos = lower_bound( _vars.begin(), _vars.end(), n );
141 if( pos != _vars.end() )
142 if( *pos == n ) { // found variable, delete it
143 _vars.erase( pos );
144 _states /= n.states();
145 }
146 return *this;
147 }
148
149 /// Adds to *this all variables in ns
150 VarSet& operator|= ( const VarSet& ns ) {
151 return( *this = (*this | ns) );
152 }
153
154 /// Add one variable
155 VarSet& operator|= ( const Var& n ) {
156 std::vector<Var>::iterator pos = lower_bound( _vars.begin(), _vars.end(), n );
157 if( pos == _vars.end() || *pos != n ) { // insert it
158 _vars.insert( pos, n );
159 _states *= n.states();
160 }
161 return *this;
162 }
163
164
165 /// Erases from *this all variables not in ns
166 VarSet& operator&= ( const VarSet& ns ) {
167 return (*this = (*this & ns));
168 }
169
170
171 /// Returns true if *this is a subset of ns
172 bool operator<< ( const VarSet& ns ) const {
173 return std::includes( ns._vars.begin(), ns._vars.end(), _vars.begin(), _vars.end() );
174 }
175
176 /// Returns true if ns is a subset of *this
177 bool operator>> ( const VarSet& ns ) const {
178 return std::includes( _vars.begin(), _vars.end(), ns._vars.begin(), ns._vars.end() );
179 }
180
181 /// Returns true if *this and ns contain common variables
182 bool intersects( const VarSet& ns ) const {
183 return( (*this & ns).size() > 0 );
184 }
185
186 /// Returns true if *this contains the variable n
187 bool contains( const Var& n ) const {
188 return std::binary_search( _vars.begin(), _vars.end(), n );
189 }
190
191 /// Sends a VarSet to an output stream
192 friend std::ostream& operator<< (std::ostream & os, const VarSet& ns) {
193 foreach( const Var &n, ns._vars )
194 os << n;
195 return( os );
196 }
197
198 /// Constant iterator over Vars
199 typedef std::vector<Var>::const_iterator const_iterator;
200 /// Iterator over Vars
201 typedef std::vector<Var>::iterator iterator;
202 /// Constant reverse iterator over Vars
203 typedef std::vector<Var>::const_reverse_iterator const_reverse_iterator;
204 /// Reverse iterator over Vars
205 typedef std::vector<Var>::reverse_iterator reverse_iterator;
206
207 /// Returns iterator that points to the first variable
208 iterator begin() { return _vars.begin(); }
209 /// Returns constant iterator that points to the first variable
210 const_iterator begin() const { return _vars.begin(); }
211
212 /// Returns iterator that points beyond the last variable
213 iterator end() { return _vars.end(); }
214 /// Returns constant iterator that points beyond the last variable
215 const_iterator end() const { return _vars.end(); }
216
217 /// Returns reverse iterator that points to the last variable
218 reverse_iterator rbegin() { return _vars.rbegin(); }
219 /// Returns constant reverse iterator that points to the last variable
220 const_reverse_iterator rbegin() const { return _vars.rbegin(); }
221
222 /// Returns reverse iterator that points beyond the first variable
223 reverse_iterator rend() { return _vars.rend(); }
224 /// Returns constant reverse iterator that points beyond the first variable
225 const_reverse_iterator rend() const { return _vars.rend(); }
226
227
228 /// Returns number of variables
229 std::vector<Var>::size_type size() const { return _vars.size(); }
230
231
232 /// Returns whether the VarSet is empty
233 bool empty() const { return _vars.size() == 0; }
234
235
236 /// Test for equality of variable labels
237 friend bool operator==( const VarSet &a, const VarSet &b ) {
238 return (a._vars == b._vars);
239 }
240
241 /// Test for inequality of variable labels
242 friend bool operator!=( const VarSet &a, const VarSet &b ) {
243 return !(a._vars == b._vars);
244 }
245
246 /// Lexicographical comparison of variable labels
247 friend bool operator<( const VarSet &a, const VarSet &b ) {
248 return a._vars < b._vars;
249 }
250
251 /// calcState calculates the linear index of this VarSet that corresponds
252 /// to the states of the variables given in states, implicitly assuming
253 /// states[m] = 0 for all m in this VarSet which are not in states.
254 size_t calcState( const std::map<Var, size_t> &states ) const {
255 size_t prod = 1;
256 size_t state = 0;
257 foreach( const Var &n, *this ) {
258 std::map<Var, size_t>::const_iterator m = states.find( n );
259 if( m != states.end() )
260 state += prod * m->second;
261 prod *= n.states();
262 }
263 return state;
264 }
265
266 private:
267 /// Calculates the number of states
268 size_t calcStates() {
269 _states = 1;
270 foreach( Var &i, _vars )
271 _states *= i.states();
272 return _states;
273 }
274 };
275
276
277 /// For two Vars n1 and n2, the expression n1 | n2 gives the Varset containing n1 and n2
278 inline VarSet operator| (const Var& n1, const Var& n2) {
279 return( VarSet(n1, n2) );
280 }
281
282
283 } // end of namespace dai
284
285
286 #endif