Miscellaneous changes:
[libdai.git] / src / factorgraph.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <iterator>
24 #include <map>
25 #include <set>
26 #include <fstream>
27 #include <string>
28 #include <algorithm>
29 #include <functional>
30 #include <tr1/unordered_map>
31 #include <dai/factorgraph.h>
32 #include <dai/util.h>
33
34
35 namespace dai {
36
37
38 using namespace std;
39
40
41 FactorGraph::FactorGraph( const std::vector<Factor> &P ) : G(), _undoProbs(), _normtype(Prob::NORMPROB) {
42 // add factors, obtain variables
43 set<Var> _vars;
44 factors.reserve( P.size() );
45 size_t nrEdges = 0;
46 for( vector<Factor>::const_iterator p2 = P.begin(); p2 != P.end(); p2++ ) {
47 factors.push_back( *p2 );
48 copy( p2->vars().begin(), p2->vars().end(), inserter( _vars, _vars.begin() ) );
49 nrEdges += p2->vars().size();
50 }
51
52 // add _vars
53 vars.reserve( _vars.size() );
54 for( set<Var>::const_iterator p1 = _vars.begin(); p1 != _vars.end(); p1++ )
55 vars.push_back( *p1 );
56
57 // create graph structure
58 createGraph( nrEdges );
59 }
60
61
62 /// Part of constructors (creates edges, neighbours and adjacency matrix)
63 void FactorGraph::createGraph( size_t nrEdges ) {
64 // create a mapping for indices
65 std::tr1::unordered_map<size_t, size_t> hashmap;
66
67 for( size_t i = 0; i < vars.size(); i++ )
68 hashmap[var(i).label()] = i;
69
70 // create edge list
71 typedef pair<unsigned,unsigned> Edge;
72 vector<Edge> edges;
73 edges.reserve( nrEdges );
74 for( size_t i2 = 0; i2 < nrFactors(); i2++ ) {
75 const VarSet& ns = factor(i2).vars();
76 for( VarSet::const_iterator q = ns.begin(); q != ns.end(); q++ )
77 edges.push_back( Edge(hashmap[q->label()], i2) );
78 }
79
80 // create bipartite graph
81 G.create( nrVars(), nrFactors(), edges.begin(), edges.end() );
82 }
83
84
85 ostream& operator << (ostream& os, const FactorGraph& fg) {
86 os << fg.nrFactors() << endl;
87
88 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
89 os << endl;
90 os << fg.factor(I).vars().size() << endl;
91 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
92 os << i->label() << " ";
93 os << endl;
94 for( VarSet::const_iterator i = fg.factor(I).vars().begin(); i != fg.factor(I).vars().end(); i++ )
95 os << i->states() << " ";
96 os << endl;
97 size_t nr_nonzeros = 0;
98 for( size_t k = 0; k < fg.factor(I).states(); k++ )
99 if( fg.factor(I)[k] != 0.0 )
100 nr_nonzeros++;
101 os << nr_nonzeros << endl;
102 for( size_t k = 0; k < fg.factor(I).states(); k++ )
103 if( fg.factor(I)[k] != 0.0 ) {
104 char buf[20];
105 sprintf(buf,"%18.14g", fg.factor(I)[k]);
106 os << k << " " << buf << endl;
107 }
108 }
109
110 return(os);
111 }
112
113
114 istream& operator >> (istream& is, FactorGraph& fg) {
115 long verbose = 0;
116
117 try {
118 vector<Factor> factors;
119 size_t nr_f;
120 string line;
121
122 while( (is.peek()) == '#' )
123 getline(is,line);
124 is >> nr_f;
125 if( is.fail() )
126 throw "ReadFromFile: unable to read number of Factors";
127 if( verbose >= 2 )
128 cout << "Reading " << nr_f << " factors..." << endl;
129
130 getline (is,line);
131 if( is.fail() )
132 throw "ReadFromFile: empty line expected";
133
134 for( size_t I = 0; I < nr_f; I++ ) {
135 if( verbose >= 3 )
136 cout << "Reading factor " << I << "..." << endl;
137 size_t nr_members;
138 while( (is.peek()) == '#' )
139 getline(is,line);
140 is >> nr_members;
141 if( verbose >= 3 )
142 cout << " nr_members: " << nr_members << endl;
143
144 vector<long> labels;
145 for( size_t mi = 0; mi < nr_members; mi++ ) {
146 long mi_label;
147 while( (is.peek()) == '#' )
148 getline(is,line);
149 is >> mi_label;
150 labels.push_back(mi_label);
151 }
152 if( verbose >= 3 ) {
153 cout << " labels: ";
154 copy (labels.begin(), labels.end(), ostream_iterator<int>(cout, " "));
155 cout << endl;
156 }
157
158 vector<size_t> dims;
159 for( size_t mi = 0; mi < nr_members; mi++ ) {
160 size_t mi_dim;
161 while( (is.peek()) == '#' )
162 getline(is,line);
163 is >> mi_dim;
164 dims.push_back(mi_dim);
165 }
166 if( verbose >= 3 ) {
167 cout << " dimensions: ";
168 copy (dims.begin(), dims.end(), ostream_iterator<int>(cout, " "));
169 cout << endl;
170 }
171
172 // add the Factor
173 VarSet I_vars;
174 for( size_t mi = 0; mi < nr_members; mi++ )
175 I_vars.insert( Var(labels[mi], dims[mi]) );
176 factors.push_back(Factor(I_vars,0.0));
177
178 // calculate permutation sigma (internally, members are sorted)
179 vector<size_t> sigma(nr_members,0);
180 VarSet::iterator j = I_vars.begin();
181 for( size_t mi = 0; mi < nr_members; mi++,j++ ) {
182 long search_for = j->label();
183 vector<long>::iterator j_loc = find(labels.begin(),labels.end(),search_for);
184 sigma[mi] = j_loc - labels.begin();
185 }
186 if( verbose >= 3 ) {
187 cout << " sigma: ";
188 copy( sigma.begin(), sigma.end(), ostream_iterator<int>(cout," "));
189 cout << endl;
190 }
191
192 // calculate multindices
193 Permute permindex( dims, sigma );
194
195 // read values
196 size_t nr_nonzeros;
197 while( (is.peek()) == '#' )
198 getline(is,line);
199 is >> nr_nonzeros;
200 if( verbose >= 3 )
201 cout << " nonzeroes: " << nr_nonzeros << endl;
202 for( size_t k = 0; k < nr_nonzeros; k++ ) {
203 size_t li;
204 double val;
205 while( (is.peek()) == '#' )
206 getline(is,line);
207 is >> li;
208 while( (is.peek()) == '#' )
209 getline(is,line);
210 is >> val;
211
212 // store value, but permute indices first according
213 // to internal representation
214 factors.back()[permindex.convert_linear_index( li )] = val;
215 }
216 }
217
218 if( verbose >= 3 ) {
219 cout << "factors:" << endl;
220 copy(factors.begin(), factors.end(), ostream_iterator<Factor>(cout,"\n"));
221 }
222
223 fg = FactorGraph(factors);
224 } catch (char *e) {
225 cout << e << endl;
226 }
227
228 return is;
229 }
230
231
232 VarSet FactorGraph::delta( unsigned i ) const {
233 // calculate Markov Blanket
234 VarSet del;
235 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
236 foreach( const Neighbor &j, nbF(I) ) // for all neighboring variables j of I
237 if( j != i )
238 del |= var(j);
239
240 return del;
241 }
242
243
244 VarSet FactorGraph::Delta( unsigned i ) const {
245 return( delta(i) | var(i) );
246 }
247
248
249 void FactorGraph::makeCavity( unsigned i ) {
250 // fills all Factors that include var(i) with ones
251 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
252 factor(I).fill( 1.0 );
253 }
254
255
256 bool FactorGraph::hasNegatives() const {
257 bool result = false;
258 for( size_t I = 0; I < nrFactors() && !result; I++ )
259 if( factor(I).hasNegatives() )
260 result = true;
261 return result;
262 }
263
264
265 long FactorGraph::ReadFromFile(const char *filename) {
266 ifstream infile;
267 infile.open (filename);
268 if (infile.is_open()) {
269 infile >> *this;
270 infile.close();
271 return 0;
272 } else {
273 cout << "ERROR OPENING FILE" << endl;
274 return 1;
275 }
276 }
277
278
279 long FactorGraph::WriteToFile(const char *filename) const {
280 ofstream outfile;
281 outfile.open (filename);
282 if (outfile.is_open()) {
283 try {
284 outfile << *this;
285 } catch (char *e) {
286 cout << e << endl;
287 return 1;
288 }
289 outfile.close();
290 return 0;
291 } else {
292 cout << "ERROR OPENING FILE" << endl;
293 return 1;
294 }
295 }
296
297
298 long FactorGraph::WriteToDotFile(const char *filename) const {
299 ofstream outfile;
300 outfile.open (filename);
301 if (outfile.is_open()) {
302 try {
303 outfile << "graph G {" << endl;
304 outfile << "graph[size=\"9,9\"];" << endl;
305 outfile << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
306 for( size_t i = 0; i < nrVars(); i++ )
307 outfile << "\tx" << var(i).label() << ";" << endl;
308 outfile << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
309 for( size_t I = 0; I < nrFactors(); I++ )
310 outfile << "\tp" << I << ";" << endl;
311 for( size_t i = 0; i < nrVars(); i++ )
312 foreach( const Neighbor &I, nbV(i) ) // for all neighboring factors I of i
313 outfile << "\tx" << var(i).label() << " -- p" << I << ";" << endl;
314 outfile << "}" << endl;
315 } catch (char *e) {
316 cout << e << endl;
317 return 1;
318 }
319 outfile.close();
320 return 0;
321 } else {
322 cout << "ERROR OPENING FILE" << endl;
323 return 1;
324 }
325 }
326
327
328 bool hasShortLoops( const vector<Factor> &P ) {
329 bool found = false;
330 vector<Factor>::const_iterator I, J;
331 for( I = P.begin(); I != P.end(); I++ ) {
332 J = I;
333 J++;
334 for( ; J != P.end(); J++ )
335 if( (I->vars() & J->vars()).size() >= 2 ) {
336 found = true;
337 break;
338 }
339 if( found )
340 break;
341 }
342 return found;
343 }
344
345
346 void RemoveShortLoops(vector<Factor> &P) {
347 bool found = true;
348 while( found ) {
349 found = false;
350 vector<Factor>::iterator I, J;
351 for( I = P.begin(); I != P.end(); I++ ) {
352 J = I;
353 J++;
354 for( ; J != P.end(); J++ )
355 if( (I->vars() & J->vars()).size() >= 2 ) {
356 found = true;
357 break;
358 }
359 if( found )
360 break;
361 }
362 if( found ) {
363 cout << "Merging factors " << I->vars() << " and " << J->vars() << endl;
364 *I *= *J;
365 P.erase(J);
366 }
367 }
368 }
369
370
371 vector<VarSet> FactorGraph::Cliques() const {
372 vector<VarSet> result;
373
374 for( size_t I = 0; I < nrFactors(); I++ ) {
375 bool maximal = true;
376 for( size_t J = 0; (J < nrFactors()) && maximal; J++ )
377 if( (factor(J).vars() >> factor(I).vars()) && !(factor(J).vars() == factor(I).vars()) )
378 maximal = false;
379
380 if( maximal )
381 result.push_back( factor(I).vars() );
382 }
383
384 return result;
385 }
386
387
388 void FactorGraph::clamp( const Var & n, size_t i ) {
389 assert( i <= n.states() );
390
391 // Multiply each factor that contains the variable with a delta function
392
393 Factor delta_n_i(n,0.0);
394 delta_n_i[i] = 1.0;
395
396 // For all factors that contain n
397 for( size_t I = 0; I < nrFactors(); I++ )
398 if( factor(I).vars() && n )
399 // Multiply it with a delta function
400 factor(I) *= delta_n_i;
401
402 return;
403 }
404
405
406 void FactorGraph::saveProb( size_t I ) {
407 map<size_t,Prob>::iterator it = _undoProbs.find( I );
408 if( it != _undoProbs.end() )
409 cout << "FactorGraph::saveProb: WARNING: _undoProbs[I] already defined!" << endl;
410 _undoProbs[I] = factor(I).p();
411 }
412
413
414 void FactorGraph::undoProb( size_t I ) {
415 map<size_t,Prob>::iterator it = _undoProbs.find( I );
416 if( it != _undoProbs.end() ) {
417 factor(I).p() = (*it).second;
418 _undoProbs.erase(it);
419 }
420 }
421
422
423 void FactorGraph::saveProbs( const VarSet &ns ) {
424 if( !_undoProbs.empty() )
425 cout << "FactorGraph::saveProbs: WARNING: _undoProbs not empy!" << endl;
426 for( size_t I = 0; I < nrFactors(); I++ )
427 if( factor(I).vars() && ns )
428 _undoProbs[I] = factor(I).p();
429 }
430
431
432 void FactorGraph::undoProbs( const VarSet &ns ) {
433 for( map<size_t,Prob>::iterator uI = _undoProbs.begin(); uI != _undoProbs.end(); ) {
434 if( factor((*uI).first).vars() && ns ) {
435 // cout << "undoing " << factor((*uI).first).vars() << endl;
436 // cout << "from " << factor((*uI).first).p() << " to " << (*uI).second << endl;
437 factor((*uI).first).p() = (*uI).second;
438 _undoProbs.erase(uI++);
439 } else
440 uI++;
441 }
442 }
443
444
445 } // end of namespace dai