updated NN functions so that they return the distances as well as the indices
[RBC.git] / driver.cu
index ff05bb1..aee8ee5 100644 (file)
--- a/driver.cu
+++ b/driver.cu
@@ -27,9 +27,8 @@ unint deviceNum=0;
 int main(int argc, char**argv){
   real *data;
   matrix x, q;
 int main(int argc, char**argv){
   real *data;
   matrix x, q;
-  unint *NNs;
-  intMatrix NNsK, kNNsRBC;
-  unint i;
+  intMatrix nnsBrute, nnsRBC;
+  matrix distsBrute, distsRBC;
   struct timeval tvB,tvE;
   cudaError_t cE;
   rbcStruct rbcS;
   struct timeval tvB,tvE;
   cudaError_t cE;
   rbcStruct rbcS;
@@ -58,24 +57,27 @@ int main(int argc, char**argv){
   x.r = n; x.c = d; x.pr = PAD(n); x.pc = PAD(d); x.ld = x.pc;
   q.r = m; q.c = d; q.pr = PAD(m); q.pc = PAD(d); q.ld = q.pc;
 
   x.r = n; x.c = d; x.pr = PAD(n); x.pc = PAD(d); x.ld = x.pc;
   q.r = m; q.c = d; q.pr = PAD(m); q.pc = PAD(d); q.ld = q.pc;
 
-  NNs = (unint*)calloc( m, sizeof(*NNs) );
-  for(i=0; i<m; i++)
-    NNs[i]=DUMMY_IDX;
-  
+  //Load data 
   readData(dataFile, (n+m), d, data);
   orgData(data, (n+m), d, x, q);
   free(data);
 
   readData(dataFile, (n+m), d, data);
   orgData(data, (n+m), d, x, q);
   free(data);
 
-  NNsK.r=q.r; NNsK.pr=q.pr; NNsK.pc=NNsK.c=K; NNsK.ld=NNsK.pc;
-  kNNsRBC.r=q.r; kNNsRBC.pr=q.pr; kNNsRBC.pc=kNNsRBC.c=K; kNNsRBC.ld=kNNsRBC.pc;
-  kNNsRBC.mat = (unint*)calloc(kNNsRBC.pr*kNNsRBC.pc, sizeof(*kNNsRBC.mat));
-  NNsK.mat = (unint*)calloc(NNsK.pr*NNsK.pc, sizeof(*NNsK.mat));
+  //Allocate space for NNs and dists
+  nnsBrute.r=q.r; nnsBrute.pr=q.pr; nnsBrute.pc=nnsBrute.c=K; nnsBrute.ld=nnsBrute.pc;
+  nnsBrute.mat = (unint*)calloc(nnsBrute.pr*nnsBrute.pc, sizeof(*nnsBrute.mat));
+  nnsRBC.r=q.r; nnsRBC.pr=q.pr; nnsRBC.pc=nnsRBC.c=K; nnsRBC.ld=nnsRBC.pc;
+  nnsRBC.mat = (unint*)calloc(nnsRBC.pr*nnsRBC.pc, sizeof(*nnsRBC.mat));
   
   
-  /* printf("running k-brute force..\n"); */
-  /* gettimeofday(&tvB,NULL); */
-  /* bruteK(x,q,NNsK); */
-  /* gettimeofday(&tvE,NULL); */
-  /* printf("\t.. time elapsed = %6.4f \n",timeDiff(tvB,tvE)); */
+  distsBrute.r=q.r; distsBrute.pr=q.pr; distsBrute.pc=distsBrute.c=K; distsBrute.ld=distsBrute.pc;
+  distsBrute.mat = (real*)calloc(distsBrute.pr*distsBrute.pc, sizeof(*distsBrute.mat));
+  distsRBC.r=q.r; distsRBC.pr=q.pr; distsRBC.pc=distsRBC.c=K; distsRBC.ld=distsRBC.pc;
+  distsRBC.mat = (real*)calloc(distsRBC.pr*distsRBC.pc, sizeof(*distsRBC.mat));
+
+  printf("running k-brute force..\n");
+  gettimeofday(&tvB,NULL);
+  bruteK(x,q,nnsBrute,distsBrute);
+  gettimeofday(&tvE,NULL);
+  printf("\t.. time elapsed = %6.4f \n",timeDiff(tvB,tvE));
 
   printf("\nrunning rbc..\n");
   gettimeofday(&tvB,NULL);
 
   printf("\nrunning rbc..\n");
   gettimeofday(&tvB,NULL);
@@ -85,7 +87,7 @@ int main(int argc, char**argv){
 
   //This finds the 32-NN; if you are only interested in the 1-NN, use queryRBC(..) instead
   gettimeofday(&tvB,NULL);
 
   //This finds the 32-NN; if you are only interested in the 1-NN, use queryRBC(..) instead
   gettimeofday(&tvB,NULL);
-  kqueryRBC(q, rbcS, kNNsRBC);
+  kqueryRBC(q, rbcS, nnsRBC, distsRBC);
   gettimeofday(&tvE,NULL);
   printf("\t.. query time for krbc = %6.4f \n",timeDiff(tvB,tvE));
   
   gettimeofday(&tvE,NULL);
   printf("\t.. query time for krbc = %6.4f \n",timeDiff(tvB,tvE));
   
@@ -97,13 +99,14 @@ int main(int argc, char**argv){
     printf("Execution failed; error type: %s \n", cudaGetErrorString(cE) );
   }
   
     printf("Execution failed; error type: %s \n", cudaGetErrorString(cE) );
   }
   
-  evalKNNerror(x,q,kNNsRBC);
+  evalKNNerror(x,q,nnsRBC);
   
   cudaThreadExit();
   
   
   cudaThreadExit();
   
-  free(NNs);
-  free(NNsK.mat);
-  free(kNNsRBC.mat);
+  free(nnsBrute.mat);
+  free(nnsRBC.mat);
+  free(distsBrute.mat);
+  free(distsRBC.mat);
   free(x.mat);
   free(q.mat);
 }
   free(x.mat);
   free(q.mat);
 }
@@ -283,9 +286,12 @@ void evalKNNerror(matrix x, matrix q, intMatrix NNs){
   intMatrix NNsB;
   NNsB.r=q.r; NNsB.pr=q.pr; NNsB.c=NNsB.pc=32; NNsB.ld=NNsB.pc;
   NNsB.mat = (unint*)calloc( NNsB.pr*NNsB.pc, sizeof(*NNsB.mat) );
   intMatrix NNsB;
   NNsB.r=q.r; NNsB.pr=q.pr; NNsB.c=NNsB.pc=32; NNsB.ld=NNsB.pc;
   NNsB.mat = (unint*)calloc( NNsB.pr*NNsB.pc, sizeof(*NNsB.mat) );
-  
+  matrix distsBrute;
+  distsBrute.r=q.r; distsBrute.pr=q.pr; distsBrute.c=distsBrute.pc=K; distsBrute.ld=distsBrute.pc;
+  distsBrute.mat = (real*)calloc( distsBrute.pr*distsBrute.pc, sizeof(*distsBrute.mat) );
+
   gettimeofday(&tvB,NULL);
   gettimeofday(&tvB,NULL);
-  bruteK(x,q,NNsB);
+  bruteK(x,q,NNsB,distsBrute);
   gettimeofday(&tvE,NULL);
 
    //calc overlap
   gettimeofday(&tvE,NULL);
 
    //calc overlap
@@ -343,4 +349,5 @@ void evalKNNerror(matrix x, matrix q, intMatrix NNs){
   free(cnts);
   free(ol);
   free(NNsB.mat);
   free(cnts);
   free(ol);
   free(NNsB.mat);
+  free(distsBrute.mat);
 }
 }