updated text files
[RBC.git] / kernels.cu
1 /* This file is part of the Random Ball Cover (RBC) library.
2  * (C) Copyright 2010, Lawrence Cayton [lcayton@tuebingen.mpg.de]
3  */
4
5 #ifndef KERNELS_CU
6 #define KERNELS_CU
7
8 #include<cuda.h>
9 #include "defs.h"
10 #include "kernels.h"
11 #include<stdio.h>
12
13 // This kernel does the same thing as nnKernel, except it only considers pairs as 
14 // specified by the compPlan. 
15 __global__ void planNNKernel(const matrix Q, const unint *qMap, const matrix X, const intMatrix xMap, real *dMins, unint *dMinIDs, compPlan cP,  unint qStartPos ){
16   unint qB = qStartPos + blockIdx.y * BLOCK_SIZE;  //indexes Q
17   unint xB; //X (DB) Block;
18   unint cB; //column Block
19   unint offQ = threadIdx.y; //the offset of qPos in this block
20   unint offX = threadIdx.x; //ditto for x
21   unint i,j,k;
22   unint groupIts;
23   
24   __shared__ real min[BLOCK_SIZE][BLOCK_SIZE];
25   __shared__ unint minPos[BLOCK_SIZE][BLOCK_SIZE];
26
27   __shared__ real Xs[BLOCK_SIZE][BLOCK_SIZE];
28   __shared__ real Qs[BLOCK_SIZE][BLOCK_SIZE];
29
30   unint g; //query group of q
31   unint xG; //DB group currently being examined
32   unint numGroups;
33   unint groupCount;
34
35   g = cP.qToQGroup[qB]; 
36   numGroups = cP.numGroups[g];
37   
38   min[offQ][offX]=MAX_REAL;
39   __syncthreads();
40   
41
42   for(i=0; i<numGroups; i++){ //iterate over DB groups
43     xG = cP.qGroupToXGroup[IDX( g, i, cP.ld )];
44     groupCount = cP.groupCountX[IDX( g, i, cP.ld )];
45     groupIts = (groupCount+BLOCK_SIZE-1)/BLOCK_SIZE;
46
47     for(j=0; j<groupIts; j++){ //iterate over elements of group
48       xB=j*BLOCK_SIZE;
49
50       real ans=0;
51       for(cB=0; cB<X.pc; cB+=BLOCK_SIZE){ // iterate over cols to compute distances
52
53         Xs[offX][offQ] = X.mat[IDX( xMap.mat[IDX( xG, xB+offQ, xMap.ld )], cB+offX, X.ld )];
54         Qs[offX][offQ] = ( (qMap[qB+offQ]==DUMMY_IDX) ? 0 : Q.mat[IDX( qMap[qB+offQ], cB+offX, Q.ld )] );
55         __syncthreads();
56         
57         for(k=0; k<BLOCK_SIZE; k++)
58           ans+=DIST( Xs[k][offX], Qs[k][offQ] );
59
60         __syncthreads();
61       }
62      
63       //compare to previous min and store into shared mem if needed.
64       if(xB+offX<groupCount && ans<min[offQ][offX]){
65         min[offQ][offX]=ans;
66         minPos[offQ][offX]= xMap.mat[IDX( xG, xB+offX, xMap.ld )];
67       }
68       __syncthreads();
69     }
70   }
71   
72   //Reduce across threads
73   for(i=BLOCK_SIZE/2; i>0; i/=2){
74     if( offX<i ){
75       if( min[offQ][offX+i] < min[offQ][offX] ){
76         min[offQ][offX] = min[offQ][offX+i];
77         minPos[offQ][offX] = minPos[offQ][offX+i];      
78       }
79     }
80     __syncthreads();
81   }
82
83   if(offX==0 && qMap[qB+offQ]!=DUMMY_IDX){
84     dMins[qMap[qB+offQ]] = min[offQ][0];
85     dMinIDs[qMap[qB+offQ]] = minPos[offQ][0];
86   }
87 }
88
89
90 //This is indentical to the planNNkernel, except that it maintains a list of 32-NNs.  At 
91 //each iteration-chunk, the next 16 distances are computed, then sorted, then merged 
92 //with the previously computed 32-NNs.
93 __global__ void planKNNKernel(const matrix Q, const unint *qMap, const matrix X, const intMatrix xMap, matrix dMins, intMatrix dMinIDs, compPlan cP,  unint qStartPos ){
94   unint qB = qStartPos + blockIdx.y * BLOCK_SIZE;  //indexes Q
95   unint xB; //X (DB) Block;
96   unint cB; //column Block
97   unint offQ = threadIdx.y; //the offset of qPos in this block
98   unint offX = threadIdx.x; //ditto for x
99   unint i,j,k;
100   unint groupIts;
101   
102   __shared__ real dNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
103   __shared__ unint idNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
104
105   __shared__ real Xs[BLOCK_SIZE][BLOCK_SIZE];
106   __shared__ real Qs[BLOCK_SIZE][BLOCK_SIZE];
107
108   unint g; //query group of q
109   unint xG; //DB group currently being examined
110   unint numGroups;
111   unint groupCount;
112
113   g = cP.qToQGroup[qB]; 
114   numGroups = cP.numGroups[g];
115   
116   dNN[offQ][offX] = MAX_REAL;
117   dNN[offQ][offX+16] = MAX_REAL;
118   idNN[offQ][offX] = DUMMY_IDX;
119   idNN[offQ][offX+16] = DUMMY_IDX;
120   __syncthreads();
121   
122   for(i=0; i<numGroups; i++){ //iterate over DB groups
123     xG = cP.qGroupToXGroup[IDX( g, i, cP.ld )];
124     groupCount = cP.groupCountX[IDX( g, i, cP.ld )];
125     groupIts = (groupCount+BLOCK_SIZE-1)/BLOCK_SIZE;
126
127     for(j=0; j<groupIts; j++){ //iterate over elements of group
128       xB=j*BLOCK_SIZE;
129
130       real ans=0;
131       for(cB=0; cB<X.pc; cB+=BLOCK_SIZE){ // iterate over cols to compute distances
132
133         Xs[offX][offQ] = X.mat[IDX( xMap.mat[IDX( xG, xB+offQ, xMap.ld )], cB+offX, X.ld )];
134         Qs[offX][offQ] = ( (qMap[qB+offQ]==DUMMY_IDX) ? 0 : Q.mat[IDX( qMap[qB+offQ], cB+offX, Q.ld )] );
135         __syncthreads();
136         
137         for(k=0; k<BLOCK_SIZE; k++)
138           ans+=DIST( Xs[k][offX], Qs[k][offQ] );
139
140         __syncthreads();
141       }
142      
143       dNN[offQ][offX+32] = (xB+offX<groupCount)? ans:MAX_REAL;
144       idNN[offQ][offX+32] = (xB+offX<groupCount)? xMap.mat[IDX( xG, xB+offX, xMap.ld )]: DUMMY_IDX; 
145       __syncthreads();
146
147       sort16off( dNN, idNN );
148       __syncthreads();
149       
150       merge32x16( dNN, idNN );
151     }
152   }
153   __syncthreads();
154   
155   if(qMap[qB+offQ]!=DUMMY_IDX){
156     dMins.mat[IDX(qMap[qB+offQ], offX, dMins.ld)] = dNN[offQ][offX];
157     dMins.mat[IDX(qMap[qB+offQ], offX+16, dMins.ld)] = dNN[offQ][offX+16];
158     dMinIDs.mat[IDX(qMap[qB+offQ], offX, dMins.ld)] = idNN[offQ][offX];
159     dMinIDs.mat[IDX(qMap[qB+offQ], offX+16, dMinIDs.ld)] = idNN[offQ][offX+16];
160   }
161 }
162
163
164 //The basic 1-NN search kernel.
165 __global__ void nnKernel(const matrix Q, unint numDone, const matrix X, real *dMins, unint *dMinIDs){
166   unint qB = blockIdx.y * BLOCK_SIZE + numDone;  //indexes Q
167   unint xB; //indexes X;
168   unint cB; //colBlock
169   unint offQ = threadIdx.y; //the offset of qPos in this block
170   unint offX = threadIdx.x; //ditto for x
171   unint i;
172   real ans;
173
174   __shared__ real min[BLOCK_SIZE][BLOCK_SIZE];
175   __shared__ unint minPos[BLOCK_SIZE][BLOCK_SIZE];
176
177   __shared__ real Xs[BLOCK_SIZE][BLOCK_SIZE];
178   __shared__ real Qs[BLOCK_SIZE][BLOCK_SIZE];
179
180   min[offQ][offX]=MAX_REAL;
181   __syncthreads();
182
183   for(xB=0; xB<X.pr; xB+=BLOCK_SIZE){
184     ans=0;
185     for(cB=0; cB<X.pc; cB+=BLOCK_SIZE){
186       
187       //Each thread loads one element of X and Q into memory.
188       Xs[offX][offQ] = X.mat[IDX( xB+offQ, cB+offX, X.ld )];
189       Qs[offX][offQ] = Q.mat[IDX( qB+offQ, cB+offX, Q.ld )];
190       
191       __syncthreads();
192       
193       for(i=0;i<BLOCK_SIZE;i++)
194         ans += DIST( Xs[i][offX], Qs[i][offQ] );
195       
196       __syncthreads();
197     }
198    
199     if( xB+offX<X.r && ans<min[offQ][offX] ){
200       minPos[offQ][offX] = xB+offX;
201       min[offQ][offX] = ans;
202     }
203   }
204   __syncthreads();
205   
206   
207   //reduce across threads
208   for(i=BLOCK_SIZE/2; i>0; i/=2){
209     if(offX<i){
210       if(min[offQ][offX+i]<min[offQ][offX]){
211         min[offQ][offX] = min[offQ][offX+i];
212         minPos[offQ][offX] = minPos[offQ][offX+i];      
213       }
214     }
215     __syncthreads();
216   }
217   
218   if(offX==0){
219     dMins[qB+offQ] = min[offQ][0];
220     dMinIDs[qB+offQ] = minPos[offQ][0];
221   }
222 }
223
224
225 //Computes the 32-NNs for each query in Q.  It is similar to nnKernel above, but maintains a 
226 //list of the 32 currently-closest points in the DB, instead of just the single NN.  After each 
227 //batch of 16 points is processed, it sorts these 16 points according to the distance from the 
228 //query, then merges this list with the other list.
229 __global__ void knnKernel(const matrix Q, unint numDone, const matrix X, matrix dMins, intMatrix dMinIDs){
230   unint qB = blockIdx.y * BLOCK_SIZE + numDone;  //indexes Q
231   unint xB; //indexes X;
232   unint cB; //colBlock
233   unint offQ = threadIdx.y; //the offset of qPos in this block
234   unint offX = threadIdx.x; //ditto for x
235   unint i;
236   real ans;
237
238   __shared__ real Xs[BLOCK_SIZE][BLOCK_SIZE];
239   __shared__ real Qs[BLOCK_SIZE][BLOCK_SIZE];
240   
241   __shared__ real dNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
242   __shared__ unint idNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
243
244   dNN[offQ][offX] = MAX_REAL;
245   dNN[offQ][offX+16] = MAX_REAL;
246   idNN[offQ][offX] = DUMMY_IDX;
247   idNN[offQ][offX+16] = DUMMY_IDX;
248   
249   __syncthreads();
250
251   for(xB=0; xB<X.pr; xB+=BLOCK_SIZE){
252     ans=0;
253     for(cB=0; cB<X.pc; cB+=BLOCK_SIZE){
254       
255       //Each thread loads one element of X and Q into memory.
256       Xs[offX][offQ] = X.mat[IDX( xB+offQ, cB+offX, X.ld )];
257       Qs[offX][offQ] = Q.mat[IDX( qB+offQ, cB+offX, Q.ld )];
258       __syncthreads();
259       
260       for(i=0;i<BLOCK_SIZE;i++)
261         ans += DIST( Xs[i][offX], Qs[i][offQ] );
262       
263       __syncthreads();
264     }
265  
266     dNN[offQ][offX+32] = (xB+offX<X.r)? ans:MAX_REAL;
267     idNN[offQ][offX+32] = xB + offX;
268     __syncthreads();
269
270     sort16off( dNN, idNN );
271     __syncthreads();
272
273     merge32x16( dNN, idNN );
274   }
275   __syncthreads();
276   
277   dMins.mat[IDX(qB+offQ, offX, dMins.ld)] = dNN[offQ][offX];
278   dMins.mat[IDX(qB+offQ, offX+16, dMins.ld)] = dNN[offQ][offX+16];
279   dMinIDs.mat[IDX(qB+offQ, offX, dMins.ld)] = idNN[offQ][offX];
280   dMinIDs.mat[IDX(qB+offQ, offX+16, dMins.ld)] = idNN[offQ][offX+16];
281   
282 }
283
284 //Computes all pairs of distances between Q and X.
285 __global__ void dist1Kernel(const matrix Q, unint qStart, const matrix X, unint xStart, matrix D){
286   unint c, i, j;
287
288   unint qB = blockIdx.y*BLOCK_SIZE + qStart;
289   unint q  = threadIdx.y;
290   unint xB = blockIdx.x*BLOCK_SIZE + xStart;
291   unint x = threadIdx.x;
292
293   real ans=0;
294
295   //This thread is responsible for computing the dist between Q[qB+q] and X[xB+x]
296   
297   __shared__ real Qs[BLOCK_SIZE][BLOCK_SIZE];
298   __shared__ real Xs[BLOCK_SIZE][BLOCK_SIZE];
299
300
301   for(i=0 ; i<Q.pc/BLOCK_SIZE ; i++){
302     c=i*BLOCK_SIZE; //current col block
303
304     Qs[x][q] = Q.mat[ IDX(qB+q, c+x, Q.ld) ];
305     Xs[x][q] = X.mat[ IDX(xB+q, c+x, X.ld) ];
306
307     __syncthreads();
308
309     for(j=0 ; j<BLOCK_SIZE ; j++)
310       ans += DIST( Qs[j][q], Xs[j][x] );
311     
312     __syncthreads();
313   }
314   
315   D.mat[ IDX( qB+q, xB+x, D.ld ) ] = ans;
316
317 }
318
319
320 //This function is used by the rbc building routine.  It find an appropriate range 
321 //such that roughly cntWant points fall within this range.  D is a matrix of distances.
322 __global__ void findRangeKernel(const matrix D, unint numDone, real *ranges, unint cntWant){
323   unint row = blockIdx.y*(BLOCK_SIZE/4)+threadIdx.y + numDone;
324   unint ro = threadIdx.y;
325   unint co = threadIdx.x;
326   unint i, c;
327   real t;
328
329   const unint LB = (90*cntWant)/100 ;
330   const unint UB = cntWant; 
331
332   __shared__ real smin[BLOCK_SIZE/4][4*BLOCK_SIZE];
333   __shared__ real smax[BLOCK_SIZE/4][4*BLOCK_SIZE];
334   
335   real min=MAX_REAL;
336   real max=0;
337   for(c=0 ; c<D.pc ; c+=(4*BLOCK_SIZE)){
338     if( c+co < D.c ){
339       t = D.mat[ IDX( row, c+co, D.ld ) ];
340       min = MIN(t,min);
341       max = MAX(t,max);
342     }
343   }
344   
345   smin[ro][co] = min;
346   smax[ro][co] = max;
347   __syncthreads();
348   
349   for(i=2*BLOCK_SIZE ; i>0 ; i/=2){
350     if( co < i ){
351       smin[ro][co] = MIN( smin[ro][co], smin[ro][co+i] );
352       smax[ro][co] = MAX( smax[ro][co], smax[ro][co+i] );
353     }
354     __syncthreads();
355   }
356
357   //Now start range counting.
358
359   unint itcount=0;
360   unint cnt;
361   real rg;
362   __shared__ unint scnt[BLOCK_SIZE/4][4*BLOCK_SIZE];
363   __shared__ char cont[BLOCK_SIZE/4];
364   
365   if(co==0)
366     cont[ro]=1;
367   
368   do{
369     itcount++;
370     __syncthreads();
371
372     if( cont[ro] )  //if we didn't actually need to cont, leave rg as it was.
373       rg = ( smax[ro][0] + smin[ro][0] ) / ((real)2.0) ;
374
375     cnt=0;
376     for(c=0 ; c<D.pc ; c+=(4*BLOCK_SIZE)){
377       cnt += (c+co < D.c && row < D.r && D.mat[ IDX( row, c+co, D.ld ) ] <= rg);
378     }
379
380     scnt[ro][co] = cnt;
381     __syncthreads();
382     
383     for(i=2*BLOCK_SIZE ; i>0 ; i/=2){
384       if( co < i ){
385         scnt[ro][co] += scnt[ro][co+i];
386       }
387       __syncthreads();
388     }
389     
390     if(co==0){
391       if( scnt[ro][0] < cntWant )
392         smin[ro][0]=rg;
393       else
394         smax[ro][0]=rg;
395     }
396     
397     // cont[ro] == this row needs to continue
398     if(co==0)
399       cont[ro] = row<D.r && ( scnt[ro][0] < LB || scnt[ro][0] > UB ); 
400     __syncthreads();
401
402     // Determine if *any* of the rows need to continue
403     for(i=BLOCK_SIZE/8 ; i>0 ; i/=2){
404       if( ro < i && co==0)
405         cont[ro] |= cont[ro+i];
406       __syncthreads();
407     }
408     
409   } while(cont[0]);
410
411   if(co==0 && row<D.r )
412     ranges[row]=rg;
413   
414 }
415
416
417 __global__ void rangeSearchKernel(const matrix D, unint xOff, unint yOff, const real *ranges, charMatrix ir){
418   unint col = blockIdx.x*BLOCK_SIZE + threadIdx.x + xOff;
419   unint row = blockIdx.y*BLOCK_SIZE + threadIdx.y + yOff;
420
421   ir.mat[IDX( row, col, ir.ld )] = D.mat[IDX( row, col, D.ld )] < ranges[row];
422
423 }
424
425
426 __global__ void rangeCountKernel(const matrix Q, unint numDone, const matrix X, real *ranges, unint *counts){
427   unint q = blockIdx.y*BLOCK_SIZE + numDone;
428   unint qo = threadIdx.y;
429   unint xo = threadIdx.x;
430   
431   real rg = ranges[q+qo];
432   
433   unint r,c,i;
434
435   __shared__ unint scnt[BLOCK_SIZE][BLOCK_SIZE];
436
437   __shared__ real xs[BLOCK_SIZE][BLOCK_SIZE];
438   __shared__ real qs[BLOCK_SIZE][BLOCK_SIZE];
439   
440   unint cnt=0;
441   for( r=0; r<X.pr; r+=BLOCK_SIZE ){
442
443     real dist=0;
444     for( c=0; c<X.pc; c+=BLOCK_SIZE){
445       xs[xo][qo] = X.mat[IDX( r+qo, c+xo, X.ld )];
446       qs[xo][qo] = Q.mat[IDX( q+qo, c+xo, Q.ld )];
447       __syncthreads();
448       
449       for( i=0; i<BLOCK_SIZE; i++)
450         dist += DIST( xs[i][xo], qs[i][qo] );
451
452       __syncthreads();
453
454     }
455     cnt += r+xo<X.r && dist<rg;
456
457   }
458   
459   scnt[qo][xo]=cnt;
460   __syncthreads();
461   
462   for( i=BLOCK_SIZE/2; i>0; i/=2 ){
463     if( xo<i ){
464       scnt[qo][xo] += scnt[qo][xo+i];
465     }
466     __syncthreads();
467   }
468
469   if( xo==0 && q+qo<Q.r )
470     counts[q+qo] = scnt[qo][0];
471 }
472
473
474 //**************************************************************************
475 // The following functions are an implementation of Batcher's sorting network.  
476 // All computations take place in (on-chip) shared memory.
477
478 // The function name is descriptive; it sorts each row of x, whose indices are xi.
479 __device__ void sort16(real x[][16], unint xi[][16]){
480   int i = threadIdx.x;
481   int j = threadIdx.y;
482
483   if(i%2==0)
484     mmGateI( x[j]+i, x[j]+i+1, xi[j]+i, xi[j]+i+1 );
485   __syncthreads();
486
487   if(i%4<2)
488     mmGateI( x[j]+i, x[j]+i+2, xi[j]+i, xi[j]+i+2 );
489   __syncthreads();
490
491   if(i%4==1)
492     mmGateI( x[j]+i, x[j]+i+1, xi[j]+i, xi[j]+i+1 );
493   __syncthreads();
494   
495   if(i%8<4)
496     mmGateI( x[j]+i, x[j]+i+4, xi[j]+i, xi[j]+i+4 );
497   __syncthreads();
498   
499   if(i%8==2 || i%8==3)
500     mmGateI( x[j]+i, x[j]+i+2, xi[j]+i, xi[j]+i+2 );
501   __syncthreads();
502
503   if( i%2 && i%8 != 7 ) 
504     mmGateI( x[j]+i, x[j]+i+1, xi[j]+i, xi[j]+i+1 );
505   __syncthreads();
506   
507   //0-7; 8-15 now sorted.  merge time.
508   if( i<8)
509     mmGateI( x[j]+i, x[j]+i+8, xi[j]+i, xi[j]+i+8 );
510   __syncthreads();
511   
512   if( i>3 && i<8 )
513     mmGateI( x[j]+i, x[j]+i+4, xi[j]+i, xi[j]+i+4 );
514   __syncthreads();
515   
516   int os = (i/2)*4+2 + i%2;
517   if(i<6)
518     mmGateI( x[j]+os, x[j]+os+2, xi[j]+os, xi[j]+os+2 );
519   __syncthreads();
520   
521   if( i%2 && i<15)
522     mmGateI( x[j]+i, x[j]+i+1, xi[j]+i, xi[j]+i+1 );
523
524 }
525
526
527 // This function takes an array of lists, each of length 48. It is assumed
528 // that the first 32 numbers are sorted, and the last 16 numbers.  The 
529 // routine then merges these lists into one sorted list of length 48.
530 __device__ void merge32x16(real x[][48], unint xi[][48]){
531   int i = threadIdx.x;
532   int j = threadIdx.y;
533
534   mmGateI( x[j]+i, x[j]+i+32, xi[j]+i, xi[j]+i+32 );
535   __syncthreads();
536
537   mmGateI( x[j]+i+16, x[j]+i+32, xi[j]+i+16, xi[j]+i+32 );
538   __syncthreads();
539
540   int os = (i<8)? 24: 0;
541   mmGateI( x[j]+os+i, x[j]+os+i+8, xi[j]+os+i, xi[j]+os+i+8 );
542   __syncthreads();
543   
544   os = (i/4)*8+4 + i%4;
545   mmGateI( x[j]+os, x[j]+os+4, xi[j]+os, xi[j]+os+4 );
546   if(i<4)
547     mmGateI(x[j]+36+i, x[j]+36+i+4, xi[j]+36+i, xi[j]+36+i+4 );
548   __syncthreads();
549
550   os = (i/2)*4+2 + i%2;
551   mmGateI( x[j]+os, x[j]+os+2, xi[j]+os, xi[j]+os+2 );
552   
553   os = (i/2)*4+34 + i%2;
554   if(i<6)
555     mmGateI( x[j]+os, x[j]+os+2, xi[j]+os, xi[j]+os+2 );
556   __syncthreads();
557
558   os = 2*i+1;
559   mmGateI(x[j]+os, x[j]+os+1, xi[j]+os, xi[j]+os+1 );
560
561   os = 2*i+33;
562   if(i<7)
563     mmGateI(x[j]+os, x[j]+os+1, xi[j]+os, xi[j]+os+1 );
564
565 }
566
567 //This is the same as sort16, but takes as input lists of length 48
568 //and sorts the last 16 entries.  This cleans up some of the NN code, 
569 //though it is inelegant.
570 __device__ void sort16off(real x[][48], unint xi[][48]){
571   int i = threadIdx.x;
572   int j = threadIdx.y;
573
574   if(i%2==0)
575     mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
576   __syncthreads();
577
578   if(i%4<2)
579     mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+2, xi[j]+KMAX+i, xi[j]+KMAX+i+2 );
580   __syncthreads();
581
582   if(i%4==1)
583     mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
584   __syncthreads();
585   
586   if(i%8<4)
587     mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+4, xi[j]+KMAX+i, xi[j]+KMAX+i+4 );
588   __syncthreads();
589   
590   if(i%8==2 || i%8==3)
591     mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+2, xi[j]+KMAX+i, xi[j]+KMAX+i+2 );
592   __syncthreads();
593
594   if( i%2 && i%8 != 7 ) 
595     mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
596   __syncthreads();
597   
598   //0-7; 8-15 now sorted.  merge time.
599   if( i<8)
600     mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+8, xi[j]+KMAX+i, xi[j]+KMAX+i+8 );
601   __syncthreads();
602   
603   if( i>3 && i<8 )
604     mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+4, xi[j]+KMAX+i, xi[j]+KMAX+i+4 );
605   __syncthreads();
606   
607   int os = (i/2)*4+2 + i%2;
608   if(i<6)
609     mmGateI( x[j]+KMAX+os, x[j]+KMAX+os+2, xi[j]+KMAX+os, xi[j]+KMAX+os+2 );
610   __syncthreads();
611   
612   if( i%2 && i<15)
613     mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
614 }
615
616 //min-max gate: it sets the minimum of x and y into x, the maximum into y, and 
617 //exchanges the indices (xi and yi) accordingly.
618 __device__ void mmGateI(real *x, real *y, unint *xi, unint *yi){
619   int ti = MINi( *x, *y, *xi, *yi );
620   *yi = MAXi( *x, *y, *xi, *yi );
621   *xi = ti;
622   real t = MIN( *x, *y );
623   *y = MAX( *x, *y );
624   *x = t;
625 }
626
627 #endif