improved driver
[RBC.git] / driver.cu
1 /* This file is part of the Random Ball Cover (RBC) library.
2 * (C) Copyright 2010, Lawrence Cayton [lcayton@tuebingen.mpg.de]
3 */
4
5 #include<stdio.h>
6 #include<stdlib.h>
7 #include<cuda.h>
8 #include<sys/time.h>
9 #include<math.h>
10 #include "defs.h"
11 #include "utils.h"
12 #include "utilsGPU.h"
13 #include "rbc.h"
14 #include "brute.h"
15 #include "sKernel.h"
16
17 void parseInput(int,char**);
18 void readData(char*,matrix);
19 void readDataText(char*,matrix);
20 void evalNNerror(matrix, matrix, unint*);
21 void evalKNNerror(matrix,matrix,intMatrix);
22 void writeNeighbs(char*,char*,intMatrix,matrix);
23
24 char *dataFileX, *dataFileQ, *dataFileXtxt, *dataFileQtxt, *outFile, *outFiletxt;
25 char runBrute=0, runEval=0;
26 unint n=0, m=0, d=0, numReps=0, deviceNum=0;
27
28 int main(int argc, char**argv){
29 matrix x, q;
30 intMatrix nnsRBC;
31 matrix distsRBC;
32 struct timeval tvB,tvE;
33 cudaError_t cE;
34 rbcStruct rbcS;
35
36 printf("*****************\n");
37 printf("RANDOM BALL COVER\n");
38 printf("*****************\n");
39
40 parseInput(argc,argv);
41
42 gettimeofday( &tvB, NULL );
43 printf("Using GPU #%d\n",deviceNum);
44 if(cudaSetDevice(deviceNum) != cudaSuccess){
45 printf("Unable to select device %d.. exiting. \n",deviceNum);
46 exit(1);
47 }
48
49 size_t memFree, memTot;
50 cudaMemGetInfo(&memFree, &memTot);
51 printf("GPU memory free = %lu/%lu (MB) \n",(unsigned long)memFree/(1024*1024),(unsigned long)memTot/(1024*1024));
52 gettimeofday( &tvE, NULL );
53 printf(" init time: %6.2f \n", timeDiff( tvB, tvE ) );
54
55 initMat( &x, n, d );
56 initMat( &q, m, d );
57 x.mat = (real*)calloc( sizeOfMat(x), sizeof(*(x.mat)) );
58 q.mat = (real*)calloc( sizeOfMat(q), sizeof(*(q.mat)) );
59
60 //Load data
61 if(dataFileXtxt)
62 readDataText(dataFileXtxt, x);
63 else
64 readData(dataFileX, x);
65 if(dataFileQtxt)
66 readDataText(dataFileQtxt, q);
67 else
68 readData(dataFileQ, q);
69
70
71 //Allocate space for NNs and dists
72 initIntMat( &nnsRBC, m, K ); //K is defined in defs.h
73 initMat( &distsRBC, m, K );
74 nnsRBC.mat = (unint*)calloc( sizeOfIntMat(nnsRBC), sizeof(*nnsRBC.mat) );
75 distsRBC.mat = (real*)calloc( sizeOfMat(distsRBC), sizeof(*distsRBC.mat) );
76
77 //Build the RBC
78 printf("building the rbc..\n");
79 gettimeofday(&tvB,NULL);
80 buildRBC(x, &rbcS, numReps, numReps);
81 gettimeofday(&tvE,NULL);
82 printf("\t.. build time = %6.4f \n",timeDiff(tvB,tvE));
83
84 //This finds the 32-NNs; if you are only interested in the 1-NN, use queryRBC(..) instead
85 gettimeofday(&tvB,NULL);
86 kqueryRBC(q, rbcS, nnsRBC, distsRBC);
87 gettimeofday(&tvE,NULL);
88 printf("\t.. query time for krbc = %6.4f \n",timeDiff(tvB,tvE));
89
90 if( runBrute ){
91 intMatrix nnsBrute;
92 matrix distsBrute;
93 initIntMat( &nnsBrute, m, K );
94 nnsBrute.mat = (unint*)calloc( sizeOfIntMat(nnsBrute), sizeof(*nnsBrute.mat) );
95 initMat( &distsBrute, m, K );
96 distsBrute.mat = (real*)calloc( sizeOfMat(distsBrute), sizeof(*distsBrute.mat) );
97
98 printf("running k-brute force..\n");
99 gettimeofday(&tvB,NULL);
100 bruteK(x,q,nnsBrute,distsBrute);
101 gettimeofday(&tvE,NULL);
102 printf("\t.. time elapsed = %6.4f \n",timeDiff(tvB,tvE));
103
104 free(nnsBrute.mat);
105 free(distsBrute.mat);
106 }
107
108 cE = cudaGetLastError();
109 if( cE != cudaSuccess ){
110 printf("Execution failed; error type: %s \n", cudaGetErrorString(cE) );
111 }
112
113 if( runEval )
114 evalKNNerror(x,q,nnsRBC);
115
116 if( outFile || outFiletxt )
117 writeNeighbs( outFile, outFiletxt, nnsRBC, distsRBC );
118
119 destroyRBC(&rbcS);
120 cudaThreadExit();
121 free(nnsRBC.mat);
122 free(distsRBC.mat);
123 free(x.mat);
124 free(q.mat);
125 }
126
127
128 void parseInput(int argc, char **argv){
129 int i=1;
130 if(argc <= 1){
131 printf("\nusage: \n testRBC -x datafileX -q datafileQ -n numPts (DB) -m numQueries -d dim -r numReps [-o outFile] [-g GPU num] [-b] [-e]\n\n");
132 printf("\tdatafileX = binary file containing the database\n");
133 printf("\tdatafileQ = binary file containing the queries\n");
134 printf("\tnumPts = size of database\n");
135 printf("\tnumQueries = number of queries\n");
136 printf("\tdim = dimensionality\n");
137 printf("\tnumReps = number of representatives\n");
138 printf("\toutFile = binary output file (optional)\n");
139 printf("\tGPU num = ID # of the GPU to use (optional) for multi-GPU machines\n");
140 printf("\n\tuse -b to run brute force in addition the RBC\n");
141 printf("\tuse -e to run the evaluation routine (implicitly runs brute force)\n");
142 printf("\n\n\tTo input/output data in text format (instead of bin), use the \n\t-X and -Q and -O switches in place of -x and -q and -o (respectively).\n");
143 printf("\n\n");
144 exit(0);
145 }
146
147 while(i<argc){
148 if(!strcmp(argv[i], "-x"))
149 dataFileX = argv[++i];
150 else if(!strcmp(argv[i], "-q"))
151 dataFileQ = argv[++i];
152 else if(!strcmp(argv[i], "-X"))
153 dataFileX = argv[++i];
154 else if(!strcmp(argv[i], "-Q"))
155 dataFileQ = argv[++i];
156 else if(!strcmp(argv[i], "-n"))
157 n = atoi(argv[++i]);
158 else if(!strcmp(argv[i], "-m"))
159 m = atoi(argv[++i]);
160 else if(!strcmp(argv[i], "-d"))
161 d = atoi(argv[++i]);
162 else if(!strcmp(argv[i], "-r"))
163 numReps = atoi(argv[++i]);
164 else if(!strcmp(argv[i], "-o"))
165 outFile = argv[++i];
166 else if(!strcmp(argv[i], "-O"))
167 outFiletxt = argv[++i];
168 else if(!strcmp(argv[i], "-g"))
169 deviceNum = atoi(argv[++i]);
170 else if(!strcmp(argv[i], "-b"))
171 runBrute=1;
172 else if(!strcmp(argv[i], "-e"))
173 runEval=1;
174 else{
175 fprintf(stderr,"%s : unrecognized option.. exiting\n",argv[i]);
176 exit(1);
177 }
178 i++;
179 }
180
181 if( !n || !m || !d || !numReps ){
182 fprintf(stderr,"more arguments needed.. exiting\n");
183 exit(1);
184 }
185 if( (!dataFileX && !dataFileXtxt) || (!dataFileQ && !dataFileQtxt) ){
186 fprintf(stderr,"more arguments needed.. exiting\n");
187 exit(1);
188 }
189 if( (dataFileX && dataFileXtxt) || (dataFileQ && dataFileQtxt) ){
190 fprintf(stderr,"you can only give one database file and one query file.. exiting\n");
191 exit(1);
192 }
193 if(numReps>n){
194 fprintf(stderr,"can't have more representatives than points.. exiting\n");
195 exit(1);
196 }
197 }
198
199
200 void readData(char *dataFile, matrix x){
201 unint i;
202 FILE *fp;
203 unint numRead;
204
205 fp = fopen(dataFile,"r");
206 if(fp==NULL){
207 fprintf(stderr,"error opening file.. exiting\n");
208 exit(1);
209 }
210
211 for( i=0; i<x.r; i++ ){ //can't load everything in one fread
212 //because matrix is padded.
213 numRead = fread( &x.mat[IDX( i, 0, x.ld )], sizeof(real), x.c, fp );
214 if(numRead != x.c){
215 fprintf(stderr,"error reading file.. exiting \n");
216 exit(1);
217 }
218 }
219 fclose(fp);
220 }
221
222
223 void readDataText(char *dataFile, matrix x){
224 FILE *fp;
225 real t;
226 int i,j;
227
228 fp = fopen(dataFile,"r");
229 if(fp==NULL){
230 fprintf(stderr,"error opening file.. exiting\n");
231 exit(1);
232 }
233
234 for(i=0; i<x.r; i++){
235 for(j=0; j<x.c; j++){
236 if(fscanf(fp,"%f ", &t)==EOF){
237 fprintf(stderr,"error reading file.. exiting \n");
238 exit(1);
239 }
240 x.mat[IDX( i, j, x.ld )]=(real)t;
241 }
242 }
243 fclose(fp);
244 }
245
246
247 //find the error rate of a set of NNs, then print it.
248 void evalNNerror(matrix x, matrix q, unint *NNs){
249 struct timeval tvB, tvE;
250 unint i;
251
252 printf("\nComputing error rates (this might take a while)\n");
253 real *ranges = (real*)calloc(q.pr,sizeof(*ranges));
254 for(i=0;i<q.r;i++){
255 if(NNs[i]>n) printf("error");
256 ranges[i] = distVec(q,x,i,NNs[i]) - 10e-6;
257 }
258
259 unint *cnts = (unint*)calloc(q.pr,sizeof(*cnts));
260 gettimeofday(&tvB,NULL);
261 bruteRangeCount(x,q,ranges,cnts);
262 gettimeofday(&tvE,NULL);
263
264 long int nc=0;
265 for(i=0;i<m;i++){
266 nc += cnts[i];
267 }
268 double mean = ((double)nc)/((double)m);
269 double var = 0.0;
270 for(i=0;i<m;i++) {
271 var += (((double)cnts[i])-mean)*(((double)cnts[i])-mean)/((double)m);
272 }
273 printf("\tavg rank = %6.4f; std dev = %6.4f \n\n", mean, sqrt(var));
274 printf("(range count took %6.4f) \n", timeDiff(tvB, tvE));
275
276 if(outFile){
277 FILE* fp = fopen(outFile, "a");
278 fprintf( fp, "%d %6.5f %6.5f \n", numReps, mean, sqrt(var) );
279 fclose(fp);
280 }
281
282 free(ranges);
283 free(cnts);
284 }
285
286
287 //evals the error rate of k-nns
288 void evalKNNerror(matrix x, matrix q, intMatrix NNs){
289 struct timeval tvB, tvE;
290 unint i,j,k;
291
292 unint m = q.r;
293 printf("\nComputing error rates (this might take a while)\n");
294
295 unint *ol = (unint*)calloc( q.r, sizeof(*ol) );
296
297 intMatrix NNsB;
298 NNsB.r=q.r; NNsB.pr=q.pr; NNsB.c=NNsB.pc=32; NNsB.ld=NNsB.pc;
299 NNsB.mat = (unint*)calloc( NNsB.pr*NNsB.pc, sizeof(*NNsB.mat) );
300 matrix distsBrute;
301 distsBrute.r=q.r; distsBrute.pr=q.pr; distsBrute.c=distsBrute.pc=K; distsBrute.ld=distsBrute.pc;
302 distsBrute.mat = (real*)calloc( distsBrute.pr*distsBrute.pc, sizeof(*distsBrute.mat) );
303
304 gettimeofday(&tvB,NULL);
305 bruteK(x,q,NNsB,distsBrute);
306 gettimeofday(&tvE,NULL);
307
308 //calc overlap
309 for(i=0; i<m; i++){
310 for(j=0; j<K; j++){
311 for(k=0; k<K; k++){
312 ol[i] += ( NNs.mat[IDX(i, j, NNs.ld)] == NNsB.mat[IDX(i, k, NNsB.ld)] );
313 }
314 }
315 }
316
317 long int nc=0;
318 for(i=0;i<m;i++){
319 nc += ol[i];
320 }
321
322 double mean = ((double)nc)/((double)m);
323 double var = 0.0;
324 for(i=0;i<m;i++) {
325 var += (((double)ol[i])-mean)*(((double)ol[i])-mean)/((double)m);
326 }
327 printf("\tavg overlap = %6.4f/%d; std dev = %6.4f \n", mean, K, sqrt(var));
328
329 FILE* fp;
330 if(outFile){
331 fp = fopen(outFile, "a");
332 fprintf( fp, "%d %6.5f %6.5f ", numReps, mean, sqrt(var) );
333 }
334
335 real *ranges = (real*)calloc(q.pr,sizeof(*ranges));
336 for(i=0;i<q.r;i++){
337 ranges[i] = distVec(q,x,i,NNs.mat[IDX(i, K-1, NNs.ld)]);
338 }
339
340 unint *cnts = (unint*)calloc(q.pr,sizeof(*cnts));
341 bruteRangeCount(x,q,ranges,cnts);
342
343 nc=0;
344 for(i=0;i<m;i++){
345 nc += cnts[i];
346 }
347 mean = ((double)nc)/((double)m);
348 var = 0.0;
349 for(i=0;i<m;i++) {
350 var += (((double)cnts[i])-mean)*(((double)cnts[i])-mean)/((double)m);
351 }
352 printf("\tavg actual rank of 32nd NN returned by the RBC = %6.4f; std dev = %6.4f \n\n", mean, sqrt(var));
353 printf("(brute k-nn took %6.4f) \n", timeDiff(tvB, tvE));
354
355 if(outFile){
356 fprintf( fp, "%6.5f %6.5f \n", mean, sqrt(var) );
357 fclose(fp);
358 }
359
360 free(cnts);
361 free(ol);
362 free(NNsB.mat);
363 free(distsBrute.mat);
364 }
365
366
367 void writeNeighbs(char *file, char *filetxt, intMatrix NNs, matrix dNNs){
368 unint i,j;
369
370 if( filetxt ) { //write text
371
372 FILE *fp = fopen(filetxt,"w");
373 if( !fp ){
374 fprintf(stderr, "can't open output file\n");
375 return;
376 }
377
378 for( i=0; i<m; i++ ){
379 for( j=0; j<K; j++ )
380 fprintf( fp, "%u ", NNs.mat[IDX( i, j, NNs.ld )] );
381 fprintf(fp, "\n");
382 }
383
384 for( i=0; i<m; i++ ){
385 for( j=0; j<K; j++ )
386 fprintf( fp, "%f ", dNNs.mat[IDX( i, j, dNNs.ld )]);
387 fprintf(fp, "\n");
388 }
389 fclose(fp);
390
391 }
392
393 if( file ){ //write binary
394
395 FILE *fp = fopen(file,"wb");
396 if( !fp ){
397 fprintf(stderr, "can't open output file\n");
398 return;
399 }
400
401 for( i=0; i<m; i++ )
402 fwrite( &NNs.mat[IDX( i, 0, NNs.ld )], sizeof(*NNs.mat), K, fp );
403 for( i=0; i<m; i++ )
404 fwrite( &dNNs.mat[IDX( i, 0, dNNs.ld )], sizeof(*dNNs.mat), K, fp );
405
406 fclose(fp);
407 }
408 }