minor bug fixes, updated readme
[RBC.git] / kernels.cu
index aeccf87..e6f6adf 100644 (file)
@@ -87,9 +87,82 @@ __global__ void planNNKernel(const matrix Q, const unint *qMap, const matrix X,
 }
 
 
+//This is indentical to the planNNkernel, except that it maintains a list of 32-NNs.  At 
+//each iteration-chunk, the next 16 distances are computed, then sorted, then merged 
+//with the previously computed 32-NNs.
+__global__ void planKNNKernel(const matrix Q, const unint *qMap, const matrix X, const intMatrix xMap, matrix dMins, intMatrix dMinIDs, compPlan cP,  unint qStartPos ){
+  unint qB = qStartPos + blockIdx.y * BLOCK_SIZE;  //indexes Q
+  unint xB; //X (DB) Block;
+  unint cB; //column Block
+  unint offQ = threadIdx.y; //the offset of qPos in this block
+  unint offX = threadIdx.x; //ditto for x
+  unint i,j,k;
+  unint groupIts;
+  
+  __shared__ real dNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
+  __shared__ unint idNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
 
-__global__ void nnKernel(const matrix Q, unint numDone, const matrix X, real *dMins, unint *dMinIDs){
+  __shared__ real Xs[BLOCK_SIZE][BLOCK_SIZE];
+  __shared__ real Qs[BLOCK_SIZE][BLOCK_SIZE];
+
+  unint g; //query group of q
+  unint xG; //DB group currently being examined
+  unint numGroups;
+  unint groupCount;
+
+  g = cP.qToQGroup[qB]; 
+  numGroups = cP.numGroups[g];
+  
+  dNN[offQ][offX] = MAX_REAL;
+  dNN[offQ][offX+16] = MAX_REAL;
+  idNN[offQ][offX] = DUMMY_IDX;
+  idNN[offQ][offX+16] = DUMMY_IDX;
+  __syncthreads();
+  
+  for(i=0; i<numGroups; i++){ //iterate over DB groups
+    xG = cP.qGroupToXGroup[IDX( g, i, cP.ld )];
+    groupCount = cP.groupCountX[IDX( g, i, cP.ld )];
+    groupIts = (groupCount+BLOCK_SIZE-1)/BLOCK_SIZE;
+
+    for(j=0; j<groupIts; j++){ //iterate over elements of group
+      xB=j*BLOCK_SIZE;
+
+      real ans=0;
+      for(cB=0; cB<X.pc; cB+=BLOCK_SIZE){ // iterate over cols to compute distances
+
+       Xs[offX][offQ] = X.mat[IDX( xMap.mat[IDX( xG, xB+offQ, xMap.ld )], cB+offX, X.ld )];
+       Qs[offX][offQ] = ( (qMap[qB+offQ]==DUMMY_IDX) ? 0 : Q.mat[IDX( qMap[qB+offQ], cB+offX, Q.ld )] );
+       __syncthreads();
+       
+       for(k=0; k<BLOCK_SIZE; k++)
+         ans+=DIST( Xs[k][offX], Qs[k][offQ] );
 
+       __syncthreads();
+      }
+     
+      dNN[offQ][offX+32] = (xB+offX<groupCount)? ans:MAX_REAL;
+      idNN[offQ][offX+32] = (xB+offX<groupCount)? xMap.mat[IDX( xG, xB+offX, xMap.ld )]: DUMMY_IDX; 
+      __syncthreads();
+
+      sort16off( dNN, idNN );
+      __syncthreads();
+      
+      merge32x16( dNN, idNN );
+    }
+  }
+  __syncthreads();
+  
+  if(qMap[qB+offQ]!=DUMMY_IDX){
+    dMins.mat[IDX(qMap[qB+offQ], offX, dMins.ld)] = dNN[offQ][offX];
+    dMins.mat[IDX(qMap[qB+offQ], offX+16, dMins.ld)] = dNN[offQ][offX+16];
+    dMinIDs.mat[IDX(qMap[qB+offQ], offX, dMins.ld)] = idNN[offQ][offX];
+    dMinIDs.mat[IDX(qMap[qB+offQ], offX+16, dMinIDs.ld)] = idNN[offQ][offX+16];
+  }
+}
+
+
+//The basic 1-NN search kernel.
+__global__ void nnKernel(const matrix Q, unint numDone, const matrix X, real *dMins, unint *dMinIDs){
   unint qB = blockIdx.y * BLOCK_SIZE + numDone;  //indexes Q
   unint xB; //indexes X;
   unint cB; //colBlock
@@ -149,8 +222,66 @@ __global__ void nnKernel(const matrix Q, unint numDone, const matrix X, real *dM
 }
 
 
+//Computes the 32-NNs for each query in Q.  It is similar to nnKernel above, but maintains a 
+//list of the 32 currently-closest points in the DB, instead of just the single NN.  After each 
+//batch of 16 points is processed, it sorts these 16 points according to the distance from the 
+//query, then merges this list with the other list.
+__global__ void knnKernel(const matrix Q, unint numDone, const matrix X, matrix dMins, intMatrix dMinIDs){
+  unint qB = blockIdx.y * BLOCK_SIZE + numDone;  //indexes Q
+  unint xB; //indexes X;
+  unint cB; //colBlock
+  unint offQ = threadIdx.y; //the offset of qPos in this block
+  unint offX = threadIdx.x; //ditto for x
+  unint i;
+  real ans;
+
+  __shared__ real Xs[BLOCK_SIZE][BLOCK_SIZE];
+  __shared__ real Qs[BLOCK_SIZE][BLOCK_SIZE];
+  
+  __shared__ real dNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
+  __shared__ unint idNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
 
+  dNN[offQ][offX] = MAX_REAL;
+  dNN[offQ][offX+16] = MAX_REAL;
+  idNN[offQ][offX] = DUMMY_IDX;
+  idNN[offQ][offX+16] = DUMMY_IDX;
+  
+  __syncthreads();
 
+  for(xB=0; xB<X.pr; xB+=BLOCK_SIZE){
+    ans=0;
+    for(cB=0; cB<X.pc; cB+=BLOCK_SIZE){
+      
+      //Each thread loads one element of X and Q into memory.
+      Xs[offX][offQ] = X.mat[IDX( xB+offQ, cB+offX, X.ld )];
+      Qs[offX][offQ] = Q.mat[IDX( qB+offQ, cB+offX, Q.ld )];
+      __syncthreads();
+      
+      for(i=0;i<BLOCK_SIZE;i++)
+       ans += DIST( Xs[i][offX], Qs[i][offQ] );
+      
+      __syncthreads();
+    }
+    dNN[offQ][offX+32] = (xB+offX<X.r)? ans:MAX_REAL;
+    idNN[offQ][offX+32] = xB + offX;
+    __syncthreads();
+
+    sort16off( dNN, idNN );
+    __syncthreads();
+
+    merge32x16( dNN, idNN );
+  }
+  __syncthreads();
+  
+  dMins.mat[IDX(qB+offQ, offX, dMins.ld)] = dNN[offQ][offX];
+  dMins.mat[IDX(qB+offQ, offX+16, dMins.ld)] = dNN[offQ][offX+16];
+  dMinIDs.mat[IDX(qB+offQ, offX, dMins.ld)] = idNN[offQ][offX];
+  dMinIDs.mat[IDX(qB+offQ, offX+16, dMins.ld)] = idNN[offQ][offX+16];
+  
+}
+
+//Computes all pairs of distances between Q and X.
 __global__ void dist1Kernel(const matrix Q, unint qStart, const matrix X, unint xStart, matrix D){
   unint c, i, j;
 
@@ -186,9 +317,9 @@ __global__ void dist1Kernel(const matrix Q, unint qStart, const matrix X, unint
 }
 
 
-
+//This function is used by the rbc building routine.  It find an appropriate range 
+//such that roughly cntWant points fall within this range.  D is a matrix of distances.
 __global__ void findRangeKernel(const matrix D, unint numDone, real *ranges, unint cntWant){
-  
   unint row = blockIdx.y*(BLOCK_SIZE/4)+threadIdx.y + numDone;
   unint ro = threadIdx.y;
   unint co = threadIdx.x;
@@ -340,4 +471,157 @@ __global__ void rangeCountKernel(const matrix Q, unint numDone, const matrix X,
 }
 
 
+//**************************************************************************
+// The following functions are an implementation of Batcher's sorting network.  
+// All computations take place in (on-chip) shared memory.
+
+// The function name is descriptive; it sorts each row of x, whose indices are xi.
+__device__ void sort16(real x[][16], unint xi[][16]){
+  int i = threadIdx.x;
+  int j = threadIdx.y;
+
+  if(i%2==0)
+    mmGateI( x[j]+i, x[j]+i+1, xi[j]+i, xi[j]+i+1 );
+  __syncthreads();
+
+  if(i%4<2)
+    mmGateI( x[j]+i, x[j]+i+2, xi[j]+i, xi[j]+i+2 );
+  __syncthreads();
+
+  if(i%4==1)
+    mmGateI( x[j]+i, x[j]+i+1, xi[j]+i, xi[j]+i+1 );
+  __syncthreads();
+  
+  if(i%8<4)
+    mmGateI( x[j]+i, x[j]+i+4, xi[j]+i, xi[j]+i+4 );
+  __syncthreads();
+  
+  if(i%8==2 || i%8==3)
+    mmGateI( x[j]+i, x[j]+i+2, xi[j]+i, xi[j]+i+2 );
+  __syncthreads();
+
+  if( i%2 && i%8 != 7 ) 
+    mmGateI( x[j]+i, x[j]+i+1, xi[j]+i, xi[j]+i+1 );
+  __syncthreads();
+  
+  //0-7; 8-15 now sorted.  merge time.
+  if( i<8)
+    mmGateI( x[j]+i, x[j]+i+8, xi[j]+i, xi[j]+i+8 );
+  __syncthreads();
+  
+  if( i>3 && i<8 )
+    mmGateI( x[j]+i, x[j]+i+4, xi[j]+i, xi[j]+i+4 );
+  __syncthreads();
+  
+  int os = (i/2)*4+2 + i%2;
+  if(i<6)
+    mmGateI( x[j]+os, x[j]+os+2, xi[j]+os, xi[j]+os+2 );
+  __syncthreads();
+  
+  if( i%2 && i<15)
+    mmGateI( x[j]+i, x[j]+i+1, xi[j]+i, xi[j]+i+1 );
+
+}
+
+
+// This function takes an array of lists, each of length 48. It is assumed
+// that the first 32 numbers are sorted, and the last 16 numbers.  The 
+// routine then merges these lists into one sorted list of length 48.
+__device__ void merge32x16(real x[][48], unint xi[][48]){
+  int i = threadIdx.x;
+  int j = threadIdx.y;
+
+  mmGateI( x[j]+i, x[j]+i+32, xi[j]+i, xi[j]+i+32 );
+  __syncthreads();
+
+  mmGateI( x[j]+i+16, x[j]+i+32, xi[j]+i+16, xi[j]+i+32 );
+  __syncthreads();
+
+  int os = (i<8)? 24: 0;
+  mmGateI( x[j]+os+i, x[j]+os+i+8, xi[j]+os+i, xi[j]+os+i+8 );
+  __syncthreads();
+  
+  os = (i/4)*8+4 + i%4;
+  mmGateI( x[j]+os, x[j]+os+4, xi[j]+os, xi[j]+os+4 );
+  if(i<4)
+    mmGateI(x[j]+36+i, x[j]+36+i+4, xi[j]+36+i, xi[j]+36+i+4 );
+  __syncthreads();
+
+  os = (i/2)*4+2 + i%2;
+  mmGateI( x[j]+os, x[j]+os+2, xi[j]+os, xi[j]+os+2 );
+  
+  os = (i/2)*4+34 + i%2;
+  if(i<6)
+    mmGateI( x[j]+os, x[j]+os+2, xi[j]+os, xi[j]+os+2 );
+  __syncthreads();
+
+  os = 2*i+1;
+  mmGateI(x[j]+os, x[j]+os+1, xi[j]+os, xi[j]+os+1 );
+
+  os = 2*i+33;
+  if(i<7)
+    mmGateI(x[j]+os, x[j]+os+1, xi[j]+os, xi[j]+os+1 );
+
+}
+
+//This is the same as sort16, but takes as input lists of length 48
+//and sorts the last 16 entries.  This cleans up some of the NN code, 
+//though it is inelegant.
+__device__ void sort16off(real x[][48], unint xi[][48]){
+  int i = threadIdx.x;
+  int j = threadIdx.y;
+
+  if(i%2==0)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
+  __syncthreads();
+
+  if(i%4<2)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+2, xi[j]+KMAX+i, xi[j]+KMAX+i+2 );
+  __syncthreads();
+
+  if(i%4==1)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
+  __syncthreads();
+  
+  if(i%8<4)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+4, xi[j]+KMAX+i, xi[j]+KMAX+i+4 );
+  __syncthreads();
+  
+  if(i%8==2 || i%8==3)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+2, xi[j]+KMAX+i, xi[j]+KMAX+i+2 );
+  __syncthreads();
+
+  if( i%2 && i%8 != 7 ) 
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
+  __syncthreads();
+  
+  //0-7; 8-15 now sorted.  merge time.
+  if( i<8)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+8, xi[j]+KMAX+i, xi[j]+KMAX+i+8 );
+  __syncthreads();
+  
+  if( i>3 && i<8 )
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+4, xi[j]+KMAX+i, xi[j]+KMAX+i+4 );
+  __syncthreads();
+  
+  int os = (i/2)*4+2 + i%2;
+  if(i<6)
+    mmGateI( x[j]+KMAX+os, x[j]+KMAX+os+2, xi[j]+KMAX+os, xi[j]+KMAX+os+2 );
+  __syncthreads();
+  
+  if( i%2 && i<15)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
+}
+
+//min-max gate: it sets the minimum of x and y into x, the maximum into y, and 
+//exchanges the indices (xi and yi) accordingly.
+__device__ void mmGateI(real *x, real *y, unint *xi, unint *yi){
+  int ti = MINi( *x, *y, *xi, *yi );
+  *yi = MAXi( *x, *y, *xi, *yi );
+  *xi = ti;
+  real t = MIN( *x, *y );
+  *y = MAX( *x, *y );
+  *x = t;
+}
+
 #endif