Replaced some FactorGraph functionality in InfAlg by a function
[libdai.git] / include / dai / daialg.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_daialg_h
23 #define __defined_libdai_daialg_h
24
25
26 #include <string>
27 #include <iostream>
28 #include <vector>
29 #include <dai/factorgraph.h>
30 #include <dai/regiongraph.h>
31 #include <dai/properties.h>
32
33
34 namespace dai {
35
36
37 /// The InfAlg class is the common denominator of the various approximate inference algorithms.
38 /// A InfAlg object represents a discrete factorized probability distribution over multiple variables
39 /// together with an inference algorithm.
40 class InfAlg {
41 private:
42 /// Properties of the algorithm (replaces _tol, _maxiter, _verbose)
43 Properties _properties;
44
45 /// Maximum difference encountered so far
46 double _maxdiff;
47
48
49 public:
50 /// Default constructor
51 InfAlg() : _properties(), _maxdiff(0.0) {}
52
53 /// Constructor with options
54 InfAlg( const Properties &opts ) : _properties(opts), _maxdiff(0.0) {}
55
56 /// Copy constructor
57 InfAlg( const InfAlg & x ) : _properties(x._properties), _maxdiff(x._maxdiff) {}
58
59 /// Clone (virtual copy constructor)
60 virtual InfAlg* clone() const = 0;
61
62 /// Assignment operator
63 InfAlg & operator=( const InfAlg & x ) {
64 if( this != &x ) {
65 _properties = x._properties;
66 _maxdiff = x._maxdiff;
67 }
68 return *this;
69 }
70
71 /// Virtual desctructor
72 // (this is needed because this class contains virtual functions)
73 virtual ~InfAlg() {}
74
75 /// Returns true if a property with the given key is present
76 bool HasProperty(const PropertyKey &key) const { return _properties.hasKey(key); }
77
78 /// Gets a property
79 const PropertyValue & GetProperty(const PropertyKey &key) const { return _properties.Get(key); }
80
81 /// Gets a property, casted as ValueType
82 template<typename ValueType>
83 ValueType GetPropertyAs(const PropertyKey &key) const { return _properties.GetAs<ValueType>(key); }
84
85 /// Sets a property
86 void SetProperty(const PropertyKey &key, const PropertyValue &val) { _properties[key] = val; }
87
88 /// Converts a property from string to ValueType, if necessary
89 template<typename ValueType>
90 void ConvertPropertyTo(const PropertyKey &key) { _properties.ConvertTo<ValueType>(key); }
91
92 /// Gets all properties
93 const Properties & GetProperties() const { return _properties; }
94
95 /// Sets properties
96 void SetProperties(const Properties &p) { _properties = p; }
97
98 /// Sets tolerance
99 void Tol( double tol ) { SetProperty("tol", tol); }
100 /// Gets tolerance
101 double Tol() const { return GetPropertyAs<double>("tol"); }
102
103 /// Sets maximum number of iterations
104 void MaxIter( size_t maxiter ) { SetProperty("maxiter", maxiter); }
105 /// Gets maximum number of iterations
106 size_t MaxIter() const { return GetPropertyAs<size_t>("maxiter"); }
107
108 /// Sets verbosity
109 void Verbose( size_t verbose ) { SetProperty("verbose", verbose); }
110 /// Gets verbosity
111 size_t Verbose() const { return GetPropertyAs<size_t>("verbose"); }
112
113 /// Sets maximum difference encountered so far
114 void MaxDiff( double maxdiff ) { _maxdiff = maxdiff; }
115 /// Gets maximum difference encountered so far
116 double MaxDiff() const { return _maxdiff; }
117 /// Updates maximum difference encountered so far
118 void updateMaxDiff( double maxdiff ) { if( maxdiff > _maxdiff ) _maxdiff = maxdiff; }
119 /// Sets maximum difference encountered so far to zero
120 void clearMaxDiff() { _maxdiff = 0.0; }
121
122 /// Identifies itself for logging purposes
123 virtual std::string identify() const = 0;
124
125 /// Get single node belief
126 virtual Factor belief( const Var &n ) const = 0;
127
128 /// Get general belief
129 virtual Factor belief( const VarSet &n ) const = 0;
130
131 /// Get all beliefs
132 virtual std::vector<Factor> beliefs() const = 0;
133
134 /// Get log partition sum
135 virtual Complex logZ() const = 0;
136
137 /// Clear messages and beliefs
138 virtual void init() = 0;
139
140 /// The actual approximate inference algorithm
141 virtual double run() = 0;
142
143 /// Save factor I
144 virtual void saveProb( size_t I ) = 0;
145 /// Save Factors involving ns
146 virtual void saveProbs( const VarSet &ns ) = 0;
147
148 /// Restore factor I
149 virtual void undoProb( size_t I ) = 0;
150 /// Restore Factors involving ns
151 virtual void undoProbs( const VarSet &ns ) = 0;
152
153 /// Clamp variable n to value i (i.e. multiply with a Kronecker delta \f$\delta_{x_n, i}\f$)
154 virtual void clamp( const Var & n, size_t i ) = 0;
155
156 /// Set all factors interacting with var(i) to 1
157 virtual void makeCavity( size_t i ) = 0;
158
159 /// Factor I has been updated
160 virtual void updatedFactor( size_t I ) = 0;
161
162 /// Get reference to underlying FactorGraph
163 virtual FactorGraph &fg() = 0;
164
165 /// Get const reference to underlying FactorGraph
166 virtual const FactorGraph &fg() const = 0;
167
168 /// Checks whether all necessary properties have been set
169 /// and casts string-valued properties to other values if necessary
170 virtual bool checkProperties() = 0;
171 };
172
173
174 template <class T>
175 class DAIAlg : public InfAlg, public T {
176 public:
177 /// Default constructor
178 DAIAlg() : InfAlg(), T() {}
179
180 /// Construct DAIAlg with empty T but using the specified properties
181 DAIAlg( const Properties &opts ) : InfAlg( opts ), T() {}
182
183 /// Construct DAIAlg using the specified properties
184 DAIAlg( const T & t, const Properties &opts ) : InfAlg( opts ), T(t) {}
185
186 /// Copy constructor
187 DAIAlg( const DAIAlg & x ) : InfAlg(x), T(x) {}
188
189 /// Save factor I (using T::saveProb)
190 void saveProb( size_t I ) { T::saveProb( I ); }
191 /// Save Factors involving ns (using T::saveProbs)
192 void saveProbs( const VarSet &ns ) { T::saveProbs( ns ); }
193
194 /// Restore factor I (using T::undoProb)
195 void undoProb( size_t I ) { T::undoProb( I ); }
196 /// Restore Factors involving ns (using T::undoProbs)
197 void undoProbs( const VarSet &ns ) { T::undoProbs( ns ); }
198
199 /// Clamp variable n to value i (i.e. multiply with a Kronecker delta \f$\delta_{x_n, i}\f$) (using T::clamp)
200 void clamp( const Var & n, size_t i ) { T::clamp( n, i ); }
201
202 /// Set all factors interacting with var(i) to 1 (using T::makeCavity)
203 void makeCavity( size_t i ) { T::makeCavity( i ); }
204
205 /// Factor I has been updated (using T::updatedFactor)
206 void updatedFactor( size_t I ) { T::updatedFactor(I); }
207
208 /// Get reference to underlying FactorGraph
209 FactorGraph &fg() { return (FactorGraph &)(*this); }
210
211 /// Get const reference to underlying FactorGraph
212 const FactorGraph &fg() const { return (const FactorGraph &)(*this); }
213 };
214
215
216 typedef DAIAlg<FactorGraph> DAIAlgFG;
217 typedef DAIAlg<RegionGraph> DAIAlgRG;
218
219
220 /// Calculate the marginal of obj on ns by clamping
221 /// all variables in ns and calculating logZ for each joined state
222 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit );
223
224
225 /// Calculate beliefs of all pairs in ns (by clamping
226 /// nodes in ns and calculating logZ and the beliefs for each state)
227 std::vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit );
228
229
230 /// Calculate beliefs of all pairs in ns (by clamping
231 /// pairs in ns and calculating logZ for each joined state)
232 std::vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit );
233
234
235 /// Calculate 2nd order interactions of the marginal of obj on ns
236 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit );
237
238
239 } // end of namespace dai
240
241
242 #endif