Significant improvement of documentation
[libdai.git] / src / factorgraph.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <iterator>
25 #include <map>
26 #include <set>
27 #include <fstream>
28 #include <string>
29 #include <algorithm>
30 #include <functional>
31 #include <dai/factorgraph.h>
32 #include <dai/util.h>
33 #include <dai/exceptions.h>
34
35
36 namespace dai {
37
38
39 using namespace std;
40
41
42 FactorGraph::FactorGraph( const std::vector<Factor> &P ) : G(), _backup() {
43 // add factors, obtain variables
44 set<Var> varset;
45 _factors.reserve( P.size() );
46 size_t nrEdges = 0;
47 for( vector<Factor>::const_iterator p2 = P.begin(); p2 != P.end(); p2++ ) {
48 _factors.push_back( *p2 );
49 copy( p2->vars().begin(), p2->vars().end(), inserter( varset, varset.begin() ) );
50 nrEdges += p2->vars().size();
51 }
52
53 // add vars
54 _vars.reserve( varset.size() );
55 for( set<Var>::const_iterator p1 = varset.begin(); p1 != varset.end(); p1++ )
56 _vars.push_back( *p1 );
57
58 // create graph structure
59 constructGraph( nrEdges );
60 }
61
62
63 void FactorGraph::constructGraph( size_t nrEdges ) {
64 // create a mapping for indices
65 hash_map<size_t, size_t> hashmap;
66
67 for( size_t i = 0; i < vars().size(); i++ )
68 hashmap[var(i).label()] = i;
69
70 // create edge list
71 vector<Edge> edges;
72 edges.reserve( nrEdges );
73 for( size_t i2 = 0; i2 < nrFactors(); i2++ ) {
74 const VarSet& ns = factor(i2).vars();
75 for( VarSet::const_iterator q = ns.begin(); q != ns.end(); q++ )
76 edges.push_back( Edge(hashmap[q->label()], i2) );
77 }
78
79 // create bipartite graph
80 G.construct( nrVars(), nrFactors(), edges.begin(), edges.end() );
81 }
82
83
84 /// Writes a FactorGraph to an output stream
85 ostream& operator << (ostream& os, const FactorGraph& fg) {
86 os << fg.nrFactors() << endl;
87
88 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
89 os << endl;
90 os << fg.factor(I).vars().size() << endl;
91 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
92 os << i->label() << " ";
93 os << endl;
94 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
95 os << i->states() << " ";
96 os << endl;
97 size_t nr_nonzeros = 0;
98 for( size_t k = 0; k < fg.factor(I).states(); k++ )
99 if( fg.factor(I)[k] != 0.0 )
100 nr_nonzeros++;
101 os << nr_nonzeros << endl;
102 for( size_t k = 0; k < fg.factor(I).states(); k++ )
103 if( fg.factor(I)[k] != 0.0 ) {
104 char buf[20];
105 sprintf(buf,"%18.14g", fg.factor(I)[k]);
106 os << k << " " << buf << endl;
107 }
108 }
109
110 return(os);
111 }
112
113
114 /// Reads a FactorGraph from an input stream
115 istream& operator >> (istream& is, FactorGraph& fg) {
116 long verbose = 0;
117
118 try {
119 vector<Factor> facs;
120 size_t nr_Factors;
121 string line;
122
123 while( (is.peek()) == '#' )
124 getline(is,line);
125 is >> nr_Factors;
126 if( is.fail() )
127 DAI_THROW(INVALID_FACTORGRAPH_FILE);
128 if( verbose >= 2 )
129 cout << "Reading " << nr_Factors << " factors..." << endl;
130
131 getline (is,line);
132 if( is.fail() )
133 DAI_THROW(INVALID_FACTORGRAPH_FILE);
134
135 map<long,size_t> vardims;
136 for( size_t I = 0; I < nr_Factors; I++ ) {
137 if( verbose >= 3 )
138 cout << "Reading factor " << I << "..." << endl;
139 size_t nr_members;
140 while( (is.peek()) == '#' )
141 getline(is,line);
142 is >> nr_members;
143 if( verbose >= 3 )
144 cout << " nr_members: " << nr_members << endl;
145
146 vector<long> labels;
147 for( size_t mi = 0; mi < nr_members; mi++ ) {
148 long mi_label;
149 while( (is.peek()) == '#' )
150 getline(is,line);
151 is >> mi_label;
152 labels.push_back(mi_label);
153 }
154 if( verbose >= 3 )
155 cout << " labels: " << labels << endl;
156
157 vector<size_t> dims;
158 for( size_t mi = 0; mi < nr_members; mi++ ) {
159 size_t mi_dim;
160 while( (is.peek()) == '#' )
161 getline(is,line);
162 is >> mi_dim;
163 dims.push_back(mi_dim);
164 }
165 if( verbose >= 3 )
166 cout << " dimensions: " << dims << endl;
167
168 // add the Factor
169 VarSet I_vars;
170 for( size_t mi = 0; mi < nr_members; mi++ ) {
171 map<long,size_t>::iterator vdi = vardims.find( labels[mi] );
172 if( vdi != vardims.end() ) {
173 // check whether dimensions are consistent
174 if( vdi->second != dims[mi] )
175 DAI_THROW(INVALID_FACTORGRAPH_FILE);
176 } else
177 vardims[labels[mi]] = dims[mi];
178 I_vars |= Var(labels[mi], dims[mi]);
179 }
180 facs.push_back( Factor( I_vars, 0.0 ) );
181
182 // calculate permutation sigma (internally, members are sorted)
183 vector<size_t> sigma(nr_members,0);
184 VarSet::iterator j = I_vars.begin();
185 for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
186 long search_for = j->label();
187 vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
188 sigma[mi] = j_loc - labels.begin();
189 }
190 if( verbose >= 3 )
191 cout << " sigma: " << sigma << endl;
192
193 // calculate multindices
194 Permute permindex( dims, sigma );
195
196 // read values
197 size_t nr_nonzeros;
198 while( (is.peek()) == '#' )
199 getline(is,line);
200 is >> nr_nonzeros;
201 if( verbose >= 3 )
202 cout << " nonzeroes: " << nr_nonzeros << endl;
203 for( size_t k = 0; k < nr_nonzeros; k++ ) {
204 size_t li;
205 double val;
206 while( (is.peek()) == '#' )
207 getline(is,line);
208 is >> li;
209 while( (is.peek()) == '#' )
210 getline(is,line);
211 is >> val;
212
213 // store value, but permute indices first according
214 // to internal representation
215 facs.back()[permindex.convert_linear_index( li )] = val;
216 }
217 }
218
219 if( verbose >= 3 )
220 cout << "factors:" << facs << endl;
221
222 fg = FactorGraph(facs);
223 } catch (char *e) {
224 cout << e << endl;
225 }
226
227 return is;
228 }
229
230
231 VarSet FactorGraph::delta( unsigned i ) const {
232 return( Delta(i) / var(i) );
233 }
234
235
236 VarSet FactorGraph::Delta( unsigned i ) const {
237 // calculate Markov Blanket
238 VarSet Del;
239 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
240 foreach( const Neighbor &j, nbF(I) ) // for all neighboring variables j of I
241 Del |= var(j);
242
243 return Del;
244 }
245
246
247 VarSet FactorGraph::Delta( const VarSet &ns ) const {
248 VarSet result;
249 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
250 result |= Delta(findVar(*n));
251 return result;
252 }
253
254
255 void FactorGraph::makeCavity( unsigned i, bool backup ) {
256 // fills all Factors that include var(i) with ones
257 map<size_t,Factor> newFacs;
258 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
259 newFacs[I] = Factor(factor(I).vars(), 1.0);
260 setFactors( newFacs, backup );
261 }
262
263
264 void FactorGraph::ReadFromFile( const char *filename ) {
265 ifstream infile;
266 infile.open( filename );
267 if( infile.is_open() ) {
268 infile >> *this;
269 infile.close();
270 } else
271 DAI_THROW(CANNOT_READ_FILE);
272 }
273
274
275 void FactorGraph::WriteToFile( const char *filename ) const {
276 ofstream outfile;
277 outfile.open( filename );
278 if( outfile.is_open() ) {
279 outfile << *this;
280 outfile.close();
281 } else
282 DAI_THROW(CANNOT_WRITE_FILE);
283 }
284
285
286 void FactorGraph::printDot( std::ostream &os ) const {
287 os << "graph G {" << endl;
288 os << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
289 for( size_t i = 0; i < nrVars(); i++ )
290 os << "\tv" << var(i).label() << ";" << endl;
291 os << "node[shape=box,width=0.3,height=0.3,fixedsize=true];" << endl;
292 for( size_t I = 0; I < nrFactors(); I++ )
293 os << "\tf" << I << ";" << endl;
294 for( size_t i = 0; i < nrVars(); i++ )
295 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
296 os << "\tv" << var(i).label() << " -- f" << I << ";" << endl;
297 os << "}" << endl;
298 }
299
300
301 vector<VarSet> FactorGraph::Cliques() const {
302 vector<VarSet> result;
303
304 for( size_t I = 0; I < nrFactors(); I++ ) {
305 bool maximal = true;
306 for( size_t J = 0; (J < nrFactors()) && maximal; J++ )
307 if( (factor(J).vars() >> factor(I).vars()) && (factor(J).vars() != factor(I).vars()) )
308 maximal = false;
309
310 if( maximal )
311 result.push_back( factor(I).vars() );
312 }
313
314 return result;
315 }
316
317
318 void FactorGraph::clamp( const Var & n, size_t i, bool backup ) {
319 assert( i <= n.states() );
320
321 // Multiply each factor that contains the variable with a delta function
322
323 Factor delta_n_i(n,0.0);
324 delta_n_i[i] = 1.0;
325
326 map<size_t, Factor> newFacs;
327 // For all factors that contain n
328 for( size_t I = 0; I < nrFactors(); I++ )
329 if( factor(I).vars().contains( n ) )
330 // Multiply it with a delta function
331 newFacs[I] = factor(I) * delta_n_i;
332 setFactors( newFacs, backup );
333
334 return;
335 }
336
337
338 void FactorGraph::backupFactor( size_t I ) {
339 map<size_t,Factor>::iterator it = _backup.find( I );
340 if( it != _backup.end() )
341 DAI_THROW( MULTIPLE_UNDO );
342 _backup[I] = factor(I);
343 }
344
345
346 void FactorGraph::restoreFactor( size_t I ) {
347 map<size_t,Factor>::iterator it = _backup.find( I );
348 if( it != _backup.end() ) {
349 setFactor(I, it->second);
350 _backup.erase(it);
351 }
352 }
353
354
355 void FactorGraph::backupFactors( const VarSet &ns ) {
356 for( size_t I = 0; I < nrFactors(); I++ )
357 if( factor(I).vars().intersects( ns ) )
358 backupFactor( I );
359 }
360
361
362 void FactorGraph::restoreFactors( const VarSet &ns ) {
363 map<size_t,Factor> facs;
364 for( map<size_t,Factor>::iterator uI = _backup.begin(); uI != _backup.end(); ) {
365 if( factor(uI->first).vars().intersects( ns ) ) {
366 facs.insert( *uI );
367 _backup.erase(uI++);
368 } else
369 uI++;
370 }
371 setFactors( facs );
372 }
373
374
375 void FactorGraph::restoreFactors() {
376 setFactors( _backup );
377 _backup.clear();
378 }
379
380 void FactorGraph::backupFactors( const std::set<size_t> & facs ) {
381 for( std::set<size_t>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ )
382 backupFactor( *fac );
383 }
384
385
386 bool FactorGraph::isPairwise() const {
387 bool pairwise = true;
388 for( size_t I = 0; I < nrFactors() && pairwise; I++ )
389 if( factor(I).vars().size() > 2 )
390 pairwise = false;
391 return pairwise;
392 }
393
394
395 bool FactorGraph::isBinary() const {
396 bool binary = true;
397 for( size_t i = 0; i < nrVars() && binary; i++ )
398 if( var(i).states() > 2 )
399 binary = false;
400 return binary;
401 }
402
403
404 FactorGraph FactorGraph::clamped( const Var & v_i, size_t state ) const {
405 Real zeroth_order = 1.0;
406 vector<Factor> clamped_facs;
407 for( size_t I = 0; I < nrFactors(); I++ ) {
408 VarSet v_I = factor(I).vars();
409 Factor new_factor;
410 if( v_I.intersects( v_i ) )
411 new_factor = factor(I).slice( v_i, state );
412 else
413 new_factor = factor(I);
414
415 if( new_factor.vars().size() != 0 ) {
416 size_t J = 0;
417 // if it can be merged with a previous one, do that
418 for( J = 0; J < clamped_facs.size(); J++ )
419 if( clamped_facs[J].vars() == new_factor.vars() ) {
420 clamped_facs[J] *= new_factor;
421 break;
422 }
423 // otherwise, push it back
424 if( J == clamped_facs.size() || clamped_facs.size() == 0 )
425 clamped_facs.push_back( new_factor );
426 } else
427 zeroth_order *= new_factor[0];
428 }
429 *(clamped_facs.begin()) *= zeroth_order;
430 return FactorGraph( clamped_facs );
431 }
432
433
434 FactorGraph FactorGraph::maximalFactors() const {
435 vector<size_t> maxfac( nrFactors() );
436 map<size_t,size_t> newindex;
437 size_t nrmax = 0;
438 for( size_t I = 0; I < nrFactors(); I++ ) {
439 maxfac[I] = I;
440 VarSet maxfacvars = factor(maxfac[I]).vars();
441 for( size_t J = 0; J < nrFactors(); J++ ) {
442 VarSet Jvars = factor(J).vars();
443 if( Jvars >> maxfacvars && (Jvars != maxfacvars) ) {
444 maxfac[I] = J;
445 maxfacvars = factor(maxfac[I]).vars();
446 }
447 }
448 if( maxfac[I] == I )
449 newindex[I] = nrmax++;
450 }
451
452 vector<Factor> facs( nrmax );
453 for( size_t I = 0; I < nrFactors(); I++ )
454 facs[newindex[maxfac[I]]] *= factor(I);
455
456 return FactorGraph( facs.begin(), facs.end(), vars().begin(), vars().end(), facs.size(), nrVars() );
457 }
458
459
460 } // end of namespace dai