fc68e1bc22d437b968ca768d6b1353b807f96e85
[libdai.git] / include / dai / cobwebgraph.h
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2012, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 /// \file
10 /// \brief Defines class CobwebGraph, which implements a type of region graph used by GLC.
11
12
13 #ifndef __defined_libdai_cobwebgraph_h
14 #define __defined_libdai_cobwebgraph_h
15
16
17 #include <iostream>
18 #include <dai/factorgraph.h>
19 #include <dai/weightedgraph.h>
20 #include <dai/smallset.h>
21 #include <dai/enum.h>
22 #include <algorithm>
23 #include <map>
24 #include <set>
25
26
27 namespace dai {
28
29
30 /// A CobwebGraph is a special type of region graph used by the GLC algorithm
31 /** \author Siamak Ravanbakhsh
32 */
33 class CobwebGraph : public FactorGraph {
34 protected:
35 /// Vector of variable indices internal to each region (r)
36 std::vector<SmallSet<size_t> > _INRs;
37 /// Vector of variable indices on the boundary of each region (\ominus r)
38 std::vector<SmallSet<size_t> > _EXRs;
39 /// Index of factors in each region
40 std::vector<SmallSet<size_t> > _Rfs;
41 /// Index of factors internal to each region, i.e., all its variables are internal to the region
42 std::vector<SmallSet<size_t> > _Rifs;
43 /// Index of factors that bridge each region, i.e., not all its variables are internal to the region
44 std::vector<SmallSet<size_t> > _Rxfs;
45 /// The vector of domain of messages leaving each region (\ominus r_{p,q})
46 std::vector<std::vector<VarSet> > _outM;
47
48 /// The information in connection between two regions
49 struct Connection {
50 /// Index of the first region (p)
51 size_t my;
52 /// Index of the second region (q)
53 size_t his;
54 /// Index of this connection in the connections of the first region
55 size_t iter;
56 /// Index of the mirror of this connection in the connections of the second region. (reference of this message in _outM)
57 size_t dual;
58 /// The message sent from region q (his) to p (my)
59 Factor msg;
60 /// "Index" of variables in the message
61 std::vector<size_t> varinds;
62 /// Used as a temporary factor only
63 /** \todo Remove CobwebGraph::Connection::newmsg
64 */
65 Factor newmsg;
66 /// Index of factors in common
67 std::vector<size_t> fc;
68 /// Regions rho that are descendents of msg from q to p in p's msg-region graph (not used in partitioning case)
69 std::vector<VarSet> subregions;
70 };
71
72 /** Indicates what happens if a subset of variables in the boundary of a region (\ominus r_p) is shared by
73 * some neighbors such that one (\ominus r_{p,q1}) is a subset of another (\ominus r_{p,q2}).
74 * - ALL all such neighbors are included in the the updates.
75 * - TOP only (\ominus r_{p,q2}) is included unless (\ominus r_{p,q2} = \ominus r_{p,q1}) in which case both are included
76 * - CLOSEST (default): similar to TOP but in case of a tie the the region r_q with largest r_q \cap r_p is considered
77 * \note Not important in perfomance!
78 */
79 DAI_ENUM(NeighborType,ALL,TOP,CLOSEST);
80
81 /// Vector of all connections to each region
82 std::vector<std::vector<Connection> > _M;
83
84 /// For each i the index of (cobweb) regions that contain variable i
85 std::vector<std::vector<size_t> > _var2CW;
86
87 /** For each region r_p a mapping from all variables rho in its msg-region graph to a pair:
88 * first: counting number
89 * second: the index of top-regions that contain rho
90 */
91 std::vector<std::map<VarSet, std::pair<int,std::vector<size_t> > > > _cn;
92
93 /// Whether a given set of region vars makes a partitioning or not
94 bool isPartition;
95
96
97 public:
98 /// \name Constructors and destructors
99 //@{
100 /// Default constructor
101 CobwebGraph() : FactorGraph(), _INRs(), _EXRs(),_Rfs(),_Rifs(),_Rxfs(), _M(), _var2CW(), _cn(), isPartition(true) {}
102
103 /// Constructs a cobweb graph from a factor graph
104 CobwebGraph( const FactorGraph& fg ): FactorGraph(), _INRs(), _EXRs(),_Rfs(), _M(), _var2CW(), _cn(), isPartition(true) {
105 // Copy factor graph structure
106 FactorGraph::operator=( fg );
107 }
108
109 /// Clone \c *this (virtual copy constructor)
110 virtual CobwebGraph* clone() const { return new CobwebGraph(*this); }
111 //@}
112
113 /// \name Accessors and mutators
114 //@{
115 /// Returns the number of regions
116 size_t nrCWs() const { return _INRs.size(); }
117
118 /// Returns constant reference to the _cn for region \a R
119 const std::map<VarSet, std::pair<int,std::vector<size_t> > >& cn( size_t R ) const {
120 DAI_DEBASSERT( R < nrCWs() );
121 return _cn[R];
122 }
123
124 /// Returns a reference to the _cn for region \a R
125 std::map<VarSet, std::pair<int,std::vector<size_t> > >& cn( size_t R ) {
126 DAI_DEBASSERT( R < nrCWs() );
127 return _cn[R];
128 }
129
130 /// Returns constant reference the vector of domain of all outgoing messages from region \a R
131 const std::vector<VarSet>& outM( size_t R ) const {
132 DAI_DEBASSERT( R < _outM.size() );
133 return _outM[R];
134 }
135
136 /// Returns reference the vector of domain of all outgoing messages from region \a R
137 std::vector<VarSet>& outM( size_t R ) {
138 DAI_DEBASSERT( R < _outM.size() );
139 return _outM[R];
140 }
141
142 /// Returns constant reference to the variables in the outgoing message from region \a R to its \a j'th neighbor
143 /** \a j corresponds to dual in the connection construct
144 */
145 const VarSet& outM( size_t R, size_t j ) const {
146 DAI_DEBASSERT( R < _outM.size() );
147 DAI_DEBASSERT( j < _outM[R].size() );
148 return _outM[R][j];
149 }
150
151 /// Returns a reference to the variables in the outgoing message from region \a R to its \a j'th neighbor
152 /** \a j corresponds to dual in the connection construct
153 */
154 VarSet& outM( size_t R, size_t j ) {
155 DAI_DEBASSERT( R < _outM.size() );
156 DAI_DEBASSERT( j < _outM[R].size() );
157 return _outM[R][j];
158 }
159
160 /// Returns constant reference to the index of factors of region \a R
161 const SmallSet<size_t>& Rfs( size_t R ) const {
162 DAI_DEBASSERT( R < _Rfs.size() );
163 return _Rfs[R];
164 }
165
166 /// Returns reference to the index of factors of region \a R
167 SmallSet<size_t>& Rfs( size_t R ) {
168 DAI_DEBASSERT( R < _Rfs.size() );
169 return _Rfs[R];
170 }
171
172 /// Returns constant reference to the index of variables on the boundary of region \a R (\ominus r)
173 const SmallSet<size_t>& EXRs( size_t R ) const {
174 DAI_DEBASSERT( R < _EXRs.size() );
175 return _EXRs[R];
176 }
177
178 /// Returns reference to the index of variables on the boundary of region \a R (\ominus r)
179 SmallSet<size_t>& EXRs( size_t R ) {
180 DAI_DEBASSERT( R < _EXRs.size() );
181 return _EXRs[R];
182 }
183
184 /// Returns constant reference to the index of variables inside region \a R (r)
185 const SmallSet<size_t>& INRs( size_t R ) const {
186 DAI_DEBASSERT( R < _INRs.size() );
187 return _INRs[R];
188 }
189
190 /// Returns reference to the index of variables inside region \a R (r)
191 SmallSet<size_t>& INRs( size_t R ) {
192 DAI_DEBASSERT( R < _INRs.size() );
193 return _INRs[R];
194 }
195
196 /// Returns constant reference to the connection structure from region \a R to its \a i'th neighbour
197 const Connection& M( size_t R, size_t i ) const {
198 DAI_DEBASSERT(R < _M.size());
199 DAI_DEBASSERT(i < _M[R].size());
200 return _M[R][i]; }
201
202 /// Returns reference to the connection structure from region \a R to its \a i'th neighbour
203 Connection& M( size_t R, size_t i ) {
204 DAI_DEBASSERT(R < _M.size());
205 DAI_DEBASSERT(i < _M[R].size());
206 return _M[R][i]; }
207
208 /// Returns constant reference to the vector of connection structure from region \a R to all its neighbours
209 const std::vector<Connection>& M( size_t R ) const {
210 DAI_DEBASSERT(R < _M.size());
211 return _M[R];
212 }
213
214 /// Returns vector of all connections to region \a R
215 std::vector<Connection>& M( size_t R ) { return _M[R]; }
216
217 /// Returns the vector of region indices that contain \a i as internal variable
218 const std::vector<size_t>& var2CW( size_t i ) const { return _var2CW[i]; }
219 //@}
220
221 /// \name Operations
222 //@{
223 /// Sets up all the regions and messages
224 void setRgn( std::vector<SmallSet<size_t> > regions, NeighborType neighbors, bool debugging = false );
225 //@}
226
227 /// \name Input/output
228 //@{
229 /// Reads a cobweb graph from a file
230 /** \note Not implemented yet
231 */
232 virtual void ReadFromFile( const char* /*filename*/ ) {
233 DAI_THROW(NOT_IMPLEMENTED);
234 }
235
236 /// Writes a cobweb graph to a file
237 /** \note Not implemented yet
238 */
239 virtual void WriteToFile( const char* /*filename*/, size_t /*precision*/=15 ) const {
240 DAI_THROW(NOT_IMPLEMENTED);
241 }
242
243 /// Writes a cobweb graph to an output stream
244 friend std::ostream& operator<< ( std::ostream& os, const CobwebGraph& rg );
245
246 /// Writes a cobweb graph to a GraphViz .dot file
247 /** \note Not implemented yet
248 */
249 virtual void printDot( std::ostream& /*os*/ ) const {
250 DAI_THROW(NOT_IMPLEMENTED);
251 }
252 //@}
253
254
255 protected:
256 /// The function to check for partitioning
257 bool checkPartition( const std::vector<SmallSet<size_t> >& regions ) const;
258
259 /// Helper function that sets the regions containing each variable using the values in _INRs
260 void setVar2CW();
261
262 /// Setup the _INRs, _EXRs and all the factor indices (e.g. Rifs)
263 void setExtnFact();
264
265 /// Helper function that setups the msgs (_M, _outM) using the values in _INRs and _EXRs
266 void setMSGs( NeighborType neighbors );
267
268 /// Sets _cn
269 void setCountingNumbers( bool debugging = false );
270
271 /// For the given set of variables for each region, removes the regions that are non-maximal
272 void eraseNonMaximal( std::vector<SmallSet<size_t> >& regions );
273 };
274
275
276 } // end of namespace dai
277
278
279 #endif