Cleaned up variable elimination code in ClusterGraph
[libdai.git] / include / dai / clustergraph.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /// \file
13 /// \brief Defines class ClusterGraph, which is used by JTree, TreeEP and HAK
14
15
16 #ifndef __defined_libdai_clustergraph_h
17 #define __defined_libdai_clustergraph_h
18
19
20 #include <set>
21 #include <vector>
22 #include <dai/varset.h>
23 #include <dai/bipgraph.h>
24
25
26 namespace dai {
27
28
29 class ClusterGraph;
30
31 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinFill" criterion.
32 /** The cost is measured as "number of added edges in the adjacency graph",
33 * where the adjacency graph has the variables as its nodes and connects
34 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
35 */
36 size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i );
37
38 /// Returns the best variable from \a remainingVars to eliminate in the cluster graph \a cl according to the "MinFill" criterion.
39 /** This function invokes eliminationCost_MinFill() for each variable in \a remainingVars, and returns
40 * the variable which has lowest cost according to eliminationCost_MinFill().
41 * \note This function can be passed to ClusterGraph::VarElim().
42 */
43 size_t eliminationChoice_MinFill( const ClusterGraph &cl, const std::set<size_t> &remainingVars );
44
45
46 /// A ClusterGraph is a hypergraph with variables as nodes, and "clusters" (sets of variables) as hyperedges.
47 /** It is implemented as a bipartite graph with variable (Var) nodes and cluster (VarSet) nodes.
48 */
49 class ClusterGraph {
50 public:
51 /// Stores the neighborhood structure
52 BipartiteGraph G;
53
54 /// Stores the variables corresponding to the nodes
55 std::vector<Var> vars;
56
57 /// Stores the clusters corresponding to the hyperedges
58 std::vector<VarSet> clusters;
59
60 /// Shorthand for BipartiteGraph::Neighbor
61 typedef BipartiteGraph::Neighbor Neighbor;
62
63 /// Shorthand for BipartiteGraph::Edge
64 typedef BipartiteGraph::Edge Edge;
65
66 public:
67 /// \name Constructors and destructors
68 //@{
69 /// Default constructor
70 ClusterGraph() : G(), vars(), clusters() {}
71
72 /// Construct from vector of VarSet 's
73 ClusterGraph( const std::vector<VarSet> & cls );
74 //@}
75
76 /// \name Queries
77 //@{
78 /// Returns a constant reference to the clusters
79 const std::vector<VarSet> & toVector() const { return clusters; }
80
81 /// Returns number of clusters
82 size_t size() const {
83 return G.nrNodes2();
84 }
85
86 /// Returns the index of variable \a n
87 size_t findVar( const Var &n ) const {
88 return find( vars.begin(), vars.end(), n ) - vars.begin();
89 }
90
91 /// Returns union of clusters that contain the \a i 'th variable
92 VarSet Delta( size_t i ) const {
93 VarSet result;
94 foreach( const Neighbor &I, G.nb1(i) )
95 result |= clusters[I];
96 return result;
97 }
98
99 /// Returns union of clusters that contain the \a i 'th (except this variable itself)
100 VarSet delta( size_t i ) const {
101 return Delta( i ) / vars[i];
102 }
103
104 /// Returns \c true if variables with indices \a i1 and \a i2 are adjacent, i.e., both contained in the same cluster
105 bool adj( size_t i1, size_t i2 ) const {
106 bool result = false;
107 foreach( const Neighbor &I, G.nb1(i1) )
108 if( find( G.nb2(I).begin(), G.nb2(I).end(), i2 ) != G.nb2(I).end() ) {
109 result = true;
110 break;
111 }
112 return result;
113 }
114
115 /// Returns \c true if cluster \a I is not contained in a larger cluster
116 bool isMaximal( size_t I ) const {
117 DAI_DEBASSERT( I < G.nrNodes2() );
118 const VarSet & clI = clusters[I];
119 bool maximal = true;
120 // The following may not be optimal, since it may repeatedly test the same cluster *J
121 foreach( const Neighbor &i, G.nb2(I) ) {
122 foreach( const Neighbor &J, G.nb1(i) )
123 if( (J != I) && (clI << clusters[J]) ) {
124 maximal = false;
125 break;
126 }
127 if( !maximal )
128 break;
129 }
130 return maximal;
131 }
132 //@}
133
134 /// \name Operations
135 //@{
136 /// Inserts a cluster (if it does not already exist)
137 void insert( const VarSet &cl ) {
138 if( find( clusters.begin(), clusters.end(), cl ) == clusters.end() ) {
139 clusters.push_back( cl );
140 // add variables (if necessary) and calculate neighborhood of new cluster
141 std::vector<size_t> nbs;
142 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
143 size_t iter = find( vars.begin(), vars.end(), *n ) - vars.begin();
144 nbs.push_back( iter );
145 if( iter == vars.size() ) {
146 G.addNode1();
147 vars.push_back( *n );
148 }
149 }
150 G.addNode2( nbs.begin(), nbs.end(), nbs.size() );
151 }
152 }
153
154 /// Erases all clusters that are not maximal
155 ClusterGraph& eraseNonMaximal() {
156 for( size_t I = 0; I < G.nrNodes2(); ) {
157 if( !isMaximal(I) ) {
158 clusters.erase( clusters.begin() + I );
159 G.eraseNode2(I);
160 } else
161 I++;
162 }
163 return *this;
164 }
165
166 /// Erases all clusters that contain the \a i 'th variable
167 ClusterGraph& eraseSubsuming( size_t i ) {
168 while( G.nb1(i).size() ) {
169 clusters.erase( clusters.begin() + G.nb1(i)[0] );
170 G.eraseNode2( G.nb1(i)[0] );
171 }
172 return *this;
173 }
174 //@}
175
176 /// \name Input/Ouput
177 //@{
178 /// Writes a ClusterGraph to an output stream
179 friend std::ostream & operator << ( std::ostream & os, const ClusterGraph & cl ) {
180 os << cl.toVector();
181 return os;
182 }
183 //@}
184
185 /// \name Variable elimination
186 //@{
187 /// Calculates cost of eliminating the \a i 'th variable.
188 /** The cost is measured as "number of added edges in the adjacency graph",
189 * where the adjacency graph has the variables as its nodes and connects
190 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
191 * \deprecated Please use dai::eliminationCost_MinFill() instead.
192 */
193 size_t eliminationCost( size_t i ) const {
194 return eliminationCost_MinFill( *this, i );
195 }
196
197 /// Performs Variable Elimination, only keeping track of the interactions that are created along the way.
198 /** \param ElimSeq The sequence in which to eliminate the variables
199 * \return A set of elimination "cliques"
200 * \deprecated Not used; if necessary, dai::ClusterGraph::VarElim( EliminationChoice & ) can be used instead.
201 */
202 ClusterGraph VarElim( const std::vector<Var> &ElimSeq ) const;
203
204 /// Performs Variable Elimination using the "MinFill" heuristic
205 /** The "MinFill" heuristic greedily minimizes the cost of eliminating a variable,
206 * measured with eliminationCost().
207 * \return A set of elimination "cliques"
208 * \deprecated Please use dai::ClusterGraph::VarElim( eliminationChoice_MinFill ) instead.
209 */
210 ClusterGraph VarElim_MinFill() const {
211 return VarElim( eliminationChoice_MinFill );
212 }
213
214 /// Performs Variable Elimination, only keeping track of the interactions that are created along the way.
215 /** \tparam EliminationChoice should support "size_t operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars )"
216 * \param f function object which returns the next variable index to eliminate; for example, eliminationChoice_MinFill()
217 * \return A set of elimination "cliques"
218 */
219 template<class EliminationChoice>
220 ClusterGraph VarElim( EliminationChoice &f ) const {
221 // Make a copy
222 ClusterGraph cl(*this);
223 cl.eraseNonMaximal();
224
225 ClusterGraph result;
226
227 // Construct set of variable indices
228 std::set<size_t> varindices;
229 for( size_t i = 0; i < vars.size(); ++i )
230 varindices.insert( i );
231
232 // Do variable elimination
233 while( !varindices.empty() ) {
234 size_t i = f( cl, varindices );
235 result.insert( cl.Delta( i ) );
236 cl.insert( cl.delta( i ) );
237 cl.eraseSubsuming( i );
238 cl.eraseNonMaximal();
239 varindices.erase( i );
240 }
241
242 return result;
243 }
244 //@}
245 };
246
247
248 } // end of namespace dai
249
250
251 #endif