Cleaned up example_imagesegmentation
[libdai.git] / utils / uai2fg.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2008-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #include <iostream>
12 #include <fstream>
13 #include <dai/alldai.h>
14 #include <dai/util.h>
15 #include <dai/index.h>
16 #include <dai/jtree.h>
17
18
19 using namespace std;
20 using namespace dai;
21
22
23 /// Reads "evidence" (a mapping from observed variable labels to the observed values) from a UAI evidence file
24 map<size_t, size_t> ReadUAIEvidenceFile( char* filename ) {
25 map<size_t, size_t> evid;
26
27 // open file
28 ifstream is;
29 is.open( filename );
30 if( is.is_open() ) {
31 // read number of observed variables
32 size_t nr_evid;
33 is >> nr_evid;
34 if( is.fail() )
35 DAI_THROWE(INVALID_EVIDENCE_FILE,"Cannot read number of observed variables");
36
37 // for each observation, read the variable label and the observed value
38 for( size_t i = 0; i < nr_evid; i++ ) {
39 size_t label, val;
40 is >> label;
41 if( is.fail() )
42 DAI_THROWE(INVALID_EVIDENCE_FILE,"Cannot read label for " + toString(i) + "'th observed variable");
43 is >> val;
44 if( is.fail() )
45 DAI_THROWE(INVALID_EVIDENCE_FILE,"Cannot read value of " + toString(i) + "'th observed variable");
46 evid[label] = val;
47 }
48
49 // close file
50 is.close();
51 } else
52 DAI_THROWE(CANNOT_READ_FILE,"Cannot read from file " + std::string(filename));
53
54 return evid;
55 }
56
57
58 /// Reads factor graph (as a pair of a variable vector and factor vector) from a UAI factor graph file
59 pair<vector<Var>, vector<Factor> > ReadUAIFGFile( const char *filename, size_t verbose ) {
60 pair<vector<Var>, vector<Factor> > result;
61 vector<Var>& vars = result.first;
62 vector<Factor>& factors = result.second;
63
64 // open file
65 ifstream is;
66 is.open( filename );
67 if( is.is_open() ) {
68 size_t nrFacs, nrVars;
69 string line;
70
71 // read header line
72 getline(is,line);
73 if( is.fail() || (line != "BAYES" && line != "MARKOV" && line != "BAYES\r" && line != "MARKOV\r") )
74 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"UAI factor graph file should start with \"BAYES\" or \"MARKOV\"");
75 if( verbose >= 2 )
76 cout << "Reading " << line << " network..." << endl;
77
78 // read number of variables
79 is >> nrVars;
80 if( is.fail() )
81 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read number of variables");
82 if( verbose >= 2 )
83 cout << "Reading " << nrVars << " variables..." << endl;
84
85 // for each variable, read its number of states
86 vars.reserve( nrVars );
87 for( size_t i = 0; i < nrVars; i++ ) {
88 size_t dim;
89 is >> dim;
90 if( is.fail() )
91 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read number of states of " + toString(i) + "'th variable");
92 vars.push_back( Var( i, dim ) );
93 }
94
95 // read number of factors
96 is >> nrFacs;
97 if( is.fail() )
98 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read number of factors");
99 if( verbose >= 2 )
100 cout << "Reading " << nrFacs << " factors..." << endl;
101
102 // for each factor, read the variables on which it depends
103 vector<vector<long> > labels;
104 factors.reserve( nrFacs );
105 labels.reserve( nrFacs );
106 for( size_t I = 0; I < nrFacs; I++ ) {
107 if( verbose >= 3 )
108 cout << "Reading factor " << I << "..." << endl;
109
110 // read number of variables for factor I
111 size_t I_nrVars;
112 is >> I_nrVars;
113 if( is.fail() )
114 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read number of variables for " + toString(I) + "'th factor");
115 if( verbose >= 3 )
116 cout << " which depends on " << I_nrVars << " variables" << endl;
117
118 // for each of the variables, read its label and number of states
119 vector<long> I_labels;
120 vector<size_t> I_dims;
121 VarSet I_vars;
122 I_labels.reserve( I_nrVars );
123 I_dims.reserve( I_nrVars );
124 for( size_t _i = 0; _i < I_nrVars; _i++ ) {
125 long label;
126 is >> label;
127 if( is.fail() )
128 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read variable labels for " + toString(I) + "'th factor");
129 I_labels.push_back( label );
130 I_dims.push_back( vars[label].states() );
131 I_vars |= vars[label];
132 }
133 if( verbose >= 3 )
134 cout << " labels: " << I_labels << ", dimensions " << I_dims << endl;
135
136 // add the factor and the labels
137 factors.push_back( Factor(I_vars,0.0) );
138 labels.push_back( I_labels );
139 }
140
141 // for each factor, read its values
142 for( size_t I = 0; I < nrFacs; I++ ) {
143 if( verbose >= 3 )
144 cout << "Reading factor " << I << "..." << endl;
145
146 // last label is least significant, so we reverse the label vector
147 reverse( labels[I].begin(), labels[I].end() );
148
149 // prepare a vector containing the dimensionalities of the variables for this factor
150 size_t I_nrVars = factors[I].vars().size();
151 vector<size_t> I_dims;
152 I_dims.reserve( I_nrVars );
153 for( size_t _i = 0; _i < I_nrVars; _i++ )
154 I_dims.push_back( vars[labels[I][_i]].states() );
155 if( verbose >= 3 )
156 cout << " labels: " << labels[I] << ", dimensions " << I_dims << endl;
157
158 // calculate permutation sigma (internally, members are sorted canonically,
159 // which may be different from the way they are sorted in the file)
160 vector<size_t> sigma( I_nrVars, 0 );
161 VarSet::const_iterator j = factors[I].vars().begin();
162 for( size_t mi = 0; mi < I_nrVars; mi++, j++ )
163 sigma[mi] = distance( labels[I].begin(), find( labels[I].begin(), labels[I].end(), j->label() ) );
164 if( verbose >= 3 )
165 cout << " permutation: " << sigma << endl;
166
167 // construct permutation object
168 Permute permindex( I_dims, sigma );
169
170 // read factor values
171 size_t nrNonZeros;
172 is >> nrNonZeros;
173 if( is.fail() )
174 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read number of nonzero factor values for " + toString(I) + "'th factor");
175 if( verbose >= 3 )
176 cout << " number of nonzero values: " << nrNonZeros << endl;
177 DAI_ASSERT( nrNonZeros == factors[I].states() );
178 for( size_t li = 0; li < nrNonZeros; li++ ) {
179 Real val;
180 is >> val;
181 if( is.fail() )
182 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read factor values of " + toString(I) + "'th factor");
183 // assign value after calculating its linear index corresponding to the permutation
184 factors[I][permindex.convertLinearIndex( li )] = val;
185 }
186 }
187 if( verbose >= 3 )
188 cout << "factors:" << factors << endl;
189
190 // close file
191 is.close();
192 } else
193 DAI_THROWE(CANNOT_READ_FILE,"Cannot read from file " + std::string(filename));
194
195 return result;
196 }
197
198
199 int main( int argc, char *argv[] ) {
200 if ( argc != 7 ) {
201 cout << "This program is part of libDAI - http://www.libdai.org/" << endl << endl;
202 cout << "Usage: ./uai2fg <filename.uai> <filename.uai.evid> <filename.fg> <type> <run_jtree> <verbose>" << endl << endl;
203 cout << "Converts input files in the UAI 2008 approximate inference evaluation format" << endl;
204 cout << "(see http://graphmod.ics.uci.edu/uai08/) to the libDAI factor graph format." << endl;
205 cout << "Reads factor graph <filename.uai> and evidence <filename.uai.evid>" << endl;
206 cout << "and writes the resulting clamped factor graph to <filename.fg>." << endl;
207 cout << "If type==0, uses surgery (recommended), otherwise, uses just adds delta factors." << endl;
208 cout << "If run_jtree!=0, runs a junction tree and reports the results in the UAI 2008 results file format." << endl;
209 return 1;
210 } else {
211 long verbose = atoi( argv[6] );
212 long type = atoi( argv[4] );
213 bool run_jtree = atoi( argv[5] );
214
215 // read factor graph and evidence
216 pair<vector<Var>, vector<Factor> > varfacs = ReadUAIFGFile( argv[1], verbose );
217 map<size_t,size_t> evid = ReadUAIEvidenceFile( argv[2] );
218 vector<Var>& vars = varfacs.first;
219 vector<Factor>& facs = varfacs.second;
220
221 // construct unclamped factor graph
222 FactorGraph fg0( facs.begin(), facs.end(), vars.begin(), vars.end(), facs.size(), vars.size() );
223
224 // change factor graph to reflect observed evidence
225 if( type == 0 ) {
226 // replace factors with clamped variables with slices
227 for( size_t I = 0; I < facs.size(); I++ ) {
228 for( map<size_t,size_t>::const_iterator e = evid.begin(); e != evid.end(); e++ ) {
229 if( facs[I].vars() >> vars[e->first] ) {
230 if( verbose >= 2 )
231 cout << "Clamping " << e->first << " to value " << e->second << " in factor " << I << " = " << facs[I].vars() << endl;
232 facs[I] = facs[I].slice( vars[e->first], e->second );
233 if( verbose >= 2 )
234 cout << "...remaining vars: " << facs[I].vars() << endl;
235 }
236 }
237 }
238 // remove empty factors
239 double logZcorr = 0.0;
240 for( vector<Factor>::iterator I = facs.begin(); I != facs.end(); )
241 if( I->vars().size() == 0 ) {
242 logZcorr += std::log( (Real)(*I)[0] );
243 I = facs.erase( I );
244 } else
245 I++;
246 // multiply with logZcorr constant
247 if( facs.size() == 0 )
248 facs.push_back( Factor( VarSet(), std::exp(logZcorr) ) );
249 else
250 facs.front() *= std::exp(logZcorr);
251 }
252 // add delta factors corresponding to observed variable values
253 for( map<size_t,size_t>::const_iterator e = evid.begin(); e != evid.end(); e++ )
254 facs.push_back( createFactorDelta( vars[e->first], e->second ) );
255
256 // construct clamped factor graph
257 FactorGraph fg( facs.begin(), facs.end(), vars.begin(), vars.end(), facs.size(), vars.size() );
258
259 // write it to a file
260 fg.WriteToFile( argv[3] );
261
262 // if requested, perform various inference tasks
263 if( run_jtree ) {
264 // construct junction tree on unclamped factor graph
265 JTree jt0( fg0, PropertySet()("updates",string("HUGIN")) );
266 jt0.init();
267 jt0.run();
268
269 // construct junction tree on clamped factor graph
270 JTree jt( fg, PropertySet()("updates",string("HUGIN")) );
271 jt.init();
272 jt.run();
273
274 // output probability of evidence
275 cout.precision( 8 );
276 if( evid.size() )
277 cout << "z " << (jt.logZ() - jt0.logZ()) / dai::log((Real)10.0) << endl;
278 else
279 cout << "z " << jt.logZ() / dai::log((Real)10.0) << endl;
280
281 // output variable marginals
282 cout << "m " << jt.nrVars() << " ";
283 for( size_t i = 0; i < jt.nrVars(); i++ ) {
284 cout << jt.var(i).states() << " ";
285 for( size_t s = 0; s < jt.var(i).states(); s++ )
286 cout << jt.beliefV(i)[s] << " ";
287 }
288 cout << endl;
289
290 // calculate MAP state
291 jt.props.inference = JTree::Properties::InfType::MAXPROD;
292 jt.init();
293 jt.run();
294 vector<size_t> MAP = jt.findMaximum();
295 map<Var, size_t> state;
296 for( size_t i = 0; i < MAP.size(); i++ )
297 state[jt.var(i)] = MAP[i];
298 double log_MAP_prob = 0.0;
299 for( size_t I = 0; I < jt.nrFactors(); I++ )
300 log_MAP_prob += dai::log( jt.factor(I)[calcLinearState( jt.factor(I).vars(), state )] );
301
302 // output MAP state
303 cout << "s ";
304 if( evid.size() )
305 cout << (log_MAP_prob - jt0.logZ()) / dai::log((Real)10.0) << " ";
306 else
307 cout << log_MAP_prob / dai::log((Real)10.0) << " ";
308 cout << jt.nrVars() << " ";
309 for( size_t i = 0; i < jt.nrVars(); i++ )
310 cout << MAP[i] << " ";
311 cout << endl;
312 }
313 }
314
315 return 0;
316 }