c6b606cf7d37d110d218a3cc02e7933669f19942
[RBC.git] / driver.cu
1 /* This file is part of the Random Ball Cover (RBC) library.
2 * (C) Copyright 2010, Lawrence Cayton [lcayton@tuebingen.mpg.de]
3 */
4
5 #include<stdio.h>
6 #include<stdlib.h>
7 #include<cuda.h>
8 #include<sys/time.h>
9 #include<math.h>
10 #include "defs.h"
11 #include "utils.h"
12 #include "utilsGPU.h"
13 #include "rbc.h"
14 #include "brute.h"
15 #include "sKernel.h"
16
17 void parseInput(int,char**);
18 void readData(char*,matrix);
19 void readDataText(char*,matrix);
20 void evalNNerror(matrix, matrix, unint*);
21 void evalKNNerror(matrix,matrix,intMatrix);
22 void writeNeighbs(char*,char*,intMatrix,matrix);
23
24 char *dataFileX, *dataFileQ, *dataFileXtxt, *dataFileQtxt, *outFile, *outFiletxt;
25 char runBrute=0, runEval=0;
26 unint n=0, m=0, d=0, numReps=0, deviceNum=0;
27
28 int main(int argc, char**argv){
29 matrix x, q;
30 intMatrix nnsRBC;
31 matrix distsRBC;
32 struct timeval tvB,tvE;
33 cudaError_t cE;
34 rbcStruct rbcS;
35
36 printf("*****************\n");
37 printf("RANDOM BALL COVER\n");
38 printf("*****************\n");
39
40 parseInput(argc,argv);
41
42 gettimeofday( &tvB, NULL );
43 printf("Using GPU #%d\n",deviceNum);
44 if(cudaSetDevice(deviceNum) != cudaSuccess){
45 printf("Unable to select device %d.. exiting. \n",deviceNum);
46 exit(1);
47 }
48
49 size_t memFree, memTot;
50 cudaMemGetInfo(&memFree, &memTot);
51 printf("GPU memory free = %lu/%lu (MB) \n",(unsigned long)memFree/(1024*1024),(unsigned long)memTot/(1024*1024));
52 gettimeofday( &tvE, NULL );
53 printf(" init time: %6.2f \n", timeDiff( tvB, tvE ) );
54
55
56 //Setup matrices
57 initMat( &x, n, d );
58 initMat( &q, m, d );
59 x.mat = (real*)calloc( sizeOfMat(x), sizeof(*(x.mat)) );
60 q.mat = (real*)calloc( sizeOfMat(q), sizeof(*(q.mat)) );
61
62 //Load data
63 if( dataFileXtxt )
64 readDataText( dataFileXtxt, x );
65 else
66 readData( dataFileX, x );
67 if( dataFileQtxt )
68 readDataText( dataFileQtxt, q );
69 else
70 readData( dataFileQ, q );
71
72
73 //Allocate space for NNs and dists
74 initIntMat( &nnsRBC, m, K ); //K is defined in defs.h
75 initMat( &distsRBC, m, K );
76 nnsRBC.mat = (unint*)calloc( sizeOfIntMat(nnsRBC), sizeof(*nnsRBC.mat) );
77 distsRBC.mat = (real*)calloc( sizeOfMat(distsRBC), sizeof(*distsRBC.mat) );
78
79 //Build the RBC
80 printf("building the rbc..\n");
81 gettimeofday( &tvB, NULL );
82 buildRBC( x, &rbcS, numReps, numReps );
83 gettimeofday( &tvE, NULL );
84 printf( "\t.. build time = %6.4f \n", timeDiff(tvB,tvE) );
85
86 //This finds the 32-NNs; if you are only interested in the 1-NN, use queryRBC(..) instead
87 gettimeofday( &tvB, NULL );
88 kqueryRBC( q, rbcS, nnsRBC, distsRBC );
89 gettimeofday( &tvE, NULL );
90 printf( "\t.. query time for krbc = %6.4f \n", timeDiff(tvB,tvE) );
91
92 if( runBrute ){
93 intMatrix nnsBrute;
94 matrix distsBrute;
95 initIntMat( &nnsBrute, m, K );
96 nnsBrute.mat = (unint*)calloc( sizeOfIntMat(nnsBrute), sizeof(*nnsBrute.mat) );
97 initMat( &distsBrute, m, K );
98 distsBrute.mat = (real*)calloc( sizeOfMat(distsBrute), sizeof(*distsBrute.mat) );
99
100 printf("running k-brute force..\n");
101 gettimeofday( &tvB, NULL );
102 bruteK( x, q, nnsBrute, distsBrute );
103 gettimeofday( &tvE, NULL );
104 printf( "\t.. time elapsed = %6.4f \n", timeDiff(tvB,tvE) );
105
106 free( nnsBrute.mat );
107 free( distsBrute.mat );
108 }
109
110 cE = cudaGetLastError();
111 if( cE != cudaSuccess ){
112 printf("Execution failed; error type: %s \n", cudaGetErrorString(cE) );
113 }
114
115 if( runEval )
116 evalKNNerror(x,q,nnsRBC);
117
118 if( outFile || outFiletxt )
119 writeNeighbs( outFile, outFiletxt, nnsRBC, distsRBC );
120
121 destroyRBC( &rbcS );
122 cudaThreadExit();
123 free( nnsRBC.mat );
124 free( distsRBC.mat );
125 free( x.mat );
126 free( q.mat );
127 }
128
129
130 void parseInput(int argc, char **argv){
131 int i=1;
132 if(argc <= 1){
133 printf("\nusage: \n testRBC -x datafileX -q datafileQ -n numPts (DB) -m numQueries -d dim -r numReps [-o outFile] [-g GPU num] [-b] [-e]\n\n");
134 printf("\tdatafileX = binary file containing the database\n");
135 printf("\tdatafileQ = binary file containing the queries\n");
136 printf("\tnumPts = size of database\n");
137 printf("\tnumQueries = number of queries\n");
138 printf("\tdim = dimensionality\n");
139 printf("\tnumReps = number of representatives\n");
140 printf("\toutFile = binary output file (optional)\n");
141 printf("\tGPU num = ID # of the GPU to use (optional) for multi-GPU machines\n");
142 printf("\n\tuse -b to run brute force in addition the RBC\n");
143 printf("\tuse -e to run the evaluation routine (implicitly runs brute force)\n");
144 printf("\n\n\tTo input/output data in text format (instead of bin), use the \n\t-X and -Q and -O switches in place of -x and -q and -o (respectively).\n");
145 printf("\n\n");
146 exit(0);
147 }
148
149 while(i<argc){
150 if(!strcmp(argv[i], "-x"))
151 dataFileX = argv[++i];
152 else if(!strcmp(argv[i], "-q"))
153 dataFileQ = argv[++i];
154 else if(!strcmp(argv[i], "-X"))
155 dataFileX = argv[++i];
156 else if(!strcmp(argv[i], "-Q"))
157 dataFileQ = argv[++i];
158 else if(!strcmp(argv[i], "-n"))
159 n = atoi(argv[++i]);
160 else if(!strcmp(argv[i], "-m"))
161 m = atoi(argv[++i]);
162 else if(!strcmp(argv[i], "-d"))
163 d = atoi(argv[++i]);
164 else if(!strcmp(argv[i], "-r"))
165 numReps = atoi(argv[++i]);
166 else if(!strcmp(argv[i], "-o"))
167 outFile = argv[++i];
168 else if(!strcmp(argv[i], "-O"))
169 outFiletxt = argv[++i];
170 else if(!strcmp(argv[i], "-g"))
171 deviceNum = atoi(argv[++i]);
172 else if(!strcmp(argv[i], "-b"))
173 runBrute=1;
174 else if(!strcmp(argv[i], "-e"))
175 runEval=1;
176 else{
177 fprintf(stderr,"%s : unrecognized option.. exiting\n",argv[i]);
178 exit(1);
179 }
180 i++;
181 }
182
183 if( !n || !m || !d || !numReps ){
184 fprintf(stderr,"more arguments needed.. exiting\n");
185 exit(1);
186 }
187 if( (!dataFileX && !dataFileXtxt) || (!dataFileQ && !dataFileQtxt) ){
188 fprintf(stderr,"more arguments needed.. exiting\n");
189 exit(1);
190 }
191 if( (dataFileX && dataFileXtxt) || (dataFileQ && dataFileQtxt) ){
192 fprintf(stderr,"you can only give one database file and one query file.. exiting\n");
193 exit(1);
194 }
195 if(numReps>n){
196 fprintf(stderr,"can't have more representatives than points.. exiting\n");
197 exit(1);
198 }
199 }
200
201
202 void readData(char *dataFile, matrix x){
203 unint i;
204 FILE *fp;
205 unint numRead;
206
207 fp = fopen(dataFile,"r");
208 if(fp==NULL){
209 fprintf(stderr,"error opening file.. exiting\n");
210 exit(1);
211 }
212
213 for( i=0; i<x.r; i++ ){ //can't load everything in one fread
214 //because matrix is padded.
215 numRead = fread( &x.mat[IDX( i, 0, x.ld )], sizeof(real), x.c, fp );
216 if(numRead != x.c){
217 fprintf(stderr,"error reading file.. exiting \n");
218 exit(1);
219 }
220 }
221 fclose(fp);
222 }
223
224
225 void readDataText(char *dataFile, matrix x){
226 FILE *fp;
227 real t;
228 int i,j;
229
230 fp = fopen(dataFile,"r");
231 if(fp==NULL){
232 fprintf(stderr,"error opening file.. exiting\n");
233 exit(1);
234 }
235
236 for(i=0; i<x.r; i++){
237 for(j=0; j<x.c; j++){
238 if(fscanf(fp,"%f ", &t)==EOF){
239 fprintf(stderr,"error reading file.. exiting \n");
240 exit(1);
241 }
242 x.mat[IDX( i, j, x.ld )]=(real)t;
243 }
244 }
245 fclose(fp);
246 }
247
248
249 //find the error rate of a set of NNs, then print it.
250 void evalNNerror(matrix x, matrix q, unint *NNs){
251 struct timeval tvB, tvE;
252 unint i;
253
254 printf("\nComputing error rates (this might take a while)\n");
255 real *ranges = (real*)calloc(q.pr,sizeof(*ranges));
256 for(i=0;i<q.r;i++){
257 if(NNs[i]>n) printf("error");
258 ranges[i] = distVec(q,x,i,NNs[i]) - 10e-6;
259 }
260
261 unint *cnts = (unint*)calloc(q.pr,sizeof(*cnts));
262 gettimeofday(&tvB,NULL);
263 bruteRangeCount(x,q,ranges,cnts);
264 gettimeofday(&tvE,NULL);
265
266 long int nc=0;
267 for(i=0;i<m;i++){
268 nc += cnts[i];
269 }
270 double mean = ((double)nc)/((double)m);
271 double var = 0.0;
272 for(i=0;i<m;i++) {
273 var += (((double)cnts[i])-mean)*(((double)cnts[i])-mean)/((double)m);
274 }
275 printf("\tavg rank = %6.4f; std dev = %6.4f \n\n", mean, sqrt(var));
276 printf("(range count took %6.4f) \n", timeDiff(tvB, tvE));
277
278 if(outFile){
279 FILE* fp = fopen(outFile, "a");
280 fprintf( fp, "%d %6.5f %6.5f \n", numReps, mean, sqrt(var) );
281 fclose(fp);
282 }
283
284 free(ranges);
285 free(cnts);
286 }
287
288
289 //evals the error rate of k-nns
290 void evalKNNerror(matrix x, matrix q, intMatrix NNs){
291 struct timeval tvB, tvE;
292 unint i,j,k;
293
294 unint m = q.r;
295 printf("\nComputing error rates (this might take a while)\n");
296
297 unint *ol = (unint*)calloc( q.r, sizeof(*ol) );
298
299 intMatrix NNsB;
300 NNsB.r=q.r; NNsB.pr=q.pr; NNsB.c=NNsB.pc=32; NNsB.ld=NNsB.pc;
301 NNsB.mat = (unint*)calloc( NNsB.pr*NNsB.pc, sizeof(*NNsB.mat) );
302 matrix distsBrute;
303 distsBrute.r=q.r; distsBrute.pr=q.pr; distsBrute.c=distsBrute.pc=K; distsBrute.ld=distsBrute.pc;
304 distsBrute.mat = (real*)calloc( distsBrute.pr*distsBrute.pc, sizeof(*distsBrute.mat) );
305
306 gettimeofday(&tvB,NULL);
307 bruteK(x,q,NNsB,distsBrute);
308 gettimeofday(&tvE,NULL);
309
310 //calc overlap
311 for(i=0; i<m; i++){
312 for(j=0; j<K; j++){
313 for(k=0; k<K; k++){
314 ol[i] += ( NNs.mat[IDX(i, j, NNs.ld)] == NNsB.mat[IDX(i, k, NNsB.ld)] );
315 }
316 }
317 }
318
319 long int nc=0;
320 for(i=0;i<m;i++){
321 nc += ol[i];
322 }
323
324 double mean = ((double)nc)/((double)m);
325 double var = 0.0;
326 for(i=0;i<m;i++) {
327 var += (((double)ol[i])-mean)*(((double)ol[i])-mean)/((double)m);
328 }
329 printf("\tavg overlap = %6.4f/%d; std dev = %6.4f \n", mean, K, sqrt(var));
330
331 FILE* fp;
332 if(outFile){
333 fp = fopen(outFile, "a");
334 fprintf( fp, "%d %6.5f %6.5f ", numReps, mean, sqrt(var) );
335 }
336
337 real *ranges = (real*)calloc(q.pr,sizeof(*ranges));
338 for(i=0;i<q.r;i++){
339 ranges[i] = distVec(q,x,i,NNs.mat[IDX(i, K-1, NNs.ld)]);
340 }
341
342 unint *cnts = (unint*)calloc(q.pr,sizeof(*cnts));
343 bruteRangeCount(x,q,ranges,cnts);
344
345 nc=0;
346 for(i=0;i<m;i++){
347 nc += cnts[i];
348 }
349 mean = ((double)nc)/((double)m);
350 var = 0.0;
351 for(i=0;i<m;i++) {
352 var += (((double)cnts[i])-mean)*(((double)cnts[i])-mean)/((double)m);
353 }
354 printf("\tavg actual rank of 32nd NN returned by the RBC = %6.4f; std dev = %6.4f \n\n", mean, sqrt(var));
355 printf("(brute k-nn took %6.4f) \n", timeDiff(tvB, tvE));
356
357 if(outFile){
358 fprintf( fp, "%6.5f %6.5f \n", mean, sqrt(var) );
359 fclose(fp);
360 }
361
362 free(cnts);
363 free(ol);
364 free(NNsB.mat);
365 free(distsBrute.mat);
366 }
367
368
369 void writeNeighbs(char *file, char *filetxt, intMatrix NNs, matrix dNNs){
370 unint i,j;
371
372 if( filetxt ) { //write text
373
374 FILE *fp = fopen(filetxt,"w");
375 if( !fp ){
376 fprintf(stderr, "can't open output file\n");
377 return;
378 }
379
380 for( i=0; i<m; i++ ){
381 for( j=0; j<K; j++ )
382 fprintf( fp, "%u ", NNs.mat[IDX( i, j, NNs.ld )] );
383 fprintf(fp, "\n");
384 }
385
386 for( i=0; i<m; i++ ){
387 for( j=0; j<K; j++ )
388 fprintf( fp, "%f ", dNNs.mat[IDX( i, j, dNNs.ld )]);
389 fprintf(fp, "\n");
390 }
391 fclose(fp);
392
393 }
394
395 if( file ){ //write binary
396
397 FILE *fp = fopen(file,"wb");
398 if( !fp ){
399 fprintf(stderr, "can't open output file\n");
400 return;
401 }
402
403 for( i=0; i<m; i++ )
404 fwrite( &NNs.mat[IDX( i, 0, NNs.ld )], sizeof(*NNs.mat), K, fp );
405 for( i=0; i<m; i++ )
406 fwrite( &dNNs.mat[IDX( i, 0, dNNs.ld )], sizeof(*dNNs.mat), K, fp );
407
408 fclose(fp);
409 }
410 }