1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
5 This file is part of libDAI.
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
30 #include <boost/program_options.hpp>
32 #include <dai/alldai.h>
37 namespace po
= boost::program_options
;
56 TestDAI( const FactorGraph
&fg
, const string
&_name
, const PropertySet
&opts
) : obj(NULL
), name(_name
), err(), q(), logZ(0.0), maxdiff(0.0), time(0), iters(0U), has_logZ(false), has_maxdiff(false), has_iters(false) {
58 if( name
== "LDPC" ) {
59 double zero
[2] = {1.0, 0.0};
61 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
62 q
.push_back( Factor(Var(i
,2), zero
) );
70 obj
= newInfAlg( name
, fg
, opts
);
81 return obj
->identify();
86 vector
<Factor
> allBeliefs() {
87 vector
<Factor
> result
;
88 for( size_t i
= 0; i
< obj
->fg().nrVars(); i
++ )
89 result
.push_back( obj
->belief( obj
->fg().var(i
) ) );
102 } catch( Exception
&e
) {
106 maxdiff
= obj
->maxDiff();
108 } catch( Exception
&e
) {
112 iters
= obj
->Iterations();
114 } catch( Exception
&e
) {
121 void calcErrs( const TestDAI
&x
) {
123 err
.reserve( q
.size() );
124 for( size_t i
= 0; i
< q
.size(); i
++ )
125 err
.push_back( dist( q
[i
], x
.q
[i
], Prob::DISTTV
) );
128 void calcErrs( const vector
<Factor
> &x
) {
130 err
.reserve( q
.size() );
131 for( size_t i
= 0; i
< q
.size(); i
++ )
132 err
.push_back( dist( q
[i
], x
[i
], Prob::DISTTV
) );
136 return( *max_element( err
.begin(), err
.end() ) );
140 return( accumulate( err
.begin(), err
.end(), 0.0 ) / err
.size() );
145 pair
<string
, PropertySet
> parseMethod( const string
&_s
, const map
<string
,string
> & aliases
) {
146 // s = first part of _s, until '['
147 string::size_type pos
= _s
.find_first_of('[');
149 if( pos
== string::npos
)
152 s
= _s
.substr(0,pos
);
154 // if the first part is an alias, substitute
155 if( aliases
.find(s
) != aliases
.end() )
156 s
= aliases
.find(s
)->second
;
158 // attach second part, merging properties if necessary
159 if( pos
!= string::npos
) {
160 if( s
.at(s
.length()-1) == ']' ) {
161 s
= s
.erase(s
.length()-1,1) + ',' + _s
.substr(pos
+1);
163 s
= s
+ _s
.substr(pos
);
166 pair
<string
, PropertySet
> result
;
167 string
& name
= result
.first
;
168 PropertySet
& opts
= result
.second
;
170 pos
= s
.find_first_of('[');
171 if( pos
== string::npos
)
172 throw "Malformed method";
173 name
= s
.substr( 0, pos
);
175 for( ; strlen( DAINames
[n
] ) != 0; n
++ )
176 if( name
== DAINames
[n
] )
178 if( strlen( DAINames
[n
] ) == 0 && (name
!= "LDPC") )
179 DAI_THROW(UNKNOWN_DAI_ALGORITHM
);
182 ss
<< s
.substr(pos
,s
.length());
189 double clipdouble( double x
, double minabs
) {
190 if( fabs(x
) < minabs
)
197 int main( int argc
, char *argv
[] ) {
201 vector
<string
> methods
;
205 bool marginals
= false;
206 bool report_iters
= true;
207 bool report_time
= true;
209 po::options_description
opts_required("Required options");
210 opts_required
.add_options()
211 ("filename", po::value
< string
>(&filename
), "Filename of FactorGraph")
212 ("methods", po::value
< vector
<string
> >(&methods
)->multitoken(), "DAI methods to test")
215 po::options_description
opts_optional("Allowed options");
216 opts_optional
.add_options()
217 ("help", "produce help message")
218 ("aliases", po::value
< string
>(&aliases
), "Filename for aliases")
219 ("tol", po::value
< double >(&tol
), "Override tolerance")
220 ("maxiter", po::value
< size_t >(&maxiter
), "Override maximum number of iterations")
221 ("verbose", po::value
< size_t >(&verbose
), "Override verbosity")
222 ("marginals", po::value
< bool >(&marginals
), "Output single node marginals?")
223 ("report-time", po::value
< bool >(&report_time
), "Report calculation time")
224 ("report-iters", po::value
< bool >(&report_iters
), "Report iterations needed")
227 po::options_description cmdline_options
;
228 cmdline_options
.add(opts_required
).add(opts_optional
);
230 po::variables_map vm
;
231 po::store(po::parse_command_line(argc
, argv
, cmdline_options
), vm
);
234 if( vm
.count("help") || !(vm
.count("filename") && vm
.count("methods")) ) {
235 cout
<< "Reads factorgraph <filename.fg> and performs the approximate" << endl
;
236 cout
<< "inference algorithms <method*>, reporting calculation time, max and average" << endl
;
237 cout
<< "error and relative logZ error (comparing with the results of" << endl
;
238 cout
<< "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl
<< endl
;
239 cout
<< opts_required
<< opts_optional
<< endl
;
244 map
<string
,string
> Aliases
;
245 if( !aliases
.empty() ) {
247 infile
.open (aliases
.c_str());
248 if (infile
.is_open()) {
251 getline(infile
,line
);
254 if( (!line
.empty()) && (line
[0] != '#') ) {
255 string::size_type pos
= line
.find(':',0);
256 if( pos
== string::npos
)
257 throw "Invalid alias";
259 string::size_type posl
= line
.substr(0, pos
).find_last_not_of(" \t");
260 string key
= line
.substr(0, posl
+ 1);
261 string::size_type posr
= line
.substr(pos
+ 1, line
.length()).find_first_not_of(" \t");
262 string val
= line
.substr(pos
+ 1 + posr
, line
.length());
269 throw "Error opening aliases file";
273 fg
.ReadFromFile( filename
.c_str() );
278 cout
.setf( ios_base::scientific
);
281 cout
<< "# " << filename
<< endl
;
283 cout
<< left
<< "# METHOD" << " ";
286 cout
<< right
<< "SECONDS" << " ";
290 cout
<< "ITERS" << " ";
293 cout
<< "MAX ERROR" << " ";
295 cout
<< "AVG ERROR" << " ";
297 cout
<< "LOGZ ERROR" << " ";
299 cout
<< "MAXDIFF" << " ";
302 for( size_t m
= 0; m
< methods
.size(); m
++ ) {
303 pair
<string
, PropertySet
> meth
= parseMethod( methods
[m
], Aliases
);
305 if( vm
.count("tol") )
306 meth
.second
.Set("tol",tol
);
307 if( vm
.count("maxiter") )
308 meth
.second
.Set("maxiter",maxiter
);
309 if( vm
.count("verbose") )
310 meth
.second
.Set("verbose",verbose
);
311 TestDAI
piet(fg
, meth
.first
, meth
.second
);
320 cout
<< left
<< methods
[m
] << " ";
323 cout
<< right
<< piet
.time
<< " ";
327 if( piet
.has_iters
) {
328 cout
<< piet
.iters
<< " ";
335 cout
.setf( ios_base::scientific
);
339 double me
= clipdouble( piet
.maxErr(), 1e-9 );
343 double ae
= clipdouble( piet
.avgErr(), 1e-9 );
347 if( piet
.has_logZ
) {
348 double le
= clipdouble( piet
.logZ
/ logZ0
- 1.0, 1e-9 );
354 if( piet
.has_maxdiff
) {
355 double md
= clipdouble( piet
.maxdiff
, 1e-9 );
367 for( size_t i
= 0; i
< piet
.q
.size(); i
++ )
368 cout
<< "# " << piet
.q
[i
] << endl
;
371 } catch(const char *e
) {
372 cerr
<< "Exception: " << e
<< endl
;
374 } catch(exception
& e
) {
375 cerr
<< "Exception: " << e
.what() << endl
;
379 cerr
<< "Exception of unknown type!" << endl
;