1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
19 #include <boost/program_options.hpp>
21 #include <dai/alldai.h>
26 namespace po
= boost::program_options
;
29 /// Wrapper class for DAI approximate inference algorithms
32 /// Stores a pointer to an InfAlg object, managed by this class
34 /// Stores the name of the InfAlg algorithm
36 /// Stores the total variation distances of the variable marginals
40 /// Stores the variable marginals
41 vector
<Factor
> varMarginals
;
42 /// Stores all marginals
43 vector
<Factor
> allMarginals
;
44 /// Stores the logarithm of the partition sum
46 /// Stores the maximum difference in the last iteration
48 /// Stores the computation time (in seconds)
50 /// Stores the number of iterations needed
52 /// Does the InfAlg support logZ()?
54 /// Does the InfAlg support maxDiff()?
56 /// Does the InfAlg support Iterations()?
59 /// Construct from factor graph \a fg, name \a _name, and set of properties \a opts
60 TestDAI( const FactorGraph
&fg
, const string
&_name
, const PropertySet
&opts
) : obj(NULL
), name(_name
), err(), varMarginals(), allMarginals(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
63 if( name
== "LDPC" ) {
64 // special case: simulating a Low Density Parity Check code
67 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
68 varMarginals
.push_back( Factor(fg
.var(i
), zero
) );
69 allMarginals
= varMarginals
;
77 // create a corresponding InfAlg object
78 obj
= newInfAlg( name
, fg
, opts
);
80 // Add the time needed to create the object
91 string
identify() const {
93 return obj
->identify();
98 /// Run the algorithm and store its results
109 // Store logarithm of the partition sum (if supported)
113 } catch( Exception
&e
) {
114 if( e
.code() == Exception::NOT_IMPLEMENTED
)
120 // Store maximum difference encountered in last iteration (if supported)
122 maxdiff
= obj
->maxDiff();
124 } catch( Exception
&e
) {
125 if( e
.code() == Exception::NOT_IMPLEMENTED
)
131 // Store number of iterations needed (if supported)
133 iters
= obj
->Iterations();
135 } catch( Exception
&e
) {
136 if( e
.code() == Exception::NOT_IMPLEMENTED
)
142 // Store variable marginals
143 varMarginals
.clear();
144 for( size_t i
= 0; i
< obj
->fg().nrVars(); i
++ )
145 varMarginals
.push_back( obj
->beliefV( i
) );
147 // Store all marginals calculated by the method
148 allMarginals
= obj
->beliefs();
152 /// Calculate total variation distance of variable marginals with respect to those in \a x
153 void calcErrs( const TestDAI
&x
) {
155 err
.reserve( varMarginals
.size() );
156 for( size_t i
= 0; i
< varMarginals
.size(); i
++ )
157 err
.push_back( dist( varMarginals
[i
], x
.varMarginals
[i
], Prob::DISTTV
) );
160 /// Calculate total variation distance of variable marginals with respect to those in \a x
161 void calcErrs( const vector
<Factor
> &x
) {
163 err
.reserve( varMarginals
.size() );
164 for( size_t i
= 0; i
< varMarginals
.size(); i
++ )
165 err
.push_back( dist( varMarginals
[i
], x
[i
], Prob::DISTTV
) );
168 /// Return maximum error
170 return( *max_element( err
.begin(), err
.end() ) );
173 /// Return average error
175 return( accumulate( err
.begin(), err
.end(), 0.0 ) / err
.size() );
180 /// Clips a real number: if the absolute value of \a x is less than \a minabs, return \a minabs, else return \a x
181 Real
clipReal( Real x
, Real minabs
) {
182 if( abs(x
) < minabs
)
189 /// Whether to output no marginals, only variable marginals, or all calculated marginals
190 DAI_ENUM(MarginalsOutputType
,NONE
,VAR
,ALL
);
194 int main( int argc
, char *argv
[] ) {
195 // Variables to store command line options
196 // Filename of factor graph
198 // Filename for aliases
200 // Approximate Inference methods to use
201 vector
<string
> methods
;
202 // Which marginals to output
203 MarginalsOutputType marginals
;
204 // Output number of iterations?
205 bool report_iters
= true;
206 // Output calculation time?
207 bool report_time
= true;
209 // Define required command line options
210 po::options_description
opts_required("Required options");
211 opts_required
.add_options()
212 ("filename", po::value
< string
>(&filename
), "Filename of factor graph")
213 ("methods", po::value
< vector
<string
> >(&methods
)->multitoken(), "DAI methods to perform")
216 // Define allowed command line options
217 po::options_description
opts_optional("Allowed options");
218 opts_optional
.add_options()
219 ("help", "Produce help message")
220 ("aliases", po::value
< string
>(&aliases
), "Filename for aliases")
221 ("marginals", po::value
< MarginalsOutputType
>(&marginals
), "Output marginals? (NONE/VAR/ALL, default=NONE)")
222 ("report-time", po::value
< bool >(&report_time
), "Output calculation time (default==1)?")
223 ("report-iters", po::value
< bool >(&report_iters
), "Output iterations needed (default==1)?")
226 // Define all command line options
227 po::options_description cmdline_options
;
228 cmdline_options
.add(opts_required
).add(opts_optional
);
230 // Parse command line
231 po::variables_map vm
;
232 po::store(po::parse_command_line(argc
, argv
, cmdline_options
), vm
);
235 // Display help message if necessary
236 if( vm
.count("help") || !(vm
.count("filename") && vm
.count("methods")) ) {
237 cout
<< "This program is part of libDAI - http://www.libdai.org/" << endl
<< endl
;
238 cout
<< "Usage: ./testdai --filename <filename.fg> --methods <method1> [<method2> <method3> ...]" << endl
<< endl
;
239 cout
<< "Reads factor graph <filename.fg> and performs the approximate inference algorithms" << endl
;
240 cout
<< "<method*>, reporting for each method:" << endl
;
241 cout
<< " o the calculation time needed, in seconds (if report-time == 1);" << endl
;
242 cout
<< " o the number of iterations needed (if report-iters == 1);" << endl
;
243 cout
<< " o the maximum (over all variables) total variation error in the variable marginals;" << endl
;
244 cout
<< " o the average (over all variables) total variation error in the variable marginals;" << endl
;
245 cout
<< " o the error (difference) of the logarithm of the partition sums;" << endl
<< endl
;
246 cout
<< "All errors are calculated by comparing the results of the current method with" << endl
;
247 cout
<< "the results of the first method (the base method). If marginals==VAR, additional" << endl
;
248 cout
<< "output consists of the variable marginals, and if marginals==ALL, all marginals" << endl
;
249 cout
<< "calculated by the method are reported." << endl
<< endl
;
250 cout
<< "<method*> should be a list of one or more methods, seperated by spaces, in the format:" << endl
<< endl
;
251 cout
<< " name[key1=val1,key2=val2,key3=val3,...,keyn=valn]" << endl
<< endl
;
252 cout
<< "where name should be the name of an algorithm in libDAI (or an alias, if an alias" << endl
;
253 cout
<< "filename is provided), followed by a list of properties (surrounded by rectangular" << endl
;
254 cout
<< "brackets), where each property consists of a key=value pair and the properties are" << endl
;
255 cout
<< "seperated by commas. If an alias file is specified, alias substitution is performed." << endl
;
256 cout
<< "This is done by looking up the name in the alias file and substituting the alias" << endl
;
257 cout
<< "by its corresponding method as defined in the alias file. Properties are parsed from" << endl
;
258 cout
<< "left to right, so if a property occurs repeatedly, the right-most value is used." << endl
<< endl
;
259 cout
<< opts_required
<< opts_optional
<< endl
;
261 cout
<< "Note: this is a debugging build of libDAI." << endl
<< endl
;
263 cout
<< "Example: ./testdai --filename testfast.fg --aliases aliases.conf --methods JTREE_HUGIN BP_SEQFIX BP_PARALL[maxiter=5]" << endl
;
269 map
<string
,string
> Aliases
;
270 if( !aliases
.empty() )
271 Aliases
= readAliasesFile( aliases
);
275 fg
.ReadFromFile( filename
.c_str() );
277 // Declare variables used for storing variable marginals and log partition sum of base method
278 vector
<Factor
> varMarginals0
;
282 cout
.setf( ios_base::scientific
);
284 cout
<< "# " << filename
<< endl
;
286 cout
<< left
<< "# METHOD" << "\t";
288 cout
<< right
<< "SECONDS " << "\t";
290 cout
<< "ITERS" << "\t";
291 cout
<< "MAX ERROR" << "\t";
292 cout
<< "AVG ERROR" << "\t";
293 cout
<< "LOGZ ERROR" << "\t";
294 cout
<< "MAXDIFF" << "\t";
297 // For each method...
298 for( size_t m
= 0; m
< methods
.size(); m
++ ) {
300 pair
<string
, PropertySet
> meth
= parseNameProperties( methods
[m
], Aliases
);
302 // Check whether name is valid
304 for( ; strlen( DAINames
[n
] ) != 0; n
++ )
305 if( meth
.first
== DAINames
[n
] )
307 if( strlen( DAINames
[n
] ) == 0 )
308 DAI_THROWE(UNKNOWN_DAI_ALGORITHM
,string("Unknown DAI algorithm \"") + meth
.first
+ string("\" in \"") + methods
[m
] + string("\""));
310 // Construct object for running the method
311 TestDAI
testdai(fg
, meth
.first
, meth
.second
);
316 // For the base method, store its variable marginals and logarithm of the partition sum
318 varMarginals0
= testdai
.varMarginals
;
319 logZ0
= testdai
.logZ
;
322 // Calculate errors relative to base method
323 testdai
.calcErrs( varMarginals0
);
325 // Output method name
327 cout
<< left
<< methods
[m
] << "\t";
328 // Output calculation time, if requested
330 cout
<< right
<< testdai
.time
<< "\t";
331 // Output number of iterations, if requested
333 if( testdai
.has_iters
) {
334 cout
<< testdai
.iters
<< "\t";
340 // If this is not the base method
342 cout
.setf( ios_base::scientific
);
345 // Output maximum error in variable marginals
346 Real me
= clipReal( testdai
.maxErr(), 1e-9 );
349 // Output average error in variable marginals
350 Real ae
= clipReal( testdai
.avgErr(), 1e-9 );
353 // Output error in log partition sum
354 if( testdai
.has_logZ
) {
355 cout
.setf( ios::showpos
);
356 Real le
= clipReal( testdai
.logZ
- logZ0
, 1e-9 );
358 cout
.unsetf( ios::showpos
);
362 // Output maximum difference in last iteration
363 if( testdai
.has_maxdiff
) {
364 Real md
= clipReal( testdai
.maxdiff
, 1e-9 );
377 // Output marginals, if requested
378 if( marginals
== MarginalsOutputType::VAR
) {
379 for( size_t i
= 0; i
< testdai
.varMarginals
.size(); i
++ )
380 cout
<< "# " << testdai
.varMarginals
[i
] << endl
;
381 } else if( marginals
== MarginalsOutputType::ALL
) {
382 for( size_t I
= 0; I
< testdai
.allMarginals
.size(); I
++ )
383 cout
<< "# " << testdai
.allMarginals
[I
] << endl
;
388 } catch( string
&s
) {
389 // Abort with error message
390 cerr
<< "Exception: " << s
<< endl
;