libDAI version 0.3.2
[libdai.git] / include / dai / cobwebgraph.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2012, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 /// \file
10 /// \brief Defines class CobwebGraph, which implements a type of region graph used by GLC.
11
12
13 #ifndef __defined_libdai_cobwebgraph_h
14 #define __defined_libdai_cobwebgraph_h
15
16
17 #include <iostream>
18 #include <dai/factorgraph.h>
19 #include <dai/weightedgraph.h>
20 #include <dai/smallset.h>
21 #include <dai/enum.h>
22 #include <algorithm>
23 #include <map>
24 #include <set>
25
26
27 namespace dai {
28
29
30 /// A CobwebGraph is a special type of region graph used by the GLC algorithm
31 /** \author Siamak Ravanbakhsh
32 * \todo Implement unit test for Cobwebgraph
33 */
34 class CobwebGraph : public FactorGraph {
35 public:
36 /// The information in connection between two regions
37 struct Connection {
38 /// Index of the first region (p)
39 size_t my;
40 /// Index of the second region (q)
41 size_t his;
42 /// Index of this connection in the connections of the first region
43 size_t iter;
44 /// Index of the mirror of this connection in the connections of the second region. (reference of this message in _outM)
45 size_t dual;
46 /// The message sent from region q (his) to p (my)
47 Factor msg;
48 /// "Index" of variables in the message
49 std::vector<size_t> varinds;
50 /// Used as a temporary factor only
51 /** \todo Remove CobwebGraph::Connection::newmsg
52 */
53 Factor newmsg;
54 /// Index of factors in common
55 std::vector<size_t> fc;
56 /// Regions rho that are descendents of msg from q to p in p's msg-region graph (not used in partitioning case)
57 std::vector<VarSet> subregions;
58 };
59
60 /** Indicates what happens if a subset of variables in the boundary of a region (\ominus r_p) is shared by
61 * some neighbors such that one (\ominus r_{p,q1}) is a subset of another (\ominus r_{p,q2}).
62 * - ALL all such neighbors are included in the the updates.
63 * - TOP only (\ominus r_{p,q2}) is included unless (\ominus r_{p,q2} = \ominus r_{p,q1}) in which case both are included
64 * - CLOSEST (default): similar to TOP but in case of a tie the the region r_q with largest r_q \cap r_p is considered
65 * \note Not important in perfomance!
66 */
67 DAI_ENUM(NeighborType,ALL,TOP,CLOSEST);
68
69 protected:
70 /// Vector of variable indices internal to each region (r)
71 std::vector<SmallSet<size_t> > _INRs;
72 /// Vector of variable indices on the boundary of each region (\ominus r)
73 std::vector<SmallSet<size_t> > _EXRs;
74 /// Index of factors in each region
75 std::vector<SmallSet<size_t> > _Rfs;
76 /// Index of factors internal to each region, i.e., all its variables are internal to the region
77 std::vector<SmallSet<size_t> > _Rifs;
78 /// Index of factors that bridge each region, i.e., not all its variables are internal to the region
79 std::vector<SmallSet<size_t> > _Rxfs;
80 /// The vector of domain of messages leaving each region (\ominus r_{p,q})
81 std::vector<std::vector<VarSet> > _outM;
82 /// Vector of all connections to each region
83 std::vector<std::vector<Connection> > _M;
84
85 /// For each i the index of (cobweb) regions that contain variable i
86 std::vector<std::vector<size_t> > _var2CW;
87
88 /** For each region r_p a mapping from all variables rho in its msg-region graph to a pair:
89 * first: counting number
90 * second: the index of top-regions that contain rho
91 */
92 std::vector<std::map<VarSet, std::pair<int,std::vector<size_t> > > > _cn;
93
94 /// Whether a given set of region vars makes a partitioning or not
95 bool isPartition;
96
97 public:
98 /// \name Constructors and destructors
99 //@{
100 /// Default constructor
101 CobwebGraph() : FactorGraph(), _INRs(), _EXRs(),_Rfs(),_Rifs(),_Rxfs(), _M(), _var2CW(), _cn(), isPartition(true) {}
102
103 /// Constructs a cobweb graph from a factor graph
104 CobwebGraph( const FactorGraph& fg ): FactorGraph(), _INRs(), _EXRs(),_Rfs(), _M(), _var2CW(), _cn(), isPartition(true) {
105 // Copy factor graph structure
106 FactorGraph::operator=( fg );
107 }
108
109 /// Clone \c *this (virtual copy constructor)
110 virtual CobwebGraph* clone() const { return new CobwebGraph(*this); }
111 //@}
112
113 /// \name Accessors and mutators
114 //@{
115 /// Returns the number of regions
116 size_t nrCWs() const { return _INRs.size(); }
117
118 /// Returns constant reference to the _cn for region \a R
119 const std::map<VarSet, std::pair<int,std::vector<size_t> > >& cn( size_t R ) const {
120 DAI_DEBASSERT( R < nrCWs() );
121 return _cn[R];
122 }
123
124 /// Returns a reference to the _cn for region \a R
125 std::map<VarSet, std::pair<int,std::vector<size_t> > >& cn( size_t R ) {
126 DAI_DEBASSERT( R < nrCWs() );
127 return _cn[R];
128 }
129
130 /// Returns constant reference the vector of domain of all outgoing messages from region \a R
131 const std::vector<VarSet>& outM( size_t R ) const {
132 DAI_DEBASSERT( R < _outM.size() );
133 return _outM[R];
134 }
135
136 /// Returns reference the vector of domain of all outgoing messages from region \a R
137 std::vector<VarSet>& outM( size_t R ) {
138 DAI_DEBASSERT( R < _outM.size() );
139 return _outM[R];
140 }
141
142 /// Returns constant reference to the variables in the outgoing message from region \a R to its \a j'th neighbor
143 /** \a j corresponds to dual in the connection construct
144 */
145 const VarSet& outM( size_t R, size_t j ) const {
146 DAI_DEBASSERT( R < _outM.size() );
147 DAI_DEBASSERT( j < _outM[R].size() );
148 return _outM[R][j];
149 }
150
151 /// Returns a reference to the variables in the outgoing message from region \a R to its \a j'th neighbor
152 /** \a j corresponds to dual in the connection construct
153 */
154 VarSet& outM( size_t R, size_t j ) {
155 DAI_DEBASSERT( R < _outM.size() );
156 DAI_DEBASSERT( j < _outM[R].size() );
157 return _outM[R][j];
158 }
159
160 /// Returns constant reference to the index of factors of region \a R
161 const SmallSet<size_t>& Rfs( size_t R ) const {
162 DAI_DEBASSERT( R < _Rfs.size() );
163 return _Rfs[R];
164 }
165
166 /// Returns reference to the index of factors of region \a R
167 SmallSet<size_t>& Rfs( size_t R ) {
168 DAI_DEBASSERT( R < _Rfs.size() );
169 return _Rfs[R];
170 }
171
172 /// Returns constant reference to the index of variables on the boundary of region \a R (\ominus r)
173 const SmallSet<size_t>& EXRs( size_t R ) const {
174 DAI_DEBASSERT( R < _EXRs.size() );
175 return _EXRs[R];
176 }
177
178 /// Returns reference to the index of variables on the boundary of region \a R (\ominus r)
179 SmallSet<size_t>& EXRs( size_t R ) {
180 DAI_DEBASSERT( R < _EXRs.size() );
181 return _EXRs[R];
182 }
183
184 /// Returns constant reference to the index of variables inside region \a R (r)
185 const SmallSet<size_t>& INRs( size_t R ) const {
186 DAI_DEBASSERT( R < _INRs.size() );
187 return _INRs[R];
188 }
189
190 /// Returns reference to the index of variables inside region \a R (r)
191 SmallSet<size_t>& INRs( size_t R ) {
192 DAI_DEBASSERT( R < _INRs.size() );
193 return _INRs[R];
194 }
195
196 /// Returns constant reference to the connection structure from region \a R to its \a i'th neighbour
197 const Connection& M( size_t R, size_t i ) const {
198 DAI_DEBASSERT(R < _M.size());
199 DAI_DEBASSERT(i < _M[R].size());
200 return _M[R][i]; }
201
202 /// Returns reference to the connection structure from region \a R to its \a i'th neighbour
203 Connection& M( size_t R, size_t i ) {
204 DAI_DEBASSERT(R < _M.size());
205 DAI_DEBASSERT(i < _M[R].size());
206 return _M[R][i]; }
207
208 /// Returns constant reference to the vector of connection structure from region \a R to all its neighbours
209 const std::vector<Connection>& M( size_t R ) const {
210 DAI_DEBASSERT(R < _M.size());
211 return _M[R];
212 }
213
214 /// Returns vector of all connections to region \a R
215 std::vector<Connection>& M( size_t R ) { return _M[R]; }
216
217 /// Returns the vector of region indices that contain \a i as internal variable
218 const std::vector<size_t>& var2CW( size_t i ) const { return _var2CW[i]; }
219 //@}
220
221 /// \name Operations
222 //@{
223 /// Sets up all the regions and messages
224 void setRgn( std::vector<SmallSet<size_t> > regions, NeighborType neighbors, bool debugging = false );
225 //@}
226
227 /// \name Input/output
228 //@{
229 /// Reads a cobweb graph from a file
230 /** \note Not implemented yet
231 */
232 virtual void ReadFromFile( const char* /*filename*/ ) {
233 DAI_THROW(NOT_IMPLEMENTED);
234 }
235
236 /// Writes a cobweb graph to a file
237 /** \note Not implemented yet
238 */
239 virtual void WriteToFile( const char* /*filename*/, size_t /*precision*/=15 ) const {
240 DAI_THROW(NOT_IMPLEMENTED);
241 }
242
243 /// Writes a cobweb graph to an output stream
244 friend std::ostream& operator<< ( std::ostream& /*os*/, const CobwebGraph& /*rg*/ ) {
245 DAI_THROW(NOT_IMPLEMENTED);
246 }
247
248 /// Formats a cobweb graph as a string
249 std::string toString() const {
250 std::stringstream ss;
251 ss << *this;
252 return ss.str();
253 }
254
255 /// Writes a cobweb graph to a GraphViz .dot file
256 /** \note Not implemented yet
257 */
258 virtual void printDot( std::ostream& /*os*/ ) const {
259 DAI_THROW(NOT_IMPLEMENTED);
260 }
261 //@}
262
263
264 protected:
265 /// The function to check for partitioning
266 bool checkPartition( const std::vector<SmallSet<size_t> >& regions ) const;
267
268 /// Helper function that sets the regions containing each variable using the values in _INRs
269 void setVar2CW();
270
271 /// Setup the _INRs, _EXRs and all the factor indices (e.g. Rifs)
272 void setExtnFact();
273
274 /// Helper function that setups the msgs (_M, _outM) using the values in _INRs and _EXRs
275 void setMSGs( NeighborType neighbors );
276
277 /// Sets _cn
278 void setCountingNumbers( bool debugging = false );
279
280 /// For the given set of variables for each region, removes the regions that are non-maximal
281 void eraseNonMaximal( std::vector<SmallSet<size_t> >& regions );
282 };
283
284
285 } // end of namespace dai
286
287
288 #endif