Several changes by Giuseppe Passino
[libdai.git] / tests / testdai.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <fstream>
25 #include <map>
26 #include <numeric>
27 #include <cmath>
28 #include <cstdlib>
29 #include <cstring>
30 #include <boost/program_options.hpp>
31 #include <dai/util.h>
32 #include <dai/alldai.h>
33
34
35 using namespace std;
36 using namespace dai;
37 namespace po = boost::program_options;
38
39
40 class TestDAI {
41 protected:
42 InfAlg *obj;
43 string name;
44 vector<double> err;
45
46 public:
47 vector<Factor> q;
48 double logZ;
49 double maxdiff;
50 double time;
51 size_t iters;
52 bool has_logZ;
53 bool has_maxdiff;
54 bool has_iters;
55
56 TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
57 double tic = toc();
58 if( name == "LDPC" ) {
59 double zero[2] = {1.0, 0.0};
60 q.clear();
61 for( size_t i = 0; i < fg.nrVars(); i++ )
62 q.push_back( Factor(Var(i,2), zero) );
63 logZ = 0.0;
64 maxdiff = 0.0;
65 iters = 1;
66 has_logZ = false;
67 has_maxdiff = false;
68 has_iters = false;
69 } else
70 obj = newInfAlg( name, fg, opts );
71 time += toc() - tic;
72 }
73
74 ~TestDAI() {
75 if( obj != NULL )
76 delete obj;
77 }
78
79 string identify() {
80 if( obj != NULL )
81 return obj->identify();
82 else
83 return "NULL";
84 }
85
86 vector<Factor> allBeliefs() {
87 vector<Factor> result;
88 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
89 result.push_back( obj->belief( obj->fg().var(i) ) );
90 return result;
91 }
92
93 void doDAI() {
94 double tic = toc();
95 if( obj != NULL ) {
96 obj->init();
97 obj->run();
98 time += toc() - tic;
99 try {
100 logZ = obj->logZ();
101 has_logZ = true;
102 } catch( Exception &e ) {
103 has_logZ = false;
104 }
105 try {
106 maxdiff = obj->maxDiff();
107 has_maxdiff = true;
108 } catch( Exception &e ) {
109 has_maxdiff = false;
110 }
111 try {
112 iters = obj->Iterations();
113 has_iters = true;
114 } catch( Exception &e ) {
115 has_iters = false;
116 }
117 q = allBeliefs();
118 };
119 }
120
121 void calcErrs( const TestDAI &x ) {
122 err.clear();
123 err.reserve( q.size() );
124 for( size_t i = 0; i < q.size(); i++ )
125 err.push_back( dist( q[i], x.q[i], Prob::DISTTV ) );
126 }
127
128 void calcErrs( const vector<Factor> &x ) {
129 err.clear();
130 err.reserve( q.size() );
131 for( size_t i = 0; i < q.size(); i++ )
132 err.push_back( dist( q[i], x[i], Prob::DISTTV ) );
133 }
134
135 double maxErr() {
136 return( *max_element( err.begin(), err.end() ) );
137 }
138
139 double avgErr() {
140 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
141 }
142 };
143
144
145 pair<string, PropertySet> parseMethod( const string &_s, const map<string,string> & aliases ) {
146 // s = first part of _s, until '['
147 string::size_type pos = _s.find_first_of('[');
148 string s;
149 if( pos == string::npos )
150 s = _s;
151 else
152 s = _s.substr(0,pos);
153
154 // if the first part is an alias, substitute
155 if( aliases.find(s) != aliases.end() )
156 s = aliases.find(s)->second;
157
158 // attach second part, merging properties if necessary
159 if( pos != string::npos ) {
160 if( s.at(s.length()-1) == ']' ) {
161 s = s.erase(s.length()-1,1) + ',' + _s.substr(pos+1);
162 } else
163 s = s + _s.substr(pos);
164 }
165
166 pair<string, PropertySet> result;
167 string & name = result.first;
168 PropertySet & opts = result.second;
169
170 pos = s.find_first_of('[');
171 if( pos == string::npos )
172 throw "Malformed method";
173 name = s.substr( 0, pos );
174 size_t n = 0;
175 for( ; strlen( DAINames[n] ) != 0; n++ )
176 if( name == DAINames[n] )
177 break;
178 if( strlen( DAINames[n] ) == 0 && (name != "LDPC") )
179 DAI_THROW(UNKNOWN_DAI_ALGORITHM);
180
181 stringstream ss;
182 ss << s.substr(pos,s.length());
183 ss >> opts;
184
185 return result;
186 }
187
188
189 double clipdouble( double x, double minabs ) {
190 if( fabs(x) < minabs )
191 return minabs;
192 else
193 return x;
194 }
195
196
197 int main( int argc, char *argv[] ) {
198 try {
199 string filename;
200 string aliases;
201 vector<string> methods;
202 double tol;
203 size_t maxiter;
204 size_t verbose;
205 bool marginals = false;
206 bool report_iters = true;
207 bool report_time = true;
208
209 po::options_description opts_required("Required options");
210 opts_required.add_options()
211 ("filename", po::value< string >(&filename), "Filename of FactorGraph")
212 ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to test")
213 ;
214
215 po::options_description opts_optional("Allowed options");
216 opts_optional.add_options()
217 ("help", "produce help message")
218 ("aliases", po::value< string >(&aliases), "Filename for aliases")
219 ("tol", po::value< double >(&tol), "Override tolerance")
220 ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
221 ("verbose", po::value< size_t >(&verbose), "Override verbosity")
222 ("marginals", po::value< bool >(&marginals), "Output single node marginals?")
223 ("report-time", po::value< bool >(&report_time), "Report calculation time")
224 ("report-iters", po::value< bool >(&report_iters), "Report iterations needed")
225 ;
226
227 po::options_description cmdline_options;
228 cmdline_options.add(opts_required).add(opts_optional);
229
230 po::variables_map vm;
231 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
232 po::notify(vm);
233
234 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
235 cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
236 cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
237 cout << "error and relative logZ error (comparing with the results of" << endl;
238 cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
239 cout << opts_required << opts_optional << endl;
240 return 1;
241 }
242
243 // Read aliases
244 map<string,string> Aliases;
245 if( !aliases.empty() ) {
246 ifstream infile;
247 infile.open (aliases.c_str());
248 if (infile.is_open()) {
249 while( true ) {
250 string line;
251 getline(infile,line);
252 if( infile.fail() )
253 break;
254 if( (!line.empty()) && (line[0] != '#') ) {
255 string::size_type pos = line.find(':',0);
256 if( pos == string::npos )
257 throw "Invalid alias";
258 else {
259 string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
260 string key = line.substr(0, posl + 1);
261 string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
262 string val = line.substr(pos + 1 + posr, line.length());
263 Aliases[key] = val;
264 }
265 }
266 }
267 infile.close();
268 } else
269 throw "Error opening aliases file";
270 }
271
272 FactorGraph fg;
273 fg.ReadFromFile( filename.c_str() );
274
275 vector<Factor> q0;
276 double logZ0 = 0.0;
277
278 cout.setf( ios_base::scientific );
279 cout.precision( 3 );
280
281 cout << "# " << filename << endl;
282 cout.width( 40 );
283 cout << left << "# METHOD" << " ";
284 if( report_time ) {
285 cout.width( 10 );
286 cout << right << "SECONDS" << " ";
287 }
288 if( report_iters ) {
289 cout.width( 10 );
290 cout << "ITERS" << " ";
291 }
292 cout.width( 10 );
293 cout << "MAX ERROR" << " ";
294 cout.width( 10 );
295 cout << "AVG ERROR" << " ";
296 cout.width( 10 );
297 cout << "LOGZ ERROR" << " ";
298 cout.width( 10 );
299 cout << "MAXDIFF" << " ";
300 cout << endl;
301
302 for( size_t m = 0; m < methods.size(); m++ ) {
303 pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
304
305 if( vm.count("tol") )
306 meth.second.Set("tol",tol);
307 if( vm.count("maxiter") )
308 meth.second.Set("maxiter",maxiter);
309 if( vm.count("verbose") )
310 meth.second.Set("verbose",verbose);
311 TestDAI piet(fg, meth.first, meth.second );
312 piet.doDAI();
313 if( m == 0 ) {
314 q0 = piet.q;
315 logZ0 = piet.logZ;
316 }
317 piet.calcErrs(q0);
318
319 cout.width( 40 );
320 cout << left << methods[m] << " ";
321 if( report_time ) {
322 cout.width( 10 );
323 cout << right << piet.time << " ";
324 }
325 if( report_iters ) {
326 cout.width( 10 );
327 if( piet.has_iters ) {
328 cout << piet.iters << " ";
329 } else {
330 cout << "N/A ";
331 }
332 }
333
334 if( m > 0 ) {
335 cout.setf( ios_base::scientific );
336 cout.precision( 3 );
337
338 cout.width( 10 );
339 double me = clipdouble( piet.maxErr(), 1e-9 );
340 cout << me << " ";
341
342 cout.width( 10 );
343 double ae = clipdouble( piet.avgErr(), 1e-9 );
344 cout << ae << " ";
345
346 cout.width( 10 );
347 if( piet.has_logZ ) {
348 double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
349 cout << le << " ";
350 } else
351 cout << "N/A ";
352
353 cout.width( 10 );
354 if( piet.has_maxdiff ) {
355 double md = clipdouble( piet.maxdiff, 1e-9 );
356 if( isnan( me ) )
357 md = me;
358 if( isnan( ae ) )
359 md = ae;
360 cout << md << " ";
361 } else
362 cout << "N/A ";
363 }
364 cout << endl;
365
366 if( marginals ) {
367 for( size_t i = 0; i < piet.q.size(); i++ )
368 cout << "# " << piet.q[i] << endl;
369 }
370 }
371 } catch(const char *e) {
372 cerr << "Exception: " << e << endl;
373 return 1;
374 } catch(exception& e) {
375 cerr << "Exception: " << e.what() << endl;
376 return 1;
377 }
378 catch(...) {
379 cerr << "Exception of unknown type!" << endl;
380 }
381
382 return 0;
383 }