f68ea4607f65300009a8a4f8afb6cd5da51c03cf
[libdai.git] / include / dai / factorgraph.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_factorgraph_h
23 #define __defined_libdai_factorgraph_h
24
25
26 #include <iostream>
27 #include <map>
28 #include <tr1/unordered_map>
29 #include <dai/bipgraph.h>
30 #include <dai/factor.h>
31
32
33 namespace dai {
34
35
36 bool hasShortLoops( const std::vector<Factor> &P );
37 void RemoveShortLoops( std::vector<Factor> &P );
38
39
40 class FactorGraph : public BipartiteGraph<Var,Factor> {
41 protected:
42 std::map<size_t,Prob> _undoProbs;
43 Prob::NormType _normtype;
44
45 public:
46 /// Default constructor
47 FactorGraph() : BipartiteGraph<Var,Factor>(), _undoProbs(), _normtype(Prob::NORMPROB) {};
48 /// Copy constructor
49 FactorGraph(const FactorGraph & x) : BipartiteGraph<Var,Factor>(x), _undoProbs(), _normtype(x._normtype) {};
50 /// Construct FactorGraph from vector of Factors
51 FactorGraph(const std::vector<Factor> &P);
52 // Construct a FactorGraph from given factor and variable iterators
53 template<typename FactorInputIterator, typename VarInputIterator>
54 FactorGraph(FactorInputIterator fact_begin, FactorInputIterator fact_end, VarInputIterator var_begin, VarInputIterator var_end, size_t nr_fact_hint = 0, size_t nr_var_hint = 0 );
55
56 /// Assignment operator
57 FactorGraph & operator=(const FactorGraph & x) {
58 if(this!=&x) {
59 BipartiteGraph<Var,Factor>::operator=(x);
60 _undoProbs = x._undoProbs;
61 _normtype = x._normtype;
62 }
63 return *this;
64 }
65 virtual ~FactorGraph() {}
66
67 // aliases
68 Var & var(size_t i) { return V1(i); }
69 const Var & var(size_t i) const { return V1(i); }
70 const std::vector<Var> & vars() const { return V1s(); }
71 std::vector<Var> & vars() { return V1s(); }
72 size_t nrVars() const { return V1s().size(); }
73 Factor & factor(size_t I) { return V2(I); }
74 const Factor & factor(size_t I) const { return V2(I); }
75 const std::vector<Factor> & factors() const { return V2s(); }
76 std::vector<Factor> & factors() { return V2s(); }
77 size_t nrFactors() const { return V2s().size(); }
78
79 /// Provides read access to neighbours of variable
80 const _nb_t & nbV( size_t i1 ) const { return nb1(i1); }
81 /// Provides full access to neighbours of variable
82 _nb_t & nbV( size_t i1 ) { return nb1(i1); }
83 /// Provides read access to neighbours of factor
84 const _nb_t & nbF( size_t i2 ) const { return nb2(i2); }
85 /// Provides full access to neighbours of factor
86 _nb_t & nbF( size_t i2 ) { return nb2(i2); }
87
88 size_t findVar(const Var & n) const {
89 size_t i = find( vars().begin(), vars().end(), n ) - vars().begin();
90 assert( i != nrVars() );
91 return i;
92 }
93 size_t findFactor(const VarSet &ns) const {
94 size_t I;
95 for( I = 0; I < nrFactors(); I++ )
96 if( factor(I).vars() == ns )
97 break;
98 assert( I != nrFactors() );
99 return I;
100 }
101
102 friend std::ostream& operator << (std::ostream& os, const FactorGraph& fg);
103 friend std::istream& operator >> (std::istream& is, FactorGraph& fg);
104
105 VarSet delta(const Var & n) const;
106 VarSet Delta(const Var & n) const;
107 virtual void makeFactorCavity(size_t I);
108 virtual void makeCavity(const Var & n);
109
110 long ReadFromFile(const char *filename);
111 long WriteToFile(const char *filename) const;
112 long WriteToDotFile(const char *filename) const;
113
114 Factor ExactMarginal(const VarSet & x) const;
115 Real ExactlogZ() const;
116
117 virtual void clamp( const Var & n, size_t i );
118
119 bool hasNegatives() const;
120 Prob::NormType NormType() const { return _normtype; }
121
122 std::vector<VarSet> Cliques() const;
123
124 virtual void undoProbs( const VarSet &ns );
125 void saveProbs( const VarSet &ns );
126 virtual void undoProb( size_t I );
127 void saveProb( size_t I );
128
129 bool isConnected() const;
130
131 virtual void updatedFactor( size_t /*I*/ ) {};
132
133 private:
134 /// Part of constructors (creates edges, neighbours and adjacency matrix)
135 void createGraph( size_t nrEdges );
136 };
137
138
139 // assumes that the set of variables in [var_begin,var_end) is the union of the variables in the factors in [fact_begin, fact_end)
140 template<typename FactorInputIterator, typename VarInputIterator>
141 FactorGraph::FactorGraph(FactorInputIterator fact_begin, FactorInputIterator fact_end, VarInputIterator var_begin, VarInputIterator var_end, size_t nr_fact_hint, size_t nr_var_hint ) : BipartiteGraph<Var,Factor>(), _undoProbs(), _normtype(Prob::NORMPROB) {
142 // add factors
143 size_t nrEdges = 0;
144 V2s().reserve( nr_fact_hint );
145 for( FactorInputIterator p2 = fact_begin; p2 != fact_end; ++p2 ) {
146 V2s().push_back( *p2 );
147 nrEdges += p2->vars().size();
148 }
149
150 // add variables
151 V1s().reserve( nr_var_hint );
152 for( VarInputIterator p1 = var_begin; p1 != var_end; ++p1 )
153 V1s().push_back( *p1 );
154
155 // create graph structure
156 createGraph( nrEdges );
157 }
158
159
160 } // end of namespace dai
161
162
163 #endif