1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
29 #include <boost/program_options.hpp>
31 #include <dai/alldai.h>
36 namespace po
= boost::program_options
;
51 TestAI( const FactorGraph
&fg
, const string
&_name
, const Properties
&opts
) : obj(NULL
), name(_name
), err(), q(), logZ(0.0), maxdiff(0.0), time(0) {
53 obj
= newInfAlg( name
, fg
, opts
);
56 } else if( method.substr(0,5) == "EXACT" ) { // EXACT
57 // Look if the network is small enough to do brute-force exact method
58 bool toolarge = false;
59 size_t total_statespace = 1;
60 for( size_t i = 0; i < fg.nrVars(); i++ ) {
61 total_statespace *= fg.var(i).states();
62 if( total_statespace > (1UL << 16) )
68 for( size_t I = 0; I < fg.nrFactors(); I++ )
69 piet *= fg.factor( I );
70 for( size_t i = 0; i < fg.nrVars(); i++ )
71 q.push_back(piet.marginal(fg.var(i)));
73 logZ = fg.ExactlogZ();
75 throw "Network too large for EXACT method";
87 return obj
->identify();
92 vector
<Factor
> allBeliefs() {
93 vector
<Factor
> result
;
94 for( size_t i
= 0; i
< obj
->nrVars(); i
++ )
95 result
.push_back( obj
->belief( obj
->var(i
) ) );
101 // if( name == "EXACT" ) {
102 // // calculation has already been done
108 logZ
= real(obj
->logZ());
109 maxdiff
= obj
->MaxDiff();
114 void calcErrs( const TestAI
&x
) {
116 err
.reserve( q
.size() );
117 for( size_t i
= 0; i
< q
.size(); i
++ )
118 err
.push_back( dist( q
[i
], x
.q
[i
], Prob::DISTTV
) );
121 void calcErrs( const vector
<Factor
> &x
) {
123 err
.reserve( q
.size() );
124 for( size_t i
= 0; i
< q
.size(); i
++ )
125 err
.push_back( dist( q
[i
], x
[i
], Prob::DISTTV
) );
129 return( *max_element( err
.begin(), err
.end() ) );
133 return( accumulate( err
.begin(), err
.end(), 0.0 ) / err
.size() );
138 pair
<string
, Properties
> parseMethod( const string
&_s
, const map
<string
,string
> & aliases
) {
140 if( aliases
.find(_s
) != aliases
.end() )
141 s
= aliases
.find(_s
)->second
;
143 pair
<string
, Properties
> result
;
144 string
& name
= result
.first
;
145 Properties
& opts
= result
.second
;
147 string::size_type pos
= s
.find_first_of('[');
148 name
= s
.substr( 0, pos
);
149 if( pos
== string::npos
)
150 throw "Malformed method";
152 for( ; strlen( DAINames
[n
] ) != 0; n
++ )
153 if( name
== DAINames
[n
] )
155 if( strlen( DAINames
[n
] ) == 0 )
156 throw "Unknown inference algorithm";
159 ss
<< s
.substr(pos
,s
.length());
166 double clipdouble( double x
, double minabs
) {
167 if( fabs(x
) < minabs
)
174 int main( int argc
, char *argv
[] ) {
178 vector
<string
> methods
;
183 po::options_description
opts_required("Required options");
184 opts_required
.add_options()
185 ("filename", po::value
< string
>(&filename
), "Filename of FactorGraph")
186 ("methods", po::value
< vector
<string
> >(&methods
)->multitoken(), "AI methods to test")
189 po::options_description
opts_optional("Allowed options");
190 opts_optional
.add_options()
191 ("help", "produce help message")
192 ("aliases", po::value
< string
>(&aliases
), "Filename for aliases")
193 ("tol", po::value
< double >(&tol
), "Override tolerance")
194 ("maxiter", po::value
< size_t >(&maxiter
), "Override maximum number of iterations")
195 ("verbose", po::value
< size_t >(&verbose
), "Override verbosity")
198 po::options_description cmdline_options
;
199 cmdline_options
.add(opts_required
).add(opts_optional
);
201 po::variables_map vm
;
202 po::store(po::parse_command_line(argc
, argv
, cmdline_options
), vm
);
205 if( vm
.count("help") || !(vm
.count("filename") && vm
.count("methods")) ) {
206 cout
<< "Reads factorgraph <filename.fg> and performs the approximate" << endl
;
207 cout
<< "inference algorithms <method*>, reporting clocks, max and average" << endl
;
208 cout
<< "error and relative logZ error (comparing with the results of" << endl
;
209 cout
<< "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl
<< endl
;
210 cout
<< opts_required
<< opts_optional
<< endl
;
215 map
<string
,string
> Aliases
;
216 if( !aliases
.empty() ) {
218 infile
.open (aliases
.c_str());
219 if (infile
.is_open()) {
222 getline(infile
,line
);
225 if( (!line
.empty()) && (line
[0] != '#') ) {
226 string::size_type pos
= line
.find(':',0);
227 if( pos
== string::npos
)
228 throw "Invalid alias";
230 string::size_type posl
= line
.substr(0, pos
).find_last_not_of(" \t");
231 string key
= line
.substr(0, posl
+ 1);
232 string::size_type posr
= line
.substr(pos
+ 1, line
.length()).find_first_not_of(" \t");
233 string val
= line
.substr(pos
+ 1 + posr
, line
.length());
240 throw "Error opening aliases file";
244 if( fg
.ReadFromFile(filename
.c_str()) ) {
245 cout
<< "Error reading " << filename
<< endl
;
251 cout
<< "# " << filename
<< endl
;
253 cout
<< left
<< "# METHOD" << " ";
255 cout
<< right
<< "SECONDS" << " ";
257 cout
<< "MAX ERROR" << " ";
259 cout
<< "AVG ERROR" << " ";
261 cout
<< "LOGZ ERROR" << " ";
263 cout
<< "MAXDIFF" << endl
;
265 for( size_t m
= 0; m
< methods
.size(); m
++ ) {
266 pair
<string
, Properties
> meth
= parseMethod( methods
[m
], Aliases
);
268 if( vm
.count("tol") )
269 meth
.second
.Set("tol",tol
);
270 if( vm
.count("maxiter") )
271 meth
.second
.Set("maxiter",maxiter
);
272 if( vm
.count("verbose") )
273 meth
.second
.Set("verbose",verbose
);
274 TestAI
piet(fg
, meth
.first
, meth
.second
);
283 // cout << left << piet.identify() << " ";
284 cout
<< left
<< methods
[m
] << " ";
286 cout
<< right
<< piet
.time
<< " ";
289 cout
.setf( ios_base::scientific
);
292 double me
= clipdouble( piet
.maxErr(), 1e-9 );
295 double ae
= clipdouble( piet
.avgErr(), 1e-9 );
298 double le
= clipdouble( piet
.logZ
/ logZ0
- 1.0, 1e-9 );
301 double md
= clipdouble( piet
.maxdiff
, 1e-9 );
311 } catch(const char *e
) {
312 cerr
<< "Exception: " << e
<< endl
;
314 } catch(exception
& e
) {
315 cerr
<< "Exception: " << e
.what() << endl
;
319 cerr
<< "Exception of unknown type!" << endl
;