Some small documentation updates
[libdai.git] / src / clustergraph.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <set>
13 #include <vector>
14 #include <iostream>
15 #include <dai/varset.h>
16 #include <dai/clustergraph.h>
17
18
19 namespace dai {
20
21
22 using namespace std;
23
24
25 ClusterGraph::ClusterGraph( const std::vector<VarSet> & cls ) : _G(), _vars(), _clusters() {
26 // construct vars, clusters and edge list
27 vector<Edge> edges;
28 foreach( const VarSet &cl, cls ) {
29 if( find( clusters().begin(), clusters().end(), cl ) == clusters().end() ) {
30 // add cluster
31 size_t n2 = nrClusters();
32 _clusters.push_back( cl );
33 for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
34 size_t n1 = find( vars().begin(), vars().end(), *n ) - vars().begin();
35 if( n1 == nrVars() )
36 // add variable
37 _vars.push_back( *n );
38 edges.push_back( Edge( n1, n2 ) );
39 }
40 } // disregard duplicate clusters
41 }
42
43 // Create bipartite graph
44 _G.construct( nrVars(), nrClusters(), edges.begin(), edges.end() );
45 }
46
47
48 ClusterGraph::ClusterGraph( const FactorGraph& fg, bool onlyMaximal ) : _G( fg.bipGraph() ), _vars(), _clusters() {
49 // copy variables
50 _vars.reserve( fg.nrVars() );
51 for( size_t i = 0; i < fg.nrVars(); i++ )
52 _vars.push_back( fg.var(i) );
53
54 // copy clusters
55 _clusters.reserve( fg.nrFactors() );
56 for( size_t I = 0; I < fg.nrFactors(); I++ )
57 _clusters.push_back( fg.factor(I).vars() );
58
59 if( onlyMaximal )
60 eraseNonMaximal();
61 }
62
63
64 size_t sequentialVariableElimination::operator()( const ClusterGraph &cl, const std::set<size_t> &/*remainingVars*/ ) {
65 return cl.findVar( seq.at(i++) );
66 }
67
68
69 size_t greedyVariableElimination::operator()( const ClusterGraph &cl, const std::set<size_t> &remainingVars ) {
70 set<size_t>::const_iterator lowest = remainingVars.end();
71 size_t lowest_cost = -1UL;
72 for( set<size_t>::const_iterator i = remainingVars.begin(); i != remainingVars.end(); i++ ) {
73 size_t cost = heuristic( cl, *i );
74 if( lowest == remainingVars.end() || lowest_cost > cost ) {
75 lowest = i;
76 lowest_cost = cost;
77 }
78 }
79 return *lowest;
80 }
81
82
83 size_t eliminationCost_MinNeighbors( const ClusterGraph &cl, size_t i ) {
84 return cl.bipGraph().delta1( i ).size();
85 }
86
87
88 size_t eliminationCost_MinWeight( const ClusterGraph &cl, size_t i ) {
89 SmallSet<size_t> id_n = cl.bipGraph().delta1( i );
90
91 size_t cost = 1;
92 for( SmallSet<size_t>::const_iterator it = id_n.begin(); it != id_n.end(); it++ )
93 cost *= cl.vars()[*it].states();
94
95 return cost;
96 }
97
98
99 size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i ) {
100 SmallSet<size_t> id_n = cl.bipGraph().delta1( i );
101
102 size_t cost = 0;
103 // for each unordered pair {i1,i2} adjacent to n
104 for( SmallSet<size_t>::const_iterator it1 = id_n.begin(); it1 != id_n.end(); it1++ )
105 for( SmallSet<size_t>::const_iterator it2 = it1; it2 != id_n.end(); it2++ )
106 if( it1 != it2 ) {
107 // if i1 and i2 are not adjacent, eliminating n would make them adjacent
108 if( !cl.adj(*it1, *it2) )
109 cost++;
110 }
111
112 return cost;
113 }
114
115
116 size_t eliminationCost_WeightedMinFill( const ClusterGraph &cl, size_t i ) {
117 SmallSet<size_t> id_n = cl.bipGraph().delta1( i );
118
119 size_t cost = 0;
120 // for each unordered pair {i1,i2} adjacent to n
121 for( SmallSet<size_t>::const_iterator it1 = id_n.begin(); it1 != id_n.end(); it1++ )
122 for( SmallSet<size_t>::const_iterator it2 = it1; it2 != id_n.end(); it2++ )
123 if( it1 != it2 ) {
124 // if i1 and i2 are not adjacent, eliminating n would make them adjacent
125 if( !cl.adj(*it1, *it2) )
126 cost += cl.vars()[*it1].states() * cl.vars()[*it2].states();
127 }
128
129 return cost;
130 }
131
132
133 } // end of namespace dai