Improved ClusterGraph implementation and MaxSpanningTreePrims implementation.
[libdai.git] / include / dai / factorgraph.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_factorgraph_h
23 #define __defined_libdai_factorgraph_h
24
25
26 #include <iostream>
27 #include <map>
28 #include <tr1/unordered_map>
29 #include <dai/bipgraph.h>
30 #include <dai/factor.h>
31
32
33 namespace dai {
34
35
36 bool hasShortLoops( const std::vector<Factor> &P );
37 void RemoveShortLoops( std::vector<Factor> &P );
38
39
40 class FactorGraph {
41 public:
42 BipartiteGraph G;
43 std::vector<Var> vars;
44 std::vector<Factor> factors;
45 typedef BipartiteGraph::Neighbor Neighbor;
46 typedef BipartiteGraph::Neighbors Neighbors;
47 typedef BipartiteGraph::Edge Edge;
48
49 protected:
50 std::map<size_t,Prob> _undoProbs;
51 Prob::NormType _normtype;
52
53 public:
54 /// Default constructor
55 FactorGraph() : G(), vars(), factors(), _undoProbs(), _normtype(Prob::NORMPROB) {};
56 /// Copy constructor
57 FactorGraph(const FactorGraph & x) : G(x.G), vars(x.vars), factors(x.factors), _undoProbs(x._undoProbs), _normtype(x._normtype) {};
58 /// Construct FactorGraph from vector of Factors
59 FactorGraph(const std::vector<Factor> &P);
60 // Construct a FactorGraph from given factor and variable iterators
61 template<typename FactorInputIterator, typename VarInputIterator>
62 FactorGraph(FactorInputIterator fact_begin, FactorInputIterator fact_end, VarInputIterator var_begin, VarInputIterator var_end, size_t nr_fact_hint = 0, size_t nr_var_hint = 0 );
63
64 /// Assignment operator
65 FactorGraph & operator=(const FactorGraph & x) {
66 if( this != &x ) {
67 G = x.G;
68 vars = x.vars;
69 factors = x.factors;
70 _undoProbs = x._undoProbs;
71 _normtype = x._normtype;
72 }
73 return *this;
74 }
75 virtual ~FactorGraph() {}
76
77 // aliases
78 Var & var(size_t i) { return vars[i]; }
79 const Var & var(size_t i) const { return vars[i]; }
80 Factor & factor(size_t I) { return factors[I]; }
81 const Factor & factor(size_t I) const { return factors[I]; }
82
83 size_t nrVars() const { return vars.size(); }
84 size_t nrFactors() const { return factors.size(); }
85 size_t nrEdges() const { return G.nrEdges(); }
86
87 /// Provides read access to neighbors of variable
88 const Neighbors & nbV( size_t i ) const { return G.nb1(i); }
89 /// Provides full access to neighbors of variable
90 Neighbors & nbV( size_t i ) { return G.nb1(i); }
91 /// Provides read access to neighbors of factor
92 const Neighbors & nbF( size_t I ) const { return G.nb2(I); }
93 /// Provides full access to neighbors of factor
94 Neighbors & nbF( size_t I ) { return G.nb2(I); }
95 /// Provides read access to neighbor of variable
96 const Neighbor & nbV( size_t i, size_t _I ) const { return G.nb1(i)[_I]; }
97 /// Provides full access to neighbor of variable
98 Neighbor & nbV( size_t i, size_t _I ) { return G.nb1(i)[_I]; }
99 /// Provides read access to neighbor of factor
100 const Neighbor & nbF( size_t I, size_t _i ) const { return G.nb2(I)[_i]; }
101 /// Provides full access to neighbor of factor
102 Neighbor & nbF( size_t I, size_t _i ) { return G.nb2(I)[_i]; }
103
104 size_t findVar(const Var & n) const {
105 size_t i = find( vars.begin(), vars.end(), n ) - vars.begin();
106 assert( i != nrVars() );
107 return i;
108 }
109 size_t findFactor(const VarSet &ns) const {
110 size_t I;
111 for( I = 0; I < nrFactors(); I++ )
112 if( factor(I).vars() == ns )
113 break;
114 assert( I != nrFactors() );
115 return I;
116 }
117
118 friend std::ostream& operator << (std::ostream& os, const FactorGraph& fg);
119 friend std::istream& operator >> (std::istream& is, FactorGraph& fg);
120
121 VarSet delta( unsigned i ) const;
122 VarSet Delta( unsigned i ) const;
123 virtual void makeCavity( unsigned i );
124
125 long ReadFromFile(const char *filename);
126 long WriteToFile(const char *filename) const;
127 long WriteToDotFile(const char *filename) const;
128
129 virtual void clamp( const Var & n, size_t i );
130
131 bool hasNegatives() const;
132 Prob::NormType NormType() const { return _normtype; }
133
134 std::vector<VarSet> Cliques() const;
135
136 virtual void undoProbs( const VarSet &ns );
137 void saveProbs( const VarSet &ns );
138 virtual void undoProb( size_t I );
139 void saveProb( size_t I );
140
141 virtual void updatedFactor( size_t /*I*/ ) {};
142
143 private:
144 /// Part of constructors (creates edges, neighbors and adjacency matrix)
145 void createGraph( size_t nrEdges );
146 };
147
148
149 // assumes that the set of variables in [var_begin,var_end) is the union of the variables in the factors in [fact_begin, fact_end)
150 template<typename FactorInputIterator, typename VarInputIterator>
151 FactorGraph::FactorGraph(FactorInputIterator fact_begin, FactorInputIterator fact_end, VarInputIterator var_begin, VarInputIterator var_end, size_t nr_fact_hint, size_t nr_var_hint ) : G(), _undoProbs(), _normtype(Prob::NORMPROB) {
152 // add factors
153 size_t nrEdges = 0;
154 factors.reserve( nr_fact_hint );
155 for( FactorInputIterator p2 = fact_begin; p2 != fact_end; ++p2 ) {
156 factors.push_back( *p2 );
157 nrEdges += p2->vars().size();
158 }
159
160 // add variables
161 vars.reserve( nr_var_hint );
162 for( VarInputIterator p1 = var_begin; p1 != var_end; ++p1 )
163 vars.push_back( *p1 );
164
165 // create graph structure
166 createGraph( nrEdges );
167 }
168
169
170 } // end of namespace dai
171
172
173 #endif