git-svn-id: http://svn.tuebingen.mpg.de/ag-raetsch/projects/QPalma@8617 e1793c9e...
[qpalma.git] / dyn_prog / Mathmatics.cpp
1 // Math.cpp: implementation of the CMath class.
2 //
3 //////////////////////////////////////////////////////////////////////
4
5
6 #include "Mathmatics.h"
7 #include "io.h"
8
9 #include <sys/time.h>
10 #include <sys/types.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <math.h>
14 #include <time.h>
15 #include <unistd.h>
16 #include <assert.h>
17
18 //////////////////////////////////////////////////////////////////////
19 // Construction/Destruction
20 //////////////////////////////////////////////////////////////////////
21
22 #ifdef USE_LOGCACHE
23 //gene/math specific constants
24 #ifdef USE_HMMDEBUG
25 #define MAX_LOG_TABLE_SIZE 10*1024*1024
26 #define LOG_TABLE_PRECISION 1e-6
27 #else
28 #define MAX_LOG_TABLE_SIZE 123*1024*1024
29 #define LOG_TABLE_PRECISION 1e-15
30 #endif
31
32 INT CMath::LOGACCURACY = 0; // 100000 steps per integer
33 #endif
34
35 INT CMath::LOGRANGE = 0; // range for logtable: log(1+exp(x)) -25 <= x <= 0
36
37 #ifdef USE_PATHDEBUG
38 const REAL CMath::INFTY = 1e11; // infinity
39 #else
40 const REAL CMath::INFTY = -log(0.0); // infinity
41 #endif
42 const REAL CMath::ALMOST_NEG_INFTY = -1000;
43
44 CHAR CMath::rand_state[256];
45
46 CMath::CMath()
47 {
48 struct timeval tv;
49 gettimeofday(&tv, NULL);
50 UINT seed=(UINT) (4223517*getpid()*tv.tv_sec*tv.tv_usec);
51 initstate(seed, CMath::rand_state, sizeof(CMath::rand_state));
52 CIO::message(M_INFO, "seeding random number generator with %u\n", seed);
53
54 #ifdef USE_LOGCACHE
55 LOGRANGE=CMath::determine_logrange();
56 LOGACCURACY=CMath::determine_logaccuracy(LOGRANGE);
57 CIO::message(M_INFO, "Initializing log-table (size=%i*%i*%i=%2.1fMB) ...",LOGRANGE,LOGACCURACY,sizeof(REAL),LOGRANGE*LOGACCURACY*sizeof(REAL)/(1024.0*1024.0)) ;
58
59 CMath::logtable=new REAL[LOGRANGE*LOGACCURACY];
60 init_log_table();
61 CIO::message(M_INFO, "Done.\n") ;
62 #else
63 INT i=0;
64 while ((REAL)log(1+((REAL)exp(-REAL(i)))))
65 i++;
66 CIO::message(M_INFO, "determined range for x in log(1+exp(-x)) is:%d\n", i);
67 LOGRANGE=i;
68 #endif
69 }
70
71 CMath::~CMath()
72 {
73 #ifdef USE_LOGCACHE
74 delete[] logtable;
75 #endif
76 }
77
78 #ifdef USE_LOGCACHE
79 INT CMath::determine_logrange()
80 {
81 INT i;
82 REAL acc=0;
83 for (i=0; i<50; i++)
84 {
85 acc=((REAL)log(1+((REAL)exp(-REAL(i)))));
86 if (acc<=(REAL)LOG_TABLE_PRECISION)
87 break;
88 }
89
90 CIO::message(M_INFO, "determined range for x in table log(1+exp(-x)) is:%d (error:%G)\n",i,acc);
91 return i;
92 }
93
94 INT CMath::determine_logaccuracy(INT range)
95 {
96 range=MAX_LOG_TABLE_SIZE/range/((int)sizeof(REAL));
97 CIO::message(M_INFO, "determined accuracy for x in table log(1+exp(-x)) is:%d (error:%G)\n",range,1.0/(double) range);
98 return range;
99 }
100
101 //init log table of form log(1+exp(x))
102 void CMath::init_log_table()
103 {
104 for (INT i=0; i< LOGACCURACY*LOGRANGE; i++)
105 logtable[i]=log(1+exp(REAL(-i)/REAL(LOGACCURACY)));
106 }
107 #endif
108
109 void CMath::sort(INT *a, INT cols, INT sort_col)
110 {
111 INT changed=1;
112 if (a[0]==-1) return ;
113 while (changed)
114 {
115 changed=0; INT i=0 ;
116 while ((a[(i+1)*cols]!=-1) && (a[(i+1)*cols+1]!=-1)) // to be sure
117 {
118 if (a[i*cols+sort_col]>a[(i+1)*cols+sort_col])
119 {
120 for (INT j=0; j<cols; j++)
121 CMath::swap(a[i*cols+j],a[(i+1)*cols+j]) ;
122 changed=1 ;
123 } ;
124 i++ ;
125 } ;
126 } ;
127 }
128
129 void CMath::sort(REAL *a, INT* idx, INT N)
130 {
131
132 INT changed=1;
133 while (changed)
134 {
135 changed=0;
136 for (INT i=0; i<N-1; i++)
137 {
138 if (a[i]>a[i+1])
139 {
140 swap(a[i],a[i+1]) ;
141 swap(idx[i],idx[i+1]) ;
142 changed=1 ;
143 } ;
144 } ;
145 } ;
146
147 }
148
149
150
151 //plot x- axis false positives (fp) 1-Specificity
152 //plot y- axis true positives (tp) Sensitivity
153 INT CMath::calcroc(REAL* fp, REAL* tp, REAL* output, INT* label, INT& size, INT& possize, INT& negsize, REAL& tresh, FILE* rocfile)
154 {
155 INT left=0;
156 INT right=size-1;
157 INT i;
158
159 for (i=0; i<size; i++)
160 {
161 if (!(label[i]== -1 || label[i]==1))
162 return -1;
163 }
164
165 //sort data such that -1 labels first +1 behind
166 while (left<right)
167 {
168 while ((label[left] < 0) && (left<right))
169 left++;
170 while ((label[right] > 0) && (left<right)) //warning: label must be either -1 or +1
171 right--;
172
173 swap(output[left],output[right]);
174 swap(label[left], label[right]);
175 }
176
177 // neg/pos sizes
178 negsize=left;
179 possize=size-left;
180 REAL* negout=output;
181 REAL* posout=&output[left];
182
183 // sort the pos and neg outputs separately
184 qsort(negout, negsize);
185 qsort(posout, possize);
186
187 // set minimum+maximum values for decision-treshhold
188 REAL minimum=min(negout[0], posout[0]);
189 REAL maximum=minimum;
190 if (negsize>0)
191 maximum=max(maximum,negout[negsize-1]);
192 if (possize>0)
193 maximum=max(maximum,posout[possize-1]);
194
195 REAL treshhold=minimum;
196 REAL old_treshhold=treshhold;
197
198 //clear array.
199 for (i=0; i<size; i++)
200 {
201 fp[i]=1.0;
202 tp[i]=1.0;
203 }
204
205 //start with fp=1.0 tp=1.0 which is posidx=0, negidx=0
206 //everything right of {pos,neg}idx is considered to beLONG to +1
207 INT posidx=0;
208 INT negidx=0;
209 INT iteration=1;
210 INT returnidx=-1;
211
212 REAL minerr=10;
213
214 while (iteration < size && treshhold<=maximum)
215 {
216 old_treshhold=treshhold;
217
218 while (treshhold==old_treshhold && treshhold<=maximum)
219 {
220 if (posidx<possize && negidx<negsize)
221 {
222 if (posout[posidx]<negout[negidx])
223 {
224 if (posout[posidx]==treshhold)
225 posidx++;
226 else
227 treshhold=posout[posidx];
228 }
229 else
230 {
231 if (negout[negidx]==treshhold)
232 negidx++;
233 else
234 treshhold=negout[negidx];
235 }
236 }
237 else
238 {
239 if (posidx>=possize && negidx<negsize-1)
240 {
241 negidx++;
242 treshhold=negout[negidx];
243 }
244 else if (negidx>=negsize && posidx<possize-1)
245 {
246 posidx++;
247 treshhold=posout[posidx];
248 }
249 else if (negidx<negsize && treshhold!=negout[negidx])
250 treshhold=negout[negidx];
251 else if (posidx<possize && treshhold!=posout[posidx])
252 treshhold=posout[posidx];
253 else
254 {
255 treshhold=2*(maximum+100); // force termination
256 posidx=possize;
257 negidx=negsize;
258 break;
259 }
260 }
261 }
262
263 //calc tp,fp
264 tp[iteration]=(possize-posidx)/(REAL (possize));
265 fp[iteration]=(negsize-negidx)/(REAL (negsize));
266
267 //choose poINT with minimal error
268 if (minerr > negsize*fp[iteration]/size+(1-tp[iteration])*possize/size )
269 {
270 minerr=negsize*fp[iteration]/size+(1-tp[iteration])*possize/size;
271 tresh=(old_treshhold+treshhold)/2;
272 returnidx=iteration;
273 }
274
275 iteration++;
276 }
277
278 // set new size
279 size=iteration;
280
281 if (rocfile)
282 {
283 const CHAR id[]="ROC";
284 fwrite(id, sizeof(char), sizeof(id), rocfile);
285 fwrite(fp, sizeof(REAL), size, rocfile);
286 fwrite(tp, sizeof(REAL), size, rocfile);
287 }
288
289 return returnidx;
290 }
291
292 UINT CMath::crc32(BYTE *data, INT len)
293 {
294 UINT result;
295 INT i,j;
296 BYTE octet;
297
298 result = 0-1;
299
300 for (i=0; i<len; i++)
301 {
302 octet = *(data++);
303 for (j=0; j<8; j++)
304 {
305 if ((octet >> 7) ^ (result >> 31))
306 {
307 result = (result << 1) ^ 0x04c11db7;
308 }
309 else
310 {
311 result = (result << 1);
312 }
313 octet <<= 1;
314 }
315 }
316
317 return ~result;
318 }
319
320 double CMath::mutual_info(REAL* p1, REAL* p2, INT len)
321 {
322 double e=0;
323
324 for (INT i=0; i<len; i++)
325 for (INT j=0; j<len; j++)
326 e+=exp(p2[j*len+i])*(p2[j*len+i]-p1[i]-p1[j]);
327
328 return e;
329 }
330
331 double CMath::relative_entropy(REAL* p, REAL* q, INT len)
332 {
333 double e=0;
334
335 for (INT i=0; i<len; i++)
336 e+=exp(p[i])*(p[i]-q[i]);
337
338 return e;
339 }
340
341 double CMath::entropy(REAL* p, INT len)
342 {
343 double e=0;
344
345 for (INT i=0; i<len; i++)
346 e-=exp(p[i])*p[i];
347
348 return e;
349 }