Improved index.h (merged from SVN head), which yields a 25% speedup.
[libdai.git] / tests / test.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <fstream>
24 #include <map>
25 #include <numeric>
26 #include <cmath>
27 #include <cstdlib>
28 #include <boost/program_options.hpp>
29 #include <dai/util.h>
30 #include <dai/alldai.h>
31
32
33 using namespace std;
34 using namespace dai;
35 namespace po = boost::program_options;
36
37
38 class TestAI {
39 protected:
40 InfAlg *obj;
41 string name;
42 vector<double> err;
43
44 public:
45 vector<Factor> q;
46 double logZ;
47 double maxdiff;
48 clock_t time;
49
50 TestAI( const FactorGraph &fg, const string &_name, const Properties &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0) {
51 clock_t tic = toc();
52 obj = newInfAlg( name, fg, opts );
53 time += toc() - tic;
54 /*
55 } else if( method.substr(0,5) == "EXACT" ) { // EXACT
56 // Look if the network is small enough to do brute-force exact method
57 bool toolarge = false;
58 size_t total_statespace = 1;
59 for( size_t i = 0; i < fg.nrVars(); i++ ) {
60 total_statespace *= fg.var(i).states();
61 if( total_statespace > (1UL << 16) )
62 toolarge = true;
63 }
64
65 if( !toolarge ) {
66 Factor piet;
67 for( size_t I = 0; I < fg.nrFactors(); I++ )
68 piet *= fg.factor( I );
69 for( size_t i = 0; i < fg.nrVars(); i++ )
70 q.push_back(piet.marginal(fg.var(i)));
71 time += toc() - tic;
72 logZ = fg.ExactlogZ();
73 } else
74 throw "Network too large for EXACT method";
75 }
76 */
77 }
78
79 ~TestAI() {
80 if( obj != NULL )
81 delete obj;
82 }
83
84 string identify() {
85 if( obj != NULL )
86 return obj->identify();
87 else
88 return "NULL";
89 }
90
91 vector<Factor> allBeliefs() {
92 vector<Factor> result;
93 for( size_t i = 0; i < obj->nrVars(); i++ )
94 result.push_back( obj->belief( obj->var(i) ) );
95 return result;
96 }
97
98 void doAI() {
99 clock_t tic = toc();
100 // if( name == "EXACT" ) {
101 // // calculation has already been done
102 // }
103 if( obj != NULL ) {
104 obj->init();
105 obj->run();
106 time += toc() - tic;
107 logZ = real(obj->logZ());
108 maxdiff = obj->MaxDiff();
109 q = allBeliefs();
110 };
111 }
112
113 void calcErrs( const TestAI &x ) {
114 err.clear();
115 err.reserve( q.size() );
116 for( size_t i = 0; i < q.size(); i++ )
117 err.push_back( dist( q[i], x.q[i], Prob::DISTTV ) );
118 }
119
120 void calcErrs( const vector<Factor> &x ) {
121 err.clear();
122 err.reserve( q.size() );
123 for( size_t i = 0; i < q.size(); i++ )
124 err.push_back( dist( q[i], x[i], Prob::DISTTV ) );
125 }
126
127 double maxErr() {
128 return( *max_element( err.begin(), err.end() ) );
129 }
130
131 double avgErr() {
132 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
133 }
134 };
135
136
137 pair<string, Properties> parseMethod( const string &_s, const map<string,string> & aliases ) {
138 string s = _s;
139 if( aliases.find(_s) != aliases.end() )
140 s = aliases.find(_s)->second;
141
142 pair<string, Properties> result;
143 string & name = result.first;
144 Properties & opts = result.second;
145
146 string::size_type pos = s.find_first_of('[');
147 name = s.substr( 0, pos );
148 if( pos == string::npos )
149 throw "Malformed method";
150 size_t n = 0;
151 for( ; n < sizeof(DAINames) / sizeof(string); n++ )
152 if( name == DAINames[n] )
153 break;
154 if( n == sizeof(DAINames) / sizeof(string) )
155 throw "Unknown inference algorithm";
156
157 stringstream ss;
158 ss << s.substr(pos,s.length());
159 ss >> opts;
160
161 return result;
162 }
163
164
165 double clipdouble( double x, double minabs ) {
166 if( fabs(x) < minabs )
167 return minabs;
168 else
169 return x;
170 }
171
172
173 int main( int argc, char *argv[] ) {
174 try {
175 string filename;
176 string aliases;
177 vector<string> methods;
178 double tol;
179 size_t maxiter;
180 size_t verbose;
181
182 po::options_description opts_required("Required options");
183 opts_required.add_options()
184 ("filename", po::value< string >(&filename), "Filename of FactorGraph")
185 ("methods", po::value< vector<string> >(&methods)->multitoken(), "AI methods to test")
186 ;
187
188 po::options_description opts_optional("Allowed options");
189 opts_optional.add_options()
190 ("help", "produce help message")
191 ("aliases", po::value< string >(&aliases), "Filename for aliases")
192 ("tol", po::value< double >(&tol), "Override tolerance")
193 ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
194 ("verbose", po::value< size_t >(&verbose), "Override verbosity")
195 ;
196
197 po::options_description cmdline_options;
198 cmdline_options.add(opts_required).add(opts_optional);
199
200 po::variables_map vm;
201 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
202 po::notify(vm);
203
204 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
205 cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
206 cout << "inference algorithms <method*>, reporting clocks, max and average" << endl;
207 cout << "error and relative logZ error (comparing with the results of" << endl;
208 cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
209 cout << opts_required << opts_optional << endl;
210 return 1;
211 }
212
213 // Read aliases
214 map<string,string> Aliases;
215 if( !aliases.empty() ) {
216 ifstream infile;
217 infile.open (aliases.c_str());
218 if (infile.is_open()) {
219 while( true ) {
220 string line;
221 getline(infile,line);
222 if( infile.fail() )
223 break;
224 if( (!line.empty()) && (line[0] != '#') ) {
225 string::size_type pos = line.find(':',0);
226 if( pos == string::npos )
227 throw "Invalid alias";
228 else {
229 string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
230 string key = line.substr(0, posl + 1);
231 string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
232 string val = line.substr(pos + 1 + posr, line.length());
233 Aliases[key] = val;
234 }
235 }
236 }
237 infile.close();
238 } else
239 throw "Error opening aliases file";
240 }
241
242 FactorGraph fg;
243 if( fg.ReadFromFile(filename.c_str()) ) {
244 cout << "Error reading " << filename << endl;
245 return 2;
246 } else {
247 vector<Factor> q0;
248 double logZ0 = 0.0;
249
250 cout << "# " << filename << endl;
251 cout.width( 40 );
252 cout << left << "# METHOD" << " ";
253 cout.width( 10 );
254 cout << right << "CLOCKS" << " ";
255 cout.width( 10 );
256 cout << "MAX ERROR" << " ";
257 cout.width( 10 );
258 cout << "AVG ERROR" << " ";
259 cout.width( 10 );
260 cout << "LOGZ ERROR" << " ";
261 cout.width( 10 );
262 cout << "MAXDIFF" << endl;
263
264 for( size_t m = 0; m < methods.size(); m++ ) {
265 pair<string, Properties> meth = parseMethod( methods[m], Aliases );
266
267 if( vm.count("tol") )
268 meth.second.Set("tol",tol);
269 if( vm.count("maxiter") )
270 meth.second.Set("maxiter",maxiter);
271 if( vm.count("verbose") )
272 meth.second.Set("verbose",verbose);
273 TestAI piet(fg, meth.first, meth.second );
274 piet.doAI();
275 if( m == 0 ) {
276 q0 = piet.q;
277 logZ0 = piet.logZ;
278 }
279 piet.calcErrs(q0);
280
281 cout.width( 40 );
282 // cout << left << piet.identify() << " ";
283 cout << left << methods[m] << " ";
284 cout.width( 10 );
285 cout << right << piet.time << " ";
286
287 if( m > 0 ) {
288 cout.setf( ios_base::scientific );
289 cout.precision( 3 );
290 cout.width( 10 );
291 double me = clipdouble( piet.maxErr(), 1e-9 );
292 cout << me << " ";
293 cout.width( 10 );
294 double ae = clipdouble( piet.avgErr(), 1e-9 );
295 cout << ae << " ";
296 cout.width( 10 );
297 double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
298 cout << le << " ";
299 cout.width( 10 );
300 double md = clipdouble( piet.maxdiff, 1e-9 );
301 if( isnan( me ) )
302 md = me;
303 if( isnan( ae ) )
304 md = ae;
305 cout << md << endl;
306 } else
307 cout << endl;
308 }
309 }
310 } catch(const char *e) {
311 cerr << "Exception: " << e << endl;
312 return 1;
313 } catch(exception& e) {
314 cerr << "Exception: " << e.what() << endl;
315 return 1;
316 }
317 catch(...) {
318 cerr << "Exception of unknown type!" << endl;
319 }
320
321 return 0;
322 }