75ab339bfe794a7fa726272593da068d9ce38d37
[libdai.git] / include / dai / clustergraph.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 /// \file
10 /// \brief Defines class ClusterGraph, which is used by JTree, TreeEP and HAK
11 /// \todo The "MinFill" and "WeightedMinFill" variable elimination heuristics seem designed for Markov graphs;
12 /// add similar heuristics which are designed for factor graphs.
13
14
15 #ifndef __defined_libdai_clustergraph_h
16 #define __defined_libdai_clustergraph_h
17
18
19 #include <set>
20 #include <vector>
21 #include <dai/varset.h>
22 #include <dai/bipgraph.h>
23 #include <dai/factorgraph.h>
24
25
26 namespace dai {
27
28
29 /// A ClusterGraph is a hypergraph with variables as nodes, and "clusters" (sets of variables) as hyperedges.
30 /** It is implemented as a bipartite graph with variable (Var) nodes and cluster (VarSet) nodes.
31 * One may think of a ClusterGraph as a FactorGraph without the actual factor values.
32 * \todo Remove the _vars and _clusters variables and use only the graph and a contextual factor graph.
33 */
34 class ClusterGraph {
35 private:
36 /// Stores the neighborhood structure
37 BipartiteGraph _G;
38
39 /// Stores the variables corresponding to the nodes
40 std::vector<Var> _vars;
41
42 /// Stores the clusters corresponding to the hyperedges
43 std::vector<VarSet> _clusters;
44
45 public:
46 /// \name Constructors and destructors
47 //@{
48 /// Default constructor
49 ClusterGraph() : _G(), _vars(), _clusters() {}
50
51 /// Construct from vector of VarSet 's
52 ClusterGraph( const std::vector<VarSet>& cls );
53
54 /// Construct from a factor graph
55 /** Creates cluster graph which has factors in \a fg as clusters if \a onlyMaximal == \c false,
56 * and only the maximal factors in \a fg if \a onlyMaximal == \c true.
57 */
58 ClusterGraph( const FactorGraph& fg, bool onlyMaximal );
59 //@}
60
61 /// \name Queries
62 //@{
63 /// Returns a constant reference to the graph structure
64 const BipartiteGraph& bipGraph() const { return _G; }
65
66 /// Returns number of variables
67 size_t nrVars() const { return _vars.size(); }
68
69 /// Returns a constant reference to the variables
70 const std::vector<Var>& vars() const { return _vars; }
71
72 /// Returns a constant reference to the \a i'th variable
73 const Var& var( size_t i ) const {
74 DAI_DEBASSERT( i < nrVars() );
75 return _vars[i];
76 }
77
78 /// Returns number of clusters
79 size_t nrClusters() const { return _clusters.size(); }
80
81 /// Returns a constant reference to the clusters
82 const std::vector<VarSet>& clusters() const { return _clusters; }
83
84 /// Returns a constant reference to the \a I'th cluster
85 const VarSet& cluster( size_t I ) const {
86 DAI_DEBASSERT( I < nrClusters() );
87 return _clusters[I];
88 }
89
90 /// Returns the index of variable \a n
91 size_t findVar( const Var& n ) const {
92 return find( _vars.begin(), _vars.end(), n ) - _vars.begin();
93 }
94
95 /// Returns the index of a cluster \a cl
96 size_t findCluster( const VarSet& cl ) const {
97 return find( _clusters.begin(), _clusters.end(), cl ) - _clusters.begin();
98 }
99
100 /// Returns union of clusters that contain the \a i 'th variable
101 VarSet Delta( size_t i ) const {
102 VarSet result;
103 bforeach( const Neighbor& I, _G.nb1(i) )
104 result |= _clusters[I];
105 return result;
106 }
107
108 /// Returns union of clusters that contain the \a i 'th (except this variable itself)
109 VarSet delta( size_t i ) const {
110 return Delta( i ) / _vars[i];
111 }
112
113 /// Returns \c true if variables with indices \a i1 and \a i2 are adjacent, i.e., both contained in the same cluster
114 bool adj( size_t i1, size_t i2 ) const {
115 if( i1 == i2 )
116 return false;
117 bool result = false;
118 bforeach( const Neighbor& I, _G.nb1(i1) )
119 if( find( _G.nb2(I).begin(), _G.nb2(I).end(), i2 ) != _G.nb2(I).end() ) {
120 result = true;
121 break;
122 }
123 return result;
124 }
125
126 /// Returns \c true if cluster \a I is not contained in a larger cluster
127 bool isMaximal( size_t I ) const {
128 DAI_DEBASSERT( I < _G.nrNodes2() );
129 const VarSet & clI = _clusters[I];
130 bool maximal = true;
131 // The following may not be optimal, since it may repeatedly test the same cluster *J
132 bforeach( const Neighbor& i, _G.nb2(I) ) {
133 bforeach( const Neighbor& J, _G.nb1(i) )
134 if( (J != I) && (clI << _clusters[J]) ) {
135 maximal = false;
136 break;
137 }
138 if( !maximal )
139 break;
140 }
141 return maximal;
142 }
143 //@}
144
145 /// \name Operations
146 //@{
147 /// Inserts a cluster (if it does not already exist) and creates new variables, if necessary
148 /** \note This function could be better optimized if the index of one variable in \a cl would be known.
149 * If one could assume _vars to be ordered, a binary search could be used instead of a linear one.
150 */
151 size_t insert( const VarSet& cl ) {
152 size_t index = findCluster( cl ); // OPTIMIZE ME
153 if( index == _clusters.size() ) {
154 _clusters.push_back( cl );
155 // add variables (if necessary) and calculate neighborhood of new cluster
156 std::vector<size_t> nbs;
157 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
158 size_t iter = findVar( *n ); // OPTIMIZE ME
159 nbs.push_back( iter );
160 if( iter == _vars.size() ) {
161 _G.addNode1();
162 _vars.push_back( *n );
163 }
164 }
165 _G.addNode2( nbs.begin(), nbs.end(), nbs.size() );
166 }
167 return index;
168 }
169
170 /// Erases all clusters that are not maximal
171 ClusterGraph& eraseNonMaximal() {
172 for( size_t I = 0; I < _G.nrNodes2(); ) {
173 if( !isMaximal(I) ) {
174 _clusters.erase( _clusters.begin() + I );
175 _G.eraseNode2(I);
176 } else
177 I++;
178 }
179 return *this;
180 }
181
182 /// Erases all clusters that contain the \a i 'th variable
183 ClusterGraph& eraseSubsuming( size_t i ) {
184 DAI_ASSERT( i < nrVars() );
185 while( _G.nb1(i).size() ) {
186 _clusters.erase( _clusters.begin() + _G.nb1(i)[0] );
187 _G.eraseNode2( _G.nb1(i)[0] );
188 }
189 return *this;
190 }
191
192 /// Eliminates variable with index \a i, without deleting the variable itself
193 /** \note This function can be better optimized
194 */
195 VarSet elimVar( size_t i ) {
196 DAI_ASSERT( i < nrVars() );
197 VarSet Di = Delta( i );
198 insert( Di / var(i) );
199 eraseSubsuming( i );
200 eraseNonMaximal();
201 return Di;
202 }
203 //@}
204
205 /// \name Input/Ouput
206 //@{
207 /// Writes a ClusterGraph to an output stream
208 friend std::ostream& operator << ( std::ostream& os, const ClusterGraph& cl ) {
209 os << cl.clusters();
210 return os;
211 }
212 //@}
213
214 /// \name Variable elimination
215 //@{
216 /// Performs Variable Elimination, keeping track of the interactions that are created along the way.
217 /** \tparam EliminationChoice should support "size_t operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars )"
218 * \param f function object which returns the next variable index to eliminate; for example, a dai::greedyVariableElimination object.
219 * \param maxStates maximum total number of states of all clusters in the output cluster graph (0 means no limit).
220 * \throws OUT_OF_MEMORY if total number of states becomes larger than maxStates
221 * \return A set of elimination "cliques".
222 */
223 template<class EliminationChoice>
224 ClusterGraph VarElim( EliminationChoice f, size_t maxStates=0 ) const {
225 // Make a copy
226 ClusterGraph cl(*this);
227 cl.eraseNonMaximal();
228
229 ClusterGraph result;
230
231 // Construct set of variable indices
232 std::set<size_t> varindices;
233 for( size_t i = 0; i < _vars.size(); ++i )
234 varindices.insert( i );
235
236 // Do variable elimination
237 BigInt totalStates = 0;
238 while( !varindices.empty() ) {
239 size_t i = f( cl, varindices );
240 VarSet Di = cl.elimVar( i );
241 result.insert( Di );
242 if( maxStates ) {
243 totalStates += Di.nrStates();
244 if( totalStates > (BigInt)maxStates )
245 DAI_THROW(OUT_OF_MEMORY);
246 }
247 varindices.erase( i );
248 }
249
250 return result;
251 }
252 //@}
253 };
254
255
256 /// Helper object for dai::ClusterGraph::VarElim()
257 /** Chooses the next variable to eliminate by picking them sequentially from a given vector of variables.
258 */
259 class sequentialVariableElimination {
260 private:
261 /// The variable elimination sequence
262 std::vector<Var> seq;
263 /// Counter
264 size_t i;
265
266 public:
267 /// Construct from vector of variables
268 sequentialVariableElimination( const std::vector<Var> s ) : seq(s), i(0) {}
269
270 /// Returns next variable in sequence
271 size_t operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ );
272 };
273
274
275 /// Helper object for dai::ClusterGraph::VarElim()
276 /** Chooses the next variable to eliminate greedily by taking the one that minimizes
277 * a given heuristic cost function.
278 */
279 class greedyVariableElimination {
280 public:
281 /// Type of cost functions to be used for greedy variable elimination
282 typedef size_t (*eliminationCostFunction)(const ClusterGraph &, size_t);
283
284 private:
285 /// Pointer to the cost function used
286 eliminationCostFunction heuristic;
287
288 public:
289 /// Construct from cost function
290 /** \note Examples of cost functions are eliminationCost_MinFill() and eliminationCost_WeightedMinFill().
291 */
292 greedyVariableElimination( eliminationCostFunction h ) : heuristic(h) {}
293
294 /// Returns the best variable from \a remainingVars to eliminate in the cluster graph \a cl by greedily minimizing the cost function.
295 /** This function calculates the cost for eliminating each variable in \a remaingVars and returns the variable which has lowest cost.
296 */
297 size_t operator()( const ClusterGraph &cl, const std::set<size_t>& remainingVars );
298 };
299
300
301 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinNeighbors" criterion.
302 /** The cost is measured as "number of neigboring nodes in the current adjacency graph",
303 * where the adjacency graph has the variables as its nodes and connects
304 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
305 */
306 size_t eliminationCost_MinNeighbors( const ClusterGraph& cl, size_t i );
307
308
309 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinWeight" criterion.
310 /** The cost is measured as "product of weights of neighboring nodes in the current adjacency graph",
311 * where the adjacency graph has the variables as its nodes and connects
312 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
313 * The weight of a node is the number of states of the corresponding variable.
314 */
315 size_t eliminationCost_MinWeight( const ClusterGraph& cl, size_t i );
316
317
318 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinFill" criterion.
319 /** The cost is measured as "number of added edges in the adjacency graph",
320 * where the adjacency graph has the variables as its nodes and connects
321 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
322 */
323 size_t eliminationCost_MinFill( const ClusterGraph& cl, size_t i );
324
325
326 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "WeightedMinFill" criterion.
327 /** The cost is measured as "total weight of added edges in the adjacency graph",
328 * where the adjacency graph has the variables as its nodes and connects
329 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
330 * The weight of an edge is the product of the number of states of the variables corresponding with its nodes.
331 */
332 size_t eliminationCost_WeightedMinFill( const ClusterGraph& cl, size_t i );
333
334
335 } // end of namespace dai
336
337
338 #endif