New git HEAD version
[libdai.git] / src / clustergraph.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <set>
10 #include <vector>
11 #include <iostream>
12 #include <dai/varset.h>
13 #include <dai/clustergraph.h>
14
15
16 namespace dai {
17
18
19 using namespace std;
20
21
22 ClusterGraph::ClusterGraph( const std::vector<VarSet> & cls ) : _G(), _vars(), _clusters() {
23 // construct vars, clusters and edge list
24 vector<Edge> edges;
25 bforeach( const VarSet &cl, cls ) {
26 if( find( clusters().begin(), clusters().end(), cl ) == clusters().end() ) {
27 // add cluster
28 size_t n2 = nrClusters();
29 _clusters.push_back( cl );
30 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
31 size_t n1 = find( vars().begin(), vars().end(), *n ) - vars().begin();
32 if( n1 == nrVars() )
33 // add variable
34 _vars.push_back( *n );
35 edges.push_back( Edge( n1, n2 ) );
36 }
37 } // disregard duplicate clusters
38 }
39
40 // Create bipartite graph
41 _G.construct( nrVars(), nrClusters(), edges.begin(), edges.end() );
42 }
43
44
45 ClusterGraph::ClusterGraph( const FactorGraph& fg, bool onlyMaximal ) : _G( fg.nrVars(), 0 ), _vars(), _clusters() {
46 // copy variables
47 _vars.reserve( fg.nrVars() );
48 for( size_t i = 0; i < fg.nrVars(); i++ )
49 _vars.push_back( fg.var(i) );
50
51 if( onlyMaximal ) {
52 for( size_t I = 0; I < fg.nrFactors(); I++ )
53 if( fg.isMaximal( I ) ) {
54 _clusters.push_back( fg.factor(I).vars() );
55 size_t clind = _G.addNode2();
56 bforeach( const Neighbor &i, fg.nbF(I) )
57 _G.addEdge( i, clind, true );
58 }
59 } else {
60 // copy clusters
61 _clusters.reserve( fg.nrFactors() );
62 for( size_t I = 0; I < fg.nrFactors(); I++ )
63 _clusters.push_back( fg.factor(I).vars() );
64 // copy bipartite graph
65 _G = fg.bipGraph();
66 }
67 }
68
69
70 size_t sequentialVariableElimination::operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ ) {
71 return cl.findVar( seq.at(i++) );
72 }
73
74
75 size_t greedyVariableElimination::operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars ) {
76 set<size_t>::const_iterator lowest = remainingVars.end();
77 size_t lowest_cost = -1UL;
78 for( set<size_t>::const_iterator i = remainingVars.begin(); i != remainingVars.end(); i++ ) {
79 size_t cost = heuristic( cl, *i );
80 if( lowest == remainingVars.end() || lowest_cost > cost ) {
81 lowest = i;
82 lowest_cost = cost;
83 }
84 }
85 return *lowest;
86 }
87
88
89 size_t eliminationCost_MinNeighbors( const ClusterGraph &cl, size_t i ) {
90 return cl.bipGraph().delta1( i ).size();
91 }
92
93
94 size_t eliminationCost_MinWeight( const ClusterGraph &cl, size_t i ) {
95 SmallSet<size_t> id_n = cl.bipGraph().delta1( i );
96
97 size_t cost = 1;
98 for( SmallSet<size_t>::const_iterator it = id_n.begin(); it != id_n.end(); it++ )
99 cost *= cl.vars()[*it].states();
100
101 return cost;
102 }
103
104
105 size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i ) {
106 SmallSet<size_t> id_n = cl.bipGraph().delta1( i );
107
108 size_t cost = 0;
109 // for each unordered pair {i1,i2} adjacent to n
110 for( SmallSet<size_t>::const_iterator it1 = id_n.begin(); it1 != id_n.end(); it1++ )
111 for( SmallSet<size_t>::const_iterator it2 = it1; it2 != id_n.end(); it2++ )
112 if( it1 != it2 ) {
113 // if i1 and i2 are not adjacent, eliminating n would make them adjacent
114 if( !cl.adj(*it1, *it2) )
115 cost++;
116 }
117
118 return cost;
119 }
120
121
122 size_t eliminationCost_WeightedMinFill( const ClusterGraph &cl, size_t i ) {
123 SmallSet<size_t> id_n = cl.bipGraph().delta1( i );
124
125 size_t cost = 0;
126 // for each unordered pair {i1,i2} adjacent to n
127 for( SmallSet<size_t>::const_iterator it1 = id_n.begin(); it1 != id_n.end(); it1++ )
128 for( SmallSet<size_t>::const_iterator it2 = it1; it2 != id_n.end(); it2++ )
129 if( it1 != it2 ) {
130 // if i1 and i2 are not adjacent, eliminating n would make them adjacent
131 if( !cl.adj(*it1, *it2) )
132 cost += cl.vars()[*it1].states() * cl.vars()[*it2].states();
133 }
134
135 return cost;
136 }
137
138
139 } // end of namespace dai