Merge branch 'joris'
[libdai.git] / include / dai / clustergraph.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 /// \file
24 /// \brief Defines class ClusterGraph
25 /// \todo Improve documentation
26
27
28 #ifndef __defined_libdai_clustergraph_h
29 #define __defined_libdai_clustergraph_h
30
31
32 #include <set>
33 #include <vector>
34 #include <dai/varset.h>
35 #include <dai/bipgraph.h>
36
37
38 namespace dai {
39
40
41 /// A ClusterGraph is a hypergraph with VarSets as nodes.
42 /** It is implemented as bipartite graph with variable (Var) nodes
43 * and cluster (VarSet) nodes.
44 */
45 class ClusterGraph {
46 public:
47 /// Stores the neighborhood structure
48 BipartiteGraph G;
49
50 /// Stores the variables corresponding to the nodes
51 std::vector<Var> vars;
52
53 /// Stores the clusters corresponding to the hyperedges
54 std::vector<VarSet> clusters;
55
56 /// Shorthand for BipartiteGraph::Neighbor
57 typedef BipartiteGraph::Neighbor Neighbor;
58
59 /// Shorthand for BipartiteGraph::Edge
60 typedef BipartiteGraph::Edge Edge;
61
62 public:
63 /// Default constructor
64 ClusterGraph() : G(), vars(), clusters() {}
65
66 /// Construct from vector<VarSet>
67 ClusterGraph( const std::vector<VarSet> & cls );
68
69 /// Copy constructor
70 ClusterGraph( const ClusterGraph &x ) : G(x.G), vars(x.vars), clusters(x.clusters) {}
71
72 /// Assignment operator
73 ClusterGraph& operator=( const ClusterGraph &x ) {
74 if( this != &x ) {
75 G = x.G;
76 vars = x.vars;
77 clusters = x.clusters;
78 }
79 return *this;
80 }
81
82 /// Returns true if cluster I is not contained in a larger cluster
83 bool isMaximal( size_t I ) const {
84 #ifdef DAI_DEBUG
85 assert( I < G.nr2() );
86 #endif
87 const VarSet & clI = clusters[I];
88 bool maximal = true;
89 // The following may not be optimal, since it may repeatedly test the same cluster *J
90 foreach( const Neighbor &i, G.nb2(I) ) {
91 foreach( const Neighbor &J, G.nb1(i) )
92 if( (J != I) && (clI << clusters[J]) ) {
93 maximal = false;
94 break;
95 }
96 if( !maximal )
97 break;
98 }
99 return maximal;
100 }
101
102 /// Erases all VarSets that are not maximal
103 ClusterGraph& eraseNonMaximal() {
104 for( size_t I = 0; I < G.nr2(); ) {
105 if( !isMaximal(I) ) {
106 clusters.erase( clusters.begin() + I );
107 G.erase2(I);
108 } else
109 I++;
110 }
111 return *this;
112 }
113
114 /// Returns number of clusters
115 size_t size() const {
116 return G.nr2();
117 }
118
119 /// Returns index of variable n
120 size_t findVar( const Var &n ) const {
121 return find( vars.begin(), vars.end(), n ) - vars.begin();
122 }
123
124 /// Returns true if vars with indices i1 and i2 are adjacent, i.e., both contained in the same cluster
125 bool adj( size_t i1, size_t i2 ) {
126 bool result = false;
127 foreach( const Neighbor &I, G.nb1(i1) )
128 if( find( G.nb2(I).begin(), G.nb2(I).end(), i2 ) != G.nb2(I).end() ) {
129 result = true;
130 break;
131 }
132 return result;
133 }
134
135 /// Returns union of clusters that contain the variable with index i
136 VarSet Delta( size_t i ) const {
137 VarSet result;
138 foreach( const Neighbor &I, G.nb1(i) )
139 result |= clusters[I];
140 return result;
141 }
142
143 /// Inserts a cluster (if it does not already exist)
144 void insert( const VarSet &cl ) {
145 if( find( clusters.begin(), clusters.end(), cl ) == clusters.end() ) {
146 clusters.push_back( cl );
147 // add variables (if necessary) and calculate neighborhood of new cluster
148 std::vector<size_t> nbs;
149 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
150 size_t iter = find( vars.begin(), vars.end(), *n ) - vars.begin();
151 nbs.push_back( iter );
152 if( iter == vars.size() ) {
153 G.add1();
154 vars.push_back( *n );
155 }
156 }
157 G.add2( nbs.begin(), nbs.end(), nbs.size() );
158 }
159 }
160
161 /// Returns union of clusters that contain variable with index i, minus this variable
162 VarSet delta( size_t i ) const {
163 return Delta( i ) / vars[i];
164 }
165
166 /// Erases all clusters that contain n where n is the variable with index i
167 ClusterGraph& eraseSubsuming( size_t i ) {
168 while( G.nb1(i).size() ) {
169 clusters.erase( clusters.begin() + G.nb1(i)[0] );
170 G.erase2( G.nb1(i)[0] );
171 }
172 return *this;
173 }
174
175 /// Returns a const reference to the clusters
176 const std::vector<VarSet> & toVector() const { return clusters; }
177
178 /// Calculates cost of eliminating the variable with index i.
179 /** The cost is measured as "number of added edges in the adjacency graph",
180 * where the adjacency graph has the variables as its nodes and
181 * connects nodes i1 and i2 iff i1 and i2 occur in some common cluster.
182 */
183 size_t eliminationCost( size_t i ) {
184 std::vector<size_t> id_n = G.delta1( i );
185
186 size_t cost = 0;
187
188 // for each unordered pair {i1,i2} adjacent to n
189 for( size_t _i1 = 0; _i1 < id_n.size(); _i1++ )
190 for( size_t _i2 = _i1 + 1; _i2 < id_n.size(); _i2++ ) {
191 // if i1 and i2 are not adjacent, eliminating n would make them adjacent
192 if( !adj(id_n[_i1], id_n[_i2]) )
193 cost++;
194 }
195
196 return cost;
197 }
198
199 /// Performs Variable Elimination without Probs, i.e. only keeping track of
200 /* the interactions that are created along the way.
201 * \param ElimSeq A set of outer clusters and an elimination sequence
202 * \return A set of elimination "cliques"
203 * \todo Variable elimination should be implemented generically using a function
204 * object that tells you which variable to delete.
205 */
206 ClusterGraph VarElim( const std::vector<Var> &ElimSeq ) const;
207
208 /// Performs Variable Eliminiation using the MinFill heuristic
209 ClusterGraph VarElim_MinFill() const;
210
211 /// Writes a ClusterGraph to an output stream
212 friend std::ostream & operator << ( std::ostream & os, const ClusterGraph & cl ) {
213 os << cl.toVector();
214 return os;
215 }
216 };
217
218
219 } // end of namespace dai
220
221
222 #endif