Some improvements to jtree and regiongraph and started work on regiongraph unit tests
[libdai.git] / tests / testdai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <map>
15 #include <numeric>
16 #include <cmath>
17 #include <cstdlib>
18 #include <cstring>
19 #include <boost/program_options.hpp>
20 #include <dai/util.h>
21 #include <dai/alldai.h>
22
23
24 using namespace std;
25 using namespace dai;
26 namespace po = boost::program_options;
27
28
29 /// Wrapper class for DAI approximate inference algorithms
30 class TestDAI {
31 protected:
32 /// Stores a pointer to an InfAlg object, managed by this class
33 InfAlg *obj;
34 /// Stores the name of the InfAlg algorithm
35 string name;
36 /// Stores the total variation distances of the variable marginals
37 vector<Real> err;
38
39 public:
40 /// Stores the variable marginals
41 vector<Factor> varMarginals;
42 /// Stores all marginals
43 vector<Factor> allMarginals;
44 /// Stores the logarithm of the partition sum
45 Real logZ;
46 /// Stores the maximum difference in the last iteration
47 Real maxdiff;
48 /// Stores the computation time (in seconds)
49 double time;
50 /// Stores the number of iterations needed
51 size_t iters;
52 /// Does the InfAlg support logZ()?
53 bool has_logZ;
54 /// Does the InfAlg support maxDiff()?
55 bool has_maxdiff;
56 /// Does the InfAlg support Iterations()?
57 bool has_iters;
58
59 /// Construct from factor graph \a fg, name \a _name, and set of properties \a opts
60 TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), varMarginals(), allMarginals(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
61 double tic = toc();
62
63 if( name == "LDPC" ) {
64 // special case: simulating a Low Density Parity Check code
65 Real zero[2] = {1.0, 0.0};
66 for( size_t i = 0; i < fg.nrVars(); i++ )
67 varMarginals.push_back( Factor(fg.var(i), zero) );
68 allMarginals = varMarginals;
69 logZ = 0.0;
70 maxdiff = 0.0;
71 iters = 1;
72 has_logZ = false;
73 has_maxdiff = false;
74 has_iters = false;
75 } else
76 // create a corresponding InfAlg object
77 obj = newInfAlg( name, fg, opts );
78
79 // Add the time needed to create the object
80 time += toc() - tic;
81 }
82
83 /// Destructor
84 ~TestDAI() {
85 if( obj != NULL )
86 delete obj;
87 }
88
89 /// Identify
90 string identify() const {
91 if( obj != NULL )
92 return obj->identify();
93 else
94 return "NULL";
95 }
96
97 /// Run the algorithm and store its results
98 void doDAI() {
99 double tic = toc();
100 if( obj != NULL ) {
101 // Initialize
102 obj->init();
103 // Run
104 obj->run();
105 // Record the time
106 time += toc() - tic;
107
108 // Store logarithm of the partition sum (if supported)
109 try {
110 logZ = obj->logZ();
111 has_logZ = true;
112 } catch( Exception &e ) {
113 if( e.code() == Exception::NOT_IMPLEMENTED )
114 has_logZ = false;
115 else
116 throw;
117 }
118
119 // Store maximum difference encountered in last iteration (if supported)
120 try {
121 maxdiff = obj->maxDiff();
122 has_maxdiff = true;
123 } catch( Exception &e ) {
124 if( e.code() == Exception::NOT_IMPLEMENTED )
125 has_maxdiff = false;
126 else
127 throw;
128 }
129
130 // Store number of iterations needed (if supported)
131 try {
132 iters = obj->Iterations();
133 has_iters = true;
134 } catch( Exception &e ) {
135 if( e.code() == Exception::NOT_IMPLEMENTED )
136 has_iters = false;
137 else
138 throw;
139 }
140
141 // Store variable marginals
142 varMarginals.clear();
143 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
144 varMarginals.push_back( obj->beliefV( i ) );
145
146 // Store all marginals calculated by the method
147 allMarginals = obj->beliefs();
148 };
149 }
150
151 /// Calculate total variation distance of variable marginals with respect to those in \a x
152 void calcErrs( const TestDAI &x ) {
153 err.clear();
154 err.reserve( varMarginals.size() );
155 for( size_t i = 0; i < varMarginals.size(); i++ )
156 err.push_back( dist( varMarginals[i], x.varMarginals[i], DISTTV ) );
157 }
158
159 /// Calculate total variation distance of variable marginals with respect to those in \a x
160 void calcErrs( const vector<Factor> &x ) {
161 err.clear();
162 err.reserve( varMarginals.size() );
163 for( size_t i = 0; i < varMarginals.size(); i++ )
164 err.push_back( dist( varMarginals[i], x[i], DISTTV ) );
165 }
166
167 /// Return maximum error
168 Real maxErr() {
169 return( *max_element( err.begin(), err.end() ) );
170 }
171
172 /// Return average error
173 Real avgErr() {
174 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
175 }
176 };
177
178
179 /// Clips a real number: if the absolute value of \a x is less than \a minabs, return \a minabs, else return \a x
180 Real clipReal( Real x, Real minabs ) {
181 if( abs(x) < minabs )
182 return minabs;
183 else
184 return x;
185 }
186
187
188 /// Whether to output no marginals, only variable marginals, or all calculated marginals
189 DAI_ENUM(MarginalsOutputType,NONE,VAR,ALL);
190
191
192 /// Main function
193 int main( int argc, char *argv[] ) {
194 // Variables to store command line options
195 // Filename of factor graph
196 string filename;
197 // Filename for aliases
198 string aliases;
199 // Approximate Inference methods to use
200 vector<string> methods;
201 // Which marginals to output
202 MarginalsOutputType marginals;
203 // Output number of iterations?
204 bool report_iters = true;
205 // Output calculation time?
206 bool report_time = true;
207
208 // Define required command line options
209 po::options_description opts_required("Required options");
210 opts_required.add_options()
211 ("filename", po::value< string >(&filename), "Filename of factor graph")
212 ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to perform")
213 ;
214
215 // Define allowed command line options
216 po::options_description opts_optional("Allowed options");
217 opts_optional.add_options()
218 ("help", "Produce help message")
219 ("aliases", po::value< string >(&aliases), "Filename for aliases")
220 ("marginals", po::value< MarginalsOutputType >(&marginals), "Output marginals? (NONE/VAR/ALL, default=NONE)")
221 ("report-time", po::value< bool >(&report_time), "Output calculation time (default==1)?")
222 ("report-iters", po::value< bool >(&report_iters), "Output iterations needed (default==1)?")
223 ;
224
225 // Define all command line options
226 po::options_description cmdline_options;
227 cmdline_options.add(opts_required).add(opts_optional);
228
229 // Parse command line
230 po::variables_map vm;
231 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
232 po::notify(vm);
233
234 // Display help message if necessary
235 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
236 cout << "This program is part of libDAI - http://www.libdai.org/" << endl << endl;
237 cout << "Usage: ./testdai --filename <filename.fg> --methods <method1> [<method2> <method3> ...]" << endl << endl;
238 cout << "Reads factor graph <filename.fg> and performs the approximate inference algorithms" << endl;
239 cout << "<method*>, reporting for each method:" << endl;
240 cout << " o the calculation time needed, in seconds (if report-time == 1);" << endl;
241 cout << " o the number of iterations needed (if report-iters == 1);" << endl;
242 cout << " o the maximum (over all variables) total variation error in the variable marginals;" << endl;
243 cout << " o the average (over all variables) total variation error in the variable marginals;" << endl;
244 cout << " o the error (difference) of the logarithm of the partition sums;" << endl << endl;
245 cout << "All errors are calculated by comparing the results of the current method with" << endl;
246 cout << "the results of the first method (the base method). If marginals==VAR, additional" << endl;
247 cout << "output consists of the variable marginals, and if marginals==ALL, all marginals" << endl;
248 cout << "calculated by the method are reported." << endl << endl;
249 cout << "<method*> should be a list of one or more methods, seperated by spaces, in the format:" << endl << endl;
250 cout << " name[key1=val1,key2=val2,key3=val3,...,keyn=valn]" << endl << endl;
251 cout << "where name should be the name of an algorithm in libDAI (or an alias, if an alias" << endl;
252 cout << "filename is provided), followed by a list of properties (surrounded by rectangular" << endl;
253 cout << "brackets), where each property consists of a key=value pair and the properties are" << endl;
254 cout << "seperated by commas. If an alias file is specified, alias substitution is performed." << endl;
255 cout << "This is done by looking up the name in the alias file and substituting the alias" << endl;
256 cout << "by its corresponding method as defined in the alias file. Properties are parsed from" << endl;
257 cout << "left to right, so if a property occurs repeatedly, the right-most value is used." << endl << endl;
258 cout << opts_required << opts_optional << endl;
259 #ifdef DAI_DEBUG
260 cout << "Note: this is a debugging build of libDAI." << endl << endl;
261 #endif
262 cout << "Example: ./testdai --filename testfast.fg --aliases aliases.conf --methods JTREE_HUGIN BP_SEQFIX BP_PARALL[maxiter=5]" << endl;
263 return 1;
264 }
265
266 try {
267 // Read aliases
268 map<string,string> Aliases;
269 if( !aliases.empty() )
270 Aliases = readAliasesFile( aliases );
271
272 // Read factor graph
273 FactorGraph fg;
274 fg.ReadFromFile( filename.c_str() );
275
276 // Declare variables used for storing variable marginals and log partition sum of base method
277 vector<Factor> varMarginals0;
278 Real logZ0 = 0.0;
279
280 // Output header
281 cout.setf( ios_base::scientific );
282 cout.precision( 3 );
283 cout << "# " << filename << endl;
284 cout.width( 39 );
285 cout << left << "# METHOD" << "\t";
286 if( report_time )
287 cout << right << "SECONDS " << "\t";
288 if( report_iters )
289 cout << "ITERS" << "\t";
290 cout << "MAX ERROR" << "\t";
291 cout << "AVG ERROR" << "\t";
292 cout << "LOGZ ERROR" << "\t";
293 cout << "MAXDIFF" << "\t";
294 cout << endl;
295
296 // For each method...
297 for( size_t m = 0; m < methods.size(); m++ ) {
298 // Parse method
299 pair<string, PropertySet> meth = parseNameProperties( methods[m], Aliases );
300
301 // Check whether name is valid
302 size_t n = 0;
303 for( ; strlen( DAINames[n] ) != 0; n++ )
304 if( meth.first == DAINames[n] )
305 break;
306 if( strlen( DAINames[n] ) == 0 )
307 DAI_THROWE(UNKNOWN_DAI_ALGORITHM,string("Unknown DAI algorithm \"") + meth.first + string("\" in \"") + methods[m] + string("\""));
308
309 // Construct object for running the method
310 TestDAI testdai(fg, meth.first, meth.second );
311
312 // Run the method
313 testdai.doDAI();
314
315 // For the base method, store its variable marginals and logarithm of the partition sum
316 if( m == 0 ) {
317 varMarginals0 = testdai.varMarginals;
318 logZ0 = testdai.logZ;
319 }
320
321 // Calculate errors relative to base method
322 testdai.calcErrs( varMarginals0 );
323
324 // Output method name
325 cout.width( 39 );
326 cout << left << methods[m] << "\t";
327 // Output calculation time, if requested
328 if( report_time )
329 cout << right << testdai.time << "\t";
330 // Output number of iterations, if requested
331 if( report_iters ) {
332 if( testdai.has_iters ) {
333 cout << testdai.iters << "\t";
334 } else {
335 cout << "N/A \t";
336 }
337 }
338
339 // If this is not the base method
340 if( m > 0 ) {
341 cout.setf( ios_base::scientific );
342 cout.precision( 3 );
343
344 // Output maximum error in variable marginals
345 Real me = clipReal( testdai.maxErr(), 1e-9 );
346 cout << me << "\t";
347
348 // Output average error in variable marginals
349 Real ae = clipReal( testdai.avgErr(), 1e-9 );
350 cout << ae << "\t";
351
352 // Output error in log partition sum
353 if( testdai.has_logZ ) {
354 cout.setf( ios::showpos );
355 Real le = clipReal( testdai.logZ - logZ0, 1e-9 );
356 cout << le << "\t";
357 cout.unsetf( ios::showpos );
358 } else
359 cout << "N/A \t";
360
361 // Output maximum difference in last iteration
362 if( testdai.has_maxdiff ) {
363 Real md = clipReal( testdai.maxdiff, 1e-9 );
364 if( isnan( me ) )
365 md = me;
366 if( isnan( ae ) )
367 md = ae;
368 if( md == INFINITY )
369 md = 1.0;
370 cout << md << "\t";
371 } else
372 cout << "N/A \t";
373 }
374 cout << endl;
375
376 // Output marginals, if requested
377 if( marginals == MarginalsOutputType::VAR ) {
378 for( size_t i = 0; i < testdai.varMarginals.size(); i++ )
379 cout << "# " << testdai.varMarginals[i] << endl;
380 } else if( marginals == MarginalsOutputType::ALL ) {
381 for( size_t I = 0; I < testdai.allMarginals.size(); I++ )
382 cout << "# " << testdai.allMarginals[I] << endl;
383 }
384 }
385
386 return 0;
387 } catch( string &s ) {
388 // Abort with error message
389 cerr << "Exception: " << s << endl;
390 return 2;
391 }
392 }