Changed to newer CImg interface and fixed two regressions
[libdai.git] / examples / example_imagesegmentation.cpp
1 #include <iostream>
2 #include <vector>
3 #include <iterator>
4 #include <algorithm>
5 #include <dai/alldai.h>
6 #include <boost/numeric/ublas/matrix_sparse.hpp>
7 #include <boost/numeric/ublas/matrix_proxy.hpp>
8 #include <boost/numeric/ublas/vector.hpp>
9 #include <boost/numeric/ublas/io.hpp>
10 #include <CImg.h>
11
12 using namespace std;
13 using namespace cimg_library;
14 using namespace dai;
15
16 typedef boost::numeric::ublas::vector<double> ublasvector;
17 typedef boost::numeric::ublas::compressed_matrix<double> ublasmatrix;
18 typedef ublasmatrix::value_array_type::const_iterator matrix_vcit;
19 typedef ublasmatrix::index_array_type::const_iterator matrix_icit;
20
21
22 class BinaryPairwiseGM {
23 public:
24 size_t N;
25 ublasmatrix w;
26 ublasvector th;
27 double logZ0;
28
29 BinaryPairwiseGM() {}
30 BinaryPairwiseGM( const FactorGraph &fg );
31 BinaryPairwiseGM( size_t _N, const ublasmatrix &_w, const ublasvector &_th, double _logZ0 ) : N(_N), w(_w), th(_th), logZ0(_logZ0) {}
32 BinaryPairwiseGM( const BinaryPairwiseGM &x ) : N(x.N), w(x.w), th(x.th), logZ0(x.logZ0) {};
33 BinaryPairwiseGM & operator=( const BinaryPairwiseGM &x ) {
34 if( this != &x ) {
35 N = x.N;
36 w = x.w;
37 th = x.th;
38 logZ0 = x.logZ0;
39 }
40 return *this;
41 }
42 double doBP( size_t maxiter, double tol, size_t verbose, ublasvector &m );
43 FactorGraph toFactorGraph();
44 };
45
46
47 // w should be upper triangular or lower triangular
48 void WTh2FG( const ublasmatrix &w, const vector<double> &th, FactorGraph &fg ) {
49 vector<Var> vars;
50 vector<Factor> factors;
51
52 size_t N = th.size();
53 assert( (w.size1() == N) && (w.size2() == N) );
54
55 vars.reserve(N);
56 for( size_t i = 0; i < N; i++ )
57 vars.push_back(Var(i,2));
58
59 factors.reserve( w.nnz() + N );
60 // walk through the sparse array structure
61 // this is similar to matlab sparse arrays
62 // index2 gives the column index
63 // index1 gives the starting indices for each row
64 size_t i = 0;
65 // cout << w << endl;
66 for( size_t pos = 0; pos < w.nnz(); pos++ ) {
67 while( pos == w.index1_data()[i+1] )
68 i++;
69 size_t j = w.index2_data()[pos];
70 double w_ij = w.value_data()[pos];
71 // cout << "(" << i << "," << j << "): " << w_ij << endl;
72 factors.push_back( createFactorIsing( vars[i], vars[j], w_ij ) );
73 }
74 for( size_t i = 0; i < N; i++ )
75 factors.push_back( createFactorIsing( vars[i], th[i] ) );
76
77 fg = FactorGraph(factors);
78 }
79
80
81 template<class T>
82 void Image2net( const CImg<T> &img, double J, double th_min, double th_plus, double th_tol, double p_background, BinaryPairwiseGM &net ) {
83 size_t dimx = img.dimx();
84 size_t dimy = img.dimy();
85
86 net.N = dimx * dimy;
87 net.w = ublasmatrix(net.N,net.N,4*net.N);
88 net.th = ublasvector(net.N);
89 for( size_t i = 0; i < net.N; i++ )
90 net.th[i] = 0.0;
91 net.logZ0 = 0.0;
92
93 CImg<float> hist = img.get_channel(0).get_histogram(256,0,255);
94 size_t cum_hist = 0;
95 size_t level = 0;
96 for( level = 0; level < 256; level++ ) {
97 cum_hist += (size_t)hist(level);
98 if( cum_hist > p_background * dimx * dimy )
99 break;
100 }
101
102 double th_avg = (th_min + th_plus) / 2.0;
103 double th_width = (th_plus - th_min) / 2.0;
104 for( size_t i = 0; i < dimx; i++ )
105 for( size_t j = 0; j < dimy; j++ ) {
106 if( i+1 < dimx )
107 net.w(i*dimy+j, (i+1)*dimy+j) = J;
108 if( i >= 1 )
109 net.w(i*dimy+j, (i-1)*dimy+j) = J;
110 if( j+1 < dimy )
111 net.w(i*dimy+j, i*dimy+(j+1)) = J;
112 if( j >= 1 )
113 net.w(i*dimy+j, i*dimy+(j-1)) = J;
114 double x = img(i,j);
115 net.th[i*dimy+j] = th_avg + th_width * tanh((x - level)/th_tol);
116 /* if( x < level )
117 x = x / level * 0.5;
118 else
119 x = 0.5 + 0.5 * ((x - level) / (255 - level));*/
120 /* if( x < level )
121 x = 0.01;
122 else
123 x = 0.99;
124 th[i*dimy+j] = 0.5 * (log(x) - log(1.0 - x));*/
125 }
126 }
127
128
129 template<class T>
130 FactorGraph img2fg( const CImg<T> &img, double J, double th_min, double th_plus, double th_tol, double p_background ) {
131 vector<Var> vars;
132 vector<Factor> factors;
133
134 size_t dimx = img.width();
135 size_t dimy = img.height();
136 size_t N = dimx * dimy;
137
138 // create variables
139 cout << "Creating " << N << " variables..." << endl;
140 vars.reserve( N );
141 for( size_t i = 0; i < N; i++ )
142 vars.push_back( Var( i, 2 ) );
143
144 // build histogram
145 CImg<float> hist = img.get_channel(0).get_histogram(256,0,255);
146 size_t cum_hist = 0;
147 size_t level = 0;
148 for( level = 0; level < 256; level++ ) {
149 cum_hist += (size_t)hist(level);
150 if( cum_hist > p_background * dimx * dimy )
151 break;
152 }
153
154 // create factors
155 cout << "Creating " << (3 * N - dimx - dimy) << " factors..." << endl;
156 factors.reserve( 3 * N - dimx - dimy );
157 double th_avg = (th_min + th_plus) / 2.0;
158 double th_width = (th_plus - th_min) / 2.0;
159 for( size_t i = 0; i < dimx; i++ )
160 for( size_t j = 0; j < dimy; j++ ) {
161 if( i >= 1 )
162 factors.push_back( createFactorIsing( vars[i*dimy+j], vars[(i-1)*dimy+j], J ) );
163 if( j >= 1 )
164 factors.push_back( createFactorIsing( vars[i*dimy+j], vars[i*dimy+(j-1)], J ) );
165 double x = img(i,j);
166 factors.push_back( createFactorIsing( vars[i*dimy+j], th_avg + th_width * tanh((x - level)/th_tol) ) );
167 }
168
169 cout << "Creating factor graph..." << endl;
170 return FactorGraph( factors.begin(), factors.end(), vars.begin(), vars.end(), factors.size(), vars.size() );
171 }
172
173
174 double myBP( BinaryPairwiseGM &net, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp );
175 double myMF( BinaryPairwiseGM &net, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp );
176 double doInference( FactorGraph &fg, string AlgOpts, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp );
177
178 int main(int argc,char **argv) {
179 // Display program usage, when invoked from the command line with option '-h'.
180 cimg_usage("Usage: example_imagesegmentation -i <inputimage1> -j <inputimage2> -o <outputimage1> -p <outputimage2> -J <J> -t <t> -s <s> -u <u> -x <x>");
181 const char* file_i = cimg_option("-i","","Input image 1");
182 const char* file_j = cimg_option("-j","","Input image 2");
183 const char* file_o = cimg_option("-o","","Output image (with BP)");
184 const char* file_p = cimg_option("-p","","Output image (without BP)");
185 const double J = cimg_option("-J",0.0,"Coupling strength");
186 const double th_min = cimg_option("-t",0.0,"Local evidence strength background");
187 const double th_plus = cimg_option("-s",0.0,"Local evidence strength foreground");
188 const double th_tol = cimg_option("-u",0.0,"Sensitivity for fore/background");
189 const double p_background = cimg_option("-x",0.0,"Percentage of background in image");
190
191 CImg<unsigned char> image1 = CImg<>(file_i);
192 CImg<unsigned char> image2 = CImg<>(file_j);
193
194 CImg<int> image3(image1);
195 image3 -= image2;
196 image3.abs();
197 image3.norm(1); // 1 = L1, 2 = L2, -1 = Linf
198 // normalize
199 for( size_t i = 0; i < image3.width(); i++ ) {
200 for( size_t j = 0; j < image3.height(); j++ ) {
201 int avg = 0;
202 for( size_t c = 0; c < image1.spectrum(); c++ )
203 avg += image1(i,j,c);
204 avg /= image1.spectrum();
205 image3(i,j,0) /= (1.0 + avg / 255.0);
206 }
207 }
208 image3.normalize(0,255);
209
210 CImgDisplay disp1(image1,"Input 1",0);
211 CImgDisplay disp2(image2,"Input 2",0);
212 CImgDisplay disp3(image3,"Absolute difference of both inputs",0);
213
214 //BinaryPairwiseGM net;
215 //Image2net( image3, J, th_min, th_plus, th_tol, p_background, net );
216 FactorGraph fg = img2fg( image3, J, th_min, th_plus, th_tol, p_background );
217 cout << "Done" << endl;
218
219 size_t dimx = image3.width();
220 size_t dimy = image3.height();
221 CImg<unsigned char> image4(dimx,dimy,1,3);
222
223 ublasvector m;
224 //net.doBP( 0, 1e-2, 3, m );
225 BP bp( fg, PropertySet("[updates=SEQFIX,maxiter=0,tol=1e-9,verbose=0,logdomain=0]") );
226 bp.init();
227 for( size_t i = 0; i < dimx; i++ )
228 for( size_t j = 0; j < dimy; j++ ) {
229 unsigned char g = (unsigned char)(bp.belief(fg.var(i*dimy+j))[1] * 255.0);
230 // unsigned char g = (unsigned char)((m[i*dimy+j] + 1.0) / 2.0 * 255.0);
231 if( g > 127 ) {
232 image4(i,j,0) = 255;
233 image4(i,j,1) = 2 * (g - 127);
234 image4(i,j,2) = 2 * (g - 127);
235 } else {
236 image4(i,j,0) = 0;
237 image4(i,j,1) = 0;
238 image4(i,j,2) = 2*g;
239 }
240 }
241 CImgDisplay disp4(image4,"Local evidence",0);
242 image4.save_jpeg(file_p,100);
243
244 // solve the problem and show intermediate steps
245 CImgDisplay disp5(dimx,dimy,"Beliefs during inference",0);
246 if( 1 ) {
247 //FactorGraph fg = net.toFactorGraph();
248 fg.WriteToFile( "joris.fg" );
249
250 doInference( fg, "BP[updates=SEQMAX,maxiter=1,tol=1e-9,verbose=0,logdomain=0]", 1000, 1e-5, 3, m, dimx, dimy, disp5 );
251 // doInference( fg, "HAK[doubleloop=0,clusters=LOOP,init=UNIFORM,loopdepth=4,tol=1e-9,maxiter=1,verbose=3]", 1000, 1e-5, 3, m, dimx, dimy, disp5 );
252 // doInference( fg, "HAK[doubleloop=0,clusters=BETHE,init=UNIFORM,maxiter=1,tol=1e-9,verbose=3]", 1000, 1e-5, 3, m, dimx, dimy, disp5 );
253 // doInference( fg, "MF[tol=1e-9,maxiter=1,damping=0.0,init=RANDOM,updates=NAIVE]", 1000, 1e-5, 3, m, dimx, dimy, disp5 );
254 } else {
255 // myBP( net, 1000, 1e-5, 3, m, dimx, dimy, disp5 );
256 // myMF( net, 1000, 1e-5, 3, m, dimx, dimy, disp5 );
257 }
258
259 for( size_t i = 0; i < dimx; i++ )
260 for( size_t j = 0; j < dimy; j++ ) {
261 // unsigned char g = (unsigned char)(bp.belief(fg.var(i*dimy+j))[1] * 255.0);
262 unsigned char g = (unsigned char)((m[i*dimy+j] + 1.0) / 2.0 * 255.0);
263 if( g > 127 ) {
264 image4(i,j,0) = image2(i,j,0);
265 image4(i,j,1) = image2(i,j,1);
266 image4(i,j,2) = image2(i,j,2);
267 } else
268 for( size_t c = 0; c < (size_t)image4.spectrum(); c++ )
269 image4(i,j,c) = 255;
270 }
271 CImgDisplay main_disp(image4,"Segmentation result",0);
272 image4.save_jpeg(file_o,100);
273
274 while( !main_disp.is_closed() )
275 cimg::wait( 40 );
276
277 return 0;
278 }
279
280
281 double myBP( BinaryPairwiseGM &net, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp ) {
282 clock_t tic = toc();
283
284 if( verbose >= 1 )
285 cout << "Starting myBP..." << endl;
286
287 size_t nr_messages = net.w.nnz();
288 ublasmatrix message( net.w );
289 for( size_t ij = 0; ij < nr_messages; ij++ )
290 message.value_data()[ij] = 0.0;
291 // NOTE: message(i,j) is \mu_{j\to i}
292
293 m = ublasvector(net.N);
294
295 size_t _iterations = 0;
296 double max_diff = 1.0;
297 for( _iterations = 0; _iterations < maxiter && max_diff > tol; _iterations++ ) {
298 // walk through the sparse array structure
299 // this is similar to matlab sparse arrays
300 // index2 gives the column index (ir in matlab)
301 // index1 gives the starting indices for each row (jc in matlab)
302 // for( size_t t = 0; t < 3; t++ ) {
303 size_t i = 0;
304 max_diff = 0.0;
305 for( size_t pos = 0; pos < nr_messages; pos++ ) {
306 while( pos == net.w.index1_data()[i+1] )
307 i++;
308 size_t j = net.w.index2_data()[pos];
309 double w_ij = net.w.value_data()[pos];
310 // \mu_{j\to i} = \atanh \tanh w_{ij} \tanh (\theta_j + \sum_{k\in\nb{j}\setm i} \mu_{k\to j})
311 double field = sum(row(message,j)) - message(j,i) + net.th[j];
312 double new_message = atanh( tanh( w_ij ) * tanh( field ) );
313 double diff = fabs(message(i,j) - new_message);
314 if( diff > max_diff )
315 max_diff = diff;
316 // if( (pos % 3) == t )
317 message(i,j) = new_message;
318 }
319 // }
320
321 if( verbose >= 3 )
322 cout << "myBP: maxdiff " << max_diff << " after " << _iterations+1 << " passes" << endl;
323
324 for( size_t j = 0; j < net.N; j++ ) {
325 // m_j = \tanh (\theta_j + \sum_{k\in\nb{j}} \mu_{k\to j})
326 double field = sum(row(message,j)) + net.th[j];
327 m[j] = tanh( field );
328 }
329 CImg<unsigned char> image(dimx,dimy,1,3);
330 for( size_t i = 0; i < dimx; i++ )
331 for( size_t j = 0; j < dimy; j++ ) {
332 // unsigned char g = (unsigned char)(bp.belief(fg.var(i*dimy+j))[1] * 255.0);
333 unsigned char g = (unsigned char)((m[i*dimy+j] + 1.0) / 2.0 * 255.0);
334 if( g > 127 ) {
335 image(i,j,0) = 255;
336 image(i,j,1) = 2 * (g - 127);
337 image(i,j,2) = 2 * (g - 127);
338 } else {
339 image(i,j,0) = 0;
340 image(i,j,1) = 0;
341 image(i,j,2) = 2*g;
342 }
343 }
344 disp = image;
345 char filename[30] = "/tmp/movie000.jpg";
346 sprintf( &filename[10], "%03ld", (long)_iterations );
347 strcat( filename, ".jpg" );
348 image.save_jpeg(filename,100);
349 }
350
351 if( verbose >= 1 ) {
352 if( max_diff > tol ) {
353 if( verbose == 1 )
354 cout << endl;
355 cout << "myBP: WARNING: not converged within " << maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << max_diff << endl;
356 } else {
357 if( verbose >= 3 )
358 cout << "myBP: ";
359 cout << "converged in " << _iterations << " passes (" << toc() - tic << " clocks)." << endl;
360 }
361 }
362
363 return max_diff;
364 }
365
366
367 double doInference( FactorGraph& fg, string AlgOpts, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp ) {
368 InfAlg* ia = newInfAlgFromString( AlgOpts, fg );
369 ia->init();
370
371 m = ublasvector( fg.nrVars() );
372 CImg<unsigned char> image(dimx,dimy,1,3);
373
374 size_t _iterations = 0;
375 double max_diff = 1.0;
376 for( _iterations = 0; _iterations < maxiter && max_diff > tol; _iterations++ ) {
377 max_diff = ia->run();
378 for( size_t i = 0; i < fg.nrVars(); i++ )
379 m[i] = ia->beliefV(i)[1] - ia->beliefV(i)[0];
380 for( size_t i = 0; i < dimx; i++ )
381 for( size_t j = 0; j < dimy; j++ ) {
382 // unsigned char g = (unsigned char)(ia->beliefV(i*dimy+j)[1] * 255.0);
383 unsigned char g = (unsigned char)((m[i*dimy+j] + 1.0) / 2.0 * 255.0);
384 if( g > 127 ) {
385 image(i,j,0) = 255;
386 image(i,j,1) = 2 * (g - 127);
387 image(i,j,2) = 2 * (g - 127);
388 } else {
389 image(i,j,0) = 0;
390 image(i,j,1) = 0;
391 image(i,j,2) = 2*g;
392 }
393 }
394 disp = image;
395 /*
396 char filename[30] = "/tmp/movie000.jpg";
397 sprintf( &filename[10], "%03ld", (long)_iterations );
398 strcat( filename, ".jpg" );
399 image.save_jpeg(filename,100);
400 */
401 cout << "_iterations = " << _iterations << ", max_diff = " << max_diff << endl;
402 }
403
404 delete ia;
405
406 return max_diff;
407 }
408
409
410 double myMF( BinaryPairwiseGM &net, size_t maxiter, double tol, size_t verbose, ublasvector &m, size_t dimx, size_t dimy, CImgDisplay &disp ) {
411 clock_t tic = toc();
412
413 if( verbose >= 1 )
414 cout << "Starting myMF..." << endl;
415
416 m = ublasvector(net.N);
417 for( size_t i = 0; i < net.N; i++ )
418 m[i] = 0.0;
419
420 size_t _iterations = 0;
421 double max_diff = 1.0;
422 for( _iterations = 0; _iterations < maxiter && max_diff > tol; _iterations++ ) {
423 max_diff = 0.0;
424 for( size_t t = 0; t < net.N; t++ ) {
425 size_t i = (size_t)(rnd_uniform() * net.N);
426 double new_m_i = tanh(net.th[i] + inner_prod(row(net.w,i), m));
427 double diff = fabs( new_m_i - m[i] );
428 if( diff > max_diff )
429 max_diff = diff;
430 m[i] = new_m_i;
431 }
432
433 if( verbose >= 3 )
434 cout << "myMF: maxdiff " << max_diff << " after " << _iterations+1 << " passes" << endl;
435
436 CImg<unsigned char> image(dimx,dimy,1,3);
437 for( size_t i = 0; i < dimx; i++ )
438 for( size_t j = 0; j < dimy; j++ ) {
439 // unsigned char g = (unsigned char)(bp.belief(fg.var(i*dimy+j))[1] * 255.0);
440 unsigned char g = (unsigned char)((m[i*dimy+j] + 1.0) / 2.0 * 255.0);
441 if( g > 127 ) {
442 image(i,j,0) = 255;
443 image(i,j,1) = 2 * (g - 127);
444 image(i,j,2) = 2 * (g - 127);
445 } else {
446 image(i,j,0) = 0;
447 image(i,j,1) = 0;
448 image(i,j,2) = 2*g;
449 }
450 }
451 disp = image;
452 char filename[30] = "/tmp/movie000.jpg";
453 sprintf( &filename[10], "%03ld", (long)_iterations );
454 strcat( filename, ".jpg" );
455 image.save_jpeg(filename,100);
456 }
457
458 if( verbose >= 1 ) {
459 if( max_diff > tol ) {
460 if( verbose == 1 )
461 cout << endl;
462 cout << "myMF: WARNING: not converged within " << maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << max_diff << endl;
463 } else {
464 if( verbose >= 3 )
465 cout << "myMF: ";
466 cout << "converged in " << _iterations << " passes (" << toc() - tic << " clocks)." << endl;
467 }
468 }
469
470 return max_diff;
471 }
472
473
474 BinaryPairwiseGM::BinaryPairwiseGM( const FactorGraph &fg ) {
475 assert( fg.isPairwise() );
476 assert( fg.isBinary() );
477
478 // create w and th
479 N = fg.nrVars();
480
481 // count non_zeros in w
482 size_t non_zeros = 0;
483 for( size_t I = 0; I < fg.nrFactors(); I++ )
484 if( fg.factor(I).vars().size() == 2 )
485 non_zeros++;
486 w = ublasmatrix(N, N, non_zeros * 2);
487
488 th = ublasvector(N);
489 for( size_t i = 0; i < N; i++ )
490 th[i] = 0.0;
491
492 logZ0 = 0.0;
493
494 for( size_t I = 0; I < fg.nrFactors(); I++ ) {
495 const Factor &psi = fg.factor(I);
496 if( psi.vars().size() == 0 )
497 logZ0 += dai::log( psi[0] );
498 else if( psi.vars().size() == 1 ) {
499 size_t i = fg.findVar( *(psi.vars().begin()) );
500 th[i] += 0.5 * (dai::log(psi[1]) - dai::log(psi[0]));
501 logZ0 += 0.5 * (dai::log(psi[0]) + dai::log(psi[1]));
502 } else if( psi.vars().size() == 2 ) {
503 size_t i = fg.findVar( *(psi.vars().begin()) );
504 VarSet::const_iterator jit = psi.vars().begin();
505 size_t j = fg.findVar( *(++jit) );
506
507 double w_ij = 0.25 * (dai::log(psi[3]) + dai::log(psi[0]) - dai::log(psi[2]) - dai::log(psi[1]));
508 w(i,j) += w_ij;
509 w(j,i) += w_ij;
510
511 th[i] += 0.25 * (dai::log(psi[3]) - dai::log(psi[2]) + dai::log(psi[1]) - dai::log(psi[0]));
512 th[j] += 0.25 * (dai::log(psi[3]) - dai::log(psi[1]) + dai::log(psi[2]) - dai::log(psi[0]));
513
514 logZ0 += 0.25 * (dai::log(psi[0]) + dai::log(psi[1]) + dai::log(psi[2]) + dai::log(psi[3]));
515 }
516 }
517 }
518
519
520 double BinaryPairwiseGM::doBP( size_t maxiter, double tol, size_t verbose, ublasvector &m ) {
521 double tic = toc();
522
523 if( verbose >= 1 )
524 cout << "Starting BinaryPairwiseGM::doBP..." << endl;
525
526 size_t nr_messages = w.nnz();
527 ublasmatrix message( w );
528 for( size_t ij = 0; ij < nr_messages; ij++ )
529 message.value_data()[ij] = 0.0;
530 // NOTE: message(i,j) is \mu_{j\to i}
531 Real maxDiff = INFINITY;
532
533 size_t _iterations = 0;
534 for( _iterations = 0; _iterations < maxiter && maxDiff > tol; _iterations++ ) {
535 // walk through the sparse array structure
536 // this is similar to matlab sparse arrays
537 // index2 gives the column index (ir in matlab)
538 // index1 gives the starting indices for each row (jc in matlab)
539 size_t i = 0;
540 maxDiff = -INFINITY;
541 for( size_t pos = 0; pos < nr_messages; pos++ ) {
542 while( pos == w.index1_data()[i+1] )
543 i++;
544 size_t j = w.index2_data()[pos];
545 double w_ij = w.value_data()[pos];
546 // \mu_{j\to i} = \atanh \tanh w_{ij} \tanh (\theta_j + \sum_{k\in\nb{j}\setm i} \mu_{k\to j})
547 double field = sum(row(message,j)) - message(j,i) + th[j];
548 double new_message = atanh( tanh( w_ij ) * tanh( field ) );
549 maxDiff = std::max( maxDiff, fabs(message(i,j) - new_message) );
550 message(i,j) = new_message;
551 }
552
553 if( verbose >= 3 )
554 cout << "BinaryPairwiseGM::doBP: maxdiff " << maxDiff << " after " << _iterations+1 << " passes" << endl;
555 }
556
557 m = ublasvector(N);
558 for( size_t j = 0; j < N; j++ ) {
559 // m_j = \tanh (\theta_j + \sum_{k\in\nb{j}} \mu_{k\to j})
560 double field = sum(row(message,j)) + th[j];
561 m[j] = tanh( field );
562 }
563
564 if( verbose >= 1 ) {
565 if( maxDiff > tol ) {
566 if( verbose == 1 )
567 cout << endl;
568 cout << "BinaryPairwiseGM::doBP: WARNING: not converged within " << maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << maxDiff << endl;
569 } else {
570 if( verbose >= 3 )
571 cout << "BinaryPairwiseGM::doBP: ";
572 cout << "converged in " << _iterations << " passes (" << toc() - tic << " clocks)." << endl;
573 }
574 }
575
576 return maxDiff;
577 }
578
579
580 FactorGraph BinaryPairwiseGM::toFactorGraph() {
581 vector<Var> vars;
582 vector<Factor> factors;
583
584 // create variables
585 vars.reserve( N );
586 for( size_t i = 0; i < N; i++ )
587 vars.push_back( Var( i, 2 ) );
588
589 // create single-variable factors
590 size_t nrE = w.nnz();
591 factors.reserve( N + nrE / 2 );
592 for( size_t i = 0; i < N; i++ )
593 factors.push_back( createFactorIsing( vars[i], th[i] ) );
594
595 // create pairwise factors
596 // walk through the sparse array structure
597 // this is similar to matlab sparse arrays
598 size_t i = 0;
599 for( size_t pos = 0; pos < nrE; pos++ ) {
600 while( pos == w.index1_data()[i+1] )
601 i++;
602 size_t j = w.index2_data()[pos];
603 double w_ij = w.value_data()[pos];
604 if( i < j )
605 factors.push_back( createFactorIsing( vars[i], vars[j], w_ij ) );
606 }
607
608 factors.front() *= dai::exp( logZ0 );
609
610 return FactorGraph( factors.begin(), factors.end(), vars.begin(), vars.end(), factors.size(), vars.size() );
611 }