aee8ee5fafee56b765a285f826d45edb9a5ba647
[RBC.git] / driver.cu
1 /* This file is part of the Random Ball Cover (RBC) library.
2 * (C) Copyright 2010, Lawrence Cayton [lcayton@tuebingen.mpg.de]
3 */
4
5 #include<stdio.h>
6 #include<stdlib.h>
7 #include<cuda.h>
8 #include<sys/time.h>
9 #include<math.h>
10 #include "defs.h"
11 #include "utils.h"
12 #include "utilsGPU.h"
13 #include "rbc.h"
14 #include "brute.h"
15 #include "sKernel.h"
16
17 void parseInput(int,char**);
18 void readData(char*,unint,unint,real*);
19 void readDataText(char*,unint,unint,real*);
20 void orgData(real*,unint,unint,matrix,matrix);
21 void evalNNerror(matrix, matrix, unint*);
22 void evalKNNerror(matrix,matrix,intMatrix);
23
24 char *dataFile, *outFile;
25 unint n=0, m=0, d=0, numReps=0, s=0;
26 unint deviceNum=0;
27 int main(int argc, char**argv){
28 real *data;
29 matrix x, q;
30 intMatrix nnsBrute, nnsRBC;
31 matrix distsBrute, distsRBC;
32 struct timeval tvB,tvE;
33 cudaError_t cE;
34 rbcStruct rbcS;
35
36 printf("*****************\n");
37 printf("RANDOM BALL COVER\n");
38 printf("*****************\n");
39
40 parseInput(argc,argv);
41
42 printf("Using GPU #%d\n",deviceNum);
43 if(cudaSetDevice(deviceNum) != cudaSuccess){
44 printf("Unable to select device %d.. exiting. \n",deviceNum);
45 exit(1);
46 }
47
48 size_t memFree, memTot;
49 cudaMemGetInfo(&memFree, &memTot);
50 printf("GPU memory free = %lu/%lu (MB) \n",(unsigned long)memFree/(1024*1024),(unsigned long)memTot/(1024*1024));
51
52 data = (real*)calloc( (n+m)*d, sizeof(*data) );
53 x.mat = (real*)calloc( PAD(n)*PAD(d), sizeof(*(x.mat)) );
54
55 //Need to allocate extra space, as each group of q will be padded later.
56 q.mat = (real*)calloc( PAD(m)*PAD(d), sizeof(*(q.mat)) );
57 x.r = n; x.c = d; x.pr = PAD(n); x.pc = PAD(d); x.ld = x.pc;
58 q.r = m; q.c = d; q.pr = PAD(m); q.pc = PAD(d); q.ld = q.pc;
59
60 //Load data
61 readData(dataFile, (n+m), d, data);
62 orgData(data, (n+m), d, x, q);
63 free(data);
64
65 //Allocate space for NNs and dists
66 nnsBrute.r=q.r; nnsBrute.pr=q.pr; nnsBrute.pc=nnsBrute.c=K; nnsBrute.ld=nnsBrute.pc;
67 nnsBrute.mat = (unint*)calloc(nnsBrute.pr*nnsBrute.pc, sizeof(*nnsBrute.mat));
68 nnsRBC.r=q.r; nnsRBC.pr=q.pr; nnsRBC.pc=nnsRBC.c=K; nnsRBC.ld=nnsRBC.pc;
69 nnsRBC.mat = (unint*)calloc(nnsRBC.pr*nnsRBC.pc, sizeof(*nnsRBC.mat));
70
71 distsBrute.r=q.r; distsBrute.pr=q.pr; distsBrute.pc=distsBrute.c=K; distsBrute.ld=distsBrute.pc;
72 distsBrute.mat = (real*)calloc(distsBrute.pr*distsBrute.pc, sizeof(*distsBrute.mat));
73 distsRBC.r=q.r; distsRBC.pr=q.pr; distsRBC.pc=distsRBC.c=K; distsRBC.ld=distsRBC.pc;
74 distsRBC.mat = (real*)calloc(distsRBC.pr*distsRBC.pc, sizeof(*distsRBC.mat));
75
76 printf("running k-brute force..\n");
77 gettimeofday(&tvB,NULL);
78 bruteK(x,q,nnsBrute,distsBrute);
79 gettimeofday(&tvE,NULL);
80 printf("\t.. time elapsed = %6.4f \n",timeDiff(tvB,tvE));
81
82 printf("\nrunning rbc..\n");
83 gettimeofday(&tvB,NULL);
84 buildRBC(x, &rbcS, numReps, s);
85 gettimeofday(&tvE,NULL);
86 printf("\t.. build time for rbc = %6.4f \n",timeDiff(tvB,tvE));
87
88 //This finds the 32-NN; if you are only interested in the 1-NN, use queryRBC(..) instead
89 gettimeofday(&tvB,NULL);
90 kqueryRBC(q, rbcS, nnsRBC, distsRBC);
91 gettimeofday(&tvE,NULL);
92 printf("\t.. query time for krbc = %6.4f \n",timeDiff(tvB,tvE));
93
94 destroyRBC(&rbcS);
95 printf("finished \n");
96
97 cE = cudaGetLastError();
98 if( cE != cudaSuccess ){
99 printf("Execution failed; error type: %s \n", cudaGetErrorString(cE) );
100 }
101
102 evalKNNerror(x,q,nnsRBC);
103
104 cudaThreadExit();
105
106 free(nnsBrute.mat);
107 free(nnsRBC.mat);
108 free(distsBrute.mat);
109 free(distsRBC.mat);
110 free(x.mat);
111 free(q.mat);
112 }
113
114
115 void parseInput(int argc, char **argv){
116 int i=1;
117 if(argc <= 1){
118 printf("\nusage: \n testRBC -f datafile (bin) -n numPts (DB) -m numQueries -d dim -r numReps -s numPtsPerRep [-o outFile] [-g GPU num]\n\n");
119 printf("\tdatafile = binary file containing the data\n");
120 printf("\tnumPts = size of database\n");
121 printf("\tnumQueries = number of queries\n");
122 printf("\tdim = dimensionailty\n");
123 printf("\tnumReps = number of representatives\n");
124 printf("\tnumPtsPerRep = number of points assigned to each representative\n");
125 printf("\toutFile = output file (optional); stored in text format\n");
126 printf("\tGPU num = ID # of the GPU to use (optional) for multi-GPU machines\n");
127 printf("\n\n");
128 exit(0);
129 }
130
131 while(i<argc){
132 if(!strcmp(argv[i], "-f"))
133 dataFile = argv[++i];
134 else if(!strcmp(argv[i], "-n"))
135 n = atoi(argv[++i]);
136 else if(!strcmp(argv[i], "-m"))
137 m = atoi(argv[++i]);
138 else if(!strcmp(argv[i], "-d"))
139 d = atoi(argv[++i]);
140 else if(!strcmp(argv[i], "-r"))
141 numReps = atoi(argv[++i]);
142 else if(!strcmp(argv[i], "-s"))
143 s = atoi(argv[++i]);
144 else if(!strcmp(argv[i], "-o"))
145 outFile = argv[++i];
146 else if(!strcmp(argv[i], "-g"))
147 deviceNum = atoi(argv[++i]);
148 else{
149 fprintf(stderr,"%s : unrecognized option.. exiting\n",argv[i]);
150 exit(1);
151 }
152 i++;
153 }
154
155 if( !n || !m || !d || !numReps || !s || !dataFile){
156 fprintf(stderr,"more arguments needed.. exiting\n");
157 exit(1);
158 }
159
160 if(numReps>n){
161 fprintf(stderr,"can't have more representatives than points.. exiting\n");
162 exit(1);
163 }
164 }
165
166
167 void readData(char *dataFile, unint rows, unint cols, real *data){
168 FILE *fp;
169 unint numRead;
170
171 fp = fopen(dataFile,"r");
172 if(fp==NULL){
173 fprintf(stderr,"error opening file.. exiting\n");
174 exit(1);
175 }
176
177 numRead = fread(data,sizeof(real),rows*cols,fp);
178 if(numRead != rows*cols){
179 fprintf(stderr,"error reading file.. exiting \n");
180 exit(1);
181 }
182 fclose(fp);
183 }
184
185
186 void readDataText(char *dataFile, unint rows, unint cols, real *data){
187 FILE *fp;
188 real t;
189
190 fp = fopen(dataFile,"r");
191 if(fp==NULL){
192 fprintf(stderr,"error opening file.. exiting\n");
193 exit(1);
194 }
195
196 for(int i=0; i<rows; i++){
197 for(int j=0; j<cols; j++){
198 if(fscanf(fp,"%f ", &t)==EOF){
199 fprintf(stderr,"error reading file.. exiting \n");
200 exit(1);
201 }
202 data[IDX( i, j, cols )]=(real)t;
203 }
204 }
205 fclose(fp);
206 }
207
208 //This function splits the data into two matrices, x and q, of
209 //their specified dimensions. The data is split randomly.
210 //It is assumed that the number of rows of data (the parameter n)
211 //is at least as large as x.r+q.r
212 void orgData(real *data, unint n, unint d, matrix x, matrix q){
213
214 unint i,fi,j;
215 unint *p;
216 p = (unint*)calloc(n,sizeof(*p));
217
218 randPerm(n,p);
219
220 for(i=0,fi=0 ; i<x.r ; i++,fi++){
221 for(j=0;j<x.c;j++){
222 x.mat[IDX(i,j,x.ld)] = data[IDX(p[fi],j,d)];
223 }
224 }
225
226 for(i=0 ; i<q.r ; i++,fi++){
227 for(j=0;j<q.c;j++){
228 q.mat[IDX(i,j,q.ld)] = data[IDX(p[fi],j,d)];
229 }
230 }
231
232 free(p);
233 }
234
235
236 //find the error rate of a set of NNs, then print it.
237 void evalNNerror(matrix x, matrix q, unint *NNs){
238 struct timeval tvB, tvE;
239 unint i;
240
241 printf("\nComputing error rates (this might take a while)\n");
242 real *ranges = (real*)calloc(q.pr,sizeof(*ranges));
243 for(i=0;i<q.r;i++){
244 if(NNs[i]>n) printf("error");
245 ranges[i] = distVec(q,x,i,NNs[i]) - 10e-6;
246 }
247
248 unint *cnts = (unint*)calloc(q.pr,sizeof(*cnts));
249 gettimeofday(&tvB,NULL);
250 bruteRangeCount(x,q,ranges,cnts);
251 gettimeofday(&tvE,NULL);
252
253 long int nc=0;
254 for(i=0;i<m;i++){
255 nc += cnts[i];
256 }
257 double mean = ((double)nc)/((double)m);
258 double var = 0.0;
259 for(i=0;i<m;i++) {
260 var += (((double)cnts[i])-mean)*(((double)cnts[i])-mean)/((double)m);
261 }
262 printf("\tavg rank = %6.4f; std dev = %6.4f \n\n", mean, sqrt(var));
263 printf("(range count took %6.4f) \n", timeDiff(tvB, tvE));
264
265 if(outFile){
266 FILE* fp = fopen(outFile, "a");
267 fprintf( fp, "%d %d %6.5f %6.5f \n", numReps, s, mean, sqrt(var) );
268 fclose(fp);
269 }
270
271 free(ranges);
272 free(cnts);
273 }
274
275
276 //evals the error rate of k-nns
277 void evalKNNerror(matrix x, matrix q, intMatrix NNs){
278 struct timeval tvB, tvE;
279 unint i,j,k;
280
281 unint m = q.r;
282 printf("\nComputing error rates (this might take a while)\n");
283
284 unint *ol = (unint*)calloc( q.r, sizeof(*ol) );
285
286 intMatrix NNsB;
287 NNsB.r=q.r; NNsB.pr=q.pr; NNsB.c=NNsB.pc=32; NNsB.ld=NNsB.pc;
288 NNsB.mat = (unint*)calloc( NNsB.pr*NNsB.pc, sizeof(*NNsB.mat) );
289 matrix distsBrute;
290 distsBrute.r=q.r; distsBrute.pr=q.pr; distsBrute.c=distsBrute.pc=K; distsBrute.ld=distsBrute.pc;
291 distsBrute.mat = (real*)calloc( distsBrute.pr*distsBrute.pc, sizeof(*distsBrute.mat) );
292
293 gettimeofday(&tvB,NULL);
294 bruteK(x,q,NNsB,distsBrute);
295 gettimeofday(&tvE,NULL);
296
297 //calc overlap
298 for(i=0; i<m; i++){
299 for(j=0; j<K; j++){
300 for(k=0; k<K; k++){
301 ol[i] += ( NNs.mat[IDX(i, j, NNs.ld)] == NNsB.mat[IDX(i, k, NNsB.ld)] );
302 }
303 }
304 }
305
306 long int nc=0;
307 for(i=0;i<m;i++){
308 nc += ol[i];
309 }
310
311 double mean = ((double)nc)/((double)m);
312 double var = 0.0;
313 for(i=0;i<m;i++) {
314 var += (((double)ol[i])-mean)*(((double)ol[i])-mean)/((double)m);
315 }
316 printf("\tavg overlap = %6.4f/%d; std dev = %6.4f \n", mean, K, sqrt(var));
317
318 FILE* fp;
319 if(outFile){
320 fp = fopen(outFile, "a");
321 fprintf( fp, "%d %d %6.5f %6.5f ", numReps, s, mean, sqrt(var) );
322 }
323
324 real *ranges = (real*)calloc(q.pr,sizeof(*ranges));
325 for(i=0;i<q.r;i++){
326 ranges[i] = distVec(q,x,i,NNs.mat[IDX(i, K-1, NNs.ld)]);
327 }
328
329 unint *cnts = (unint*)calloc(q.pr,sizeof(*cnts));
330 bruteRangeCount(x,q,ranges,cnts);
331
332 nc=0;
333 for(i=0;i<m;i++){
334 nc += cnts[i];
335 }
336 mean = ((double)nc)/((double)m);
337 var = 0.0;
338 for(i=0;i<m;i++) {
339 var += (((double)cnts[i])-mean)*(((double)cnts[i])-mean)/((double)m);
340 }
341 printf("\tavg actual rank of 32nd NN returned by the RBC = %6.4f; std dev = %6.4f \n\n", mean, sqrt(var));
342 printf("(brute k-nn took %6.4f) \n", timeDiff(tvB, tvE));
343
344 if(outFile){
345 fprintf( fp, "%6.5f %6.5f \n", mean, sqrt(var) );
346 fclose(fp);
347 }
348
349 free(cnts);
350 free(ol);
351 free(NNsB.mat);
352 free(distsBrute.mat);
353 }