Improved WeightedGraph code and added unit tests
[libdai.git] / include / dai / weightedgraph.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 /** \file
13 * \brief Defines some utility functions for (weighted) undirected graphs, trees and rooted trees.
14 * \todo Improve general support for graphs and trees.
15 */
16
17
18 #ifndef __defined_libdai_weightedgraph_h
19 #define __defined_libdai_weightedgraph_h
20
21
22 #include <vector>
23 #include <map>
24 #include <iostream>
25 #include <set>
26 #include <limits>
27 #include <climits> // Work-around for bug in boost graph library
28 #include <dai/util.h>
29 #include <dai/exceptions.h>
30
31 #include <boost/graph/adjacency_list.hpp>
32 #include <boost/graph/prim_minimum_spanning_tree.hpp>
33 #include <boost/graph/kruskal_min_spanning_tree.hpp>
34
35
36 namespace dai {
37
38
39 /// Represents a directed edge
40 class DEdge {
41 public:
42 /// First node index (source of edge)
43 union {
44 size_t n1;
45 size_t first; /// alias
46 };
47
48 /// Second node index (target of edge)
49 union {
50 size_t n2;
51 size_t second; /// alias
52 };
53
54 /// Default constructor
55 DEdge() : n1(0), n2(0) {}
56
57 /// Constructs a directed edge pointing from \a m1 to \a m2
58 DEdge( size_t m1, size_t m2 ) : n1(m1), n2(m2) {}
59
60 /// Tests for equality
61 bool operator==( const DEdge &x ) const { return ((n1 == x.n1) && (n2 == x.n2)); }
62
63 /// Smaller-than operator (performs lexicographical comparison)
64 bool operator<( const DEdge &x ) const {
65 return( (n1 < x.n1) || ((n1 == x.n1) && (n2 < x.n2)) );
66 }
67
68 /// Writes a directed edge to an output stream
69 friend std::ostream & operator << (std::ostream & os, const DEdge & e) {
70 os << "(" << e.n1 << "->" << e.n2 << ")";
71 return os;
72 }
73 };
74
75
76 /// Represents an undirected edge
77 class UEdge {
78 public:
79 /// First node index
80 union {
81 size_t n1;
82 size_t first; /// alias
83 };
84 /// Second node index
85 union {
86 size_t n2;
87 size_t second; /// alias
88 };
89
90 /// Default constructor
91 UEdge() : n1(0), n2(0) {}
92
93 /// Constructs an undirected edge between \a m1 and \a m2
94 UEdge( size_t m1, size_t m2 ) : n1(m1), n2(m2) {}
95
96 /// Construct from DEdge
97 UEdge( const DEdge &e ) : n1(e.n1), n2(e.n2) {}
98
99 /// Tests for inequality (disregarding the ordering of the nodes)
100 bool operator==( const UEdge &x ) {
101 return ((n1 == x.n1) && (n2 == x.n2)) || ((n1 == x.n2) && (n2 == x.n1));
102 }
103
104 /// Smaller-than operator
105 bool operator<( const UEdge &x ) const {
106 size_t s = n1, l = n2;
107 if( s > l )
108 std::swap( s, l );
109 size_t xs = x.n1, xl = x.n2;
110 if( xs > xl )
111 std::swap( xs, xl );
112 return( (s < xs) || ((s == xs) && (l < xl)) );
113 }
114
115 /// Writes an undirected edge to an output stream
116 friend std::ostream & operator << (std::ostream & os, const UEdge & e) {
117 if( e.n1 < e.n2 )
118 os << "{" << e.n1 << "--" << e.n2 << "}";
119 else
120 os << "{" << e.n2 << "--" << e.n1 << "}";
121 return os;
122 }
123 };
124
125
126 /// Represents an undirected graph, implemented as a std::set of undirected edges
127 class GraphEL : public std::set<UEdge> {
128 public:
129 /// Default constructor
130 GraphEL() {}
131
132 /// Construct from range of objects that can be cast to UEdge
133 template <class InputIterator>
134 GraphEL( InputIterator begin, InputIterator end ) {
135 insert( begin, end );
136 }
137 };
138
139
140 /// Represents an undirected weighted graph, with weights of type \a T, implemented as a std::map mapping undirected edges to weights
141 template<class T> class WeightedGraph : public std::map<UEdge, T> {};
142
143
144 /// Represents a rooted tree, implemented as a vector of directed edges
145 /** By convention, the edges are stored such that they point away from
146 * the root and such that edges nearer to the root come before edges
147 * farther away from the root.
148 */
149 class RootedTree : public std::vector<DEdge> {
150 public:
151 /// Default constructor
152 RootedTree() {}
153
154 /// Constructs a rooted tree from a tree and a root
155 /** \pre T has no cycles and contains node \a Root
156 */
157 RootedTree( const GraphEL &T, size_t Root );
158 };
159
160
161 /// Constructs a minimum spanning tree from the (non-negatively) weighted graph \a G.
162 /** \param usePrim If true, use Prim's algorithm (complexity O(E log(V))), otherwise, use Kruskal's algorithm (complexity O(E log(E)))
163 * \note Uses implementation from Boost Graph Library.
164 * \note The vertices of \a G must be in the range [0,N) where N is the number of vertices of \a G.
165 */
166 template<typename T> RootedTree MinSpanningTree( const WeightedGraph<T> &G, bool usePrim ) {
167 RootedTree result;
168 if( G.size() > 0 ) {
169 using namespace boost;
170 using namespace std;
171 typedef adjacency_list< listS, vecS, undirectedS, no_property, property<edge_weight_t, double> > boostGraph;
172
173 set<size_t> nodes;
174 vector<UEdge> edges;
175 vector<double> weights;
176 edges.reserve( G.size() );
177 weights.reserve( G.size() );
178 for( typename WeightedGraph<T>::const_iterator e = G.begin(); e != G.end(); e++ ) {
179 weights.push_back( e->second );
180 edges.push_back( e->first );
181 nodes.insert( e->first.n1 );
182 nodes.insert( e->first.n2 );
183 }
184
185 size_t N = nodes.size();
186 for( set<size_t>::const_iterator it = nodes.begin(); it != nodes.end(); it++ )
187 if( *it >= N )
188 DAI_THROWE(RUNTIME_ERROR,"Vertices must be in range [0..N) where N is the number of vertices.");
189
190 boostGraph g( edges.begin(), edges.end(), weights.begin(), nodes.size() );
191 size_t root = *(nodes.begin());
192 GraphEL tree;
193 if( usePrim ) {
194 // Prim's algorithm
195 vector< graph_traits< boostGraph >::vertex_descriptor > p( num_vertices(g) );
196 prim_minimum_spanning_tree( g, &(p[0]) );
197
198 // Store tree edges in result
199 for( size_t i = 0; i != p.size(); i++ ) {
200 if( p[i] != i )
201 tree.insert( UEdge( p[i], i ) );
202 }
203 } else {
204 // Kruskal's algorithm
205 vector< graph_traits< boostGraph >::edge_descriptor > p( num_vertices(g) );
206 kruskal_minimum_spanning_tree( g, &(p[0]) );
207
208 // Store tree edges in result
209 for( size_t i = 0; i != p.size(); i++ ) {
210 size_t v1 = source( p[i], g );
211 size_t v2 = target( p[i], g );
212 if( v1 != v2 )
213 tree.insert( UEdge( v1, v2 ) );
214 }
215 }
216
217 // Direct edges in order to obtain a rooted tree
218 result = RootedTree( tree, root );
219 }
220 return result;
221 }
222
223
224 /// Constructs a minimum spanning tree from the (non-negatively) weighted graph \a G.
225 /** \param usePrim If true, use Prim's algorithm (complexity O(E log(V))), otherwise, use Kruskal's algorithm (complexity O(E log(E)))
226 * \note Uses implementation from Boost Graph Library.
227 * \note The vertices of \a G must be in the range [0,N) where N is the number of vertices of \a G.
228 */
229 template<typename T> RootedTree MaxSpanningTree( const WeightedGraph<T> &G, bool usePrim ) {
230 if( G.size() == 0 )
231 return RootedTree();
232 else {
233 T maxweight = G.begin()->second;
234 for( typename WeightedGraph<T>::const_iterator it = G.begin(); it != G.end(); it++ )
235 if( it->second > maxweight )
236 maxweight = it->second;
237 // make a copy of the graph
238 WeightedGraph<T> gr( G );
239 // invoke MinSpanningTree with negative weights
240 // (which have to be shifted to satisfy positivity criterion)
241 for( typename WeightedGraph<T>::iterator it = gr.begin(); it != gr.end(); it++ )
242 it->second = maxweight - it->second;
243 return MinSpanningTree( gr, usePrim );
244 }
245 }
246
247
248 } // end of namespace dai
249
250
251 #endif