New git HEAD version
[libdai.git] / src / regiongraph.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <algorithm>
10 #include <cmath>
11 #include <boost/dynamic_bitset.hpp>
12 #include <dai/regiongraph.h>
13 #include <dai/factorgraph.h>
14 #include <dai/clustergraph.h>
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 void RegionGraph::construct( const FactorGraph &fg, const std::vector<VarSet> &ors, const std::vector<Region> &irs, const std::vector<std::pair<size_t,size_t> > &edges ) {
24 // Copy factor graph structure
25 FactorGraph::operator=( fg );
26
27 // Copy inner regions
28 _IRs = irs;
29
30 // Construct outer regions (giving them counting number 1.0)
31 _ORs.clear();
32 _ORs.reserve( ors.size() );
33 bforeach( const VarSet &alpha, ors )
34 _ORs.push_back( FRegion(Factor(alpha, 1.0), 1.0) );
35
36 // For each factor, find an outer region that subsumes that factor.
37 // Then, multiply the outer region with that factor.
38 _fac2OR.clear();
39 _fac2OR.reserve( nrFactors() );
40 for( size_t I = 0; I < nrFactors(); I++ ) {
41 size_t alpha;
42 for( alpha = 0; alpha < nrORs(); alpha++ )
43 if( OR(alpha).vars() >> factor(I).vars() ) {
44 _fac2OR.push_back( alpha );
45 break;
46 }
47 DAI_ASSERT( alpha != nrORs() );
48 }
49 recomputeORs();
50
51 // Create bipartite graph
52 _G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
53 }
54
55
56 void RegionGraph::constructCVM( const FactorGraph &fg, const std::vector<VarSet> &cl, size_t verbose ) {
57 if( verbose )
58 cerr << "constructCVM called (" << fg.nrVars() << " vars, " << fg.nrFactors() << " facs, " << cl.size() << " clusters)" << endl;
59
60 // Retain only maximal clusters
61 if( verbose )
62 cerr << " Constructing ClusterGraph" << endl;
63 ClusterGraph cg( cl );
64 if( verbose )
65 cerr << " Erasing non-maximal clusters" << endl;
66 cg.eraseNonMaximal();
67
68 // Create inner regions - first pass
69 if( verbose )
70 cerr << " Creating inner regions (first pass)" << endl;
71 set<VarSet> betas;
72 for( size_t alpha = 0; alpha < cg.nrClusters(); alpha++ )
73 for( size_t alpha2 = alpha; (++alpha2) != cg.nrClusters(); ) {
74 VarSet intersection = cg.cluster(alpha) & cg.cluster(alpha2);
75 if( intersection.size() > 0 )
76 betas.insert( intersection );
77 }
78
79 // Create inner regions - subsequent passes
80 if( verbose )
81 cerr << " Creating inner regions (next passes)" << endl;
82 set<VarSet> new_betas;
83 do {
84 new_betas.clear();
85 for( set<VarSet>::const_iterator gamma = betas.begin(); gamma != betas.end(); gamma++ )
86 for( set<VarSet>::const_iterator gamma2 = gamma; (++gamma2) != betas.end(); ) {
87 VarSet intersection = (*gamma) & (*gamma2);
88 if( (intersection.size() > 0) && (betas.count(intersection) == 0) )
89 new_betas.insert( intersection );
90 }
91 betas.insert(new_betas.begin(), new_betas.end());
92 } while( new_betas.size() );
93
94 // Create inner regions - final phase
95 if( verbose )
96 cerr << " Creating inner regions (final phase)" << endl;
97 vector<Region> irs;
98 irs.reserve( betas.size() );
99 for( set<VarSet>::const_iterator beta = betas.begin(); beta != betas.end(); beta++ )
100 irs.push_back( Region(*beta,0.0) );
101
102 // Create edges
103 if( verbose )
104 cerr << " Creating edges" << endl;
105 vector<pair<size_t,size_t> > edges;
106 for( size_t beta = 0; beta < irs.size(); beta++ )
107 for( size_t alpha = 0; alpha < cg.nrClusters(); alpha++ )
108 if( cg.cluster(alpha) >> irs[beta] )
109 edges.push_back( pair<size_t,size_t>(alpha,beta) );
110
111 // Construct region graph
112 if( verbose )
113 cerr << " Constructing region graph" << endl;
114 construct( fg, cg.clusters(), irs, edges );
115
116 // Calculate counting numbers
117 if( verbose )
118 cerr << " Calculating counting numbers" << endl;
119 calcCVMCountingNumbers();
120
121 if( verbose )
122 cerr << "Done." << endl;
123 }
124
125
126 void RegionGraph::calcCVMCountingNumbers() {
127 // Calculates counting numbers of inner regions based upon counting numbers of outer regions
128
129 vector<vector<size_t> > ancestors(nrIRs());
130 boost::dynamic_bitset<> assigned(nrIRs());
131 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
132 IR(beta).c() = 0.0;
133 for( size_t beta2 = 0; beta2 < nrIRs(); beta2++ )
134 if( (beta2 != beta) && IR(beta2) >> IR(beta) )
135 ancestors[beta].push_back(beta2);
136 }
137
138 bool new_counting;
139 do {
140 new_counting = false;
141 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
142 if( !assigned[beta] ) {
143 bool has_unassigned_ancestor = false;
144 for( vector<size_t>::const_iterator beta2 = ancestors[beta].begin(); (beta2 != ancestors[beta].end()) && !has_unassigned_ancestor; beta2++ )
145 if( !assigned[*beta2] )
146 has_unassigned_ancestor = true;
147 if( !has_unassigned_ancestor ) {
148 Real c = 1.0;
149 bforeach( const Neighbor &alpha, nbIR(beta) )
150 c -= OR(alpha).c();
151 for( vector<size_t>::const_iterator beta2 = ancestors[beta].begin(); beta2 != ancestors[beta].end(); beta2++ )
152 c -= IR(*beta2).c();
153 IR(beta).c() = c;
154 assigned.set(beta, true);
155 new_counting = true;
156 }
157 }
158 }
159 } while( new_counting );
160 }
161
162
163 bool RegionGraph::checkCountingNumbers() const {
164 // Checks whether the counting numbers satisfy the fundamental relation
165
166 bool all_valid = true;
167 for( vector<Var>::const_iterator n = vars().begin(); n != vars().end(); n++ ) {
168 Real c_n = 0.0;
169 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
170 if( OR(alpha).vars().contains( *n ) )
171 c_n += OR(alpha).c();
172 for( size_t beta = 0; beta < nrIRs(); beta++ )
173 if( IR(beta).contains( *n ) )
174 c_n += IR(beta).c();
175 if( fabs(c_n - 1.0) > 1e-15 ) {
176 all_valid = false;
177 cerr << "WARNING: counting numbers do not satisfy relation for " << *n << "(c_n = " << c_n << ")." << endl;
178 }
179 }
180
181 return all_valid;
182 }
183
184
185 void RegionGraph::recomputeORs() {
186 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
187 OR(alpha).fill( 1.0 );
188 for( size_t I = 0; I < nrFactors(); I++ )
189 if( fac2OR(I) != -1U )
190 OR( fac2OR(I) ) *= factor( I );
191 }
192
193
194 void RegionGraph::recomputeORs( const VarSet &ns ) {
195 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
196 if( OR(alpha).vars().intersects( ns ) )
197 OR(alpha).fill( 1.0 );
198 for( size_t I = 0; I < nrFactors(); I++ )
199 if( fac2OR(I) != -1U )
200 if( OR( fac2OR(I) ).vars().intersects( ns ) )
201 OR( fac2OR(I) ) *= factor( I );
202 }
203
204
205 void RegionGraph::recomputeOR( size_t I ) {
206 DAI_ASSERT( I < nrFactors() );
207 if( fac2OR(I) != -1U ) {
208 size_t alpha = fac2OR(I);
209 OR(alpha).fill( 1.0 );
210 for( size_t J = 0; J < nrFactors(); J++ )
211 if( fac2OR(J) == alpha )
212 OR(alpha) *= factor( J );
213 }
214 }
215
216
217 /// Send RegionGraph to output stream
218 ostream & operator << (ostream & os, const RegionGraph & rg) {
219 os << "digraph RegionGraph {" << endl;
220 os << "node[shape=box];" << endl;
221 for( size_t alpha = 0; alpha < rg.nrORs(); alpha++ )
222 os << "\ta" << alpha << " [label=\"a" << alpha << ": " << rg.OR(alpha).vars() << ", c=" << rg.OR(alpha).c() << "\"];" << endl;
223 os << "node[shape=ellipse];" << endl;
224 for( size_t beta = 0; beta < rg.nrIRs(); beta++ )
225 os << "\tb" << beta << " [label=\"b" << beta << ": " << (VarSet)rg.IR(beta) << ", c=" << rg.IR(beta).c() << "\"];" << endl;
226 for( size_t alpha = 0; alpha < rg.nrORs(); alpha++ )
227 bforeach( const Neighbor &beta, rg.nbOR(alpha) )
228 os << "\ta" << alpha << " -> b" << beta << ";" << endl;
229 os << "}" << endl;
230 return os;
231 }
232
233
234 } // end of namespace dai