libDAI version 0.3.2
[libdai.git] / include / dai / clustergraph.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 /// \file
10 /// \brief Defines class ClusterGraph, which is used by JTree, TreeEP and HAK
11 /// \todo The "MinFill" and "WeightedMinFill" variable elimination heuristics seem designed for Markov graphs;
12 /// add similar heuristics which are designed for factor graphs.
13
14
15 #ifndef __defined_libdai_clustergraph_h
16 #define __defined_libdai_clustergraph_h
17
18
19 #include <set>
20 #include <vector>
21 #include <dai/varset.h>
22 #include <dai/bipgraph.h>
23 #include <dai/factorgraph.h>
24
25
26 namespace dai {
27
28
29 /// A ClusterGraph is a hypergraph with variables as nodes, and "clusters" (sets of variables) as hyperedges.
30 /** It is implemented as a bipartite graph with variable (Var) nodes and cluster (VarSet) nodes.
31 * One may think of a ClusterGraph as a FactorGraph without the actual factor values.
32 * \todo Remove the _vars and _clusters variables and use only the graph and a contextual factor graph.
33 */
34 class ClusterGraph {
35 private:
36 /// Stores the neighborhood structure
37 BipartiteGraph _G;
38
39 /// Stores the variables corresponding to the nodes
40 std::vector<Var> _vars;
41
42 /// Stores the clusters corresponding to the hyperedges
43 std::vector<VarSet> _clusters;
44
45 public:
46 /// \name Constructors and destructors
47 //@{
48 /// Default constructor
49 ClusterGraph() : _G(), _vars(), _clusters() {}
50
51 /// Construct from vector of VarSet 's
52 ClusterGraph( const std::vector<VarSet>& cls );
53
54 /// Construct from a factor graph
55 /** Creates cluster graph which has factors in \a fg as clusters if \a onlyMaximal == \c false,
56 * and only the maximal factors in \a fg if \a onlyMaximal == \c true.
57 */
58 ClusterGraph( const FactorGraph& fg, bool onlyMaximal );
59 //@}
60
61 /// \name Queries
62 //@{
63 /// Returns a constant reference to the graph structure
64 const BipartiteGraph& bipGraph() const { return _G; }
65
66 /// Returns number of variables
67 size_t nrVars() const { return _vars.size(); }
68
69 /// Returns a constant reference to the variables
70 const std::vector<Var>& vars() const { return _vars; }
71
72 /// Returns a constant reference to the \a i'th variable
73 const Var& var( size_t i ) const {
74 DAI_DEBASSERT( i < nrVars() );
75 return _vars[i];
76 }
77
78 /// Returns number of clusters
79 size_t nrClusters() const { return _clusters.size(); }
80
81 /// Returns a constant reference to the clusters
82 const std::vector<VarSet>& clusters() const { return _clusters; }
83
84 /// Returns a constant reference to the \a I'th cluster
85 const VarSet& cluster( size_t I ) const {
86 DAI_DEBASSERT( I < nrClusters() );
87 return _clusters[I];
88 }
89
90 /// Returns the index of variable \a n
91 size_t findVar( const Var& n ) const {
92 return find( _vars.begin(), _vars.end(), n ) - _vars.begin();
93 }
94
95 /// Returns the index of a cluster \a cl
96 size_t findCluster( const VarSet& cl ) const {
97 return find( _clusters.begin(), _clusters.end(), cl ) - _clusters.begin();
98 }
99
100 /// Returns union of clusters that contain the \a i 'th variable
101 VarSet Delta( size_t i ) const {
102 VarSet result;
103 bforeach( const Neighbor& I, _G.nb1(i) )
104 result |= _clusters[I];
105 return result;
106 }
107
108 /// Returns union of clusters that contain the \a i 'th (except this variable itself)
109 VarSet delta( size_t i ) const {
110 return Delta( i ) / _vars[i];
111 }
112
113 /// Returns \c true if variables with indices \a i1 and \a i2 are adjacent, i.e., both contained in the same cluster
114 bool adj( size_t i1, size_t i2 ) const {
115 if( i1 == i2 )
116 return false;
117 bool result = false;
118 bforeach( const Neighbor& I, _G.nb1(i1) )
119 if( find( _G.nb2(I).begin(), _G.nb2(I).end(), i2 ) != _G.nb2(I).end() ) {
120 result = true;
121 break;
122 }
123 return result;
124 }
125
126 /// Returns \c true if cluster \a I is not contained in a larger cluster
127 bool isMaximal( size_t I ) const {
128 DAI_DEBASSERT( I < _G.nrNodes2() );
129 const VarSet & clI = _clusters[I];
130 bool maximal = true;
131 // The following may not be optimal, since it may repeatedly test the same cluster *J
132 bforeach( const Neighbor& i, _G.nb2(I) ) {
133 bforeach( const Neighbor& J, _G.nb1(i) )
134 if( (J != I) && (clI << _clusters[J]) ) {
135 maximal = false;
136 break;
137 }
138 if( !maximal )
139 break;
140 }
141 return maximal;
142 }
143 //@}
144
145 /// \name Operations
146 //@{
147 /// Inserts a cluster (if it does not already exist) and creates new variables, if necessary
148 /** \note This function could be better optimized if the index of one variable in \a cl would be known.
149 * If one could assume _vars to be ordered, a binary search could be used instead of a linear one.
150 */
151 size_t insert( const VarSet& cl ) {
152 size_t index = findCluster( cl ); // OPTIMIZE ME
153 if( index == _clusters.size() ) {
154 _clusters.push_back( cl );
155 // add variables (if necessary) and calculate neighborhood of new cluster
156 std::vector<size_t> nbs;
157 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
158 size_t iter = findVar( *n ); // OPTIMIZE ME
159 nbs.push_back( iter );
160 if( iter == _vars.size() ) {
161 _G.addNode1();
162 _vars.push_back( *n );
163 }
164 }
165 _G.addNode2( nbs.begin(), nbs.end(), nbs.size() );
166 }
167 return index;
168 }
169
170 /// Erases all clusters that are not maximal
171 ClusterGraph& eraseNonMaximal() {
172 for( size_t I = 0; I < _G.nrNodes2(); ) {
173 if( !isMaximal(I) ) {
174 _clusters.erase( _clusters.begin() + I );
175 _G.eraseNode2(I);
176 } else
177 I++;
178 }
179 return *this;
180 }
181
182 /// Erases all clusters that contain the \a i 'th variable
183 ClusterGraph& eraseSubsuming( size_t i ) {
184 DAI_ASSERT( i < nrVars() );
185 while( _G.nb1(i).size() ) {
186 _clusters.erase( _clusters.begin() + _G.nb1(i)[0] );
187 _G.eraseNode2( _G.nb1(i)[0] );
188 }
189 return *this;
190 }
191
192 /// Eliminates variable with index \a i, without deleting the variable itself
193 /** \note This function can be better optimized
194 */
195 VarSet elimVar( size_t i ) {
196 DAI_ASSERT( i < nrVars() );
197 VarSet Di = Delta( i );
198 insert( Di / var(i) );
199 eraseSubsuming( i );
200 eraseNonMaximal();
201 return Di;
202 }
203 //@}
204
205 /// \name Input/Ouput
206 //@{
207 /// Writes a ClusterGraph to an output stream
208 friend std::ostream& operator << ( std::ostream& os, const ClusterGraph& cl ) {
209 os << cl.clusters();
210 return os;
211 }
212
213 /// Formats a ClusterGraph as a string
214 std::string toString() const {
215 std::stringstream ss;
216 ss << *this;
217 return ss.str();
218 }
219 //@}
220
221 /// \name Variable elimination
222 //@{
223 /// Performs Variable Elimination, keeping track of the interactions that are created along the way.
224 /** \tparam EliminationChoice should support "size_t operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars )"
225 * \param f function object which returns the next variable index to eliminate; for example, a dai::greedyVariableElimination object.
226 * \param maxStates maximum total number of states of all clusters in the output cluster graph (0 means no limit).
227 * \throws OUT_OF_MEMORY if total number of states becomes larger than maxStates
228 * \return A set of elimination "cliques".
229 */
230 template<class EliminationChoice>
231 ClusterGraph VarElim( EliminationChoice f, size_t maxStates=0 ) const {
232 // Make a copy
233 ClusterGraph cl(*this);
234 cl.eraseNonMaximal();
235
236 ClusterGraph result;
237
238 // Construct set of variable indices
239 std::set<size_t> varindices;
240 for( size_t i = 0; i < _vars.size(); ++i )
241 varindices.insert( i );
242
243 // Do variable elimination
244 BigInt totalStates = 0;
245 while( !varindices.empty() ) {
246 size_t i = f( cl, varindices );
247 VarSet Di = cl.elimVar( i );
248 result.insert( Di );
249 if( maxStates ) {
250 totalStates += Di.nrStates();
251 if( totalStates > (BigInt)maxStates )
252 DAI_THROW(OUT_OF_MEMORY);
253 }
254 varindices.erase( i );
255 }
256
257 return result;
258 }
259 //@}
260 };
261
262
263 /// Helper object for dai::ClusterGraph::VarElim()
264 /** Chooses the next variable to eliminate by picking them sequentially from a given vector of variables.
265 */
266 class sequentialVariableElimination {
267 private:
268 /// The variable elimination sequence
269 std::vector<Var> seq;
270 /// Counter
271 size_t i;
272
273 public:
274 /// Construct from vector of variables
275 sequentialVariableElimination( const std::vector<Var> s ) : seq(s), i(0) {}
276
277 /// Returns next variable in sequence
278 size_t operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ );
279 };
280
281
282 /// Helper object for dai::ClusterGraph::VarElim()
283 /** Chooses the next variable to eliminate greedily by taking the one that minimizes
284 * a given heuristic cost function.
285 */
286 class greedyVariableElimination {
287 public:
288 /// Type of cost functions to be used for greedy variable elimination
289 typedef size_t (*eliminationCostFunction)(const ClusterGraph &, size_t);
290
291 private:
292 /// Pointer to the cost function used
293 eliminationCostFunction heuristic;
294
295 public:
296 /// Construct from cost function
297 /** \note Examples of cost functions are eliminationCost_MinFill() and eliminationCost_WeightedMinFill().
298 */
299 greedyVariableElimination( eliminationCostFunction h ) : heuristic(h) {}
300
301 /// Returns the best variable from \a remainingVars to eliminate in the cluster graph \a cl by greedily minimizing the cost function.
302 /** This function calculates the cost for eliminating each variable in \a remaingVars and returns the variable which has lowest cost.
303 */
304 size_t operator()( const ClusterGraph &cl, const std::set<size_t>& remainingVars );
305 };
306
307
308 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinNeighbors" criterion.
309 /** The cost is measured as "number of neigboring nodes in the current adjacency graph",
310 * where the adjacency graph has the variables as its nodes and connects
311 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
312 */
313 size_t eliminationCost_MinNeighbors( const ClusterGraph& cl, size_t i );
314
315
316 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinWeight" criterion.
317 /** The cost is measured as "product of weights of neighboring nodes in the current adjacency graph",
318 * where the adjacency graph has the variables as its nodes and connects
319 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
320 * The weight of a node is the number of states of the corresponding variable.
321 */
322 size_t eliminationCost_MinWeight( const ClusterGraph& cl, size_t i );
323
324
325 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinFill" criterion.
326 /** The cost is measured as "number of added edges in the adjacency graph",
327 * where the adjacency graph has the variables as its nodes and connects
328 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
329 */
330 size_t eliminationCost_MinFill( const ClusterGraph& cl, size_t i );
331
332
333 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "WeightedMinFill" criterion.
334 /** The cost is measured as "total weight of added edges in the adjacency graph",
335 * where the adjacency graph has the variables as its nodes and connects
336 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
337 * The weight of an edge is the product of the number of states of the variables corresponding with its nodes.
338 */
339 size_t eliminationCost_WeightedMinFill( const ClusterGraph& cl, size_t i );
340
341
342 } // end of namespace dai
343
344
345 #endif