Various cleanups
[libdai.git] / src / factorgraph.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <iterator>
25 #include <map>
26 #include <set>
27 #include <fstream>
28 #include <string>
29 #include <algorithm>
30 #include <functional>
31 #include <dai/factorgraph.h>
32 #include <dai/util.h>
33 #include <dai/exceptions.h>
34
35
36 namespace dai {
37
38
39 using namespace std;
40
41
42 FactorGraph::FactorGraph( const std::vector<Factor> &P ) : G(), _backup() {
43 // add factors, obtain variables
44 set<Var> varset;
45 _factors.reserve( P.size() );
46 size_t nrEdges = 0;
47 for( vector<Factor>::const_iterator p2 = P.begin(); p2 != P.end(); p2++ ) {
48 _factors.push_back( *p2 );
49 copy( p2->vars().begin(), p2->vars().end(), inserter( varset, varset.begin() ) );
50 nrEdges += p2->vars().size();
51 }
52
53 // add vars
54 _vars.reserve( varset.size() );
55 for( set<Var>::const_iterator p1 = varset.begin(); p1 != varset.end(); p1++ )
56 _vars.push_back( *p1 );
57
58 // create graph structure
59 constructGraph( nrEdges );
60 }
61
62
63 void FactorGraph::constructGraph( size_t nrEdges ) {
64 // create a mapping for indices
65 hash_map<size_t, size_t> hashmap;
66
67 for( size_t i = 0; i < vars().size(); i++ )
68 hashmap[var(i).label()] = i;
69
70 // create edge list
71 vector<Edge> edges;
72 edges.reserve( nrEdges );
73 for( size_t i2 = 0; i2 < nrFactors(); i2++ ) {
74 const VarSet& ns = factor(i2).vars();
75 for( VarSet::const_iterator q = ns.begin(); q != ns.end(); q++ )
76 edges.push_back( Edge(hashmap[q->label()], i2) );
77 }
78
79 // create bipartite graph
80 G.construct( nrVars(), nrFactors(), edges.begin(), edges.end() );
81 }
82
83
84 /// Writes a FactorGraph to an output stream
85 ostream& operator << (ostream& os, const FactorGraph& fg) {
86 os << fg.nrFactors() << endl;
87
88 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
89 os << endl;
90 os << fg.factor(I).vars().size() << endl;
91 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
92 os << i->label() << " ";
93 os << endl;
94 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
95 os << i->states() << " ";
96 os << endl;
97 size_t nr_nonzeros = 0;
98 for( size_t k = 0; k < fg.factor(I).states(); k++ )
99 if( fg.factor(I)[k] != 0.0 )
100 nr_nonzeros++;
101 os << nr_nonzeros << endl;
102 for( size_t k = 0; k < fg.factor(I).states(); k++ )
103 if( fg.factor(I)[k] != 0.0 ) {
104 char buf[20];
105 sprintf(buf,"%18.14g", fg.factor(I)[k]);
106 os << k << " " << buf << endl;
107 }
108 }
109
110 return(os);
111 }
112
113
114 /// Reads a FactorGraph from an input stream
115 istream& operator >> (istream& is, FactorGraph& fg) {
116 long verbose = 0;
117
118 vector<Factor> facs;
119 size_t nr_Factors;
120 string line;
121
122 while( (is.peek()) == '#' )
123 getline(is,line);
124 is >> nr_Factors;
125 if( is.fail() )
126 DAI_THROW(INVALID_FACTORGRAPH_FILE);
127 if( verbose >= 2 )
128 cerr << "Reading " << nr_Factors << " factors..." << endl;
129
130 getline (is,line);
131 if( is.fail() )
132 DAI_THROW(INVALID_FACTORGRAPH_FILE);
133
134 map<long,size_t> vardims;
135 for( size_t I = 0; I < nr_Factors; I++ ) {
136 if( verbose >= 3 )
137 cerr << "Reading factor " << I << "..." << endl;
138 size_t nr_members;
139 while( (is.peek()) == '#' )
140 getline(is,line);
141 is >> nr_members;
142 if( verbose >= 3 )
143 cerr << " nr_members: " << nr_members << endl;
144
145 vector<long> labels;
146 for( size_t mi = 0; mi < nr_members; mi++ ) {
147 long mi_label;
148 while( (is.peek()) == '#' )
149 getline(is,line);
150 is >> mi_label;
151 labels.push_back(mi_label);
152 }
153 if( verbose >= 3 )
154 cerr << " labels: " << labels << endl;
155
156 vector<size_t> dims;
157 for( size_t mi = 0; mi < nr_members; mi++ ) {
158 size_t mi_dim;
159 while( (is.peek()) == '#' )
160 getline(is,line);
161 is >> mi_dim;
162 dims.push_back(mi_dim);
163 }
164 if( verbose >= 3 )
165 cerr << " dimensions: " << dims << endl;
166
167 // add the Factor
168 VarSet I_vars;
169 for( size_t mi = 0; mi < nr_members; mi++ ) {
170 map<long,size_t>::iterator vdi = vardims.find( labels[mi] );
171 if( vdi != vardims.end() ) {
172 // check whether dimensions are consistent
173 if( vdi->second != dims[mi] )
174 DAI_THROW(INVALID_FACTORGRAPH_FILE);
175 } else
176 vardims[labels[mi]] = dims[mi];
177 I_vars |= Var(labels[mi], dims[mi]);
178 }
179 facs.push_back( Factor( I_vars, 0.0 ) );
180
181 // calculate permutation sigma (internally, members are sorted)
182 vector<size_t> sigma(nr_members,0);
183 VarSet::iterator j = I_vars.begin();
184 for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
185 long search_for = j->label();
186 vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
187 sigma[mi] = j_loc - labels.begin();
188 }
189 if( verbose >= 3 )
190 cerr << " sigma: " << sigma << endl;
191
192 // calculate multindices
193 Permute permindex( dims, sigma );
194
195 // read values
196 size_t nr_nonzeros;
197 while( (is.peek()) == '#' )
198 getline(is,line);
199 is >> nr_nonzeros;
200 if( verbose >= 3 )
201 cerr << " nonzeroes: " << nr_nonzeros << endl;
202 for( size_t k = 0; k < nr_nonzeros; k++ ) {
203 size_t li;
204 double val;
205 while( (is.peek()) == '#' )
206 getline(is,line);
207 is >> li;
208 while( (is.peek()) == '#' )
209 getline(is,line);
210 is >> val;
211
212 // store value, but permute indices first according
213 // to internal representation
214 facs.back()[permindex.convert_linear_index( li )] = val;
215 }
216 }
217
218 if( verbose >= 3 )
219 cerr << "factors:" << facs << endl;
220
221 fg = FactorGraph(facs);
222
223 return is;
224 }
225
226
227 VarSet FactorGraph::delta( unsigned i ) const {
228 return( Delta(i) / var(i) );
229 }
230
231
232 VarSet FactorGraph::Delta( unsigned i ) const {
233 // calculate Markov Blanket
234 VarSet Del;
235 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
236 foreach( const Neighbor &j, nbF(I) ) // for all neighboring variables j of I
237 Del |= var(j);
238
239 return Del;
240 }
241
242
243 VarSet FactorGraph::Delta( const VarSet &ns ) const {
244 VarSet result;
245 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
246 result |= Delta(findVar(*n));
247 return result;
248 }
249
250
251 void FactorGraph::makeCavity( unsigned i, bool backup ) {
252 // fills all Factors that include var(i) with ones
253 map<size_t,Factor> newFacs;
254 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
255 newFacs[I] = Factor(factor(I).vars(), 1.0);
256 setFactors( newFacs, backup );
257 }
258
259
260 void FactorGraph::ReadFromFile( const char *filename ) {
261 ifstream infile;
262 infile.open( filename );
263 if( infile.is_open() ) {
264 infile >> *this;
265 infile.close();
266 } else
267 DAI_THROW(CANNOT_READ_FILE);
268 }
269
270
271 void FactorGraph::WriteToFile( const char *filename ) const {
272 ofstream outfile;
273 outfile.open( filename );
274 if( outfile.is_open() ) {
275 outfile << *this;
276 outfile.close();
277 } else
278 DAI_THROW(CANNOT_WRITE_FILE);
279 }
280
281
282 void FactorGraph::printDot( std::ostream &os ) const {
283 os << "graph G {" << endl;
284 os << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
285 for( size_t i = 0; i < nrVars(); i++ )
286 os << "\tv" << var(i).label() << ";" << endl;
287 os << "node[shape=box,width=0.3,height=0.3,fixedsize=true];" << endl;
288 for( size_t I = 0; I < nrFactors(); I++ )
289 os << "\tf" << I << ";" << endl;
290 for( size_t i = 0; i < nrVars(); i++ )
291 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
292 os << "\tv" << var(i).label() << " -- f" << I << ";" << endl;
293 os << "}" << endl;
294 }
295
296
297 vector<VarSet> FactorGraph::Cliques() const {
298 vector<VarSet> result;
299
300 for( size_t I = 0; I < nrFactors(); I++ ) {
301 bool maximal = true;
302 for( size_t J = 0; (J < nrFactors()) && maximal; J++ )
303 if( (factor(J).vars() >> factor(I).vars()) && (factor(J).vars() != factor(I).vars()) )
304 maximal = false;
305
306 if( maximal )
307 result.push_back( factor(I).vars() );
308 }
309
310 return result;
311 }
312
313
314 void FactorGraph::clamp( const Var & n, size_t i, bool backup ) {
315 assert( i <= n.states() );
316
317 // Multiply each factor that contains the variable with a delta function
318
319 Factor delta_n_i(n,0.0);
320 delta_n_i[i] = 1.0;
321
322 map<size_t, Factor> newFacs;
323 // For all factors that contain n
324 for( size_t I = 0; I < nrFactors(); I++ )
325 if( factor(I).vars().contains( n ) )
326 // Multiply it with a delta function
327 newFacs[I] = factor(I) * delta_n_i;
328 setFactors( newFacs, backup );
329
330 return;
331 }
332
333
334 void FactorGraph::clampVar( size_t i, const vector<size_t> &is, bool backup ) {
335 Var n = var(i);
336 Factor mask_n( n, 0.0 );
337
338 foreach( size_t i, is ) {
339 assert( i <= n.states() );
340 mask_n[i] = 1.0;
341 }
342
343 map<size_t, Factor> newFacs;
344 for( size_t I = 0; I < nrFactors(); I++ )
345 if( factor(I).vars().contains( n ) ) {
346 newFacs[I] = factor(I) * mask_n;
347 }
348 setFactors( newFacs, backup );
349 }
350
351
352 void FactorGraph::clampFactor( size_t I, const vector<size_t> &is, bool backup ) {
353 size_t st = factor(I).states();
354 Factor newF( factor(I).vars(), 0.0 );
355
356 foreach( size_t i, is ) {
357 assert( i <= st );
358 newF[i] = factor(I)[i];
359 }
360
361 setFactor( I, newF, backup );
362 }
363
364
365 void FactorGraph::backupFactor( size_t I ) {
366 map<size_t,Factor>::iterator it = _backup.find( I );
367 if( it != _backup.end() )
368 DAI_THROW( MULTIPLE_UNDO );
369 _backup[I] = factor(I);
370 }
371
372
373 void FactorGraph::restoreFactor( size_t I ) {
374 map<size_t,Factor>::iterator it = _backup.find( I );
375 if( it != _backup.end() ) {
376 setFactor(I, it->second);
377 _backup.erase(it);
378 }
379 }
380
381
382 void FactorGraph::backupFactors( const VarSet &ns ) {
383 for( size_t I = 0; I < nrFactors(); I++ )
384 if( factor(I).vars().intersects( ns ) )
385 backupFactor( I );
386 }
387
388
389 void FactorGraph::restoreFactors( const VarSet &ns ) {
390 map<size_t,Factor> facs;
391 for( map<size_t,Factor>::iterator uI = _backup.begin(); uI != _backup.end(); ) {
392 if( factor(uI->first).vars().intersects( ns ) ) {
393 facs.insert( *uI );
394 _backup.erase(uI++);
395 } else
396 uI++;
397 }
398 setFactors( facs );
399 }
400
401
402 void FactorGraph::restoreFactors() {
403 setFactors( _backup );
404 _backup.clear();
405 }
406
407
408 void FactorGraph::backupFactors( const std::set<size_t> & facs ) {
409 for( std::set<size_t>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ )
410 backupFactor( *fac );
411 }
412
413
414 bool FactorGraph::isPairwise() const {
415 bool pairwise = true;
416 for( size_t I = 0; I < nrFactors() && pairwise; I++ )
417 if( factor(I).vars().size() > 2 )
418 pairwise = false;
419 return pairwise;
420 }
421
422
423 bool FactorGraph::isBinary() const {
424 bool binary = true;
425 for( size_t i = 0; i < nrVars() && binary; i++ )
426 if( var(i).states() > 2 )
427 binary = false;
428 return binary;
429 }
430
431
432 FactorGraph FactorGraph::clamped( const Var & v_i, size_t state ) const {
433 Real zeroth_order = 1.0;
434 vector<Factor> clamped_facs;
435 for( size_t I = 0; I < nrFactors(); I++ ) {
436 VarSet v_I = factor(I).vars();
437 Factor new_factor;
438 if( v_I.intersects( v_i ) )
439 new_factor = factor(I).slice( v_i, state );
440 else
441 new_factor = factor(I);
442
443 if( new_factor.vars().size() != 0 ) {
444 size_t J = 0;
445 // if it can be merged with a previous one, do that
446 for( J = 0; J < clamped_facs.size(); J++ )
447 if( clamped_facs[J].vars() == new_factor.vars() ) {
448 clamped_facs[J] *= new_factor;
449 break;
450 }
451 // otherwise, push it back
452 if( J == clamped_facs.size() || clamped_facs.size() == 0 )
453 clamped_facs.push_back( new_factor );
454 } else
455 zeroth_order *= new_factor[0];
456 }
457 *(clamped_facs.begin()) *= zeroth_order;
458 return FactorGraph( clamped_facs );
459 }
460
461
462 FactorGraph FactorGraph::maximalFactors() const {
463 vector<size_t> maxfac( nrFactors() );
464 map<size_t,size_t> newindex;
465 size_t nrmax = 0;
466 for( size_t I = 0; I < nrFactors(); I++ ) {
467 maxfac[I] = I;
468 VarSet maxfacvars = factor(maxfac[I]).vars();
469 for( size_t J = 0; J < nrFactors(); J++ ) {
470 VarSet Jvars = factor(J).vars();
471 if( Jvars >> maxfacvars && (Jvars != maxfacvars) ) {
472 maxfac[I] = J;
473 maxfacvars = factor(maxfac[I]).vars();
474 }
475 }
476 if( maxfac[I] == I )
477 newindex[I] = nrmax++;
478 }
479
480 vector<Factor> facs( nrmax );
481 for( size_t I = 0; I < nrFactors(); I++ )
482 facs[newindex[maxfac[I]]] *= factor(I);
483
484 return FactorGraph( facs.begin(), facs.end(), vars().begin(), vars().end(), facs.size(), nrVars() );
485 }
486
487
488 } // end of namespace dai