Now compiles also with Visual Studio 2008 under Windows (still buggy!)
[libdai.git] / src / factorgraph.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <iterator>
24 #include <map>
25 #include <set>
26 #include <fstream>
27 #include <string>
28 #include <algorithm>
29 #include <functional>
30 #include <dai/factorgraph.h>
31 #include <dai/util.h>
32
33
34 namespace dai {
35
36
37 using namespace std;
38
39
40 FactorGraph::FactorGraph( const std::vector<Factor> &P ) : G(), _undoProbs(), _normtype(Prob::NORMPROB) {
41 // add factors, obtain variables
42 set<Var> _vars;
43 factors.reserve( P.size() );
44 size_t nrEdges = 0;
45 for( vector<Factor>::const_iterator p2 = P.begin(); p2 != P.end(); p2++ ) {
46 factors.push_back( *p2 );
47 copy( p2->vars().begin(), p2->vars().end(), inserter( _vars, _vars.begin() ) );
48 nrEdges += p2->vars().size();
49 }
50
51 // add _vars
52 vars.reserve( _vars.size() );
53 for( set<Var>::const_iterator p1 = _vars.begin(); p1 != _vars.end(); p1++ )
54 vars.push_back( *p1 );
55
56 // create graph structure
57 createGraph( nrEdges );
58 }
59
60
61 /// Part of constructors (creates edges, neighbours and adjacency matrix)
62 void FactorGraph::createGraph( size_t nrEdges ) {
63 // create a mapping for indices
64 hash_map<size_t, size_t> hashmap;
65
66 for( size_t i = 0; i < vars.size(); i++ )
67 hashmap[var(i).label()] = i;
68
69 // create edge list
70 vector<Edge> edges;
71 edges.reserve( nrEdges );
72 for( size_t i2 = 0; i2 < nrFactors(); i2++ ) {
73 const VarSet& ns = factor(i2).vars();
74 for( VarSet::const_iterator q = ns.begin(); q != ns.end(); q++ )
75 edges.push_back( Edge(hashmap[q->label()], i2) );
76 }
77
78 // create bipartite graph
79 G.create( nrVars(), nrFactors(), edges.begin(), edges.end() );
80 }
81
82
83 ostream& operator << (ostream& os, const FactorGraph& fg) {
84 os << fg.nrFactors() << endl;
85
86 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
87 os << endl;
88 os << fg.factor(I).vars().size() << endl;
89 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
90 os << i->label() << " ";
91 os << endl;
92 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
93 os << i->states() << " ";
94 os << endl;
95 size_t nr_nonzeros = 0;
96 for( size_t k = 0; k < fg.factor(I).states(); k++ )
97 if( fg.factor(I)[k] != 0.0 )
98 nr_nonzeros++;
99 os << nr_nonzeros << endl;
100 for( size_t k = 0; k < fg.factor(I).states(); k++ )
101 if( fg.factor(I)[k] != 0.0 ) {
102 char buf[20];
103 sprintf(buf,"%18.14g", fg.factor(I)[k]);
104 os << k << " " << buf << endl;
105 }
106 }
107
108 return(os);
109 }
110
111
112 istream& operator >> (istream& is, FactorGraph& fg) {
113 long verbose = 0;
114
115 try {
116 vector<Factor> factors;
117 size_t nr_f;
118 string line;
119
120 while( (is.peek()) == '#' )
121 getline(is,line);
122 is >> nr_f;
123 if( is.fail() )
124 throw "ReadFromFile: unable to read number of Factors";
125 if( verbose >= 2 )
126 cout << "Reading " << nr_f << " factors..." << endl;
127
128 getline (is,line);
129 if( is.fail() )
130 throw "ReadFromFile: empty line expected";
131
132 for( size_t I = 0; I < nr_f; I++ ) {
133 if( verbose >= 3 )
134 cout << "Reading factor " << I << "..." << endl;
135 size_t nr_members;
136 while( (is.peek()) == '#' )
137 getline(is,line);
138 is >> nr_members;
139 if( verbose >= 3 )
140 cout << " nr_members: " << nr_members << endl;
141
142 vector<long> labels;
143 for( size_t mi = 0; mi < nr_members; mi++ ) {
144 long mi_label;
145 while( (is.peek()) == '#' )
146 getline(is,line);
147 is >> mi_label;
148 labels.push_back(mi_label);
149 }
150 if( verbose >= 3 ) {
151 cout << " labels: ";
152 copy (labels.begin(), labels.end(), ostream_iterator<int>(cout, " "));
153 cout << endl;
154 }
155
156 vector<size_t> dims;
157 for( size_t mi = 0; mi < nr_members; mi++ ) {
158 size_t mi_dim;
159 while( (is.peek()) == '#' )
160 getline(is,line);
161 is >> mi_dim;
162 dims.push_back(mi_dim);
163 }
164 if( verbose >= 3 ) {
165 cout << " dimensions: ";
166 copy (dims.begin(), dims.end(), ostream_iterator<int>(cout, " "));
167 cout << endl;
168 }
169
170 // add the Factor
171 VarSet I_vars;
172 for( size_t mi = 0; mi < nr_members; mi++ )
173 I_vars |= Var(labels[mi], dims[mi]);
174 factors.push_back(Factor(I_vars,0.0));
175
176 // calculate permutation sigma (internally, members are sorted)
177 vector<size_t> sigma(nr_members,0);
178 VarSet::iterator j = I_vars.begin();
179 for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
180 long search_for = j->label();
181 vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
182 sigma[mi] = j_loc - labels.begin();
183 }
184 if( verbose >= 3 ) {
185 cout << " sigma: ";
186 copy( sigma.begin(), sigma.end(), ostream_iterator<int>(cout," "));
187 cout << endl;
188 }
189
190 // calculate multindices
191 Permute permindex( dims, sigma );
192
193 // read values
194 size_t nr_nonzeros;
195 while( (is.peek()) == '#' )
196 getline(is,line);
197 is >> nr_nonzeros;
198 if( verbose >= 3 )
199 cout << " nonzeroes: " << nr_nonzeros << endl;
200 for( size_t k = 0; k < nr_nonzeros; k++ ) {
201 size_t li;
202 double val;
203 while( (is.peek()) == '#' )
204 getline(is,line);
205 is >> li;
206 while( (is.peek()) == '#' )
207 getline(is,line);
208 is >> val;
209
210 // store value, but permute indices first according
211 // to internal representation
212 factors.back()[permindex.convert_linear_index( li )] = val;
213 }
214 }
215
216 if( verbose >= 3 ) {
217 cout << "factors:" << endl;
218 copy(factors.begin(), factors.end(), ostream_iterator<Factor>(cout,"\n"));
219 }
220
221 fg = FactorGraph(factors);
222 } catch (char *e) {
223 cout << e << endl;
224 }
225
226 return is;
227 }
228
229
230 VarSet FactorGraph::delta( unsigned i ) const {
231 // calculate Markov Blanket
232 VarSet del;
233 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
234 foreach( const Neighbor &j, nbF(I) ) // for all neighboring variables j of I
235 if( j != i )
236 del |= var(j);
237
238 return del;
239 }
240
241
242 VarSet FactorGraph::Delta( unsigned i ) const {
243 return( delta(i) | var(i) );
244 }
245
246
247 void FactorGraph::makeCavity( unsigned i ) {
248 // fills all Factors that include var(i) with ones
249 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
250 factor(I).fill( 1.0 );
251 }
252
253
254 bool FactorGraph::hasNegatives() const {
255 bool result = false;
256 for( size_t I = 0; I < nrFactors() && !result; I++ )
257 if( factor(I).hasNegatives() )
258 result = true;
259 return result;
260 }
261
262
263 long FactorGraph::ReadFromFile(const char *filename) {
264 ifstream infile;
265 infile.open (filename);
266 if (infile.is_open()) {
267 infile >> *this;
268 infile.close();
269 return 0;
270 } else {
271 cout << "ERROR OPENING FILE" << endl;
272 return 1;
273 }
274 }
275
276
277 long FactorGraph::WriteToFile(const char *filename) const {
278 ofstream outfile;
279 outfile.open (filename);
280 if (outfile.is_open()) {
281 try {
282 outfile << *this;
283 } catch (char *e) {
284 cout << e << endl;
285 return 1;
286 }
287 outfile.close();
288 return 0;
289 } else {
290 cout << "ERROR OPENING FILE" << endl;
291 return 1;
292 }
293 }
294
295
296 long FactorGraph::WriteToDotFile(const char *filename) const {
297 ofstream outfile;
298 outfile.open (filename);
299 if (outfile.is_open()) {
300 try {
301 outfile << "graph G {" << endl;
302 outfile << "graph[size=\"9,9\"];" << endl;
303 outfile << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
304 for( size_t i = 0; i < nrVars(); i++ )
305 outfile << "\tx" << var(i).label() << ";" << endl;
306 outfile << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
307 for( size_t I = 0; I < nrFactors(); I++ )
308 outfile << "\tp" << I << ";" << endl;
309 for( size_t i = 0; i < nrVars(); i++ )
310 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
311 outfile << "\tx" << var(i).label() << " -- p" << I << ";" << endl;
312 outfile << "}" << endl;
313 } catch (char *e) {
314 cout << e << endl;
315 return 1;
316 }
317 outfile.close();
318 return 0;
319 } else {
320 cout << "ERROR OPENING FILE" << endl;
321 return 1;
322 }
323 }
324
325
326 bool hasShortLoops( const vector<Factor> &P ) {
327 bool found = false;
328 vector<Factor>::const_iterator I, J;
329 for( I = P.begin(); I != P.end(); I++ ) {
330 J = I;
331 J++;
332 for( ; J != P.end(); J++ )
333 if( (I->vars() & J->vars()).size() >= 2 ) {
334 found = true;
335 break;
336 }
337 if( found )
338 break;
339 }
340 return found;
341 }
342
343
344 void RemoveShortLoops(vector<Factor> &P) {
345 bool found = true;
346 while( found ) {
347 found = false;
348 vector<Factor>::iterator I, J;
349 for( I = P.begin(); I != P.end(); I++ ) {
350 J = I;
351 J++;
352 for( ; J != P.end(); J++ )
353 if( (I->vars() & J->vars()).size() >= 2 ) {
354 found = true;
355 break;
356 }
357 if( found )
358 break;
359 }
360 if( found ) {
361 cout << "Merging factors " << I->vars() << " and " << J->vars() << endl;
362 *I *= *J;
363 P.erase(J);
364 }
365 }
366 }
367
368
369 vector<VarSet> FactorGraph::Cliques() const {
370 vector<VarSet> result;
371
372 for( size_t I = 0; I < nrFactors(); I++ ) {
373 bool maximal = true;
374 for( size_t J = 0; (J < nrFactors()) && maximal; J++ )
375 if( (factor(J).vars() >> factor(I).vars()) && !(factor(J).vars() == factor(I).vars()) )
376 maximal = false;
377
378 if( maximal )
379 result.push_back( factor(I).vars() );
380 }
381
382 return result;
383 }
384
385
386 void FactorGraph::clamp( const Var & n, size_t i ) {
387 assert( i <= n.states() );
388
389 // Multiply each factor that contains the variable with a delta function
390
391 Factor delta_n_i(n,0.0);
392 delta_n_i[i] = 1.0;
393
394 // For all factors that contain n
395 for( size_t I = 0; I < nrFactors(); I++ )
396 if( factor(I).vars().contains( n ) )
397 // Multiply it with a delta function
398 factor(I) *= delta_n_i;
399
400 return;
401 }
402
403
404 void FactorGraph::saveProb( size_t I ) {
405 map<size_t,Prob>::iterator it = _undoProbs.find( I );
406 if( it != _undoProbs.end() )
407 cout << "FactorGraph::saveProb: WARNING: _undoProbs[I] already defined!" << endl;
408 _undoProbs[I] = factor(I).p();
409 }
410
411
412 void FactorGraph::undoProb( size_t I ) {
413 map<size_t,Prob>::iterator it = _undoProbs.find( I );
414 if( it != _undoProbs.end() ) {
415 factor(I).p() = (*it).second;
416 _undoProbs.erase(it);
417 }
418 }
419
420
421 void FactorGraph::saveProbs( const VarSet &ns ) {
422 if( !_undoProbs.empty() )
423 cout << "FactorGraph::saveProbs: WARNING: _undoProbs not empy!" << endl;
424 for( size_t I = 0; I < nrFactors(); I++ )
425 if( factor(I).vars().intersects( ns ) )
426 _undoProbs[I] = factor(I).p();
427 }
428
429
430 void FactorGraph::undoProbs( const VarSet &ns ) {
431 for( map<size_t,Prob>::iterator uI = _undoProbs.begin(); uI != _undoProbs.end(); ) {
432 if( factor((*uI).first).vars().intersects( ns ) ) {
433 // cout << "undoing " << factor((*uI).first).vars() << endl;
434 // cout << "from " << factor((*uI).first).p() << " to " << (*uI).second << endl;
435 factor((*uI).first).p() = (*uI).second;
436 _undoProbs.erase(uI++);
437 } else
438 uI++;
439 }
440 }
441
442
443 } // end of namespace dai