Fixed tabs and trailing whitespaces
[libdai.git] / src / factorgraph.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <iomanip>
25 #include <iterator>
26 #include <map>
27 #include <set>
28 #include <fstream>
29 #include <string>
30 #include <algorithm>
31 #include <functional>
32 #include <dai/factorgraph.h>
33 #include <dai/util.h>
34 #include <dai/exceptions.h>
35 #include <boost/lexical_cast.hpp>
36
37
38 namespace dai {
39
40
41 using namespace std;
42
43
44 FactorGraph::FactorGraph( const std::vector<Factor> &P ) : G(), _backup() {
45 // add factors, obtain variables
46 set<Var> varset;
47 _factors.reserve( P.size() );
48 size_t nrEdges = 0;
49 for( vector<Factor>::const_iterator p2 = P.begin(); p2 != P.end(); p2++ ) {
50 _factors.push_back( *p2 );
51 copy( p2->vars().begin(), p2->vars().end(), inserter( varset, varset.begin() ) );
52 nrEdges += p2->vars().size();
53 }
54
55 // add vars
56 _vars.reserve( varset.size() );
57 for( set<Var>::const_iterator p1 = varset.begin(); p1 != varset.end(); p1++ )
58 _vars.push_back( *p1 );
59
60 // create graph structure
61 constructGraph( nrEdges );
62 }
63
64
65 void FactorGraph::constructGraph( size_t nrEdges ) {
66 // create a mapping for indices
67 hash_map<size_t, size_t> hashmap;
68
69 for( size_t i = 0; i < vars().size(); i++ )
70 hashmap[var(i).label()] = i;
71
72 // create edge list
73 vector<Edge> edges;
74 edges.reserve( nrEdges );
75 for( size_t i2 = 0; i2 < nrFactors(); i2++ ) {
76 const VarSet& ns = factor(i2).vars();
77 for( VarSet::const_iterator q = ns.begin(); q != ns.end(); q++ )
78 edges.push_back( Edge(hashmap[q->label()], i2) );
79 }
80
81 // create bipartite graph
82 G.construct( nrVars(), nrFactors(), edges.begin(), edges.end() );
83 }
84
85
86 /// Writes a FactorGraph to an output stream
87 std::ostream& operator<< ( std::ostream &os, const FactorGraph &fg ) {
88 os << fg.nrFactors() << endl;
89
90 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
91 os << endl;
92 os << fg.factor(I).vars().size() << endl;
93 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
94 os << i->label() << " ";
95 os << endl;
96 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
97 os << i->states() << " ";
98 os << endl;
99 size_t nr_nonzeros = 0;
100 for( size_t k = 0; k < fg.factor(I).states(); k++ )
101 if( fg.factor(I)[k] != 0.0 )
102 nr_nonzeros++;
103 os << nr_nonzeros << endl;
104 for( size_t k = 0; k < fg.factor(I).states(); k++ )
105 if( fg.factor(I)[k] != 0.0 )
106 os << k << " " << setw(os.precision()+4) << fg.factor(I)[k] << endl;
107 }
108
109 return(os);
110 }
111
112
113 /// Reads a FactorGraph from an input stream
114 std::istream& operator>> ( std::istream& is, FactorGraph &fg ) {
115 long verbose = 0;
116
117 vector<Factor> facs;
118 size_t nr_Factors;
119 string line;
120
121 while( (is.peek()) == '#' )
122 getline(is,line);
123 is >> nr_Factors;
124 if( is.fail() )
125 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Cannot read number of factors");
126 if( verbose >= 2 )
127 cerr << "Reading " << nr_Factors << " factors..." << endl;
128
129 getline (is,line);
130 if( is.fail() )
131 DAI_THROW(INVALID_FACTORGRAPH_FILE);
132
133 map<long,size_t> vardims;
134 for( size_t I = 0; I < nr_Factors; I++ ) {
135 if( verbose >= 3 )
136 cerr << "Reading factor " << I << "..." << endl;
137 size_t nr_members;
138 while( (is.peek()) == '#' )
139 getline(is,line);
140 is >> nr_members;
141 if( verbose >= 3 )
142 cerr << " nr_members: " << nr_members << endl;
143
144 vector<long> labels;
145 for( size_t mi = 0; mi < nr_members; mi++ ) {
146 long mi_label;
147 while( (is.peek()) == '#' )
148 getline(is,line);
149 is >> mi_label;
150 labels.push_back(mi_label);
151 }
152 if( verbose >= 3 )
153 cerr << " labels: " << labels << endl;
154
155 vector<size_t> dims;
156 for( size_t mi = 0; mi < nr_members; mi++ ) {
157 size_t mi_dim;
158 while( (is.peek()) == '#' )
159 getline(is,line);
160 is >> mi_dim;
161 dims.push_back(mi_dim);
162 }
163 if( verbose >= 3 )
164 cerr << " dimensions: " << dims << endl;
165
166 // add the Factor
167 VarSet I_vars;
168 for( size_t mi = 0; mi < nr_members; mi++ ) {
169 map<long,size_t>::iterator vdi = vardims.find( labels[mi] );
170 if( vdi != vardims.end() ) {
171 // check whether dimensions are consistent
172 if( vdi->second != dims[mi] )
173 DAI_THROWE(INVALID_FACTORGRAPH_FILE,"Variable with label " + boost::lexical_cast<string>(labels[mi]) + " has inconsistent dimensions.");
174 } else
175 vardims[labels[mi]] = dims[mi];
176 I_vars |= Var(labels[mi], dims[mi]);
177 }
178 facs.push_back( Factor( I_vars, 0.0 ) );
179
180 // calculate permutation sigma (internally, members are sorted)
181 vector<size_t> sigma(nr_members,0);
182 VarSet::iterator j = I_vars.begin();
183 for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
184 long search_for = j->label();
185 vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
186 sigma[mi] = j_loc - labels.begin();
187 }
188 if( verbose >= 3 )
189 cerr << " sigma: " << sigma << endl;
190
191 // calculate multindices
192 Permute permindex( dims, sigma );
193
194 // read values
195 size_t nr_nonzeros;
196 while( (is.peek()) == '#' )
197 getline(is,line);
198 is >> nr_nonzeros;
199 if( verbose >= 3 )
200 cerr << " nonzeroes: " << nr_nonzeros << endl;
201 for( size_t k = 0; k < nr_nonzeros; k++ ) {
202 size_t li;
203 double val;
204 while( (is.peek()) == '#' )
205 getline(is,line);
206 is >> li;
207 while( (is.peek()) == '#' )
208 getline(is,line);
209 is >> val;
210
211 // store value, but permute indices first according
212 // to internal representation
213 facs.back()[permindex.convert_linear_index( li )] = val;
214 }
215 }
216
217 if( verbose >= 3 )
218 cerr << "factors:" << facs << endl;
219
220 fg = FactorGraph(facs);
221
222 return is;
223 }
224
225
226 VarSet FactorGraph::delta( unsigned i ) const {
227 return( Delta(i) / var(i) );
228 }
229
230
231 VarSet FactorGraph::Delta( unsigned i ) const {
232 // calculate Markov Blanket
233 VarSet Del;
234 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
235 foreach( const Neighbor &j, nbF(I) ) // for all neighboring variables j of I
236 Del |= var(j);
237
238 return Del;
239 }
240
241
242 VarSet FactorGraph::Delta( const VarSet &ns ) const {
243 VarSet result;
244 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
245 result |= Delta(findVar(*n));
246 return result;
247 }
248
249
250 void FactorGraph::makeCavity( unsigned i, bool backup ) {
251 // fills all Factors that include var(i) with ones
252 map<size_t,Factor> newFacs;
253 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
254 newFacs[I] = Factor(factor(I).vars(), 1.0);
255 setFactors( newFacs, backup );
256 }
257
258
259 void FactorGraph::ReadFromFile( const char *filename ) {
260 ifstream infile;
261 infile.open( filename );
262 if( infile.is_open() ) {
263 infile >> *this;
264 infile.close();
265 } else
266 DAI_THROWE(CANNOT_READ_FILE,"Cannot read from file " + std::string(filename));
267 }
268
269
270 void FactorGraph::WriteToFile( const char *filename, size_t precision ) const {
271 ofstream outfile;
272 outfile.open( filename );
273 if( outfile.is_open() ) {
274 outfile.precision( precision );
275 outfile << *this;
276 outfile.close();
277 } else
278 DAI_THROWE(CANNOT_WRITE_FILE,"Cannot write to file " + std::string(filename));
279 }
280
281
282 void FactorGraph::printDot( std::ostream &os ) const {
283 os << "graph G {" << endl;
284 os << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
285 for( size_t i = 0; i < nrVars(); i++ )
286 os << "\tv" << var(i).label() << ";" << endl;
287 os << "node[shape=box,width=0.3,height=0.3,fixedsize=true];" << endl;
288 for( size_t I = 0; I < nrFactors(); I++ )
289 os << "\tf" << I << ";" << endl;
290 for( size_t i = 0; i < nrVars(); i++ )
291 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
292 os << "\tv" << var(i).label() << " -- f" << I << ";" << endl;
293 os << "}" << endl;
294 }
295
296
297 vector<VarSet> FactorGraph::Cliques() const {
298 vector<VarSet> result;
299
300 for( size_t I = 0; I < nrFactors(); I++ ) {
301 bool maximal = true;
302 for( size_t J = 0; (J < nrFactors()) && maximal; J++ )
303 if( (factor(J).vars() >> factor(I).vars()) && (factor(J).vars() != factor(I).vars()) )
304 maximal = false;
305
306 if( maximal )
307 result.push_back( factor(I).vars() );
308 }
309
310 return result;
311 }
312
313
314 void FactorGraph::clamp( size_t i, size_t x, bool backup ) {
315 assert( x <= var(i).states() );
316 Factor mask( var(i), 0.0 );
317 mask[x] = 1.0;
318
319 map<size_t, Factor> newFacs;
320 foreach( const BipartiteGraph::Neighbor &I, nbV(i) )
321 newFacs[I] = factor(I) * mask;
322 setFactors( newFacs, backup );
323
324 return;
325 }
326
327
328 void FactorGraph::clampVar( size_t i, const vector<size_t> &is, bool backup ) {
329 Var n = var(i);
330 Factor mask_n( n, 0.0 );
331
332 foreach( size_t i, is ) {
333 assert( i <= n.states() );
334 mask_n[i] = 1.0;
335 }
336
337 map<size_t, Factor> newFacs;
338 foreach( const BipartiteGraph::Neighbor &I, nbV(i) )
339 newFacs[I] = factor(I) * mask_n;
340 setFactors( newFacs, backup );
341 }
342
343
344 void FactorGraph::clampFactor( size_t I, const vector<size_t> &is, bool backup ) {
345 size_t st = factor(I).states();
346 Factor newF( factor(I).vars(), 0.0 );
347
348 foreach( size_t i, is ) {
349 assert( i <= st );
350 newF[i] = factor(I)[i];
351 }
352
353 setFactor( I, newF, backup );
354 }
355
356
357 void FactorGraph::backupFactor( size_t I ) {
358 map<size_t,Factor>::iterator it = _backup.find( I );
359 if( it != _backup.end() )
360 DAI_THROW(MULTIPLE_UNDO);
361 _backup[I] = factor(I);
362 }
363
364
365 void FactorGraph::restoreFactor( size_t I ) {
366 map<size_t,Factor>::iterator it = _backup.find( I );
367 if( it != _backup.end() ) {
368 setFactor(I, it->second);
369 _backup.erase(it);
370 }
371 }
372
373
374 void FactorGraph::backupFactors( const VarSet &ns ) {
375 for( size_t I = 0; I < nrFactors(); I++ )
376 if( factor(I).vars().intersects( ns ) )
377 backupFactor( I );
378 }
379
380
381 void FactorGraph::restoreFactors( const VarSet &ns ) {
382 map<size_t,Factor> facs;
383 for( map<size_t,Factor>::iterator uI = _backup.begin(); uI != _backup.end(); ) {
384 if( factor(uI->first).vars().intersects( ns ) ) {
385 facs.insert( *uI );
386 _backup.erase(uI++);
387 } else
388 uI++;
389 }
390 setFactors( facs );
391 }
392
393
394 void FactorGraph::restoreFactors() {
395 setFactors( _backup );
396 _backup.clear();
397 }
398
399
400 void FactorGraph::backupFactors( const std::set<size_t> & facs ) {
401 for( std::set<size_t>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ )
402 backupFactor( *fac );
403 }
404
405
406 bool FactorGraph::isPairwise() const {
407 bool pairwise = true;
408 for( size_t I = 0; I < nrFactors() && pairwise; I++ )
409 if( factor(I).vars().size() > 2 )
410 pairwise = false;
411 return pairwise;
412 }
413
414
415 bool FactorGraph::isBinary() const {
416 bool binary = true;
417 for( size_t i = 0; i < nrVars() && binary; i++ )
418 if( var(i).states() > 2 )
419 binary = false;
420 return binary;
421 }
422
423
424 FactorGraph FactorGraph::clamped( const Var &v, size_t state ) const {
425 Real zeroth_order = 1.0;
426 vector<Factor> clamped_facs;
427 for( size_t I = 0; I < nrFactors(); I++ ) {
428 VarSet v_I = factor(I).vars();
429 Factor new_factor;
430 if( v_I.intersects( v ) )
431 new_factor = factor(I).slice( v, state );
432 else
433 new_factor = factor(I);
434
435 if( new_factor.vars().size() != 0 ) {
436 size_t J = 0;
437 // if it can be merged with a previous one, do that
438 for( J = 0; J < clamped_facs.size(); J++ )
439 if( clamped_facs[J].vars() == new_factor.vars() ) {
440 clamped_facs[J] *= new_factor;
441 break;
442 }
443 // otherwise, push it back
444 if( J == clamped_facs.size() || clamped_facs.size() == 0 )
445 clamped_facs.push_back( new_factor );
446 } else
447 zeroth_order *= new_factor[0];
448 }
449 *(clamped_facs.begin()) *= zeroth_order;
450 return FactorGraph( clamped_facs );
451 }
452
453
454 FactorGraph FactorGraph::maximalFactors() const {
455 vector<size_t> maxfac( nrFactors() );
456 map<size_t,size_t> newindex;
457 size_t nrmax = 0;
458 for( size_t I = 0; I < nrFactors(); I++ ) {
459 maxfac[I] = I;
460 VarSet maxfacvars = factor(maxfac[I]).vars();
461 for( size_t J = 0; J < nrFactors(); J++ ) {
462 VarSet Jvars = factor(J).vars();
463 if( Jvars >> maxfacvars && (Jvars != maxfacvars) ) {
464 maxfac[I] = J;
465 maxfacvars = factor(maxfac[I]).vars();
466 }
467 }
468 if( maxfac[I] == I )
469 newindex[I] = nrmax++;
470 }
471
472 vector<Factor> facs( nrmax );
473 for( size_t I = 0; I < nrFactors(); I++ )
474 facs[newindex[maxfac[I]]] *= factor(I);
475
476 return FactorGraph( facs.begin(), facs.end(), vars().begin(), vars().end(), facs.size(), nrVars() );
477 }
478
479
480 } // end of namespace dai