Replaced complex numbers by real numbers
[libdai.git] / tests / test.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <fstream>
24 #include <map>
25 #include <numeric>
26 #include <cmath>
27 #include <cstdlib>
28 #include <cstring>
29 #include <boost/program_options.hpp>
30 #include <dai/util.h>
31 #include <dai/alldai.h>
32
33
34 using namespace std;
35 using namespace dai;
36 namespace po = boost::program_options;
37
38
39 class TestAI {
40 protected:
41 InfAlg *obj;
42 string name;
43 vector<double> err;
44
45 public:
46 vector<Factor> q;
47 double logZ;
48 double maxdiff;
49 double time;
50 bool has_logZ;
51
52 TestAI( const FactorGraph &fg, const string &_name, const Properties &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0), has_logZ(true) {
53 double tic = toc();
54 obj = newInfAlg( name, fg, opts );
55 time += toc() - tic;
56 /*
57 } else if( method.substr(0,5) == "EXACT" ) { // EXACT
58 // Look if the network is small enough to do brute-force exact method
59 bool toolarge = false;
60 size_t total_statespace = 1;
61 for( size_t i = 0; i < fg.nrVars(); i++ ) {
62 total_statespace *= fg.var(i).states();
63 if( total_statespace > (1UL << 16) )
64 toolarge = true;
65 }
66
67 if( !toolarge ) {
68 Factor piet;
69 for( size_t I = 0; I < fg.nrFactors(); I++ )
70 piet *= fg.factor( I );
71 for( size_t i = 0; i < fg.nrVars(); i++ )
72 q.push_back(piet.marginal(fg.var(i)));
73 time += toc() - tic;
74 logZ = fg.ExactlogZ();
75 } else
76 throw "Network too large for EXACT method";
77 }
78 */
79 }
80
81 ~TestAI() {
82 if( obj != NULL )
83 delete obj;
84 }
85
86 string identify() {
87 if( obj != NULL )
88 return obj->identify();
89 else
90 return "NULL";
91 }
92
93 vector<Factor> allBeliefs() {
94 vector<Factor> result;
95 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
96 result.push_back( obj->belief( obj->fg().var(i) ) );
97 return result;
98 }
99
100 void doAI() {
101 double tic = toc();
102 // if( name == "EXACT" ) {
103 // // calculation has already been done
104 // }
105 if( obj != NULL ) {
106 obj->init();
107 obj->run();
108 time += toc() - tic;
109 try {
110 logZ = obj->logZ();
111 has_logZ = true;
112 } catch( Exception &e ) {
113 has_logZ = false;
114 }
115 maxdiff = obj->MaxDiff();
116 q = allBeliefs();
117 };
118 }
119
120 void calcErrs( const TestAI &x ) {
121 err.clear();
122 err.reserve( q.size() );
123 for( size_t i = 0; i < q.size(); i++ )
124 err.push_back( dist( q[i], x.q[i], Prob::DISTTV ) );
125 }
126
127 void calcErrs( const vector<Factor> &x ) {
128 err.clear();
129 err.reserve( q.size() );
130 for( size_t i = 0; i < q.size(); i++ )
131 err.push_back( dist( q[i], x[i], Prob::DISTTV ) );
132 }
133
134 double maxErr() {
135 return( *max_element( err.begin(), err.end() ) );
136 }
137
138 double avgErr() {
139 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
140 }
141 };
142
143
144 pair<string, Properties> parseMethod( const string &_s, const map<string,string> & aliases ) {
145 string s = _s;
146 if( aliases.find(_s) != aliases.end() )
147 s = aliases.find(_s)->second;
148
149 pair<string, Properties> result;
150 string & name = result.first;
151 Properties & opts = result.second;
152
153 string::size_type pos = s.find_first_of('[');
154 name = s.substr( 0, pos );
155 if( pos == string::npos )
156 throw "Malformed method";
157 size_t n = 0;
158 for( ; strlen( DAINames[n] ) != 0; n++ )
159 if( name == DAINames[n] )
160 break;
161 if( strlen( DAINames[n] ) == 0 )
162 throw "Unknown inference algorithm";
163
164 stringstream ss;
165 ss << s.substr(pos,s.length());
166 ss >> opts;
167
168 return result;
169 }
170
171
172 double clipdouble( double x, double minabs ) {
173 if( fabs(x) < minabs )
174 return minabs;
175 else
176 return x;
177 }
178
179
180 int main( int argc, char *argv[] ) {
181 try {
182 string filename;
183 string aliases;
184 vector<string> methods;
185 double tol;
186 size_t maxiter;
187 size_t verbose;
188 bool report_time = true;
189
190 po::options_description opts_required("Required options");
191 opts_required.add_options()
192 ("filename", po::value< string >(&filename), "Filename of FactorGraph")
193 ("methods", po::value< vector<string> >(&methods)->multitoken(), "AI methods to test")
194 ;
195
196 po::options_description opts_optional("Allowed options");
197 opts_optional.add_options()
198 ("help", "produce help message")
199 ("aliases", po::value< string >(&aliases), "Filename for aliases")
200 ("tol", po::value< double >(&tol), "Override tolerance")
201 ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
202 ("verbose", po::value< size_t >(&verbose), "Override verbosity")
203 ("report-time", po::value< bool >(&report_time), "Report calculation time")
204 ;
205
206 po::options_description cmdline_options;
207 cmdline_options.add(opts_required).add(opts_optional);
208
209 po::variables_map vm;
210 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
211 po::notify(vm);
212
213 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
214 cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
215 cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
216 cout << "error and relative logZ error (comparing with the results of" << endl;
217 cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
218 cout << opts_required << opts_optional << endl;
219 return 1;
220 }
221
222 // Read aliases
223 map<string,string> Aliases;
224 if( !aliases.empty() ) {
225 ifstream infile;
226 infile.open (aliases.c_str());
227 if (infile.is_open()) {
228 while( true ) {
229 string line;
230 getline(infile,line);
231 if( infile.fail() )
232 break;
233 if( (!line.empty()) && (line[0] != '#') ) {
234 string::size_type pos = line.find(':',0);
235 if( pos == string::npos )
236 throw "Invalid alias";
237 else {
238 string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
239 string key = line.substr(0, posl + 1);
240 string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
241 string val = line.substr(pos + 1 + posr, line.length());
242 Aliases[key] = val;
243 }
244 }
245 }
246 infile.close();
247 } else
248 throw "Error opening aliases file";
249 }
250
251 FactorGraph fg;
252 if( fg.ReadFromFile(filename.c_str()) ) {
253 cout << "Error reading " << filename << endl;
254 return 2;
255 } else {
256 vector<Factor> q0;
257 double logZ0 = 0.0;
258
259 cout << "# " << filename << endl;
260 cout.width( 40 );
261 cout << left << "# METHOD" << " ";
262 if( report_time ) {
263 cout.width( 10 );
264 cout << right << "SECONDS" << " ";
265 }
266 cout.width( 10 );
267 cout << "MAX ERROR" << " ";
268 cout.width( 10 );
269 cout << "AVG ERROR" << " ";
270 cout.width( 10 );
271 cout << "LOGZ ERROR" << " ";
272 cout.width( 10 );
273 cout << "MAXDIFF" << endl;
274
275 for( size_t m = 0; m < methods.size(); m++ ) {
276 pair<string, Properties> meth = parseMethod( methods[m], Aliases );
277
278 if( vm.count("tol") )
279 meth.second.Set("tol",tol);
280 if( vm.count("maxiter") )
281 meth.second.Set("maxiter",maxiter);
282 if( vm.count("verbose") )
283 meth.second.Set("verbose",verbose);
284 TestAI piet(fg, meth.first, meth.second );
285 piet.doAI();
286 if( m == 0 ) {
287 q0 = piet.q;
288 logZ0 = piet.logZ;
289 }
290 piet.calcErrs(q0);
291
292 cout.width( 40 );
293 // cout << left << piet.identify() << " ";
294 cout << left << methods[m] << " ";
295 if( report_time ) {
296 cout.width( 10 );
297 cout << right << piet.time << " ";
298 }
299
300 if( m > 0 ) {
301 cout.setf( ios_base::scientific );
302 cout.precision( 3 );
303 cout.width( 10 );
304 double me = clipdouble( piet.maxErr(), 1e-9 );
305 cout << me << " ";
306 cout.width( 10 );
307 double ae = clipdouble( piet.avgErr(), 1e-9 );
308 cout << ae << " ";
309 cout.width( 10 );
310 if( piet.has_logZ ) {
311 double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
312 cout << le << " ";
313 } else {
314 cout << "N/A ";
315 }
316 cout.width( 10 );
317 double md = clipdouble( piet.maxdiff, 1e-9 );
318 if( isnan( me ) )
319 md = me;
320 if( isnan( ae ) )
321 md = ae;
322 cout << md << endl;
323 } else
324 cout << endl;
325 }
326 }
327 } catch(const char *e) {
328 cerr << "Exception: " << e << endl;
329 return 1;
330 } catch(exception& e) {
331 cerr << "Exception: " << e.what() << endl;
332 return 1;
333 }
334 catch(...) {
335 cerr << "Exception of unknown type!" << endl;
336 }
337
338 return 0;
339 }