Merged SVN head ...
[libdai.git] / tests / test.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <fstream>
24 #include <map>
25 #include <numeric>
26 #include <cmath>
27 #include <cstdlib>
28 #include <cstring>
29 #include <boost/program_options.hpp>
30 #include <dai/util.h>
31 #include <dai/alldai.h>
32
33
34 using namespace std;
35 using namespace dai;
36 namespace po = boost::program_options;
37
38
39 class TestDAI {
40 protected:
41 InfAlg *obj;
42 string name;
43 vector<double> err;
44
45 public:
46 vector<Factor> q;
47 double logZ;
48 double maxdiff;
49 double time;
50 size_t iters;
51 bool has_logZ;
52 bool has_maxdiff;
53 bool has_iters;
54
55 TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
56 double tic = toc();
57 if( name == "LDPC" ) {
58 double zero[2] = {1.0, 0.0};
59 q.clear();
60 for( size_t i = 0; i < fg.nrVars(); i++ )
61 q.push_back( Factor(Var(i,2), zero) );
62 logZ = 0.0;
63 maxdiff = 0.0;
64 iters = 1;
65 has_logZ = false;
66 has_maxdiff = false;
67 has_iters = false;
68 } else
69 obj = newInfAlg( name, fg, opts );
70 time += toc() - tic;
71 }
72
73 ~TestDAI() {
74 if( obj != NULL )
75 delete obj;
76 }
77
78 string identify() {
79 if( obj != NULL )
80 return obj->identify();
81 else
82 return "NULL";
83 }
84
85 vector<Factor> allBeliefs() {
86 vector<Factor> result;
87 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
88 result.push_back( obj->belief( obj->fg().var(i) ) );
89 return result;
90 }
91
92 void doDAI() {
93 double tic = toc();
94 if( obj != NULL ) {
95 obj->init();
96 obj->run();
97 time += toc() - tic;
98 try {
99 logZ = obj->logZ();
100 has_logZ = true;
101 } catch( Exception &e ) {
102 has_logZ = false;
103 }
104 try {
105 maxdiff = obj->maxDiff();
106 has_maxdiff = true;
107 } catch( Exception &e ) {
108 has_maxdiff = false;
109 }
110 try {
111 iters = obj->Iterations();
112 has_iters = true;
113 } catch( Exception &e ) {
114 has_iters = false;
115 }
116 q = allBeliefs();
117 };
118 }
119
120 void calcErrs( const TestDAI &x ) {
121 err.clear();
122 err.reserve( q.size() );
123 for( size_t i = 0; i < q.size(); i++ )
124 err.push_back( dist( q[i], x.q[i], Prob::DISTTV ) );
125 }
126
127 void calcErrs( const vector<Factor> &x ) {
128 err.clear();
129 err.reserve( q.size() );
130 for( size_t i = 0; i < q.size(); i++ )
131 err.push_back( dist( q[i], x[i], Prob::DISTTV ) );
132 }
133
134 double maxErr() {
135 return( *max_element( err.begin(), err.end() ) );
136 }
137
138 double avgErr() {
139 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
140 }
141 };
142
143
144 pair<string, PropertySet> parseMethod( const string &_s, const map<string,string> & aliases ) {
145 // s = first part of _s, until '['
146 string::size_type pos = _s.find_first_of('[');
147 string s;
148 if( pos == string::npos )
149 s = _s;
150 else
151 s = _s.substr(0,pos);
152
153 // if the first part is an alias, substitute
154 if( aliases.find(s) != aliases.end() )
155 s = aliases.find(s)->second;
156
157 // attach second part, merging properties if necessary
158 if( pos != string::npos ) {
159 if( s.at(s.length()-1) == ']' ) {
160 s = s.erase(s.length()-1,1) + ',' + _s.substr(pos+1);
161 } else
162 s = s + _s.substr(pos);
163 }
164
165 pair<string, PropertySet> result;
166 string & name = result.first;
167 PropertySet & opts = result.second;
168
169 pos = s.find_first_of('[');
170 if( pos == string::npos )
171 throw "Malformed method";
172 name = s.substr( 0, pos );
173 size_t n = 0;
174 for( ; strlen( DAINames[n] ) != 0; n++ )
175 if( name == DAINames[n] )
176 break;
177 if( strlen( DAINames[n] ) == 0 && (name != "LDPC") )
178 DAI_THROW(UNKNOWN_DAI_ALGORITHM);
179
180 stringstream ss;
181 ss << s.substr(pos,s.length());
182 ss >> opts;
183
184 return result;
185 }
186
187
188 double clipdouble( double x, double minabs ) {
189 if( fabs(x) < minabs )
190 return minabs;
191 else
192 return x;
193 }
194
195
196 int main( int argc, char *argv[] ) {
197 try {
198 string filename;
199 string aliases;
200 vector<string> methods;
201 double tol;
202 size_t maxiter;
203 size_t verbose;
204 bool marginals = false;
205 bool report_time = true;
206
207 po::options_description opts_required("Required options");
208 opts_required.add_options()
209 ("filename", po::value< string >(&filename), "Filename of FactorGraph")
210 ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to test")
211 ;
212
213 po::options_description opts_optional("Allowed options");
214 opts_optional.add_options()
215 ("help", "produce help message")
216 ("aliases", po::value< string >(&aliases), "Filename for aliases")
217 ("tol", po::value< double >(&tol), "Override tolerance")
218 ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
219 ("verbose", po::value< size_t >(&verbose), "Override verbosity")
220 ("marginals", po::value< bool >(&marginals), "Output single node marginals?")
221 ("report-time", po::value< bool >(&report_time), "Report calculation time")
222 ;
223
224 po::options_description cmdline_options;
225 cmdline_options.add(opts_required).add(opts_optional);
226
227 po::variables_map vm;
228 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
229 po::notify(vm);
230
231 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
232 cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
233 cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
234 cout << "error and relative logZ error (comparing with the results of" << endl;
235 cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
236 cout << opts_required << opts_optional << endl;
237 return 1;
238 }
239
240 // Read aliases
241 map<string,string> Aliases;
242 if( !aliases.empty() ) {
243 ifstream infile;
244 infile.open (aliases.c_str());
245 if (infile.is_open()) {
246 while( true ) {
247 string line;
248 getline(infile,line);
249 if( infile.fail() )
250 break;
251 if( (!line.empty()) && (line[0] != '#') ) {
252 string::size_type pos = line.find(':',0);
253 if( pos == string::npos )
254 throw "Invalid alias";
255 else {
256 string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
257 string key = line.substr(0, posl + 1);
258 string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
259 string val = line.substr(pos + 1 + posr, line.length());
260 Aliases[key] = val;
261 }
262 }
263 }
264 infile.close();
265 } else
266 throw "Error opening aliases file";
267 }
268
269 FactorGraph fg;
270 fg.ReadFromFile( filename.c_str() );
271
272 vector<Factor> q0;
273 double logZ0 = 0.0;
274
275 cout << "# " << filename << endl;
276 cout.width( 40 );
277 cout << left << "# METHOD" << " ";
278 if( report_time ) {
279 cout.width( 10 );
280 cout << right << "SECONDS" << " ";
281 }
282 cout.width( 10 );
283 cout << "MAX ERROR" << " ";
284 cout.width( 10 );
285 cout << "AVG ERROR" << " ";
286 cout.width( 10 );
287 cout << "LOGZ ERROR" << " ";
288 cout.width( 10 );
289 cout << "MAXDIFF" << " ";
290 cout.width( 10 );
291 cout << "ITERS" << endl;
292
293 for( size_t m = 0; m < methods.size(); m++ ) {
294 pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
295
296 if( vm.count("tol") )
297 meth.second.Set("tol",tol);
298 if( vm.count("maxiter") )
299 meth.second.Set("maxiter",maxiter);
300 if( vm.count("verbose") )
301 meth.second.Set("verbose",verbose);
302 TestDAI piet(fg, meth.first, meth.second );
303 piet.doDAI();
304 if( m == 0 ) {
305 q0 = piet.q;
306 logZ0 = piet.logZ;
307 }
308 piet.calcErrs(q0);
309
310 cout.width( 40 );
311 // cout << left << piet.identify() << " ";
312 cout << left << methods[m] << " ";
313 if( report_time ) {
314 cout.width( 10 );
315 cout << right << piet.time << " ";
316 }
317
318 if( m > 0 ) {
319 cout.setf( ios_base::scientific );
320 cout.precision( 3 );
321 cout.width( 10 );
322 double me = clipdouble( piet.maxErr(), 1e-9 );
323 cout << me << " ";
324 cout.width( 10 );
325 double ae = clipdouble( piet.avgErr(), 1e-9 );
326 cout << ae << " ";
327 cout.width( 10 );
328 if( piet.has_logZ ) {
329 double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
330 cout << le << " ";
331 } else {
332 cout << "N/A ";
333 }
334 cout.width( 10 );
335 if( piet.has_maxdiff ) {
336 double md = clipdouble( piet.maxdiff, 1e-9 );
337 if( isnan( me ) )
338 md = me;
339 if( isnan( ae ) )
340 md = ae;
341 cout << md << " ";
342 } else {
343 cout << "N/A ";
344 }
345 cout.width( 10 );
346 if( piet.has_iters ) {
347 cout << piet.iters << " ";
348 } else {
349 cout << "N/A ";
350 }
351 }
352 cout << endl;
353
354 if( marginals ) {
355 for( size_t i = 0; i < piet.q.size(); i++ )
356 cout << "# " << piet.q[i] << endl;
357 }
358 }
359 } catch(const char *e) {
360 cerr << "Exception: " << e << endl;
361 return 1;
362 } catch(exception& e) {
363 cerr << "Exception: " << e.what() << endl;
364 return 1;
365 }
366 catch(...) {
367 cerr << "Exception of unknown type!" << endl;
368 }
369
370 return 0;
371 }