updated text files
[RBC.git] / kernels.cu
index 4924638..e6f6adf 100644 (file)
@@ -87,6 +87,9 @@ __global__ void planNNKernel(const matrix Q, const unint *qMap, const matrix X,
 }
 
 
+//This is indentical to the planNNkernel, except that it maintains a list of 32-NNs.  At 
+//each iteration-chunk, the next 16 distances are computed, then sorted, then merged 
+//with the previously computed 32-NNs.
 __global__ void planKNNKernel(const matrix Q, const unint *qMap, const matrix X, const intMatrix xMap, matrix dMins, intMatrix dMinIDs, compPlan cP,  unint qStartPos ){
   unint qB = qStartPos + blockIdx.y * BLOCK_SIZE;  //indexes Q
   unint xB; //X (DB) Block;
@@ -96,8 +99,8 @@ __global__ void planKNNKernel(const matrix Q, const unint *qMap, const matrix X,
   unint i,j,k;
   unint groupIts;
   
-  __shared__ real dNN[BLOCK_SIZE][K+BLOCK_SIZE];
-  __shared__ unint idNN[BLOCK_SIZE][K+BLOCK_SIZE];
+  __shared__ real dNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
+  __shared__ unint idNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
 
   __shared__ real Xs[BLOCK_SIZE][BLOCK_SIZE];
   __shared__ real Qs[BLOCK_SIZE][BLOCK_SIZE];
@@ -138,7 +141,7 @@ __global__ void planKNNKernel(const matrix Q, const unint *qMap, const matrix X,
       }
      
       dNN[offQ][offX+32] = (xB+offX<groupCount)? ans:MAX_REAL;
-      idNN[offQ][offX+32] = xB + offX;
+      idNN[offQ][offX+32] = (xB+offX<groupCount)? xMap.mat[IDX( xG, xB+offX, xMap.ld )]: DUMMY_IDX; 
       __syncthreads();
 
       sort16off( dNN, idNN );
@@ -158,9 +161,8 @@ __global__ void planKNNKernel(const matrix Q, const unint *qMap, const matrix X,
 }
 
 
-
+//The basic 1-NN search kernel.
 __global__ void nnKernel(const matrix Q, unint numDone, const matrix X, real *dMins, unint *dMinIDs){
-
   unint qB = blockIdx.y * BLOCK_SIZE + numDone;  //indexes Q
   unint xB; //indexes X;
   unint cB; //colBlock
@@ -220,8 +222,11 @@ __global__ void nnKernel(const matrix Q, unint numDone, const matrix X, real *dM
 }
 
 
+//Computes the 32-NNs for each query in Q.  It is similar to nnKernel above, but maintains a 
+//list of the 32 currently-closest points in the DB, instead of just the single NN.  After each 
+//batch of 16 points is processed, it sorts these 16 points according to the distance from the 
+//query, then merges this list with the other list.
 __global__ void knnKernel(const matrix Q, unint numDone, const matrix X, matrix dMins, intMatrix dMinIDs){
-  
   unint qB = blockIdx.y * BLOCK_SIZE + numDone;  //indexes Q
   unint xB; //indexes X;
   unint cB; //colBlock
@@ -233,8 +238,8 @@ __global__ void knnKernel(const matrix Q, unint numDone, const matrix X, matrix
   __shared__ real Xs[BLOCK_SIZE][BLOCK_SIZE];
   __shared__ real Qs[BLOCK_SIZE][BLOCK_SIZE];
   
-  __shared__ real dNN[BLOCK_SIZE][K+BLOCK_SIZE];
-  __shared__ unint idNN[BLOCK_SIZE][K+BLOCK_SIZE];
+  __shared__ real dNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
+  __shared__ unint idNN[BLOCK_SIZE][KMAX+BLOCK_SIZE];
 
   dNN[offQ][offX] = MAX_REAL;
   dNN[offQ][offX+16] = MAX_REAL;
@@ -276,7 +281,7 @@ __global__ void knnKernel(const matrix Q, unint numDone, const matrix X, matrix
   
 }
 
-
+//Computes all pairs of distances between Q and X.
 __global__ void dist1Kernel(const matrix Q, unint qStart, const matrix X, unint xStart, matrix D){
   unint c, i, j;
 
@@ -312,9 +317,9 @@ __global__ void dist1Kernel(const matrix Q, unint qStart, const matrix X, unint
 }
 
 
-
+//This function is used by the rbc building routine.  It find an appropriate range 
+//such that roughly cntWant points fall within this range.  D is a matrix of distances.
 __global__ void findRangeKernel(const matrix D, unint numDone, real *ranges, unint cntWant){
-  
   unint row = blockIdx.y*(BLOCK_SIZE/4)+threadIdx.y + numDone;
   unint ro = threadIdx.y;
   unint co = threadIdx.x;
@@ -466,53 +471,11 @@ __global__ void rangeCountKernel(const matrix Q, unint numDone, const matrix X,
 }
 
 
-__device__ void sort16off(real x[][48], unint xi[][48]){
-  int i = threadIdx.x;
-  int j = threadIdx.y;
-
-  if(i%2==0)
-    mmGateI( x[j]+K+i, x[j]+K+i+1, xi[j]+K+i, xi[j]+K+i+1 );
-  __syncthreads();
-
-  if(i%4<2)
-    mmGateI( x[j]+K+i, x[j]+K+i+2, xi[j]+K+i, xi[j]+K+i+2 );
-  __syncthreads();
-
-  if(i%4==1)
-    mmGateI( x[j]+K+i, x[j]+K+i+1, xi[j]+K+i, xi[j]+K+i+1 );
-  __syncthreads();
-  
-  if(i%8<4)
-    mmGateI( x[j]+K+i, x[j]+K+i+4, xi[j]+K+i, xi[j]+K+i+4 );
-  __syncthreads();
-  
-  if(i%8==2 || i%8==3)
-    mmGateI( x[j]+K+i, x[j]+K+i+2, xi[j]+K+i, xi[j]+K+i+2 );
-  __syncthreads();
-
-  if( i%2 && i%8 != 7 ) 
-    mmGateI( x[j]+K+i, x[j]+K+i+1, xi[j]+K+i, xi[j]+K+i+1 );
-  __syncthreads();
-  
-  //0-7; 8-15 now sorted.  merge time.
-  if( i<8)
-    mmGateI( x[j]+K+i, x[j]+K+i+8, xi[j]+K+i, xi[j]+K+i+8 );
-  __syncthreads();
-  
-  if( i>3 && i<8 )
-    mmGateI( x[j]+K+i, x[j]+K+i+4, xi[j]+K+i, xi[j]+K+i+4 );
-  __syncthreads();
-  
-  int os = (i/2)*4+2 + i%2;
-  if(i<6)
-    mmGateI( x[j]+K+os, x[j]+K+os+2, xi[j]+K+os, xi[j]+K+os+2 );
-  __syncthreads();
-  
-  if( i%2 && i<15)
-    mmGateI( x[j]+K+i, x[j]+K+i+1, xi[j]+K+i, xi[j]+K+i+1 );
-}
-
+//**************************************************************************
+// The following functions are an implementation of Batcher's sorting network.  
+// All computations take place in (on-chip) shared memory.
 
+// The function name is descriptive; it sorts each row of x, whose indices are xi.
 __device__ void sort16(real x[][16], unint xi[][16]){
   int i = threadIdx.x;
   int j = threadIdx.y;
@@ -561,6 +524,9 @@ __device__ void sort16(real x[][16], unint xi[][16]){
 }
 
 
+// This function takes an array of lists, each of length 48. It is assumed
+// that the first 32 numbers are sorted, and the last 16 numbers.  The 
+// routine then merges these lists into one sorted list of length 48.
 __device__ void merge32x16(real x[][48], unint xi[][48]){
   int i = threadIdx.x;
   int j = threadIdx.y;
@@ -598,7 +564,57 @@ __device__ void merge32x16(real x[][48], unint xi[][48]){
 
 }
 
+//This is the same as sort16, but takes as input lists of length 48
+//and sorts the last 16 entries.  This cleans up some of the NN code, 
+//though it is inelegant.
+__device__ void sort16off(real x[][48], unint xi[][48]){
+  int i = threadIdx.x;
+  int j = threadIdx.y;
+
+  if(i%2==0)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
+  __syncthreads();
+
+  if(i%4<2)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+2, xi[j]+KMAX+i, xi[j]+KMAX+i+2 );
+  __syncthreads();
+
+  if(i%4==1)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
+  __syncthreads();
+  
+  if(i%8<4)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+4, xi[j]+KMAX+i, xi[j]+KMAX+i+4 );
+  __syncthreads();
+  
+  if(i%8==2 || i%8==3)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+2, xi[j]+KMAX+i, xi[j]+KMAX+i+2 );
+  __syncthreads();
+
+  if( i%2 && i%8 != 7 ) 
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
+  __syncthreads();
+  
+  //0-7; 8-15 now sorted.  merge time.
+  if( i<8)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+8, xi[j]+KMAX+i, xi[j]+KMAX+i+8 );
+  __syncthreads();
+  
+  if( i>3 && i<8 )
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+4, xi[j]+KMAX+i, xi[j]+KMAX+i+4 );
+  __syncthreads();
+  
+  int os = (i/2)*4+2 + i%2;
+  if(i<6)
+    mmGateI( x[j]+KMAX+os, x[j]+KMAX+os+2, xi[j]+KMAX+os, xi[j]+KMAX+os+2 );
+  __syncthreads();
+  
+  if( i%2 && i<15)
+    mmGateI( x[j]+KMAX+i, x[j]+KMAX+i+1, xi[j]+KMAX+i, xi[j]+KMAX+i+1 );
+}
 
+//min-max gate: it sets the minimum of x and y into x, the maximum into y, and 
+//exchanges the indices (xi and yi) accordingly.
 __device__ void mmGateI(real *x, real *y, unint *xi, unint *yi){
   int ti = MINi( *x, *y, *xi, *yi );
   *yi = MAXi( *x, *y, *xi, *yi );