Improved ClusterGraph implementation and MaxSpanningTreePrims implementation.
[libdai.git] / include / dai / clustergraph.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_clustergraph_h
23 #define __defined_libdai_clustergraph_h
24
25
26 #include <set>
27 #include <vector>
28 #include <dai/varset.h>
29 #include <dai/bipgraph.h>
30
31
32 namespace dai {
33
34
35 /// A ClusterGraph is a hypergraph with VarSets as nodes.
36 /// It is implemented as bipartite graph with variable nodes
37 /// and cluster (VarSet) nodes. The additional
38 /// functionality compared to a simple set<VarSet> is
39 /// finding maximal clusters, finding cliques, etc...
40 class ClusterGraph {
41 public:
42 BipartiteGraph G;
43 std::vector<Var> vars;
44 std::vector<VarSet> clusters;
45
46 typedef BipartiteGraph::Neighbor Neighbor;
47 typedef BipartiteGraph::Edge Edge;
48
49 public:
50 /// Default constructor
51 ClusterGraph() : G(), vars(), clusters() {}
52
53 /// Construct from vector<VarSet>
54 ClusterGraph( const std::vector<VarSet> & cls );
55
56 /// Copy constructor
57 ClusterGraph( const ClusterGraph &x ) : G(x.G), vars(x.vars), clusters(x.clusters) {}
58
59 /// Assignment operator
60 ClusterGraph& operator=( const ClusterGraph &x ) {
61 if( this != &x ) {
62 G = x.G;
63 vars = x.vars;
64 clusters = x.clusters;
65 }
66 return *this;
67 }
68
69 /// Returns true if cluster I is not contained in a larger cluster
70 bool isMaximal( size_t I ) const {
71 #ifdef DAI_DEBUG
72 assert( I < G.nr2() );
73 #endif
74 const VarSet & clI = clusters[I];
75 bool maximal = true;
76 // The following may not be optimal, since it may repeatedly test the same cluster *J
77 foreach( const Neighbor &i, G.nb2(I) ) {
78 foreach( const Neighbor &J, G.nb1(i) )
79 if( (J != I) && (clI << clusters[J]) ) {
80 maximal = false;
81 break;
82 }
83 if( !maximal )
84 break;
85 }
86 return maximal;
87 }
88
89 /// Erase all VarSets that are not maximal
90 ClusterGraph& eraseNonMaximal() {
91 for( size_t I = 0; I < G.nr2(); ) {
92 if( !isMaximal(I) ) {
93 clusters.erase( clusters.begin() + I );
94 G.erase2(I);
95 } else
96 I++;
97 }
98 return *this;
99 }
100
101 /// Return number of clusters
102 size_t size() const {
103 return G.nr2();
104 }
105
106 /// Return index of variable n
107 size_t findVar( const Var &n ) const {
108 return find( vars.begin(), vars.end(), n ) - vars.begin();
109 }
110
111 /// Returns true if vars with indices i1 and i2 are adjacent, i.e., both contained in the same cluster
112 bool adj( size_t i1, size_t i2 ) {
113 bool result = false;
114 foreach( const Neighbor &I, G.nb1(i1) )
115 if( find( G.nb2(I).begin(), G.nb2(I).end(), i2 ) != G.nb2(I).end() ) {
116 result = true;
117 break;
118 }
119 return result;
120 }
121
122 /// Returns union of clusters that contain the variable with index i
123 VarSet Delta( size_t i ) const {
124 VarSet result;
125 foreach( const Neighbor &I, G.nb1(i) )
126 result |= clusters[I];
127 return result;
128 }
129
130 /// Inserts a cluster (if it does not already exist)
131 void insert( const VarSet &cl ) {
132 if( find( clusters.begin(), clusters.end(), cl ) == clusters.end() ) {
133 clusters.push_back( cl );
134 // add variables (if necessary) and calculate neighborhood of new cluster
135 std::vector<size_t> nbs;
136 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
137 size_t iter = find( vars.begin(), vars.end(), *n ) - vars.begin();
138 nbs.push_back( iter );
139 if( iter == vars.size() ) {
140 G.add1();
141 vars.push_back( *n );
142 }
143 }
144 G.add2( nbs.begin(), nbs.end(), nbs.size() );
145 }
146 }
147
148 /// Returns union of clusters that contain variable with index i, minus this variable
149 VarSet delta( size_t i ) const {
150 return Delta( i ) / vars[i];
151 }
152
153 /// Erases all clusters that contain n where n is the variable with index i
154 ClusterGraph& eraseSubsuming( size_t i ) {
155 while( G.nb1(i).size() ) {
156 clusters.erase( clusters.begin() + G.nb1(i)[0] );
157 G.erase2( G.nb1(i)[0] );
158 }
159 return *this;
160 }
161
162 /// Provide read access to clusters
163 const std::vector<VarSet> & toVector() const { return clusters; }
164
165 /// Calculate cost of eliminating the variable with index i,
166 /// using as a measure "number of added edges in the adjacency graph"
167 /// where the adjacency graph has the variables as its nodes and
168 /// connects nodes i1 and i2 iff i1 and i2 occur in some common cluster
169 size_t eliminationCost( size_t i ) {
170 std::vector<size_t> id_n = G.delta1( i );
171
172 size_t cost = 0;
173
174 // for each unordered pair {i1,i2} adjacent to n
175 for( size_t _i1 = 0; _i1 < id_n.size(); _i1++ )
176 for( size_t _i2 = _i1 + 1; _i2 < id_n.size(); _i2++ ) {
177 // if i1 and i2 are not adjacent, eliminating n would make them adjacent
178 if( !adj(id_n[_i1], id_n[_i2]) )
179 cost++;
180 }
181
182 return cost;
183 }
184
185 /// Perform Variable Elimination without Probs, i.e. only keeping track of
186 /// the interactions that are created along the way.
187 /// Input: a set of outer clusters and an elimination sequence
188 /// Output: a set of elimination "cliques"
189 ClusterGraph VarElim( const std::vector<Var> &ElimSeq ) const;
190
191 /// Perform Variable Eliminiation using the MinFill heuristic
192 ClusterGraph VarElim_MinFill() const;
193
194 /// Send to output stream
195 friend std::ostream & operator << ( std::ostream & os, const ClusterGraph & cl ) {
196 os << cl.toVector();
197 return os;
198 }
199 };
200
201
202 } // end of namespace dai
203
204
205 #endif