Removed stuff from InfAlg, moved it to individual inference algorithms
[libdai.git] / include / dai / daialg.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_daialg_h
23 #define __defined_libdai_daialg_h
24
25
26 #include <string>
27 #include <iostream>
28 #include <vector>
29 #include <dai/factorgraph.h>
30 #include <dai/regiongraph.h>
31
32
33 namespace dai {
34
35
36 /// The InfAlg class is the common denominator of the various approximate inference algorithms.
37 /// A InfAlg object represents a discrete factorized probability distribution over multiple variables
38 /// together with an inference algorithm.
39 class InfAlg {
40 public:
41 /// Clone (virtual copy constructor)
42 virtual InfAlg* clone() const = 0;
43
44 /// Virtual desctructor
45 // (this is needed because this class contains virtual functions)
46 virtual ~InfAlg() {}
47
48 /// Identifies itself for logging purposes
49 virtual std::string identify() const = 0;
50
51 /// Get single node belief
52 virtual Factor belief( const Var &n ) const = 0;
53
54 /// Get general belief
55 virtual Factor belief( const VarSet &n ) const = 0;
56
57 /// Get all beliefs
58 virtual std::vector<Factor> beliefs() const = 0;
59
60 /// Get log partition sum
61 virtual Complex logZ() const = 0;
62
63 /// Clear messages and beliefs
64 virtual void init() = 0;
65
66 /// The actual approximate inference algorithm
67 virtual double run() = 0;
68
69 /// Save factor I
70 virtual void saveProb( size_t I ) = 0;
71 /// Save Factors involving ns
72 virtual void saveProbs( const VarSet &ns ) = 0;
73
74 /// Restore factor I
75 virtual void undoProb( size_t I ) = 0;
76 /// Restore Factors involving ns
77 virtual void undoProbs( const VarSet &ns ) = 0;
78
79 /// Clamp variable n to value i (i.e. multiply with a Kronecker delta \f$\delta_{x_n, i}\f$)
80 virtual void clamp( const Var & n, size_t i ) = 0;
81
82 /// Return all variables that interact with var(i)
83 virtual VarSet delta( size_t i ) const = 0;
84
85 /// Set all factors interacting with var(i) to 1
86 virtual void makeCavity( size_t i ) = 0;
87
88 /// Get index of variable n
89 virtual size_t findVar( const Var & n ) const = 0;
90
91 /// Get index of first factor involving ns
92 virtual size_t findFactor( const VarSet &ns ) const = 0;
93
94 /// Get number of variables
95 virtual size_t nrVars() const = 0;
96
97 /// Get number of factors
98 virtual size_t nrFactors() const = 0;
99
100 /// Get const reference to variable i
101 virtual const Var & var(size_t i) const = 0;
102
103 /// Get reference to variable i
104 virtual Var & var(size_t i) = 0;
105
106 /// Get const reference to factor I
107 virtual const Factor & factor( size_t I ) const = 0;
108
109 /// Get reference to factor I
110 virtual Factor & factor( size_t I ) = 0;
111
112 /// Factor I has been updated
113 virtual void updatedFactor( size_t I ) = 0;
114
115 /// Return maximum difference between beliefs in the last pass
116 virtual double maxDiff() const = 0;
117 };
118
119
120 template <class T>
121 class DAIAlg : public InfAlg, public T {
122 public:
123 /// Default constructor
124 DAIAlg() : InfAlg(), T() {}
125
126 /// Construct from T
127 DAIAlg( const T &t ) : InfAlg(), T(t) {}
128
129 /// Copy constructor
130 DAIAlg( const DAIAlg & x ) : InfAlg(x), T(x) {}
131
132 /// Save factor I (using T::saveProb)
133 void saveProb( size_t I ) { T::saveProb( I ); }
134 /// Save Factors involving ns (using T::saveProbs)
135 void saveProbs( const VarSet &ns ) { T::saveProbs( ns ); }
136
137 /// Restore factor I (using T::undoProb)
138 void undoProb( size_t I ) { T::undoProb( I ); }
139 /// Restore Factors involving ns (using T::undoProbs)
140 void undoProbs( const VarSet &ns ) { T::undoProbs( ns ); }
141
142 /// Clamp variable n to value i (i.e. multiply with a Kronecker delta \f$\delta_{x_n, i}\f$) (using T::clamp)
143 void clamp( const Var & n, size_t i ) { T::clamp( n, i ); }
144
145 /// Return all variables that interact with var(i) (using T::delta)
146 VarSet delta( size_t i ) const { return T::delta( i ); }
147
148 /// Set all factors interacting with var(i) to 1 (using T::makeCavity)
149 void makeCavity( size_t i ) { T::makeCavity( i ); }
150
151 /// Get index of variable n (using T::findVar)
152 size_t findVar( const Var & n ) const { return T::findVar(n); }
153
154 /// Get index of first factor involving ns (using T::findFactor)
155 size_t findFactor( const VarSet &ns ) const { return T::findFactor(ns); }
156
157 /// Get number of variables (using T::nrFactors)
158 size_t nrVars() const { return T::nrVars(); }
159
160 /// Get number of factors (using T::nrFactors)
161 size_t nrFactors() const { return T::nrFactors(); }
162
163 /// Get const reference to variable i (using T::var)
164 const Var & var( size_t i ) const { return T::var(i); }
165
166 /// Get reference to variable i (using T::var)
167 Var & var(size_t i) { return T::var(i); }
168
169 /// Get const reference to factor I (using T::factor)
170 const Factor & factor( size_t I ) const { return T::factor(I); }
171
172 /// Get reference to factor I (using T::factor)
173 Factor & factor( size_t I ) { return T::factor(I); }
174
175 /// Factor I has been updated (using T::updatedFactor)
176 void updatedFactor( size_t I ) { T::updatedFactor(I); }
177 };
178
179
180 typedef DAIAlg<FactorGraph> DAIAlgFG;
181 typedef DAIAlg<RegionGraph> DAIAlgRG;
182
183
184 /// Calculate the marginal of obj on ns by clamping
185 /// all variables in ns and calculating logZ for each joined state
186 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit );
187
188
189 /// Calculate beliefs of all pairs in ns (by clamping
190 /// nodes in ns and calculating logZ and the beliefs for each state)
191 std::vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit );
192
193
194 /// Calculate beliefs of all pairs in ns (by clamping
195 /// pairs in ns and calculating logZ for each joined state)
196 std::vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit );
197
198
199 /// Calculate 2nd order interactions of the marginal of obj on ns
200 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit );
201
202
203 } // end of namespace dai
204
205
206 #endif