Merged regiongraph.* and daialg.* from SVN head,
[libdai.git] / include / dai / mr.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_mr_h
23 #define __defined_libdai_mr_h
24
25
26 #include <vector>
27 #include <string>
28 #include <dai/factorgraph.h>
29 #include <dai/daialg.h>
30 #include <dai/enum.h>
31 #include <dai/properties.h>
32 #include <dai/exceptions.h>
33
34
35 namespace dai {
36
37
38 class sub_nb;
39
40
41 class MR : public DAIAlgFG {
42 private:
43 bool supported; // is the underlying factor graph supported?
44
45 std::vector<size_t> con; // con[i] = connectivity of spin i
46 std::vector<std::vector<size_t> > nb; // nb[i] are the neighbours of spin i
47 std::vector<std::vector<double> > tJ; // tJ[i][_j] is the tanh of the interaction between spin i and its neighbour nb[i][_j]
48 std::vector<double> theta; // theta[i] is the local field on spin i
49 std::vector<std::vector<double> > M; // M[i][_j] is M^{(i)}_j
50 std::vector<std::vector<size_t> > kindex; // the _j'th neighbour of spin i has spin i as its kindex[i][_j]'th neighbour
51 std::vector<std::vector<std::vector<double> > > cors;
52
53 static const size_t kmax = 31;
54
55 size_t N;
56
57 std::vector<double> Mag;
58
59 public:
60 struct Properties {
61 size_t verbose;
62 double tol;
63 DAI_ENUM(UpdateType,FULL,LINEAR)
64 DAI_ENUM(InitType,RESPPROP,CLAMPING,EXACT)
65 UpdateType updates;
66 InitType inits;
67 } props;
68 double maxdiff;
69
70 public:
71 MR() {}
72 MR( const FactorGraph & fg, const PropertySet &opts );
73 void init(size_t Nin, double *_w, double *_th);
74 void makekindex();
75 void read_files();
76 void init_cor();
77 double init_cor_resp();
78 void solvemcav();
79 void solveM();
80 double run();
81 Factor belief( const Var &n ) const;
82 Factor belief( const VarSet &/*ns*/ ) const {
83 DAI_THROW(NOT_IMPLEMENTED);
84 return Factor();
85 }
86 std::vector<Factor> beliefs() const;
87 Real logZ() const {
88 DAI_THROW(NOT_IMPLEMENTED);
89 return 0.0;
90 }
91 void init() {}
92 /// Clear messages and beliefs corresponding to the nodes in ns
93 virtual void init( const VarSet &/*ns*/ ) {
94 DAI_THROW(NOT_IMPLEMENTED);
95 }
96 static const char *Name;
97 std::string identify() const;
98 double _tJ(size_t i, sub_nb A);
99
100 double Omega(size_t i, size_t _j, size_t _l);
101 double T(size_t i, sub_nb A);
102 double T(size_t i, size_t _j);
103 double Gamma(size_t i, size_t _j, size_t _l1, size_t _l2);
104 double Gamma(size_t i, size_t _l1, size_t _l2);
105
106 double appM(size_t i, sub_nb A);
107 void sum_subs(size_t j, sub_nb A, double *sum_even, double *sum_odd);
108
109 double sign(double a) { return (a >= 0) ? 1.0 : -1.0; }
110 MR* clone() const { return new MR(*this); }
111 /// Create (virtual constructor)
112 virtual MR* create() const { return new MR(); }
113
114 void setProperties( const PropertySet &opts );
115 PropertySet getProperties() const;
116 std::string printProperties() const;
117 double maxDiff() const { return maxdiff; }
118 };
119
120
121 // represents a subset of nb[i] as a binary number
122 // the elements of a subset should be thought of as indices in nb[i]
123 class sub_nb {
124 private:
125 size_t s;
126 size_t bits;
127
128 public:
129 // construct full subset containing nr_elmt elements
130 sub_nb(size_t nr_elmt) {
131 #ifdef DAI_DEBUG
132 assert( nr_elmt < sizeof(size_t) / sizeof(char) * 8 );
133 #endif
134 bits = nr_elmt;
135 s = (1U << bits) - 1;
136 }
137
138 // copy constructor
139 sub_nb( const sub_nb & x ) : s(x.s), bits(x.bits) {}
140
141 // assignment operator
142 sub_nb & operator=( const sub_nb &x ) {
143 if( this != &x ) {
144 s = x.s;
145 bits = x.bits;
146 }
147 return *this;
148 }
149
150 // returns number of elements
151 size_t size() {
152 size_t size = 0;
153 for(size_t bit = 0; bit < bits; bit++)
154 if( s & (1U << bit) )
155 size++;
156 return size;
157 }
158
159 // increases s by one (for enumeration in lexicographical order)
160 sub_nb operator++() {
161 s++;
162 if( s >= (1U << bits) )
163 s = 0;
164 return *this;
165 }
166
167 // return i'th element of this subset
168 size_t operator[](size_t i) {
169 size_t bit;
170 for(bit = 0; bit < bits; bit++ )
171 if( s & (1U << bit) ) {
172 if( i == 0 )
173 break;
174 else
175 i--;
176 }
177 #ifdef DAI_DEBUG
178 assert( bit < bits );
179 #endif
180 return bit;
181 }
182
183 // add index _j to this subset
184 sub_nb &operator +=(size_t _j) {
185 s |= (1U << _j);
186 return *this;
187 }
188
189 // return copy with index _j
190 sub_nb operator+(size_t _j) {
191 sub_nb x = *this;
192 x += _j;
193 return x;
194 }
195
196 // delete index _j from this subset
197 sub_nb &operator -=(size_t _j) {
198 s &= ~(1U << _j);
199 return *this;
200 }
201
202 // return copy without index _j
203 sub_nb operator-(size_t _j) {
204 sub_nb x = *this;
205 x -= _j;
206 return x;
207 }
208
209 // empty this subset
210 sub_nb & clear() {
211 s = 0;
212 return *this;
213 }
214
215 // returns true if subset is empty
216 bool empty() { return (s == 0); }
217
218 // return 1 if _j is contained, 0 otherwise ("is element of")
219 size_t operator&(size_t _j) { return s & (1U << _j); }
220
221 friend std::ostream& operator<< (std::ostream& os, const sub_nb x) {
222 if( x.bits == 0 )
223 os << "empty";
224 else {
225 for(size_t bit = x.bits; bit > 0; bit-- )
226 if( x.s & (1U << (bit-1)) )
227 os << "1";
228 else
229 os << "0";
230 }
231 return os;
232 }
233 };
234
235
236 } // end of namespace dai
237
238
239 #endif