1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
19 #include <boost/program_options.hpp>
21 #include <dai/alldai.h>
26 namespace po
= boost::program_options
;
45 TestDAI( const FactorGraph
&fg
, const string
&_name
, const PropertySet
&opts
) : obj(NULL
), name(_name
), err(), q(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
47 if( name
== "LDPC" ) {
51 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
52 q
.push_back( Factor(Var(i
,2), zero
) );
60 obj
= newInfAlg( name
, fg
, opts
);
69 string
identify() const {
71 return obj
->identify();
76 vector
<Factor
> allBeliefs() {
77 vector
<Factor
> result
;
78 for( size_t i
= 0; i
< obj
->fg().nrVars(); i
++ )
79 result
.push_back( obj
->belief( obj
->fg().var(i
) ) );
93 } catch( Exception
&e
) {
94 if( e
.code() == Exception::NOT_IMPLEMENTED
)
101 maxdiff
= obj
->maxDiff();
103 } catch( Exception
&e
) {
104 if( e
.code() == Exception::NOT_IMPLEMENTED
)
111 iters
= obj
->Iterations();
113 } catch( Exception
&e
) {
114 if( e
.code() == Exception::NOT_IMPLEMENTED
)
124 void calcErrs( const TestDAI
&x
) {
126 err
.reserve( q
.size() );
127 for( size_t i
= 0; i
< q
.size(); i
++ )
128 err
.push_back( dist( q
[i
], x
.q
[i
], Prob::DISTTV
) );
131 void calcErrs( const vector
<Factor
> &x
) {
133 err
.reserve( q
.size() );
134 for( size_t i
= 0; i
< q
.size(); i
++ )
135 err
.push_back( dist( q
[i
], x
[i
], Prob::DISTTV
) );
139 return( *max_element( err
.begin(), err
.end() ) );
143 return( accumulate( err
.begin(), err
.end(), 0.0 ) / err
.size() );
148 pair
<string
, PropertySet
> parseMethodRaw( const string
&s
) {
149 string::size_type pos
= s
.find_first_of('[');
152 if( pos
== string::npos
) {
155 name
= s
.substr(0,pos
);
158 ss
<< s
.substr(pos
,s
.length());
161 return make_pair(name
,opts
);
165 pair
<string
, PropertySet
> parseMethod( const string
&_s
, const map
<string
,string
> & aliases
) {
166 // break string into method[properties]
167 pair
<string
,PropertySet
> ps
= parseMethodRaw(_s
);
170 // as long as 'method' is an alias, update:
171 while( aliases
.find(ps
.first
) != aliases
.end() && !looped
) {
172 string astr
= aliases
.find(ps
.first
)->second
;
173 pair
<string
,PropertySet
> aps
= parseMethodRaw(astr
);
174 if( aps
.first
== ps
.first
)
176 // override aps properties by ps properties
177 aps
.second
.Set( ps
.second
);
180 // repeat until method name == alias name ('looped'), or
181 // there is no longer an alias 'method'
184 // check whether name is valid
186 for( ; strlen( DAINames
[n
] ) != 0; n
++ )
187 if( ps
.first
== DAINames
[n
] )
189 if( strlen( DAINames
[n
] ) == 0 && (ps
.first
!= "LDPC") )
190 DAI_THROWE(UNKNOWN_DAI_ALGORITHM
,string("Unknown DAI algorithm \"") + ps
.first
+ string("\" in \"") + _s
+ string("\""));
196 Real
clipReal( Real x
, Real minabs
) {
197 if( abs(x
) < minabs
)
204 int main( int argc
, char *argv
[] ) {
207 vector
<string
> methods
;
211 bool marginals
= false;
212 bool report_iters
= true;
213 bool report_time
= true;
215 po::options_description
opts_required("Required options");
216 opts_required
.add_options()
217 ("filename", po::value
< string
>(&filename
), "Filename of FactorGraph")
218 ("methods", po::value
< vector
<string
> >(&methods
)->multitoken(), "DAI methods to test")
221 po::options_description
opts_optional("Allowed options");
222 opts_optional
.add_options()
223 ("help", "produce help message")
224 ("aliases", po::value
< string
>(&aliases
), "Filename for aliases")
225 ("tol", po::value
< Real
>(&tol
), "Override tolerance")
226 ("maxiter", po::value
< size_t >(&maxiter
), "Override maximum number of iterations")
227 ("verbose", po::value
< size_t >(&verbose
), "Override verbosity")
228 ("marginals", po::value
< bool >(&marginals
), "Output single node marginals?")
229 ("report-time", po::value
< bool >(&report_time
), "Report calculation time")
230 ("report-iters", po::value
< bool >(&report_iters
), "Report iterations needed")
233 po::options_description cmdline_options
;
234 cmdline_options
.add(opts_required
).add(opts_optional
);
236 po::variables_map vm
;
237 po::store(po::parse_command_line(argc
, argv
, cmdline_options
), vm
);
240 if( vm
.count("help") || !(vm
.count("filename") && vm
.count("methods")) ) {
241 cout
<< "Reads factorgraph <filename.fg> and performs the approximate" << endl
;
242 cout
<< "inference algorithms <method*>, reporting calculation time, max and average" << endl
;
243 cout
<< "error and relative logZ error (comparing with the results of" << endl
;
244 cout
<< "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl
<< endl
;
245 cout
<< opts_required
<< opts_optional
<< endl
;
247 cout
<< "This is a debugging (unoptimised) build of libDAI." << endl
;
254 map
<string
,string
> Aliases
;
255 if( !aliases
.empty() ) {
257 infile
.open (aliases
.c_str());
258 if (infile
.is_open()) {
261 getline(infile
,line
);
264 if( (!line
.empty()) && (line
[0] != '#') ) {
265 string::size_type pos
= line
.find(':',0);
266 if( pos
== string::npos
)
267 DAI_THROWE(RUNTIME_ERROR
,"Invalid alias");
269 string::size_type posl
= line
.substr(0, pos
).find_last_not_of(" \t");
270 string key
= line
.substr(0, posl
+ 1);
271 string::size_type posr
= line
.substr(pos
+ 1, line
.length()).find_first_not_of(" \t");
272 string val
= line
.substr(pos
+ 1 + posr
, line
.length());
279 DAI_THROWE(RUNTIME_ERROR
,"Error opening aliases file");
283 fg
.ReadFromFile( filename
.c_str() );
288 cout
.setf( ios_base::scientific
);
291 cout
<< "# " << filename
<< endl
;
293 cout
<< left
<< "# METHOD" << "\t";
295 cout
<< right
<< "SECONDS " << "\t";
297 cout
<< "ITERS" << "\t";
298 cout
<< "MAX ERROR" << "\t";
299 cout
<< "AVG ERROR" << "\t";
300 cout
<< "LOGZ ERROR" << "\t";
301 cout
<< "MAXDIFF" << "\t";
304 for( size_t m
= 0; m
< methods
.size(); m
++ ) {
305 pair
<string
, PropertySet
> meth
= parseMethod( methods
[m
], Aliases
);
307 if( vm
.count("tol") )
308 meth
.second
.Set("tol",tol
);
309 if( vm
.count("maxiter") )
310 meth
.second
.Set("maxiter",maxiter
);
311 if( vm
.count("verbose") )
312 meth
.second
.Set("verbose",verbose
);
313 TestDAI
piet(fg
, meth
.first
, meth
.second
);
322 cout
<< left
<< methods
[m
] << "\t";
324 cout
<< right
<< piet
.time
<< "\t";
326 if( piet
.has_iters
) {
327 cout
<< piet
.iters
<< "\t";
334 cout
.setf( ios_base::scientific
);
337 Real me
= clipReal( piet
.maxErr(), 1e-9 );
340 Real ae
= clipReal( piet
.avgErr(), 1e-9 );
343 if( piet
.has_logZ
) {
344 cout
.setf( ios::showpos
);
345 Real le
= clipReal( piet
.logZ
/ logZ0
- 1.0, 1e-9 );
347 cout
.unsetf( ios::showpos
);
351 if( piet
.has_maxdiff
) {
352 Real md
= clipReal( piet
.maxdiff
, 1e-9 );
366 for( size_t i
= 0; i
< piet
.q
.size(); i
++ )
367 cout
<< "# " << piet
.q
[i
] << endl
;
372 } catch( string
&s
) {
373 cerr
<< "Exception: " << s
<< endl
;