Merge branch 'pletscher'
[libdai.git] / tests / testdai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <map>
15 #include <numeric>
16 #include <cmath>
17 #include <cstdlib>
18 #include <cstring>
19 #include <boost/program_options.hpp>
20 #include <dai/util.h>
21 #include <dai/alldai.h>
22
23
24 using namespace std;
25 using namespace dai;
26 namespace po = boost::program_options;
27
28
29 class TestDAI {
30 protected:
31 InfAlg *obj;
32 string name;
33 vector<double> err;
34
35 public:
36 vector<Factor> q;
37 double logZ;
38 double maxdiff;
39 double time;
40 size_t iters;
41 bool has_logZ;
42 bool has_maxdiff;
43 bool has_iters;
44
45 TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
46 double tic = toc();
47 if( name == "LDPC" ) {
48 double zero[2] = {1.0, 0.0};
49 q.clear();
50 for( size_t i = 0; i < fg.nrVars(); i++ )
51 q.push_back( Factor(Var(i,2), zero) );
52 logZ = 0.0;
53 maxdiff = 0.0;
54 iters = 1;
55 has_logZ = false;
56 has_maxdiff = false;
57 has_iters = false;
58 } else
59 obj = newInfAlg( name, fg, opts );
60 time += toc() - tic;
61 }
62
63 ~TestDAI() {
64 if( obj != NULL )
65 delete obj;
66 }
67
68 string identify() const {
69 if( obj != NULL )
70 return obj->identify();
71 else
72 return "NULL";
73 }
74
75 vector<Factor> allBeliefs() {
76 vector<Factor> result;
77 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
78 result.push_back( obj->belief( obj->fg().var(i) ) );
79 return result;
80 }
81
82 void doDAI() {
83 double tic = toc();
84 if( obj != NULL ) {
85 obj->init();
86 obj->run();
87 time += toc() - tic;
88
89 try {
90 logZ = obj->logZ();
91 has_logZ = true;
92 } catch( Exception &e ) {
93 if( e.code() == Exception::NOT_IMPLEMENTED )
94 has_logZ = false;
95 else
96 throw;
97 }
98
99 try {
100 maxdiff = obj->maxDiff();
101 has_maxdiff = true;
102 } catch( Exception &e ) {
103 if( e.code() == Exception::NOT_IMPLEMENTED )
104 has_maxdiff = false;
105 else
106 throw;
107 }
108
109 try {
110 iters = obj->Iterations();
111 has_iters = true;
112 } catch( Exception &e ) {
113 if( e.code() == Exception::NOT_IMPLEMENTED )
114 has_iters = false;
115 else
116 throw;
117 }
118
119 q = allBeliefs();
120 };
121 }
122
123 void calcErrs( const TestDAI &x ) {
124 err.clear();
125 err.reserve( q.size() );
126 for( size_t i = 0; i < q.size(); i++ )
127 err.push_back( dist( q[i], x.q[i], Prob::DISTTV ) );
128 }
129
130 void calcErrs( const vector<Factor> &x ) {
131 err.clear();
132 err.reserve( q.size() );
133 for( size_t i = 0; i < q.size(); i++ )
134 err.push_back( dist( q[i], x[i], Prob::DISTTV ) );
135 }
136
137 double maxErr() {
138 return( *max_element( err.begin(), err.end() ) );
139 }
140
141 double avgErr() {
142 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
143 }
144 };
145
146
147 pair<string, PropertySet> parseMethodRaw( const string &s ) {
148 string::size_type pos = s.find_first_of('[');
149 string name;
150 PropertySet opts;
151 if( pos == string::npos ) {
152 name = s;
153 } else {
154 name = s.substr(0,pos);
155
156 stringstream ss;
157 ss << s.substr(pos,s.length());
158 ss >> opts;
159 }
160 return make_pair(name,opts);
161 }
162
163
164 pair<string, PropertySet> parseMethod( const string &_s, const map<string,string> & aliases ) {
165 // break string into method[properties]
166 pair<string,PropertySet> ps = parseMethodRaw(_s);
167 bool looped = false;
168
169 // as long as 'method' is an alias, update:
170 while( aliases.find(ps.first) != aliases.end() && !looped ) {
171 string astr = aliases.find(ps.first)->second;
172 pair<string,PropertySet> aps = parseMethodRaw(astr);
173 if( aps.first == ps.first )
174 looped = true;
175 // override aps properties by ps properties
176 aps.second.Set( ps.second );
177 // replace ps by aps
178 ps = aps;
179 // repeat until method name == alias name ('looped'), or
180 // there is no longer an alias 'method'
181 }
182
183 // check whether name is valid
184 size_t n = 0;
185 for( ; strlen( DAINames[n] ) != 0; n++ )
186 if( ps.first == DAINames[n] )
187 break;
188 if( strlen( DAINames[n] ) == 0 && (ps.first != "LDPC") )
189 DAI_THROWE(UNKNOWN_DAI_ALGORITHM,string("Unknown DAI algorithm \"") + ps.first + string("\" in \"") + _s + string("\""));
190
191 return ps;
192 }
193
194
195 double clipdouble( double x, double minabs ) {
196 if( fabs(x) < minabs )
197 return minabs;
198 else
199 return x;
200 }
201
202
203 int main( int argc, char *argv[] ) {
204 string filename;
205 string aliases;
206 vector<string> methods;
207 double tol;
208 size_t maxiter;
209 size_t verbose;
210 bool marginals = false;
211 bool report_iters = true;
212 bool report_time = true;
213
214 po::options_description opts_required("Required options");
215 opts_required.add_options()
216 ("filename", po::value< string >(&filename), "Filename of FactorGraph")
217 ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to test")
218 ;
219
220 po::options_description opts_optional("Allowed options");
221 opts_optional.add_options()
222 ("help", "produce help message")
223 ("aliases", po::value< string >(&aliases), "Filename for aliases")
224 ("tol", po::value< double >(&tol), "Override tolerance")
225 ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
226 ("verbose", po::value< size_t >(&verbose), "Override verbosity")
227 ("marginals", po::value< bool >(&marginals), "Output single node marginals?")
228 ("report-time", po::value< bool >(&report_time), "Report calculation time")
229 ("report-iters", po::value< bool >(&report_iters), "Report iterations needed")
230 ;
231
232 po::options_description cmdline_options;
233 cmdline_options.add(opts_required).add(opts_optional);
234
235 po::variables_map vm;
236 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
237 po::notify(vm);
238
239 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
240 cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
241 cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
242 cout << "error and relative logZ error (comparing with the results of" << endl;
243 cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
244 cout << opts_required << opts_optional << endl;
245 #ifdef DAI_DEBUG
246 cout << "This is a debugging (unoptimised) build of libDAI." << endl;
247 #endif
248 return 1;
249 }
250
251 try {
252 // Read aliases
253 map<string,string> Aliases;
254 if( !aliases.empty() ) {
255 ifstream infile;
256 infile.open (aliases.c_str());
257 if (infile.is_open()) {
258 while( true ) {
259 string line;
260 getline(infile,line);
261 if( infile.fail() )
262 break;
263 if( (!line.empty()) && (line[0] != '#') ) {
264 string::size_type pos = line.find(':',0);
265 if( pos == string::npos )
266 DAI_THROWE(RUNTIME_ERROR,"Invalid alias");
267 else {
268 string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
269 string key = line.substr(0, posl + 1);
270 string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
271 string val = line.substr(pos + 1 + posr, line.length());
272 Aliases[key] = val;
273 }
274 }
275 }
276 infile.close();
277 } else
278 DAI_THROWE(RUNTIME_ERROR,"Error opening aliases file");
279 }
280
281 FactorGraph fg;
282 fg.ReadFromFile( filename.c_str() );
283
284 vector<Factor> q0;
285 double logZ0 = 0.0;
286
287 cout.setf( ios_base::scientific );
288 cout.precision( 3 );
289
290 cout << "# " << filename << endl;
291 cout.width( 39 );
292 cout << left << "# METHOD" << "\t";
293 if( report_time )
294 cout << right << "SECONDS " << "\t";
295 if( report_iters )
296 cout << "ITERS" << "\t";
297 cout << "MAX ERROR" << "\t";
298 cout << "AVG ERROR" << "\t";
299 cout << "LOGZ ERROR" << "\t";
300 cout << "MAXDIFF" << "\t";
301 cout << endl;
302
303 for( size_t m = 0; m < methods.size(); m++ ) {
304 pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
305
306 if( vm.count("tol") )
307 meth.second.Set("tol",tol);
308 if( vm.count("maxiter") )
309 meth.second.Set("maxiter",maxiter);
310 if( vm.count("verbose") )
311 meth.second.Set("verbose",verbose);
312 TestDAI piet(fg, meth.first, meth.second );
313 piet.doDAI();
314 if( m == 0 ) {
315 q0 = piet.q;
316 logZ0 = piet.logZ;
317 }
318 piet.calcErrs(q0);
319
320 cout.width( 39 );
321 cout << left << methods[m] << "\t";
322 if( report_time )
323 cout << right << piet.time << "\t";
324 if( report_iters ) {
325 if( piet.has_iters ) {
326 cout << piet.iters << "\t";
327 } else {
328 cout << "N/A \t";
329 }
330 }
331
332 if( m > 0 ) {
333 cout.setf( ios_base::scientific );
334 cout.precision( 3 );
335
336 double me = clipdouble( piet.maxErr(), 1e-9 );
337 cout << me << "\t";
338
339 double ae = clipdouble( piet.avgErr(), 1e-9 );
340 cout << ae << "\t";
341
342 if( piet.has_logZ ) {
343 cout.setf( ios::showpos );
344 double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
345 cout << le << "\t";
346 cout.unsetf( ios::showpos );
347 } else
348 cout << "N/A \t";
349
350 if( piet.has_maxdiff ) {
351 double md = clipdouble( piet.maxdiff, 1e-9 );
352 if( isnan( me ) )
353 md = me;
354 if( isnan( ae ) )
355 md = ae;
356 cout << md << "\t";
357 } else
358 cout << "N/A \t";
359 }
360 cout << endl;
361
362 if( marginals ) {
363 for( size_t i = 0; i < piet.q.size(); i++ )
364 cout << "# " << piet.q[i] << endl;
365 }
366 }
367
368 return 0;
369 } catch( string &s ) {
370 cerr << "Exception: " << s << endl;
371 return 2;
372 }
373 }