Now compiles also with Visual Studio 2008 under Windows (still buggy!)
[libdai.git] / tests / test.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <fstream>
24 #include <map>
25 #include <numeric>
26 #include <cmath>
27 #include <cstdlib>
28 #include <cstring>
29 #include <boost/program_options.hpp>
30 #include <dai/util.h>
31 #include <dai/alldai.h>
32
33
34 using namespace std;
35 using namespace dai;
36 namespace po = boost::program_options;
37
38
39 class TestAI {
40 protected:
41 InfAlg *obj;
42 string name;
43 vector<double> err;
44
45 public:
46 vector<Factor> q;
47 double logZ;
48 double maxdiff;
49 double time;
50
51 TestAI( const FactorGraph &fg, const string &_name, const Properties &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0) {
52 double tic = toc();
53 obj = newInfAlg( name, fg, opts );
54 time += toc() - tic;
55 /*
56 } else if( method.substr(0,5) == "EXACT" ) { // EXACT
57 // Look if the network is small enough to do brute-force exact method
58 bool toolarge = false;
59 size_t total_statespace = 1;
60 for( size_t i = 0; i < fg.nrVars(); i++ ) {
61 total_statespace *= fg.var(i).states();
62 if( total_statespace > (1UL << 16) )
63 toolarge = true;
64 }
65
66 if( !toolarge ) {
67 Factor piet;
68 for( size_t I = 0; I < fg.nrFactors(); I++ )
69 piet *= fg.factor( I );
70 for( size_t i = 0; i < fg.nrVars(); i++ )
71 q.push_back(piet.marginal(fg.var(i)));
72 time += toc() - tic;
73 logZ = fg.ExactlogZ();
74 } else
75 throw "Network too large for EXACT method";
76 }
77 */
78 }
79
80 ~TestAI() {
81 if( obj != NULL )
82 delete obj;
83 }
84
85 string identify() {
86 if( obj != NULL )
87 return obj->identify();
88 else
89 return "NULL";
90 }
91
92 vector<Factor> allBeliefs() {
93 vector<Factor> result;
94 for( size_t i = 0; i < obj->nrVars(); i++ )
95 result.push_back( obj->belief( obj->var(i) ) );
96 return result;
97 }
98
99 void doAI() {
100 double tic = toc();
101 // if( name == "EXACT" ) {
102 // // calculation has already been done
103 // }
104 if( obj != NULL ) {
105 obj->init();
106 obj->run();
107 time += toc() - tic;
108 logZ = real(obj->logZ());
109 maxdiff = obj->MaxDiff();
110 q = allBeliefs();
111 };
112 }
113
114 void calcErrs( const TestAI &x ) {
115 err.clear();
116 err.reserve( q.size() );
117 for( size_t i = 0; i < q.size(); i++ )
118 err.push_back( dist( q[i], x.q[i], Prob::DISTTV ) );
119 }
120
121 void calcErrs( const vector<Factor> &x ) {
122 err.clear();
123 err.reserve( q.size() );
124 for( size_t i = 0; i < q.size(); i++ )
125 err.push_back( dist( q[i], x[i], Prob::DISTTV ) );
126 }
127
128 double maxErr() {
129 return( *max_element( err.begin(), err.end() ) );
130 }
131
132 double avgErr() {
133 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
134 }
135 };
136
137
138 pair<string, Properties> parseMethod( const string &_s, const map<string,string> & aliases ) {
139 string s = _s;
140 if( aliases.find(_s) != aliases.end() )
141 s = aliases.find(_s)->second;
142
143 pair<string, Properties> result;
144 string & name = result.first;
145 Properties & opts = result.second;
146
147 string::size_type pos = s.find_first_of('[');
148 name = s.substr( 0, pos );
149 if( pos == string::npos )
150 throw "Malformed method";
151 size_t n = 0;
152 for( ; strlen( DAINames[n] ) != 0; n++ )
153 if( name == DAINames[n] )
154 break;
155 if( strlen( DAINames[n] ) == 0 )
156 throw "Unknown inference algorithm";
157
158 stringstream ss;
159 ss << s.substr(pos,s.length());
160 ss >> opts;
161
162 return result;
163 }
164
165
166 double clipdouble( double x, double minabs ) {
167 if( fabs(x) < minabs )
168 return minabs;
169 else
170 return x;
171 }
172
173
174 int main( int argc, char *argv[] ) {
175 try {
176 string filename;
177 string aliases;
178 vector<string> methods;
179 double tol;
180 size_t maxiter;
181 size_t verbose;
182
183 po::options_description opts_required("Required options");
184 opts_required.add_options()
185 ("filename", po::value< string >(&filename), "Filename of FactorGraph")
186 ("methods", po::value< vector<string> >(&methods)->multitoken(), "AI methods to test")
187 ;
188
189 po::options_description opts_optional("Allowed options");
190 opts_optional.add_options()
191 ("help", "produce help message")
192 ("aliases", po::value< string >(&aliases), "Filename for aliases")
193 ("tol", po::value< double >(&tol), "Override tolerance")
194 ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
195 ("verbose", po::value< size_t >(&verbose), "Override verbosity")
196 ;
197
198 po::options_description cmdline_options;
199 cmdline_options.add(opts_required).add(opts_optional);
200
201 po::variables_map vm;
202 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
203 po::notify(vm);
204
205 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
206 cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
207 cout << "inference algorithms <method*>, reporting clocks, max and average" << endl;
208 cout << "error and relative logZ error (comparing with the results of" << endl;
209 cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
210 cout << opts_required << opts_optional << endl;
211 return 1;
212 }
213
214 // Read aliases
215 map<string,string> Aliases;
216 if( !aliases.empty() ) {
217 ifstream infile;
218 infile.open (aliases.c_str());
219 if (infile.is_open()) {
220 while( true ) {
221 string line;
222 getline(infile,line);
223 if( infile.fail() )
224 break;
225 if( (!line.empty()) && (line[0] != '#') ) {
226 string::size_type pos = line.find(':',0);
227 if( pos == string::npos )
228 throw "Invalid alias";
229 else {
230 string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
231 string key = line.substr(0, posl + 1);
232 string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
233 string val = line.substr(pos + 1 + posr, line.length());
234 Aliases[key] = val;
235 }
236 }
237 }
238 infile.close();
239 } else
240 throw "Error opening aliases file";
241 }
242
243 FactorGraph fg;
244 if( fg.ReadFromFile(filename.c_str()) ) {
245 cout << "Error reading " << filename << endl;
246 return 2;
247 } else {
248 vector<Factor> q0;
249 double logZ0 = 0.0;
250
251 cout << "# " << filename << endl;
252 cout.width( 40 );
253 cout << left << "# METHOD" << " ";
254 cout.width( 10 );
255 cout << right << "SECONDS" << " ";
256 cout.width( 10 );
257 cout << "MAX ERROR" << " ";
258 cout.width( 10 );
259 cout << "AVG ERROR" << " ";
260 cout.width( 10 );
261 cout << "LOGZ ERROR" << " ";
262 cout.width( 10 );
263 cout << "MAXDIFF" << endl;
264
265 for( size_t m = 0; m < methods.size(); m++ ) {
266 pair<string, Properties> meth = parseMethod( methods[m], Aliases );
267
268 if( vm.count("tol") )
269 meth.second.Set("tol",tol);
270 if( vm.count("maxiter") )
271 meth.second.Set("maxiter",maxiter);
272 if( vm.count("verbose") )
273 meth.second.Set("verbose",verbose);
274 TestAI piet(fg, meth.first, meth.second );
275 piet.doAI();
276 if( m == 0 ) {
277 q0 = piet.q;
278 logZ0 = piet.logZ;
279 }
280 piet.calcErrs(q0);
281
282 cout.width( 40 );
283 // cout << left << piet.identify() << " ";
284 cout << left << methods[m] << " ";
285 cout.width( 10 );
286 cout << right << piet.time << " ";
287
288 if( m > 0 ) {
289 cout.setf( ios_base::scientific );
290 cout.precision( 3 );
291 cout.width( 10 );
292 double me = clipdouble( piet.maxErr(), 1e-9 );
293 cout << me << " ";
294 cout.width( 10 );
295 double ae = clipdouble( piet.avgErr(), 1e-9 );
296 cout << ae << " ";
297 cout.width( 10 );
298 double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
299 cout << le << " ";
300 cout.width( 10 );
301 double md = clipdouble( piet.maxdiff, 1e-9 );
302 if( isnan( me ) )
303 md = me;
304 if( isnan( ae ) )
305 md = ae;
306 cout << md << endl;
307 } else
308 cout << endl;
309 }
310 }
311 } catch(const char *e) {
312 cerr << "Exception: " << e << endl;
313 return 1;
314 } catch(exception& e) {
315 cerr << "Exception: " << e.what() << endl;
316 return 1;
317 }
318 catch(...) {
319 cerr << "Exception of unknown type!" << endl;
320 }
321
322 return 0;
323 }