1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
24 #include <boost/program_options.hpp>
25 #include <dai/factorgraph.h>
26 #include <dai/weightedgraph.h>
32 namespace po
= boost::program_options
;
35 void MakeHOIFG( size_t N
, size_t M
, size_t k
, double sigma
, FactorGraph
&fg
) {
37 vector
<Factor
> factors
;
39 for( size_t i
= 0; i
< N
; i
++ )
40 vars
.push_back(Var(i
,2));
42 for( size_t I
= 0; I
< M
; I
++ ) {
44 while( vars
.size() < k
) {
46 size_t newind
= (size_t)(N
* rnd_uniform());
47 Var newvar
= Var(newind
, 2);
48 if( !(vars
&& newvar
) ) {
55 for( size_t t
= 0; t
< newfac
.states(); t
++ )
56 newfac
[t
] = exp(rnd_stdnormal() * sigma
);
57 factors
.push_back(newfac
);
60 fg
= FactorGraph(factors
);
64 void MakeFullFG( size_t N
, double sigma_w
, double sigma_th
, string type
, FactorGraph
&fg
) {
66 vector
<Factor
> factors
;
72 for( size_t i
= 0; i
< N
; i
++ )
73 vars
.push_back(Var(i
,2));
75 for( size_t i
= 0; i
< N
; i
++ )
76 for( size_t j
= 0; j
< N
; j
++ )
79 for( size_t i
= 0; i
< N
; i
++ )
80 for( size_t j
= i
+1; j
< N
; j
++ ) {
81 w
[i
][j
] = rnd_stdnormal() * sigma_w
;
83 w
[i
][j
] = fabs(w
[i
][j
]);
84 else if( type
== "af" )
85 w
[i
][j
] = -fabs(w
[i
][j
]);
87 buf
[0] = (buf
[3] = exp(w
[i
][j
]));
88 buf
[1] = (buf
[2] = exp(-w
[i
][j
]));
89 factors
.push_back(Factor(VarSet(vars
[i
],vars
[j
]),buf
));
92 for( size_t i
= 0; i
< N
; i
++ ) {
93 th
[i
] = rnd_stdnormal() * sigma_th
;
96 factors
.push_back(Factor(vars
[i
],buf
));
99 fg
= FactorGraph(factors
);
103 void MakeGridFG( long periodic
, long n
, double sigma_w
, double sigma_th
, string type
, FactorGraph
&fg
) {
105 vector
<Factor
> factors
;
113 for( long i
= 0; i
< N
; i
++ )
114 vars
.push_back(Var(i
,2));
116 for( long i
= 0; i
< N
; i
++ )
117 for( long j
= 0; j
< N
; j
++ )
120 for( long i
= 0; i
< n
; i
++ )
121 for( long j
= 0; j
< n
; j
++ ) {
122 if( i
+1 < n
|| periodic
)
123 w
[i
*n
+j
][((i
+1)%n
)*n
+j
] = 1.0;
124 if( i
> 0 || periodic
)
125 w
[i
*n
+j
][((i
+n
-1)%n
)*n
+j
] = 1.0;
126 if( j
+1 < n
|| periodic
)
127 w
[i
*n
+j
][i
*n
+((j
+1)%n
)] = 1.0;
128 if( j
> 0 || periodic
)
129 w
[i
*n
+j
][i
*n
+((j
+n
-1)%n
)] = 1.0;
132 for( long i
= 0; i
< N
; i
++ )
133 for( long j
= i
+1; j
< N
; j
++ )
135 w
[i
][j
] = rnd_stdnormal() * sigma_w
;
137 w
[i
][j
] = fabs(w
[i
][j
]);
138 else if( type
== "af" )
139 w
[i
][j
] = -fabs(w
[i
][j
]);
141 buf
[0] = (buf
[3] = exp(w
[i
][j
]));
142 buf
[1] = (buf
[2] = exp(-w
[i
][j
]));
143 factors
.push_back(Factor(VarSet(vars
[i
],vars
[j
]),buf
));
146 for( long i
= 0; i
< N
; i
++ ) {
147 th
[i
] = rnd_stdnormal() * sigma_th
;
149 buf
[1] = exp(-th
[i
]);
150 factors
.push_back(Factor(vars
[i
],buf
));
153 fg
= FactorGraph(factors
);
157 void MakeDRegFG( size_t N
, size_t d
, double sigma_w
, double sigma_th
, string type
, FactorGraph
&fg
) {
159 vector
<Factor
> factors
;
165 for( size_t i
= 0; i
< N
; i
++ )
166 vars
.push_back(Var(i
,2));
168 for( size_t i
= 0; i
< N
; i
++ )
169 for( size_t j
= 0; j
< N
; j
++ )
172 UEdgeVec g
= RandomDRegularGraph( N
, d
);
173 for( size_t i
= 0; i
< g
.size(); i
++ ) {
174 w
[g
[i
].n1
][g
[i
].n2
] = 1.0;
175 w
[g
[i
].n2
][g
[i
].n1
] = 1.0;
178 for( size_t i
= 0; i
< N
; i
++ )
179 for( size_t j
= i
+1; j
< N
; j
++ )
181 w
[i
][j
] = rnd_stdnormal() * sigma_w
;
183 w
[i
][j
] = fabs(w
[i
][j
]);
184 else if( type
== "af" )
185 w
[i
][j
] = -fabs(w
[i
][j
]);
187 buf
[0] = (buf
[3] = exp(w
[i
][j
]));
188 buf
[1] = (buf
[2] = exp(-w
[i
][j
]));
189 factors
.push_back(Factor(VarSet(vars
[i
],vars
[j
]),buf
));
192 for( size_t i
= 0; i
< N
; i
++ ) {
193 th
[i
] = rnd_stdnormal() * sigma_th
;
195 buf
[1] = exp(-th
[i
]);
196 factors
.push_back(Factor(vars
[i
],buf
));
199 fg
= FactorGraph(factors
);
203 const char *HOITYPE
= "hoi";
204 const char *FULLTYPE
= "full";
205 const char *GRIDTYPE
= "grid";
206 const char *DREGTYPE
= "dreg";
210 // create_full_fg <N> <sigma_w> <sigma_th> <subtype>
211 // create_grid_fg <periodic> <n> <sigma_w> <sigma_th> <subtype>
212 // create_dreg_fg <d> <N> <sigma_w> <sigma_th> <subtype>
215 int main( int argc
, char *argv
[] ) {
220 double beta
, sigma_w
, sigma_th
;
221 string type
, subtype
;
223 // Declare the supported options.
224 po::options_description
desc("Allowed options");
226 ("help", "produce help message")
227 ("type", po::value
<string
>(&type
), "factor graph type:\n\t'full', 'grid', 'dreg' or 'hoi'")
228 ("seed", po::value
<size_t>(&seed
), "random number seed")
229 ("subtype", po::value
<string
>(&subtype
), "interactions type:\n\t'sg', 'fe' or 'af'\n\t(ignored for type=='hoi')")
230 ("N", po::value
<size_t>(&N
), "number of (binary) variables")
231 ("M", po::value
<size_t>(&M
), "number of factors\n\t(only for type=='hoi')")
232 ("k", po::value
<size_t>(&k
), "connectivity of the factors\n\t(only for type=='hoi')")
233 ("d", po::value
<size_t>(&d
), "variable connectivity\n\t(only for type=='dreg')")
234 ("beta", po::value
<double>(&beta
), "stddev of log-factor entries\n\t(only for type=='hoi')")
235 ("sigma_w", po::value
<double>(&sigma_w
), "stddev of pairwise interactions w_{ij}\n\t(ignored for type=='hoi')")
236 ("sigma_th", po::value
<double>(&sigma_th
), "stddev of singleton interactions th_i\n\t(ignored for type=='hoi')")
237 ("periodic", po::value
<size_t>(&periodic
), "0/1 corresponding to nonperiodic/periodic grid\n\t(only for type=='grid')")
240 po::variables_map vm
;
241 po::store(po::parse_command_line(argc
, argv
, desc
), vm
);
244 if( vm
.count("help") || !vm
.count("type") ) {
245 if( vm
.count("type") ) {
246 if( type
== HOITYPE
) {
247 cout
<< "Creates a random factor graph of <N> binary variables and" << endl
;
248 cout
<< "<M> factors, each factor being an interaction of <k> variables." << endl
;
249 cout
<< "The entries of the factors are exponentials of i.i.d. Gaussian" << endl
;
250 cout
<< "variables with mean 0 and standard deviation <beta>." << endl
;
251 } else if( type
== FULLTYPE
) {
252 cout
<< "Creates fully connected pairwise graphical model of <N> variables;" << endl
;
253 } else if( type
== GRIDTYPE
) {
254 cout
<< "Creates 2D Ising grid (periodic if <periodic>!=0) of (approx.) <N> variables;" << endl
;
255 } else if( type
== DREGTYPE
) {
256 cout
<< "Creates random d-regular graph of <N> nodes with uniform degree <d>" << endl
;
257 cout
<< "(where <d><N> should be even);" << endl
;
259 cerr
<< "Unknown type (should be one of 'full', 'grid', 'dreg' or 'hoi')" << endl
;
261 if( type
== FULLTYPE
|| type
== GRIDTYPE
|| type
== DREGTYPE
) {
262 cout
<< "singleton interactions are Gaussian with mean 0 and standard" << endl
;
263 cout
<< "deviation <sigma_th>; pairwise interactions are Gaussian with mean 0" << endl
;
264 cout
<< "and standard deviation <sigma_w> if <subtype>=='sg', absolute value" << endl
;
265 cout
<< "is taken if <subtype>=='fe' and a minus sign is added if <subtype>=='af'." << endl
;
268 cout
<< endl
<< desc
<< endl
;
272 if( !vm
.count("seed") )
273 throw "Please specify random number seed.";
275 // srand( gsl_rng_default_seed );
279 cout
<< "# Factor graph made by " << argv
[0] << endl
;
280 cout
<< "# type = " << type
<< endl
;
282 if( type
== HOITYPE
) {
283 if( !vm
.count("N") || !vm
.count("M") || !vm
.count("k") || !vm
.count("beta") )
284 throw "Please specify all required arguments";
286 MakeHOIFG( N
, M
, k
, beta
, fg
);
287 } while( !fg
.G
.isConnected() );
289 cout
<< "# N = " << N
<< endl
;
290 cout
<< "# M = " << M
<< endl
;
291 cout
<< "# k = " << k
<< endl
;
292 cout
<< "# beta = " << beta
<< endl
;
293 } else if( type
== FULLTYPE
) {
294 if( !vm
.count("N") || !vm
.count("sigma_w") || !vm
.count("sigma_th") || !vm
.count("subtype") )
295 throw "Please specify all required arguments";
296 MakeFullFG( N
, sigma_w
, sigma_th
, subtype
, fg
);
298 cout
<< "# N = " << N
<< endl
;
299 cout
<< "# sigma_w = " << sigma_w
<< endl
;
300 cout
<< "# sigma_th = " << sigma_th
<< endl
;
301 cout
<< "# subtype = " << subtype
<< endl
;
302 } else if( type
== GRIDTYPE
) {
303 if( !vm
.count("N") || !vm
.count("sigma_w") || !vm
.count("sigma_th") || !vm
.count("subtype") || !vm
.count("periodic") )
304 throw "Please specify all required arguments";
306 size_t n
= (size_t)sqrt((long double)N
);
309 MakeGridFG( periodic
, n
, sigma_w
, sigma_th
, subtype
, fg
);
311 cout
<< "# periodic = " << periodic
<< endl
;
312 cout
<< "# n = " << n
<< endl
;
313 cout
<< "# N = " << N
<< endl
;
314 cout
<< "# sigma_w = " << sigma_w
<< endl
;
315 cout
<< "# sigma_th = " << sigma_th
<< endl
;
316 cout
<< "# subtype = " << subtype
<< endl
;
317 } else if( type
== DREGTYPE
) {
318 if( !vm
.count("N") || !vm
.count("sigma_w") || !vm
.count("sigma_th") || !vm
.count("subtype") || !vm
.count("d") )
319 throw "Please specify all required arguments";
321 MakeDRegFG( N
, d
, sigma_w
, sigma_th
, subtype
, fg
);
323 cout
<< "# N = " << N
<< endl
;
324 cout
<< "# d = " << d
<< endl
;
325 cout
<< "# sigma_w = " << sigma_w
<< endl
;
326 cout
<< "# sigma_th = " << sigma_th
<< endl
;
327 cout
<< "# subtype = " << subtype
<< endl
;
330 cout
<< "# seed = " << seed
<< endl
;
333 catch(exception
& e
) {
334 cerr
<< "Error: " << e
.what() << endl
;
337 catch(const char * e
) {
338 cerr
<< "Error: " << e
<< endl
;
342 cerr
<< "Exception of unknown type!" << endl
;