Moved alias code from tests/testdai.cpp to src/alldai.cpp
[libdai.git] / tests / testdai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <map>
15 #include <numeric>
16 #include <cmath>
17 #include <cstdlib>
18 #include <cstring>
19 #include <boost/program_options.hpp>
20 #include <dai/util.h>
21 #include <dai/alldai.h>
22
23
24 using namespace std;
25 using namespace dai;
26 namespace po = boost::program_options;
27
28
29 class TestDAI {
30 protected:
31 InfAlg *obj;
32 string name;
33 vector<Real> err;
34
35 public:
36 vector<Factor> varmargs;
37 vector<Factor> marginals;
38 Real logZ;
39 Real maxdiff;
40 double time;
41 size_t iters;
42 bool has_logZ;
43 bool has_maxdiff;
44 bool has_iters;
45
46 TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), varmargs(), marginals(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
47 double tic = toc();
48 if( name == "LDPC" ) {
49 Prob zero(2,0.0);
50 zero[0] = 1.0;
51 for( size_t i = 0; i < fg.nrVars(); i++ )
52 varmargs.push_back( Factor(fg.var(i), zero) );
53 marginals = varmargs;
54 logZ = 0.0;
55 maxdiff = 0.0;
56 iters = 1;
57 has_logZ = false;
58 has_maxdiff = false;
59 has_iters = false;
60 } else
61 obj = newInfAlg( name, fg, opts );
62 time += toc() - tic;
63 }
64
65 ~TestDAI() {
66 if( obj != NULL )
67 delete obj;
68 }
69
70 string identify() const {
71 if( obj != NULL )
72 return obj->identify();
73 else
74 return "NULL";
75 }
76
77 void doDAI() {
78 double tic = toc();
79 if( obj != NULL ) {
80 obj->init();
81 obj->run();
82 time += toc() - tic;
83
84 try {
85 logZ = obj->logZ();
86 has_logZ = true;
87 } catch( Exception &e ) {
88 if( e.code() == Exception::NOT_IMPLEMENTED )
89 has_logZ = false;
90 else
91 throw;
92 }
93
94 try {
95 maxdiff = obj->maxDiff();
96 has_maxdiff = true;
97 } catch( Exception &e ) {
98 if( e.code() == Exception::NOT_IMPLEMENTED )
99 has_maxdiff = false;
100 else
101 throw;
102 }
103
104 try {
105 iters = obj->Iterations();
106 has_iters = true;
107 } catch( Exception &e ) {
108 if( e.code() == Exception::NOT_IMPLEMENTED )
109 has_iters = false;
110 else
111 throw;
112 }
113
114 varmargs.clear();
115 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
116 varmargs.push_back( obj->beliefV( i ) );
117
118 marginals = obj->beliefs();
119 };
120 }
121
122 void calcErrs( const TestDAI &x ) {
123 err.clear();
124 err.reserve( varmargs.size() );
125 for( size_t i = 0; i < varmargs.size(); i++ )
126 err.push_back( dist( varmargs[i], x.varmargs[i], Prob::DISTTV ) );
127 }
128
129 void calcErrs( const vector<Factor> &x ) {
130 err.clear();
131 err.reserve( varmargs.size() );
132 for( size_t i = 0; i < varmargs.size(); i++ )
133 err.push_back( dist( varmargs[i], x[i], Prob::DISTTV ) );
134 }
135
136 Real maxErr() {
137 return( *max_element( err.begin(), err.end() ) );
138 }
139
140 Real avgErr() {
141 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
142 }
143 };
144
145
146 Real clipReal( Real x, Real minabs ) {
147 if( abs(x) < minabs )
148 return minabs;
149 else
150 return x;
151 }
152
153
154 DAI_ENUM(MarginalsOutputType,NONE,VAR,ALL);
155
156
157 int main( int argc, char *argv[] ) {
158 string filename;
159 string aliases;
160 vector<string> methods;
161 Real tol;
162 size_t maxiter;
163 size_t verbose;
164 MarginalsOutputType marginals;
165 bool report_iters = true;
166 bool report_time = true;
167
168 po::options_description opts_required("Required options");
169 opts_required.add_options()
170 ("filename", po::value< string >(&filename), "Filename of FactorGraph")
171 ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to test")
172 ;
173
174 po::options_description opts_optional("Allowed options");
175 opts_optional.add_options()
176 ("help", "produce help message")
177 ("aliases", po::value< string >(&aliases), "Filename for aliases")
178 ("tol", po::value< Real >(&tol), "Override tolerance")
179 ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
180 ("verbose", po::value< size_t >(&verbose), "Override verbosity")
181 ("marginals", po::value< MarginalsOutputType >(&marginals), "Output marginals? (NONE,VAR,ALL)")
182 ("report-time", po::value< bool >(&report_time), "Report calculation time")
183 ("report-iters", po::value< bool >(&report_iters), "Report iterations needed")
184 ;
185
186 po::options_description cmdline_options;
187 cmdline_options.add(opts_required).add(opts_optional);
188
189 po::variables_map vm;
190 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
191 po::notify(vm);
192
193 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
194 cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
195 cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
196 cout << "error and relative logZ error (comparing with the results of" << endl;
197 cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
198 cout << opts_required << opts_optional << endl;
199 #ifdef DAI_DEBUG
200 cout << "This is a debugging (unoptimised) build of libDAI." << endl;
201 #endif
202 return 1;
203 }
204
205 try {
206 // Read aliases
207 map<string,string> Aliases;
208 if( !aliases.empty() )
209 Aliases = readAliasesFile( aliases );
210
211 FactorGraph fg;
212 fg.ReadFromFile( filename.c_str() );
213
214 vector<Factor> varmargs0;
215 Real logZ0 = 0.0;
216
217 cout.setf( ios_base::scientific );
218 cout.precision( 3 );
219
220 cout << "# " << filename << endl;
221 cout.width( 39 );
222 cout << left << "# METHOD" << "\t";
223 if( report_time )
224 cout << right << "SECONDS " << "\t";
225 if( report_iters )
226 cout << "ITERS" << "\t";
227 cout << "MAX ERROR" << "\t";
228 cout << "AVG ERROR" << "\t";
229 cout << "LOGZ ERROR" << "\t";
230 cout << "MAXDIFF" << "\t";
231 cout << endl;
232
233 for( size_t m = 0; m < methods.size(); m++ ) {
234 // parse method
235 pair<string, PropertySet> meth = parseNameProperties( methods[m], Aliases );
236
237 // check whether name is valid
238 size_t n = 0;
239 for( ; strlen( DAINames[n] ) != 0; n++ )
240 if( meth.first == DAINames[n] )
241 break;
242 if( strlen( DAINames[n] ) == 0 )
243 DAI_THROWE(UNKNOWN_DAI_ALGORITHM,string("Unknown DAI algorithm \"") + meth.first + string("\" in \"") + methods[m] + string("\""));
244
245 if( vm.count("tol") )
246 meth.second.Set("tol",tol);
247 if( vm.count("maxiter") )
248 meth.second.Set("maxiter",maxiter);
249 if( vm.count("verbose") )
250 meth.second.Set("verbose",verbose);
251 TestDAI testdai(fg, meth.first, meth.second );
252 testdai.doDAI();
253 if( m == 0 ) {
254 varmargs0 = testdai.varmargs;
255 logZ0 = testdai.logZ;
256 }
257 testdai.calcErrs(varmargs0);
258
259 cout.width( 39 );
260 cout << left << methods[m] << "\t";
261 if( report_time )
262 cout << right << testdai.time << "\t";
263 if( report_iters ) {
264 if( testdai.has_iters ) {
265 cout << testdai.iters << "\t";
266 } else {
267 cout << "N/A \t";
268 }
269 }
270
271 if( m > 0 ) {
272 cout.setf( ios_base::scientific );
273 cout.precision( 3 );
274
275 Real me = clipReal( testdai.maxErr(), 1e-9 );
276 cout << me << "\t";
277
278 Real ae = clipReal( testdai.avgErr(), 1e-9 );
279 cout << ae << "\t";
280
281 if( testdai.has_logZ ) {
282 cout.setf( ios::showpos );
283 Real le = clipReal( testdai.logZ / logZ0 - 1.0, 1e-9 );
284 cout << le << "\t";
285 cout.unsetf( ios::showpos );
286 } else
287 cout << "N/A \t";
288
289 if( testdai.has_maxdiff ) {
290 Real md = clipReal( testdai.maxdiff, 1e-9 );
291 if( isnan( me ) )
292 md = me;
293 if( isnan( ae ) )
294 md = ae;
295 if( md == INFINITY )
296 md = 1.0;
297 cout << md << "\t";
298 } else
299 cout << "N/A \t";
300 }
301 cout << endl;
302
303 if( marginals == MarginalsOutputType::VAR ) {
304 for( size_t i = 0; i < testdai.varmargs.size(); i++ )
305 cout << "# " << testdai.varmargs[i] << endl;
306 } else if( marginals == MarginalsOutputType::ALL ) {
307 for( size_t I = 0; I < testdai.marginals.size(); I++ )
308 cout << "# " << testdai.marginals[I] << endl;
309 }
310 }
311
312 return 0;
313 } catch( string &s ) {
314 cerr << "Exception: " << s << endl;
315 return 2;
316 }
317 }