Fixed previous partial commit which left everything broken
[libdai.git] / src / factorgraph.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <iterator>
24 #include <map>
25 #include <set>
26 #include <fstream>
27 #include <string>
28 #include <algorithm>
29 #include <functional>
30 #include <dai/factorgraph.h>
31 #include <dai/util.h>
32 #include <dai/exceptions.h>
33
34
35 namespace dai {
36
37
38 using namespace std;
39
40
41 FactorGraph::FactorGraph( const std::vector<Factor> &P ) : G(), _backup() {
42 // add factors, obtain variables
43 set<Var> _vars;
44 factors.reserve( P.size() );
45 size_t nrEdges = 0;
46 for( vector<Factor>::const_iterator p2 = P.begin(); p2 != P.end(); p2++ ) {
47 factors.push_back( *p2 );
48 copy( p2->vars().begin(), p2->vars().end(), inserter( _vars, _vars.begin() ) );
49 nrEdges += p2->vars().size();
50 }
51
52 // add _vars
53 vars.reserve( _vars.size() );
54 for( set<Var>::const_iterator p1 = _vars.begin(); p1 != _vars.end(); p1++ )
55 vars.push_back( *p1 );
56
57 // create graph structure
58 createGraph( nrEdges );
59 }
60
61
62 /// Part of constructors (creates edges, neighbours and adjacency matrix)
63 void FactorGraph::createGraph( size_t nrEdges ) {
64 // create a mapping for indices
65 hash_map<size_t, size_t> hashmap;
66
67 for( size_t i = 0; i < vars.size(); i++ )
68 hashmap[var(i).label()] = i;
69
70 // create edge list
71 vector<Edge> edges;
72 edges.reserve( nrEdges );
73 for( size_t i2 = 0; i2 < nrFactors(); i2++ ) {
74 const VarSet& ns = factor(i2).vars();
75 for( VarSet::const_iterator q = ns.begin(); q != ns.end(); q++ )
76 edges.push_back( Edge(hashmap[q->label()], i2) );
77 }
78
79 // create bipartite graph
80 G.create( nrVars(), nrFactors(), edges.begin(), edges.end() );
81 }
82
83
84 ostream& operator << (ostream& os, const FactorGraph& fg) {
85 os << fg.nrFactors() << endl;
86
87 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
88 os << endl;
89 os << fg.factor(I).vars().size() << endl;
90 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
91 os << i->label() << " ";
92 os << endl;
93 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
94 os << i->states() << " ";
95 os << endl;
96 size_t nr_nonzeros = 0;
97 for( size_t k = 0; k < fg.factor(I).states(); k++ )
98 if( fg.factor(I)[k] != 0.0 )
99 nr_nonzeros++;
100 os << nr_nonzeros << endl;
101 for( size_t k = 0; k < fg.factor(I).states(); k++ )
102 if( fg.factor(I)[k] != 0.0 ) {
103 char buf[20];
104 sprintf(buf,"%18.14g", fg.factor(I)[k]);
105 os << k << " " << buf << endl;
106 }
107 }
108
109 return(os);
110 }
111
112
113 istream& operator >> (istream& is, FactorGraph& fg) {
114 long verbose = 0;
115
116 try {
117 vector<Factor> factors;
118 size_t nr_Factors;
119 string line;
120
121 while( (is.peek()) == '#' )
122 getline(is,line);
123 is >> nr_Factors;
124 if( is.fail() )
125 DAI_THROW(INVALID_FACTORGRAPH_FILE);
126 if( verbose >= 2 )
127 cout << "Reading " << nr_Factors << " factors..." << endl;
128
129 getline (is,line);
130 if( is.fail() )
131 DAI_THROW(INVALID_FACTORGRAPH_FILE);
132
133 map<long,size_t> vardims;
134 for( size_t I = 0; I < nr_Factors; I++ ) {
135 if( verbose >= 3 )
136 cout << "Reading factor " << I << "..." << endl;
137 size_t nr_members;
138 while( (is.peek()) == '#' )
139 getline(is,line);
140 is >> nr_members;
141 if( verbose >= 3 )
142 cout << " nr_members: " << nr_members << endl;
143
144 vector<long> labels;
145 for( size_t mi = 0; mi < nr_members; mi++ ) {
146 long mi_label;
147 while( (is.peek()) == '#' )
148 getline(is,line);
149 is >> mi_label;
150 labels.push_back(mi_label);
151 }
152 if( verbose >= 3 )
153 cout << " labels: " << labels << endl;
154
155 vector<size_t> dims;
156 for( size_t mi = 0; mi < nr_members; mi++ ) {
157 size_t mi_dim;
158 while( (is.peek()) == '#' )
159 getline(is,line);
160 is >> mi_dim;
161 dims.push_back(mi_dim);
162 }
163 if( verbose >= 3 )
164 cout << " dimensions: " << dims << endl;
165
166 // add the Factor
167 VarSet I_vars;
168 for( size_t mi = 0; mi < nr_members; mi++ ) {
169 map<long,size_t>::iterator vdi = vardims.find( labels[mi] );
170 if( vdi != vardims.end() ) {
171 // check whether dimensions are consistent
172 if( vdi->second != dims[mi] )
173 DAI_THROW(INVALID_FACTORGRAPH_FILE);
174 } else
175 vardims[labels[mi]] = dims[mi];
176 I_vars |= Var(labels[mi], dims[mi]);
177 }
178 factors.push_back( Factor( I_vars, 0.0 ) );
179
180 // calculate permutation sigma (internally, members are sorted)
181 vector<size_t> sigma(nr_members,0);
182 VarSet::iterator j = I_vars.begin();
183 for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
184 long search_for = j->label();
185 vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
186 sigma[mi] = j_loc - labels.begin();
187 }
188 if( verbose >= 3 )
189 cout << " sigma: " << sigma << endl;
190
191 // calculate multindices
192 Permute permindex( dims, sigma );
193
194 // read values
195 size_t nr_nonzeros;
196 while( (is.peek()) == '#' )
197 getline(is,line);
198 is >> nr_nonzeros;
199 if( verbose >= 3 )
200 cout << " nonzeroes: " << nr_nonzeros << endl;
201 for( size_t k = 0; k < nr_nonzeros; k++ ) {
202 size_t li;
203 double val;
204 while( (is.peek()) == '#' )
205 getline(is,line);
206 is >> li;
207 while( (is.peek()) == '#' )
208 getline(is,line);
209 is >> val;
210
211 // store value, but permute indices first according
212 // to internal representation
213 factors.back()[permindex.convert_linear_index( li )] = val;
214 }
215 }
216
217 if( verbose >= 3 )
218 cout << "factors:" << factors << endl;
219
220 fg = FactorGraph(factors);
221 } catch (char *e) {
222 cout << e << endl;
223 }
224
225 return is;
226 }
227
228
229 VarSet FactorGraph::delta( unsigned i ) const {
230 return( Delta(i) / var(i) );
231 }
232
233
234 VarSet FactorGraph::Delta( unsigned i ) const {
235 // calculate Markov Blanket
236 VarSet Del;
237 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
238 foreach( const Neighbor &j, nbF(I) ) // for all neighboring variables j of I
239 Del |= var(j);
240
241 return Del;
242 }
243
244
245 VarSet FactorGraph::Delta( const VarSet &ns ) const {
246 VarSet result;
247 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
248 result |= Delta(findVar(*n));
249 return result;
250 }
251
252
253 void FactorGraph::makeCavity( unsigned i, bool backup ) {
254 // fills all Factors that include var(i) with ones
255 map<size_t,Factor> newFacs;
256 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
257 newFacs[I] = Factor(factor(I).vars(), 1.0);
258 setFactors( newFacs, backup );
259 }
260
261
262 void FactorGraph::ReadFromFile( const char *filename ) {
263 ifstream infile;
264 infile.open( filename );
265 if( infile.is_open() ) {
266 infile >> *this;
267 infile.close();
268 } else
269 DAI_THROW(CANNOT_READ_FILE);
270 }
271
272
273 void FactorGraph::WriteToFile( const char *filename ) const {
274 ofstream outfile;
275 outfile.open( filename );
276 if( outfile.is_open() ) {
277 outfile << *this;
278 outfile.close();
279 } else
280 DAI_THROW(CANNOT_WRITE_FILE);
281 }
282
283
284 void FactorGraph::printDot( std::ostream &os ) const {
285 os << "graph G {" << endl;
286 os << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
287 for( size_t i = 0; i < nrVars(); i++ )
288 os << "\tv" << var(i).label() << ";" << endl;
289 os << "node[shape=box,width=0.3,height=0.3,fixedsize=true];" << endl;
290 for( size_t I = 0; I < nrFactors(); I++ )
291 os << "\tf" << I << ";" << endl;
292 for( size_t i = 0; i < nrVars(); i++ )
293 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
294 os << "\tv" << var(i).label() << " -- f" << I << ";" << endl;
295 os << "}" << endl;
296 }
297
298
299 vector<VarSet> FactorGraph::Cliques() const {
300 vector<VarSet> result;
301
302 for( size_t I = 0; I < nrFactors(); I++ ) {
303 bool maximal = true;
304 for( size_t J = 0; (J < nrFactors()) && maximal; J++ )
305 if( (factor(J).vars() >> factor(I).vars()) && (factor(J).vars() != factor(I).vars()) )
306 maximal = false;
307
308 if( maximal )
309 result.push_back( factor(I).vars() );
310 }
311
312 return result;
313 }
314
315
316 void FactorGraph::clamp( const Var & n, size_t i, bool backup ) {
317 assert( i <= n.states() );
318
319 // Multiply each factor that contains the variable with a delta function
320
321 Factor delta_n_i(n,0.0);
322 delta_n_i[i] = 1.0;
323
324 map<size_t, Factor> newFacs;
325 // For all factors that contain n
326 for( size_t I = 0; I < nrFactors(); I++ )
327 if( factor(I).vars().contains( n ) )
328 // Multiply it with a delta function
329 newFacs[I] = factor(I) * delta_n_i;
330 setFactors( newFacs, backup );
331
332 return;
333 }
334
335
336 void FactorGraph::backupFactor( size_t I ) {
337 map<size_t,Factor>::iterator it = _backup.find( I );
338 if( it != _backup.end() )
339 DAI_THROW( MULTIPLE_UNDO );
340 _backup[I] = factor(I);
341 }
342
343
344 void FactorGraph::restoreFactor( size_t I ) {
345 map<size_t,Factor>::iterator it = _backup.find( I );
346 if( it != _backup.end() ) {
347 setFactor(I, it->second);
348 _backup.erase(it);
349 }
350 }
351
352
353 void FactorGraph::backupFactors( const VarSet &ns ) {
354 for( size_t I = 0; I < nrFactors(); I++ )
355 if( factor(I).vars().intersects( ns ) )
356 backupFactor( I );
357 }
358
359
360 void FactorGraph::restoreFactors( const VarSet &ns ) {
361 map<size_t,Factor> facs;
362 for( map<size_t,Factor>::iterator uI = _backup.begin(); uI != _backup.end(); ) {
363 if( factor(uI->first).vars().intersects( ns ) ) {
364 facs.insert( *uI );
365 _backup.erase(uI++);
366 } else
367 uI++;
368 }
369 setFactors( facs );
370 }
371
372
373 void FactorGraph::restoreFactors() {
374 setFactors( _backup );
375 _backup.clear();
376 }
377
378 void FactorGraph::backupFactors( const std::set<size_t> & facs ) {
379 for( std::set<size_t>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ )
380 backupFactor( *fac );
381 }
382
383
384 bool FactorGraph::isPairwise() const {
385 bool pairwise = true;
386 for( size_t I = 0; I < nrFactors() && pairwise; I++ )
387 if( factor(I).vars().size() > 2 )
388 pairwise = false;
389 return pairwise;
390 }
391
392
393 bool FactorGraph::isBinary() const {
394 bool binary = true;
395 for( size_t i = 0; i < nrVars() && binary; i++ )
396 if( var(i).states() > 2 )
397 binary = false;
398 return binary;
399 }
400
401
402 FactorGraph FactorGraph::clamped( const Var & v_i, size_t state ) const {
403 Real zeroth_order = 1.0;
404 vector<Factor> clamped_facs;
405 for( size_t I = 0; I < nrFactors(); I++ ) {
406 VarSet v_I = factor(I).vars();
407 Factor new_factor;
408 if( v_I.intersects( v_i ) )
409 new_factor = factor(I).slice( v_i, state );
410 else
411 new_factor = factor(I);
412
413 if( new_factor.vars().size() != 0 ) {
414 size_t J = 0;
415 // if it can be merged with a previous one, do that
416 for( J = 0; J < clamped_facs.size(); J++ )
417 if( clamped_facs[J].vars() == new_factor.vars() ) {
418 clamped_facs[J] *= new_factor;
419 break;
420 }
421 // otherwise, push it back
422 if( J == clamped_facs.size() || clamped_facs.size() == 0 )
423 clamped_facs.push_back( new_factor );
424 } else
425 zeroth_order *= new_factor[0];
426 }
427 *(clamped_facs.begin()) *= zeroth_order;
428 return FactorGraph( clamped_facs );
429 }
430
431
432 FactorGraph FactorGraph::maximalFactors() const {
433 vector<size_t> maxfac( nrFactors() );
434 map<size_t,size_t> newindex;
435 size_t nrmax = 0;
436 for( size_t I = 0; I < nrFactors(); I++ ) {
437 maxfac[I] = I;
438 VarSet maxfacvars = factor(maxfac[I]).vars();
439 for( size_t J = 0; J < nrFactors(); J++ ) {
440 VarSet Jvars = factor(J).vars();
441 if( Jvars >> maxfacvars && (Jvars != maxfacvars) ) {
442 maxfac[I] = J;
443 maxfacvars = factor(maxfac[I]).vars();
444 }
445 }
446 if( maxfac[I] == I )
447 newindex[I] = nrmax++;
448 }
449
450 vector<Factor> facs( nrmax );
451 for( size_t I = 0; I < nrFactors(); I++ )
452 facs[newindex[maxfac[I]]] *= factor(I);
453
454 return FactorGraph( facs.begin(), facs.end(), vars.begin(), vars.end(), facs.size(), nrVars() );
455 }
456
457
458 } // end of namespace dai