a8c1596bfb98cd13b2ae29d05fa692c7fc6d3748
[libdai.git] / include / dai / clustergraph.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines class ClusterGraph, which is used by JTree, TreeEP and HAK
14 /// \todo The "MinFill" and "WeightedMinFill" variable elimination heuristics seem designed for Markov graphs;
15 /// add similar heuristics which are designed for factor graphs.
16
17
18 #ifndef __defined_libdai_clustergraph_h
19 #define __defined_libdai_clustergraph_h
20
21
22 #include <set>
23 #include <vector>
24 #include <dai/varset.h>
25 #include <dai/bipgraph.h>
26
27
28 namespace dai {
29
30
31 /// A ClusterGraph is a hypergraph with variables as nodes, and "clusters" (sets of variables) as hyperedges.
32 /** It is implemented as a bipartite graph with variable (Var) nodes and cluster (VarSet) nodes.
33 * One may think of a ClusterGraph as a FactorGraph without the actual factor values.
34 */
35 class ClusterGraph {
36 public:
37 /// Shorthand for BipartiteGraph::Neighbor
38 typedef BipartiteGraph::Neighbor Neighbor;
39
40 /// Shorthand for BipartiteGraph::Edge
41 typedef BipartiteGraph::Edge Edge;
42
43 private:
44 /// Stores the neighborhood structure
45 BipartiteGraph _G;
46
47 /// Stores the variables corresponding to the nodes
48 std::vector<Var> _vars;
49
50 /// Stores the clusters corresponding to the hyperedges
51 std::vector<VarSet> _clusters;
52
53 public:
54 /// \name Constructors and destructors
55 //@{
56 /// Default constructor
57 ClusterGraph() : _G(), _vars(), _clusters() {}
58
59 /// Construct from vector of VarSet 's
60 ClusterGraph( const std::vector<VarSet>& cls );
61 //@}
62
63 /// \name Queries
64 //@{
65 /// Returns a constant reference to the graph structure
66 const BipartiteGraph& bipGraph() const { return _G; }
67
68 /// Returns number of variables
69 size_t nrVars() const { return _vars.size(); }
70
71 /// Returns a constant reference to the variables
72 const std::vector<Var>& vars() const { return _vars; }
73
74 /// Returns a constant reference to the \a i'th variable
75 const Var& var( size_t i ) const {
76 DAI_DEBASSERT( i < nrVars() );
77 return _vars[i];
78 }
79
80 /// Returns number of clusters
81 size_t nrClusters() const { return _clusters.size(); }
82
83 /// Returns a constant reference to the clusters
84 const std::vector<VarSet>& clusters() const { return _clusters; }
85
86 /// Returns a constant reference to the \a I'th cluster
87 const VarSet& cluster( size_t I ) const {
88 DAI_DEBASSERT( I < nrClusters() );
89 return _clusters[I];
90 }
91
92 /// Returns the index of variable \a n
93 /** \throw OBJECT_NOT_FOUND if the variable does not occur in the cluster graph
94 */
95 size_t findVar( const Var& n ) const {
96 size_t r = find( _vars.begin(), _vars.end(), n ) - _vars.begin();
97 if( r == _vars.size() )
98 DAI_THROW(OBJECT_NOT_FOUND);
99 return r;
100 }
101
102 /// Returns union of clusters that contain the \a i 'th variable
103 VarSet Delta( size_t i ) const {
104 VarSet result;
105 foreach( const Neighbor &I, _G.nb1(i) )
106 result |= _clusters[I];
107 return result;
108 }
109
110 /// Returns union of clusters that contain the \a i 'th (except this variable itself)
111 VarSet delta( size_t i ) const {
112 return Delta( i ) / _vars[i];
113 }
114
115 /// Returns \c true if variables with indices \a i1 and \a i2 are adjacent, i.e., both contained in the same cluster
116 bool adj( size_t i1, size_t i2 ) const {
117 if( i1 == i2 )
118 return false;
119 bool result = false;
120 foreach( const Neighbor &I, _G.nb1(i1) )
121 if( find( _G.nb2(I).begin(), _G.nb2(I).end(), i2 ) != _G.nb2(I).end() ) {
122 result = true;
123 break;
124 }
125 return result;
126 }
127
128 /// Returns \c true if cluster \a I is not contained in a larger cluster
129 bool isMaximal( size_t I ) const {
130 DAI_DEBASSERT( I < _G.nrNodes2() );
131 const VarSet & clI = _clusters[I];
132 bool maximal = true;
133 // The following may not be optimal, since it may repeatedly test the same cluster *J
134 foreach( const Neighbor &i, _G.nb2(I) ) {
135 foreach( const Neighbor &J, _G.nb1(i) )
136 if( (J != I) && (clI << _clusters[J]) ) {
137 maximal = false;
138 break;
139 }
140 if( !maximal )
141 break;
142 }
143 return maximal;
144 }
145 //@}
146
147 /// \name Operations
148 //@{
149 /// Inserts a cluster (if it does not already exist)
150 void insert( const VarSet& cl ) {
151 if( find( _clusters.begin(), _clusters.end(), cl ) == _clusters.end() ) {
152 _clusters.push_back( cl );
153 // add variables (if necessary) and calculate neighborhood of new cluster
154 std::vector<size_t> nbs;
155 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
156 size_t iter = find( _vars.begin(), _vars.end(), *n ) - _vars.begin();
157 nbs.push_back( iter );
158 if( iter == _vars.size() ) {
159 _G.addNode1();
160 _vars.push_back( *n );
161 }
162 }
163 _G.addNode2( nbs.begin(), nbs.end(), nbs.size() );
164 }
165 }
166
167 /// Erases all clusters that are not maximal
168 ClusterGraph& eraseNonMaximal() {
169 for( size_t I = 0; I < _G.nrNodes2(); ) {
170 if( !isMaximal(I) ) {
171 _clusters.erase( _clusters.begin() + I );
172 _G.eraseNode2(I);
173 } else
174 I++;
175 }
176 return *this;
177 }
178
179 /// Erases all clusters that contain the \a i 'th variable
180 ClusterGraph& eraseSubsuming( size_t i ) {
181 while( _G.nb1(i).size() ) {
182 _clusters.erase( _clusters.begin() + _G.nb1(i)[0] );
183 _G.eraseNode2( _G.nb1(i)[0] );
184 }
185 return *this;
186 }
187 //@}
188
189 /// \name Input/Ouput
190 //@{
191 /// Writes a ClusterGraph to an output stream
192 friend std::ostream& operator << ( std::ostream& os, const ClusterGraph& cl ) {
193 os << cl.clusters();
194 return os;
195 }
196 //@}
197
198 /// \name Variable elimination
199 //@{
200 /// Performs Variable Elimination, keeping track of the interactions that are created along the way.
201 /** \tparam EliminationChoice should support "size_t operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars )"
202 * \param f function object which returns the next variable index to eliminate; for example, a dai::greedyVariableElimination object.
203 * \return A set of elimination "cliques".
204 */
205 template<class EliminationChoice>
206 ClusterGraph VarElim( EliminationChoice f ) const {
207 // Make a copy
208 ClusterGraph cl(*this);
209 cl.eraseNonMaximal();
210
211 ClusterGraph result;
212
213 // Construct set of variable indices
214 std::set<size_t> varindices;
215 for( size_t i = 0; i < _vars.size(); ++i )
216 varindices.insert( i );
217
218 // Do variable elimination
219 while( !varindices.empty() ) {
220 size_t i = f( cl, varindices );
221 DAI_ASSERT( i < _vars.size() );
222 result.insert( cl.Delta( i ) );
223 cl.insert( cl.delta( i ) );
224 cl.eraseSubsuming( i );
225 cl.eraseNonMaximal();
226 varindices.erase( i );
227 }
228
229 return result;
230 }
231 //@}
232 };
233
234
235 /// Helper object for dai::ClusterGraph::VarElim()
236 /** Chooses the next variable to eliminate by picking them sequentially from a given vector of variables.
237 */
238 class sequentialVariableElimination {
239 private:
240 /// The variable elimination sequence
241 std::vector<Var> seq;
242 /// Counter
243 size_t i;
244
245 public:
246 /// Construct from vector of variables
247 sequentialVariableElimination( const std::vector<Var> s ) : seq(s), i(0) {}
248
249 /// Returns next variable in sequence
250 size_t operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ );
251 };
252
253
254 /// Helper object for dai::ClusterGraph::VarElim()
255 /** Chooses the next variable to eliminate greedily by taking the one that minimizes
256 * a given heuristic cost function.
257 */
258 class greedyVariableElimination {
259 public:
260 /// Type of cost functions to be used for greedy variable elimination
261 typedef size_t (*eliminationCostFunction)(const ClusterGraph &, size_t);
262
263 private:
264 /// Pointer to the cost function used
265 eliminationCostFunction heuristic;
266
267 public:
268 /// Construct from cost function
269 /** \note Examples of cost functions are eliminationCost_MinFill() and eliminationCost_WeightedMinFill().
270 */
271 greedyVariableElimination( eliminationCostFunction h ) : heuristic(h) {}
272
273 /// Returns the best variable from \a remainingVars to eliminate in the cluster graph \a cl by greedily minimizing the cost function.
274 /** This function calculates the cost for eliminating each variable in \a remaingVars and returns the variable which has lowest cost.
275 */
276 size_t operator()( const ClusterGraph &cl, const std::set<size_t>& remainingVars );
277 };
278
279
280 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinNeighbors" criterion.
281 /** The cost is measured as "number of neigboring nodes in the current adjacency graph",
282 * where the adjacency graph has the variables as its nodes and connects
283 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
284 */
285 size_t eliminationCost_MinNeighbors( const ClusterGraph& cl, size_t i );
286
287
288 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinWeight" criterion.
289 /** The cost is measured as "product of weights of neighboring nodes in the current adjacency graph",
290 * where the adjacency graph has the variables as its nodes and connects
291 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
292 * The weight of a node is the number of states of the corresponding variable.
293 */
294 size_t eliminationCost_MinWeight( const ClusterGraph& cl, size_t i );
295
296
297 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinFill" criterion.
298 /** The cost is measured as "number of added edges in the adjacency graph",
299 * where the adjacency graph has the variables as its nodes and connects
300 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
301 */
302 size_t eliminationCost_MinFill( const ClusterGraph& cl, size_t i );
303
304
305 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "WeightedMinFill" criterion.
306 /** The cost is measured as "total weight of added edges in the adjacency graph",
307 * where the adjacency graph has the variables as its nodes and connects
308 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
309 * The weight of an edge is the product of the number of states of the variables corresponding with its nodes.
310 */
311 size_t eliminationCost_WeightedMinFill( const ClusterGraph& cl, size_t i );
312
313
314 } // end of namespace dai
315
316
317 #endif