[Frederik Eaton] Added Fractional Belief Propagation
[libdai.git] / tests / testdai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <fstream>
14 #include <map>
15 #include <numeric>
16 #include <cmath>
17 #include <cstdlib>
18 #include <cstring>
19 #include <boost/program_options.hpp>
20 #include <dai/util.h>
21 #include <dai/alldai.h>
22
23
24 using namespace std;
25 using namespace dai;
26 namespace po = boost::program_options;
27
28
29 class TestDAI {
30 protected:
31 InfAlg *obj;
32 string name;
33 vector<Real> err;
34
35 public:
36 vector<Factor> q;
37 Real logZ;
38 Real maxdiff;
39 double time;
40 size_t iters;
41 bool has_logZ;
42 bool has_maxdiff;
43 bool has_iters;
44
45 TestDAI( const FactorGraph &fg, const string &_name, const PropertySet &opts ) : obj(NULL), name(_name), err(), q(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
46 double tic = toc();
47 if( name == "LDPC" ) {
48 Prob zero(2,0.0);
49 zero[0] = 1.0;
50 q.clear();
51 for( size_t i = 0; i < fg.nrVars(); i++ )
52 q.push_back( Factor(Var(i,2), zero) );
53 logZ = 0.0;
54 maxdiff = 0.0;
55 iters = 1;
56 has_logZ = false;
57 has_maxdiff = false;
58 has_iters = false;
59 } else
60 obj = newInfAlg( name, fg, opts );
61 time += toc() - tic;
62 }
63
64 ~TestDAI() {
65 if( obj != NULL )
66 delete obj;
67 }
68
69 string identify() const {
70 if( obj != NULL )
71 return obj->identify();
72 else
73 return "NULL";
74 }
75
76 vector<Factor> allBeliefs() {
77 vector<Factor> result;
78 for( size_t i = 0; i < obj->fg().nrVars(); i++ )
79 result.push_back( obj->belief( obj->fg().var(i) ) );
80 return result;
81 }
82
83 void doDAI() {
84 double tic = toc();
85 if( obj != NULL ) {
86 obj->init();
87 obj->run();
88 time += toc() - tic;
89
90 try {
91 logZ = obj->logZ();
92 has_logZ = true;
93 } catch( Exception &e ) {
94 if( e.code() == Exception::NOT_IMPLEMENTED )
95 has_logZ = false;
96 else
97 throw;
98 }
99
100 try {
101 maxdiff = obj->maxDiff();
102 has_maxdiff = true;
103 } catch( Exception &e ) {
104 if( e.code() == Exception::NOT_IMPLEMENTED )
105 has_maxdiff = false;
106 else
107 throw;
108 }
109
110 try {
111 iters = obj->Iterations();
112 has_iters = true;
113 } catch( Exception &e ) {
114 if( e.code() == Exception::NOT_IMPLEMENTED )
115 has_iters = false;
116 else
117 throw;
118 }
119
120 q = allBeliefs();
121 };
122 }
123
124 void calcErrs( const TestDAI &x ) {
125 err.clear();
126 err.reserve( q.size() );
127 for( size_t i = 0; i < q.size(); i++ )
128 err.push_back( dist( q[i], x.q[i], Prob::DISTTV ) );
129 }
130
131 void calcErrs( const vector<Factor> &x ) {
132 err.clear();
133 err.reserve( q.size() );
134 for( size_t i = 0; i < q.size(); i++ )
135 err.push_back( dist( q[i], x[i], Prob::DISTTV ) );
136 }
137
138 Real maxErr() {
139 return( *max_element( err.begin(), err.end() ) );
140 }
141
142 Real avgErr() {
143 return( accumulate( err.begin(), err.end(), 0.0 ) / err.size() );
144 }
145 };
146
147
148 pair<string, PropertySet> parseMethodRaw( const string &s ) {
149 string::size_type pos = s.find_first_of('[');
150 string name;
151 PropertySet opts;
152 if( pos == string::npos ) {
153 name = s;
154 } else {
155 name = s.substr(0,pos);
156
157 stringstream ss;
158 ss << s.substr(pos,s.length());
159 ss >> opts;
160 }
161 return make_pair(name,opts);
162 }
163
164
165 pair<string, PropertySet> parseMethod( const string &_s, const map<string,string> & aliases ) {
166 // break string into method[properties]
167 pair<string,PropertySet> ps = parseMethodRaw(_s);
168 bool looped = false;
169
170 // as long as 'method' is an alias, update:
171 while( aliases.find(ps.first) != aliases.end() && !looped ) {
172 string astr = aliases.find(ps.first)->second;
173 pair<string,PropertySet> aps = parseMethodRaw(astr);
174 if( aps.first == ps.first )
175 looped = true;
176 // override aps properties by ps properties
177 aps.second.Set( ps.second );
178 // replace ps by aps
179 ps = aps;
180 // repeat until method name == alias name ('looped'), or
181 // there is no longer an alias 'method'
182 }
183
184 // check whether name is valid
185 size_t n = 0;
186 for( ; strlen( DAINames[n] ) != 0; n++ )
187 if( ps.first == DAINames[n] )
188 break;
189 if( strlen( DAINames[n] ) == 0 && (ps.first != "LDPC") )
190 DAI_THROWE(UNKNOWN_DAI_ALGORITHM,string("Unknown DAI algorithm \"") + ps.first + string("\" in \"") + _s + string("\""));
191
192 return ps;
193 }
194
195
196 Real clipReal( Real x, Real minabs ) {
197 if( abs(x) < minabs )
198 return minabs;
199 else
200 return x;
201 }
202
203
204 int main( int argc, char *argv[] ) {
205 string filename;
206 string aliases;
207 vector<string> methods;
208 Real tol;
209 size_t maxiter;
210 size_t verbose;
211 bool marginals = false;
212 bool report_iters = true;
213 bool report_time = true;
214
215 po::options_description opts_required("Required options");
216 opts_required.add_options()
217 ("filename", po::value< string >(&filename), "Filename of FactorGraph")
218 ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to test")
219 ;
220
221 po::options_description opts_optional("Allowed options");
222 opts_optional.add_options()
223 ("help", "produce help message")
224 ("aliases", po::value< string >(&aliases), "Filename for aliases")
225 ("tol", po::value< Real >(&tol), "Override tolerance")
226 ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
227 ("verbose", po::value< size_t >(&verbose), "Override verbosity")
228 ("marginals", po::value< bool >(&marginals), "Output single node marginals?")
229 ("report-time", po::value< bool >(&report_time), "Report calculation time")
230 ("report-iters", po::value< bool >(&report_iters), "Report iterations needed")
231 ;
232
233 po::options_description cmdline_options;
234 cmdline_options.add(opts_required).add(opts_optional);
235
236 po::variables_map vm;
237 po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
238 po::notify(vm);
239
240 if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
241 cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
242 cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
243 cout << "error and relative logZ error (comparing with the results of" << endl;
244 cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
245 cout << opts_required << opts_optional << endl;
246 #ifdef DAI_DEBUG
247 cout << "This is a debugging (unoptimised) build of libDAI." << endl;
248 #endif
249 return 1;
250 }
251
252 try {
253 // Read aliases
254 map<string,string> Aliases;
255 if( !aliases.empty() ) {
256 ifstream infile;
257 infile.open (aliases.c_str());
258 if (infile.is_open()) {
259 while( true ) {
260 string line;
261 getline(infile,line);
262 if( infile.fail() )
263 break;
264 if( (!line.empty()) && (line[0] != '#') ) {
265 string::size_type pos = line.find(':',0);
266 if( pos == string::npos )
267 DAI_THROWE(RUNTIME_ERROR,"Invalid alias");
268 else {
269 string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
270 string key = line.substr(0, posl + 1);
271 string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
272 string val = line.substr(pos + 1 + posr, line.length());
273 Aliases[key] = val;
274 }
275 }
276 }
277 infile.close();
278 } else
279 DAI_THROWE(RUNTIME_ERROR,"Error opening aliases file");
280 }
281
282 FactorGraph fg;
283 fg.ReadFromFile( filename.c_str() );
284
285 vector<Factor> q0;
286 Real logZ0 = 0.0;
287
288 cout.setf( ios_base::scientific );
289 cout.precision( 3 );
290
291 cout << "# " << filename << endl;
292 cout.width( 39 );
293 cout << left << "# METHOD" << "\t";
294 if( report_time )
295 cout << right << "SECONDS " << "\t";
296 if( report_iters )
297 cout << "ITERS" << "\t";
298 cout << "MAX ERROR" << "\t";
299 cout << "AVG ERROR" << "\t";
300 cout << "LOGZ ERROR" << "\t";
301 cout << "MAXDIFF" << "\t";
302 cout << endl;
303
304 for( size_t m = 0; m < methods.size(); m++ ) {
305 pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );
306
307 if( vm.count("tol") )
308 meth.second.Set("tol",tol);
309 if( vm.count("maxiter") )
310 meth.second.Set("maxiter",maxiter);
311 if( vm.count("verbose") )
312 meth.second.Set("verbose",verbose);
313 TestDAI piet(fg, meth.first, meth.second );
314 piet.doDAI();
315 if( m == 0 ) {
316 q0 = piet.q;
317 logZ0 = piet.logZ;
318 }
319 piet.calcErrs(q0);
320
321 cout.width( 39 );
322 cout << left << methods[m] << "\t";
323 if( report_time )
324 cout << right << piet.time << "\t";
325 if( report_iters ) {
326 if( piet.has_iters ) {
327 cout << piet.iters << "\t";
328 } else {
329 cout << "N/A \t";
330 }
331 }
332
333 if( m > 0 ) {
334 cout.setf( ios_base::scientific );
335 cout.precision( 3 );
336
337 Real me = clipReal( piet.maxErr(), 1e-9 );
338 cout << me << "\t";
339
340 Real ae = clipReal( piet.avgErr(), 1e-9 );
341 cout << ae << "\t";
342
343 if( piet.has_logZ ) {
344 cout.setf( ios::showpos );
345 Real le = clipReal( piet.logZ / logZ0 - 1.0, 1e-9 );
346 cout << le << "\t";
347 cout.unsetf( ios::showpos );
348 } else
349 cout << "N/A \t";
350
351 if( piet.has_maxdiff ) {
352 Real md = clipReal( piet.maxdiff, 1e-9 );
353 if( isnan( me ) )
354 md = me;
355 if( isnan( ae ) )
356 md = ae;
357 if( md == INFINITY )
358 md = 1.0;
359 cout << md << "\t";
360 } else
361 cout << "N/A \t";
362 }
363 cout << endl;
364
365 if( marginals ) {
366 for( size_t i = 0; i < piet.q.size(); i++ )
367 cout << "# " << piet.q[i] << endl;
368 }
369 }
370
371 return 0;
372 } catch( string &s ) {
373 cerr << "Exception: " << s << endl;
374 return 2;
375 }
376 }