Initial commit of libDAI-0.2.1
[libdai.git] / tests / test.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <fstream>
24 #include <map>
25 #include <numeric>
26 #include <cmath>
27 #include <cstdlib>
28 #include "util.h"
29 #include "alldai.h"
30 #include <boost/program_options.hpp>
31
32
33 using namespace std;
34 namespace po = boost::program_options;
35
36
37 class TestAI {
38 protected:
39 InfAlg *obj;
40 string name;
41 vector<double> err;
42
43 public:
44 vector<Factor> q;
45 double logZ;
46 double maxdiff;
47 clock_t time;
48
49 TestAI( const FactorGraph &fg, const string &_name, const Properties &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0) {
50 clock_t tic = toc();
51 obj = newInfAlg( name, fg, opts );
52 time += toc() - tic;
53 /*
54 } else if( method.substr(0,5) == "EXACT" ) { // EXACT
55 // Look if the network is small enough to do brute-force exact method
56 bool toolarge = false;
57 size_t total_statespace = 1;
58 for( size_t i = 0; i < fg.nrVars(); i++ ) {
59 total_statespace *= fg.var(i).states();
60 if( total_statespace > (1UL << 16) )
61 toolarge = true;
62 }
63
64 if( !toolarge ) {
65 Factor piet;
66 for( size_t I = 0; I < fg.nrFactors(); I++ )
67 piet *= fg.factor( I );
68 for( size_t i = 0; i < fg.nrVars(); i++ )
69 q.push_back(piet.marginal(fg.var(i)));
70 time += toc() - tic;
71 logZ = fg.ExactlogZ();
72 } else
73 throw "Network too large for EXACT method";
74 }
75 */
76 }
77
78 ~TestAI() {
79 if( obj != NULL )
80 delete obj;
81 }
82
83 string identify() {
84 if( obj != NULL )
85 return obj->identify();
86 else
87 return "NULL";
88 }
89
90 vector<Factor> allBeliefs() {
91 vector<Factor> result;
92 for( size_t i = 0; i < obj->nrVars(); i++ )
93 result.push_back( obj->belief( obj->var(i) ) );
94 return result;
95 }
96
97 void doAI() {
98 clock_t tic = toc();
99 // if( name == "EXACT" ) {
100 // // calculation has already been done
101 // }
102 if( obj != NULL ) {
103 obj->init();
104 obj->run();
105 time += toc() - tic;
106 logZ = real(obj->logZ());
107 maxdiff = obj->MaxDiff();
108 q = allBeliefs();
109 };
110 }
111
112 void calcErrs( const TestAI &x ) {
113 err.clear();
114 err.reserve( q.size() );
115 for( size_t i = 0; i < q.size(); i++ )
116 err.push_back( dist( q[i], x.q[i], Prob::DISTTV ) );
117 }
118
119 void calcErrs( const vector<Factor> &x ) {
120 err.clear();
121 err.reserve( q.size() );
122 for( size_t i = 0; i < q.size(); i++ )
123 err.push_back( dist( q[i], x[i], Prob::DISTTV ) );
124 }
125
126 double maxErr() {
127 return( *max_element( err.begin(), err.end() ) );
128 }
129
130 double avgErr() {
131 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
132 }
133 };
134
135
136 pair<string, Properties> parseMethod( const string &_s, const map<string,string> & aliases ) {
137 string s = _s;
138 if( aliases.find(_s) != aliases.end() )
139 s = aliases.find(_s)->second;
140
141 pair<string, Properties> result;
142 string & name = result.first;
143 Properties & opts = result.second;
144
145 string::size_type pos = s.find_first_of('[');
146 name = s.substr( 0, pos );
147 if( pos == string::npos )
148 throw "Malformed method";
149 size_t n = 0;
150 for( ; n < sizeof(DAINames) / sizeof(string); n++ )
151 if( name == DAINames[n] )
152 break;
153 if( n == sizeof(DAINames) / sizeof(string) )
154 throw "Unknown inference algorithm";
155
156 stringstream ss;
157 ss << s.substr(pos,s.length());
158 ss >> opts;
159
160 return result;
161 }
162
163
164 double clipdouble( double x, double minabs ) {
165 if( fabs(x) < minabs )
166 return minabs;
167 else
168 return x;
169 }
170
171
172 int main( int argc, char *argv[] ) {
173 try {
174 string filename;
175 string aliases;
176 vector<string> methods;
177 double tol;
178 size_t maxiter;
179 size_t verbose;
180
181 po::options_description opts_required("Required options");
182 opts_required.add_options()
183 ("filename", po::value< string >(&filename), "Filename of FactorGraph")
184 ("methods", po::value< vector<string> >(&methods)->multitoken(), "AI methods to test")
185 ;
186
187 po::options_description opts_optional("Allowed options");
188 opts_optional.add_options()
189 ("help", "produce help message")
190 ("aliases", po::value< string >(&aliases), "Filename for aliases")
191 ("tol", po::value< double >(&tol), "Override tolerance")
192 ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
193 ("verbose", po::value< size_t >(&verbose), "Override verbosity")
194 ;
195
196 po::options_description cmdline_options;
197 cmdline_options.add(opts_required).add(opts_optional);
198
199 po::variables_map vm;
200 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
201 po::notify(vm);
202
203 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
204 cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
205 cout << "inference algorithms <method*>, reporting clocks, max and average" << endl;
206 cout << "error and relative logZ error (comparing with the results of" << endl;
207 cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
208 cout << opts_required << opts_optional << endl;
209 return 1;
210 }
211
212 // Read aliases
213 map<string,string> Aliases;
214 if( !aliases.empty() ) {
215 ifstream infile;
216 infile.open (aliases.c_str());
217 if (infile.is_open()) {
218 while( true ) {
219 string line;
220 getline(infile,line);
221 if( infile.fail() )
222 break;
223 if( (!line.empty()) && (line[0] != '#') ) {
224 string::size_type pos = line.find(':',0);
225 if( pos == string::npos )
226 throw "Invalid alias";
227 else {
228 string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
229 string key = line.substr(0, posl + 1);
230 string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
231 string val = line.substr(pos + 1 + posr, line.length());
232 Aliases[key] = val;
233 }
234 }
235 }
236 infile.close();
237 } else
238 throw "Error opening aliases file";
239 }
240
241 FactorGraph fg;
242 if( fg.ReadFromFile(filename.c_str()) ) {
243 cout << "Error reading " << filename << endl;
244 return 2;
245 } else {
246 vector<Factor> q0;
247 double logZ0 = 0.0;
248
249 cout << "# " << filename << endl;
250 cout.width( 40 );
251 cout << left << "# METHOD" << " ";
252 cout.width( 10 );
253 cout << right << "CLOCKS" << " ";
254 cout.width( 10 );
255 cout << "MAX ERROR" << " ";
256 cout.width( 10 );
257 cout << "AVG ERROR" << " ";
258 cout.width( 10 );
259 cout << "LOGZ ERROR" << " ";
260 cout.width( 10 );
261 cout << "MAXDIFF" << endl;
262
263 for( size_t m = 0; m < methods.size(); m++ ) {
264 pair<string, Properties> meth = parseMethod( methods[m], Aliases );
265
266 if( vm.count("tol") )
267 meth.second.Set("tol",tol);
268 if( vm.count("maxiter") )
269 meth.second.Set("maxiter",maxiter);
270 if( vm.count("verbose") )
271 meth.second.Set("verbose",verbose);
272 TestAI piet(fg, meth.first, meth.second );
273 piet.doAI();
274 if( m == 0 ) {
275 q0 = piet.q;
276 logZ0 = piet.logZ;
277 }
278 piet.calcErrs(q0);
279
280 cout.width( 40 );
281 // cout << left << piet.identify() << " ";
282 cout << left << methods[m] << " ";
283 cout.width( 10 );
284 cout << right << piet.time << " ";
285
286 if( m > 0 ) {
287 cout.setf( ios_base::scientific );
288 cout.precision( 3 );
289 cout.width( 10 );
290 double me = clipdouble( piet.maxErr(), 1e-9 );
291 cout << me << " ";
292 cout.width( 10 );
293 double ae = clipdouble( piet.avgErr(), 1e-9 );
294 cout << ae << " ";
295 cout.width( 10 );
296 double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
297 cout << le << " ";
298 cout.width( 10 );
299 double md = clipdouble( piet.maxdiff, 1e-9 );
300 if( isnan( me ) )
301 md = me;
302 if( isnan( ae ) )
303 md = ae;
304 cout << md << endl;
305 } else
306 cout << endl;
307 }
308 }
309 } catch(const char *e) {
310 cerr << "Exception: " << e << endl;
311 return 1;
312 } catch(exception& e) {
313 cerr << "Exception: " << e.what() << endl;
314 return 1;
315 }
316 catch(...) {
317 cerr << "Exception of unknown type!" << endl;
318 }
319
320 return 0;
321 }