1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
24 #include "factorgraph.h"
25 #include "weightedgraph.h"
27 #include <boost/program_options.hpp>
31 namespace po
= boost::program_options
;
34 void MakeHOIFG( size_t N
, size_t M
, size_t k
, double sigma
, FactorGraph
&fg
) {
36 vector
<Factor
> factors
;
38 for( size_t i
= 0; i
< N
; i
++ )
39 vars
.push_back(Var(i
,2));
41 for( size_t I
= 0; I
< M
; I
++ ) {
43 while( vars
.size() < k
) {
45 size_t newind
= (size_t)(N
* rnd_uniform());
46 Var newvar
= Var(newind
, 2);
47 if( !(vars
&& newvar
) ) {
54 for( size_t t
= 0; t
< newfac
.stateSpace(); t
++ )
55 newfac
[t
] = exp(rnd_stdnormal() * sigma
);
56 factors
.push_back(newfac
);
59 fg
= FactorGraph(factors
);
63 void MakeFullFG( size_t N
, double sigma_w
, double sigma_th
, string type
, FactorGraph
&fg
) {
65 vector
<Factor
> factors
;
71 for( size_t i
= 0; i
< N
; i
++ )
72 vars
.push_back(Var(i
,2));
74 for( size_t i
= 0; i
< N
; i
++ )
75 for( size_t j
= 0; j
< N
; j
++ )
78 for( size_t i
= 0; i
< N
; i
++ )
79 for( size_t j
= i
+1; j
< N
; j
++ ) {
80 w
[i
][j
] = rnd_stdnormal() * sigma_w
;
82 w
[i
][j
] = fabs(w
[i
][j
]);
83 else if( type
== "af" )
84 w
[i
][j
] = -fabs(w
[i
][j
]);
86 buf
[0] = (buf
[3] = exp(w
[i
][j
]));
87 buf
[1] = (buf
[2] = exp(-w
[i
][j
]));
88 factors
.push_back(Factor(VarSet(vars
[i
],vars
[j
]),buf
));
91 for( size_t i
= 0; i
< N
; i
++ ) {
92 th
[i
] = rnd_stdnormal() * sigma_th
;
95 factors
.push_back(Factor(vars
[i
],buf
));
98 fg
= FactorGraph(factors
);
102 void MakeGridFG( long periodic
, long n
, double sigma_w
, double sigma_th
, string type
, FactorGraph
&fg
) {
104 vector
<Factor
> factors
;
112 for( long i
= 0; i
< N
; i
++ )
113 vars
.push_back(Var(i
,2));
115 for( long i
= 0; i
< N
; i
++ )
116 for( long j
= 0; j
< N
; j
++ )
119 for( long i
= 0; i
< n
; i
++ )
120 for( long j
= 0; j
< n
; j
++ ) {
121 if( i
+1 < n
|| periodic
)
122 w
[i
*n
+j
][((i
+1)%n
)*n
+j
] = 1.0;
123 if( i
> 0 || periodic
)
124 w
[i
*n
+j
][((i
+n
-1)%n
)*n
+j
] = 1.0;
125 if( j
+1 < n
|| periodic
)
126 w
[i
*n
+j
][i
*n
+((j
+1)%n
)] = 1.0;
127 if( j
> 0 || periodic
)
128 w
[i
*n
+j
][i
*n
+((j
+n
-1)%n
)] = 1.0;
131 for( long i
= 0; i
< N
; i
++ )
132 for( long j
= i
+1; j
< N
; j
++ )
134 w
[i
][j
] = rnd_stdnormal() * sigma_w
;
136 w
[i
][j
] = fabs(w
[i
][j
]);
137 else if( type
== "af" )
138 w
[i
][j
] = -fabs(w
[i
][j
]);
140 buf
[0] = (buf
[3] = exp(w
[i
][j
]));
141 buf
[1] = (buf
[2] = exp(-w
[i
][j
]));
142 factors
.push_back(Factor(VarSet(vars
[i
],vars
[j
]),buf
));
145 for( long i
= 0; i
< N
; i
++ ) {
146 th
[i
] = rnd_stdnormal() * sigma_th
;
148 buf
[1] = exp(-th
[i
]);
149 factors
.push_back(Factor(vars
[i
],buf
));
152 fg
= FactorGraph(factors
);
156 void MakeDRegFG( size_t N
, size_t d
, double sigma_w
, double sigma_th
, string type
, FactorGraph
&fg
) {
158 vector
<Factor
> factors
;
164 for( size_t i
= 0; i
< N
; i
++ )
165 vars
.push_back(Var(i
,2));
167 for( size_t i
= 0; i
< N
; i
++ )
168 for( size_t j
= 0; j
< N
; j
++ )
171 UEdgeVec g
= RandomDRegularGraph( N
, d
);
172 for( size_t i
= 0; i
< g
.size(); i
++ ) {
173 w
[g
[i
].n1
][g
[i
].n2
] = 1.0;
174 w
[g
[i
].n2
][g
[i
].n1
] = 1.0;
177 for( size_t i
= 0; i
< N
; i
++ )
178 for( size_t j
= i
+1; j
< N
; j
++ )
180 w
[i
][j
] = rnd_stdnormal() * sigma_w
;
182 w
[i
][j
] = fabs(w
[i
][j
]);
183 else if( type
== "af" )
184 w
[i
][j
] = -fabs(w
[i
][j
]);
186 buf
[0] = (buf
[3] = exp(w
[i
][j
]));
187 buf
[1] = (buf
[2] = exp(-w
[i
][j
]));
188 factors
.push_back(Factor(VarSet(vars
[i
],vars
[j
]),buf
));
191 for( size_t i
= 0; i
< N
; i
++ ) {
192 th
[i
] = rnd_stdnormal() * sigma_th
;
194 buf
[1] = exp(-th
[i
]);
195 factors
.push_back(Factor(vars
[i
],buf
));
198 fg
= FactorGraph(factors
);
202 const char *HOITYPE
= "hoi";
203 const char *FULLTYPE
= "full";
204 const char *GRIDTYPE
= "grid";
205 const char *DREGTYPE
= "dreg";
209 // create_full_fg <N> <sigma_w> <sigma_th> <subtype>
210 // create_grid_fg <periodic> <n> <sigma_w> <sigma_th> <subtype>
211 // create_dreg_fg <d> <N> <sigma_w> <sigma_th> <subtype>
214 int main( int argc
, char *argv
[] ) {
219 double beta
, sigma_w
, sigma_th
;
220 string type
, subtype
;
222 // Declare the supported options.
223 po::options_description
desc("Allowed options");
225 ("help", "produce help message")
226 ("type", po::value
<string
>(&type
), "factor graph type:\n\t'full', 'grid', 'dreg' or 'hoi'")
227 ("seed", po::value
<size_t>(&seed
), "random number seed")
228 ("subtype", po::value
<string
>(&subtype
), "interactions type:\n\t'sg', 'fe' or 'af'\n\t(ignored for type=='hoi')")
229 ("N", po::value
<size_t>(&N
), "number of (binary) variables")
230 ("M", po::value
<size_t>(&M
), "number of factors\n\t(only for type=='hoi')")
231 ("k", po::value
<size_t>(&k
), "connectivity of the factors\n\t(only for type=='hoi')")
232 ("d", po::value
<size_t>(&d
), "variable connectivity\n\t(only for type=='dreg')")
233 ("beta", po::value
<double>(&beta
), "stddev of log-factor entries\n\t(only for type=='hoi')")
234 ("sigma_w", po::value
<double>(&sigma_w
), "stddev of pairwise interactions w_{ij}\n\t(ignored for type=='hoi')")
235 ("sigma_th", po::value
<double>(&sigma_th
), "stddev of singleton interactions th_i\n\t(ignored for type=='hoi')")
236 ("periodic", po::value
<size_t>(&periodic
), "0/1 corresponding to nonperiodic/periodic grid\n\t(only for type=='grid')")
239 po::variables_map vm
;
240 po::store(po::parse_command_line(argc
, argv
, desc
), vm
);
243 if( vm
.count("help") || !vm
.count("type") ) {
244 if( vm
.count("type") ) {
245 if( type
== HOITYPE
) {
246 cout
<< "Creates a random factor graph of <N> binary variables and" << endl
;
247 cout
<< "<M> factors, each factor being an interaction of <k> variables." << endl
;
248 cout
<< "The entries of the factors are exponentials of i.i.d. Gaussian" << endl
;
249 cout
<< "variables with mean 0 and standard deviation <beta>." << endl
;
250 } else if( type
== FULLTYPE
) {
251 cout
<< "Creates fully connected pairwise graphical model of <N> variables;" << endl
;
252 } else if( type
== GRIDTYPE
) {
253 cout
<< "Creates 2D Ising grid (periodic if <periodic>!=0) of (approx.) <N> variables;" << endl
;
254 } else if( type
== DREGTYPE
) {
255 cout
<< "Creates random d-regular graph of <N> nodes with uniform degree <d>" << endl
;
256 cout
<< "(where <d><N> should be even);" << endl
;
258 cerr
<< "Unknown type (should be one of 'full', 'grid', 'dreg' or 'hoi')" << endl
;
260 if( type
== FULLTYPE
|| type
== GRIDTYPE
|| type
== DREGTYPE
) {
261 cout
<< "singleton interactions are Gaussian with mean 0 and standard" << endl
;
262 cout
<< "deviation <sigma_th>; pairwise interactions are Gaussian with mean 0" << endl
;
263 cout
<< "and standard deviation <sigma_w> if <subtype>=='sg', absolute value" << endl
;
264 cout
<< "is taken if <subtype>=='fe' and a minus sign is added if <subtype>=='af'." << endl
;
267 cout
<< endl
<< desc
<< endl
;
271 if( !vm
.count("seed") )
272 throw "Please specify random number seed.";
274 // srand( gsl_rng_default_seed );
278 cout
<< "# Factor graph made by " << argv
[0] << endl
;
279 cout
<< "# type = " << type
<< endl
;
281 if( type
== HOITYPE
) {
282 if( !vm
.count("N") || !vm
.count("M") || !vm
.count("k") || !vm
.count("beta") )
283 throw "Please specify all required arguments";
285 MakeHOIFG( N
, M
, k
, beta
, fg
);
286 } while( !fg
.isConnected() );
288 cout
<< "# N = " << N
<< endl
;
289 cout
<< "# M = " << M
<< endl
;
290 cout
<< "# k = " << k
<< endl
;
291 cout
<< "# beta = " << beta
<< endl
;
292 } else if( type
== FULLTYPE
) {
293 if( !vm
.count("N") || !vm
.count("sigma_w") || !vm
.count("sigma_th") || !vm
.count("subtype") )
294 throw "Please specify all required arguments";
295 MakeFullFG( N
, sigma_w
, sigma_th
, subtype
, fg
);
297 cout
<< "# N = " << N
<< endl
;
298 cout
<< "# sigma_w = " << sigma_w
<< endl
;
299 cout
<< "# sigma_th = " << sigma_th
<< endl
;
300 cout
<< "# subtype = " << subtype
<< endl
;
301 } else if( type
== GRIDTYPE
) {
302 if( !vm
.count("N") || !vm
.count("sigma_w") || !vm
.count("sigma_th") || !vm
.count("subtype") || !vm
.count("periodic") )
303 throw "Please specify all required arguments";
305 size_t n
= (size_t)sqrt((long double)N
);
308 MakeGridFG( periodic
, n
, sigma_w
, sigma_th
, subtype
, fg
);
310 cout
<< "# periodic = " << periodic
<< endl
;
311 cout
<< "# n = " << n
<< endl
;
312 cout
<< "# N = " << N
<< endl
;
313 cout
<< "# sigma_w = " << sigma_w
<< endl
;
314 cout
<< "# sigma_th = " << sigma_th
<< endl
;
315 cout
<< "# subtype = " << subtype
<< endl
;
316 } else if( type
== DREGTYPE
) {
317 if( !vm
.count("N") || !vm
.count("sigma_w") || !vm
.count("sigma_th") || !vm
.count("subtype") || !vm
.count("d") )
318 throw "Please specify all required arguments";
320 MakeDRegFG( N
, d
, sigma_w
, sigma_th
, subtype
, fg
);
322 cout
<< "# N = " << N
<< endl
;
323 cout
<< "# d = " << d
<< endl
;
324 cout
<< "# sigma_w = " << sigma_w
<< endl
;
325 cout
<< "# sigma_th = " << sigma_th
<< endl
;
326 cout
<< "# subtype = " << subtype
<< endl
;
329 cout
<< "# seed = " << seed
<< endl
;
332 catch(exception
& e
) {
333 cerr
<< "Error: " << e
.what() << endl
;
336 catch(const char * e
) {
337 cerr
<< "Error: " << e
<< endl
;
341 cerr
<< "Exception of unknown type!" << endl
;