New git HEAD version
[libdai.git] / src / io.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/io.h>
10 #include <dai/alldai.h>
11 #include <iostream>
12 #include <fstream>
13
14
15 namespace dai {
16
17
18 using namespace std;
19
20
21 void ReadUaiAieFactorGraphFile( const char *filename, size_t verbose, std::vector<Var>& vars, std::vector<Factor>& factors, std::vector<Permute>& permutations ) {
22 vars.clear();
23 factors.clear();
24 permutations.clear();
25
26 // open file
27 ifstream is;
28 is.open( filename );
29 if( is.is_open() ) {
30 size_t nrFacs, nrVars;
31 string line;
32
33 // read header line
34 getline(is,line);
35 if( is.fail() || line.size() == 0 )
36 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"UAI factor graph file should start with nonempty header line");
37 if( line[line.size() - 1] == '\r' )
38 line.resize( line.size() - 1 ); // for DOS text files
39 if( line != "BAYES" && line != "MARKOV" )
40 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"UAI factor graph file should start with \"BAYES\" or \"MARKOV\"");
41 if( verbose >= 2 )
42 cout << "Reading " << line << " network..." << endl;
43
44 // read number of variables
45 is >> nrVars;
46 if( is.fail() )
47 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read number of variables");
48 if( verbose >= 2 )
49 cout << "Reading " << nrVars << " variables..." << endl;
50
51 // for each variable, read its number of states
52 vars.reserve( nrVars );
53 for( size_t i = 0; i < nrVars; i++ ) {
54 size_t dim;
55 is >> dim;
56 if( is.fail() )
57 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read number of states of " + toString(i) + "'th variable");
58 vars.push_back( Var( i, dim ) );
59 }
60
61 // read number of factors
62 is >> nrFacs;
63 if( is.fail() )
64 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read number of factors");
65 if( verbose >= 2 )
66 cout << "Reading " << nrFacs << " factors..." << endl;
67
68 // for each factor, read the variables on which it depends
69 vector<vector<Var> > factorVars;
70 factors.reserve( nrFacs );
71 factorVars.reserve( nrFacs );
72 for( size_t I = 0; I < nrFacs; I++ ) {
73 if( verbose >= 3 )
74 cout << "Reading factor " << I << "..." << endl;
75
76 // read number of variables for factor I
77 size_t I_nrVars;
78 is >> I_nrVars;
79 if( is.fail() )
80 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read number of variables for " + toString(I) + "'th factor");
81 if( verbose >= 3 )
82 cout << " which depends on " << I_nrVars << " variables" << endl;
83
84 // read the variable labels
85 vector<long> I_labels;
86 vector<size_t> I_dims;
87 I_labels.reserve( I_nrVars );
88 I_dims.reserve( I_nrVars );
89 factorVars[I].reserve( I_nrVars );
90 for( size_t _i = 0; _i < I_nrVars; _i++ ) {
91 long label;
92 is >> label;
93 if( is.fail() )
94 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read variable labels for " + toString(I) + "'th factor");
95 I_labels.push_back( label );
96 I_dims.push_back( vars[label].states() );
97 factorVars[I].push_back( vars[label] );
98 }
99 if( verbose >= 3 )
100 cout << " labels: " << I_labels << ", dimensions " << I_dims << endl;
101
102 // add the factor and the labels
103 factors.push_back( Factor( VarSet( factorVars[I].begin(), factorVars[I].end(), factorVars[I].size() ), (Real)0 ) );
104 }
105
106 // for each factor, read its values
107 permutations.reserve( nrFacs );
108 for( size_t I = 0; I < nrFacs; I++ ) {
109 if( verbose >= 3 )
110 cout << "Reading factor " << I << "..." << endl;
111
112 // calculate permutation object, reversing the indexing in factorVars[I] first
113 Permute permindex( factorVars[I], true );
114 permutations.push_back( permindex );
115
116 // read factor values
117 size_t nrNonZeros;
118 is >> nrNonZeros;
119 if( is.fail() )
120 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read number of nonzero factor values for " + toString(I) + "'th factor");
121 if( verbose >= 3 )
122 cout << " number of nonzero values: " << nrNonZeros << endl;
123 DAI_ASSERT( nrNonZeros == factors[I].nrStates() );
124 for( size_t li = 0; li < nrNonZeros; li++ ) {
125 Real val;
126 is >> val;
127 if( is.fail() )
128 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read factor values of " + toString(I) + "'th factor");
129 // assign value after calculating its linear index corresponding to the permutation
130 if( verbose >= 4 )
131 cout << " " << li << "'th value " << val << " corresponds with index " << permindex.convertLinearIndex(li) << endl;
132 factors[I].set( permindex.convertLinearIndex( li ), val );
133 }
134 }
135 if( verbose >= 3 )
136 cout << "variables:" << vars << endl;
137 if( verbose >= 3 )
138 cout << "factors:" << factors << endl;
139
140 // close file
141 is.close();
142 } else
143 DAI_THROWE(CANNOT_READ_FILE,"Cannot read from file " + std::string(filename));
144 }
145
146
147 std::vector<std::map<size_t, size_t> > ReadUaiAieEvidenceFile( const char* filename, size_t verbose ) {
148 vector<map<size_t, size_t> > evid;
149
150 // open file
151 ifstream is;
152 string line;
153 is.open( filename );
154 if( is.is_open() ) {
155 // read number of lines
156 getline( is, line );
157 if( is.fail() || line.size() == 0 )
158 DAI_THROWE(INVALID_EVIDENCE_FILE,"Cannot read header line of evidence file");
159 if( line[line.size() - 1] == '\r' )
160 line.resize( line.size() - 1 ); // for DOS text files
161 size_t nrLines = fromString<size_t>( line );
162 if( verbose >= 2 )
163 cout << "Reading " << nrLines << " evidence file lines..." << endl;
164
165 if( nrLines ) {
166 // detect version (pre-2010 or 2010)
167 streampos pos = is.tellg();
168 getline( is, line );
169 if( is.fail() || line.size() == 0 )
170 DAI_THROWE(INVALID_EVIDENCE_FILE,"Cannot read second line of evidence file");
171 if( line[line.size() - 1] == '\r' )
172 line.resize( line.size() - 1 ); // for DOS text files
173 vector<string> cols;
174 cols = tokenizeString( line, false, " \t" );
175 bool oldVersion = true;
176 if( cols.size() % 2 )
177 oldVersion = false;
178 if( verbose >= 2 ) {
179 if( oldVersion )
180 cout << "Detected old (2006, 2008) evidence file format" << endl;
181 else
182 cout << "Detected new (2010) evidence file format" << endl;
183 }
184 size_t nrEvid;
185 if( oldVersion ) {
186 nrEvid = 1;
187 is.seekg( 0 );
188 } else {
189 nrEvid = nrLines;
190 is.seekg( pos );
191 }
192
193 // read all evidence cases
194 if( verbose >= 2 )
195 cout << "Reading " << nrEvid << " evidence cases..." << endl;
196 evid.resize( nrEvid );
197 for( size_t i = 0; i < nrEvid; i++ ) {
198 // read number of variables
199 size_t nrObs;
200 is >> nrObs;
201 if( is.fail() )
202 DAI_THROWE(INVALID_EVIDENCE_FILE,"Evidence case " + toString(i) + ": Cannot read number of observations");
203 if( verbose >= 2 )
204 cout << "Evidence case " << i << ": reading " << nrObs << " observations..." << endl;
205
206 // for each observation, read the variable label and the observed value
207 for( size_t j = 0; j < nrObs; j++ ) {
208 size_t label, val;
209 is >> label;
210 if( is.fail() )
211 DAI_THROWE(INVALID_EVIDENCE_FILE,"Evidence case " + toString(i) + ": Cannot read label for " + toString(j) + "'th observed variable");
212 is >> val;
213 if( is.fail() )
214 DAI_THROWE(INVALID_EVIDENCE_FILE,"Evidence case " + toString(i) + ": Cannot read value of " + toString(j) + "'th observed variable");
215 if( verbose >= 3 )
216 cout << " variable: " << label << ", value: " << val << endl;
217 evid[i][label] = val;
218 }
219 }
220 }
221
222 // close file
223 is.close();
224 } else
225 DAI_THROWE(CANNOT_READ_FILE,"Cannot read from file " + std::string(filename));
226
227 if( evid.size() == 0 )
228 evid.resize( 1 );
229
230 return evid;
231 }
232
233
234 } // end of namespace dai