Merge branch 'no-edges2'
[libdai.git] / src / factorgraph.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <iterator>
24 #include <map>
25 #include <set>
26 #include <fstream>
27 #include <string>
28 #include <algorithm>
29 #include <functional>
30 #include <tr1/unordered_map>
31 #include <dai/factorgraph.h>
32 #include <dai/util.h>
33
34
35 namespace dai {
36
37
38 using namespace std;
39
40
41 FactorGraph::FactorGraph( const vector<Factor> &P ) : G(), _undoProbs(), _normtype(Prob::NORMPROB) {
42 // add factors, obtain variables
43 set<Var> _vars;
44 factors.reserve( P.size() );
45 size_t nrEdges = 0;
46 for( vector<Factor>::const_iterator p2 = P.begin(); p2 != P.end(); p2++ ) {
47 factors.push_back( *p2 );
48 copy( p2->vars().begin(), p2->vars().end(), inserter( _vars, _vars.begin() ) );
49 nrEdges += p2->vars().size();
50 }
51
52 // add _vars
53 vars.reserve( _vars.size() );
54 for( set<Var>::const_iterator p1 = _vars.begin(); p1 != _vars.end(); p1++ )
55 vars.push_back( *p1 );
56
57 // create graph structure
58 createGraph( nrEdges );
59 }
60
61
62 /// Part of constructors (creates edges, neighbours and adjacency matrix)
63 void FactorGraph::createGraph( size_t nrEdges ) {
64 // create a mapping for indices
65 std::tr1::unordered_map<size_t, size_t> hashmap;
66
67 for( size_t i = 0; i < vars.size(); i++ )
68 hashmap[var(i).label()] = i;
69
70 // create edge list
71 typedef pair<unsigned,unsigned> Edge;
72 vector<Edge> edges;
73 edges.reserve( nrEdges );
74 for( size_t i2 = 0; i2 < nrFactors(); i2++ ) {
75 const VarSet& ns = factor(i2).vars();
76 for( VarSet::const_iterator q = ns.begin(); q != ns.end(); q++ )
77 edges.push_back( Edge(hashmap[q->label()], i2) );
78 }
79
80 // create bipartite graph
81 G.create( nrVars(), nrFactors(), edges.begin(), edges.end() );
82 }
83
84
85 ostream& operator << (ostream& os, const FactorGraph& fg) {
86 os << fg.nrFactors() << endl;
87
88 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
89 os << endl;
90 os << fg.factor(I).vars().size() << endl;
91 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
92 os << i->label() << " ";
93 os << endl;
94 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
95 os << i->states() << " ";
96 os << endl;
97 size_t nr_nonzeros = 0;
98 for( size_t k = 0; k < fg.factor(I).stateSpace(); k++ )
99 if( fg.factor(I)[k] != 0.0 )
100 nr_nonzeros++;
101 os << nr_nonzeros << endl;
102 for( size_t k = 0; k < fg.factor(I).stateSpace(); k++ )
103 if( fg.factor(I)[k] != 0.0 ) {
104 char buf[20];
105 sprintf(buf,"%18.14g", fg.factor(I)[k]);
106 os << k << " " << buf << endl;
107 }
108 }
109
110 return(os);
111 }
112
113
114 istream& operator >> (istream& is, FactorGraph& fg) {
115 long verbose = 0;
116
117 try {
118 vector<Factor> factors;
119 size_t nr_f;
120 string line;
121
122 while( (is.peek()) == '#' )
123 getline(is,line);
124 is >> nr_f;
125 if( is.fail() )
126 throw "ReadFromFile: unable to read number of Factors";
127 if( verbose >= 2 )
128 cout << "Reading " << nr_f << " factors..." << endl;
129
130 getline (is,line);
131 if( is.fail() )
132 throw "ReadFromFile: empty line expected";
133
134 for( size_t I = 0; I < nr_f; I++ ) {
135 if( verbose >= 3 )
136 cout << "Reading factor " << I << "..." << endl;
137 size_t nr_members;
138 while( (is.peek()) == '#' )
139 getline(is,line);
140 is >> nr_members;
141 if( verbose >= 3 )
142 cout << " nr_members: " << nr_members << endl;
143
144 vector<long> labels;
145 for( size_t mi = 0; mi < nr_members; mi++ ) {
146 long mi_label;
147 while( (is.peek()) == '#' )
148 getline(is,line);
149 is >> mi_label;
150 labels.push_back(mi_label);
151 }
152 if( verbose >= 3 ) {
153 cout << " labels: ";
154 copy (labels.begin(), labels.end(), ostream_iterator<int>(cout, " "));
155 cout << endl;
156 }
157
158 vector<size_t> dims;
159 for( size_t mi = 0; mi < nr_members; mi++ ) {
160 size_t mi_dim;
161 while( (is.peek()) == '#' )
162 getline(is,line);
163 is >> mi_dim;
164 dims.push_back(mi_dim);
165 }
166 if( verbose >= 3 ) {
167 cout << " dimensions: ";
168 copy (dims.begin(), dims.end(), ostream_iterator<int>(cout, " "));
169 cout << endl;
170 }
171
172 // add the Factor
173 VarSet I_vars;
174 for( size_t mi = 0; mi < nr_members; mi++ )
175 I_vars.insert( Var(labels[mi], dims[mi]) );
176 factors.push_back(Factor(I_vars,0.0));
177
178 // calculate permutation sigma (internally, members are sorted)
179 vector<long> sigma(nr_members,0);
180 VarSet::iterator j = I_vars.begin();
181 for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
182 long search_for = j->label();
183 vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
184 sigma[mi] = j_loc - labels.begin();
185 }
186 if( verbose >= 3 ) {
187 cout << " sigma: ";
188 copy( sigma.begin(), sigma.end(), ostream_iterator<int>(cout," "));
189 cout << endl;
190 }
191
192 // calculate multindices
193 vector<size_t> sdims(nr_members,0);
194 for( size_t k = 0; k < nr_members; k++ ) {
195 sdims[k] = dims[sigma[k]];
196 }
197 multind mi(dims);
198 multind smi(sdims);
199 if( verbose >= 3 ) {
200 cout << " mi.max(): " << mi.max() << endl;
201 cout << " ";
202 for( size_t k=0; k < nr_members; k++ )
203 cout << labels[k] << " ";
204 cout << " ";
205 for( size_t k=0; k < nr_members; k++ )
206 cout << labels[sigma[k]] << " ";
207 cout << endl;
208 }
209
210 // read values
211 size_t nr_nonzeros;
212 while( (is.peek()) == '#' )
213 getline(is,line);
214 is >> nr_nonzeros;
215 if( verbose >= 3 )
216 cout << " nonzeroes: " << nr_nonzeros << endl;
217 for( size_t k = 0; k < nr_nonzeros; k++ ) {
218 size_t li;
219 double val;
220 while( (is.peek()) == '#' )
221 getline(is,line);
222 is >> li;
223 while( (is.peek()) == '#' )
224 getline(is,line);
225 is >> val;
226
227 vector<size_t> vi = mi.vi(li);
228 vector<size_t> svi(vi.size(),0);
229 for( size_t k = 0; k < vi.size(); k++ )
230 svi[k] = vi[sigma[k]];
231 size_t sli = smi.li(svi);
232 if( verbose >= 3 ) {
233 cout << " " << li << ": ";
234 copy( vi.begin(), vi.end(), ostream_iterator<size_t>(cout," "));
235 cout << "-> ";
236 copy( svi.begin(), svi.end(), ostream_iterator<size_t>(cout," "));
237 cout << ": " << sli << endl;
238 }
239 factors.back()[sli] = val;
240 }
241 }
242
243 if( verbose >= 3 ) {
244 cout << "factors:" << endl;
245 copy(factors.begin(), factors.end(), ostream_iterator<Factor>(cout,"\n"));
246 }
247
248 fg = FactorGraph(factors);
249 } catch (char *e) {
250 cout << e << endl;
251 }
252
253 return is;
254 }
255
256
257 VarSet FactorGraph::delta( unsigned i ) const {
258 // calculate Markov Blanket
259 VarSet del;
260 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
261 foreach( const Neighbor &j, nbF(I) ) // for all neighboring variables j of I
262 if( j != i )
263 del |= var(j);
264
265 return del;
266 }
267
268
269 VarSet FactorGraph::Delta( unsigned i ) const {
270 return( delta(i) | var(i) );
271 }
272
273
274 void FactorGraph::makeCavity( unsigned i ) {
275 // fills all Factors that include var(i) with ones
276 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
277 factor(I).fill( 1.0 );
278 }
279
280
281 bool FactorGraph::hasNegatives() const {
282 bool result = false;
283 for( size_t I = 0; I < nrFactors() && !result; I++ )
284 if( factor(I).hasNegatives() )
285 result = true;
286 return result;
287 }
288
289
290 long FactorGraph::ReadFromFile(const char *filename) {
291 ifstream infile;
292 infile.open (filename);
293 if (infile.is_open()) {
294 infile >> *this;
295 infile.close();
296 return 0;
297 } else {
298 cout << "ERROR OPENING FILE" << endl;
299 return 1;
300 }
301 }
302
303
304 long FactorGraph::WriteToFile(const char *filename) const {
305 ofstream outfile;
306 outfile.open (filename);
307 if (outfile.is_open()) {
308 try {
309 outfile << *this;
310 } catch (char *e) {
311 cout << e << endl;
312 return 1;
313 }
314 outfile.close();
315 return 0;
316 } else {
317 cout << "ERROR OPENING FILE" << endl;
318 return 1;
319 }
320 }
321
322
323 long FactorGraph::WriteToDotFile(const char *filename) const {
324 ofstream outfile;
325 outfile.open (filename);
326 if (outfile.is_open()) {
327 try {
328 outfile << "graph G {" << endl;
329 outfile << "graph[size=\"9,9\"];" << endl;
330 outfile << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
331 for( size_t i = 0; i < nrVars(); i++ )
332 outfile << "\tx" << var(i).label() << ";" << endl;
333 outfile << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
334 for( size_t I = 0; I < nrFactors(); I++ )
335 outfile << "\tp" << I << ";" << endl;
336 for( size_t i = 0; i < nrVars(); i++ )
337 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
338 outfile << "\tx" << var(i).label() << " -- p" << I << ";" << endl;
339 outfile << "}" << endl;
340 } catch (char *e) {
341 cout << e << endl;
342 return 1;
343 }
344 outfile.close();
345 return 0;
346 } else {
347 cout << "ERROR OPENING FILE" << endl;
348 return 1;
349 }
350 }
351
352
353 bool hasShortLoops( const vector<Factor> &P ) {
354 bool found = false;
355 vector<Factor>::const_iterator I, J;
356 for( I = P.begin(); I != P.end(); I++ ) {
357 J = I;
358 J++;
359 for( ; J != P.end(); J++ )
360 if( (I->vars() & J->vars()).size() >= 2 ) {
361 found = true;
362 break;
363 }
364 if( found )
365 break;
366 }
367 return found;
368 }
369
370
371 void RemoveShortLoops(vector<Factor> &P) {
372 bool found = true;
373 while( found ) {
374 found = false;
375 vector<Factor>::iterator I, J;
376 for( I = P.begin(); I != P.end(); I++ ) {
377 J = I;
378 J++;
379 for( ; J != P.end(); J++ )
380 if( (I->vars() & J->vars()).size() >= 2 ) {
381 found = true;
382 break;
383 }
384 if( found )
385 break;
386 }
387 if( found ) {
388 cout << "Merging factors " << I->vars() << " and " << J->vars() << endl;
389 *I *= *J;
390 P.erase(J);
391 }
392 }
393 }
394
395
396 vector<VarSet> FactorGraph::Cliques() const {
397 vector<VarSet> result;
398
399 for( size_t I = 0; I < nrFactors(); I++ ) {
400 bool maximal = true;
401 for( size_t J = 0; (J < nrFactors()) && maximal; J++ )
402 if( (factor(J).vars() >> factor(I).vars()) && !(factor(J).vars() == factor(I).vars()) )
403 maximal = false;
404
405 if( maximal )
406 result.push_back( factor(I).vars() );
407 }
408
409 return result;
410 }
411
412
413 void FactorGraph::clamp( const Var & n, size_t i ) {
414 assert( i <= n.states() );
415
416 /* if( do_surgery ) {
417 size_t ni = find( vars().begin(), vars().end(), n) - vars().begin();
418
419 if( ni != nrVars() ) {
420 for( _nb_cit I = nb1(ni).begin(); I != nb1(ni).end(); I++ ) {
421 if( factor(*I).size() == 1 )
422 // Remove this single-variable factor
423 // I = (_V2.erase(I))--;
424 _E12.erase( _E12.begin() + VV2E(ni, *I) );
425 else {
426 // Replace it by the slice
427 Index ind_I_min_n( factor(*I), factor(*I) / n );
428 Index ind_n( factor(*I), n );
429 Factor slice_I( factor(*I) / n );
430 for( size_t ind_I = 0; ind_I < factor(*I).stateSpace(); ++ind_I, ++ind_I_min_n, ++ind_n )
431 if( ind_n == i )
432 slice_I[ind_I_min_n] = factor(*I)[ind_I];
433 factor(*I) = slice_I;
434
435 // Remove the edge between n and I
436 _E12.erase( _E12.begin() + VV2E(ni, *I) );
437 }
438 }
439
440 Regenerate();
441
442 // remove all unconnected factors
443 for( size_t I = 0; I < nrFactors(); I++ )
444 if( nb2(I).size() == 0 )
445 DeleteFactor(I--);
446
447 DeleteVar( ni );
448
449 // FIXME
450 }
451 } */
452
453 // The cheap solution (at least in terms of coding time) is to multiply every factor
454 // that contains the variable with a delta function
455
456 Factor delta_n_i(n,0.0);
457 delta_n_i[i] = 1.0;
458
459 // For all factors that contain n
460 for( size_t I = 0; I < nrFactors(); I++ )
461 if( factor(I).vars() && n )
462 // Multiply it with a delta function
463 factor(I) *= delta_n_i;
464
465 return;
466 }
467
468
469 void FactorGraph::saveProb( size_t I ) {
470 map<size_t,Prob>::iterator it = _undoProbs.find( I );
471 if( it != _undoProbs.end() )
472 cout << "FactorGraph::saveProb: WARNING: _undoProbs[I] already defined!" << endl;
473 _undoProbs[I] = factor(I).p();
474 }
475
476
477 void FactorGraph::undoProb( size_t I ) {
478 map<size_t,Prob>::iterator it = _undoProbs.find( I );
479 if( it != _undoProbs.end() ) {
480 factor(I).p() = (*it).second;
481 _undoProbs.erase(it);
482 }
483 }
484
485
486 void FactorGraph::saveProbs( const VarSet &ns ) {
487 if( !_undoProbs.empty() )
488 cout << "FactorGraph::saveProbs: WARNING: _undoProbs not empy!" << endl;
489 for( size_t I = 0; I < nrFactors(); I++ )
490 if( factor(I).vars() && ns )
491 _undoProbs[I] = factor(I).p();
492 }
493
494
495 void FactorGraph::undoProbs( const VarSet &ns ) {
496 for( map<size_t,Prob>::iterator uI = _undoProbs.begin(); uI != _undoProbs.end(); ) {
497 if( factor((*uI).first).vars() && ns ) {
498 // cout << "undoing " << factor((*uI).first).vars() << endl;
499 // cout << "from " << factor((*uI).first).p() << " to " << (*uI).second << endl;
500 factor((*uI).first).p() = (*uI).second;
501 _undoProbs.erase(uI++);
502 } else
503 uI++;
504 }
505 }
506
507
508 } // end of namespace dai