Merged regiongraph.* and daialg.* from SVN head,
[libdai.git] / include / dai / daialg.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_daialg_h
23 #define __defined_libdai_daialg_h
24
25
26 #include <string>
27 #include <iostream>
28 #include <vector>
29 #include <dai/factorgraph.h>
30 #include <dai/regiongraph.h>
31
32
33 namespace dai {
34
35
36 /// The InfAlg class is the common denominator of the various approximate inference algorithms.
37 /// A InfAlg object represents a discrete factorized probability distribution over multiple variables
38 /// together with an inference algorithm.
39 class InfAlg {
40 public:
41 /// Clone (virtual copy constructor)
42 virtual InfAlg* clone() const = 0;
43
44 /// Create (virtual constructor)
45 virtual InfAlg* create() const = 0;
46
47 /// Virtual desctructor
48 // (this is needed because this class contains virtual functions)
49 virtual ~InfAlg() {}
50
51 /// Identifies itself for logging purposes
52 virtual std::string identify() const = 0;
53
54 /// Get single node belief
55 virtual Factor belief( const Var &n ) const = 0;
56
57 /// Get general belief
58 virtual Factor belief( const VarSet &n ) const = 0;
59
60 /// Get all beliefs
61 virtual std::vector<Factor> beliefs() const = 0;
62
63 /// Get log partition sum
64 virtual Real logZ() const = 0;
65
66 /// Clear messages and beliefs
67 virtual void init() = 0;
68
69 /// Clear messages and beliefs corresponding to the nodes in ns
70 virtual void init( const VarSet &ns ) = 0;
71
72 /// The actual approximate inference algorithm
73 virtual double run() = 0;
74
75 /// Save factor I
76 virtual void backupFactor( size_t I ) = 0;
77 /// Save Factors involving ns
78 virtual void backupFactors( const VarSet &ns ) = 0;
79
80 /// Restore factor I
81 virtual void restoreFactor( size_t I ) = 0;
82 /// Restore Factors involving ns
83 virtual void restoreFactors( const VarSet &ns ) = 0;
84
85 /// Clamp variable n to value i (i.e. multiply with a Kronecker delta \f$\delta_{x_n, i}\f$)
86 virtual void clamp( const Var & n, size_t i, bool backup = false ) = 0;
87
88 /// Set all factors interacting with var(i) to 1
89 virtual void makeCavity( size_t i, bool backup = false ) = 0;
90
91 /// Get reference to underlying FactorGraph
92 virtual FactorGraph &fg() = 0;
93
94 /// Get const reference to underlying FactorGraph
95 virtual const FactorGraph &fg() const = 0;
96
97 /// Return maximum difference between beliefs in the last pass
98 virtual double maxDiff() const = 0;
99 };
100
101
102 template <class T>
103 class DAIAlg : public InfAlg, public T {
104 public:
105 /// Default constructor
106 DAIAlg() : InfAlg(), T() {}
107
108 /// Construct from T
109 DAIAlg( const T &t ) : InfAlg(), T(t) {}
110
111 /// Copy constructor
112 DAIAlg( const DAIAlg & x ) : InfAlg(x), T(x) {}
113
114 /// Assignment operator
115 DAIAlg & operator=( const DAIAlg &x ) {
116 if( this != &x ) {
117 InfAlg::operator=(x);
118 T::operator=(x);
119 }
120 return *this;
121 }
122
123 /// Save factor I (using T::backupFactor)
124 void backupFactor( size_t I ) { T::backupFactor( I ); }
125 /// Save Factors involving ns (using T::backupFactors)
126 void backupFactors( const VarSet &ns ) { T::backupFactors( ns ); }
127
128 /// Restore factor I (using T::restoreFactor)
129 void restoreFactor( size_t I ) { T::restoreFactor( I ); }
130 /// Restore Factors involving ns (using T::restoreFactors)
131 void restoreFactors( const VarSet &ns ) { T::restoreFactors( ns ); }
132
133 /// Clamp variable n to value i (i.e. multiply with a Kronecker delta \f$\delta_{x_n, i}\f$) (using T::clamp)
134 void clamp( const Var & n, size_t i, bool backup = false ) { T::clamp( n, i, backup ); }
135
136 /// Set all factors interacting with var(i) to 1 (using T::makeCavity)
137 void makeCavity( size_t i, bool backup = false ) { T::makeCavity( i, backup ); }
138
139 /// Get reference to underlying FactorGraph
140 FactorGraph &fg() { return (FactorGraph &)(*this); }
141
142 /// Get const reference to underlying FactorGraph
143 const FactorGraph &fg() const { return (const FactorGraph &)(*this); }
144 };
145
146
147 typedef DAIAlg<FactorGraph> DAIAlgFG;
148 typedef DAIAlg<RegionGraph> DAIAlgRG;
149
150
151 /// Calculate the marginal of obj on ns by clamping
152 /// all variables in ns and calculating logZ for each joined state
153 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit );
154
155
156 /// Calculate beliefs of all pairs in ns (by clamping
157 /// nodes in ns and calculating logZ and the beliefs for each state)
158 std::vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reInit );
159
160
161 /// Calculate beliefs of all pairs in ns (by clamping
162 /// pairs in ns and calculating logZ for each joined state)
163 std::vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool reInit );
164
165
166 /// Calculate 2nd order interactions of the marginal of obj on ns
167 Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit );
168
169
170 } // end of namespace dai
171
172
173 #endif