[Frederik Eaton] Change testdai help message to warn if the user is running a debuggi...
[libdai.git] / tests / testdai.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <fstream>
25 #include <map>
26 #include <numeric>
27 #include <cmath>
28 #include <cstdlib>
29 #include <cstring>
30 #include <boost/program_options.hpp>
31 #include <dai/util.h>
32 #include <dai/alldai.h>
33
34
35 using namespace std;
36 using namespace dai;
37 namespace po = boost::program_options;
38
39
40 class TestDAI {
41 protected:
42 InfAlg *obj;
43 string name;
44 vector<double> err;
45
46 public:
47 vector<Factor> q;
48 double logZ;
49 double maxdiff;
50 double time;
51 size_t iters;
52 bool has_logZ;
53 bool has_maxdiff;
54 bool has_iters;
55
56 TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
57 double tic = toc();
58 if( name == "LDPC" ) {
59 double zero[2] = {1.0, 0.0};
60 q.clear();
61 for( size_t i = 0; i < fg.nrVars(); i++ )
62 q.push_back( Factor(Var(i,2), zero) );
63 logZ = 0.0;
64 maxdiff = 0.0;
65 iters = 1;
66 has_logZ = false;
67 has_maxdiff = false;
68 has_iters = false;
69 } else
70 obj = newInfAlg( name, fg, opts );
71 time += toc() - tic;
72 }
73
74 ~TestDAI() {
75 if( obj != NULL )
76 delete obj;
77 }
78
79 string identify() const {
80 if( obj != NULL )
81 return obj->identify();
82 else
83 return "NULL";
84 }
85
86 vector<Factor> allBeliefs() {
87 vector<Factor> result;
88 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
89 result.push_back( obj->belief( obj->fg().var(i) ) );
90 return result;
91 }
92
93 void doDAI() {
94 double tic = toc();
95 if( obj != NULL ) {
96 obj->init();
97 obj->run();
98 time += toc() - tic;
99
100 try {
101 logZ = obj->logZ();
102 has_logZ = true;
103 } catch( Exception &e ) {
104 if( e.code() == Exception::NOT_IMPLEMENTED )
105 has_logZ = false;
106 else
107 throw;
108 }
109
110 try {
111 maxdiff = obj->maxDiff();
112 has_maxdiff = true;
113 } catch( Exception &e ) {
114 if( e.code() == Exception::NOT_IMPLEMENTED )
115 has_maxdiff = false;
116 else
117 throw;
118 }
119
120 try {
121 iters = obj->Iterations();
122 has_iters = true;
123 } catch( Exception &e ) {
124 if( e.code() == Exception::NOT_IMPLEMENTED )
125 has_iters = false;
126 else
127 throw;
128 }
129
130 q = allBeliefs();
131 };
132 }
133
134 void calcErrs( const TestDAI &x ) {
135 err.clear();
136 err.reserve( q.size() );
137 for( size_t i = 0; i < q.size(); i++ )
138 err.push_back( dist( q[i], x.q[i], Prob::DISTTV ) );
139 }
140
141 void calcErrs( const vector<Factor> &x ) {
142 err.clear();
143 err.reserve( q.size() );
144 for( size_t i = 0; i < q.size(); i++ )
145 err.push_back( dist( q[i], x[i], Prob::DISTTV ) );
146 }
147
148 double maxErr() {
149 return( *max_element( err.begin(), err.end() ) );
150 }
151
152 double avgErr() {
153 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
154 }
155 };
156
157
158 pair<string, PropertySet> parseMethodRaw( const string &s ) {
159 string::size_type pos = s.find_first_of('[');
160 string name;
161 PropertySet opts;
162 if( pos == string::npos ) {
163 name = s;
164 } else {
165 name = s.substr(0,pos);
166
167 stringstream ss;
168 ss << s.substr(pos,s.length());
169 ss >> opts;
170 }
171 return make_pair(name,opts);
172 }
173
174
175 pair<string, PropertySet> parseMethod( const string &_s, const map<string,string> & aliases ) {
176 // break string into method[properties]
177 pair<string,PropertySet> ps = parseMethodRaw(_s);
178 bool looped = false;
179
180 // as long as 'method' is an alias, update:
181 while( aliases.find(ps.first) != aliases.end() && !looped ) {
182 string astr = aliases.find(ps.first)->second;
183 pair<string,PropertySet> aps = parseMethodRaw(astr);
184 if( aps.first == ps.first )
185 looped = true;
186 // override aps properties by ps properties
187 aps.second.Set( ps.second );
188 // replace ps by aps
189 ps = aps;
190 // repeat until method name == alias name ('looped'), or
191 // there is no longer an alias 'method'
192 }
193
194 // check whether name is valid
195 size_t n = 0;
196 for( ; strlen( DAINames[n] ) != 0; n++ )
197 if( ps.first == DAINames[n] )
198 break;
199 if( strlen( DAINames[n] ) == 0 && (ps.first != "LDPC") )
200 throw string("Unknown DAI algorithm \"") + ps.first + string("\" in \"") + _s + string("\"");
201
202 return ps;
203 }
204
205
206 double clipdouble( double x, double minabs ) {
207 if( fabs(x) < minabs )
208 return minabs;
209 else
210 return x;
211 }
212
213
214 int main( int argc, char *argv[] ) {
215 string filename;
216 string aliases;
217 vector<string> methods;
218 double tol;
219 size_t maxiter;
220 size_t verbose;
221 bool marginals = false;
222 bool report_iters = true;
223 bool report_time = true;
224
225 po::options_description opts_required("Required options");
226 opts_required.add_options()
227 ("filename", po::value< string >(&filename), "Filename of FactorGraph")
228 ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to test")
229 ;
230
231 po::options_description opts_optional("Allowed options");
232 opts_optional.add_options()
233 ("help", "produce help message")
234 ("aliases", po::value< string >(&aliases), "Filename for aliases")
235 ("tol", po::value< double >(&tol), "Override tolerance")
236 ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
237 ("verbose", po::value< size_t >(&verbose), "Override verbosity")
238 ("marginals", po::value< bool >(&marginals), "Output single node marginals?")
239 ("report-time", po::value< bool >(&report_time), "Report calculation time")
240 ("report-iters", po::value< bool >(&report_iters), "Report iterations needed")
241 ;
242
243 po::options_description cmdline_options;
244 cmdline_options.add(opts_required).add(opts_optional);
245
246 po::variables_map vm;
247 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
248 po::notify(vm);
249
250 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
251 cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
252 cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
253 cout << "error and relative logZ error (comparing with the results of" << endl;
254 cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
255 cout << opts_required << opts_optional << endl;
256 #ifdef DAI_DEBUG
257 cout << "This is a debugging (unoptimised) build of libdai" << endl;
258 #endif
259 return 1;
260 }
261
262 // Read aliases
263 map<string,string> Aliases;
264 if( !aliases.empty() ) {
265 ifstream infile;
266 infile.open (aliases.c_str());
267 if (infile.is_open()) {
268 while( true ) {
269 string line;
270 getline(infile,line);
271 if( infile.fail() )
272 break;
273 if( (!line.empty()) && (line[0] != '#') ) {
274 string::size_type pos = line.find(':',0);
275 if( pos == string::npos )
276 throw "Invalid alias";
277 else {
278 string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
279 string key = line.substr(0, posl + 1);
280 string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
281 string val = line.substr(pos + 1 + posr, line.length());
282 Aliases[key] = val;
283 }
284 }
285 }
286 infile.close();
287 } else
288 throw "Error opening aliases file";
289 }
290
291 FactorGraph fg;
292 fg.ReadFromFile( filename.c_str() );
293
294 vector<Factor> q0;
295 double logZ0 = 0.0;
296
297 cout.setf( ios_base::scientific );
298 cout.precision( 3 );
299
300 cout << "# " << filename << endl;
301 cout.width( 39 );
302 cout << left << "# METHOD" << "\t";
303 if( report_time )
304 cout << right << "SECONDS " << "\t";
305 if( report_iters )
306 cout << "ITERS" << "\t";
307 cout << "MAX ERROR" << "\t";
308 cout << "AVG ERROR" << "\t";
309 cout << "LOGZ ERROR" << "\t";
310 cout << "MAXDIFF" << "\t";
311 cout << endl;
312
313 for( size_t m = 0; m < methods.size(); m++ ) {
314 pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
315
316 if( vm.count("tol") )
317 meth.second.Set("tol",tol);
318 if( vm.count("maxiter") )
319 meth.second.Set("maxiter",maxiter);
320 if( vm.count("verbose") )
321 meth.second.Set("verbose",verbose);
322 TestDAI piet(fg, meth.first, meth.second );
323 piet.doDAI();
324 if( m == 0 ) {
325 q0 = piet.q;
326 logZ0 = piet.logZ;
327 }
328 piet.calcErrs(q0);
329
330 cout.width( 39 );
331 cout << left << methods[m] << "\t";
332 if( report_time )
333 cout << right << piet.time << "\t";
334 if( report_iters ) {
335 if( piet.has_iters ) {
336 cout << piet.iters << "\t";
337 } else {
338 cout << "N/A \t";
339 }
340 }
341
342 if( m > 0 ) {
343 cout.setf( ios_base::scientific );
344 cout.precision( 3 );
345
346 double me = clipdouble( piet.maxErr(), 1e-9 );
347 cout << me << "\t";
348
349 double ae = clipdouble( piet.avgErr(), 1e-9 );
350 cout << ae << "\t";
351
352 if( piet.has_logZ ) {
353 cout.setf( ios::showpos );
354 double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
355 cout << le << "\t";
356 cout.unsetf( ios::showpos );
357 } else
358 cout << "N/A \t";
359
360 if( piet.has_maxdiff ) {
361 double md = clipdouble( piet.maxdiff, 1e-9 );
362 if( isnan( me ) )
363 md = me;
364 if( isnan( ae ) )
365 md = ae;
366 cout << md << "\t";
367 } else
368 cout << "N/A \t";
369 }
370 cout << endl;
371
372 if( marginals ) {
373 for( size_t i = 0; i < piet.q.size(); i++ )
374 cout << "# " << piet.q[i] << endl;
375 }
376 }
377
378 return 0;
379 }