updated NN functions so that they return the distances as well as the indices
[RBC.git] / brute.cu
index 7041f28..f40d70a 100644 (file)
--- a/brute.cu
+++ b/brute.cu
@@ -68,17 +68,17 @@ void bruteSearch(matrix x, matrix q, unint *NNs){
 }
 
 
-void bruteK(matrix x, matrix q, intMatrix NNs){
-  matrix dMins;
+void bruteK(matrix x, matrix q, intMatrix NNs, matrix NNdists){
+  matrix dNNdists;
   intMatrix dMinIDs;
   matrix dx, dq;
   
   dx.r=x.r; dx.pr=x.pr; dx.c=x.c; dx.pc=x.pc; dx.ld=x.ld;
   dq.r=q.r; dq.pr=q.pr; dq.c=q.c; dq.pc=q.pc; dq.ld=q.ld;
-  dMins.r=q.r; dMins.pr=q.pr; dMins.c=K; dMins.pc=K; dMins.ld=dMins.pc;
+  dNNdists.r=q.r; dNNdists.pr=q.pr; dNNdists.c=K; dNNdists.pc=K; dNNdists.ld=dNNdists.pc;
   dMinIDs.r=q.r; dMinIDs.pr=q.pr; dMinIDs.c=K; dMinIDs.pc=K; dMinIDs.ld=dMinIDs.pc;
 
-  checkErr( cudaMalloc((void**)&dMins.mat, dMins.pc*dMins.pr*sizeof(*dMins.mat)) );
+  checkErr( cudaMalloc((void**)&dNNdists.mat, dNNdists.pc*dNNdists.pr*sizeof(*dNNdists.mat)) );
   checkErr( cudaMalloc((void**)&dMinIDs.mat, dMinIDs.pc*dMinIDs.pr*sizeof(*dMinIDs.mat)) );
   checkErr( cudaMalloc((void**)&dx.mat, dx.pr*dx.pc*sizeof(*dx.mat)) );
   checkErr( cudaMalloc((void**)&dq.mat, dq.pr*dq.pc*sizeof(*dq.mat)) );
@@ -86,11 +86,12 @@ void bruteK(matrix x, matrix q, intMatrix NNs){
   cudaMemcpy(dx.mat,x.mat,x.pr*x.pc*sizeof(*dx.mat),cudaMemcpyHostToDevice);
   cudaMemcpy(dq.mat,q.mat,q.pr*q.pc*sizeof(*dq.mat),cudaMemcpyHostToDevice);
   
-  knnWrap(dq,dx,dMins,dMinIDs);
+  knnWrap(dq,dx,dNNdists,dMinIDs);
 
   cudaMemcpy(NNs.mat,dMinIDs.mat,NNs.pr*NNs.pc*sizeof(*NNs.mat),cudaMemcpyDeviceToHost);
-  
-  cudaFree(dMins.mat);
+  cudaMemcpy(NNdists.mat,dNNdists.mat,NNdists.pr*NNdists.pc*sizeof(*NNdists.mat),cudaMemcpyDeviceToHost);
+
+  cudaFree(dNNdists.mat);
   cudaFree(dMinIDs.mat);
   cudaFree(dx.mat);
   cudaFree(dq.mat);