Improved dai::Exception object (it now stores more information and doesn't print...
[libdai.git] / tests / testdai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <fstream>
11 #include <map>
12 #include <numeric>
13 #include <cmath>
14 #include <cstdlib>
15 #include <cstring>
16 #include <boost/program_options.hpp>
17 #include <dai/util.h>
18 #include <dai/alldai.h>
19
20
21 using namespace std;
22 using namespace dai;
23 namespace po = boost::program_options;
24
25
26 std::vector<Real> calcDists( const vector<Factor> &x, const vector<Factor> &y ) {
27 vector<Real> errs;
28 errs.reserve( x.size() );
29 DAI_ASSERT( x.size() == y.size() );
30 for( size_t i = 0; i < x.size(); i++ )
31 errs.push_back( dist( x[i], y[i], DISTTV ) );
32 return errs;
33 }
34
35
36 /// Wrapper class for DAI approximate inference algorithms
37 class TestDAI {
38 protected:
39 /// Stores a pointer to an InfAlg object, managed by this class
40 InfAlg *obj;
41 /// Stores the name of the InfAlg algorithm
42 string name;
43 /// Stores the total variation distances of the variable marginals
44 vector<Real> varErr;
45 /// Stores the total variation distances of the factor marginals
46 vector<Real> facErr;
47
48 public:
49 /// Stores the variable marginals
50 vector<Factor> varMarginals;
51 /// Stores the factor marginals
52 vector<Factor> facMarginals;
53 /// Stores all marginals
54 vector<Factor> allMarginals;
55 /// Stores the logarithm of the partition sum
56 Real logZ;
57 /// Stores the maximum difference in the last iteration
58 Real maxdiff;
59 /// Stores the computation time (in seconds)
60 double time;
61 /// Stores the number of iterations needed
62 size_t iters;
63 /// Does the InfAlg support logZ()?
64 bool has_logZ;
65 /// Does the InfAlg support maxDiff()?
66 bool has_maxdiff;
67 /// Does the InfAlg support Iterations()?
68 bool has_iters;
69
70 /// Construct from factor graph \a fg, name \a _name, and set of properties \a opts
71 TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), varErr(), facErr(), varMarginals(), facMarginals(), allMarginals(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
72 double tic = toc();
73
74 if( name == "LDPC" ) {
75 // special case: simulating a Low Density Parity Check code
76 Real zero[2] = {1.0, 0.0};
77 for( size_t i = 0; i < fg.nrVars(); i++ )
78 varMarginals.push_back( Factor(fg.var(i), zero) );
79 allMarginals = varMarginals;
80 logZ = 0.0;
81 maxdiff = 0.0;
82 iters = 1;
83 has_logZ = false;
84 has_maxdiff = false;
85 has_iters = false;
86 } else
87 // create a corresponding InfAlg object
88 obj = newInfAlg( name, fg, opts );
89
90 // Add the time needed to create the object
91 time += toc() - tic;
92 }
93
94 /// Destructor
95 ~TestDAI() {
96 if( obj != NULL )
97 delete obj;
98 }
99
100 /// Identify
101 string identify() const {
102 if( obj != NULL )
103 return obj->identify();
104 else
105 return "NULL";
106 }
107
108 /// Run the algorithm and store its results
109 void doDAI() {
110 double tic = toc();
111 if( obj != NULL ) {
112 // Initialize
113 obj->init();
114 // Run
115 obj->run();
116 // Record the time
117 time += toc() - tic;
118
119 // Store logarithm of the partition sum (if supported)
120 try {
121 logZ = obj->logZ();
122 has_logZ = true;
123 } catch( Exception &e ) {
124 if( e.getCode() == Exception::NOT_IMPLEMENTED )
125 has_logZ = false;
126 else
127 throw;
128 }
129
130 // Store maximum difference encountered in last iteration (if supported)
131 try {
132 maxdiff = obj->maxDiff();
133 has_maxdiff = true;
134 } catch( Exception &e ) {
135 if( e.getCode() == Exception::NOT_IMPLEMENTED )
136 has_maxdiff = false;
137 else
138 throw;
139 }
140
141 // Store number of iterations needed (if supported)
142 try {
143 iters = obj->Iterations();
144 has_iters = true;
145 } catch( Exception &e ) {
146 if( e.getCode() == Exception::NOT_IMPLEMENTED )
147 has_iters = false;
148 else
149 throw;
150 }
151
152 // Store variable marginals
153 varMarginals.clear();
154 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
155 varMarginals.push_back( obj->beliefV( i ) );
156
157 // Store factor marginals
158 facMarginals.clear();
159 for( size_t I = 0; I < obj->fg().nrFactors(); I++ )
160 try {
161 facMarginals.push_back( obj->beliefF( I ) );
162 } catch( Exception &e ) {
163 if( e.getCode() == Exception::BELIEF_NOT_AVAILABLE )
164 facMarginals.push_back( Factor( obj->fg().factor(I).vars(), INFINITY ) );
165 else
166 throw;
167 }
168
169 // Store all marginals calculated by the method
170 allMarginals = obj->beliefs();
171 };
172 }
173
174 /// Calculate total variation distance of variable and factor marginals with respect to those in \a varMargs and \a facMargs
175 void calcErrors( const vector<Factor>& varMargs, const vector<Factor>& facMargs ) {
176 varErr = calcDists( varMarginals, varMargs );
177 facErr = calcDists( facMarginals, facMargs );
178 }
179
180 /// Return maximum variable error
181 Real maxVarErr() {
182 return( *max_element( varErr.begin(), varErr.end() ) );
183 }
184
185 /// Return average variable error
186 Real avgVarErr() {
187 return( accumulate( varErr.begin(), varErr.end(), 0.0 ) / varErr.size() );
188 }
189
190 /// Return maximum factor error
191 Real maxFacErr() {
192 return( *max_element( facErr.begin(), facErr.end() ) );
193 }
194
195 /// Return average factor error
196 Real avgFacErr() {
197 return( accumulate( facErr.begin(), facErr.end(), 0.0 ) / facErr.size() );
198 }
199 };
200
201
202 /// Clips a real number: if the absolute value of \a x is less than \a minabs, return \a minabs, else return \a x
203 Real clipReal( Real x, Real minabs ) {
204 if( abs(x) < minabs )
205 return minabs;
206 else
207 return x;
208 }
209
210
211 /// Which marginals to outpu (none, only variable, only factor, variable and factor, all)
212 DAI_ENUM(MarginalsOutputType,NONE,VAR,FAC,VARFAC,ALL);
213
214
215 /// Main function
216 int main( int argc, char *argv[] ) {
217 // Variables to store command line options
218 // Filename of factor graph
219 string filename;
220 // Filename for aliases
221 string aliases;
222 // Approximate Inference methods to use
223 vector<string> methods;
224 // Which marginals to output
225 MarginalsOutputType marginals;
226 // Output number of iterations?
227 bool report_iters = true;
228 // Output calculation time?
229 bool report_time = true;
230
231 // Define required command line options
232 po::options_description opts_required("Required options");
233 opts_required.add_options()
234 ("filename", po::value< string >(&filename), "Filename of factor graph")
235 ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to perform")
236 ;
237
238 // Define allowed command line options
239 po::options_description opts_optional("Allowed options");
240 opts_optional.add_options()
241 ("help", "Produce help message")
242 ("aliases", po::value< string >(&aliases), "Filename for aliases")
243 ("marginals", po::value< MarginalsOutputType >(&marginals), "Output marginals? (NONE/VAR/FAC/VARFAC/ALL, default=NONE)")
244 ("report-time", po::value< bool >(&report_time), "Output calculation time (default==1)?")
245 ("report-iters", po::value< bool >(&report_iters), "Output iterations needed (default==1)?")
246 ;
247
248 // Define all command line options
249 po::options_description cmdline_options;
250 cmdline_options.add(opts_required).add(opts_optional);
251
252 // Parse command line
253 po::variables_map vm;
254 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
255 po::notify(vm);
256
257 // Display help message if necessary
258 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
259 cout << "This program is part of libDAI - http://www.libdai.org/" << endl << endl;
260 cout << "Usage: ./testdai --filename <filename.fg> --methods <method1> [<method2> <method3> ...]" << endl << endl;
261 cout << "Reads factor graph <filename.fg> and performs the approximate inference algorithms" << endl;
262 cout << "<method*>, reporting for each method:" << endl;
263 cout << " o the calculation time needed, in seconds (if report-time == 1);" << endl;
264 cout << " o the number of iterations needed (if report-iters == 1);" << endl;
265 cout << " o the maximum (over all variables) total variation error in the variable marginals;" << endl;
266 cout << " o the average (over all variables) total variation error in the variable marginals;" << endl;
267 cout << " o the maximum (over all factors) total variation error in the factor marginals;" << endl;
268 cout << " o the average (over all factors) total variation error in the factor marginals;" << endl;
269 cout << " o the error (difference) of the logarithm of the partition sums;" << endl << endl;
270 cout << "All errors are calculated by comparing the results of the current method with" << endl;
271 cout << "the results of the first method (the base method). If marginals==VAR, additional" << endl;
272 cout << "output consists of the variable marginals, if marginals==FAC, the factor marginals" << endl;
273 cout << "if marginals==VARFAC, both variable and factor marginals, and if marginals==ALL, all" << endl;
274 cout << "marginals calculated by the method are reported." << endl << endl;
275 cout << "<method*> should be a list of one or more methods, seperated by spaces, in the format:" << endl << endl;
276 cout << " name[key1=val1,key2=val2,key3=val3,...,keyn=valn]" << endl << endl;
277 cout << "where name should be the name of an algorithm in libDAI (or an alias, if an alias" << endl;
278 cout << "filename is provided), followed by a list of properties (surrounded by rectangular" << endl;
279 cout << "brackets), where each property consists of a key=value pair and the properties are" << endl;
280 cout << "seperated by commas. If an alias file is specified, alias substitution is performed." << endl;
281 cout << "This is done by looking up the name in the alias file and substituting the alias" << endl;
282 cout << "by its corresponding method as defined in the alias file. Properties are parsed from" << endl;
283 cout << "left to right, so if a property occurs repeatedly, the right-most value is used." << endl << endl;
284 cout << opts_required << opts_optional << endl;
285 #ifdef DAI_DEBUG
286 cout << "Note: this is a debugging build of libDAI." << endl << endl;
287 #endif
288 cout << "Example: ./testdai --filename testfast.fg --aliases aliases.conf --methods JTREE_HUGIN BP_SEQFIX BP_PARALL[maxiter=5]" << endl;
289 return 1;
290 }
291
292 try {
293 // Read aliases
294 map<string,string> Aliases;
295 if( !aliases.empty() )
296 Aliases = readAliasesFile( aliases );
297
298 // Read factor graph
299 FactorGraph fg;
300 fg.ReadFromFile( filename.c_str() );
301
302 // Declare variables used for storing variable factor marginals and log partition sum of base method
303 vector<Factor> varMarginals0;
304 vector<Factor> facMarginals0;
305 Real logZ0 = 0.0;
306
307 // Output header
308 cout.setf( ios_base::scientific );
309 cout.precision( 3 );
310 cout << "# " << filename << endl;
311 cout.width( 39 );
312 cout << left << "# METHOD" << "\t";
313 if( report_time )
314 cout << right << "SECONDS " << "\t";
315 if( report_iters )
316 cout << "ITERS" << "\t";
317 cout << "MAX VAR ERR" << "\t";
318 cout << "AVG VAR ERR" << "\t";
319 cout << "MAX FAC ERR" << "\t";
320 cout << "AVG FAC ERR" << "\t";
321 cout << "LOGZ ERROR" << "\t";
322 cout << "MAXDIFF" << "\t";
323 cout << endl;
324
325 // For each method...
326 for( size_t m = 0; m < methods.size(); m++ ) {
327 // Parse method
328 pair<string, PropertySet> meth = parseNameProperties( methods[m], Aliases );
329
330 // Construct object for running the method
331 TestDAI testdai(fg, meth.first, meth.second );
332
333 // Run the method
334 testdai.doDAI();
335
336 // For the base method, store its variable marginals and logarithm of the partition sum
337 if( m == 0 ) {
338 varMarginals0 = testdai.varMarginals;
339 facMarginals0 = testdai.facMarginals;
340 logZ0 = testdai.logZ;
341 }
342
343 // Calculate errors relative to base method
344 testdai.calcErrors( varMarginals0, facMarginals0 );
345
346 // Output method name
347 cout.width( 39 );
348 cout << left << methods[m] << "\t";
349 // Output calculation time, if requested
350 if( report_time )
351 cout << right << testdai.time << "\t";
352 // Output number of iterations, if requested
353 if( report_iters ) {
354 if( testdai.has_iters ) {
355 cout << testdai.iters << "\t";
356 } else {
357 cout << "N/A \t";
358 }
359 }
360
361 // If this is not the base method
362 if( m > 0 ) {
363 cout.setf( ios_base::scientific );
364 cout.precision( 3 );
365
366 // Output maximum error in variable marginals
367 Real mev = clipReal( testdai.maxVarErr(), 1e-9 );
368 cout << mev << "\t";
369
370 // Output average error in variable marginals
371 Real aev = clipReal( testdai.avgVarErr(), 1e-9 );
372 cout << aev << "\t";
373
374 // Output maximum error in factor marginals
375 Real mef = clipReal( testdai.maxFacErr(), 1e-9 );
376 if( mef == INFINITY )
377 cout << "N/A \t";
378 else
379 cout << mef << "\t";
380
381 // Output average error in factor marginals
382 Real aef = clipReal( testdai.avgFacErr(), 1e-9 );
383 if( aef == INFINITY )
384 cout << "N/A \t";
385 else
386 cout << aef << "\t";
387
388 // Output error in log partition sum
389 if( testdai.has_logZ ) {
390 cout.setf( ios::showpos );
391 Real le = clipReal( testdai.logZ - logZ0, 1e-9 );
392 cout << le << "\t";
393 cout.unsetf( ios::showpos );
394 } else
395 cout << "N/A \t";
396
397 // Output maximum difference in last iteration
398 if( testdai.has_maxdiff ) {
399 Real md = clipReal( testdai.maxdiff, 1e-9 );
400 if( dai::isnan( mev ) )
401 md = mev;
402 if( dai::isnan( aev ) )
403 md = aev;
404 if( md == INFINITY )
405 md = 1.0;
406 cout << md << "\t";
407 } else
408 cout << "N/A \t";
409 }
410 cout << endl;
411
412 // Output marginals, if requested
413 if( marginals == MarginalsOutputType::VAR || marginals == MarginalsOutputType::VARFAC )
414 for( size_t i = 0; i < testdai.varMarginals.size(); i++ )
415 cout << "# " << testdai.varMarginals[i] << endl;
416 if( marginals == MarginalsOutputType::FAC || marginals == MarginalsOutputType::VARFAC )
417 for( size_t I = 0; I < testdai.facMarginals.size(); I++ )
418 cout << "# " << testdai.facMarginals[I] << endl;
419 if( marginals == MarginalsOutputType::ALL )
420 for( size_t I = 0; I < testdai.allMarginals.size(); I++ )
421 cout << "# " << testdai.allMarginals[I] << endl;
422 }
423
424 return 0;
425 } catch( string &s ) {
426 // Abort with error message
427 cerr << "Exception: " << s << endl;
428 return 2;
429 }
430 }