11ea17fb87083ac2e12c47d5c12ca0bd840c5807
[libdai.git] / src / factorgraph.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <iomanip>
25 #include <iterator>
26 #include <map>
27 #include <set>
28 #include <fstream>
29 #include <string>
30 #include <algorithm>
31 #include <functional>
32 #include <dai/factorgraph.h>
33 #include <dai/util.h>
34 #include <dai/exceptions.h>
35
36
37 namespace dai {
38
39
40 using namespace std;
41
42
43 FactorGraph::FactorGraph( const std::vector<Factor> &P ) : G(), _backup() {
44 // add factors, obtain variables
45 set<Var> varset;
46 _factors.reserve( P.size() );
47 size_t nrEdges = 0;
48 for( vector<Factor>::const_iterator p2 = P.begin(); p2 != P.end(); p2++ ) {
49 _factors.push_back( *p2 );
50 copy( p2->vars().begin(), p2->vars().end(), inserter( varset, varset.begin() ) );
51 nrEdges += p2->vars().size();
52 }
53
54 // add vars
55 _vars.reserve( varset.size() );
56 for( set<Var>::const_iterator p1 = varset.begin(); p1 != varset.end(); p1++ )
57 _vars.push_back( *p1 );
58
59 // create graph structure
60 constructGraph( nrEdges );
61 }
62
63
64 void FactorGraph::constructGraph( size_t nrEdges ) {
65 // create a mapping for indices
66 hash_map<size_t, size_t> hashmap;
67
68 for( size_t i = 0; i < vars().size(); i++ )
69 hashmap[var(i).label()] = i;
70
71 // create edge list
72 vector<Edge> edges;
73 edges.reserve( nrEdges );
74 for( size_t i2 = 0; i2 < nrFactors(); i2++ ) {
75 const VarSet& ns = factor(i2).vars();
76 for( VarSet::const_iterator q = ns.begin(); q != ns.end(); q++ )
77 edges.push_back( Edge(hashmap[q->label()], i2) );
78 }
79
80 // create bipartite graph
81 G.construct( nrVars(), nrFactors(), edges.begin(), edges.end() );
82 }
83
84
85 /// Writes a FactorGraph to an output stream
86 ostream& operator << (ostream& os, const FactorGraph& fg) {
87 os << fg.nrFactors() << endl;
88
89 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
90 os << endl;
91 os << fg.factor(I).vars().size() << endl;
92 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
93 os << i->label() << " ";
94 os << endl;
95 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
96 os << i->states() << " ";
97 os << endl;
98 size_t nr_nonzeros = 0;
99 for( size_t k = 0; k < fg.factor(I).states(); k++ )
100 if( fg.factor(I)[k] != 0.0 )
101 nr_nonzeros++;
102 os << nr_nonzeros << endl;
103 for( size_t k = 0; k < fg.factor(I).states(); k++ )
104 if( fg.factor(I)[k] != 0.0 )
105 os << k << " " << setw(os.precision()+4) << fg.factor(I)[k] << endl;
106 }
107
108 return(os);
109 }
110
111
112 /// Reads a FactorGraph from an input stream
113 istream& operator >> (istream& is, FactorGraph& fg) {
114 long verbose = 0;
115
116 vector<Factor> facs;
117 size_t nr_Factors;
118 string line;
119
120 while( (is.peek()) == '#' )
121 getline(is,line);
122 is >> nr_Factors;
123 if( is.fail() )
124 DAI_THROW(INVALID_FACTORGRAPH_FILE);
125 if( verbose >= 2 )
126 cerr << "Reading " << nr_Factors << " factors..." << endl;
127
128 getline (is,line);
129 if( is.fail() )
130 DAI_THROW(INVALID_FACTORGRAPH_FILE);
131
132 map<long,size_t> vardims;
133 for( size_t I = 0; I < nr_Factors; I++ ) {
134 if( verbose >= 3 )
135 cerr << "Reading factor " << I << "..." << endl;
136 size_t nr_members;
137 while( (is.peek()) == '#' )
138 getline(is,line);
139 is >> nr_members;
140 if( verbose >= 3 )
141 cerr << " nr_members: " << nr_members << endl;
142
143 vector<long> labels;
144 for( size_t mi = 0; mi < nr_members; mi++ ) {
145 long mi_label;
146 while( (is.peek()) == '#' )
147 getline(is,line);
148 is >> mi_label;
149 labels.push_back(mi_label);
150 }
151 if( verbose >= 3 )
152 cerr << " labels: " << labels << endl;
153
154 vector<size_t> dims;
155 for( size_t mi = 0; mi < nr_members; mi++ ) {
156 size_t mi_dim;
157 while( (is.peek()) == '#' )
158 getline(is,line);
159 is >> mi_dim;
160 dims.push_back(mi_dim);
161 }
162 if( verbose >= 3 )
163 cerr << " dimensions: " << dims << endl;
164
165 // add the Factor
166 VarSet I_vars;
167 for( size_t mi = 0; mi < nr_members; mi++ ) {
168 map<long,size_t>::iterator vdi = vardims.find( labels[mi] );
169 if( vdi != vardims.end() ) {
170 // check whether dimensions are consistent
171 if( vdi->second != dims[mi] )
172 DAI_THROW(INVALID_FACTORGRAPH_FILE);
173 } else
174 vardims[labels[mi]] = dims[mi];
175 I_vars |= Var(labels[mi], dims[mi]);
176 }
177 facs.push_back( Factor( I_vars, 0.0 ) );
178
179 // calculate permutation sigma (internally, members are sorted)
180 vector<size_t> sigma(nr_members,0);
181 VarSet::iterator j = I_vars.begin();
182 for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
183 long search_for = j->label();
184 vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
185 sigma[mi] = j_loc - labels.begin();
186 }
187 if( verbose >= 3 )
188 cerr << " sigma: " << sigma << endl;
189
190 // calculate multindices
191 Permute permindex( dims, sigma );
192
193 // read values
194 size_t nr_nonzeros;
195 while( (is.peek()) == '#' )
196 getline(is,line);
197 is >> nr_nonzeros;
198 if( verbose >= 3 )
199 cerr << " nonzeroes: " << nr_nonzeros << endl;
200 for( size_t k = 0; k < nr_nonzeros; k++ ) {
201 size_t li;
202 double val;
203 while( (is.peek()) == '#' )
204 getline(is,line);
205 is >> li;
206 while( (is.peek()) == '#' )
207 getline(is,line);
208 is >> val;
209
210 // store value, but permute indices first according
211 // to internal representation
212 facs.back()[permindex.convert_linear_index( li )] = val;
213 }
214 }
215
216 if( verbose >= 3 )
217 cerr << "factors:" << facs << endl;
218
219 fg = FactorGraph(facs);
220
221 return is;
222 }
223
224
225 VarSet FactorGraph::delta( unsigned i ) const {
226 return( Delta(i) / var(i) );
227 }
228
229
230 VarSet FactorGraph::Delta( unsigned i ) const {
231 // calculate Markov Blanket
232 VarSet Del;
233 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
234 foreach( const Neighbor &j, nbF(I) ) // for all neighboring variables j of I
235 Del |= var(j);
236
237 return Del;
238 }
239
240
241 VarSet FactorGraph::Delta( const VarSet &ns ) const {
242 VarSet result;
243 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
244 result |= Delta(findVar(*n));
245 return result;
246 }
247
248
249 void FactorGraph::makeCavity( unsigned i, bool backup ) {
250 // fills all Factors that include var(i) with ones
251 map<size_t,Factor> newFacs;
252 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
253 newFacs[I] = Factor(factor(I).vars(), 1.0);
254 setFactors( newFacs, backup );
255 }
256
257
258 void FactorGraph::ReadFromFile( const char *filename ) {
259 ifstream infile;
260 infile.open( filename );
261 if( infile.is_open() ) {
262 infile >> *this;
263 infile.close();
264 } else
265 DAI_THROW(CANNOT_READ_FILE);
266 }
267
268
269 void FactorGraph::WriteToFile( const char *filename, size_t precision ) const {
270 ofstream outfile;
271 outfile.open( filename );
272 if( outfile.is_open() ) {
273 outfile.precision( precision );
274 outfile << *this;
275 outfile.close();
276 } else
277 DAI_THROW(CANNOT_WRITE_FILE);
278 }
279
280
281 void FactorGraph::printDot( std::ostream &os ) const {
282 os << "graph G {" << endl;
283 os << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
284 for( size_t i = 0; i < nrVars(); i++ )
285 os << "\tv" << var(i).label() << ";" << endl;
286 os << "node[shape=box,width=0.3,height=0.3,fixedsize=true];" << endl;
287 for( size_t I = 0; I < nrFactors(); I++ )
288 os << "\tf" << I << ";" << endl;
289 for( size_t i = 0; i < nrVars(); i++ )
290 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
291 os << "\tv" << var(i).label() << " -- f" << I << ";" << endl;
292 os << "}" << endl;
293 }
294
295
296 vector<VarSet> FactorGraph::Cliques() const {
297 vector<VarSet> result;
298
299 for( size_t I = 0; I < nrFactors(); I++ ) {
300 bool maximal = true;
301 for( size_t J = 0; (J < nrFactors()) && maximal; J++ )
302 if( (factor(J).vars() >> factor(I).vars()) && (factor(J).vars() != factor(I).vars()) )
303 maximal = false;
304
305 if( maximal )
306 result.push_back( factor(I).vars() );
307 }
308
309 return result;
310 }
311
312
313 void FactorGraph::clamp( size_t i, size_t x, bool backup ) {
314 assert( x <= var(i).states() );
315 Factor mask( var(i), 0.0 );
316 mask[x] = 1.0;
317
318 map<size_t, Factor> newFacs;
319 foreach( const BipartiteGraph::Neighbor &I, nbV(i) )
320 newFacs[I] = factor(I) * mask;
321 setFactors( newFacs, backup );
322
323 return;
324 }
325
326
327 void FactorGraph::clampVar( size_t i, const vector<size_t> &is, bool backup ) {
328 Var n = var(i);
329 Factor mask_n( n, 0.0 );
330
331 foreach( size_t i, is ) {
332 assert( i <= n.states() );
333 mask_n[i] = 1.0;
334 }
335
336 map<size_t, Factor> newFacs;
337 foreach( const BipartiteGraph::Neighbor &I, nbV(i) )
338 newFacs[I] = factor(I) * mask_n;
339 setFactors( newFacs, backup );
340 }
341
342
343 void FactorGraph::clampFactor( size_t I, const vector<size_t> &is, bool backup ) {
344 size_t st = factor(I).states();
345 Factor newF( factor(I).vars(), 0.0 );
346
347 foreach( size_t i, is ) {
348 assert( i <= st );
349 newF[i] = factor(I)[i];
350 }
351
352 setFactor( I, newF, backup );
353 }
354
355
356 void FactorGraph::backupFactor( size_t I ) {
357 map<size_t,Factor>::iterator it = _backup.find( I );
358 if( it != _backup.end() )
359 DAI_THROW( MULTIPLE_UNDO );
360 _backup[I] = factor(I);
361 }
362
363
364 void FactorGraph::restoreFactor( size_t I ) {
365 map<size_t,Factor>::iterator it = _backup.find( I );
366 if( it != _backup.end() ) {
367 setFactor(I, it->second);
368 _backup.erase(it);
369 }
370 }
371
372
373 void FactorGraph::backupFactors( const VarSet &ns ) {
374 for( size_t I = 0; I < nrFactors(); I++ )
375 if( factor(I).vars().intersects( ns ) )
376 backupFactor( I );
377 }
378
379
380 void FactorGraph::restoreFactors( const VarSet &ns ) {
381 map<size_t,Factor> facs;
382 for( map<size_t,Factor>::iterator uI = _backup.begin(); uI != _backup.end(); ) {
383 if( factor(uI->first).vars().intersects( ns ) ) {
384 facs.insert( *uI );
385 _backup.erase(uI++);
386 } else
387 uI++;
388 }
389 setFactors( facs );
390 }
391
392
393 void FactorGraph::restoreFactors() {
394 setFactors( _backup );
395 _backup.clear();
396 }
397
398
399 void FactorGraph::backupFactors( const std::set<size_t> & facs ) {
400 for( std::set<size_t>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ )
401 backupFactor( *fac );
402 }
403
404
405 bool FactorGraph::isPairwise() const {
406 bool pairwise = true;
407 for( size_t I = 0; I < nrFactors() && pairwise; I++ )
408 if( factor(I).vars().size() > 2 )
409 pairwise = false;
410 return pairwise;
411 }
412
413
414 bool FactorGraph::isBinary() const {
415 bool binary = true;
416 for( size_t i = 0; i < nrVars() && binary; i++ )
417 if( var(i).states() > 2 )
418 binary = false;
419 return binary;
420 }
421
422
423 FactorGraph FactorGraph::clamped( const Var &v, size_t state ) const {
424 Real zeroth_order = 1.0;
425 vector<Factor> clamped_facs;
426 for( size_t I = 0; I < nrFactors(); I++ ) {
427 VarSet v_I = factor(I).vars();
428 Factor new_factor;
429 if( v_I.intersects( v ) )
430 new_factor = factor(I).slice( v, state );
431 else
432 new_factor = factor(I);
433
434 if( new_factor.vars().size() != 0 ) {
435 size_t J = 0;
436 // if it can be merged with a previous one, do that
437 for( J = 0; J < clamped_facs.size(); J++ )
438 if( clamped_facs[J].vars() == new_factor.vars() ) {
439 clamped_facs[J] *= new_factor;
440 break;
441 }
442 // otherwise, push it back
443 if( J == clamped_facs.size() || clamped_facs.size() == 0 )
444 clamped_facs.push_back( new_factor );
445 } else
446 zeroth_order *= new_factor[0];
447 }
448 *(clamped_facs.begin()) *= zeroth_order;
449 return FactorGraph( clamped_facs );
450 }
451
452
453 FactorGraph FactorGraph::maximalFactors() const {
454 vector<size_t> maxfac( nrFactors() );
455 map<size_t,size_t> newindex;
456 size_t nrmax = 0;
457 for( size_t I = 0; I < nrFactors(); I++ ) {
458 maxfac[I] = I;
459 VarSet maxfacvars = factor(maxfac[I]).vars();
460 for( size_t J = 0; J < nrFactors(); J++ ) {
461 VarSet Jvars = factor(J).vars();
462 if( Jvars >> maxfacvars && (Jvars != maxfacvars) ) {
463 maxfac[I] = J;
464 maxfacvars = factor(maxfac[I]).vars();
465 }
466 }
467 if( maxfac[I] == I )
468 newindex[I] = nrmax++;
469 }
470
471 vector<Factor> facs( nrmax );
472 for( size_t I = 0; I < nrFactors(); I++ )
473 facs[newindex[maxfac[I]]] *= factor(I);
474
475 return FactorGraph( facs.begin(), facs.end(), vars().begin(), vars().end(), facs.size(), nrVars() );
476 }
477
478
479 } // end of namespace dai