Fixed a bug (introduced in commit 64db6bc3...) and another one in Factors2mx
[libdai.git] / include / dai / clustergraph.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines class ClusterGraph, which is used by JTree, TreeEP and HAK
14 /// \todo The "MinFill" and "WeightedMinFill" variable elimination heuristics seem designed for Markov graphs;
15 /// add similar heuristics which are designed for factor graphs.
16
17
18 #ifndef __defined_libdai_clustergraph_h
19 #define __defined_libdai_clustergraph_h
20
21
22 #include <set>
23 #include <vector>
24 #include <dai/varset.h>
25 #include <dai/bipgraph.h>
26 #include <dai/factorgraph.h>
27
28
29 namespace dai {
30
31
32 /// A ClusterGraph is a hypergraph with variables as nodes, and "clusters" (sets of variables) as hyperedges.
33 /** It is implemented as a bipartite graph with variable (Var) nodes and cluster (VarSet) nodes.
34 * One may think of a ClusterGraph as a FactorGraph without the actual factor values.
35 * \todo Remove the _vars and _clusters variables and use only the graph and a contextual factor graph.
36 */
37 class ClusterGraph {
38 public:
39 /// Shorthand for BipartiteGraph::Neighbor
40 typedef BipartiteGraph::Neighbor Neighbor;
41
42 /// Shorthand for BipartiteGraph::Edge
43 typedef BipartiteGraph::Edge Edge;
44
45 private:
46 /// Stores the neighborhood structure
47 BipartiteGraph _G;
48
49 /// Stores the variables corresponding to the nodes
50 std::vector<Var> _vars;
51
52 /// Stores the clusters corresponding to the hyperedges
53 std::vector<VarSet> _clusters;
54
55 public:
56 /// \name Constructors and destructors
57 //@{
58 /// Default constructor
59 ClusterGraph() : _G(), _vars(), _clusters() {}
60
61 /// Construct from vector of VarSet 's
62 ClusterGraph( const std::vector<VarSet>& cls );
63
64 /// Construct from a factor graph
65 /** Creates cluster graph which has factors in \a fg as clusters if \a onlyMaximal == \c false,
66 * and only the maximal factors in \a fg if \a onlyMaximal == \c true.
67 */
68 ClusterGraph( const FactorGraph& fg, bool onlyMaximal );
69 //@}
70
71 /// \name Queries
72 //@{
73 /// Returns a constant reference to the graph structure
74 const BipartiteGraph& bipGraph() const { return _G; }
75
76 /// Returns number of variables
77 size_t nrVars() const { return _vars.size(); }
78
79 /// Returns a constant reference to the variables
80 const std::vector<Var>& vars() const { return _vars; }
81
82 /// Returns a constant reference to the \a i'th variable
83 const Var& var( size_t i ) const {
84 DAI_DEBASSERT( i < nrVars() );
85 return _vars[i];
86 }
87
88 /// Returns number of clusters
89 size_t nrClusters() const { return _clusters.size(); }
90
91 /// Returns a constant reference to the clusters
92 const std::vector<VarSet>& clusters() const { return _clusters; }
93
94 /// Returns a constant reference to the \a I'th cluster
95 const VarSet& cluster( size_t I ) const {
96 DAI_DEBASSERT( I < nrClusters() );
97 return _clusters[I];
98 }
99
100 /// Returns the index of variable \a n
101 size_t findVar( const Var& n ) const {
102 return find( _vars.begin(), _vars.end(), n ) - _vars.begin();
103 }
104
105 /// Returns the index of a cluster \a cl
106 size_t findCluster( const VarSet& cl ) const {
107 return find( _clusters.begin(), _clusters.end(), cl ) - _clusters.begin();
108 }
109
110 /* /// Returns the index of a cluster \a _cl
111 size_t findCluster( const SmallSet<size_t>& _cl ) const {
112 if( _cl.size() == 0 ) {
113 for( size_t I = 0; I < nrClusters(); I++ )
114 if( cluster(I).size() == 0 )
115 return I;
116 } else {
117 size_t i = _cl.front();
118 foreach( const Neighbor& I, _G.nb1(i) )
119 if( _G.nb2Set(I) == _cl )
120 return I;
121 }
122 return nrClusters();
123 }*/
124
125 /// Returns union of clusters that contain the \a i 'th variable
126 VarSet Delta( size_t i ) const {
127 VarSet result;
128 foreach( const Neighbor& I, _G.nb1(i) )
129 result |= _clusters[I];
130 return result;
131 }
132
133 /// Returns union of clusters that contain the \a i 'th (except this variable itself)
134 VarSet delta( size_t i ) const {
135 return Delta( i ) / _vars[i];
136 }
137
138 /// Returns \c true if variables with indices \a i1 and \a i2 are adjacent, i.e., both contained in the same cluster
139 bool adj( size_t i1, size_t i2 ) const {
140 if( i1 == i2 )
141 return false;
142 bool result = false;
143 foreach( const Neighbor& I, _G.nb1(i1) )
144 if( find( _G.nb2(I).begin(), _G.nb2(I).end(), i2 ) != _G.nb2(I).end() ) {
145 result = true;
146 break;
147 }
148 return result;
149 }
150
151 /// Returns \c true if cluster \a I is not contained in a larger cluster
152 bool isMaximal( size_t I ) const {
153 DAI_DEBASSERT( I < _G.nrNodes2() );
154 const VarSet & clI = _clusters[I];
155 bool maximal = true;
156 // The following may not be optimal, since it may repeatedly test the same cluster *J
157 foreach( const Neighbor& i, _G.nb2(I) ) {
158 foreach( const Neighbor& J, _G.nb1(i) )
159 if( (J != I) && (clI << _clusters[J]) ) {
160 maximal = false;
161 break;
162 }
163 if( !maximal )
164 break;
165 }
166 return maximal;
167 }
168 //@}
169
170 /// \name Operations
171 //@{
172 /// Inserts a cluster (if it does not already exist) and creates new variables, if necessary
173 /** \note This function could be better optimized if the index of one variable in \a cl would be known.
174 * If one could assume _vars to be ordered, a binary search could be used instead of a linear one.
175 */
176 size_t insert( const VarSet& cl ) {
177 size_t index = findCluster( cl ); // OPTIMIZE ME
178 if( index == _clusters.size() ) {
179 _clusters.push_back( cl );
180 // add variables (if necessary) and calculate neighborhood of new cluster
181 std::vector<size_t> nbs;
182 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
183 size_t iter = findVar( *n ); // OPTIMIZE ME
184 nbs.push_back( iter );
185 if( iter == _vars.size() ) {
186 _G.addNode1();
187 _vars.push_back( *n );
188 }
189 }
190 _G.addNode2( nbs.begin(), nbs.end(), nbs.size() );
191 }
192 return index;
193 }
194
195 /* /// Inserts a cluster (if it does not already exist), assuming no new variables have to be created
196 size_t insert( const SmallSet<size_t>& _cl ) {
197 size_t index = findCluster( _cl );
198 if( index == _clusters.size() ) {
199 VarSet cl;
200 foreach( size_t i, _cl )
201 cl |= var(i);
202 _clusters.push_back( cl );
203 _G.addNode2( _cl.begin(), _cl.end(), _cl.size() );
204 }
205 return index;
206 }*/
207
208 /// Erases all clusters that are not maximal
209 ClusterGraph& eraseNonMaximal() {
210 for( size_t I = 0; I < _G.nrNodes2(); ) {
211 if( !isMaximal(I) ) {
212 _clusters.erase( _clusters.begin() + I );
213 _G.eraseNode2(I);
214 } else
215 I++;
216 }
217 return *this;
218 }
219
220 /// Erases all clusters that contain the \a i 'th variable
221 ClusterGraph& eraseSubsuming( size_t i ) {
222 DAI_ASSERT( i < nrVars() );
223 while( _G.nb1(i).size() ) {
224 _clusters.erase( _clusters.begin() + _G.nb1(i)[0] );
225 _G.eraseNode2( _G.nb1(i)[0] );
226 }
227 return *this;
228 }
229
230 /// Eliminates variable with index \a i, without deleting the variable itself
231 /** \note This function can be better optimized
232 */
233 VarSet elimVar( size_t i ) {
234 DAI_ASSERT( i < nrVars() );
235 VarSet Di = Delta( i );
236
237 // if( 1 ) { // unoptimized, transparent code
238 VarSet di = delta( i );
239 insert( di );
240 eraseSubsuming( i );
241 eraseNonMaximal();
242 /* } else { // partially optimized code
243 SmallSet<size_t> nbI = _G.delta1( i, false );
244 size_t I = insert( nbI );
245
246 while( _G.nb1(i).size() ) {
247 size_t J = _G.nb1(i,0);
248 _clusters.erase( _clusters.begin() + J );
249 _G.eraseNode2( J );
250 if( I > J )
251 I--;
252 }
253
254 bool di_maximal = true;
255 foreach( size_t j, nbI ) {
256 for( size_t _J = 0; _J < _G.nb1(j).size(); ) {
257 size_t J = _G.nb1(j,_J);
258 SmallSet<size_t> indJ = _G.nb2Set( J );
259 if( indJ << nbI && indJ.size() != nbI.size() ) {
260 _clusters.erase( _clusters.begin() + J );
261 _G.eraseNode2( J );
262 if( I > J )
263 I--;
264 } else {
265 if( di_maximal && indJ >> nbI && indJ.size() != nbI.size() )
266 di_maximal = false;
267 _J++;
268 }
269 }
270 }
271 if( !di_maximal ) {
272 _clusters.erase( _clusters.begin() + I );
273 _G.eraseNode2( I );
274 }
275 }*/
276
277 return Di;
278 }
279 //@}
280
281 /// \name Input/Ouput
282 //@{
283 /// Writes a ClusterGraph to an output stream
284 friend std::ostream& operator << ( std::ostream& os, const ClusterGraph& cl ) {
285 os << cl.clusters();
286 return os;
287 }
288 //@}
289
290 /// \name Variable elimination
291 //@{
292 /// Performs Variable Elimination, keeping track of the interactions that are created along the way.
293 /** \tparam EliminationChoice should support "size_t operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars )"
294 * \param f function object which returns the next variable index to eliminate; for example, a dai::greedyVariableElimination object.
295 * \param maxStates maximum total number of states of all clusters in the output cluster graph (0 means no limit).
296 * \throws OUT_OF_MEMORY if total number of states becomes larger than maxStates
297 * \return A set of elimination "cliques".
298 */
299 template<class EliminationChoice>
300 ClusterGraph VarElim( EliminationChoice f, size_t maxStates=0 ) const {
301 // Make a copy
302 ClusterGraph cl(*this);
303 cl.eraseNonMaximal();
304
305 ClusterGraph result;
306
307 // Construct set of variable indices
308 std::set<size_t> varindices;
309 for( size_t i = 0; i < _vars.size(); ++i )
310 varindices.insert( i );
311
312 // Do variable elimination
313 long double totalStates = 0.0;
314 while( !varindices.empty() ) {
315 size_t i = f( cl, varindices );
316 VarSet Di = cl.elimVar( i );
317 result.insert( Di );
318 if( maxStates ) {
319 totalStates += Di.nrStates();
320 if( totalStates > maxStates )
321 DAI_THROW(OUT_OF_MEMORY);
322 }
323 varindices.erase( i );
324 }
325
326 return result;
327 }
328 //@}
329 };
330
331
332 /// Helper object for dai::ClusterGraph::VarElim()
333 /** Chooses the next variable to eliminate by picking them sequentially from a given vector of variables.
334 */
335 class sequentialVariableElimination {
336 private:
337 /// The variable elimination sequence
338 std::vector<Var> seq;
339 /// Counter
340 size_t i;
341
342 public:
343 /// Construct from vector of variables
344 sequentialVariableElimination( const std::vector<Var> s ) : seq(s), i(0) {}
345
346 /// Returns next variable in sequence
347 size_t operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ );
348 };
349
350
351 /// Helper object for dai::ClusterGraph::VarElim()
352 /** Chooses the next variable to eliminate greedily by taking the one that minimizes
353 * a given heuristic cost function.
354 */
355 class greedyVariableElimination {
356 public:
357 /// Type of cost functions to be used for greedy variable elimination
358 typedef size_t (*eliminationCostFunction)(const ClusterGraph &, size_t);
359
360 private:
361 /// Pointer to the cost function used
362 eliminationCostFunction heuristic;
363
364 public:
365 /// Construct from cost function
366 /** \note Examples of cost functions are eliminationCost_MinFill() and eliminationCost_WeightedMinFill().
367 */
368 greedyVariableElimination( eliminationCostFunction h ) : heuristic(h) {}
369
370 /// Returns the best variable from \a remainingVars to eliminate in the cluster graph \a cl by greedily minimizing the cost function.
371 /** This function calculates the cost for eliminating each variable in \a remaingVars and returns the variable which has lowest cost.
372 */
373 size_t operator()( const ClusterGraph &cl, const std::set<size_t>& remainingVars );
374 };
375
376
377 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinNeighbors" criterion.
378 /** The cost is measured as "number of neigboring nodes in the current adjacency graph",
379 * where the adjacency graph has the variables as its nodes and connects
380 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
381 */
382 size_t eliminationCost_MinNeighbors( const ClusterGraph& cl, size_t i );
383
384
385 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinWeight" criterion.
386 /** The cost is measured as "product of weights of neighboring nodes in the current adjacency graph",
387 * where the adjacency graph has the variables as its nodes and connects
388 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
389 * The weight of a node is the number of states of the corresponding variable.
390 */
391 size_t eliminationCost_MinWeight( const ClusterGraph& cl, size_t i );
392
393
394 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinFill" criterion.
395 /** The cost is measured as "number of added edges in the adjacency graph",
396 * where the adjacency graph has the variables as its nodes and connects
397 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
398 */
399 size_t eliminationCost_MinFill( const ClusterGraph& cl, size_t i );
400
401
402 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "WeightedMinFill" criterion.
403 /** The cost is measured as "total weight of added edges in the adjacency graph",
404 * where the adjacency graph has the variables as its nodes and connects
405 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
406 * The weight of an edge is the product of the number of states of the variables corresponding with its nodes.
407 */
408 size_t eliminationCost_WeightedMinFill( const ClusterGraph& cl, size_t i );
409
410
411 } // end of namespace dai
412
413
414 #endif