b6f0c51fa2b5a4ce2b56158173dfbb14dd8f72da
[libdai.git] / examples / example_imagesegmentation.cpp
1 #include <iostream>
2 #include <vector>
3 #include <iterator>
4 #include <algorithm>
5 #include <dai/alldai.h>
6 #include <boost/numeric/ublas/matrix_sparse.hpp>
7 #include <boost/numeric/ublas/matrix_proxy.hpp>
8 #include <boost/numeric/ublas/vector.hpp>
9 #include <boost/numeric/ublas/io.hpp>
10 #include <CImg.h>
11
12 using namespace std;
13 using namespace cimg_library;
14 using namespace dai;
15
16 typedef boost::numeric::ublas::vector<double> ublasvector;
17 typedef boost::numeric::ublas::compressed_matrix<double> ublasmatrix;
18 typedef ublasmatrix::value_array_type::const_iterator matrix_vcit;
19 typedef ublasmatrix::index_array_type::const_iterator matrix_icit;
20
21
22 class BinaryPairwiseGM {
23 public:
24 size_t N;
25 ublasmatrix w;
26 ublasvector th;
27 double logZ0;
28
29 BinaryPairwiseGM() {}
30 BinaryPairwiseGM( const FactorGraph &fg );
31 BinaryPairwiseGM( size_t _N, const ublasmatrix &_w, const ublasvector &_th, double _logZ0 ) : N(_N), w(_w), th(_th), logZ0(_logZ0) {}
32 BinaryPairwiseGM( const BinaryPairwiseGM &x ) : N(x.N), w(x.w), th(x.th), logZ0(x.logZ0) {};
33 BinaryPairwiseGM & operator=( const BinaryPairwiseGM &x ) {
34 if( this != &x ) {
35 N = x.N;
36 w = x.w;
37 th = x.th;
38 logZ0 = x.logZ0;
39 }
40 return *this;
41 }
42 double doBP( size_t maxiter, double tol, size_t verbose, ublasvector &m );
43 FactorGraph toFactorGraph();
44 };
45
46
47 // w should be upper triangular or lower triangular
48 void WTh2FG( const ublasmatrix &w, const vector<double> &th, FactorGraph &fg ) {
49 vector<Var> vars;
50 vector<Factor> factors;
51
52 size_t N = th.size();
53 assert( (w.size1() == N) && (w.size2() == N) );
54
55 vars.reserve(N);
56 for( size_t i = 0; i < N; i++ )
57 vars.push_back(Var(i,2));
58
59 factors.reserve( w.nnz() + N );
60 // walk through the sparse array structure
61 // this is similar to matlab sparse arrays
62 // index2 gives the column index
63 // index1 gives the starting indices for each row
64 size_t i = 0;
65 // cout << w << endl;
66 for( size_t pos = 0; pos < w.nnz(); pos++ ) {
67 while( pos == w.index1_data()[i+1] )
68 i++;
69 size_t j = w.index2_data()[pos];
70 double w_ij = w.value_data()[pos];
71 // cout << "(" << i << "," << j << "): " << w_ij << endl;
72 factors.push_back( createFactorIsing( vars[i], vars[j], w_ij ) );
73 }
74 for( size_t i = 0; i < N; i++ )
75 factors.push_back( createFactorIsing( vars[i], th[i] ) );
76
77 fg = FactorGraph(factors);
78 }
79
80
81 template<class T>
82 void Image2net( const CImg<T> &img, double J, double th_min, double th_plus, double th_tol, double p_background, BinaryPairwiseGM &net ) {
83 size_t dimx = img.dimx();
84 size_t dimy = img.dimy();
85
86 net.N = dimx * dimy;
87 net.w = ublasmatrix(net.N,net.N,4*net.N);
88 net.th = ublasvector(net.N);
89 for( size_t i = 0; i < net.N; i++ )
90 net.th[i] = 0.0;
91 net.logZ0 = 0.0;
92
93 CImg<float> hist = img.get_channel(0).get_histogram(256,0,255);
94 size_t cum_hist = 0;
95 size_t level = 0;
96 for( level = 0; level < 256; level++ ) {
97 cum_hist += (size_t)hist(level);
98 if( cum_hist > p_background * dimx * dimy )
99 break;
100 }
101
102 double th_avg = (th_min + th_plus) / 2.0;
103 double th_width = (th_plus - th_min) / 2.0;
104 for( size_t i = 0; i < dimx; i++ )
105 for( size_t j = 0; j < dimy; j++ ) {
106 if( i+1 < dimx )
107 net.w(i*dimy+j, (i+1)*dimy+j) = J;
108 if( i >= 1 )
109 net.w(i*dimy+j, (i-1)*dimy+j) = J;
110 if( j+1 < dimy )
111 net.w(i*dimy+j, i*dimy+(j+1)) = J;
112 if( j >= 1 )
113 net.w(i*dimy+j, i*dimy+(j-1)) = J;
114 double x = img(i,j);
115 net.th[i*dimy+j] = th_avg + th_width * tanh((x - level)/th_tol);
116 /* if( x < level )
117 x = x / level * 0.5;
118 else
119 x = 0.5 + 0.5 * ((x - level) / (255 - level));*/
120 /* if( x < level )
121 x = 0.01;
122 else
123 x = 0.99;
124 th[i*dimy+j] = 0.5 * (log(x) - log(1.0 - x));*/
125 }
126 }
127
128
129 template<class T>
130 FactorGraph img2fg( const CImg<T> &img, double J, double th_min, double th_plus, double th_tol, double p_background ) {
131 vector<Var> vars;
132 vector<Factor> factors;
133
134 size_t dimx = img.dimx();
135 size_t dimy = img.dimy();
136 size_t N = dimx * dimy;
137
138 // create variables
139 vars.reserve( N );
140 for( size_t i = 0; i < N; i++ )
141 vars.push_back( Var( i, 2 ) );
142
143 // build histogram
144 CImg<float> hist = img.get_channel(0).get_histogram(256,0,255);
145 size_t cum_hist = 0;
146 size_t level = 0;
147 for( level = 0; level < 256; level++ ) {
148 cum_hist += (size_t)hist(level);
149 if( cum_hist > p_background * dimx * dimy )
150 break;
151 }
152
153 // create factors
154 factors.reserve( 3 * N - dimx - dimy );
155 double th_avg = (th_min + th_plus) / 2.0;
156 double th_width = (th_plus - th_min) / 2.0;
157 for( size_t i = 0; i < dimx; i++ )
158 for( size_t j = 0; j < dimy; j++ ) {
159 if( i >= 1 )
160 factors.push_back( createFactorIsing( vars[i*dimy+j], vars[(i-1)*dimy+j], J ) );
161 if( j >= 1 )
162 factors.push_back( createFactorIsing( vars[i*dimy+j], vars[i*dimy+(j-1)], J ) );
163 double x = img(i,j);
164 factors.push_back( createFactorIsing( vars[i*dimy+j], th_avg + th_width * tanh((x - level)/th_tol) ) );
165 }
166 }
167
168 return FactorGraph( factors.begin(), factors.end(), vars.begin(), vars.end(), factors.size(), vars.size() );
169 }
170
171
172 double myBP( BinaryPairwiseGM &net, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp );
173 double myMF( BinaryPairwiseGM &net, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp );
174 double doInference( FactorGraph &fg, string AlgOpts, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp );
175
176 int main(int argc,char **argv) {
177 // Display program usage, when invoked from the command line with option '-h'.
178 cimg_usage("Usage: example_imagesegmentation -i <inputimage1> -j <inputimage2> -o <outputimage1> -p <outputimage2> -J <J> -t <t> -s <s> -u <u> -x <x>");
179 const char* file_i = cimg_option("-i","","Input image 1");
180 const char* file_j = cimg_option("-j","","Input image 2");
181 const char* file_o = cimg_option("-o","","Output image (with BP)");
182 const char* file_p = cimg_option("-p","","Output image (without BP)");
183 const double J = cimg_option("-J",0.0,"Coupling strength");
184 const double th_min = cimg_option("-t",0.0,"Local evidence strength background");
185 const double th_plus = cimg_option("-s",0.0,"Local evidence strength foreground");
186 const double th_tol = cimg_option("-u",0.0,"Sensitivity for fore/background");
187 const double p_background = cimg_option("-x",0.0,"Percentage of background in image");
188
189 CImg<unsigned char> image1 = CImg<>(file_i);
190 CImg<unsigned char> image2 = CImg<>(file_j);
191
192 CImg<int> image3(image1);
193 image3 -= image2;
194 image3.abs();
195 image3.norm_pointwise(1); // 1 = L1, 2 = L2, -1 = Linf
196 // normalize
197 for( size_t i = 0; i < image3.dimx(); i++ ) {
198 for( size_t j = 0; j < image3.dimy(); j++ ) {
199 int avg = 0;
200 for( size_t c = 0; c < image1.dimv(); c++ )
201 avg += image1(i,j,c);
202 avg /= image1.dimv();
203 image3(i,j,0) /= (1.0 + avg / 255.0);
204 }
205 }
206 image3.normalize(0,255);
207
208 CImgDisplay disp1(image1,"Input 1",0);
209 CImgDisplay disp2(image2,"Input 2",0);
210 CImgDisplay disp3(image3,"Absolute difference of both inputs",0);
211
212 //BinaryPairwiseGM net;
213 //Image2net( image3, J, th_min, th_plus, th_tol, p_background, net );
214 FactorGraph fg = img2fg( image3, J, th_min, th_plis, th_tol, p_background );
215
216 size_t dimx = image3.dimx();
217 size_t dimy = image3.dimy();
218 CImg<unsigned char> image4(dimx,dimy,1,3);
219
220 ublasvector m;
221 //net.doBP( 0, 1e-2, 3, m );
222 BP bp( fg, PropertySet("[updates=SEQFIX,maxiter=0,tol=1e-9,verbose=0,logdomain=0]") );
223 bp.init();
224 for( size_t i = 0; i < dimx; i++ )
225 for( size_t j = 0; j < dimy; j++ ) {
226 unsigned char g = (unsigned char)(bp.belief(fg.var(i*dimy+j))[1] * 255.0);
227 // unsigned char g = (unsigned char)((m[i*dimy+j] + 1.0) / 2.0 * 255.0);
228 if( g > 127 ) {
229 image4(i,j,0) = 255;
230 image4(i,j,1) = 2 * (g - 127);
231 image4(i,j,2) = 2 * (g - 127);
232 } else {
233 image4(i,j,0) = 0;
234 image4(i,j,1) = 0;
235 image4(i,j,2) = 2*g;
236 }
237 }
238 CImgDisplay disp4(image4,"Local evidence",0);
239 image4.save_jpeg(file_p,100);
240
241 // solve the problem and show intermediate steps
242 CImgDisplay disp5(dimx,dimy,"Beliefs during inference",0);
243 if( 1 ) {
244 //FactorGraph fg = net.toFactorGraph();
245 fg.WriteToFile( "joris.fg" );
246
247 doInference( fg, "BP[updates=SEQMAX,maxiter=1,tol=1e-9,verbose=0,logdomain=0]", 1000, 1e-5, 3, m, dimx, dimy, disp5 );
248 // doInference( fg, "HAK[doubleloop=0,clusters=LOOP,init=UNIFORM,loopdepth=4,tol=1e-9,maxiter=1,verbose=3]", 1000, 1e-5, 3, m, dimx, dimy, disp5 );
249 // doInference( fg, "HAK[doubleloop=0,clusters=BETHE,init=UNIFORM,maxiter=1,tol=1e-9,verbose=3]", 1000, 1e-5, 3, m, dimx, dimy, disp5 );
250 // doInference( fg, "MF[tol=1e-9,maxiter=1,damping=0.0,init=RANDOM,updates=NAIVE]", 1000, 1e-5, 3, m, dimx, dimy, disp5 );
251 } else {
252 // myBP( net, 1000, 1e-5, 3, m, dimx, dimy, disp5 );
253 // myMF( net, 1000, 1e-5, 3, m, dimx, dimy, disp5 );
254 }
255
256 for( size_t i = 0; i < dimx; i++ )
257 for( size_t j = 0; j < dimy; j++ ) {
258 // unsigned char g = (unsigned char)(bp.belief(fg.var(i*dimy+j))[1] * 255.0);
259 unsigned char g = (unsigned char)((m[i*dimy+j] + 1.0) / 2.0 * 255.0);
260 if( g > 127 ) {
261 image4(i,j,0) = image2(i,j,0);
262 image4(i,j,1) = image2(i,j,1);
263 image4(i,j,2) = image2(i,j,2);
264 } else
265 for( size_t c = 0; c < (size_t)image4.dimv(); c++ )
266 image4(i,j,c) = 255;
267 /* if( g > 127 ) {
268 image4(i,j,0) = image4(i,j,1) = image4(i,j,2) = 0;
269 } else {
270 image4(i,j,0) = image4(i,j,1) = image4(i,j,2) = 255;
271 }*/
272 }
273 CImgDisplay main_disp(image4,"Segmentation result",0);
274 image4.save_jpeg(file_o,100);
275
276 while( !main_disp.is_closed )
277 cimg::wait( 40 );
278
279 return 0;
280 }
281
282
283 double myBP( BinaryPairwiseGM &net, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp ) {
284 clock_t tic = toc();
285
286 if( verbose >= 1 )
287 cout << "Starting myBP..." << endl;
288
289 size_t nr_messages = net.w.nnz();
290 ublasmatrix message( net.w );
291 for( size_t ij = 0; ij < nr_messages; ij++ )
292 message.value_data()[ij] = 0.0;
293 // NOTE: message(i,j) is \mu_{j\to i}
294
295 m = ublasvector(net.N);
296
297 size_t _iterations = 0;
298 double max_diff = 1.0;
299 for( _iterations = 0; _iterations < maxiter && max_diff > tol; _iterations++ ) {
300 // walk through the sparse array structure
301 // this is similar to matlab sparse arrays
302 // index2 gives the column index (ir in matlab)
303 // index1 gives the starting indices for each row (jc in matlab)
304 // for( size_t t = 0; t < 3; t++ ) {
305 size_t i = 0;
306 max_diff = 0.0;
307 for( size_t pos = 0; pos < nr_messages; pos++ ) {
308 while( pos == net.w.index1_data()[i+1] )
309 i++;
310 size_t j = net.w.index2_data()[pos];
311 double w_ij = net.w.value_data()[pos];
312 // \mu_{j\to i} = \atanh \tanh w_{ij} \tanh (\theta_j + \sum_{k\in\nb{j}\setm i} \mu_{k\to j})
313 double field = sum(row(message,j)) - message(j,i) + net.th[j];
314 double new_message = atanh( tanh( w_ij ) * tanh( field ) );
315 double diff = fabs(message(i,j) - new_message);
316 if( diff > max_diff )
317 max_diff = diff;
318 // if( (pos % 3) == t )
319 message(i,j) = new_message;
320 }
321 // }
322
323 if( verbose >= 3 )
324 cout << "myBP: maxdiff " << max_diff << " after " << _iterations+1 << " passes" << endl;
325
326 for( size_t j = 0; j < net.N; j++ ) {
327 // m_j = \tanh (\theta_j + \sum_{k\in\nb{j}} \mu_{k\to j})
328 double field = sum(row(message,j)) + net.th[j];
329 m[j] = tanh( field );
330 }
331 CImg<unsigned char> image(dimx,dimy,1,3);
332 for( size_t i = 0; i < dimx; i++ )
333 for( size_t j = 0; j < dimy; j++ ) {
334 // unsigned char g = (unsigned char)(bp.belief(fg.var(i*dimy+j))[1] * 255.0);
335 unsigned char g = (unsigned char)((m[i*dimy+j] + 1.0) / 2.0 * 255.0);
336 if( g > 127 ) {
337 image(i,j,0) = 255;
338 image(i,j,1) = 2 * (g - 127);
339 image(i,j,2) = 2 * (g - 127);
340 } else {
341 image(i,j,0) = 0;
342 image(i,j,1) = 0;
343 image(i,j,2) = 2*g;
344 }
345 }
346 disp << image;
347 char filename[30] = "/tmp/movie000.jpg";
348 sprintf( &filename[10], "%03ld", (long)_iterations );
349 strcat( filename, ".jpg" );
350 image.save_jpeg(filename,100);
351 }
352
353 if( verbose >= 1 ) {
354 if( max_diff > tol ) {
355 if( verbose == 1 )
356 cout << endl;
357 cout << "myBP: WARNING: not converged within " << maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << max_diff << endl;
358 } else {
359 if( verbose >= 3 )
360 cout << "myBP: ";
361 cout << "converged in " << _iterations << " passes (" << toc() - tic << " clocks)." << endl;
362 }
363 }
364
365 return max_diff;
366 }
367
368
369 double doInference( FactorGraph& fg, string AlgOpts, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp ) {
370 InfAlg* ia = newInfAlgFromString( AlgOpts, fg );
371 ia->init();
372
373 m = ublasvector( fg.nrVars() );
374 CImg<unsigned char> image(dimx,dimy,1,3);
375
376 size_t _iterations = 0;
377 double max_diff = 1.0;
378 for( _iterations = 0; _iterations < maxiter && max_diff > tol; _iterations++ ) {
379 max_diff = ia->run();
380 for( size_t i = 0; i < fg.nrVars(); i++ )
381 m[i] = ia->beliefV(i)[1] - ia->beliefV(i)[0];
382 for( size_t i = 0; i < dimx; i++ )
383 for( size_t j = 0; j < dimy; j++ ) {
384 // unsigned char g = (unsigned char)(ia->beliefV(i*dimy+j)[1] * 255.0);
385 unsigned char g = (unsigned char)((m[i*dimy+j] + 1.0) / 2.0 * 255.0);
386 if( g > 127 ) {
387 image(i,j,0) = 255;
388 image(i,j,1) = 2 * (g - 127);
389 image(i,j,2) = 2 * (g - 127);
390 } else {
391 image(i,j,0) = 0;
392 image(i,j,1) = 0;
393 image(i,j,2) = 2*g;
394 }
395 }
396 disp << image;
397 /*
398 char filename[30] = "/tmp/movie000.jpg";
399 sprintf( &filename[10], "%03ld", (long)_iterations );
400 strcat( filename, ".jpg" );
401 image.save_jpeg(filename,100);
402 */
403 cout << "_iterations = " << _iterations << ", max_diff = " << max_diff << endl;
404 }
405
406 delete ia;
407
408 return max_diff;
409 }
410
411
412 double myMF( BinaryPairwiseGM &net, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp ) {
413 clock_t tic = toc();
414
415 if( verbose >= 1 )
416 cout << "Starting myMF..." << endl;
417
418 m = ublasvector(net.N);
419 for( size_t i = 0; i < net.N; i++ )
420 m[i] = 0.0;
421
422 size_t _iterations = 0;
423 double max_diff = 1.0;
424 for( _iterations = 0; _iterations < maxiter && max_diff > tol; _iterations++ ) {
425 max_diff = 0.0;
426 for( size_t t = 0; t < net.N; t++ ) {
427 size_t i = (size_t)(rnd_uniform() * net.N);
428 double new_m_i = tanh(net.th[i] + inner_prod(row(net.w,i), m));
429 double diff = fabs( new_m_i - m[i] );
430 if( diff > max_diff )
431 max_diff = diff;
432 m[i] = new_m_i;
433 }
434
435 if( verbose >= 3 )
436 cout << "myMF: maxdiff " << max_diff << " after " << _iterations+1 << " passes" << endl;
437
438 CImg<unsigned char> image(dimx,dimy,1,3);
439 for( size_t i = 0; i < dimx; i++ )
440 for( size_t j = 0; j < dimy; j++ ) {
441 // unsigned char g = (unsigned char)(bp.belief(fg.var(i*dimy+j))[1] * 255.0);
442 unsigned char g = (unsigned char)((m[i*dimy+j] + 1.0) / 2.0 * 255.0);
443 if( g > 127 ) {
444 image(i,j,0) = 255;
445 image(i,j,1) = 2 * (g - 127);
446 image(i,j,2) = 2 * (g - 127);
447 } else {
448 image(i,j,0) = 0;
449 image(i,j,1) = 0;
450 image(i,j,2) = 2*g;
451 }
452 }
453 disp << image;
454 char filename[30] = "/tmp/movie000.jpg";
455 sprintf( &filename[10], "%03ld", (long)_iterations );
456 strcat( filename, ".jpg" );
457 image.save_jpeg(filename,100);
458 }
459
460 if( verbose >= 1 ) {
461 if( max_diff > tol ) {
462 if( verbose == 1 )
463 cout << endl;
464 cout << "myMF: WARNING: not converged within " << maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << max_diff << endl;
465 } else {
466 if( verbose >= 3 )
467 cout << "myMF: ";
468 cout << "converged in " << _iterations << " passes (" << toc() - tic << " clocks)." << endl;
469 }
470 }
471
472 return max_diff;
473 }
474
475
476 BinaryPairwiseGM::BinaryPairwiseGM( const FactorGraph &fg ) {
477 assert( fg.isPairwise() );
478 assert( fg.isBinary() );
479
480 // create w and th
481 N = fg.nrVars();
482
483 // count non_zeros in w
484 size_t non_zeros = 0;
485 for( size_t I = 0; I < fg.nrFactors(); I++ )
486 if( fg.factor(I).vars().size() == 2 )
487 non_zeros++;
488 w = ublasmatrix(N, N, non_zeros * 2);
489
490 th = ublasvector(N);
491 for( size_t i = 0; i < N; i++ )
492 th[i] = 0.0;
493
494 logZ0 = 0.0;
495
496 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
497 const Factor &psi = fg.factor(I);
498 if( psi.vars().size() == 0 )
499 logZ0 += dai::log( psi[0] );
500 else if( psi.vars().size() == 1 ) {
501 size_t i = fg.findVar( *(psi.vars().begin()) );
502 th[i] += 0.5 * (dai::log(psi[1]) - dai::log(psi[0]));
503 logZ0 += 0.5 * (dai::log(psi[0]) + dai::log(psi[1]));
504 } else if( psi.vars().size() == 2 ) {
505 size_t i = fg.findVar( *(psi.vars().begin()) );
506 VarSet::const_iterator jit = psi.vars().begin();
507 size_t j = fg.findVar( *(++jit) );
508
509 double w_ij = 0.25 * (dai::log(psi[3]) + dai::log(psi[0]) - dai::log(psi[2]) - dai::log(psi[1]));
510 w(i,j) += w_ij;
511 w(j,i) += w_ij;
512
513 th[i] += 0.25 * (dai::log(psi[3]) - dai::log(psi[2]) + dai::log(psi[1]) - dai::log(psi[0]));
514 th[j] += 0.25 * (dai::log(psi[3]) - dai::log(psi[1]) + dai::log(psi[2]) - dai::log(psi[0]));
515
516 logZ0 += 0.25 * (dai::log(psi[0]) + dai::log(psi[1]) + dai::log(psi[2]) + dai::log(psi[3]));
517 }
518 }
519 }
520
521
522 double BinaryPairwiseGM::doBP( size_t maxiter, double tol, size_t verbose, ublasvector &m ) {
523 double tic = toc();
524
525 if( verbose >= 1 )
526 cout << "Starting BinaryPairwiseGM::doBP..." << endl;
527
528 size_t nr_messages = w.nnz();
529 ublasmatrix message( w );
530 for( size_t ij = 0; ij < nr_messages; ij++ )
531 message.value_data()[ij] = 0.0;
532 // NOTE: message(i,j) is \mu_{j\to i}
533 Real maxDiff = INFINITY;
534
535 size_t _iterations = 0;
536 for( _iterations = 0; _iterations < maxiter && maxDiff > tol; _iterations++ ) {
537 // walk through the sparse array structure
538 // this is similar to matlab sparse arrays
539 // index2 gives the column index (ir in matlab)
540 // index1 gives the starting indices for each row (jc in matlab)
541 size_t i = 0;
542 maxDiff = -INFINITY;
543 for( size_t pos = 0; pos < nr_messages; pos++ ) {
544 while( pos == w.index1_data()[i+1] )
545 i++;
546 size_t j = w.index2_data()[pos];
547 double w_ij = w.value_data()[pos];
548 // \mu_{j\to i} = \atanh \tanh w_{ij} \tanh (\theta_j + \sum_{k\in\nb{j}\setm i} \mu_{k\to j})
549 double field = sum(row(message,j)) - message(j,i) + th[j];
550 double new_message = atanh( tanh( w_ij ) * tanh( field ) );
551 maxDiff = std::max( maxDiff, fabs(message(i,j) - new_message) );
552 message(i,j) = new_message;
553 }
554
555 if( verbose >= 3 )
556 cout << "BinaryPairwiseGM::doBP: maxdiff " << maxDiff << " after " << _iterations+1 << " passes" << endl;
557 }
558
559 m = ublasvector(N);
560 for( size_t j = 0; j < N; j++ ) {
561 // m_j = \tanh (\theta_j + \sum_{k\in\nb{j}} \mu_{k\to j})
562 double field = sum(row(message,j)) + th[j];
563 m[j] = tanh( field );
564 }
565
566 if( verbose >= 1 ) {
567 if( maxDiff > tol ) {
568 if( verbose == 1 )
569 cout << endl;
570 cout << "BinaryPairwiseGM::doBP: WARNING: not converged within " << maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << maxDiff << endl;
571 } else {
572 if( verbose >= 3 )
573 cout << "BinaryPairwiseGM::doBP: ";
574 cout << "converged in " << _iterations << " passes (" << toc() - tic << " clocks)." << endl;
575 }
576 }
577
578 return maxDiff;
579 }
580
581
582 FactorGraph BinaryPairwiseGM::toFactorGraph() {
583 vector<Var> vars;
584 vector<Factor> factors;
585
586 // create variables
587 vars.reserve( N );
588 for( size_t i = 0; i < N; i++ )
589 vars.push_back( Var( i, 2 ) );
590
591 // create single-variable factors
592 size_t nrE = w.nnz();
593 factors.reserve( N + nrE / 2 );
594 for( size_t i = 0; i < N; i++ )
595 factors.push_back( createFactorIsing( vars[i], th[i] ) );
596
597 // create pairwise factors
598 // walk through the sparse array structure
599 // this is similar to matlab sparse arrays
600 size_t i = 0;
601 for( size_t pos = 0; pos < nrE; pos++ ) {
602 while( pos == w.index1_data()[i+1] )
603 i++;
604 size_t j = w.index2_data()[pos];
605 double w_ij = w.value_data()[pos];
606 if( i < j )
607 factors.push_back( createFactorIsing( vars[i], vars[j], w_ij ) );
608 }
609
610 factors.front() *= dai::exp( logZ0 );
611
612 return FactorGraph( factors.begin(), factors.end(), vars.begin(), vars.end(), factors.size(), vars.size() );
613 }