Fixed bug in HAK and changed tests/testdai "marginals" option
[libdai.git] / tests / testdai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <map>
15 #include <numeric>
16 #include <cmath>
17 #include <cstdlib>
18 #include <cstring>
19 #include <boost/program_options.hpp>
20 #include <dai/util.h>
21 #include <dai/alldai.h>
22
23
24 using namespace std;
25 using namespace dai;
26 namespace po = boost::program_options;
27
28
29 class TestDAI {
30 protected:
31 InfAlg *obj;
32 string name;
33 vector<Real> err;
34
35 public:
36 vector<Factor> varmargs;
37 vector<Factor> marginals;
38 Real logZ;
39 Real maxdiff;
40 double time;
41 size_t iters;
42 bool has_logZ;
43 bool has_maxdiff;
44 bool has_iters;
45
46 TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), varmargs(), marginals(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
47 double tic = toc();
48 if( name == "LDPC" ) {
49 Prob zero(2,0.0);
50 zero[0] = 1.0;
51 for( size_t i = 0; i < fg.nrVars(); i++ )
52 varmargs.push_back( Factor(fg.var(i), zero) );
53 marginals = varmargs;
54 logZ = 0.0;
55 maxdiff = 0.0;
56 iters = 1;
57 has_logZ = false;
58 has_maxdiff = false;
59 has_iters = false;
60 } else
61 obj = newInfAlg( name, fg, opts );
62 time += toc() - tic;
63 }
64
65 ~TestDAI() {
66 if( obj != NULL )
67 delete obj;
68 }
69
70 string identify() const {
71 if( obj != NULL )
72 return obj->identify();
73 else
74 return "NULL";
75 }
76
77 void doDAI() {
78 double tic = toc();
79 if( obj != NULL ) {
80 obj->init();
81 obj->run();
82 time += toc() - tic;
83
84 try {
85 logZ = obj->logZ();
86 has_logZ = true;
87 } catch( Exception &e ) {
88 if( e.code() == Exception::NOT_IMPLEMENTED )
89 has_logZ = false;
90 else
91 throw;
92 }
93
94 try {
95 maxdiff = obj->maxDiff();
96 has_maxdiff = true;
97 } catch( Exception &e ) {
98 if( e.code() == Exception::NOT_IMPLEMENTED )
99 has_maxdiff = false;
100 else
101 throw;
102 }
103
104 try {
105 iters = obj->Iterations();
106 has_iters = true;
107 } catch( Exception &e ) {
108 if( e.code() == Exception::NOT_IMPLEMENTED )
109 has_iters = false;
110 else
111 throw;
112 }
113
114 varmargs.clear();
115 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
116 varmargs.push_back( obj->beliefV( i ) );
117
118 marginals = obj->beliefs();
119 };
120 }
121
122 void calcErrs( const TestDAI &x ) {
123 err.clear();
124 err.reserve( varmargs.size() );
125 for( size_t i = 0; i < varmargs.size(); i++ )
126 err.push_back( dist( varmargs[i], x.varmargs[i], Prob::DISTTV ) );
127 }
128
129 void calcErrs( const vector<Factor> &x ) {
130 err.clear();
131 err.reserve( varmargs.size() );
132 for( size_t i = 0; i < varmargs.size(); i++ )
133 err.push_back( dist( varmargs[i], x[i], Prob::DISTTV ) );
134 }
135
136 Real maxErr() {
137 return( *max_element( err.begin(), err.end() ) );
138 }
139
140 Real avgErr() {
141 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
142 }
143 };
144
145
146 pair<string, PropertySet> parseMethodRaw( const string &s ) {
147 string::size_type pos = s.find_first_of('[');
148 string name;
149 PropertySet opts;
150 if( pos == string::npos ) {
151 name = s;
152 } else {
153 name = s.substr(0,pos);
154
155 stringstream ss;
156 ss << s.substr(pos,s.length());
157 ss >> opts;
158 }
159 return make_pair(name,opts);
160 }
161
162
163 pair<string, PropertySet> parseMethod( const string &_s, const map<string,string> & aliases ) {
164 // break string into method[properties]
165 pair<string,PropertySet> ps = parseMethodRaw(_s);
166 bool looped = false;
167
168 // as long as 'method' is an alias, update:
169 while( aliases.find(ps.first) != aliases.end() && !looped ) {
170 string astr = aliases.find(ps.first)->second;
171 pair<string,PropertySet> aps = parseMethodRaw(astr);
172 if( aps.first == ps.first )
173 looped = true;
174 // override aps properties by ps properties
175 aps.second.Set( ps.second );
176 // replace ps by aps
177 ps = aps;
178 // repeat until method name == alias name ('looped'), or
179 // there is no longer an alias 'method'
180 }
181
182 // check whether name is valid
183 size_t n = 0;
184 for( ; strlen( DAINames[n] ) != 0; n++ )
185 if( ps.first == DAINames[n] )
186 break;
187 if( strlen( DAINames[n] ) == 0 && (ps.first != "LDPC") )
188 DAI_THROWE(UNKNOWN_DAI_ALGORITHM,string("Unknown DAI algorithm \"") + ps.first + string("\" in \"") + _s + string("\""));
189
190 return ps;
191 }
192
193
194 Real clipReal( Real x, Real minabs ) {
195 if( abs(x) < minabs )
196 return minabs;
197 else
198 return x;
199 }
200
201
202 DAI_ENUM(MarginalsOutputType,NONE,VAR,ALL);
203
204
205 int main( int argc, char *argv[] ) {
206 string filename;
207 string aliases;
208 vector<string> methods;
209 Real tol;
210 size_t maxiter;
211 size_t verbose;
212 MarginalsOutputType marginals;
213 bool report_iters = true;
214 bool report_time = true;
215
216 po::options_description opts_required("Required options");
217 opts_required.add_options()
218 ("filename", po::value< string >(&filename), "Filename of FactorGraph")
219 ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to test")
220 ;
221
222 po::options_description opts_optional("Allowed options");
223 opts_optional.add_options()
224 ("help", "produce help message")
225 ("aliases", po::value< string >(&aliases), "Filename for aliases")
226 ("tol", po::value< Real >(&tol), "Override tolerance")
227 ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
228 ("verbose", po::value< size_t >(&verbose), "Override verbosity")
229 ("marginals", po::value< MarginalsOutputType >(&marginals), "Output marginals? (NONE,VAR,ALL)")
230 ("report-time", po::value< bool >(&report_time), "Report calculation time")
231 ("report-iters", po::value< bool >(&report_iters), "Report iterations needed")
232 ;
233
234 po::options_description cmdline_options;
235 cmdline_options.add(opts_required).add(opts_optional);
236
237 po::variables_map vm;
238 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
239 po::notify(vm);
240
241 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
242 cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
243 cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
244 cout << "error and relative logZ error (comparing with the results of" << endl;
245 cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
246 cout << opts_required << opts_optional << endl;
247 #ifdef DAI_DEBUG
248 cout << "This is a debugging (unoptimised) build of libDAI." << endl;
249 #endif
250 return 1;
251 }
252
253 try {
254 // Read aliases
255 map<string,string> Aliases;
256 if( !aliases.empty() ) {
257 ifstream infile;
258 infile.open (aliases.c_str());
259 if (infile.is_open()) {
260 while( true ) {
261 string line;
262 getline(infile,line);
263 if( infile.fail() )
264 break;
265 if( (!line.empty()) && (line[0] != '#') ) {
266 string::size_type pos = line.find(':',0);
267 if( pos == string::npos )
268 DAI_THROWE(RUNTIME_ERROR,"Invalid alias");
269 else {
270 string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
271 string key = line.substr(0, posl + 1);
272 string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
273 string val = line.substr(pos + 1 + posr, line.length());
274 Aliases[key] = val;
275 }
276 }
277 }
278 infile.close();
279 } else
280 DAI_THROWE(RUNTIME_ERROR,"Error opening aliases file");
281 }
282
283 FactorGraph fg;
284 fg.ReadFromFile( filename.c_str() );
285
286 vector<Factor> varmargs0;
287 Real logZ0 = 0.0;
288
289 cout.setf( ios_base::scientific );
290 cout.precision( 3 );
291
292 cout << "# " << filename << endl;
293 cout.width( 39 );
294 cout << left << "# METHOD" << "\t";
295 if( report_time )
296 cout << right << "SECONDS " << "\t";
297 if( report_iters )
298 cout << "ITERS" << "\t";
299 cout << "MAX ERROR" << "\t";
300 cout << "AVG ERROR" << "\t";
301 cout << "LOGZ ERROR" << "\t";
302 cout << "MAXDIFF" << "\t";
303 cout << endl;
304
305 for( size_t m = 0; m < methods.size(); m++ ) {
306 pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
307
308 if( vm.count("tol") )
309 meth.second.Set("tol",tol);
310 if( vm.count("maxiter") )
311 meth.second.Set("maxiter",maxiter);
312 if( vm.count("verbose") )
313 meth.second.Set("verbose",verbose);
314 TestDAI testdai(fg, meth.first, meth.second );
315 testdai.doDAI();
316 if( m == 0 ) {
317 varmargs0 = testdai.varmargs;
318 logZ0 = testdai.logZ;
319 }
320 testdai.calcErrs(varmargs0);
321
322 cout.width( 39 );
323 cout << left << methods[m] << "\t";
324 if( report_time )
325 cout << right << testdai.time << "\t";
326 if( report_iters ) {
327 if( testdai.has_iters ) {
328 cout << testdai.iters << "\t";
329 } else {
330 cout << "N/A \t";
331 }
332 }
333
334 if( m > 0 ) {
335 cout.setf( ios_base::scientific );
336 cout.precision( 3 );
337
338 Real me = clipReal( testdai.maxErr(), 1e-9 );
339 cout << me << "\t";
340
341 Real ae = clipReal( testdai.avgErr(), 1e-9 );
342 cout << ae << "\t";
343
344 if( testdai.has_logZ ) {
345 cout.setf( ios::showpos );
346 Real le = clipReal( testdai.logZ / logZ0 - 1.0, 1e-9 );
347 cout << le << "\t";
348 cout.unsetf( ios::showpos );
349 } else
350 cout << "N/A \t";
351
352 if( testdai.has_maxdiff ) {
353 Real md = clipReal( testdai.maxdiff, 1e-9 );
354 if( isnan( me ) )
355 md = me;
356 if( isnan( ae ) )
357 md = ae;
358 if( md == INFINITY )
359 md = 1.0;
360 cout << md << "\t";
361 } else
362 cout << "N/A \t";
363 }
364 cout << endl;
365
366 if( marginals == MarginalsOutputType::VAR ) {
367 for( size_t i = 0; i < testdai.varmargs.size(); i++ )
368 cout << "# " << testdai.varmargs[i] << endl;
369 } else if( marginals == MarginalsOutputType::ALL ) {
370 for( size_t I = 0; I < testdai.marginals.size(); I++ )
371 cout << "# " << testdai.marginals[I] << endl;
372 }
373 }
374
375 return 0;
376 } catch( string &s ) {
377 cerr << "Exception: " << s << endl;
378 return 2;
379 }
380 }