1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
19 #include <boost/program_options.hpp>
21 #include <dai/alldai.h>
26 namespace po
= boost::program_options
;
36 vector
<Factor
> varmargs
;
37 vector
<Factor
> marginals
;
46 TestDAI( const FactorGraph
&fg
, const string
&_name
, const PropertySet
&opts
) : obj(NULL
), name(_name
), err(), varmargs(), marginals(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
48 if( name
== "LDPC" ) {
51 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
52 varmargs
.push_back( Factor(fg
.var(i
), zero
) );
61 obj
= newInfAlg( name
, fg
, opts
);
70 string
identify() const {
72 return obj
->identify();
87 } catch( Exception
&e
) {
88 if( e
.code() == Exception::NOT_IMPLEMENTED
)
95 maxdiff
= obj
->maxDiff();
97 } catch( Exception
&e
) {
98 if( e
.code() == Exception::NOT_IMPLEMENTED
)
105 iters
= obj
->Iterations();
107 } catch( Exception
&e
) {
108 if( e
.code() == Exception::NOT_IMPLEMENTED
)
115 for( size_t i
= 0; i
< obj
->fg().nrVars(); i
++ )
116 varmargs
.push_back( obj
->beliefV( i
) );
118 marginals
= obj
->beliefs();
122 void calcErrs( const TestDAI
&x
) {
124 err
.reserve( varmargs
.size() );
125 for( size_t i
= 0; i
< varmargs
.size(); i
++ )
126 err
.push_back( dist( varmargs
[i
], x
.varmargs
[i
], Prob::DISTTV
) );
129 void calcErrs( const vector
<Factor
> &x
) {
131 err
.reserve( varmargs
.size() );
132 for( size_t i
= 0; i
< varmargs
.size(); i
++ )
133 err
.push_back( dist( varmargs
[i
], x
[i
], Prob::DISTTV
) );
137 return( *max_element( err
.begin(), err
.end() ) );
141 return( accumulate( err
.begin(), err
.end(), 0.0 ) / err
.size() );
146 Real
clipReal( Real x
, Real minabs
) {
147 if( abs(x
) < minabs
)
154 DAI_ENUM(MarginalsOutputType
,NONE
,VAR
,ALL
);
157 int main( int argc
, char *argv
[] ) {
160 vector
<string
> methods
;
164 MarginalsOutputType marginals
;
165 bool report_iters
= true;
166 bool report_time
= true;
168 po::options_description
opts_required("Required options");
169 opts_required
.add_options()
170 ("filename", po::value
< string
>(&filename
), "Filename of FactorGraph")
171 ("methods", po::value
< vector
<string
> >(&methods
)->multitoken(), "DAI methods to test")
174 po::options_description
opts_optional("Allowed options");
175 opts_optional
.add_options()
176 ("help", "produce help message")
177 ("aliases", po::value
< string
>(&aliases
), "Filename for aliases")
178 ("tol", po::value
< Real
>(&tol
), "Override tolerance")
179 ("maxiter", po::value
< size_t >(&maxiter
), "Override maximum number of iterations")
180 ("verbose", po::value
< size_t >(&verbose
), "Override verbosity")
181 ("marginals", po::value
< MarginalsOutputType
>(&marginals
), "Output marginals? (NONE,VAR,ALL)")
182 ("report-time", po::value
< bool >(&report_time
), "Report calculation time")
183 ("report-iters", po::value
< bool >(&report_iters
), "Report iterations needed")
186 po::options_description cmdline_options
;
187 cmdline_options
.add(opts_required
).add(opts_optional
);
189 po::variables_map vm
;
190 po::store(po::parse_command_line(argc
, argv
, cmdline_options
), vm
);
193 if( vm
.count("help") || !(vm
.count("filename") && vm
.count("methods")) ) {
194 cout
<< "Reads factorgraph <filename.fg> and performs the approximate" << endl
;
195 cout
<< "inference algorithms <method*>, reporting calculation time, max and average" << endl
;
196 cout
<< "error and relative logZ error (comparing with the results of" << endl
;
197 cout
<< "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl
<< endl
;
198 cout
<< opts_required
<< opts_optional
<< endl
;
200 cout
<< "This is a debugging (unoptimised) build of libDAI." << endl
;
207 map
<string
,string
> Aliases
;
208 if( !aliases
.empty() )
209 Aliases
= readAliasesFile( aliases
);
213 fg
.ReadFromFile( filename
.c_str() );
215 vector
<Factor
> varmargs0
;
218 cout
.setf( ios_base::scientific
);
221 cout
<< "# " << filename
<< endl
;
223 cout
<< left
<< "# METHOD" << "\t";
225 cout
<< right
<< "SECONDS " << "\t";
227 cout
<< "ITERS" << "\t";
228 cout
<< "MAX ERROR" << "\t";
229 cout
<< "AVG ERROR" << "\t";
230 cout
<< "LOGZ ERROR" << "\t";
231 cout
<< "MAXDIFF" << "\t";
234 for( size_t m
= 0; m
< methods
.size(); m
++ ) {
236 pair
<string
, PropertySet
> meth
= parseNameProperties( methods
[m
], Aliases
);
238 // check whether name is valid
240 for( ; strlen( DAINames
[n
] ) != 0; n
++ )
241 if( meth
.first
== DAINames
[n
] )
243 if( strlen( DAINames
[n
] ) == 0 )
244 DAI_THROWE(UNKNOWN_DAI_ALGORITHM
,string("Unknown DAI algorithm \"") + meth
.first
+ string("\" in \"") + methods
[m
] + string("\""));
246 if( vm
.count("tol") )
247 meth
.second
.Set("tol",tol
);
248 if( vm
.count("maxiter") )
249 meth
.second
.Set("maxiter",maxiter
);
250 if( vm
.count("verbose") )
251 meth
.second
.Set("verbose",verbose
);
252 TestDAI
testdai(fg
, meth
.first
, meth
.second
);
255 varmargs0
= testdai
.varmargs
;
256 logZ0
= testdai
.logZ
;
258 testdai
.calcErrs(varmargs0
);
261 cout
<< left
<< methods
[m
] << "\t";
263 cout
<< right
<< testdai
.time
<< "\t";
265 if( testdai
.has_iters
) {
266 cout
<< testdai
.iters
<< "\t";
273 cout
.setf( ios_base::scientific
);
276 Real me
= clipReal( testdai
.maxErr(), 1e-9 );
279 Real ae
= clipReal( testdai
.avgErr(), 1e-9 );
282 if( testdai
.has_logZ
) {
283 cout
.setf( ios::showpos
);
284 Real le
= clipReal( testdai
.logZ
/ logZ0
- 1.0, 1e-9 );
286 cout
.unsetf( ios::showpos
);
290 if( testdai
.has_maxdiff
) {
291 Real md
= clipReal( testdai
.maxdiff
, 1e-9 );
304 if( marginals
== MarginalsOutputType::VAR
) {
305 for( size_t i
= 0; i
< testdai
.varmargs
.size(); i
++ )
306 cout
<< "# " << testdai
.varmargs
[i
] << endl
;
307 } else if( marginals
== MarginalsOutputType::ALL
) {
308 for( size_t I
= 0; I
< testdai
.marginals
.size(); I
++ )
309 cout
<< "# " << testdai
.marginals
[I
] << endl
;
314 } catch( string
&s
) {
315 cerr
<< "Exception: " << s
<< endl
;