updated text files
[RBC.git] / driver.cu
1 /* This file is part of the Random Ball Cover (RBC) library.
2  * (C) Copyright 2010, Lawrence Cayton [lcayton@tuebingen.mpg.de]
3  */
4
5 #include<stdio.h>
6 #include<stdlib.h>
7 #include<sys/time.h>
8
9 #include "rbc_include.h"
10
11 void parseInput(int,char**);
12 void readData(char*,matrix);
13 void readDataText(char*,matrix);
14 void evalNNerror(matrix, matrix, unint*);
15 void evalKNNerror(matrix,matrix,intMatrix);
16 void writeNeighbs(char*,char*,intMatrix,matrix);
17
18 char *dataFileX, *dataFileQ, *dataFileXtxt, *dataFileQtxt, *outFile, *outFiletxt;
19 char runBrute=0, runEval=0;
20 unint n=0, m=0, d=0, numReps=0, deviceNum=0;
21
22 int main(int argc, char**argv){
23   matrix x, q;
24   intMatrix nnsRBC;
25   matrix distsRBC;
26   struct timeval tvB,tvE;
27   cudaError_t cE;
28   rbcStruct rbcS;
29
30   printf("*****************\n");
31   printf("RANDOM BALL COVER\n");
32   printf("*****************\n");
33   
34   parseInput(argc,argv);
35   
36   gettimeofday( &tvB, NULL );
37   printf("Using GPU #%d\n",deviceNum);  
38   if(cudaSetDevice(deviceNum) != cudaSuccess){  
39     printf("Unable to select device %d.. exiting. \n",deviceNum);  
40     exit(1);  
41   }  
42   
43   size_t memFree, memTot;
44   cudaMemGetInfo(&memFree, &memTot);
45   printf("GPU memory free = %lu/%lu (MB) \n",(unsigned long)memFree/(1024*1024),(unsigned long)memTot/(1024*1024));
46   gettimeofday( &tvE, NULL );
47   printf(" init time: %6.2f \n", timeDiff( tvB, tvE ) );
48   
49
50   //Setup matrices
51   initMat( &x, n, d );
52   initMat( &q, m, d );
53   x.mat = (real*)calloc( sizeOfMat(x), sizeof(*(x.mat)) );
54   q.mat = (real*)calloc( sizeOfMat(q), sizeof(*(q.mat)) );
55     
56   //Load data 
57   if( dataFileXtxt )
58     readDataText( dataFileXtxt, x );
59   else
60     readData( dataFileX, x );
61   if( dataFileQtxt )
62     readDataText( dataFileQtxt, q );
63   else
64     readData( dataFileQ, q );
65
66
67   //Allocate space for NNs and dists
68   initIntMat( &nnsRBC, m, KMAX );  //KMAX is defined in defs.h
69   initMat( &distsRBC, m, KMAX );
70   nnsRBC.mat = (unint*)calloc( sizeOfIntMat(nnsRBC), sizeof(*nnsRBC.mat) );
71   distsRBC.mat = (real*)calloc( sizeOfMat(distsRBC), sizeof(*distsRBC.mat) );
72
73   //Build the RBC
74   printf("building the rbc..\n");
75   gettimeofday( &tvB, NULL );
76   buildRBC( x, &rbcS, numReps, numReps );
77   gettimeofday( &tvE, NULL );
78   printf( "\t.. build time = %6.4f \n", timeDiff(tvB,tvE) );
79   
80   //This finds the 32-NNs; if you are only interested in the 1-NN, use queryRBC(..) instead
81   gettimeofday( &tvB, NULL );
82   kqueryRBC( q, rbcS, nnsRBC, distsRBC );
83   gettimeofday( &tvE, NULL );
84   printf( "\t.. query time for krbc = %6.4f \n", timeDiff(tvB,tvE) );
85   
86   //Shows how to run brute force search
87   if( runBrute ){
88     intMatrix nnsBrute;
89     matrix distsBrute;
90     initIntMat( &nnsBrute, m, KMAX );
91     nnsBrute.mat = (unint*)calloc( sizeOfIntMat(nnsBrute), sizeof(*nnsBrute.mat) );
92     initMat( &distsBrute, m, KMAX );
93     distsBrute.mat = (real*)calloc( sizeOfMat(distsBrute), sizeof(*distsBrute.mat) );
94     
95     printf("running k-brute force..\n");
96     gettimeofday( &tvB, NULL );
97     bruteK( x, q, nnsBrute, distsBrute );
98     gettimeofday( &tvE, NULL );
99     printf( "\t.. time elapsed = %6.4f \n", timeDiff(tvB,tvE) );
100     
101     free( nnsBrute.mat );
102     free( distsBrute.mat );
103   }
104
105   cE = cudaGetLastError();
106   if( cE != cudaSuccess ){
107     printf("Execution failed; error type: %s \n", cudaGetErrorString(cE) );
108   }
109   
110   if( runEval )
111     evalKNNerror(x,q,nnsRBC);
112   
113   if( outFile || outFiletxt )
114     writeNeighbs( outFile, outFiletxt, nnsRBC, distsRBC );
115
116   destroyRBC( &rbcS );
117   cudaThreadExit();
118   free( nnsRBC.mat );
119   free( distsRBC.mat );
120   free( x.mat );
121   free( q.mat );
122 }
123
124
125 void parseInput(int argc, char **argv){
126   int i=1;
127   if(argc <= 1){
128     printf("\nusage: \n  testRBC -x datafileX -q datafileQ  -n numPts (DB) -m numQueries -d dim -r numReps [-o outFile] [-g GPU num] [-b] [-e]\n\n");
129     printf("\tdatafileX    = binary file containing the database\n");
130     printf("\tdatafileQ    = binary file containing the queries\n");
131     printf("\tnumPts       = size of database\n");
132     printf("\tnumQueries   = number of queries\n");
133     printf("\tdim          = dimensionality\n");
134     printf("\tnumReps      = number of representatives (must be at least 32)\n");
135     printf("\toutFile      = binary output file (optional)\n");
136     printf("\tGPU num      = ID # of the GPU to use (optional) for multi-GPU machines\n");
137     printf("\n\tuse -b to run brute force in addition the RBC\n");
138     printf("\tuse -e to run the evaluation routine (implicitly runs brute force)\n");
139     printf("\n\n\tTo input/output data in text format (instead of bin), use the \n\t-X and -Q and -O switches in place of -x and -q and -o (respectively).\n");
140     printf("\n\n");
141     exit(0);
142   }
143   
144   while(i<argc){
145     if(!strcmp(argv[i], "-x"))
146       dataFileX = argv[++i];
147     else if(!strcmp(argv[i], "-q"))
148       dataFileQ = argv[++i];
149     else if(!strcmp(argv[i], "-X"))
150       dataFileXtxt = argv[++i];
151     else if(!strcmp(argv[i], "-Q"))
152       dataFileQtxt = argv[++i];
153     else if(!strcmp(argv[i], "-n"))
154       n = atoi(argv[++i]);
155     else if(!strcmp(argv[i], "-m"))
156       m = atoi(argv[++i]);
157     else if(!strcmp(argv[i], "-d"))
158       d = atoi(argv[++i]);
159     else if(!strcmp(argv[i], "-r"))
160       numReps = atoi(argv[++i]);
161     else if(!strcmp(argv[i], "-o"))
162       outFile = argv[++i];
163     else if(!strcmp(argv[i], "-O"))
164       outFiletxt = argv[++i];
165     else if(!strcmp(argv[i], "-g"))
166       deviceNum = atoi(argv[++i]);
167     else if(!strcmp(argv[i], "-b"))
168       runBrute=1;
169     else if(!strcmp(argv[i], "-e"))
170       runEval=1;
171     else{
172       fprintf(stderr,"%s : unrecognized option.. exiting\n",argv[i]);
173       exit(1);
174     }
175     i++;
176   }
177
178   if( !n || !m || !d || !numReps  ){
179     fprintf(stderr,"more arguments needed.. exiting\n");
180     exit(1);
181   }
182   if( (!dataFileX && !dataFileXtxt) || (!dataFileQ && !dataFileQtxt) ){
183     fprintf(stderr,"more arguments needed.. exiting\n");
184     exit(1);
185   }
186   if( (dataFileX && dataFileXtxt) || (dataFileQ && dataFileQtxt) ){
187     fprintf(stderr,"you can only give one database file and one query file.. exiting\n");
188     exit(1); 
189   }
190   if( numReps>n ){
191     fprintf(stderr,"can't have more representatives than points.. exiting\n");
192     exit(1);
193   }
194   if( numReps<32 ){
195     fprintf(stderr, "number of representatives must be at least 32\n");
196     exit(1);
197   }
198 }
199
200
201 void readData(char *dataFile, matrix x){
202   unint i;
203   FILE *fp;
204   unint numRead;
205
206   fp = fopen(dataFile,"r");
207   if(fp==NULL){
208     fprintf(stderr,"error opening file.. exiting\n");
209     exit(1);
210   }
211     
212   for( i=0; i<x.r; i++ ){ //can't load everything in one fread
213                            //because matrix is padded.
214     numRead = fread( &x.mat[IDX( i, 0, x.ld )], sizeof(real), x.c, fp );
215     if(numRead != x.c){
216       fprintf(stderr,"error reading file.. exiting \n");
217       exit(1);
218     }
219   }
220   fclose(fp);
221 }
222
223
224 void readDataText(char *dataFile, matrix x){
225   FILE *fp;
226   double t;
227   unint i,j;
228
229   fp = fopen(dataFile,"r");
230   if(fp==NULL){
231     fprintf(stderr,"error opening file.. exiting\n");
232     exit(1);
233   }
234     
235   for(i=0; i<x.r; i++){
236     for(j=0; j<x.c; j++){
237       if(fscanf(fp,"%lf ", &t)==EOF){
238         fprintf(stderr,"error reading file.. exiting \n");
239         exit(1);
240       }
241       x.mat[IDX( i, j, x.ld )]=(real)t;
242     }
243   }
244   fclose(fp);
245 }
246
247
248 //find the error rate of a set of NNs, then print it.
249 void evalNNerror(matrix x, matrix q, unint *NNs){
250   struct timeval tvB, tvE;
251   unint i;
252
253   printf("\nComputing error rates (this might take a while)\n");
254   real *ranges = (real*)calloc(q.pr,sizeof(*ranges));
255   for(i=0;i<q.r;i++){
256     if(NNs[i]>n) printf("error");
257     ranges[i] = distVec(q,x,i,NNs[i]) - 10e-6;
258   }
259
260   unint *cnts = (unint*)calloc(q.pr,sizeof(*cnts));
261   gettimeofday(&tvB,NULL);
262   bruteRangeCount(x,q,ranges,cnts);
263   gettimeofday(&tvE,NULL);
264   
265   long int nc=0;
266   for(i=0;i<m;i++){
267     nc += cnts[i];
268   }
269   double mean = ((double)nc)/((double)m);
270   double var = 0.0;
271   for(i=0;i<m;i++) {
272     var += (((double)cnts[i])-mean)*(((double)cnts[i])-mean)/((double)m);
273   }
274   printf("\tavg rank = %6.4f; std dev = %6.4f \n\n", mean, sqrt(var));
275   printf("(range count took %6.4f) \n", timeDiff(tvB, tvE));
276   
277   free(ranges);
278   free(cnts);
279 }
280
281
282 //evals the error rate of k-nns
283 void evalKNNerror(matrix x, matrix q, intMatrix NNs){
284   struct timeval tvB, tvE;
285   unint i,j,k;
286
287   unint m = q.r;
288   printf("\nComputing error rates (this might take a while)\n");
289   
290   unint *ol = (unint*)calloc( q.r, sizeof(*ol) );
291   
292   intMatrix NNsB;
293   matrix distsBrute;
294
295   initIntMat( &NNsB, q.r, KMAX );
296   initMat( &distsBrute, q.r, KMAX );
297   NNsB.mat = (unint*)calloc( sizeOfIntMat(NNsB), sizeof(*NNsB.mat) );
298   distsBrute.mat = (real*)calloc( sizeOfMat(distsBrute), sizeof(*distsBrute.mat) );
299
300   gettimeofday(&tvB,NULL);
301   bruteK(x,q,NNsB,distsBrute);
302   gettimeofday(&tvE,NULL);
303
304    //calc overlap
305   for(i=0; i<m; i++){
306     for(j=0; j<KMAX; j++){
307       for(k=0; k<KMAX; k++){
308         ol[i] += ( NNs.mat[IDX(i, j, NNs.ld)] == NNsB.mat[IDX(i, k, NNsB.ld)] );
309       }
310     }
311   }
312
313   long int nc=0;
314   for(i=0;i<m;i++){
315     nc += ol[i];
316   }
317
318   double mean = ((double)nc)/((double)m);
319   double var = 0.0;
320   for(i=0;i<m;i++) {
321     var += (((double)ol[i])-mean)*(((double)ol[i])-mean)/((double)m);
322   }
323   printf("\tavg overlap = %6.4f/%d; std dev = %6.4f \n", mean, KMAX, sqrt(var));
324
325   real *ranges = (real*)calloc(q.pr,sizeof(*ranges));
326   for(i=0;i<q.r;i++){
327     ranges[i] = distVec(q,x,i,NNs.mat[IDX(i, KMAX-1, NNs.ld)]);
328   }
329     
330   unint *cnts = (unint*)calloc(q.pr,sizeof(*cnts));
331   bruteRangeCount(x,q,ranges,cnts);
332   
333   nc=0;
334   for(i=0;i<m;i++){
335     nc += cnts[i];
336   }
337   mean = ((double)nc)/((double)m);
338   var = 0.0;
339   for(i=0;i<m;i++) {
340     var += (((double)cnts[i])-mean)*(((double)cnts[i])-mean)/((double)m);
341   }
342   printf("\tavg actual rank of 32nd NN returned by the RBC = %6.4f; std dev = %6.4f \n\n", mean, sqrt(var));
343   printf("(brute k-nn took %6.4f) \n", timeDiff(tvB, tvE));
344
345   free(cnts);
346   free(ol);
347   free(NNsB.mat);
348   free(distsBrute.mat);
349 }
350
351
352 void writeNeighbs(char *file, char *filetxt, intMatrix NNs, matrix dNNs){
353   unint i,j;
354   
355   if( filetxt ) { //write text
356
357     FILE *fp = fopen(filetxt,"w");
358     if( !fp ){
359       fprintf(stderr, "can't open output file\n");
360       return;
361     }
362     
363     for( i=0; i<m; i++ ){
364       for( j=0; j<KMAX; j++ )
365         fprintf( fp, "%u ", NNs.mat[IDX( i, j, NNs.ld )] );
366       fprintf(fp, "\n");
367     }
368     
369     for( i=0; i<m; i++ ){
370       for( j=0; j<KMAX; j++ )
371         fprintf( fp, "%f ", dNNs.mat[IDX( i, j, dNNs.ld )]); 
372       fprintf(fp, "\n");
373     }
374     fclose(fp);
375     
376   }
377
378   if( file ){ //write binary
379
380     FILE *fp = fopen(file,"wb");
381     if( !fp ){
382       fprintf(stderr, "can't open output file\n");
383       return;
384     }
385     
386     for( i=0; i<m; i++ )
387       fwrite( &NNs.mat[IDX( i, 0, NNs.ld )], sizeof(*NNs.mat), KMAX, fp );
388     for( i=0; i<m; i++ )
389       fwrite( &dNNs.mat[IDX( i, 0, dNNs.ld )], sizeof(*dNNs.mat), KMAX, fp );
390     
391     fclose(fp);
392   }
393 }