1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
30 #include <tr1/unordered_map>
31 #include <dai/factorgraph.h>
40 FactorGraph::FactorGraph( const vector
<Factor
> &P
) : BipartiteGraph
<Var
,Factor
>(), _undoProbs(), _normtype(Prob::NORMPROB
) {
41 // add factors, obtain variables
43 V2s().reserve( P
.size() );
45 for( vector
<Factor
>::const_iterator p2
= P
.begin(); p2
!= P
.end(); p2
++ ) {
46 V2s().push_back( *p2
);
47 copy( p2
->vars().begin(), p2
->vars().end(), inserter( _vars
, _vars
.begin() ) );
48 nrEdges
+= p2
->vars().size();
52 V1s().reserve( _vars
.size() );
53 for( set
<Var
>::const_iterator p1
= _vars
.begin(); p1
!= _vars
.end(); p1
++ )
54 V1s().push_back( *p1
);
56 // create graph structure
57 createGraph( nrEdges
);
61 /// Part of constructors (creates edges, neighbours and adjacency matrix)
62 void FactorGraph::createGraph( size_t nrEdges
) {
63 // create a mapping for indices
64 std::tr1::unordered_map
<size_t, size_t> hashmap
;
66 for( size_t i
= 0; i
< vars().size(); i
++ )
67 hashmap
[vars()[i
].label()] = i
;
70 edges().reserve( nrEdges
);
71 for( size_t i2
= 0; i2
< nrFactors(); i2
++ ) {
72 const VarSet
& ns
= factor(i2
).vars();
73 for( VarSet::const_iterator q
= ns
.begin(); q
!= ns
.end(); q
++ )
74 edges().push_back(_edge_t(hashmap
[q
->label()], i2
));
77 // calc neighbours and adjacency matrix
82 /*FactorGraph& FactorGraph::addFactor( const Factor &I ) {
86 // add new vars in Factor
87 for( VarSet::const_iterator i = I.vars().begin(); i != I.vars().end(); i++ ) {
88 size_t i_ind = find(vars().begin(), vars().end(), *i) - vars().begin();
89 if( i_ind == vars().size() )
91 _E12.push_back( _edge_t( i_ind, nrFactors() - 1 ) );
99 ostream
& operator << (ostream
& os
, const FactorGraph
& fg
) {
100 os
<< fg
.nrFactors() << endl
;
102 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
104 os
<< fg
.factor(I
).vars().size() << endl
;
105 for( VarSet::const_iterator i
= fg
.factor(I
).vars().begin(); i
!= fg
.factor(I
).vars().end(); i
++ )
106 os
<< i
->label() << " ";
108 for( VarSet::const_iterator i
= fg
.factor(I
).vars().begin(); i
!= fg
.factor(I
).vars().end(); i
++ )
109 os
<< i
->states() << " ";
111 size_t nr_nonzeros
= 0;
112 for( size_t k
= 0; k
< fg
.factor(I
).stateSpace(); k
++ )
113 if( fg
.factor(I
)[k
] != 0.0 )
115 os
<< nr_nonzeros
<< endl
;
116 for( size_t k
= 0; k
< fg
.factor(I
).stateSpace(); k
++ )
117 if( fg
.factor(I
)[k
] != 0.0 ) {
119 sprintf(buf
,"%18.14g", fg
.factor(I
)[k
]);
120 os
<< k
<< " " << buf
<< endl
;
128 istream
& operator >> (istream
& is
, FactorGraph
& fg
) {
132 vector
<Factor
> factors
;
136 while( (is
.peek()) == '#' )
140 throw "ReadFromFile: unable to read number of Factors";
142 cout
<< "Reading " << nr_f
<< " factors..." << endl
;
146 throw "ReadFromFile: empty line expected";
148 for( size_t I
= 0; I
< nr_f
; I
++ ) {
150 cout
<< "Reading factor " << I
<< "..." << endl
;
152 while( (is
.peek()) == '#' )
156 cout
<< " nr_members: " << nr_members
<< endl
;
159 for( size_t mi
= 0; mi
< nr_members
; mi
++ ) {
161 while( (is
.peek()) == '#' )
164 labels
.push_back(mi_label
);
168 copy (labels
.begin(), labels
.end(), ostream_iterator
<int>(cout
, " "));
173 for( size_t mi
= 0; mi
< nr_members
; mi
++ ) {
175 while( (is
.peek()) == '#' )
178 dims
.push_back(mi_dim
);
181 cout
<< " dimensions: ";
182 copy (dims
.begin(), dims
.end(), ostream_iterator
<int>(cout
, " "));
188 for( size_t mi
= 0; mi
< nr_members
; mi
++ )
189 I_vars
.insert( Var(labels
[mi
], dims
[mi
]) );
190 factors
.push_back(Factor(I_vars
,0.0));
192 // calculate permutation sigma (internally, members are sorted)
193 vector
<long> sigma(nr_members
,0);
194 VarSet::iterator j
= I_vars
.begin();
195 for( size_t mi
= 0; mi
< nr_members
; mi
++,j
++ ) {
196 long search_for
= j
->label();
197 vector
<long>::iterator j_loc
= find(labels
.begin(),labels
.end(),search_for
);
198 sigma
[mi
] = j_loc
- labels
.begin();
202 copy( sigma
.begin(), sigma
.end(), ostream_iterator
<int>(cout
," "));
206 // calculate multindices
207 vector
<size_t> sdims(nr_members
,0);
208 for( size_t k
= 0; k
< nr_members
; k
++ ) {
209 sdims
[k
] = dims
[sigma
[k
]];
214 cout
<< " mi.max(): " << mi
.max() << endl
;
216 for( size_t k
=0; k
< nr_members
; k
++ )
217 cout
<< labels
[k
] << " ";
219 for( size_t k
=0; k
< nr_members
; k
++ )
220 cout
<< labels
[sigma
[k
]] << " ";
226 while( (is
.peek()) == '#' )
230 cout
<< " nonzeroes: " << nr_nonzeros
<< endl
;
231 for( size_t k
= 0; k
< nr_nonzeros
; k
++ ) {
234 while( (is
.peek()) == '#' )
237 while( (is
.peek()) == '#' )
241 vector
<size_t> vi
= mi
.vi(li
);
242 vector
<size_t> svi(vi
.size(),0);
243 for( size_t k
= 0; k
< vi
.size(); k
++ )
244 svi
[k
] = vi
[sigma
[k
]];
245 size_t sli
= smi
.li(svi
);
247 cout
<< " " << li
<< ": ";
248 copy( vi
.begin(), vi
.end(), ostream_iterator
<size_t>(cout
," "));
250 copy( svi
.begin(), svi
.end(), ostream_iterator
<size_t>(cout
," "));
251 cout
<< ": " << sli
<< endl
;
253 factors
.back()[sli
] = val
;
258 cout
<< "factors:" << endl
;
259 copy(factors
.begin(), factors
.end(), ostream_iterator
<Factor
>(cout
,"\n"));
262 fg
= FactorGraph(factors
);
271 VarSet
FactorGraph::delta(const Var
& n
) const {
272 // calculate Markov Blanket
273 size_t i
= findVar( n
);
276 for( _nb_cit I
= nb1(i
).begin(); I
!= nb1(i
).end(); ++I
)
277 for( _nb_cit j
= nb2(*I
).begin(); j
!= nb2(*I
).end(); ++j
)
285 VarSet
FactorGraph::Delta(const Var
& n
) const {
286 return( delta(n
) | n
);
290 void FactorGraph::makeFactorCavity(size_t I
) {
291 // fill Factor I with ones
296 void FactorGraph::makeCavity(const Var
& n
) {
297 // fills all Factors that include Var n with ones
298 size_t i
= findVar( n
);
300 for( _nb_cit I
= nb1(i
).begin(); I
!= nb1(i
).end(); ++I
)
301 factor(*I
).fill(1.0);
305 bool FactorGraph::hasNegatives() const {
307 for( size_t I
= 0; I
< nrFactors() && !result
; I
++ )
308 if( factor(I
).hasNegatives() )
314 /*FactorGraph & FactorGraph::DeleteFactor(size_t I) {
315 // Go through all edges
316 for( vector<_edge_t>::iterator edge = _E12.begin(); edge != _E12.end(); edge++ )
317 if( edge->second >= I ) {
318 if( edge->second == I )
323 // Remove all edges containing I
324 for( vector<_edge_t>::iterator edge = _E12.begin(); edge != _E12.end(); edge++ )
325 if( edge->second == -1UL )
326 edge = _E12.erase( edge );
327 // vector<_edge_t>::iterator new_end = _E12.remove_if( _E12.begin(), _E12.end(), compose1( bind2nd(equal_to<size_t>(), -1), select2nd<_edge_t>() ) );
328 // _E12.erase( new_end, _E12.end() );
331 _V2.erase( _V2.begin() + I );
339 FactorGraph & FactorGraph::DeleteVar(size_t i) {
340 // Go through all edges
341 for( vector<_edge_t>::iterator edge = _E12.begin(); edge != _E12.end(); edge++ )
342 if( edge->first >= i ) {
343 if( edge->first == i )
348 // Remove all edges containing i
349 for( vector<_edge_t>::iterator edge = _E12.begin(); edge != _E12.end(); edge++ )
350 if( edge->first == -1UL )
351 edge = _E12.erase( edge );
353 // vector<_edge_t>::iterator new_end = _E12.remove_if( _E12.begin(), _E12.end(), compose1( bind2nd(equal_to<size_t>(), -1), select1st<_edge_t>() ) );
354 // _E12.erase( new_end, _E12.end() );
356 // Erase the variable
357 _V1.erase( _V1.begin() + i );
365 long FactorGraph::ReadFromFile(const char *filename
) {
367 infile
.open (filename
);
368 if (infile
.is_open()) {
373 cout
<< "ERROR OPENING FILE" << endl
;
379 long FactorGraph::WriteToFile(const char *filename
) const {
381 outfile
.open (filename
);
382 if (outfile
.is_open()) {
392 cout
<< "ERROR OPENING FILE" << endl
;
398 long FactorGraph::WriteToDotFile(const char *filename
) const {
400 outfile
.open (filename
);
401 if (outfile
.is_open()) {
403 outfile
<< "graph G {" << endl
;
404 outfile
<< "graph[size=\"9,9\"];" << endl
;
405 outfile
<< "node[shape=circle,width=0.4,fixedsize=true];" << endl
;
406 for( size_t i
= 0; i
< nrVars(); i
++ )
407 outfile
<< "\tx" << var(i
).label() << ";" << endl
;
408 outfile
<< "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl
;
409 for( size_t I
= 0; I
< nrFactors(); I
++ )
410 outfile
<< "\tp" << I
<< ";" << endl
;
411 for( size_t iI
= 0; iI
< nr_edges(); iI
++ )
412 outfile
<< "\tx" << var(edge(iI
).first
).label() << " -- p" << edge(iI
).second
<< ";" << endl
;
413 outfile
<< "}" << endl
;
421 cout
<< "ERROR OPENING FILE" << endl
;
427 bool hasShortLoops( const vector
<Factor
> &P
) {
429 vector
<Factor
>::const_iterator I
, J
;
430 for( I
= P
.begin(); I
!= P
.end(); I
++ ) {
433 for( ; J
!= P
.end(); J
++ )
434 if( (I
->vars() & J
->vars()).size() >= 2 ) {
445 void RemoveShortLoops(vector
<Factor
> &P
) {
449 vector
<Factor
>::iterator I
, J
;
450 for( I
= P
.begin(); I
!= P
.end(); I
++ ) {
453 for( ; J
!= P
.end(); J
++ )
454 if( (I
->vars() & J
->vars()).size() >= 2 ) {
462 cout
<< "Merging factors " << I
->vars() << " and " << J
->vars() << endl
;
470 Factor
FactorGraph::ExactMarginal(const VarSet
& x
) const {
472 for( size_t I
= 0; I
< nrFactors(); I
++ )
474 return P
.marginal(x
);
478 Real
FactorGraph::ExactlogZ() const {
480 for( size_t I
= 0; I
< nrFactors(); I
++ )
482 return std::log(P
.totalSum());
486 vector
<VarSet
> FactorGraph::Cliques() const {
487 vector
<VarSet
> result
;
489 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
491 for( size_t J
= 0; (J
< nrFactors()) && maximal
; J
++ )
492 if( (factor(J
).vars() >> factor(I
).vars()) && !(factor(J
).vars() == factor(I
).vars()) )
496 result
.push_back( factor(I
).vars() );
503 void FactorGraph::clamp( const Var
& n
, size_t i
) {
504 assert( i
<= n
.states() );
506 /* if( do_surgery ) {
507 size_t ni = find( vars().begin(), vars().end(), n) - vars().begin();
509 if( ni != nrVars() ) {
510 for( _nb_cit I = nb1(ni).begin(); I != nb1(ni).end(); I++ ) {
511 if( factor(*I).size() == 1 )
512 // Remove this single-variable factor
513 // I = (_V2.erase(I))--;
514 _E12.erase( _E12.begin() + VV2E(ni, *I) );
516 // Replace it by the slice
517 Index ind_I_min_n( factor(*I), factor(*I) / n );
518 Index ind_n( factor(*I), n );
519 Factor slice_I( factor(*I) / n );
520 for( size_t ind_I = 0; ind_I < factor(*I).stateSpace(); ++ind_I, ++ind_I_min_n, ++ind_n )
522 slice_I[ind_I_min_n] = factor(*I)[ind_I];
523 factor(*I) = slice_I;
525 // Remove the edge between n and I
526 _E12.erase( _E12.begin() + VV2E(ni, *I) );
532 // remove all unconnected factors
533 for( size_t I = 0; I < nrFactors(); I++ )
534 if( nb2(I).size() == 0 )
543 // The cheap solution (at least in terms of coding time) is to multiply every factor
544 // that contains the variable with a delta function
546 Factor
delta_n_i(n
,0.0);
549 // For all factors that contain n
550 for( size_t I
= 0; I
< nrFactors(); I
++ )
551 if( factor(I
).vars() && n
)
552 // Multiply it with a delta function
553 factor(I
) *= delta_n_i
;
559 void FactorGraph::saveProb( size_t I
) {
560 map
<size_t,Prob
>::iterator it
= _undoProbs
.find( I
);
561 if( it
!= _undoProbs
.end() )
562 cout
<< "FactorGraph::saveProb: WARNING: _undoProbs[I] already defined!" << endl
;
563 _undoProbs
[I
] = factor(I
).p();
567 void FactorGraph::undoProb( size_t I
) {
568 map
<size_t,Prob
>::iterator it
= _undoProbs
.find( I
);
569 if( it
!= _undoProbs
.end() ) {
570 factor(I
).p() = (*it
).second
;
571 _undoProbs
.erase(it
);
576 void FactorGraph::saveProbs( const VarSet
&ns
) {
577 if( !_undoProbs
.empty() )
578 cout
<< "FactorGraph::saveProbs: WARNING: _undoProbs not empy!" << endl
;
579 for( size_t I
= 0; I
< nrFactors(); I
++ )
580 if( factor(I
).vars() && ns
)
581 _undoProbs
[I
] = factor(I
).p();
585 void FactorGraph::undoProbs( const VarSet
&ns
) {
586 for( map
<size_t,Prob
>::iterator uI
= _undoProbs
.begin(); uI
!= _undoProbs
.end(); ) {
587 if( factor((*uI
).first
).vars() && ns
) {
588 // cout << "undoing " << factor((*uI).first).vars() << endl;
589 // cout << "from " << factor((*uI).first).p() << " to " << (*uI).second << endl;
590 factor((*uI
).first
).p() = (*uI
).second
;
591 _undoProbs
.erase(uI
++);
598 bool FactorGraph::isConnected() const {
604 VarSet component
= n
;
607 for( size_t i
= 1; i
< nrVars(); i
++ )
610 bool found_new_vars
= true;
611 while( found_new_vars
) {
613 for( VarSet::const_iterator m
= remaining
.begin(); m
!= remaining
.end(); m
++ )
614 if( delta(*m
) && component
)
617 if( new_vars
.empty() )
618 found_new_vars
= false;
620 found_new_vars
= true;
622 component
|= new_vars
;
623 remaining
/= new_vars
;
625 return remaining
.empty();
630 } // end of namespace dai