Merged SVN head ...
[libdai.git] / include / dai / mr.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_mr_h
23 #define __defined_libdai_mr_h
24
25
26 #include <vector>
27 #include <string>
28 #include <dai/factorgraph.h>
29 #include <dai/daialg.h>
30 #include <dai/enum.h>
31 #include <dai/properties.h>
32 #include <dai/exceptions.h>
33
34
35 namespace dai {
36
37
38 class sub_nb;
39
40
41 class MR : public DAIAlgFG {
42 private:
43 bool supported; // is the underlying factor graph supported?
44
45 std::vector<size_t> con; // con[i] = connectivity of spin i
46 std::vector<std::vector<size_t> > nb; // nb[i] are the neighbours of spin i
47 std::vector<std::vector<double> > tJ; // tJ[i][_j] is the tanh of the interaction between spin i and its neighbour nb[i][_j]
48 std::vector<double> theta; // theta[i] is the local field on spin i
49 std::vector<std::vector<double> > M; // M[i][_j] is M^{(i)}_j
50 std::vector<std::vector<size_t> > kindex; // the _j'th neighbour of spin i has spin i as its kindex[i][_j]'th neighbour
51 std::vector<std::vector<std::vector<double> > > cors;
52
53 static const size_t kmax = 31;
54
55 size_t N;
56
57 std::vector<double> Mag;
58
59 double _maxdiff;
60 size_t _iters;
61
62 public:
63 struct Properties {
64 size_t verbose;
65 double tol;
66 DAI_ENUM(UpdateType,FULL,LINEAR)
67 DAI_ENUM(InitType,RESPPROP,CLAMPING,EXACT)
68 UpdateType updates;
69 InitType inits;
70 } props;
71 static const char *Name;
72
73 public:
74 /// Default constructor
75 MR() : DAIAlgFG(), supported(), con(), nb(), tJ(), theta(), M(), kindex(), cors(), N(), Mag(), _maxdiff(), _iters(), props() {}
76
77 /// Construct from FactorGraph fg and PropertySet opts
78 MR( const FactorGraph &fg, const PropertySet &opts );
79
80 /// Copy constructor
81 MR( const MR &x ) : DAIAlgFG(x), supported(x.supported), con(x.con), nb(x.nb), tJ(x.tJ), theta(x.theta), M(x.M), kindex(x.kindex), cors(x.cors), N(x.N), Mag(x.Mag), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props) {}
82
83 /// Clone *this (virtual copy constructor)
84 virtual MR* clone() const { return new MR(*this); }
85
86 /// Create (virtual default constructor)
87 virtual MR* create() const { return new MR(); }
88
89 /// Assignment operator
90 MR& operator=( const MR &x ) {
91 if( this != &x ) {
92 DAIAlgFG::operator=(x);
93 supported = x.supported;
94 con = x.con;
95 nb = x.nb;
96 tJ = x.tJ;
97 theta = x.theta;
98 M = x.M;
99 kindex = x.kindex;
100 cors = x.cors;
101 N = x.N;
102 Mag = x.Mag;
103 _maxdiff = x._maxdiff;
104 _iters = x._iters;
105 props = x.props;
106 }
107 return *this;
108 }
109
110 /// Identifies itself for logging purposes
111 virtual std::string identify() const;
112
113 /// Get single node belief
114 virtual Factor belief( const Var &n ) const;
115
116 /// Get general belief
117 virtual Factor belief( const VarSet &/*ns*/ ) const {
118 DAI_THROW(NOT_IMPLEMENTED);
119 return Factor();
120 }
121
122 /// Get all beliefs
123 virtual std::vector<Factor> beliefs() const;
124
125 /// Get log partition sum
126 virtual Real logZ() const {
127 DAI_THROW(NOT_IMPLEMENTED);
128 return 0.0;
129 }
130
131 /// Clear messages and beliefs
132 virtual void init() {}
133
134 /// Clear messages and beliefs corresponding to the nodes in ns
135 virtual void init( const VarSet &/*ns*/ ) {
136 DAI_THROW(NOT_IMPLEMENTED);
137 }
138
139 /// The actual approximate inference algorithm
140 virtual double run();
141
142 /// Return maximum difference between single node beliefs in the last pass
143 virtual double maxDiff() const { return _maxdiff; }
144
145 /// Return number of passes over the factorgraph
146 virtual size_t Iterations() const { return _iters; }
147
148
149 void init(size_t Nin, double *_w, double *_th);
150 void makekindex();
151 void read_files();
152 void init_cor();
153 double init_cor_resp();
154 void solvemcav();
155 void solveM();
156
157 double _tJ(size_t i, sub_nb A);
158
159 double Omega(size_t i, size_t _j, size_t _l);
160 double T(size_t i, sub_nb A);
161 double T(size_t i, size_t _j);
162 double Gamma(size_t i, size_t _j, size_t _l1, size_t _l2);
163 double Gamma(size_t i, size_t _l1, size_t _l2);
164
165 double appM(size_t i, sub_nb A);
166 void sum_subs(size_t j, sub_nb A, double *sum_even, double *sum_odd);
167
168 double sign(double a) { return (a >= 0) ? 1.0 : -1.0; }
169
170 void setProperties( const PropertySet &opts );
171 PropertySet getProperties() const;
172 std::string printProperties() const;
173 };
174
175
176 // represents a subset of nb[i] as a binary number
177 // the elements of a subset should be thought of as indices in nb[i]
178 class sub_nb {
179 private:
180 size_t s;
181 size_t bits;
182
183 public:
184 // construct full subset containing nr_elmt elements
185 sub_nb(size_t nr_elmt) {
186 #ifdef DAI_DEBUG
187 assert( nr_elmt < sizeof(size_t) / sizeof(char) * 8 );
188 #endif
189 bits = nr_elmt;
190 s = (1U << bits) - 1;
191 }
192
193 // copy constructor
194 sub_nb( const sub_nb & x ) : s(x.s), bits(x.bits) {}
195
196 // assignment operator
197 sub_nb & operator=( const sub_nb &x ) {
198 if( this != &x ) {
199 s = x.s;
200 bits = x.bits;
201 }
202 return *this;
203 }
204
205 // returns number of elements
206 size_t size() {
207 size_t size = 0;
208 for(size_t bit = 0; bit < bits; bit++)
209 if( s & (1U << bit) )
210 size++;
211 return size;
212 }
213
214 // increases s by one (for enumeration in lexicographical order)
215 sub_nb operator++() {
216 s++;
217 if( s >= (1U << bits) )
218 s = 0;
219 return *this;
220 }
221
222 // return i'th element of this subset
223 size_t operator[](size_t i) {
224 size_t bit;
225 for(bit = 0; bit < bits; bit++ )
226 if( s & (1U << bit) ) {
227 if( i == 0 )
228 break;
229 else
230 i--;
231 }
232 #ifdef DAI_DEBUG
233 assert( bit < bits );
234 #endif
235 return bit;
236 }
237
238 // add index _j to this subset
239 sub_nb &operator +=(size_t _j) {
240 s |= (1U << _j);
241 return *this;
242 }
243
244 // return copy with index _j
245 sub_nb operator+(size_t _j) {
246 sub_nb x = *this;
247 x += _j;
248 return x;
249 }
250
251 // delete index _j from this subset
252 sub_nb &operator -=(size_t _j) {
253 s &= ~(1U << _j);
254 return *this;
255 }
256
257 // return copy without index _j
258 sub_nb operator-(size_t _j) {
259 sub_nb x = *this;
260 x -= _j;
261 return x;
262 }
263
264 // empty this subset
265 sub_nb & clear() {
266 s = 0;
267 return *this;
268 }
269
270 // returns true if subset is empty
271 bool empty() { return (s == 0); }
272
273 // return 1 if _j is contained, 0 otherwise ("is element of")
274 size_t operator&(size_t _j) { return s & (1U << _j); }
275
276 friend std::ostream& operator<< (std::ostream& os, const sub_nb x) {
277 if( x.bits == 0 )
278 os << "empty";
279 else {
280 for(size_t bit = x.bits; bit > 0; bit-- )
281 if( x.s & (1U << (bit-1)) )
282 os << "1";
283 else
284 os << "0";
285 }
286 return os;
287 }
288 };
289
290
291 } // end of namespace dai
292
293
294 #endif