1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
21 #include <dai/factorgraph.h>
23 #include <dai/exceptions.h>
24 #include <boost/lexical_cast.hpp>
33 FactorGraph::FactorGraph( const std::vector
<Factor
> &P
) : G(), _backup() {
34 // add factors, obtain variables
36 _factors
.reserve( P
.size() );
38 for( vector
<Factor
>::const_iterator p2
= P
.begin(); p2
!= P
.end(); p2
++ ) {
39 _factors
.push_back( *p2
);
40 copy( p2
->vars().begin(), p2
->vars().end(), inserter( varset
, varset
.begin() ) );
41 nrEdges
+= p2
->vars().size();
45 _vars
.reserve( varset
.size() );
46 for( set
<Var
>::const_iterator p1
= varset
.begin(); p1
!= varset
.end(); p1
++ )
47 _vars
.push_back( *p1
);
49 // create graph structure
50 constructGraph( nrEdges
);
54 void FactorGraph::constructGraph( size_t nrEdges
) {
55 // create a mapping for indices
56 hash_map
<size_t, size_t> hashmap
;
58 for( size_t i
= 0; i
< vars().size(); i
++ )
59 hashmap
[var(i
).label()] = i
;
63 edges
.reserve( nrEdges
);
64 for( size_t i2
= 0; i2
< nrFactors(); i2
++ ) {
65 const VarSet
& ns
= factor(i2
).vars();
66 for( VarSet::const_iterator q
= ns
.begin(); q
!= ns
.end(); q
++ )
67 edges
.push_back( Edge(hashmap
[q
->label()], i2
) );
70 // create bipartite graph
71 G
.construct( nrVars(), nrFactors(), edges
.begin(), edges
.end() );
75 /// Writes a FactorGraph to an output stream
76 std::ostream
& operator<< ( std::ostream
&os
, const FactorGraph
&fg
) {
77 os
<< fg
.nrFactors() << endl
;
79 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
81 os
<< fg
.factor(I
).vars().size() << endl
;
82 for( VarSet::const_iterator i
= fg
.factor(I
).vars().begin(); i
!= fg
.factor(I
).vars().end(); i
++ )
83 os
<< i
->label() << " ";
85 for( VarSet::const_iterator i
= fg
.factor(I
).vars().begin(); i
!= fg
.factor(I
).vars().end(); i
++ )
86 os
<< i
->states() << " ";
88 size_t nr_nonzeros
= 0;
89 for( size_t k
= 0; k
< fg
.factor(I
).states(); k
++ )
90 if( fg
.factor(I
)[k
] != (Real
)0 )
92 os
<< nr_nonzeros
<< endl
;
93 for( size_t k
= 0; k
< fg
.factor(I
).states(); k
++ )
94 if( fg
.factor(I
)[k
] != (Real
)0 )
95 os
<< k
<< " " << setw(os
.precision()+4) << fg
.factor(I
)[k
] << endl
;
102 /// Reads a FactorGraph from an input stream
103 std::istream
& operator>> ( std::istream
& is
, FactorGraph
&fg
) {
110 while( (is
.peek()) == '#' )
114 DAI_THROWE(INVALID_FACTORGRAPH_FILE
,"Cannot read number of factors");
116 cerr
<< "Reading " << nr_Factors
<< " factors..." << endl
;
120 DAI_THROWE(INVALID_FACTORGRAPH_FILE
,"Expecting empty line");
122 map
<long,size_t> vardims
;
123 for( size_t I
= 0; I
< nr_Factors
; I
++ ) {
125 cerr
<< "Reading factor " << I
<< "..." << endl
;
127 while( (is
.peek()) == '#' )
131 cerr
<< " nr_members: " << nr_members
<< endl
;
134 for( size_t mi
= 0; mi
< nr_members
; mi
++ ) {
136 while( (is
.peek()) == '#' )
139 labels
.push_back(mi_label
);
142 cerr
<< " labels: " << labels
<< endl
;
145 for( size_t mi
= 0; mi
< nr_members
; mi
++ ) {
147 while( (is
.peek()) == '#' )
150 dims
.push_back(mi_dim
);
153 cerr
<< " dimensions: " << dims
<< endl
;
157 Ivars
.reserve( nr_members
);
158 for( size_t mi
= 0; mi
< nr_members
; mi
++ ) {
159 map
<long,size_t>::iterator vdi
= vardims
.find( labels
[mi
] );
160 if( vdi
!= vardims
.end() ) {
161 // check whether dimensions are consistent
162 if( vdi
->second
!= dims
[mi
] )
163 DAI_THROWE(INVALID_FACTORGRAPH_FILE
,"Variable with label " + boost::lexical_cast
<string
>(labels
[mi
]) + " has inconsistent dimensions.");
165 vardims
[labels
[mi
]] = dims
[mi
];
166 Ivars
.push_back( Var(labels
[mi
], dims
[mi
]) );
168 facs
.push_back( Factor( VarSet( Ivars
.begin(), Ivars
.end(), Ivars
.size() ), (Real
)0 ) );
170 // calculate permutation object
171 Permute
permindex( Ivars
);
175 while( (is
.peek()) == '#' )
179 cerr
<< " nonzeroes: " << nr_nonzeros
<< endl
;
180 for( size_t k
= 0; k
< nr_nonzeros
; k
++ ) {
183 while( (is
.peek()) == '#' )
186 while( (is
.peek()) == '#' )
190 // store value, but permute indices first according to internal representation
191 facs
.back()[permindex
.convertLinearIndex( li
)] = val
;
196 cerr
<< "factors:" << facs
<< endl
;
198 fg
= FactorGraph(facs
);
204 VarSet
FactorGraph::delta( size_t i
) const {
205 return( Delta(i
) / var(i
) );
209 VarSet
FactorGraph::Delta( size_t i
) const {
210 // calculate Markov Blanket
212 foreach( const Neighbor
&I
, nbV(i
) ) // for all neighboring factors I of i
213 foreach( const Neighbor
&j
, nbF(I
) ) // for all neighboring variables j of I
220 VarSet
FactorGraph::Delta( const VarSet
&ns
) const {
222 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ )
223 result
|= Delta(findVar(*n
));
228 void FactorGraph::makeCavity( size_t i
, bool backup
) {
229 // fills all Factors that include var(i) with ones
230 map
<size_t,Factor
> newFacs
;
231 foreach( const Neighbor
&I
, nbV(i
) ) // for all neighboring factors I of i
232 newFacs
[I
] = Factor( factor(I
).vars(), (Real
)1 );
233 setFactors( newFacs
, backup
);
237 void FactorGraph::ReadFromFile( const char *filename
) {
239 infile
.open( filename
);
240 if( infile
.is_open() ) {
244 DAI_THROWE(CANNOT_READ_FILE
,"Cannot read from file " + std::string(filename
));
248 void FactorGraph::WriteToFile( const char *filename
, size_t precision
) const {
250 outfile
.open( filename
);
251 if( outfile
.is_open() ) {
252 outfile
.precision( precision
);
256 DAI_THROWE(CANNOT_WRITE_FILE
,"Cannot write to file " + std::string(filename
));
260 void FactorGraph::printDot( std::ostream
&os
) const {
261 os
<< "graph G {" << endl
;
262 os
<< "node[shape=circle,width=0.4,fixedsize=true];" << endl
;
263 for( size_t i
= 0; i
< nrVars(); i
++ )
264 os
<< "\tv" << var(i
).label() << ";" << endl
;
265 os
<< "node[shape=box,width=0.3,height=0.3,fixedsize=true];" << endl
;
266 for( size_t I
= 0; I
< nrFactors(); I
++ )
267 os
<< "\tf" << I
<< ";" << endl
;
268 for( size_t i
= 0; i
< nrVars(); i
++ )
269 foreach( const Neighbor
&I
, nbV(i
) ) // for all neighboring factors I of i
270 os
<< "\tv" << var(i
).label() << " -- f" << I
<< ";" << endl
;
275 vector
<VarSet
> FactorGraph::Cliques() const {
276 vector
<VarSet
> result
;
278 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
280 for( size_t J
= 0; (J
< nrFactors()) && maximal
; J
++ )
281 if( (factor(J
).vars() >> factor(I
).vars()) && (factor(J
).vars() != factor(I
).vars()) )
285 result
.push_back( factor(I
).vars() );
292 void FactorGraph::clamp( size_t i
, size_t x
, bool backup
) {
293 DAI_ASSERT( x
<= var(i
).states() );
294 Factor
mask( var(i
), (Real
)0 );
297 map
<size_t, Factor
> newFacs
;
298 foreach( const Neighbor
&I
, nbV(i
) )
299 newFacs
[I
] = factor(I
) * mask
;
300 setFactors( newFacs
, backup
);
306 void FactorGraph::clampVar( size_t i
, const vector
<size_t> &is
, bool backup
) {
308 Factor
mask_n( n
, (Real
)0 );
310 foreach( size_t i
, is
) {
311 DAI_ASSERT( i
<= n
.states() );
315 map
<size_t, Factor
> newFacs
;
316 foreach( const Neighbor
&I
, nbV(i
) )
317 newFacs
[I
] = factor(I
) * mask_n
;
318 setFactors( newFacs
, backup
);
322 void FactorGraph::clampFactor( size_t I
, const vector
<size_t> &is
, bool backup
) {
323 size_t st
= factor(I
).states();
324 Factor
newF( factor(I
).vars(), (Real
)0 );
326 foreach( size_t i
, is
) {
327 DAI_ASSERT( i
<= st
);
328 newF
[i
] = factor(I
)[i
];
331 setFactor( I
, newF
, backup
);
335 void FactorGraph::backupFactor( size_t I
) {
336 map
<size_t,Factor
>::iterator it
= _backup
.find( I
);
337 if( it
!= _backup
.end() )
338 DAI_THROW(MULTIPLE_UNDO
);
339 _backup
[I
] = factor(I
);
343 void FactorGraph::restoreFactor( size_t I
) {
344 map
<size_t,Factor
>::iterator it
= _backup
.find( I
);
345 if( it
!= _backup
.end() ) {
346 setFactor(I
, it
->second
);
352 void FactorGraph::backupFactors( const VarSet
&ns
) {
353 for( size_t I
= 0; I
< nrFactors(); I
++ )
354 if( factor(I
).vars().intersects( ns
) )
359 void FactorGraph::restoreFactors( const VarSet
&ns
) {
360 map
<size_t,Factor
> facs
;
361 for( map
<size_t,Factor
>::iterator uI
= _backup
.begin(); uI
!= _backup
.end(); ) {
362 if( factor(uI
->first
).vars().intersects( ns
) ) {
372 void FactorGraph::restoreFactors() {
373 setFactors( _backup
);
378 void FactorGraph::backupFactors( const std::set
<size_t> & facs
) {
379 for( std::set
<size_t>::const_iterator fac
= facs
.begin(); fac
!= facs
.end(); fac
++ )
380 backupFactor( *fac
);
384 bool FactorGraph::isPairwise() const {
385 bool pairwise
= true;
386 for( size_t I
= 0; I
< nrFactors() && pairwise
; I
++ )
387 if( factor(I
).vars().size() > 2 )
393 bool FactorGraph::isBinary() const {
395 for( size_t i
= 0; i
< nrVars() && binary
; i
++ )
396 if( var(i
).states() > 2 )
402 FactorGraph
FactorGraph::clamped( size_t i
, size_t state
) const {
404 Real zeroth_order
= (Real
)1;
405 vector
<Factor
> clamped_facs
;
406 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
407 VarSet v_I
= factor(I
).vars();
409 if( v_I
.intersects( v
) )
410 new_factor
= factor(I
).slice( v
, state
);
412 new_factor
= factor(I
);
414 if( new_factor
.vars().size() != 0 ) {
416 // if it can be merged with a previous one, do that
417 for( J
= 0; J
< clamped_facs
.size(); J
++ )
418 if( clamped_facs
[J
].vars() == new_factor
.vars() ) {
419 clamped_facs
[J
] *= new_factor
;
422 // otherwise, push it back
423 if( J
== clamped_facs
.size() || clamped_facs
.size() == 0 )
424 clamped_facs
.push_back( new_factor
);
426 zeroth_order
*= new_factor
[0];
428 *(clamped_facs
.begin()) *= zeroth_order
;
429 return FactorGraph( clamped_facs
);
433 FactorGraph
FactorGraph::maximalFactors() const {
434 vector
<size_t> maxfac( nrFactors() );
435 map
<size_t,size_t> newindex
;
437 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
439 VarSet maxfacvars
= factor(maxfac
[I
]).vars();
440 for( size_t J
= 0; J
< nrFactors(); J
++ ) {
441 VarSet Jvars
= factor(J
).vars();
442 if( Jvars
>> maxfacvars
&& (Jvars
!= maxfacvars
) ) {
444 maxfacvars
= factor(maxfac
[I
]).vars();
448 newindex
[I
] = nrmax
++;
451 vector
<Factor
> facs( nrmax
);
452 for( size_t I
= 0; I
< nrFactors(); I
++ )
453 facs
[newindex
[maxfac
[I
]]] *= factor(I
);
455 return FactorGraph( facs
.begin(), facs
.end(), vars().begin(), vars().end(), facs
.size(), nrVars() );
459 } // end of namespace dai