Partial adoption of contributions by Giuseppe:
[libdai.git] / src / factorgraph.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <iterator>
24 #include <map>
25 #include <set>
26 #include <fstream>
27 #include <string>
28 #include <algorithm>
29 #include <functional>
30 #include <tr1/unordered_map>
31 #include <dai/factorgraph.h>
32
33
34 namespace dai {
35
36
37 using namespace std;
38
39
40 FactorGraph::FactorGraph( const vector<Factor> &P ) : BipartiteGraph<Var,Factor>(), _undoProbs(), _normtype(Prob::NORMPROB) {
41 // add factors, obtain variables
42 set<Var> _vars;
43 V2s().reserve( P.size() );
44 size_t nrEdges = 0;
45 for( vector<Factor>::const_iterator p2 = P.begin(); p2 != P.end(); p2++ ) {
46 V2s().push_back( *p2 );
47 copy( p2->vars().begin(), p2->vars().end(), inserter( _vars, _vars.begin() ) );
48 nrEdges += p2->vars().size();
49 }
50
51 // add _vars
52 V1s().reserve( _vars.size() );
53 for( set<Var>::const_iterator p1 = _vars.begin(); p1 != _vars.end(); p1++ )
54 V1s().push_back( *p1 );
55
56 // create graph structure
57 createGraph( nrEdges );
58 }
59
60
61 /// Part of constructors (creates edges, neighbours and adjacency matrix)
62 void FactorGraph::createGraph( size_t nrEdges ) {
63 // create a mapping for indices
64 std::tr1::unordered_map<size_t, size_t> hashmap;
65
66 for( size_t i = 0; i < vars().size(); i++ )
67 hashmap[vars()[i].label()] = i;
68
69 // create edges
70 edges().reserve( nrEdges );
71 for( size_t i2 = 0; i2 < nrFactors(); i2++ ) {
72 const VarSet& ns = factor(i2).vars();
73 for( VarSet::const_iterator q = ns.begin(); q != ns.end(); q++ )
74 edges().push_back(_edge_t(hashmap[q->label()], i2));
75 }
76
77 // calc neighbours and adjacency matrix
78 Regenerate();
79 }
80
81
82 /*FactorGraph& FactorGraph::addFactor( const Factor &I ) {
83 // add Factor
84 _V2.push_back( I );
85
86 // add new vars in Factor
87 for( VarSet::const_iterator i = I.vars().begin(); i != I.vars().end(); i++ ) {
88 size_t i_ind = find(vars().begin(), vars().end(), *i) - vars().begin();
89 if( i_ind == vars().size() )
90 _V1.push_back( *i );
91 _E12.push_back( _edge_t( i_ind, nrFactors() - 1 ) );
92 }
93
94 Regenerate();
95 return(*this);
96 }*/
97
98
99 ostream& operator << (ostream& os, const FactorGraph& fg) {
100 os << fg.nrFactors() << endl;
101
102 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
103 os << endl;
104 os << fg.factor(I).vars().size() << endl;
105 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
106 os << i->label() << " ";
107 os << endl;
108 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
109 os << i->states() << " ";
110 os << endl;
111 size_t nr_nonzeros = 0;
112 for( size_t k = 0; k < fg.factor(I).stateSpace(); k++ )
113 if( fg.factor(I)[k] != 0.0 )
114 nr_nonzeros++;
115 os << nr_nonzeros << endl;
116 for( size_t k = 0; k < fg.factor(I).stateSpace(); k++ )
117 if( fg.factor(I)[k] != 0.0 ) {
118 char buf[20];
119 sprintf(buf,"%18.14g", fg.factor(I)[k]);
120 os << k << " " << buf << endl;
121 }
122 }
123
124 return(os);
125 }
126
127
128 istream& operator >> (istream& is, FactorGraph& fg) {
129 long verbose = 0;
130
131 try {
132 vector<Factor> factors;
133 size_t nr_f;
134 string line;
135
136 while( (is.peek()) == '#' )
137 getline(is,line);
138 is >> nr_f;
139 if( is.fail() )
140 throw "ReadFromFile: unable to read number of Factors";
141 if( verbose >= 2 )
142 cout << "Reading " << nr_f << " factors..." << endl;
143
144 getline (is,line);
145 if( is.fail() )
146 throw "ReadFromFile: empty line expected";
147
148 for( size_t I = 0; I < nr_f; I++ ) {
149 if( verbose >= 3 )
150 cout << "Reading factor " << I << "..." << endl;
151 size_t nr_members;
152 while( (is.peek()) == '#' )
153 getline(is,line);
154 is >> nr_members;
155 if( verbose >= 3 )
156 cout << " nr_members: " << nr_members << endl;
157
158 vector<long> labels;
159 for( size_t mi = 0; mi < nr_members; mi++ ) {
160 long mi_label;
161 while( (is.peek()) == '#' )
162 getline(is,line);
163 is >> mi_label;
164 labels.push_back(mi_label);
165 }
166 if( verbose >= 3 ) {
167 cout << " labels: ";
168 copy (labels.begin(), labels.end(), ostream_iterator<int>(cout, " "));
169 cout << endl;
170 }
171
172 vector<size_t> dims;
173 for( size_t mi = 0; mi < nr_members; mi++ ) {
174 size_t mi_dim;
175 while( (is.peek()) == '#' )
176 getline(is,line);
177 is >> mi_dim;
178 dims.push_back(mi_dim);
179 }
180 if( verbose >= 3 ) {
181 cout << " dimensions: ";
182 copy (dims.begin(), dims.end(), ostream_iterator<int>(cout, " "));
183 cout << endl;
184 }
185
186 // add the Factor
187 VarSet I_vars;
188 for( size_t mi = 0; mi < nr_members; mi++ )
189 I_vars.insert( Var(labels[mi], dims[mi]) );
190 factors.push_back(Factor(I_vars,0.0));
191
192 // calculate permutation sigma (internally, members are sorted)
193 vector<long> sigma(nr_members,0);
194 VarSet::iterator j = I_vars.begin();
195 for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
196 long search_for = j->label();
197 vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
198 sigma[mi] = j_loc - labels.begin();
199 }
200 if( verbose >= 3 ) {
201 cout << " sigma: ";
202 copy( sigma.begin(), sigma.end(), ostream_iterator<int>(cout," "));
203 cout << endl;
204 }
205
206 // calculate multindices
207 vector<size_t> sdims(nr_members,0);
208 for( size_t k = 0; k < nr_members; k++ ) {
209 sdims[k] = dims[sigma[k]];
210 }
211 multind mi(dims);
212 multind smi(sdims);
213 if( verbose >= 3 ) {
214 cout << " mi.max(): " << mi.max() << endl;
215 cout << " ";
216 for( size_t k=0; k < nr_members; k++ )
217 cout << labels[k] << " ";
218 cout << " ";
219 for( size_t k=0; k < nr_members; k++ )
220 cout << labels[sigma[k]] << " ";
221 cout << endl;
222 }
223
224 // read values
225 size_t nr_nonzeros;
226 while( (is.peek()) == '#' )
227 getline(is,line);
228 is >> nr_nonzeros;
229 if( verbose >= 3 )
230 cout << " nonzeroes: " << nr_nonzeros << endl;
231 for( size_t k = 0; k < nr_nonzeros; k++ ) {
232 size_t li;
233 double val;
234 while( (is.peek()) == '#' )
235 getline(is,line);
236 is >> li;
237 while( (is.peek()) == '#' )
238 getline(is,line);
239 is >> val;
240
241 vector<size_t> vi = mi.vi(li);
242 vector<size_t> svi(vi.size(),0);
243 for( size_t k = 0; k < vi.size(); k++ )
244 svi[k] = vi[sigma[k]];
245 size_t sli = smi.li(svi);
246 if( verbose >= 3 ) {
247 cout << " " << li << ": ";
248 copy( vi.begin(), vi.end(), ostream_iterator<size_t>(cout," "));
249 cout << "-> ";
250 copy( svi.begin(), svi.end(), ostream_iterator<size_t>(cout," "));
251 cout << ": " << sli << endl;
252 }
253 factors.back()[sli] = val;
254 }
255 }
256
257 if( verbose >= 3 ) {
258 cout << "factors:" << endl;
259 copy(factors.begin(), factors.end(), ostream_iterator<Factor>(cout,"\n"));
260 }
261
262 fg = FactorGraph(factors);
263 } catch (char *e) {
264 cout << e << endl;
265 }
266
267 return is;
268 }
269
270
271 VarSet FactorGraph::delta(const Var & n) const {
272 // calculate Markov Blanket
273 size_t i = findVar( n );
274
275 VarSet del;
276 for( _nb_cit I = nb1(i).begin(); I != nb1(i).end(); ++I )
277 for( _nb_cit j = nb2(*I).begin(); j != nb2(*I).end(); ++j )
278 if( *j != i )
279 del |= var(*j);
280
281 return del;
282 }
283
284
285 VarSet FactorGraph::Delta(const Var & n) const {
286 return( delta(n) | n );
287 }
288
289
290 void FactorGraph::makeFactorCavity(size_t I) {
291 // fill Factor I with ones
292 factor(I).fill(1.0);
293 }
294
295
296 void FactorGraph::makeCavity(const Var & n) {
297 // fills all Factors that include Var n with ones
298 size_t i = findVar( n );
299
300 for( _nb_cit I = nb1(i).begin(); I != nb1(i).end(); ++I )
301 factor(*I).fill(1.0);
302 }
303
304
305 bool FactorGraph::hasNegatives() const {
306 bool result = false;
307 for( size_t I = 0; I < nrFactors() && !result; I++ )
308 if( factor(I).hasNegatives() )
309 result = true;
310 return result;
311 }
312
313
314 /*FactorGraph & FactorGraph::DeleteFactor(size_t I) {
315 // Go through all edges
316 for( vector<_edge_t>::iterator edge = _E12.begin(); edge != _E12.end(); edge++ )
317 if( edge->second >= I ) {
318 if( edge->second == I )
319 edge->second = -1UL;
320 else
321 (edge->second)--;
322 }
323 // Remove all edges containing I
324 for( vector<_edge_t>::iterator edge = _E12.begin(); edge != _E12.end(); edge++ )
325 if( edge->second == -1UL )
326 edge = _E12.erase( edge );
327 // vector<_edge_t>::iterator new_end = _E12.remove_if( _E12.begin(), _E12.end(), compose1( bind2nd(equal_to<size_t>(), -1), select2nd<_edge_t>() ) );
328 // _E12.erase( new_end, _E12.end() );
329
330 // Erase the factor
331 _V2.erase( _V2.begin() + I );
332
333 Regenerate();
334
335 return *this;
336 }
337
338
339 FactorGraph & FactorGraph::DeleteVar(size_t i) {
340 // Go through all edges
341 for( vector<_edge_t>::iterator edge = _E12.begin(); edge != _E12.end(); edge++ )
342 if( edge->first >= i ) {
343 if( edge->first == i )
344 edge->first = -1UL;
345 else
346 (edge->first)--;
347 }
348 // Remove all edges containing i
349 for( vector<_edge_t>::iterator edge = _E12.begin(); edge != _E12.end(); edge++ )
350 if( edge->first == -1UL )
351 edge = _E12.erase( edge );
352
353 // vector<_edge_t>::iterator new_end = _E12.remove_if( _E12.begin(), _E12.end(), compose1( bind2nd(equal_to<size_t>(), -1), select1st<_edge_t>() ) );
354 // _E12.erase( new_end, _E12.end() );
355
356 // Erase the variable
357 _V1.erase( _V1.begin() + i );
358
359 Regenerate();
360
361 return *this;
362 }*/
363
364
365 long FactorGraph::ReadFromFile(const char *filename) {
366 ifstream infile;
367 infile.open (filename);
368 if (infile.is_open()) {
369 infile >> *this;
370 infile.close();
371 return 0;
372 } else {
373 cout << "ERROR OPENING FILE" << endl;
374 return 1;
375 }
376 }
377
378
379 long FactorGraph::WriteToFile(const char *filename) const {
380 ofstream outfile;
381 outfile.open (filename);
382 if (outfile.is_open()) {
383 try {
384 outfile << *this;
385 } catch (char *e) {
386 cout << e << endl;
387 return 1;
388 }
389 outfile.close();
390 return 0;
391 } else {
392 cout << "ERROR OPENING FILE" << endl;
393 return 1;
394 }
395 }
396
397
398 long FactorGraph::WriteToDotFile(const char *filename) const {
399 ofstream outfile;
400 outfile.open (filename);
401 if (outfile.is_open()) {
402 try {
403 outfile << "graph G {" << endl;
404 outfile << "graph[size=\"9,9\"];" << endl;
405 outfile << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
406 for( size_t i = 0; i < nrVars(); i++ )
407 outfile << "\tx" << var(i).label() << ";" << endl;
408 outfile << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
409 for( size_t I = 0; I < nrFactors(); I++ )
410 outfile << "\tp" << I << ";" << endl;
411 for( size_t iI = 0; iI < nr_edges(); iI++ )
412 outfile << "\tx" << var(edge(iI).first).label() << " -- p" << edge(iI).second << ";" << endl;
413 outfile << "}" << endl;
414 } catch (char *e) {
415 cout << e << endl;
416 return 1;
417 }
418 outfile.close();
419 return 0;
420 } else {
421 cout << "ERROR OPENING FILE" << endl;
422 return 1;
423 }
424 }
425
426
427 bool hasShortLoops( const vector<Factor> &P ) {
428 bool found = false;
429 vector<Factor>::const_iterator I, J;
430 for( I = P.begin(); I != P.end(); I++ ) {
431 J = I;
432 J++;
433 for( ; J != P.end(); J++ )
434 if( (I->vars() & J->vars()).size() >= 2 ) {
435 found = true;
436 break;
437 }
438 if( found )
439 break;
440 }
441 return found;
442 }
443
444
445 void RemoveShortLoops(vector<Factor> &P) {
446 bool found = true;
447 while( found ) {
448 found = false;
449 vector<Factor>::iterator I, J;
450 for( I = P.begin(); I != P.end(); I++ ) {
451 J = I;
452 J++;
453 for( ; J != P.end(); J++ )
454 if( (I->vars() & J->vars()).size() >= 2 ) {
455 found = true;
456 break;
457 }
458 if( found )
459 break;
460 }
461 if( found ) {
462 cout << "Merging factors " << I->vars() << " and " << J->vars() << endl;
463 *I *= *J;
464 P.erase(J);
465 }
466 }
467 }
468
469
470 Factor FactorGraph::ExactMarginal(const VarSet & x) const {
471 Factor P;
472 for( size_t I = 0; I < nrFactors(); I++ )
473 P *= factor(I);
474 return P.marginal(x);
475 }
476
477
478 Real FactorGraph::ExactlogZ() const {
479 Factor P;
480 for( size_t I = 0; I < nrFactors(); I++ )
481 P *= factor(I);
482 return std::log(P.totalSum());
483 }
484
485
486 vector<VarSet> FactorGraph::Cliques() const {
487 vector<VarSet> result;
488
489 for( size_t I = 0; I < nrFactors(); I++ ) {
490 bool maximal = true;
491 for( size_t J = 0; (J < nrFactors()) && maximal; J++ )
492 if( (factor(J).vars() >> factor(I).vars()) && !(factor(J).vars() == factor(I).vars()) )
493 maximal = false;
494
495 if( maximal )
496 result.push_back( factor(I).vars() );
497 }
498
499 return result;
500 }
501
502
503 void FactorGraph::clamp( const Var & n, size_t i ) {
504 assert( i <= n.states() );
505
506 /* if( do_surgery ) {
507 size_t ni = find( vars().begin(), vars().end(), n) - vars().begin();
508
509 if( ni != nrVars() ) {
510 for( _nb_cit I = nb1(ni).begin(); I != nb1(ni).end(); I++ ) {
511 if( factor(*I).size() == 1 )
512 // Remove this single-variable factor
513 // I = (_V2.erase(I))--;
514 _E12.erase( _E12.begin() + VV2E(ni, *I) );
515 else {
516 // Replace it by the slice
517 Index ind_I_min_n( factor(*I), factor(*I) / n );
518 Index ind_n( factor(*I), n );
519 Factor slice_I( factor(*I) / n );
520 for( size_t ind_I = 0; ind_I < factor(*I).stateSpace(); ++ind_I, ++ind_I_min_n, ++ind_n )
521 if( ind_n == i )
522 slice_I[ind_I_min_n] = factor(*I)[ind_I];
523 factor(*I) = slice_I;
524
525 // Remove the edge between n and I
526 _E12.erase( _E12.begin() + VV2E(ni, *I) );
527 }
528 }
529
530 Regenerate();
531
532 // remove all unconnected factors
533 for( size_t I = 0; I < nrFactors(); I++ )
534 if( nb2(I).size() == 0 )
535 DeleteFactor(I--);
536
537 DeleteVar( ni );
538
539 // FIXME
540 }
541 } */
542
543 // The cheap solution (at least in terms of coding time) is to multiply every factor
544 // that contains the variable with a delta function
545
546 Factor delta_n_i(n,0.0);
547 delta_n_i[i] = 1.0;
548
549 // For all factors that contain n
550 for( size_t I = 0; I < nrFactors(); I++ )
551 if( factor(I).vars() && n )
552 // Multiply it with a delta function
553 factor(I) *= delta_n_i;
554
555 return;
556 }
557
558
559 void FactorGraph::saveProb( size_t I ) {
560 map<size_t,Prob>::iterator it = _undoProbs.find( I );
561 if( it != _undoProbs.end() )
562 cout << "FactorGraph::saveProb: WARNING: _undoProbs[I] already defined!" << endl;
563 _undoProbs[I] = factor(I).p();
564 }
565
566
567 void FactorGraph::undoProb( size_t I ) {
568 map<size_t,Prob>::iterator it = _undoProbs.find( I );
569 if( it != _undoProbs.end() ) {
570 factor(I).p() = (*it).second;
571 _undoProbs.erase(it);
572 }
573 }
574
575
576 void FactorGraph::saveProbs( const VarSet &ns ) {
577 if( !_undoProbs.empty() )
578 cout << "FactorGraph::saveProbs: WARNING: _undoProbs not empy!" << endl;
579 for( size_t I = 0; I < nrFactors(); I++ )
580 if( factor(I).vars() && ns )
581 _undoProbs[I] = factor(I).p();
582 }
583
584
585 void FactorGraph::undoProbs( const VarSet &ns ) {
586 for( map<size_t,Prob>::iterator uI = _undoProbs.begin(); uI != _undoProbs.end(); ) {
587 if( factor((*uI).first).vars() && ns ) {
588 // cout << "undoing " << factor((*uI).first).vars() << endl;
589 // cout << "from " << factor((*uI).first).p() << " to " << (*uI).second << endl;
590 factor((*uI).first).p() = (*uI).second;
591 _undoProbs.erase(uI++);
592 } else
593 uI++;
594 }
595 }
596
597
598 bool FactorGraph::isConnected() const {
599 if( nrVars() == 0 )
600 return false;
601 else {
602 Var n = var( 0 );
603
604 VarSet component = n;
605
606 VarSet remaining;
607 for( size_t i = 1; i < nrVars(); i++ )
608 remaining |= var(i);
609
610 bool found_new_vars = true;
611 while( found_new_vars ) {
612 VarSet new_vars;
613 for( VarSet::const_iterator m = remaining.begin(); m != remaining.end(); m++ )
614 if( delta(*m) && component )
615 new_vars |= *m;
616
617 if( new_vars.empty() )
618 found_new_vars = false;
619 else
620 found_new_vars = true;
621
622 component |= new_vars;
623 remaining /= new_vars;
624 };
625 return remaining.empty();
626 }
627 }
628
629
630 } // end of namespace dai