updated NN functions so that they return the distances as well as the indices
[RBC.git] / rbc.cu
diff --git a/rbc.cu b/rbc.cu
index 61fad14..3237fe1 100644 (file)
--- a/rbc.cu
+++ b/rbc.cu
@@ -16,7 +16,7 @@
 #include "kernelWrap.h"
 #include "sKernelWrap.h"
 
-void queryRBC(const matrix q, const rbcStruct rbcS, unint *NNs){
+void queryRBC(const matrix q, const rbcStruct rbcS, unint *NNs, real* NNdists){
   unint m = q.r;
   unint numReps = rbcS.dr.r;
   unint compLength;
@@ -55,7 +55,7 @@ void queryRBC(const matrix q, const rbcStruct rbcS, unint *NNs){
   checkErr( cudaMalloc( (void**)&dqMap, compLength*sizeof(*dqMap) ) );
   cudaMemcpy( dqMap, qMap, compLength*sizeof(*dqMap), cudaMemcpyHostToDevice );
   
-  computeNNs(rbcS.dx, rbcS.dxMap, dq, dqMap, dcP, NNs, compLength);
+  computeNNs(rbcS.dx, rbcS.dxMap, dq, dqMap, dcP, NNs, NNdists, compLength);
   
   free(qMap);
   freeCompPlan(&dcP);
@@ -68,7 +68,7 @@ void queryRBC(const matrix q, const rbcStruct rbcS, unint *NNs){
 
 //This function is very similar to queryRBC, with a couple of basic changes to handle
 //k-nn.  
-void kqueryRBC(const matrix q, const rbcStruct rbcS, intMatrix NNs){
+void kqueryRBC(const matrix q, const rbcStruct rbcS, intMatrix NNs, matrix NNdists){
   unint m = q.r;
   unint numReps = rbcS.dr.r;
   unint compLength;
@@ -110,7 +110,7 @@ void kqueryRBC(const matrix q, const rbcStruct rbcS, intMatrix NNs){
   checkErr( cudaMalloc( (void**)&dqMap, compLength*sizeof(*dqMap) ) );
   cudaMemcpy( dqMap, qMap, compLength*sizeof(*dqMap), cudaMemcpyHostToDevice );
   
-  computeKNNs(rbcS.dx, rbcS.dxMap, dq, dqMap, dcP, NNs, compLength);
+  computeKNNs(rbcS.dx, rbcS.dxMap, dq, dqMap, dcP, NNs, NNdists, compLength);
 
   free(qMap);
   freeCompPlan(&dcP);
@@ -311,34 +311,36 @@ void fullIntersection(charMatrix cM){
 }
 
 
-void computeNNs(matrix dx, intMatrix dxMap, matrix dq, unint *dqMap, compPlan dcP, unint *NNs, unint compLength){
-  real *dMins;
+void computeNNs(matrix dx, intMatrix dxMap, matrix dq, unint *dqMap, compPlan dcP, unint *NNs, real *NNdists, unint compLength){
+  real *dNNdists;
   unint *dMinIDs;
   
-  checkErr( cudaMalloc((void**)&dMins,compLength*sizeof(*dMins)) );
+  checkErr( cudaMalloc((void**)&dNNdists,compLength*sizeof(*dNNdists)) );
   checkErr( cudaMalloc((void**)&dMinIDs,compLength*sizeof(*dMinIDs)) );
 
-  planNNWrap(dq, dqMap, dx, dxMap, dMins, dMinIDs, dcP, compLength);
-  cudaMemcpy( NNs, dMinIDs, dq.r*sizeof(*NNs), cudaMemcpyDeviceToHost);
-  
-  cudaFree(dMins);
+  planNNWrap(dq, dqMap, dx, dxMap, dNNdists, dMinIDs, dcP, compLength );
+  cudaMemcpy( NNs, dMinIDs, dq.r*sizeof(*NNs), cudaMemcpyDeviceToHost );
+  cudaMemcpy( NNdists, dNNdists, dq.r*sizeof(*dNNdists), cudaMemcpyDeviceToHost );
+
+  cudaFree(dNNdists);
   cudaFree(dMinIDs);
 }
 
 
-void computeKNNs(matrix dx, intMatrix dxMap, matrix dq, unint *dqMap, compPlan dcP, intMatrix NNs, unint compLength){
-  matrix dMins;
+void computeKNNs(matrix dx, intMatrix dxMap, matrix dq, unint *dqMap, compPlan dcP, intMatrix NNs, matrix NNdists, unint compLength){
+  matrix dNNdists;
   intMatrix dMinIDs;
-  dMins.r=compLength; dMins.pr=compLength; dMins.c=K; dMins.pc=K; dMins.ld=dMins.pc;
+  dNNdists.r=compLength; dNNdists.pr=compLength; dNNdists.c=K; dNNdists.pc=K; dNNdists.ld=dNNdists.pc;
   dMinIDs.r=compLength; dMinIDs.pr=compLength; dMinIDs.c=K; dMinIDs.pc=K; dMinIDs.ld=dMinIDs.pc;
 
-  checkErr( cudaMalloc((void**)&dMins.mat,dMins.pr*dMins.pc*sizeof(*dMins.mat)) );
+  checkErr( cudaMalloc((void**)&dNNdists.mat,dNNdists.pr*dNNdists.pc*sizeof(*dNNdists.mat)) );
   checkErr( cudaMalloc((void**)&dMinIDs.mat,dMinIDs.pr*dMinIDs.pc*sizeof(*dMinIDs.mat)) );
 
-  planKNNWrap(dq, dqMap, dx, dxMap, dMins, dMinIDs, dcP, compLength);
-  cudaMemcpy( NNs.mat, dMinIDs.mat, dq.r*K*sizeof(*NNs.mat), cudaMemcpyDeviceToHost);
+  planKNNWrap(dq, dqMap, dx, dxMap, dNNdists, dMinIDs, dcP, compLength);
+  cudaMemcpy( NNs.mat, dMinIDs.mat, dq.r*K*sizeof(*NNs.mat), cudaMemcpyDeviceToHost );
+  cudaMemcpy( NNdists.mat, dNNdists.mat, dq.r*K*sizeof(*NNdists.mat), cudaMemcpyDeviceToHost );
 
-  cudaFree(dMins.mat);
+  cudaFree(dNNdists.mat);
   cudaFree(dMinIDs.mat);
 }