Merged var.h and varset.h from SVN head
[libdai.git] / src / factorgraph.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <iterator>
24 #include <map>
25 #include <set>
26 #include <fstream>
27 #include <string>
28 #include <algorithm>
29 #include <functional>
30 #include <tr1/unordered_map>
31 #include <dai/factorgraph.h>
32 #include <dai/util.h>
33
34
35 namespace dai {
36
37
38 using namespace std;
39
40
41 FactorGraph::FactorGraph( const std::vector<Factor> &P ) : G(), _undoProbs(), _normtype(Prob::NORMPROB) {
42 // add factors, obtain variables
43 set<Var> _vars;
44 factors.reserve( P.size() );
45 size_t nrEdges = 0;
46 for( vector<Factor>::const_iterator p2 = P.begin(); p2 != P.end(); p2++ ) {
47 factors.push_back( *p2 );
48 copy( p2->vars().begin(), p2->vars().end(), inserter( _vars, _vars.begin() ) );
49 nrEdges += p2->vars().size();
50 }
51
52 // add _vars
53 vars.reserve( _vars.size() );
54 for( set<Var>::const_iterator p1 = _vars.begin(); p1 != _vars.end(); p1++ )
55 vars.push_back( *p1 );
56
57 // create graph structure
58 createGraph( nrEdges );
59 }
60
61
62 /// Part of constructors (creates edges, neighbours and adjacency matrix)
63 void FactorGraph::createGraph( size_t nrEdges ) {
64 // create a mapping for indices
65 std::tr1::unordered_map<size_t, size_t> hashmap;
66
67 for( size_t i = 0; i < vars.size(); i++ )
68 hashmap[var(i).label()] = i;
69
70 // create edge list
71 vector<Edge> edges;
72 edges.reserve( nrEdges );
73 for( size_t i2 = 0; i2 < nrFactors(); i2++ ) {
74 const VarSet& ns = factor(i2).vars();
75 for( VarSet::const_iterator q = ns.begin(); q != ns.end(); q++ )
76 edges.push_back( Edge(hashmap[q->label()], i2) );
77 }
78
79 // create bipartite graph
80 G.create( nrVars(), nrFactors(), edges.begin(), edges.end() );
81 }
82
83
84 ostream& operator << (ostream& os, const FactorGraph& fg) {
85 os << fg.nrFactors() << endl;
86
87 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
88 os << endl;
89 os << fg.factor(I).vars().size() << endl;
90 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
91 os << i->label() << " ";
92 os << endl;
93 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
94 os << i->states() << " ";
95 os << endl;
96 size_t nr_nonzeros = 0;
97 for( size_t k = 0; k < fg.factor(I).states(); k++ )
98 if( fg.factor(I)[k] != 0.0 )
99 nr_nonzeros++;
100 os << nr_nonzeros << endl;
101 for( size_t k = 0; k < fg.factor(I).states(); k++ )
102 if( fg.factor(I)[k] != 0.0 ) {
103 char buf[20];
104 sprintf(buf,"%18.14g", fg.factor(I)[k]);
105 os << k << " " << buf << endl;
106 }
107 }
108
109 return(os);
110 }
111
112
113 istream& operator >> (istream& is, FactorGraph& fg) {
114 long verbose = 0;
115
116 try {
117 vector<Factor> factors;
118 size_t nr_f;
119 string line;
120
121 while( (is.peek()) == '#' )
122 getline(is,line);
123 is >> nr_f;
124 if( is.fail() )
125 throw "ReadFromFile: unable to read number of Factors";
126 if( verbose >= 2 )
127 cout << "Reading " << nr_f << " factors..." << endl;
128
129 getline (is,line);
130 if( is.fail() )
131 throw "ReadFromFile: empty line expected";
132
133 for( size_t I = 0; I < nr_f; I++ ) {
134 if( verbose >= 3 )
135 cout << "Reading factor " << I << "..." << endl;
136 size_t nr_members;
137 while( (is.peek()) == '#' )
138 getline(is,line);
139 is >> nr_members;
140 if( verbose >= 3 )
141 cout << " nr_members: " << nr_members << endl;
142
143 vector<long> labels;
144 for( size_t mi = 0; mi < nr_members; mi++ ) {
145 long mi_label;
146 while( (is.peek()) == '#' )
147 getline(is,line);
148 is >> mi_label;
149 labels.push_back(mi_label);
150 }
151 if( verbose >= 3 ) {
152 cout << " labels: ";
153 copy (labels.begin(), labels.end(), ostream_iterator<int>(cout, " "));
154 cout << endl;
155 }
156
157 vector<size_t> dims;
158 for( size_t mi = 0; mi < nr_members; mi++ ) {
159 size_t mi_dim;
160 while( (is.peek()) == '#' )
161 getline(is,line);
162 is >> mi_dim;
163 dims.push_back(mi_dim);
164 }
165 if( verbose >= 3 ) {
166 cout << " dimensions: ";
167 copy (dims.begin(), dims.end(), ostream_iterator<int>(cout, " "));
168 cout << endl;
169 }
170
171 // add the Factor
172 VarSet I_vars;
173 for( size_t mi = 0; mi < nr_members; mi++ )
174 I_vars |= Var(labels[mi], dims[mi]);
175 factors.push_back(Factor(I_vars,0.0));
176
177 // calculate permutation sigma (internally, members are sorted)
178 vector<size_t> sigma(nr_members,0);
179 VarSet::iterator j = I_vars.begin();
180 for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
181 long search_for = j->label();
182 vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
183 sigma[mi] = j_loc - labels.begin();
184 }
185 if( verbose >= 3 ) {
186 cout << " sigma: ";
187 copy( sigma.begin(), sigma.end(), ostream_iterator<int>(cout," "));
188 cout << endl;
189 }
190
191 // calculate multindices
192 Permute permindex( dims, sigma );
193
194 // read values
195 size_t nr_nonzeros;
196 while( (is.peek()) == '#' )
197 getline(is,line);
198 is >> nr_nonzeros;
199 if( verbose >= 3 )
200 cout << " nonzeroes: " << nr_nonzeros << endl;
201 for( size_t k = 0; k < nr_nonzeros; k++ ) {
202 size_t li;
203 double val;
204 while( (is.peek()) == '#' )
205 getline(is,line);
206 is >> li;
207 while( (is.peek()) == '#' )
208 getline(is,line);
209 is >> val;
210
211 // store value, but permute indices first according
212 // to internal representation
213 factors.back()[permindex.convert_linear_index( li )] = val;
214 }
215 }
216
217 if( verbose >= 3 ) {
218 cout << "factors:" << endl;
219 copy(factors.begin(), factors.end(), ostream_iterator<Factor>(cout,"\n"));
220 }
221
222 fg = FactorGraph(factors);
223 } catch (char *e) {
224 cout << e << endl;
225 }
226
227 return is;
228 }
229
230
231 VarSet FactorGraph::delta( unsigned i ) const {
232 // calculate Markov Blanket
233 VarSet del;
234 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
235 foreach( const Neighbor &j, nbF(I) ) // for all neighboring variables j of I
236 if( j != i )
237 del |= var(j);
238
239 return del;
240 }
241
242
243 VarSet FactorGraph::Delta( unsigned i ) const {
244 return( delta(i) | var(i) );
245 }
246
247
248 void FactorGraph::makeCavity( unsigned i ) {
249 // fills all Factors that include var(i) with ones
250 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
251 factor(I).fill( 1.0 );
252 }
253
254
255 bool FactorGraph::hasNegatives() const {
256 bool result = false;
257 for( size_t I = 0; I < nrFactors() && !result; I++ )
258 if( factor(I).hasNegatives() )
259 result = true;
260 return result;
261 }
262
263
264 long FactorGraph::ReadFromFile(const char *filename) {
265 ifstream infile;
266 infile.open (filename);
267 if (infile.is_open()) {
268 infile >> *this;
269 infile.close();
270 return 0;
271 } else {
272 cout << "ERROR OPENING FILE" << endl;
273 return 1;
274 }
275 }
276
277
278 long FactorGraph::WriteToFile(const char *filename) const {
279 ofstream outfile;
280 outfile.open (filename);
281 if (outfile.is_open()) {
282 try {
283 outfile << *this;
284 } catch (char *e) {
285 cout << e << endl;
286 return 1;
287 }
288 outfile.close();
289 return 0;
290 } else {
291 cout << "ERROR OPENING FILE" << endl;
292 return 1;
293 }
294 }
295
296
297 long FactorGraph::WriteToDotFile(const char *filename) const {
298 ofstream outfile;
299 outfile.open (filename);
300 if (outfile.is_open()) {
301 try {
302 outfile << "graph G {" << endl;
303 outfile << "graph[size=\"9,9\"];" << endl;
304 outfile << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
305 for( size_t i = 0; i < nrVars(); i++ )
306 outfile << "\tx" << var(i).label() << ";" << endl;
307 outfile << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
308 for( size_t I = 0; I < nrFactors(); I++ )
309 outfile << "\tp" << I << ";" << endl;
310 for( size_t i = 0; i < nrVars(); i++ )
311 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
312 outfile << "\tx" << var(i).label() << " -- p" << I << ";" << endl;
313 outfile << "}" << endl;
314 } catch (char *e) {
315 cout << e << endl;
316 return 1;
317 }
318 outfile.close();
319 return 0;
320 } else {
321 cout << "ERROR OPENING FILE" << endl;
322 return 1;
323 }
324 }
325
326
327 bool hasShortLoops( const vector<Factor> &P ) {
328 bool found = false;
329 vector<Factor>::const_iterator I, J;
330 for( I = P.begin(); I != P.end(); I++ ) {
331 J = I;
332 J++;
333 for( ; J != P.end(); J++ )
334 if( (I->vars() & J->vars()).size() >= 2 ) {
335 found = true;
336 break;
337 }
338 if( found )
339 break;
340 }
341 return found;
342 }
343
344
345 void RemoveShortLoops(vector<Factor> &P) {
346 bool found = true;
347 while( found ) {
348 found = false;
349 vector<Factor>::iterator I, J;
350 for( I = P.begin(); I != P.end(); I++ ) {
351 J = I;
352 J++;
353 for( ; J != P.end(); J++ )
354 if( (I->vars() & J->vars()).size() >= 2 ) {
355 found = true;
356 break;
357 }
358 if( found )
359 break;
360 }
361 if( found ) {
362 cout << "Merging factors " << I->vars() << " and " << J->vars() << endl;
363 *I *= *J;
364 P.erase(J);
365 }
366 }
367 }
368
369
370 vector<VarSet> FactorGraph::Cliques() const {
371 vector<VarSet> result;
372
373 for( size_t I = 0; I < nrFactors(); I++ ) {
374 bool maximal = true;
375 for( size_t J = 0; (J < nrFactors()) && maximal; J++ )
376 if( (factor(J).vars() >> factor(I).vars()) && !(factor(J).vars() == factor(I).vars()) )
377 maximal = false;
378
379 if( maximal )
380 result.push_back( factor(I).vars() );
381 }
382
383 return result;
384 }
385
386
387 void FactorGraph::clamp( const Var & n, size_t i ) {
388 assert( i <= n.states() );
389
390 // Multiply each factor that contains the variable with a delta function
391
392 Factor delta_n_i(n,0.0);
393 delta_n_i[i] = 1.0;
394
395 // For all factors that contain n
396 for( size_t I = 0; I < nrFactors(); I++ )
397 if( factor(I).vars().contains( n ) )
398 // Multiply it with a delta function
399 factor(I) *= delta_n_i;
400
401 return;
402 }
403
404
405 void FactorGraph::saveProb( size_t I ) {
406 map<size_t,Prob>::iterator it = _undoProbs.find( I );
407 if( it != _undoProbs.end() )
408 cout << "FactorGraph::saveProb: WARNING: _undoProbs[I] already defined!" << endl;
409 _undoProbs[I] = factor(I).p();
410 }
411
412
413 void FactorGraph::undoProb( size_t I ) {
414 map<size_t,Prob>::iterator it = _undoProbs.find( I );
415 if( it != _undoProbs.end() ) {
416 factor(I).p() = (*it).second;
417 _undoProbs.erase(it);
418 }
419 }
420
421
422 void FactorGraph::saveProbs( const VarSet &ns ) {
423 if( !_undoProbs.empty() )
424 cout << "FactorGraph::saveProbs: WARNING: _undoProbs not empy!" << endl;
425 for( size_t I = 0; I < nrFactors(); I++ )
426 if( factor(I).vars().intersects( ns ) )
427 _undoProbs[I] = factor(I).p();
428 }
429
430
431 void FactorGraph::undoProbs( const VarSet &ns ) {
432 for( map<size_t,Prob>::iterator uI = _undoProbs.begin(); uI != _undoProbs.end(); ) {
433 if( factor((*uI).first).vars().intersects( ns ) ) {
434 // cout << "undoing " << factor((*uI).first).vars() << endl;
435 // cout << "from " << factor((*uI).first).p() << " to " << (*uI).second << endl;
436 factor((*uI).first).p() = (*uI).second;
437 _undoProbs.erase(uI++);
438 } else
439 uI++;
440 }
441 }
442
443
444 } // end of namespace dai