Moved platform independent build options into Makefile.ALL and documented tests/testdai
[libdai.git] / tests / testdai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <map>
15 #include <numeric>
16 #include <cmath>
17 #include <cstdlib>
18 #include <cstring>
19 #include <boost/program_options.hpp>
20 #include <dai/util.h>
21 #include <dai/alldai.h>
22
23
24 using namespace std;
25 using namespace dai;
26 namespace po = boost::program_options;
27
28
29 /// Wrapper class for DAI approximate inference algorithms
30 class TestDAI {
31 protected:
32 /// Stores a pointer to an InfAlg object, managed by this class
33 InfAlg *obj;
34 /// Stores the name of the InfAlg algorithm
35 string name;
36 /// Stores the total variation distances of the variable marginals
37 vector<Real> err;
38
39 public:
40 /// Stores the variable marginals
41 vector<Factor> varMarginals;
42 /// Stores all marginals
43 vector<Factor> allMarginals;
44 /// Stores the logarithm of the partition sum
45 Real logZ;
46 /// Stores the maximum difference in the last iteration
47 Real maxdiff;
48 /// Stores the computation time (in seconds)
49 double time;
50 /// Stores the number of iterations needed
51 size_t iters;
52 /// Does the InfAlg support logZ()?
53 bool has_logZ;
54 /// Does the InfAlg support maxDiff()?
55 bool has_maxdiff;
56 /// Does the InfAlg support Iterations()?
57 bool has_iters;
58
59 /// Construct from factor graph \a fg, name \a _name, and set of properties \a opts
60 TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), varMarginals(), allMarginals(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
61 double tic = toc();
62
63 if( name == "LDPC" ) {
64 // special case: simulating a Low Density Parity Check code
65 Prob zero(2,0.0);
66 zero[0] = 1.0;
67 for( size_t i = 0; i < fg.nrVars(); i++ )
68 varMarginals.push_back( Factor(fg.var(i), zero) );
69 allMarginals = varMarginals;
70 logZ = 0.0;
71 maxdiff = 0.0;
72 iters = 1;
73 has_logZ = false;
74 has_maxdiff = false;
75 has_iters = false;
76 } else
77 // create a corresponding InfAlg object
78 obj = newInfAlg( name, fg, opts );
79
80 // Add the time needed to create the object
81 time += toc() - tic;
82 }
83
84 /// Destructor
85 ~TestDAI() {
86 if( obj != NULL )
87 delete obj;
88 }
89
90 /// Identify
91 string identify() const {
92 if( obj != NULL )
93 return obj->identify();
94 else
95 return "NULL";
96 }
97
98 /// Run the algorithm and store its results
99 void doDAI() {
100 double tic = toc();
101 if( obj != NULL ) {
102 // Initialize
103 obj->init();
104 // Run
105 obj->run();
106 // Record the time
107 time += toc() - tic;
108
109 // Store logarithm of the partition sum (if supported)
110 try {
111 logZ = obj->logZ();
112 has_logZ = true;
113 } catch( Exception &e ) {
114 if( e.code() == Exception::NOT_IMPLEMENTED )
115 has_logZ = false;
116 else
117 throw;
118 }
119
120 // Store maximum difference encountered in last iteration (if supported)
121 try {
122 maxdiff = obj->maxDiff();
123 has_maxdiff = true;
124 } catch( Exception &e ) {
125 if( e.code() == Exception::NOT_IMPLEMENTED )
126 has_maxdiff = false;
127 else
128 throw;
129 }
130
131 // Store number of iterations needed (if supported)
132 try {
133 iters = obj->Iterations();
134 has_iters = true;
135 } catch( Exception &e ) {
136 if( e.code() == Exception::NOT_IMPLEMENTED )
137 has_iters = false;
138 else
139 throw;
140 }
141
142 // Store variable marginals
143 varMarginals.clear();
144 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
145 varMarginals.push_back( obj->beliefV( i ) );
146
147 // Store all marginals calculated by the method
148 allMarginals = obj->beliefs();
149 };
150 }
151
152 /// Calculate total variation distance of variable marginals with respect to those in \a x
153 void calcErrs( const TestDAI &x ) {
154 err.clear();
155 err.reserve( varMarginals.size() );
156 for( size_t i = 0; i < varMarginals.size(); i++ )
157 err.push_back( dist( varMarginals[i], x.varMarginals[i], Prob::DISTTV ) );
158 }
159
160 /// Calculate total variation distance of variable marginals with respect to those in \a x
161 void calcErrs( const vector<Factor> &x ) {
162 err.clear();
163 err.reserve( varMarginals.size() );
164 for( size_t i = 0; i < varMarginals.size(); i++ )
165 err.push_back( dist( varMarginals[i], x[i], Prob::DISTTV ) );
166 }
167
168 /// Return maximum error
169 Real maxErr() {
170 return( *max_element( err.begin(), err.end() ) );
171 }
172
173 /// Return average error
174 Real avgErr() {
175 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
176 }
177 };
178
179
180 /// Clips a real number: if the absolute value of \a x is less than \a minabs, return \a minabs, else return \a x
181 Real clipReal( Real x, Real minabs ) {
182 if( abs(x) < minabs )
183 return minabs;
184 else
185 return x;
186 }
187
188
189 /// Whether to output no marginals, only variable marginals, or all calculated marginals
190 DAI_ENUM(MarginalsOutputType,NONE,VAR,ALL);
191
192
193 /// Main function
194 int main( int argc, char *argv[] ) {
195 // Variables to store command line options
196 // Filename of factor graph
197 string filename;
198 // Filename for aliases
199 string aliases;
200 // Approximate Inference methods to use
201 vector<string> methods;
202 // Which marginals to output
203 MarginalsOutputType marginals;
204 // Output number of iterations?
205 bool report_iters = true;
206 // Output calculation time?
207 bool report_time = true;
208
209 // Define required command line options
210 po::options_description opts_required("Required options");
211 opts_required.add_options()
212 ("filename", po::value< string >(&filename), "Filename of factor graph")
213 ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to perform")
214 ;
215
216 // Define allowed command line options
217 po::options_description opts_optional("Allowed options");
218 opts_optional.add_options()
219 ("help", "Produce help message")
220 ("aliases", po::value< string >(&aliases), "Filename for aliases")
221 ("marginals", po::value< MarginalsOutputType >(&marginals), "Output marginals? (NONE/VAR/ALL, default=NONE)")
222 ("report-time", po::value< bool >(&report_time), "Output calculation time (default==1)?")
223 ("report-iters", po::value< bool >(&report_iters), "Output iterations needed (default==1)?")
224 ;
225
226 // Define all command line options
227 po::options_description cmdline_options;
228 cmdline_options.add(opts_required).add(opts_optional);
229
230 // Parse command line
231 po::variables_map vm;
232 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
233 po::notify(vm);
234
235 // Display help message if necessary
236 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
237 cout << "This program is part of libDAI - http://www.libdai.org/" << endl << endl;
238 cout << "Usage: ./testdai --filename <filename.fg> --methods <method1> [<method2> <method3> ...]" << endl << endl;
239 cout << "Reads factor graph <filename.fg> and performs the approximate inference algorithms" << endl;
240 cout << "<method*>, reporting for each method:" << endl;
241 cout << " o the calculation time needed, in seconds (if report-time == 1);" << endl;
242 cout << " o the number of iterations needed (if report-iters == 1);" << endl;
243 cout << " o the maximum (over all variables) total variation error in the variable marginals;" << endl;
244 cout << " o the average (over all variables) total variation error in the variable marginals;" << endl;
245 cout << " o the error (difference) of the logarithm of the partition sums;" << endl << endl;
246 cout << "All errors are calculated by comparing the results of the current method with" << endl;
247 cout << "the results of the first method (the base method). If marginals==VAR, additional" << endl;
248 cout << "output consists of the variable marginals, and if marginals==ALL, all marginals" << endl;
249 cout << "calculated by the method are reported." << endl << endl;
250 cout << "<method*> should be a list of one or more methods, seperated by spaces, in the format:" << endl << endl;
251 cout << " name[key1=val1,key2=val2,key3=val3,...,keyn=valn]" << endl << endl;
252 cout << "where name should be the name of an algorithm in libDAI (or an alias, if an alias" << endl;
253 cout << "filename is provided), followed by a list of properties (surrounded by rectangular" << endl;
254 cout << "brackets), where each property consists of a key=value pair and the properties are" << endl;
255 cout << "seperated by commas. If an alias file is specified, alias substitution is performed." << endl;
256 cout << "This is done by looking up the name in the alias file and substituting the alias" << endl;
257 cout << "by its corresponding method as defined in the alias file. Properties are parsed from" << endl;
258 cout << "left to right, so if a property occurs repeatedly, the right-most value is used." << endl << endl;
259 cout << opts_required << opts_optional << endl;
260 #ifdef DAI_DEBUG
261 cout << "Note: this is a debugging build of libDAI." << endl << endl;
262 #endif
263 cout << "Example: ./testdai --filename testfast.fg --aliases aliases.conf --methods JTREE_HUGIN BP_SEQFIX BP_PARALL[maxiter=5]" << endl;
264 return 1;
265 }
266
267 try {
268 // Read aliases
269 map<string,string> Aliases;
270 if( !aliases.empty() )
271 Aliases = readAliasesFile( aliases );
272
273 // Read factor graph
274 FactorGraph fg;
275 fg.ReadFromFile( filename.c_str() );
276
277 // Declare variables used for storing variable marginals and log partition sum of base method
278 vector<Factor> varMarginals0;
279 Real logZ0 = 0.0;
280
281 // Output header
282 cout.setf( ios_base::scientific );
283 cout.precision( 3 );
284 cout << "# " << filename << endl;
285 cout.width( 39 );
286 cout << left << "# METHOD" << "\t";
287 if( report_time )
288 cout << right << "SECONDS " << "\t";
289 if( report_iters )
290 cout << "ITERS" << "\t";
291 cout << "MAX ERROR" << "\t";
292 cout << "AVG ERROR" << "\t";
293 cout << "LOGZ ERROR" << "\t";
294 cout << "MAXDIFF" << "\t";
295 cout << endl;
296
297 // For each method...
298 for( size_t m = 0; m < methods.size(); m++ ) {
299 // Parse method
300 pair<string, PropertySet> meth = parseNameProperties( methods[m], Aliases );
301
302 // Check whether name is valid
303 size_t n = 0;
304 for( ; strlen( DAINames[n] ) != 0; n++ )
305 if( meth.first == DAINames[n] )
306 break;
307 if( strlen( DAINames[n] ) == 0 )
308 DAI_THROWE(UNKNOWN_DAI_ALGORITHM,string("Unknown DAI algorithm \"") + meth.first + string("\" in \"") + methods[m] + string("\""));
309
310 // Construct object for running the method
311 TestDAI testdai(fg, meth.first, meth.second );
312
313 // Run the method
314 testdai.doDAI();
315
316 // For the base method, store its variable marginals and logarithm of the partition sum
317 if( m == 0 ) {
318 varMarginals0 = testdai.varMarginals;
319 logZ0 = testdai.logZ;
320 }
321
322 // Calculate errors relative to base method
323 testdai.calcErrs( varMarginals0 );
324
325 // Output method name
326 cout.width( 39 );
327 cout << left << methods[m] << "\t";
328 // Output calculation time, if requested
329 if( report_time )
330 cout << right << testdai.time << "\t";
331 // Output number of iterations, if requested
332 if( report_iters ) {
333 if( testdai.has_iters ) {
334 cout << testdai.iters << "\t";
335 } else {
336 cout << "N/A \t";
337 }
338 }
339
340 // If this is not the base method
341 if( m > 0 ) {
342 cout.setf( ios_base::scientific );
343 cout.precision( 3 );
344
345 // Output maximum error in variable marginals
346 Real me = clipReal( testdai.maxErr(), 1e-9 );
347 cout << me << "\t";
348
349 // Output average error in variable marginals
350 Real ae = clipReal( testdai.avgErr(), 1e-9 );
351 cout << ae << "\t";
352
353 // Output error in log partition sum
354 if( testdai.has_logZ ) {
355 cout.setf( ios::showpos );
356 Real le = clipReal( testdai.logZ - logZ0, 1e-9 );
357 cout << le << "\t";
358 cout.unsetf( ios::showpos );
359 } else
360 cout << "N/A \t";
361
362 // Output maximum difference in last iteration
363 if( testdai.has_maxdiff ) {
364 Real md = clipReal( testdai.maxdiff, 1e-9 );
365 if( isnan( me ) )
366 md = me;
367 if( isnan( ae ) )
368 md = ae;
369 if( md == INFINITY )
370 md = 1.0;
371 cout << md << "\t";
372 } else
373 cout << "N/A \t";
374 }
375 cout << endl;
376
377 // Output marginals, if requested
378 if( marginals == MarginalsOutputType::VAR ) {
379 for( size_t i = 0; i < testdai.varMarginals.size(); i++ )
380 cout << "# " << testdai.varMarginals[i] << endl;
381 } else if( marginals == MarginalsOutputType::ALL ) {
382 for( size_t I = 0; I < testdai.allMarginals.size(); I++ )
383 cout << "# " << testdai.allMarginals[I] << endl;
384 }
385 }
386
387 return 0;
388 } catch( string &s ) {
389 // Abort with error message
390 cerr << "Exception: " << s << endl;
391 return 2;
392 }
393 }