e5f78f45cdd6789fdade03370fcf4e1d409ee74a
[libdai.git] / utils / createfg.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <iterator>
24 #include <boost/program_options.hpp>
25 #include <dai/factorgraph.h>
26 #include <dai/weightedgraph.h>
27 #include <dai/util.h>
28
29
30 using namespace std;
31 using namespace dai;
32 namespace po = boost::program_options;
33
34
35 void MakeHOIFG( size_t N, size_t M, size_t k, double sigma, FactorGraph &fg ) {
36 vector<Var> vars;
37 vector<Factor> factors;
38
39 for( size_t i = 0; i < N; i++ )
40 vars.push_back(Var(i,2));
41
42 for( size_t I = 0; I < M; I++ ) {
43 VarSet vars;
44 while( vars.size() < k ) {
45 do {
46 size_t newind = (size_t)(N * rnd_uniform());
47 Var newvar = Var(newind, 2);
48 if( !(vars && newvar) ) {
49 vars |= newvar;
50 break;
51 }
52 } while( 1 );
53 }
54 Factor newfac(vars);
55 for( size_t t = 0; t < newfac.states(); t++ )
56 newfac[t] = exp(rnd_stdnormal() * sigma);
57 factors.push_back(newfac);
58 }
59
60 fg = FactorGraph(factors);
61 };
62
63
64 void MakeFullFG( size_t N, double sigma_w, double sigma_th, string type, FactorGraph &fg ) {
65 vector<Var> vars;
66 vector<Factor> factors;
67
68 double w[N][N];
69 double th[N];
70 double buf[4];
71
72 for( size_t i = 0; i < N; i++ )
73 vars.push_back(Var(i,2));
74
75 for( size_t i = 0; i < N; i++ )
76 for( size_t j = 0; j < N; j++ )
77 w[i][j] = 0.0;
78
79 for( size_t i = 0; i < N; i++ )
80 for( size_t j = i+1; j < N; j++ ) {
81 w[i][j] = rnd_stdnormal() * sigma_w;
82 if( type == "fe" )
83 w[i][j] = fabs(w[i][j]);
84 else if( type == "af" )
85 w[i][j] = -fabs(w[i][j]);
86 w[j][i] = w[i][j];
87 buf[0] = (buf[3] = exp(w[i][j]));
88 buf[1] = (buf[2] = exp(-w[i][j]));
89 factors.push_back(Factor(VarSet(vars[i],vars[j]),buf));
90 }
91
92 for( size_t i = 0; i < N; i++ ) {
93 th[i] = rnd_stdnormal() * sigma_th;
94 buf[0] = exp(th[i]);
95 buf[1] = exp(-th[i]);
96 factors.push_back(Factor(vars[i],buf));
97 }
98
99 fg = FactorGraph(factors);
100 };
101
102
103 void MakeGridFG( long periodic, long n, double sigma_w, double sigma_th, string type, FactorGraph &fg ) {
104 vector<Var> vars;
105 vector<Factor> factors;
106
107 long N = n*n;
108
109 double w[N][N];
110 double th[N];
111 double buf[4];
112
113 for( long i = 0; i < N; i++ )
114 vars.push_back(Var(i,2));
115
116 for( long i = 0; i < N; i++ )
117 for( long j = 0; j < N; j++ )
118 w[i][j] = 0.0;
119
120 for( long i = 0; i < n; i++ )
121 for( long j = 0; j < n; j++ ) {
122 if( i+1 < n || periodic )
123 w[i*n+j][((i+1)%n)*n+j] = 1.0;
124 if( i > 0 || periodic )
125 w[i*n+j][((i+n-1)%n)*n+j] = 1.0;
126 if( j+1 < n || periodic )
127 w[i*n+j][i*n+((j+1)%n)] = 1.0;
128 if( j > 0 || periodic )
129 w[i*n+j][i*n+((j+n-1)%n)] = 1.0;
130 }
131
132 for( long i = 0; i < N; i++ )
133 for( long j = i+1; j < N; j++ )
134 if( w[i][j] ) {
135 w[i][j] = rnd_stdnormal() * sigma_w;
136 if( type == "fe" )
137 w[i][j] = fabs(w[i][j]);
138 else if( type == "af" )
139 w[i][j] = -fabs(w[i][j]);
140 w[j][i] = w[i][j];
141 buf[0] = (buf[3] = exp(w[i][j]));
142 buf[1] = (buf[2] = exp(-w[i][j]));
143 factors.push_back(Factor(VarSet(vars[i],vars[j]),buf));
144 }
145
146 for( long i = 0; i < N; i++ ) {
147 th[i] = rnd_stdnormal() * sigma_th;
148 buf[0] = exp(th[i]);
149 buf[1] = exp(-th[i]);
150 factors.push_back(Factor(vars[i],buf));
151 }
152
153 fg = FactorGraph(factors);
154 };
155
156
157 void MakeDRegFG( size_t N, size_t d, double sigma_w, double sigma_th, string type, FactorGraph &fg ) {
158 vector<Var> vars;
159 vector<Factor> factors;
160
161 double w[N][N];
162 double th[N];
163 double buf[4];
164
165 for( size_t i = 0; i < N; i++ )
166 vars.push_back(Var(i,2));
167
168 for( size_t i = 0; i < N; i++ )
169 for( size_t j = 0; j < N; j++ )
170 w[i][j] = 0.0;
171
172 UEdgeVec g = RandomDRegularGraph( N, d );
173 for( size_t i = 0; i < g.size(); i++ ) {
174 w[g[i].n1][g[i].n2] = 1.0;
175 w[g[i].n2][g[i].n1] = 1.0;
176 }
177
178 for( size_t i = 0; i < N; i++ )
179 for( size_t j = i+1; j < N; j++ )
180 if( w[i][j] ) {
181 w[i][j] = rnd_stdnormal() * sigma_w;
182 if( type == "fe" )
183 w[i][j] = fabs(w[i][j]);
184 else if( type == "af" )
185 w[i][j] = -fabs(w[i][j]);
186 w[j][i] = w[i][j];
187 buf[0] = (buf[3] = exp(w[i][j]));
188 buf[1] = (buf[2] = exp(-w[i][j]));
189 factors.push_back(Factor(VarSet(vars[i],vars[j]),buf));
190 }
191
192 for( size_t i = 0; i < N; i++ ) {
193 th[i] = rnd_stdnormal() * sigma_th;
194 buf[0] = exp(th[i]);
195 buf[1] = exp(-th[i]);
196 factors.push_back(Factor(vars[i],buf));
197 }
198
199 fg = FactorGraph(factors);
200 };
201
202
203 const char *HOITYPE = "hoi";
204 const char *FULLTYPE = "full";
205 const char *GRIDTYPE = "grid";
206 const char *DREGTYPE = "dreg";
207
208
209 // Old usages:
210 // create_full_fg <N> <sigma_w> <sigma_th> <subtype>
211 // create_grid_fg <periodic> <n> <sigma_w> <sigma_th> <subtype>
212 // create_dreg_fg <d> <N> <sigma_w> <sigma_th> <subtype>
213
214
215 int main( int argc, char *argv[] ) {
216 try {
217 size_t N, M, k, d;
218 size_t periodic;
219 size_t seed;
220 double beta, sigma_w, sigma_th;
221 string type, subtype;
222
223 // Declare the supported options.
224 po::options_description desc("Allowed options");
225 desc.add_options()
226 ("help", "produce help message")
227 ("type", po::value<string>(&type), "factor graph type:\n\t'full', 'grid', 'dreg' or 'hoi'")
228 ("seed", po::value<size_t>(&seed), "random number seed")
229 ("subtype", po::value<string>(&subtype), "interactions type:\n\t'sg', 'fe' or 'af'\n\t(ignored for type=='hoi')")
230 ("N", po::value<size_t>(&N), "number of (binary) variables")
231 ("M", po::value<size_t>(&M), "number of factors\n\t(only for type=='hoi')")
232 ("k", po::value<size_t>(&k), "connectivity of the factors\n\t(only for type=='hoi')")
233 ("d", po::value<size_t>(&d), "variable connectivity\n\t(only for type=='dreg')")
234 ("beta", po::value<double>(&beta), "stddev of log-factor entries\n\t(only for type=='hoi')")
235 ("sigma_w", po::value<double>(&sigma_w), "stddev of pairwise interactions w_{ij}\n\t(ignored for type=='hoi')")
236 ("sigma_th", po::value<double>(&sigma_th), "stddev of singleton interactions th_i\n\t(ignored for type=='hoi')")
237 ("periodic", po::value<size_t>(&periodic), "0/1 corresponding to nonperiodic/periodic grid\n\t(only for type=='grid')")
238 ;
239
240 po::variables_map vm;
241 po::store(po::parse_command_line(argc, argv, desc), vm);
242 po::notify(vm);
243
244 if( vm.count("help") || !vm.count("type") ) {
245 if( vm.count("type") ) {
246 if( type == HOITYPE ) {
247 cout << "Creates a random factor graph of <N> binary variables and" << endl;
248 cout << "<M> factors, each factor being an interaction of <k> variables." << endl;
249 cout << "The entries of the factors are exponentials of i.i.d. Gaussian" << endl;
250 cout << "variables with mean 0 and standard deviation <beta>." << endl;
251 } else if( type == FULLTYPE ) {
252 cout << "Creates fully connected pairwise graphical model of <N> variables;" << endl;
253 } else if( type == GRIDTYPE ) {
254 cout << "Creates 2D Ising grid (periodic if <periodic>!=0) of (approx.) <N> variables;" << endl;
255 } else if( type == DREGTYPE ) {
256 cout << "Creates random d-regular graph of <N> nodes with uniform degree <d>" << endl;
257 cout << "(where <d><N> should be even);" << endl;
258 } else
259 cerr << "Unknown type (should be one of 'full', 'grid', 'dreg' or 'hoi')" << endl;
260
261 if( type == FULLTYPE || type == GRIDTYPE || type == DREGTYPE ) {
262 cout << "singleton interactions are Gaussian with mean 0 and standard" << endl;
263 cout << "deviation <sigma_th>; pairwise interactions are Gaussian with mean 0" << endl;
264 cout << "and standard deviation <sigma_w> if <subtype>=='sg', absolute value" << endl;
265 cout << "is taken if <subtype>=='fe' and a minus sign is added if <subtype>=='af'." << endl;
266 }
267 }
268 cout << endl << desc << endl;
269 return 1;
270 }
271
272 if( !vm.count("seed") )
273 throw "Please specify random number seed.";
274 rnd_seed( seed );
275 // srand( gsl_rng_default_seed );
276
277 FactorGraph fg;
278
279 cout << "# Factor graph made by " << argv[0] << endl;
280 cout << "# type = " << type << endl;
281
282 if( type == HOITYPE ) {
283 if( !vm.count("N") || !vm.count("M") || !vm.count("k") || !vm.count("beta") )
284 throw "Please specify all required arguments";
285 do {
286 MakeHOIFG( N, M, k, beta, fg );
287 } while( !fg.G.isConnected() );
288
289 cout << "# N = " << N << endl;
290 cout << "# M = " << M << endl;
291 cout << "# k = " << k << endl;
292 cout << "# beta = " << beta << endl;
293 } else if( type == FULLTYPE ) {
294 if( !vm.count("N") || !vm.count("sigma_w") || !vm.count("sigma_th") || !vm.count("subtype") )
295 throw "Please specify all required arguments";
296 MakeFullFG( N, sigma_w, sigma_th, subtype, fg );
297
298 cout << "# N = " << N << endl;
299 cout << "# sigma_w = " << sigma_w << endl;
300 cout << "# sigma_th = " << sigma_th << endl;
301 cout << "# subtype = " << subtype << endl;
302 } else if( type == GRIDTYPE ) {
303 if( !vm.count("N") || !vm.count("sigma_w") || !vm.count("sigma_th") || !vm.count("subtype") || !vm.count("periodic") )
304 throw "Please specify all required arguments";
305
306 size_t n = (size_t)sqrt((long double)N);
307 N = n * n;
308
309 MakeGridFG( periodic, n, sigma_w, sigma_th, subtype, fg );
310
311 cout << "# periodic = " << periodic << endl;
312 cout << "# n = " << n << endl;
313 cout << "# N = " << N << endl;
314 cout << "# sigma_w = " << sigma_w << endl;
315 cout << "# sigma_th = " << sigma_th << endl;
316 cout << "# subtype = " << subtype << endl;
317 } else if( type == DREGTYPE ) {
318 if( !vm.count("N") || !vm.count("sigma_w") || !vm.count("sigma_th") || !vm.count("subtype") || !vm.count("d") )
319 throw "Please specify all required arguments";
320
321 MakeDRegFG( N, d, sigma_w, sigma_th, subtype, fg );
322
323 cout << "# N = " << N << endl;
324 cout << "# d = " << d << endl;
325 cout << "# sigma_w = " << sigma_w << endl;
326 cout << "# sigma_th = " << sigma_th << endl;
327 cout << "# subtype = " << subtype << endl;
328 }
329
330 cout << "# seed = " << seed << endl;
331 cout << fg;
332 }
333 catch(exception& e) {
334 cerr << "Error: " << e.what() << endl;
335 return 1;
336 }
337 catch(const char * e) {
338 cerr << "Error: " << e << endl;
339 return 1;
340 }
341 catch(...) {
342 cerr << "Exception of unknown type!" << endl;
343 }
344
345 return 0;
346 }