3e84bd0616f41e241002993eeb32298d774c8496
[libdai.git] / include / dai / clustergraph.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines class ClusterGraph, which is used by JTree, TreeEP and HAK
14 /// \todo The "MinFill" and "WeightedMinFill" variable elimination heuristics seem designed for Markov graphs;
15 /// add similar heuristics which are designed for factor graphs.
16
17
18 #ifndef __defined_libdai_clustergraph_h
19 #define __defined_libdai_clustergraph_h
20
21
22 #include <set>
23 #include <vector>
24 #include <dai/varset.h>
25 #include <dai/bipgraph.h>
26 #include <dai/factorgraph.h>
27
28
29 namespace dai {
30
31
32 /// A ClusterGraph is a hypergraph with variables as nodes, and "clusters" (sets of variables) as hyperedges.
33 /** It is implemented as a bipartite graph with variable (Var) nodes and cluster (VarSet) nodes.
34 * One may think of a ClusterGraph as a FactorGraph without the actual factor values.
35 * \todo Remove the _vars and _clusters variables and use only the graph and a contextual factor graph.
36 */
37 class ClusterGraph {
38 public:
39 /// Shorthand for BipartiteGraph::Neighbor
40 typedef BipartiteGraph::Neighbor Neighbor;
41
42 /// Shorthand for BipartiteGraph::Edge
43 typedef BipartiteGraph::Edge Edge;
44
45 private:
46 /// Stores the neighborhood structure
47 BipartiteGraph _G;
48
49 /// Stores the variables corresponding to the nodes
50 std::vector<Var> _vars;
51
52 /// Stores the clusters corresponding to the hyperedges
53 std::vector<VarSet> _clusters;
54
55 public:
56 /// \name Constructors and destructors
57 //@{
58 /// Default constructor
59 ClusterGraph() : _G(), _vars(), _clusters() {}
60
61 /// Construct from vector of VarSet 's
62 ClusterGraph( const std::vector<VarSet>& cls );
63
64 /// Construct from a factor graph
65 /** Creates cluster graph which has factors in \a fg as clusters if \a onlyMaximal == \c false,
66 * and only the maximal factors in \a fg if \a onlyMaximal == \c true.
67 */
68 ClusterGraph( const FactorGraph& fg, bool onlyMaximal );
69 //@}
70
71 /// \name Queries
72 //@{
73 /// Returns a constant reference to the graph structure
74 const BipartiteGraph& bipGraph() const { return _G; }
75
76 /// Returns number of variables
77 size_t nrVars() const { return _vars.size(); }
78
79 /// Returns a constant reference to the variables
80 const std::vector<Var>& vars() const { return _vars; }
81
82 /// Returns a constant reference to the \a i'th variable
83 const Var& var( size_t i ) const {
84 DAI_DEBASSERT( i < nrVars() );
85 return _vars[i];
86 }
87
88 /// Returns number of clusters
89 size_t nrClusters() const { return _clusters.size(); }
90
91 /// Returns a constant reference to the clusters
92 const std::vector<VarSet>& clusters() const { return _clusters; }
93
94 /// Returns a constant reference to the \a I'th cluster
95 const VarSet& cluster( size_t I ) const {
96 DAI_DEBASSERT( I < nrClusters() );
97 return _clusters[I];
98 }
99
100 /// Returns the index of variable \a n
101 size_t findVar( const Var& n ) const {
102 return find( _vars.begin(), _vars.end(), n ) - _vars.begin();
103 }
104
105 /// Returns the index of a cluster \a cl
106 size_t findCluster( const VarSet& cl ) const {
107 return find( _clusters.begin(), _clusters.end(), cl ) - _clusters.begin();
108 }
109
110 /// Returns union of clusters that contain the \a i 'th variable
111 VarSet Delta( size_t i ) const {
112 VarSet result;
113 foreach( const Neighbor& I, _G.nb1(i) )
114 result |= _clusters[I];
115 return result;
116 }
117
118 /// Returns union of clusters that contain the \a i 'th (except this variable itself)
119 VarSet delta( size_t i ) const {
120 return Delta( i ) / _vars[i];
121 }
122
123 /// Returns \c true if variables with indices \a i1 and \a i2 are adjacent, i.e., both contained in the same cluster
124 bool adj( size_t i1, size_t i2 ) const {
125 if( i1 == i2 )
126 return false;
127 bool result = false;
128 foreach( const Neighbor& I, _G.nb1(i1) )
129 if( find( _G.nb2(I).begin(), _G.nb2(I).end(), i2 ) != _G.nb2(I).end() ) {
130 result = true;
131 break;
132 }
133 return result;
134 }
135
136 /// Returns \c true if cluster \a I is not contained in a larger cluster
137 bool isMaximal( size_t I ) const {
138 DAI_DEBASSERT( I < _G.nrNodes2() );
139 const VarSet & clI = _clusters[I];
140 bool maximal = true;
141 // The following may not be optimal, since it may repeatedly test the same cluster *J
142 foreach( const Neighbor& i, _G.nb2(I) ) {
143 foreach( const Neighbor& J, _G.nb1(i) )
144 if( (J != I) && (clI << _clusters[J]) ) {
145 maximal = false;
146 break;
147 }
148 if( !maximal )
149 break;
150 }
151 return maximal;
152 }
153 //@}
154
155 /// \name Operations
156 //@{
157 /// Inserts a cluster (if it does not already exist) and creates new variables, if necessary
158 /** \note This function could be better optimized if the index of one variable in \a cl would be known.
159 * If one could assume _vars to be ordered, a binary search could be used instead of a linear one.
160 */
161 size_t insert( const VarSet& cl ) {
162 size_t index = findCluster( cl ); // OPTIMIZE ME
163 if( index == _clusters.size() ) {
164 _clusters.push_back( cl );
165 // add variables (if necessary) and calculate neighborhood of new cluster
166 std::vector<size_t> nbs;
167 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
168 size_t iter = findVar( *n ); // OPTIMIZE ME
169 nbs.push_back( iter );
170 if( iter == _vars.size() ) {
171 _G.addNode1();
172 _vars.push_back( *n );
173 }
174 }
175 _G.addNode2( nbs.begin(), nbs.end(), nbs.size() );
176 }
177 return index;
178 }
179
180 /// Erases all clusters that are not maximal
181 ClusterGraph& eraseNonMaximal() {
182 for( size_t I = 0; I < _G.nrNodes2(); ) {
183 if( !isMaximal(I) ) {
184 _clusters.erase( _clusters.begin() + I );
185 _G.eraseNode2(I);
186 } else
187 I++;
188 }
189 return *this;
190 }
191
192 /// Erases all clusters that contain the \a i 'th variable
193 ClusterGraph& eraseSubsuming( size_t i ) {
194 DAI_ASSERT( i < nrVars() );
195 while( _G.nb1(i).size() ) {
196 _clusters.erase( _clusters.begin() + _G.nb1(i)[0] );
197 _G.eraseNode2( _G.nb1(i)[0] );
198 }
199 return *this;
200 }
201
202 /// Eliminates variable with index \a i, without deleting the variable itself
203 /** \note This function can be better optimized
204 */
205 VarSet elimVar( size_t i ) {
206 DAI_ASSERT( i < nrVars() );
207 VarSet Di = Delta( i );
208 insert( Di / var(i) );
209 eraseSubsuming( i );
210 eraseNonMaximal();
211 return Di;
212 }
213 //@}
214
215 /// \name Input/Ouput
216 //@{
217 /// Writes a ClusterGraph to an output stream
218 friend std::ostream& operator << ( std::ostream& os, const ClusterGraph& cl ) {
219 os << cl.clusters();
220 return os;
221 }
222 //@}
223
224 /// \name Variable elimination
225 //@{
226 /// Performs Variable Elimination, keeping track of the interactions that are created along the way.
227 /** \tparam EliminationChoice should support "size_t operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars )"
228 * \param f function object which returns the next variable index to eliminate; for example, a dai::greedyVariableElimination object.
229 * \param maxStates maximum total number of states of all clusters in the output cluster graph (0 means no limit).
230 * \throws OUT_OF_MEMORY if total number of states becomes larger than maxStates
231 * \return A set of elimination "cliques".
232 */
233 template<class EliminationChoice>
234 ClusterGraph VarElim( EliminationChoice f, size_t maxStates=0 ) const {
235 // Make a copy
236 ClusterGraph cl(*this);
237 cl.eraseNonMaximal();
238
239 ClusterGraph result;
240
241 // Construct set of variable indices
242 std::set<size_t> varindices;
243 for( size_t i = 0; i < _vars.size(); ++i )
244 varindices.insert( i );
245
246 // Do variable elimination
247 long double totalStates = 0.0;
248 while( !varindices.empty() ) {
249 size_t i = f( cl, varindices );
250 VarSet Di = cl.elimVar( i );
251 result.insert( Di );
252 if( maxStates ) {
253 totalStates += Di.nrStates();
254 if( totalStates > maxStates )
255 DAI_THROW(OUT_OF_MEMORY);
256 }
257 varindices.erase( i );
258 }
259
260 return result;
261 }
262 //@}
263 };
264
265
266 /// Helper object for dai::ClusterGraph::VarElim()
267 /** Chooses the next variable to eliminate by picking them sequentially from a given vector of variables.
268 */
269 class sequentialVariableElimination {
270 private:
271 /// The variable elimination sequence
272 std::vector<Var> seq;
273 /// Counter
274 size_t i;
275
276 public:
277 /// Construct from vector of variables
278 sequentialVariableElimination( const std::vector<Var> s ) : seq(s), i(0) {}
279
280 /// Returns next variable in sequence
281 size_t operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ );
282 };
283
284
285 /// Helper object for dai::ClusterGraph::VarElim()
286 /** Chooses the next variable to eliminate greedily by taking the one that minimizes
287 * a given heuristic cost function.
288 */
289 class greedyVariableElimination {
290 public:
291 /// Type of cost functions to be used for greedy variable elimination
292 typedef size_t (*eliminationCostFunction)(const ClusterGraph &, size_t);
293
294 private:
295 /// Pointer to the cost function used
296 eliminationCostFunction heuristic;
297
298 public:
299 /// Construct from cost function
300 /** \note Examples of cost functions are eliminationCost_MinFill() and eliminationCost_WeightedMinFill().
301 */
302 greedyVariableElimination( eliminationCostFunction h ) : heuristic(h) {}
303
304 /// Returns the best variable from \a remainingVars to eliminate in the cluster graph \a cl by greedily minimizing the cost function.
305 /** This function calculates the cost for eliminating each variable in \a remaingVars and returns the variable which has lowest cost.
306 */
307 size_t operator()( const ClusterGraph &cl, const std::set<size_t>& remainingVars );
308 };
309
310
311 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinNeighbors" criterion.
312 /** The cost is measured as "number of neigboring nodes in the current adjacency graph",
313 * where the adjacency graph has the variables as its nodes and connects
314 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
315 */
316 size_t eliminationCost_MinNeighbors( const ClusterGraph& cl, size_t i );
317
318
319 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinWeight" criterion.
320 /** The cost is measured as "product of weights of neighboring nodes in the current adjacency graph",
321 * where the adjacency graph has the variables as its nodes and connects
322 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
323 * The weight of a node is the number of states of the corresponding variable.
324 */
325 size_t eliminationCost_MinWeight( const ClusterGraph& cl, size_t i );
326
327
328 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinFill" criterion.
329 /** The cost is measured as "number of added edges in the adjacency graph",
330 * where the adjacency graph has the variables as its nodes and connects
331 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
332 */
333 size_t eliminationCost_MinFill( const ClusterGraph& cl, size_t i );
334
335
336 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "WeightedMinFill" criterion.
337 /** The cost is measured as "total weight of added edges in the adjacency graph",
338 * where the adjacency graph has the variables as its nodes and connects
339 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
340 * The weight of an edge is the product of the number of states of the variables corresponding with its nodes.
341 */
342 size_t eliminationCost_WeightedMinFill( const ClusterGraph& cl, size_t i );
343
344
345 } // end of namespace dai
346
347
348 #endif