1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
5 This file is part of libDAI.
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
32 #include <dai/factorgraph.h>
34 #include <dai/exceptions.h>
43 FactorGraph::FactorGraph( const std::vector
<Factor
> &P
) : G(), _backup() {
44 // add factors, obtain variables
46 _factors
.reserve( P
.size() );
48 for( vector
<Factor
>::const_iterator p2
= P
.begin(); p2
!= P
.end(); p2
++ ) {
49 _factors
.push_back( *p2
);
50 copy( p2
->vars().begin(), p2
->vars().end(), inserter( varset
, varset
.begin() ) );
51 nrEdges
+= p2
->vars().size();
55 _vars
.reserve( varset
.size() );
56 for( set
<Var
>::const_iterator p1
= varset
.begin(); p1
!= varset
.end(); p1
++ )
57 _vars
.push_back( *p1
);
59 // create graph structure
60 constructGraph( nrEdges
);
64 void FactorGraph::constructGraph( size_t nrEdges
) {
65 // create a mapping for indices
66 hash_map
<size_t, size_t> hashmap
;
68 for( size_t i
= 0; i
< vars().size(); i
++ )
69 hashmap
[var(i
).label()] = i
;
73 edges
.reserve( nrEdges
);
74 for( size_t i2
= 0; i2
< nrFactors(); i2
++ ) {
75 const VarSet
& ns
= factor(i2
).vars();
76 for( VarSet::const_iterator q
= ns
.begin(); q
!= ns
.end(); q
++ )
77 edges
.push_back( Edge(hashmap
[q
->label()], i2
) );
80 // create bipartite graph
81 G
.construct( nrVars(), nrFactors(), edges
.begin(), edges
.end() );
85 /// Writes a FactorGraph to an output stream
86 ostream
& operator << (ostream
& os
, const FactorGraph
& fg
) {
87 os
<< fg
.nrFactors() << endl
;
89 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
91 os
<< fg
.factor(I
).vars().size() << endl
;
92 for( VarSet::const_iterator i
= fg
.factor(I
).vars().begin(); i
!= fg
.factor(I
).vars().end(); i
++ )
93 os
<< i
->label() << " ";
95 for( VarSet::const_iterator i
= fg
.factor(I
).vars().begin(); i
!= fg
.factor(I
).vars().end(); i
++ )
96 os
<< i
->states() << " ";
98 size_t nr_nonzeros
= 0;
99 for( size_t k
= 0; k
< fg
.factor(I
).states(); k
++ )
100 if( fg
.factor(I
)[k
] != 0.0 )
102 os
<< nr_nonzeros
<< endl
;
103 for( size_t k
= 0; k
< fg
.factor(I
).states(); k
++ )
104 if( fg
.factor(I
)[k
] != 0.0 )
105 os
<< k
<< " " << setw(os
.precision()+4) << fg
.factor(I
)[k
] << endl
;
112 /// Reads a FactorGraph from an input stream
113 istream
& operator >> (istream
& is
, FactorGraph
& fg
) {
120 while( (is
.peek()) == '#' )
124 DAI_THROW(INVALID_FACTORGRAPH_FILE
);
126 cerr
<< "Reading " << nr_Factors
<< " factors..." << endl
;
130 DAI_THROW(INVALID_FACTORGRAPH_FILE
);
132 map
<long,size_t> vardims
;
133 for( size_t I
= 0; I
< nr_Factors
; I
++ ) {
135 cerr
<< "Reading factor " << I
<< "..." << endl
;
137 while( (is
.peek()) == '#' )
141 cerr
<< " nr_members: " << nr_members
<< endl
;
144 for( size_t mi
= 0; mi
< nr_members
; mi
++ ) {
146 while( (is
.peek()) == '#' )
149 labels
.push_back(mi_label
);
152 cerr
<< " labels: " << labels
<< endl
;
155 for( size_t mi
= 0; mi
< nr_members
; mi
++ ) {
157 while( (is
.peek()) == '#' )
160 dims
.push_back(mi_dim
);
163 cerr
<< " dimensions: " << dims
<< endl
;
167 for( size_t mi
= 0; mi
< nr_members
; mi
++ ) {
168 map
<long,size_t>::iterator vdi
= vardims
.find( labels
[mi
] );
169 if( vdi
!= vardims
.end() ) {
170 // check whether dimensions are consistent
171 if( vdi
->second
!= dims
[mi
] )
172 DAI_THROW(INVALID_FACTORGRAPH_FILE
);
174 vardims
[labels
[mi
]] = dims
[mi
];
175 I_vars
|= Var(labels
[mi
], dims
[mi
]);
177 facs
.push_back( Factor( I_vars
, 0.0 ) );
179 // calculate permutation sigma (internally, members are sorted)
180 vector
<size_t> sigma(nr_members
,0);
181 VarSet::iterator j
= I_vars
.begin();
182 for( size_t mi
= 0; mi
< nr_members
; mi
++,j
++ ) {
183 long search_for
= j
->label();
184 vector
<long>::iterator j_loc
= find(labels
.begin(),labels
.end(),search_for
);
185 sigma
[mi
] = j_loc
- labels
.begin();
188 cerr
<< " sigma: " << sigma
<< endl
;
190 // calculate multindices
191 Permute
permindex( dims
, sigma
);
195 while( (is
.peek()) == '#' )
199 cerr
<< " nonzeroes: " << nr_nonzeros
<< endl
;
200 for( size_t k
= 0; k
< nr_nonzeros
; k
++ ) {
203 while( (is
.peek()) == '#' )
206 while( (is
.peek()) == '#' )
210 // store value, but permute indices first according
211 // to internal representation
212 facs
.back()[permindex
.convert_linear_index( li
)] = val
;
217 cerr
<< "factors:" << facs
<< endl
;
219 fg
= FactorGraph(facs
);
225 VarSet
FactorGraph::delta( unsigned i
) const {
226 return( Delta(i
) / var(i
) );
230 VarSet
FactorGraph::Delta( unsigned i
) const {
231 // calculate Markov Blanket
233 foreach( const Neighbor
&I
, nbV(i
) ) // for all neighboring factors I of i
234 foreach( const Neighbor
&j
, nbF(I
) ) // for all neighboring variables j of I
241 VarSet
FactorGraph::Delta( const VarSet
&ns
) const {
243 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ )
244 result
|= Delta(findVar(*n
));
249 void FactorGraph::makeCavity( unsigned i
, bool backup
) {
250 // fills all Factors that include var(i) with ones
251 map
<size_t,Factor
> newFacs
;
252 foreach( const Neighbor
&I
, nbV(i
) ) // for all neighboring factors I of i
253 newFacs
[I
] = Factor(factor(I
).vars(), 1.0);
254 setFactors( newFacs
, backup
);
258 void FactorGraph::ReadFromFile( const char *filename
) {
260 infile
.open( filename
);
261 if( infile
.is_open() ) {
265 DAI_THROW(CANNOT_READ_FILE
);
269 void FactorGraph::WriteToFile( const char *filename
, size_t precision
) const {
271 outfile
.open( filename
);
272 if( outfile
.is_open() ) {
273 outfile
.precision( precision
);
277 DAI_THROW(CANNOT_WRITE_FILE
);
281 void FactorGraph::printDot( std::ostream
&os
) const {
282 os
<< "graph G {" << endl
;
283 os
<< "node[shape=circle,width=0.4,fixedsize=true];" << endl
;
284 for( size_t i
= 0; i
< nrVars(); i
++ )
285 os
<< "\tv" << var(i
).label() << ";" << endl
;
286 os
<< "node[shape=box,width=0.3,height=0.3,fixedsize=true];" << endl
;
287 for( size_t I
= 0; I
< nrFactors(); I
++ )
288 os
<< "\tf" << I
<< ";" << endl
;
289 for( size_t i
= 0; i
< nrVars(); i
++ )
290 foreach( const Neighbor
&I
, nbV(i
) ) // for all neighboring factors I of i
291 os
<< "\tv" << var(i
).label() << " -- f" << I
<< ";" << endl
;
296 vector
<VarSet
> FactorGraph::Cliques() const {
297 vector
<VarSet
> result
;
299 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
301 for( size_t J
= 0; (J
< nrFactors()) && maximal
; J
++ )
302 if( (factor(J
).vars() >> factor(I
).vars()) && (factor(J
).vars() != factor(I
).vars()) )
306 result
.push_back( factor(I
).vars() );
313 void FactorGraph::clamp( const Var
& n
, size_t i
, bool backup
) {
314 assert( i
<= n
.states() );
316 // Multiply each factor that contains the variable with a delta function
318 Factor
delta_n_i(n
,0.0);
321 map
<size_t, Factor
> newFacs
;
322 // For all factors that contain n
323 for( size_t I
= 0; I
< nrFactors(); I
++ )
324 if( factor(I
).vars().contains( n
) )
325 // Multiply it with a delta function
326 newFacs
[I
] = factor(I
) * delta_n_i
;
327 setFactors( newFacs
, backup
);
333 void FactorGraph::clampVar( size_t i
, const vector
<size_t> &is
, bool backup
) {
335 Factor
mask_n( n
, 0.0 );
337 foreach( size_t i
, is
) {
338 assert( i
<= n
.states() );
342 map
<size_t, Factor
> newFacs
;
343 for( size_t I
= 0; I
< nrFactors(); I
++ )
344 if( factor(I
).vars().contains( n
) ) {
345 newFacs
[I
] = factor(I
) * mask_n
;
347 setFactors( newFacs
, backup
);
351 void FactorGraph::clampFactor( size_t I
, const vector
<size_t> &is
, bool backup
) {
352 size_t st
= factor(I
).states();
353 Factor
newF( factor(I
).vars(), 0.0 );
355 foreach( size_t i
, is
) {
357 newF
[i
] = factor(I
)[i
];
360 setFactor( I
, newF
, backup
);
364 void FactorGraph::backupFactor( size_t I
) {
365 map
<size_t,Factor
>::iterator it
= _backup
.find( I
);
366 if( it
!= _backup
.end() )
367 DAI_THROW( MULTIPLE_UNDO
);
368 _backup
[I
] = factor(I
);
372 void FactorGraph::restoreFactor( size_t I
) {
373 map
<size_t,Factor
>::iterator it
= _backup
.find( I
);
374 if( it
!= _backup
.end() ) {
375 setFactor(I
, it
->second
);
381 void FactorGraph::backupFactors( const VarSet
&ns
) {
382 for( size_t I
= 0; I
< nrFactors(); I
++ )
383 if( factor(I
).vars().intersects( ns
) )
388 void FactorGraph::restoreFactors( const VarSet
&ns
) {
389 map
<size_t,Factor
> facs
;
390 for( map
<size_t,Factor
>::iterator uI
= _backup
.begin(); uI
!= _backup
.end(); ) {
391 if( factor(uI
->first
).vars().intersects( ns
) ) {
401 void FactorGraph::restoreFactors() {
402 setFactors( _backup
);
407 void FactorGraph::backupFactors( const std::set
<size_t> & facs
) {
408 for( std::set
<size_t>::const_iterator fac
= facs
.begin(); fac
!= facs
.end(); fac
++ )
409 backupFactor( *fac
);
413 bool FactorGraph::isPairwise() const {
414 bool pairwise
= true;
415 for( size_t I
= 0; I
< nrFactors() && pairwise
; I
++ )
416 if( factor(I
).vars().size() > 2 )
422 bool FactorGraph::isBinary() const {
424 for( size_t i
= 0; i
< nrVars() && binary
; i
++ )
425 if( var(i
).states() > 2 )
431 FactorGraph
FactorGraph::clamped( const Var
& v_i
, size_t state
) const {
432 Real zeroth_order
= 1.0;
433 vector
<Factor
> clamped_facs
;
434 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
435 VarSet v_I
= factor(I
).vars();
437 if( v_I
.intersects( v_i
) )
438 new_factor
= factor(I
).slice( v_i
, state
);
440 new_factor
= factor(I
);
442 if( new_factor
.vars().size() != 0 ) {
444 // if it can be merged with a previous one, do that
445 for( J
= 0; J
< clamped_facs
.size(); J
++ )
446 if( clamped_facs
[J
].vars() == new_factor
.vars() ) {
447 clamped_facs
[J
] *= new_factor
;
450 // otherwise, push it back
451 if( J
== clamped_facs
.size() || clamped_facs
.size() == 0 )
452 clamped_facs
.push_back( new_factor
);
454 zeroth_order
*= new_factor
[0];
456 *(clamped_facs
.begin()) *= zeroth_order
;
457 return FactorGraph( clamped_facs
);
461 FactorGraph
FactorGraph::maximalFactors() const {
462 vector
<size_t> maxfac( nrFactors() );
463 map
<size_t,size_t> newindex
;
465 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
467 VarSet maxfacvars
= factor(maxfac
[I
]).vars();
468 for( size_t J
= 0; J
< nrFactors(); J
++ ) {
469 VarSet Jvars
= factor(J
).vars();
470 if( Jvars
>> maxfacvars
&& (Jvars
!= maxfacvars
) ) {
472 maxfacvars
= factor(maxfac
[I
]).vars();
476 newindex
[I
] = nrmax
++;
479 vector
<Factor
> facs( nrmax
);
480 for( size_t I
= 0; I
< nrFactors(); I
++ )
481 facs
[newindex
[maxfac
[I
]]] *= factor(I
);
483 return FactorGraph( facs
.begin(), facs
.end(), vars().begin(), vars().end(), facs
.size(), nrVars() );
487 } // end of namespace dai