Initial commit of libDAI-0.2.1
[libdai.git] / utils / createfg.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <iterator>
24 #include "factorgraph.h"
25 #include "weightedgraph.h"
26 #include "util.h"
27 #include <boost/program_options.hpp>
28
29
30 using namespace std;
31 namespace po = boost::program_options;
32
33
34 void MakeHOIFG( size_t N, size_t M, size_t k, double sigma, FactorGraph &fg ) {
35 vector<Var> vars;
36 vector<Factor> factors;
37
38 for( size_t i = 0; i < N; i++ )
39 vars.push_back(Var(i,2));
40
41 for( size_t I = 0; I < M; I++ ) {
42 VarSet vars;
43 while( vars.size() < k ) {
44 do {
45 size_t newind = (size_t)(N * rnd_uniform());
46 Var newvar = Var(newind, 2);
47 if( !(vars && newvar) ) {
48 vars |= newvar;
49 break;
50 }
51 } while( 1 );
52 }
53 Factor newfac(vars);
54 for( size_t t = 0; t < newfac.stateSpace(); t++ )
55 newfac[t] = exp(rnd_stdnormal() * sigma);
56 factors.push_back(newfac);
57 }
58
59 fg = FactorGraph(factors);
60 };
61
62
63 void MakeFullFG( size_t N, double sigma_w, double sigma_th, string type, FactorGraph &fg ) {
64 vector<Var> vars;
65 vector<Factor> factors;
66
67 double w[N][N];
68 double th[N];
69 double buf[4];
70
71 for( size_t i = 0; i < N; i++ )
72 vars.push_back(Var(i,2));
73
74 for( size_t i = 0; i < N; i++ )
75 for( size_t j = 0; j < N; j++ )
76 w[i][j] = 0.0;
77
78 for( size_t i = 0; i < N; i++ )
79 for( size_t j = i+1; j < N; j++ ) {
80 w[i][j] = rnd_stdnormal() * sigma_w;
81 if( type == "fe" )
82 w[i][j] = fabs(w[i][j]);
83 else if( type == "af" )
84 w[i][j] = -fabs(w[i][j]);
85 w[j][i] = w[i][j];
86 buf[0] = (buf[3] = exp(w[i][j]));
87 buf[1] = (buf[2] = exp(-w[i][j]));
88 factors.push_back(Factor(VarSet(vars[i],vars[j]),buf));
89 }
90
91 for( size_t i = 0; i < N; i++ ) {
92 th[i] = rnd_stdnormal() * sigma_th;
93 buf[0] = exp(th[i]);
94 buf[1] = exp(-th[i]);
95 factors.push_back(Factor(vars[i],buf));
96 }
97
98 fg = FactorGraph(factors);
99 };
100
101
102 void MakeGridFG( long periodic, long n, double sigma_w, double sigma_th, string type, FactorGraph &fg ) {
103 vector<Var> vars;
104 vector<Factor> factors;
105
106 long N = n*n;
107
108 double w[N][N];
109 double th[N];
110 double buf[4];
111
112 for( long i = 0; i < N; i++ )
113 vars.push_back(Var(i,2));
114
115 for( long i = 0; i < N; i++ )
116 for( long j = 0; j < N; j++ )
117 w[i][j] = 0.0;
118
119 for( long i = 0; i < n; i++ )
120 for( long j = 0; j < n; j++ ) {
121 if( i+1 < n || periodic )
122 w[i*n+j][((i+1)%n)*n+j] = 1.0;
123 if( i > 0 || periodic )
124 w[i*n+j][((i+n-1)%n)*n+j] = 1.0;
125 if( j+1 < n || periodic )
126 w[i*n+j][i*n+((j+1)%n)] = 1.0;
127 if( j > 0 || periodic )
128 w[i*n+j][i*n+((j+n-1)%n)] = 1.0;
129 }
130
131 for( long i = 0; i < N; i++ )
132 for( long j = i+1; j < N; j++ )
133 if( w[i][j] ) {
134 w[i][j] = rnd_stdnormal() * sigma_w;
135 if( type == "fe" )
136 w[i][j] = fabs(w[i][j]);
137 else if( type == "af" )
138 w[i][j] = -fabs(w[i][j]);
139 w[j][i] = w[i][j];
140 buf[0] = (buf[3] = exp(w[i][j]));
141 buf[1] = (buf[2] = exp(-w[i][j]));
142 factors.push_back(Factor(VarSet(vars[i],vars[j]),buf));
143 }
144
145 for( long i = 0; i < N; i++ ) {
146 th[i] = rnd_stdnormal() * sigma_th;
147 buf[0] = exp(th[i]);
148 buf[1] = exp(-th[i]);
149 factors.push_back(Factor(vars[i],buf));
150 }
151
152 fg = FactorGraph(factors);
153 };
154
155
156 void MakeDRegFG( size_t N, size_t d, double sigma_w, double sigma_th, string type, FactorGraph &fg ) {
157 vector<Var> vars;
158 vector<Factor> factors;
159
160 double w[N][N];
161 double th[N];
162 double buf[4];
163
164 for( size_t i = 0; i < N; i++ )
165 vars.push_back(Var(i,2));
166
167 for( size_t i = 0; i < N; i++ )
168 for( size_t j = 0; j < N; j++ )
169 w[i][j] = 0.0;
170
171 UEdgeVec g = RandomDRegularGraph( N, d );
172 for( size_t i = 0; i < g.size(); i++ ) {
173 w[g[i].n1][g[i].n2] = 1.0;
174 w[g[i].n2][g[i].n1] = 1.0;
175 }
176
177 for( size_t i = 0; i < N; i++ )
178 for( size_t j = i+1; j < N; j++ )
179 if( w[i][j] ) {
180 w[i][j] = rnd_stdnormal() * sigma_w;
181 if( type == "fe" )
182 w[i][j] = fabs(w[i][j]);
183 else if( type == "af" )
184 w[i][j] = -fabs(w[i][j]);
185 w[j][i] = w[i][j];
186 buf[0] = (buf[3] = exp(w[i][j]));
187 buf[1] = (buf[2] = exp(-w[i][j]));
188 factors.push_back(Factor(VarSet(vars[i],vars[j]),buf));
189 }
190
191 for( size_t i = 0; i < N; i++ ) {
192 th[i] = rnd_stdnormal() * sigma_th;
193 buf[0] = exp(th[i]);
194 buf[1] = exp(-th[i]);
195 factors.push_back(Factor(vars[i],buf));
196 }
197
198 fg = FactorGraph(factors);
199 };
200
201
202 const char *HOITYPE = "hoi";
203 const char *FULLTYPE = "full";
204 const char *GRIDTYPE = "grid";
205 const char *DREGTYPE = "dreg";
206
207
208 // Old usages:
209 // create_full_fg <N> <sigma_w> <sigma_th> <subtype>
210 // create_grid_fg <periodic> <n> <sigma_w> <sigma_th> <subtype>
211 // create_dreg_fg <d> <N> <sigma_w> <sigma_th> <subtype>
212
213
214 int main( int argc, char *argv[] ) {
215 try {
216 size_t N, M, k, d;
217 size_t periodic;
218 size_t seed;
219 double beta, sigma_w, sigma_th;
220 string type, subtype;
221
222 // Declare the supported options.
223 po::options_description desc("Allowed options");
224 desc.add_options()
225 ("help", "produce help message")
226 ("type", po::value<string>(&type), "factor graph type:\n\t'full', 'grid', 'dreg' or 'hoi'")
227 ("seed", po::value<size_t>(&seed), "random number seed")
228 ("subtype", po::value<string>(&subtype), "interactions type:\n\t'sg', 'fe' or 'af'\n\t(ignored for type=='hoi')")
229 ("N", po::value<size_t>(&N), "number of (binary) variables")
230 ("M", po::value<size_t>(&M), "number of factors\n\t(only for type=='hoi')")
231 ("k", po::value<size_t>(&k), "connectivity of the factors\n\t(only for type=='hoi')")
232 ("d", po::value<size_t>(&d), "variable connectivity\n\t(only for type=='dreg')")
233 ("beta", po::value<double>(&beta), "stddev of log-factor entries\n\t(only for type=='hoi')")
234 ("sigma_w", po::value<double>(&sigma_w), "stddev of pairwise interactions w_{ij}\n\t(ignored for type=='hoi')")
235 ("sigma_th", po::value<double>(&sigma_th), "stddev of singleton interactions th_i\n\t(ignored for type=='hoi')")
236 ("periodic", po::value<size_t>(&periodic), "0/1 corresponding to nonperiodic/periodic grid\n\t(only for type=='grid')")
237 ;
238
239 po::variables_map vm;
240 po::store(po::parse_command_line(argc, argv, desc), vm);
241 po::notify(vm);
242
243 if( vm.count("help") || !vm.count("type") ) {
244 if( vm.count("type") ) {
245 if( type == HOITYPE ) {
246 cout << "Creates a random factor graph of <N> binary variables and" << endl;
247 cout << "<M> factors, each factor being an interaction of <k> variables." << endl;
248 cout << "The entries of the factors are exponentials of i.i.d. Gaussian" << endl;
249 cout << "variables with mean 0 and standard deviation <beta>." << endl;
250 } else if( type == FULLTYPE ) {
251 cout << "Creates fully connected pairwise graphical model of <N> variables;" << endl;
252 } else if( type == GRIDTYPE ) {
253 cout << "Creates 2D Ising grid (periodic if <periodic>!=0) of (approx.) <N> variables;" << endl;
254 } else if( type == DREGTYPE ) {
255 cout << "Creates random d-regular graph of <N> nodes with uniform degree <d>" << endl;
256 cout << "(where <d><N> should be even);" << endl;
257 } else
258 cerr << "Unknown type (should be one of 'full', 'grid', 'dreg' or 'hoi')" << endl;
259
260 if( type == FULLTYPE || type == GRIDTYPE || type == DREGTYPE ) {
261 cout << "singleton interactions are Gaussian with mean 0 and standard" << endl;
262 cout << "deviation <sigma_th>; pairwise interactions are Gaussian with mean 0" << endl;
263 cout << "and standard deviation <sigma_w> if <subtype>=='sg', absolute value" << endl;
264 cout << "is taken if <subtype>=='fe' and a minus sign is added if <subtype>=='af'." << endl;
265 }
266 }
267 cout << endl << desc << endl;
268 return 1;
269 }
270
271 if( !vm.count("seed") )
272 throw "Please specify random number seed.";
273 rnd_seed( seed );
274 // srand( gsl_rng_default_seed );
275
276 FactorGraph fg;
277
278 cout << "# Factor graph made by " << argv[0] << endl;
279 cout << "# type = " << type << endl;
280
281 if( type == HOITYPE ) {
282 if( !vm.count("N") || !vm.count("M") || !vm.count("k") || !vm.count("beta") )
283 throw "Please specify all required arguments";
284 do {
285 MakeHOIFG( N, M, k, beta, fg );
286 } while( !fg.isConnected() );
287
288 cout << "# N = " << N << endl;
289 cout << "# M = " << M << endl;
290 cout << "# k = " << k << endl;
291 cout << "# beta = " << beta << endl;
292 } else if( type == FULLTYPE ) {
293 if( !vm.count("N") || !vm.count("sigma_w") || !vm.count("sigma_th") || !vm.count("subtype") )
294 throw "Please specify all required arguments";
295 MakeFullFG( N, sigma_w, sigma_th, subtype, fg );
296
297 cout << "# N = " << N << endl;
298 cout << "# sigma_w = " << sigma_w << endl;
299 cout << "# sigma_th = " << sigma_th << endl;
300 cout << "# subtype = " << subtype << endl;
301 } else if( type == GRIDTYPE ) {
302 if( !vm.count("N") || !vm.count("sigma_w") || !vm.count("sigma_th") || !vm.count("subtype") || !vm.count("periodic") )
303 throw "Please specify all required arguments";
304
305 size_t n = (size_t)sqrt((long double)N);
306 N = n * n;
307
308 MakeGridFG( periodic, n, sigma_w, sigma_th, subtype, fg );
309
310 cout << "# periodic = " << periodic << endl;
311 cout << "# n = " << n << endl;
312 cout << "# N = " << N << endl;
313 cout << "# sigma_w = " << sigma_w << endl;
314 cout << "# sigma_th = " << sigma_th << endl;
315 cout << "# subtype = " << subtype << endl;
316 } else if( type == DREGTYPE ) {
317 if( !vm.count("N") || !vm.count("sigma_w") || !vm.count("sigma_th") || !vm.count("subtype") || !vm.count("d") )
318 throw "Please specify all required arguments";
319
320 MakeDRegFG( N, d, sigma_w, sigma_th, subtype, fg );
321
322 cout << "# N = " << N << endl;
323 cout << "# d = " << d << endl;
324 cout << "# sigma_w = " << sigma_w << endl;
325 cout << "# sigma_th = " << sigma_th << endl;
326 cout << "# subtype = " << subtype << endl;
327 }
328
329 cout << "# seed = " << seed << endl;
330 cout << fg;
331 }
332 catch(exception& e) {
333 cerr << "Error: " << e.what() << endl;
334 return 1;
335 }
336 catch(const char * e) {
337 cerr << "Error: " << e << endl;
338 return 1;
339 }
340 catch(...) {
341 cerr << "Exception of unknown type!" << endl;
342 }
343
344 return 0;
345 }