Some small documentation updates
[libdai.git] / src / regiongraph.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <algorithm>
13 #include <cmath>
14 #include <boost/dynamic_bitset.hpp>
15 #include <dai/regiongraph.h>
16 #include <dai/factorgraph.h>
17 #include <dai/clustergraph.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 void RegionGraph::construct( const FactorGraph &fg, const std::vector<VarSet> &ors, const std::vector<Region> &irs, const std::vector<std::pair<size_t,size_t> > &edges ) {
27 // Copy factor graph structure
28 FactorGraph::operator=( fg );
29
30 // Copy inner regions
31 _IRs = irs;
32
33 // Construct outer regions (giving them counting number 1.0)
34 _ORs.clear();
35 _ORs.reserve( ors.size() );
36 foreach( const VarSet &alpha, ors )
37 _ORs.push_back( FRegion(Factor(alpha, 1.0), 1.0) );
38
39 // For each factor, find an outer region that subsumes that factor.
40 // Then, multiply the outer region with that factor.
41 _fac2OR.clear();
42 _fac2OR.reserve( nrFactors() );
43 for( size_t I = 0; I < nrFactors(); I++ ) {
44 size_t alpha;
45 for( alpha = 0; alpha < nrORs(); alpha++ )
46 if( OR(alpha).vars() >> factor(I).vars() ) {
47 _fac2OR.push_back( alpha );
48 break;
49 }
50 DAI_ASSERT( alpha != nrORs() );
51 }
52 recomputeORs();
53
54 // Create bipartite graph
55 _G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
56 }
57
58
59 void RegionGraph::constructCVM( const FactorGraph &fg, const std::vector<VarSet> &cl, size_t verbose ) {
60 if( verbose )
61 cerr << "constructCVM called (" << fg.nrVars() << " vars, " << fg.nrFactors() << " facs, " << cl.size() << " clusters)" << endl;
62
63 // Retain only maximal clusters
64 if( verbose )
65 cerr << " Constructing ClusterGraph" << endl;
66 ClusterGraph cg( cl );
67 if( verbose )
68 cerr << " Erasing non-maximal clusters" << endl;
69 cg.eraseNonMaximal();
70
71 // Create inner regions - first pass
72 if( verbose )
73 cerr << " Creating inner regions (first pass)" << endl;
74 set<VarSet> betas;
75 for( size_t alpha = 0; alpha < cg.nrClusters(); alpha++ )
76 for( size_t alpha2 = alpha; (++alpha2) != cg.nrClusters(); ) {
77 VarSet intersection = cg.cluster(alpha) & cg.cluster(alpha2);
78 if( intersection.size() > 0 )
79 betas.insert( intersection );
80 }
81
82 // Create inner regions - subsequent passes
83 if( verbose )
84 cerr << " Creating inner regions (next passes)" << endl;
85 set<VarSet> new_betas;
86 do {
87 new_betas.clear();
88 for( set<VarSet>::const_iterator gamma = betas.begin(); gamma != betas.end(); gamma++ )
89 for( set<VarSet>::const_iterator gamma2 = gamma; (++gamma2) != betas.end(); ) {
90 VarSet intersection = (*gamma) & (*gamma2);
91 if( (intersection.size() > 0) && (betas.count(intersection) == 0) )
92 new_betas.insert( intersection );
93 }
94 betas.insert(new_betas.begin(), new_betas.end());
95 } while( new_betas.size() );
96
97 // Create inner regions - final phase
98 if( verbose )
99 cerr << " Creating inner regions (final phase)" << endl;
100 vector<Region> irs;
101 irs.reserve( betas.size() );
102 for( set<VarSet>::const_iterator beta = betas.begin(); beta != betas.end(); beta++ )
103 irs.push_back( Region(*beta,0.0) );
104
105 // Create edges
106 if( verbose )
107 cerr << " Creating edges" << endl;
108 vector<pair<size_t,size_t> > edges;
109 for( size_t beta = 0; beta < irs.size(); beta++ )
110 for( size_t alpha = 0; alpha < cg.nrClusters(); alpha++ )
111 if( cg.cluster(alpha) >> irs[beta] )
112 edges.push_back( pair<size_t,size_t>(alpha,beta) );
113
114 // Construct region graph
115 if( verbose )
116 cerr << " Constructing region graph" << endl;
117 construct( fg, cg.clusters(), irs, edges );
118
119 // Calculate counting numbers
120 if( verbose )
121 cerr << " Calculating counting numbers" << endl;
122 calcCVMCountingNumbers();
123
124 if( verbose )
125 cerr << "Done." << endl;
126 }
127
128
129 void RegionGraph::calcCVMCountingNumbers() {
130 // Calculates counting numbers of inner regions based upon counting numbers of outer regions
131
132 vector<vector<size_t> > ancestors(nrIRs());
133 boost::dynamic_bitset<> assigned(nrIRs());
134 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
135 IR(beta).c() = 0.0;
136 for( size_t beta2 = 0; beta2 < nrIRs(); beta2++ )
137 if( (beta2 != beta) && IR(beta2) >> IR(beta) )
138 ancestors[beta].push_back(beta2);
139 }
140
141 bool new_counting;
142 do {
143 new_counting = false;
144 for( size_t beta = 0; beta < nrIRs(); beta++ ) {
145 if( !assigned[beta] ) {
146 bool has_unassigned_ancestor = false;
147 for( vector<size_t>::const_iterator beta2 = ancestors[beta].begin(); (beta2 != ancestors[beta].end()) && !has_unassigned_ancestor; beta2++ )
148 if( !assigned[*beta2] )
149 has_unassigned_ancestor = true;
150 if( !has_unassigned_ancestor ) {
151 Real c = 1.0;
152 foreach( const Neighbor &alpha, nbIR(beta) )
153 c -= OR(alpha).c();
154 for( vector<size_t>::const_iterator beta2 = ancestors[beta].begin(); beta2 != ancestors[beta].end(); beta2++ )
155 c -= IR(*beta2).c();
156 IR(beta).c() = c;
157 assigned.set(beta, true);
158 new_counting = true;
159 }
160 }
161 }
162 } while( new_counting );
163 }
164
165
166 bool RegionGraph::checkCountingNumbers() const {
167 // Checks whether the counting numbers satisfy the fundamental relation
168
169 bool all_valid = true;
170 for( vector<Var>::const_iterator n = vars().begin(); n != vars().end(); n++ ) {
171 Real c_n = 0.0;
172 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
173 if( OR(alpha).vars().contains( *n ) )
174 c_n += OR(alpha).c();
175 for( size_t beta = 0; beta < nrIRs(); beta++ )
176 if( IR(beta).contains( *n ) )
177 c_n += IR(beta).c();
178 if( fabs(c_n - 1.0) > 1e-15 ) {
179 all_valid = false;
180 cerr << "WARNING: counting numbers do not satisfy relation for " << *n << "(c_n = " << c_n << ")." << endl;
181 }
182 }
183
184 return all_valid;
185 }
186
187
188 void RegionGraph::recomputeORs() {
189 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
190 OR(alpha).fill( 1.0 );
191 for( size_t I = 0; I < nrFactors(); I++ )
192 if( fac2OR(I) != -1U )
193 OR( fac2OR(I) ) *= factor( I );
194 }
195
196
197 void RegionGraph::recomputeORs( const VarSet &ns ) {
198 for( size_t alpha = 0; alpha < nrORs(); alpha++ )
199 if( OR(alpha).vars().intersects( ns ) )
200 OR(alpha).fill( 1.0 );
201 for( size_t I = 0; I < nrFactors(); I++ )
202 if( fac2OR(I) != -1U )
203 if( OR( fac2OR(I) ).vars().intersects( ns ) )
204 OR( fac2OR(I) ) *= factor( I );
205 }
206
207
208 void RegionGraph::recomputeOR( size_t I ) {
209 DAI_ASSERT( I < nrFactors() );
210 if( fac2OR(I) != -1U ) {
211 size_t alpha = fac2OR(I);
212 OR(alpha).fill( 1.0 );
213 for( size_t J = 0; J < nrFactors(); J++ )
214 if( fac2OR(J) == alpha )
215 OR(alpha) *= factor( J );
216 }
217 }
218
219
220 /// Send RegionGraph to output stream
221 ostream & operator << (ostream & os, const RegionGraph & rg) {
222 os << "digraph RegionGraph {" << endl;
223 os << "node[shape=box];" << endl;
224 for( size_t alpha = 0; alpha < rg.nrORs(); alpha++ )
225 os << "\ta" << alpha << " [label=\"a" << alpha << ": " << rg.OR(alpha).vars() << ", c=" << rg.OR(alpha).c() << "\"];" << endl;
226 os << "node[shape=ellipse];" << endl;
227 for( size_t beta = 0; beta < rg.nrIRs(); beta++ )
228 os << "\tb" << beta << " [label=\"b" << beta << ": " << (VarSet)rg.IR(beta) << ", c=" << rg.IR(beta).c() << "\"];" << endl;
229 for( size_t alpha = 0; alpha < rg.nrORs(); alpha++ )
230 foreach( const RegionGraph::Neighbor &beta, rg.nbOR(alpha) )
231 os << "\ta" << alpha << " -> b" << beta << ";" << endl;
232 os << "}" << endl;
233 return os;
234 }
235
236
237 } // end of namespace dai