1 /* This file is part of libDAI - http://www.libdai.org/
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
10 /// \brief Defines class ClusterGraph, which is used by JTree, TreeEP and HAK
11 /// \todo The "MinFill" and "WeightedMinFill" variable elimination heuristics seem designed for Markov graphs;
12 /// add similar heuristics which are designed for factor graphs.
15 #ifndef __defined_libdai_clustergraph_h
16 #define __defined_libdai_clustergraph_h
21 #include <dai/varset.h>
22 #include <dai/bipgraph.h>
23 #include <dai/factorgraph.h>
29 /// A ClusterGraph is a hypergraph with variables as nodes, and "clusters" (sets of variables) as hyperedges.
30 /** It is implemented as a bipartite graph with variable (Var) nodes and cluster (VarSet) nodes.
31 * One may think of a ClusterGraph as a FactorGraph without the actual factor values.
32 * \todo Remove the _vars and _clusters variables and use only the graph and a contextual factor graph.
36 /// Stores the neighborhood structure
39 /// Stores the variables corresponding to the nodes
40 std::vector
<Var
> _vars
;
42 /// Stores the clusters corresponding to the hyperedges
43 std::vector
<VarSet
> _clusters
;
46 /// \name Constructors and destructors
48 /// Default constructor
49 ClusterGraph() : _G(), _vars(), _clusters() {}
51 /// Construct from vector of VarSet 's
52 ClusterGraph( const std::vector
<VarSet
>& cls
);
54 /// Construct from a factor graph
55 /** Creates cluster graph which has factors in \a fg as clusters if \a onlyMaximal == \c false,
56 * and only the maximal factors in \a fg if \a onlyMaximal == \c true.
58 ClusterGraph( const FactorGraph
& fg
, bool onlyMaximal
);
63 /// Returns a constant reference to the graph structure
64 const BipartiteGraph
& bipGraph() const { return _G
; }
66 /// Returns number of variables
67 size_t nrVars() const { return _vars
.size(); }
69 /// Returns a constant reference to the variables
70 const std::vector
<Var
>& vars() const { return _vars
; }
72 /// Returns a constant reference to the \a i'th variable
73 const Var
& var( size_t i
) const {
74 DAI_DEBASSERT( i
< nrVars() );
78 /// Returns number of clusters
79 size_t nrClusters() const { return _clusters
.size(); }
81 /// Returns a constant reference to the clusters
82 const std::vector
<VarSet
>& clusters() const { return _clusters
; }
84 /// Returns a constant reference to the \a I'th cluster
85 const VarSet
& cluster( size_t I
) const {
86 DAI_DEBASSERT( I
< nrClusters() );
90 /// Returns the index of variable \a n
91 size_t findVar( const Var
& n
) const {
92 return find( _vars
.begin(), _vars
.end(), n
) - _vars
.begin();
95 /// Returns the index of a cluster \a cl
96 size_t findCluster( const VarSet
& cl
) const {
97 return find( _clusters
.begin(), _clusters
.end(), cl
) - _clusters
.begin();
100 /// Returns union of clusters that contain the \a i 'th variable
101 VarSet
Delta( size_t i
) const {
103 foreach( const Neighbor
& I
, _G
.nb1(i
) )
104 result
|= _clusters
[I
];
108 /// Returns union of clusters that contain the \a i 'th (except this variable itself)
109 VarSet
delta( size_t i
) const {
110 return Delta( i
) / _vars
[i
];
113 /// Returns \c true if variables with indices \a i1 and \a i2 are adjacent, i.e., both contained in the same cluster
114 bool adj( size_t i1
, size_t i2
) const {
118 foreach( const Neighbor
& I
, _G
.nb1(i1
) )
119 if( find( _G
.nb2(I
).begin(), _G
.nb2(I
).end(), i2
) != _G
.nb2(I
).end() ) {
126 /// Returns \c true if cluster \a I is not contained in a larger cluster
127 bool isMaximal( size_t I
) const {
128 DAI_DEBASSERT( I
< _G
.nrNodes2() );
129 const VarSet
& clI
= _clusters
[I
];
131 // The following may not be optimal, since it may repeatedly test the same cluster *J
132 foreach( const Neighbor
& i
, _G
.nb2(I
) ) {
133 foreach( const Neighbor
& J
, _G
.nb1(i
) )
134 if( (J
!= I
) && (clI
<< _clusters
[J
]) ) {
147 /// Inserts a cluster (if it does not already exist) and creates new variables, if necessary
148 /** \note This function could be better optimized if the index of one variable in \a cl would be known.
149 * If one could assume _vars to be ordered, a binary search could be used instead of a linear one.
151 size_t insert( const VarSet
& cl
) {
152 size_t index
= findCluster( cl
); // OPTIMIZE ME
153 if( index
== _clusters
.size() ) {
154 _clusters
.push_back( cl
);
155 // add variables (if necessary) and calculate neighborhood of new cluster
156 std::vector
<size_t> nbs
;
157 for( VarSet::const_iterator n
= cl
.begin(); n
!= cl
.end(); n
++ ) {
158 size_t iter
= findVar( *n
); // OPTIMIZE ME
159 nbs
.push_back( iter
);
160 if( iter
== _vars
.size() ) {
162 _vars
.push_back( *n
);
165 _G
.addNode2( nbs
.begin(), nbs
.end(), nbs
.size() );
170 /// Erases all clusters that are not maximal
171 ClusterGraph
& eraseNonMaximal() {
172 for( size_t I
= 0; I
< _G
.nrNodes2(); ) {
173 if( !isMaximal(I
) ) {
174 _clusters
.erase( _clusters
.begin() + I
);
182 /// Erases all clusters that contain the \a i 'th variable
183 ClusterGraph
& eraseSubsuming( size_t i
) {
184 DAI_ASSERT( i
< nrVars() );
185 while( _G
.nb1(i
).size() ) {
186 _clusters
.erase( _clusters
.begin() + _G
.nb1(i
)[0] );
187 _G
.eraseNode2( _G
.nb1(i
)[0] );
192 /// Eliminates variable with index \a i, without deleting the variable itself
193 /** \note This function can be better optimized
195 VarSet
elimVar( size_t i
) {
196 DAI_ASSERT( i
< nrVars() );
197 VarSet Di
= Delta( i
);
198 insert( Di
/ var(i
) );
205 /// \name Input/Ouput
207 /// Writes a ClusterGraph to an output stream
208 friend std::ostream
& operator << ( std::ostream
& os
, const ClusterGraph
& cl
) {
214 /// \name Variable elimination
216 /// Performs Variable Elimination, keeping track of the interactions that are created along the way.
217 /** \tparam EliminationChoice should support "size_t operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars )"
218 * \param f function object which returns the next variable index to eliminate; for example, a dai::greedyVariableElimination object.
219 * \param maxStates maximum total number of states of all clusters in the output cluster graph (0 means no limit).
220 * \throws OUT_OF_MEMORY if total number of states becomes larger than maxStates
221 * \return A set of elimination "cliques".
223 template<class EliminationChoice
>
224 ClusterGraph
VarElim( EliminationChoice f
, size_t maxStates
=0 ) const {
226 ClusterGraph
cl(*this);
227 cl
.eraseNonMaximal();
231 // Construct set of variable indices
232 std::set
<size_t> varindices
;
233 for( size_t i
= 0; i
< _vars
.size(); ++i
)
234 varindices
.insert( i
);
236 // Do variable elimination
237 long double totalStates
= 0.0;
238 while( !varindices
.empty() ) {
239 size_t i
= f( cl
, varindices
);
240 VarSet Di
= cl
.elimVar( i
);
243 totalStates
+= Di
.nrStates();
244 if( totalStates
> maxStates
)
245 DAI_THROW(OUT_OF_MEMORY
);
247 varindices
.erase( i
);
256 /// Helper object for dai::ClusterGraph::VarElim()
257 /** Chooses the next variable to eliminate by picking them sequentially from a given vector of variables.
259 class sequentialVariableElimination
{
261 /// The variable elimination sequence
262 std::vector
<Var
> seq
;
267 /// Construct from vector of variables
268 sequentialVariableElimination( const std::vector
<Var
> s
) : seq(s
), i(0) {}
270 /// Returns next variable in sequence
271 size_t operator()( const ClusterGraph
&cl
, const std::set
<size_t> &/*remainingVars*/ );
275 /// Helper object for dai::ClusterGraph::VarElim()
276 /** Chooses the next variable to eliminate greedily by taking the one that minimizes
277 * a given heuristic cost function.
279 class greedyVariableElimination
{
281 /// Type of cost functions to be used for greedy variable elimination
282 typedef size_t (*eliminationCostFunction
)(const ClusterGraph
&, size_t);
285 /// Pointer to the cost function used
286 eliminationCostFunction heuristic
;
289 /// Construct from cost function
290 /** \note Examples of cost functions are eliminationCost_MinFill() and eliminationCost_WeightedMinFill().
292 greedyVariableElimination( eliminationCostFunction h
) : heuristic(h
) {}
294 /// Returns the best variable from \a remainingVars to eliminate in the cluster graph \a cl by greedily minimizing the cost function.
295 /** This function calculates the cost for eliminating each variable in \a remaingVars and returns the variable which has lowest cost.
297 size_t operator()( const ClusterGraph
&cl
, const std::set
<size_t>& remainingVars
);
301 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinNeighbors" criterion.
302 /** The cost is measured as "number of neigboring nodes in the current adjacency graph",
303 * where the adjacency graph has the variables as its nodes and connects
304 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
306 size_t eliminationCost_MinNeighbors( const ClusterGraph
& cl
, size_t i
);
309 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinWeight" criterion.
310 /** The cost is measured as "product of weights of neighboring nodes in the current adjacency graph",
311 * where the adjacency graph has the variables as its nodes and connects
312 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
313 * The weight of a node is the number of states of the corresponding variable.
315 size_t eliminationCost_MinWeight( const ClusterGraph
& cl
, size_t i
);
318 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "MinFill" criterion.
319 /** The cost is measured as "number of added edges in the adjacency graph",
320 * where the adjacency graph has the variables as its nodes and connects
321 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
323 size_t eliminationCost_MinFill( const ClusterGraph
& cl
, size_t i
);
326 /// Calculates cost of eliminating the \a i 'th variable from cluster graph \a cl according to the "WeightedMinFill" criterion.
327 /** The cost is measured as "total weight of added edges in the adjacency graph",
328 * where the adjacency graph has the variables as its nodes and connects
329 * nodes \a i1 and \a i2 iff \a i1 and \a i2 occur together in some common cluster.
330 * The weight of an edge is the product of the number of states of the variables corresponding with its nodes.
332 size_t eliminationCost_WeightedMinFill( const ClusterGraph
& cl
, size_t i
);
335 } // end of namespace dai