updated text files
[RBC.git] / driver.cu
1 /* This file is part of the Random Ball Cover (RBC) library.
2 * (C) Copyright 2010, Lawrence Cayton [lcayton@tuebingen.mpg.de]
3 */
4
5 #include<stdio.h>
6 #include<stdlib.h>
7 #include<sys/time.h>
8
9 #include "rbc_include.h"
10
11 void parseInput(int,char**);
12 void readData(char*,matrix);
13 void readDataText(char*,matrix);
14 void evalNNerror(matrix, matrix, unint*);
15 void evalKNNerror(matrix,matrix,intMatrix);
16 void writeNeighbs(char*,char*,intMatrix,matrix);
17
18 char *dataFileX, *dataFileQ, *dataFileXtxt, *dataFileQtxt, *outFile, *outFiletxt;
19 char runBrute=0, runEval=0;
20 unint n=0, m=0, d=0, numReps=0, deviceNum=0;
21
22 int main(int argc, char**argv){
23 matrix x, q;
24 intMatrix nnsRBC;
25 matrix distsRBC;
26 struct timeval tvB,tvE;
27 cudaError_t cE;
28 rbcStruct rbcS;
29
30 printf("*****************\n");
31 printf("RANDOM BALL COVER\n");
32 printf("*****************\n");
33
34 parseInput(argc,argv);
35
36 gettimeofday( &tvB, NULL );
37 printf("Using GPU #%d\n",deviceNum);
38 if(cudaSetDevice(deviceNum) != cudaSuccess){
39 printf("Unable to select device %d.. exiting. \n",deviceNum);
40 exit(1);
41 }
42
43 size_t memFree, memTot;
44 cudaMemGetInfo(&memFree, &memTot);
45 printf("GPU memory free = %lu/%lu (MB) \n",(unsigned long)memFree/(1024*1024),(unsigned long)memTot/(1024*1024));
46 gettimeofday( &tvE, NULL );
47 printf(" init time: %6.2f \n", timeDiff( tvB, tvE ) );
48
49
50 //Setup matrices
51 initMat( &x, n, d );
52 initMat( &q, m, d );
53 x.mat = (real*)calloc( sizeOfMat(x), sizeof(*(x.mat)) );
54 q.mat = (real*)calloc( sizeOfMat(q), sizeof(*(q.mat)) );
55
56 //Load data
57 if( dataFileXtxt )
58 readDataText( dataFileXtxt, x );
59 else
60 readData( dataFileX, x );
61 if( dataFileQtxt )
62 readDataText( dataFileQtxt, q );
63 else
64 readData( dataFileQ, q );
65
66
67 //Allocate space for NNs and dists
68 initIntMat( &nnsRBC, m, KMAX ); //KMAX is defined in defs.h
69 initMat( &distsRBC, m, KMAX );
70 nnsRBC.mat = (unint*)calloc( sizeOfIntMat(nnsRBC), sizeof(*nnsRBC.mat) );
71 distsRBC.mat = (real*)calloc( sizeOfMat(distsRBC), sizeof(*distsRBC.mat) );
72
73 //Build the RBC
74 printf("building the rbc..\n");
75 gettimeofday( &tvB, NULL );
76 buildRBC( x, &rbcS, numReps, numReps );
77 gettimeofday( &tvE, NULL );
78 printf( "\t.. build time = %6.4f \n", timeDiff(tvB,tvE) );
79
80 //This finds the 32-NNs; if you are only interested in the 1-NN, use queryRBC(..) instead
81 gettimeofday( &tvB, NULL );
82 kqueryRBC( q, rbcS, nnsRBC, distsRBC );
83 gettimeofday( &tvE, NULL );
84 printf( "\t.. query time for krbc = %6.4f \n", timeDiff(tvB,tvE) );
85
86 //Shows how to run brute force search
87 if( runBrute ){
88 intMatrix nnsBrute;
89 matrix distsBrute;
90 initIntMat( &nnsBrute, m, KMAX );
91 nnsBrute.mat = (unint*)calloc( sizeOfIntMat(nnsBrute), sizeof(*nnsBrute.mat) );
92 initMat( &distsBrute, m, KMAX );
93 distsBrute.mat = (real*)calloc( sizeOfMat(distsBrute), sizeof(*distsBrute.mat) );
94
95 printf("running k-brute force..\n");
96 gettimeofday( &tvB, NULL );
97 bruteK( x, q, nnsBrute, distsBrute );
98 gettimeofday( &tvE, NULL );
99 printf( "\t.. time elapsed = %6.4f \n", timeDiff(tvB,tvE) );
100
101 free( nnsBrute.mat );
102 free( distsBrute.mat );
103 }
104
105 cE = cudaGetLastError();
106 if( cE != cudaSuccess ){
107 printf("Execution failed; error type: %s \n", cudaGetErrorString(cE) );
108 }
109
110 if( runEval )
111 evalKNNerror(x,q,nnsRBC);
112
113 if( outFile || outFiletxt )
114 writeNeighbs( outFile, outFiletxt, nnsRBC, distsRBC );
115
116 destroyRBC( &rbcS );
117 cudaThreadExit();
118 free( nnsRBC.mat );
119 free( distsRBC.mat );
120 free( x.mat );
121 free( q.mat );
122 }
123
124
125 void parseInput(int argc, char **argv){
126 int i=1;
127 if(argc <= 1){
128 printf("\nusage: \n testRBC -x datafileX -q datafileQ -n numPts (DB) -m numQueries -d dim -r numReps [-o outFile] [-g GPU num] [-b] [-e]\n\n");
129 printf("\tdatafileX = binary file containing the database\n");
130 printf("\tdatafileQ = binary file containing the queries\n");
131 printf("\tnumPts = size of database\n");
132 printf("\tnumQueries = number of queries\n");
133 printf("\tdim = dimensionality\n");
134 printf("\tnumReps = number of representatives (must be at least 32)\n");
135 printf("\toutFile = binary output file (optional)\n");
136 printf("\tGPU num = ID # of the GPU to use (optional) for multi-GPU machines\n");
137 printf("\n\tuse -b to run brute force in addition the RBC\n");
138 printf("\tuse -e to run the evaluation routine (implicitly runs brute force)\n");
139 printf("\n\n\tTo input/output data in text format (instead of bin), use the \n\t-X and -Q and -O switches in place of -x and -q and -o (respectively).\n");
140 printf("\n\n");
141 exit(0);
142 }
143
144 while(i<argc){
145 if(!strcmp(argv[i], "-x"))
146 dataFileX = argv[++i];
147 else if(!strcmp(argv[i], "-q"))
148 dataFileQ = argv[++i];
149 else if(!strcmp(argv[i], "-X"))
150 dataFileXtxt = argv[++i];
151 else if(!strcmp(argv[i], "-Q"))
152 dataFileQtxt = argv[++i];
153 else if(!strcmp(argv[i], "-n"))
154 n = atoi(argv[++i]);
155 else if(!strcmp(argv[i], "-m"))
156 m = atoi(argv[++i]);
157 else if(!strcmp(argv[i], "-d"))
158 d = atoi(argv[++i]);
159 else if(!strcmp(argv[i], "-r"))
160 numReps = atoi(argv[++i]);
161 else if(!strcmp(argv[i], "-o"))
162 outFile = argv[++i];
163 else if(!strcmp(argv[i], "-O"))
164 outFiletxt = argv[++i];
165 else if(!strcmp(argv[i], "-g"))
166 deviceNum = atoi(argv[++i]);
167 else if(!strcmp(argv[i], "-b"))
168 runBrute=1;
169 else if(!strcmp(argv[i], "-e"))
170 runEval=1;
171 else{
172 fprintf(stderr,"%s : unrecognized option.. exiting\n",argv[i]);
173 exit(1);
174 }
175 i++;
176 }
177
178 if( !n || !m || !d || !numReps ){
179 fprintf(stderr,"more arguments needed.. exiting\n");
180 exit(1);
181 }
182 if( (!dataFileX && !dataFileXtxt) || (!dataFileQ && !dataFileQtxt) ){
183 fprintf(stderr,"more arguments needed.. exiting\n");
184 exit(1);
185 }
186 if( (dataFileX && dataFileXtxt) || (dataFileQ && dataFileQtxt) ){
187 fprintf(stderr,"you can only give one database file and one query file.. exiting\n");
188 exit(1);
189 }
190 if( numReps>n ){
191 fprintf(stderr,"can't have more representatives than points.. exiting\n");
192 exit(1);
193 }
194 if( numReps<32 ){
195 fprintf(stderr, "number of representatives must be at least 32\n");
196 exit(1);
197 }
198 }
199
200
201 void readData(char *dataFile, matrix x){
202 unint i;
203 FILE *fp;
204 unint numRead;
205
206 fp = fopen(dataFile,"r");
207 if(fp==NULL){
208 fprintf(stderr,"error opening file.. exiting\n");
209 exit(1);
210 }
211
212 for( i=0; i<x.r; i++ ){ //can't load everything in one fread
213 //because matrix is padded.
214 numRead = fread( &x.mat[IDX( i, 0, x.ld )], sizeof(real), x.c, fp );
215 if(numRead != x.c){
216 fprintf(stderr,"error reading file.. exiting \n");
217 exit(1);
218 }
219 }
220 fclose(fp);
221 }
222
223
224 void readDataText(char *dataFile, matrix x){
225 FILE *fp;
226 double t;
227 unint i,j;
228
229 fp = fopen(dataFile,"r");
230 if(fp==NULL){
231 fprintf(stderr,"error opening file.. exiting\n");
232 exit(1);
233 }
234
235 for(i=0; i<x.r; i++){
236 for(j=0; j<x.c; j++){
237 if(fscanf(fp,"%lf ", &t)==EOF){
238 fprintf(stderr,"error reading file.. exiting \n");
239 exit(1);
240 }
241 x.mat[IDX( i, j, x.ld )]=(real)t;
242 }
243 }
244 fclose(fp);
245 }
246
247
248 //find the error rate of a set of NNs, then print it.
249 void evalNNerror(matrix x, matrix q, unint *NNs){
250 struct timeval tvB, tvE;
251 unint i;
252
253 printf("\nComputing error rates (this might take a while)\n");
254 real *ranges = (real*)calloc(q.pr,sizeof(*ranges));
255 for(i=0;i<q.r;i++){
256 if(NNs[i]>n) printf("error");
257 ranges[i] = distVec(q,x,i,NNs[i]) - 10e-6;
258 }
259
260 unint *cnts = (unint*)calloc(q.pr,sizeof(*cnts));
261 gettimeofday(&tvB,NULL);
262 bruteRangeCount(x,q,ranges,cnts);
263 gettimeofday(&tvE,NULL);
264
265 long int nc=0;
266 for(i=0;i<m;i++){
267 nc += cnts[i];
268 }
269 double mean = ((double)nc)/((double)m);
270 double var = 0.0;
271 for(i=0;i<m;i++) {
272 var += (((double)cnts[i])-mean)*(((double)cnts[i])-mean)/((double)m);
273 }
274 printf("\tavg rank = %6.4f; std dev = %6.4f \n\n", mean, sqrt(var));
275 printf("(range count took %6.4f) \n", timeDiff(tvB, tvE));
276
277 free(ranges);
278 free(cnts);
279 }
280
281
282 //evals the error rate of k-nns
283 void evalKNNerror(matrix x, matrix q, intMatrix NNs){
284 struct timeval tvB, tvE;
285 unint i,j,k;
286
287 unint m = q.r;
288 printf("\nComputing error rates (this might take a while)\n");
289
290 unint *ol = (unint*)calloc( q.r, sizeof(*ol) );
291
292 intMatrix NNsB;
293 matrix distsBrute;
294
295 initIntMat( &NNsB, q.r, KMAX );
296 initMat( &distsBrute, q.r, KMAX );
297 NNsB.mat = (unint*)calloc( sizeOfIntMat(NNsB), sizeof(*NNsB.mat) );
298 distsBrute.mat = (real*)calloc( sizeOfMat(distsBrute), sizeof(*distsBrute.mat) );
299
300 gettimeofday(&tvB,NULL);
301 bruteK(x,q,NNsB,distsBrute);
302 gettimeofday(&tvE,NULL);
303
304 //calc overlap
305 for(i=0; i<m; i++){
306 for(j=0; j<KMAX; j++){
307 for(k=0; k<KMAX; k++){
308 ol[i] += ( NNs.mat[IDX(i, j, NNs.ld)] == NNsB.mat[IDX(i, k, NNsB.ld)] );
309 }
310 }
311 }
312
313 long int nc=0;
314 for(i=0;i<m;i++){
315 nc += ol[i];
316 }
317
318 double mean = ((double)nc)/((double)m);
319 double var = 0.0;
320 for(i=0;i<m;i++) {
321 var += (((double)ol[i])-mean)*(((double)ol[i])-mean)/((double)m);
322 }
323 printf("\tavg overlap = %6.4f/%d; std dev = %6.4f \n", mean, KMAX, sqrt(var));
324
325 real *ranges = (real*)calloc(q.pr,sizeof(*ranges));
326 for(i=0;i<q.r;i++){
327 ranges[i] = distVec(q,x,i,NNs.mat[IDX(i, KMAX-1, NNs.ld)]);
328 }
329
330 unint *cnts = (unint*)calloc(q.pr,sizeof(*cnts));
331 bruteRangeCount(x,q,ranges,cnts);
332
333 nc=0;
334 for(i=0;i<m;i++){
335 nc += cnts[i];
336 }
337 mean = ((double)nc)/((double)m);
338 var = 0.0;
339 for(i=0;i<m;i++) {
340 var += (((double)cnts[i])-mean)*(((double)cnts[i])-mean)/((double)m);
341 }
342 printf("\tavg actual rank of 32nd NN returned by the RBC = %6.4f; std dev = %6.4f \n\n", mean, sqrt(var));
343 printf("(brute k-nn took %6.4f) \n", timeDiff(tvB, tvE));
344
345 free(cnts);
346 free(ol);
347 free(NNsB.mat);
348 free(distsBrute.mat);
349 }
350
351
352 void writeNeighbs(char *file, char *filetxt, intMatrix NNs, matrix dNNs){
353 unint i,j;
354
355 if( filetxt ) { //write text
356
357 FILE *fp = fopen(filetxt,"w");
358 if( !fp ){
359 fprintf(stderr, "can't open output file\n");
360 return;
361 }
362
363 for( i=0; i<m; i++ ){
364 for( j=0; j<KMAX; j++ )
365 fprintf( fp, "%u ", NNs.mat[IDX( i, j, NNs.ld )] );
366 fprintf(fp, "\n");
367 }
368
369 for( i=0; i<m; i++ ){
370 for( j=0; j<KMAX; j++ )
371 fprintf( fp, "%f ", dNNs.mat[IDX( i, j, dNNs.ld )]);
372 fprintf(fp, "\n");
373 }
374 fclose(fp);
375
376 }
377
378 if( file ){ //write binary
379
380 FILE *fp = fopen(file,"wb");
381 if( !fp ){
382 fprintf(stderr, "can't open output file\n");
383 return;
384 }
385
386 for( i=0; i<m; i++ )
387 fwrite( &NNs.mat[IDX( i, 0, NNs.ld )], sizeof(*NNs.mat), KMAX, fp );
388 for( i=0; i<m; i++ )
389 fwrite( &dNNs.mat[IDX( i, 0, dNNs.ld )], sizeof(*dNNs.mat), KMAX, fp );
390
391 fclose(fp);
392 }
393 }