+ fixed some index bugs in the evaluation
[qpalma.git] / scripts / Evaluation.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10
11 from qpalma.parsers import *
12
13
14 data = None
15
16
17 def createErrorVSCutPlot(results):
18 """
19 This function takes the results of the evaluation and creates a tex table.
20 """
21
22 fh = open('error_rates_table.tex','w+')
23 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
24 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
25 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
26
27 #for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
28 for pos,key in enumerate(['+++']):
29 res = results[key]
30 for i in range(37):
31 ctr = 0
32 try:
33 ctr = res[1][i]
34 except:
35 ctr = 0
36
37 lines.append( '%d\n' % ctr)
38
39 if pos % 2 == 1:
40 lines.append('\hline')
41
42 lines.append('\end{tabular}')
43
44 lines = [l+'\n' for l in lines]
45 for l in lines:
46 fh.write(l)
47 fh.close()
48
49
50 def createTable(results):
51 """
52 This function takes the results of the evaluation and creates a tex table.
53 """
54
55 fh = open('result_table.tex','w+')
56 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
57 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
58 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
59
60 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
61 res = [e*100 for e in results[key]]
62
63 lines.append( '%s & %s & %s & %2.2f & %2.2f \\%%\\\\' % ( key[0], key[1], key[2], res[0], res[1] ) )
64 if pos % 2 == 1:
65 lines.append('\hline')
66
67 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
68 res = [e*100 for e in results[key]]
69
70 lines.append( '%s & %s & %s & %2.2f & x \\%%\\\\' % ( key[0], key[1], key[2], res[2] ) )
71 if pos % 2 == 1:
72 lines.append('\hline')
73
74 lines.append('\end{tabular}')
75
76 lines = [l+'\n' for l in lines]
77 for l in lines:
78 fh.write(l)
79 fh.close()
80
81
82 def compare_scores_and_labels(scores,labels):
83 """
84 Iterate through all predictions. If we find a correct prediction check
85 whether this correct prediction scores higher than the incorrect
86 predictions for this example.
87 """
88
89 for currentPos,currentElem in enumerate(scores):
90 if labels[currentPos] == True:
91 for otherPos,otherElem in enumerate(scores):
92 if otherPos == currentPos:
93 continue
94
95 if labels[otherPos] == False and otherElem > currentElem:
96 return False
97
98 return True
99
100
101 def compare_exons(predExons,trueExons):
102 e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
103
104 if len(predExons) == 4:
105 e1_begin,e1_end = predExons[0],predExons[1]
106 e2_begin,e2_end = predExons[2],predExons[3]
107 else:
108 return False
109
110 e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
111 e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
112
113 e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
114 e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
115
116 if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
117 and e2_e_off == 0:
118 return True
119
120 return False
121
122
123 def evaluate_unmapped_example(current_prediction):
124 predExons = current_prediction['predExons']
125 trueExons = current_prediction['trueExons']
126
127 result = compare_exons(predExons,trueExons)
128 return result
129
130
131 def evaluate_example(current_prediction):
132 label = False
133 label = current_prediction['label']
134
135 pred_score = current_prediction['DPScores'].flatten().tolist()[0][0]
136
137 # if the read was mapped by vmatch at an incorrect position we only have to
138 # compare the score
139 if label == False:
140 return label,False,pred_score
141
142 predExons = current_prediction['predExons']
143 trueExons = current_prediction['trueExons']
144
145 predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
146 truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
147
148 #pdb.set_trace()
149
150 pos_comparison = (predPositions == truePositions)
151
152 #if label == True and pos_comparison == False:
153 # pdb.set_trace()
154
155 return label,pos_comparison,pred_score
156
157
158 def prediction_on(filename):
159
160 incorrect_gt_cuts = {}
161 incorrect_vmatch_cuts = {}
162
163 allPredictions = cPickle.load(open(filename))
164
165 exon1Begin = []
166 exon1End = []
167 exon2Begin = []
168 exon2End = []
169 allWrongExons = []
170 allDoubleScores = []
171
172 gt_correct_ctr = 0
173 gt_incorrect_ctr = 0
174
175 pos_correct_ctr = 0
176 pos_incorrect_ctr = 0
177
178 score_correct_ctr = 0
179 score_incorrect_ctr = 0
180
181 total_gt_examples = 0
182
183 total_vmatch_instances_ctr = 0
184 true_vmatch_instances_ctr = 0
185
186 for current_example_pred in allPredictions:
187 gt_example = current_example_pred[0]
188 gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
189 gt_correct = evaluate_unmapped_example(gt_example)
190
191 exampleIdx = gt_example['exampleIdx']
192
193 cut_pos = gt_example['true_cut']
194
195 current_scores = []
196 current_labels = []
197 current_scores.append(gt_score)
198 current_labels.append(gt_correct)
199
200 if gt_correct:
201 gt_correct_ctr += 1
202 else:
203 gt_incorrect_ctr += 1
204
205 try:
206 incorrect_gt_cuts[cut_pos] += 1
207 except:
208 incorrect_gt_cuts[cut_pos] = 1
209
210 total_gt_examples += 1
211
212 #pdb.set_trace()
213
214 for elem_nr,current_pred in enumerate(current_example_pred[1:]):
215 current_label,comparison_result,current_score = evaluate_example(current_pred)
216
217 # if vmatch found the right read pos we check for right exons
218 # boundaries
219 if current_label:
220 if comparison_result:
221 pos_correct_ctr += 1
222 else:
223 pos_incorrect_ctr += 1
224
225 try:
226 incorrect_vmatch_cuts[cut_pos] += 1
227 except:
228 incorrect_vmatch_cuts[cut_pos] = 1
229
230 true_vmatch_instances_ctr += 1
231
232 current_scores.append(current_score)
233 current_labels.append(current_label)
234
235 total_vmatch_instances_ctr += 1
236
237 # check whether the correct predictions score higher than the incorrect
238 # ones
239 cmp_res = compare_scores_and_labels(current_scores,current_labels)
240 if cmp_res:
241 score_correct_ctr += 1
242 else:
243 score_incorrect_ctr += 1
244
245 # now that we have evaluated all instances put out all counters and sizes
246 print 'Total num. of examples: %d' % len(allPredictions)
247 print 'Number of correct ground truth examples: %d' % gt_correct_ctr
248 print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
249 print 'Correct pos: %d, incorrect pos: %d' %\
250 (pos_correct_ctr,pos_incorrect_ctr)
251 print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
252 print 'Correct scores: %d, incorrect scores: %d' %\
253 (score_correct_ctr,score_incorrect_ctr)
254
255 pos_error = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
256 score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
257 gt_error = 1.0 * gt_incorrect_ctr / total_gt_examples
258
259 #print pos_error,score_error,gt_error
260
261 return (pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts)
262
263
264 def collect_prediction(current_dir,run_name):
265 """
266 Given the toplevel directory this function takes care that for each distinct
267 experiment the training and test predictions are evaluated.
268
269 """
270 idx = 5
271
272 train_suffix = '_%d_allPredictions_TRAIN' % (idx)
273 test_suffix = '_%d_allPredictions_TEST' % (idx)
274
275 jp = os.path.join
276 b2s = ['-','+']
277
278 currentRun = cPickle.load(open(jp(current_dir,'run_object_%d.pickle'%(idx))))
279 QFlag = currentRun['enable_quality_scores']
280 SSFlag = currentRun['enable_splice_signals']
281 ILFlag = currentRun['enable_intron_length']
282 currentRunId = '%s%s%s' % (b2s[QFlag],b2s[SSFlag],b2s[ILFlag])
283
284 #filename = jp(current_dir,run_name)+train_suffix
285 #print 'Prediction on: %s' % filename
286 #train_result = prediction_on(filename)
287 train_result = []
288
289 filename = jp(current_dir,run_name)+test_suffix
290 print 'Prediction on: %s' % filename
291 test_result = prediction_on(filename)
292
293 return train_result,test_result,currentRunId
294
295
296 def perform_prediction(current_dir,run_name):
297 """
298 This function takes care of starting the jobs needed for the prediction phase
299 of qpalma
300 """
301 #for i in range(1,6):
302 for i in range(1,2):
303 cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s %d |\
304 qsub -l h_vmem=12.0G -cwd -j y -N \"%s_%d.log\"'%(current_dir,i,run_name,i)
305
306 #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
307 #print cmd
308 os.system(cmd)
309
310
311
312 def forall_experiments(current_func,tl_dir):
313 """
314 Given the toplevel directoy this function calls for each subdir the
315 function given as first argument. Which are at the moment:
316
317 - perform_prediction, and
318 - collect_prediction.
319
320 """
321
322 dir_entries = os.listdir(tl_dir)
323 dir_entries = [os.path.join(tl_dir,de) for de in dir_entries]
324 run_dirs = [de for de in dir_entries if os.path.isdir(de)]
325
326 all_results = {}
327 all_error_rates = {}
328
329 for current_dir in run_dirs:
330 run_name = current_dir.split('/')[-1]
331
332 current_func(current_dir,run_name)
333
334 #train_result,test_result,currentRunId = current_func(current_dir,run_name)
335 #all_results[currentRunId] = test_result
336 #pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
337 #all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
338
339 #createErrorVSCutPlot(all_error_rates)
340 #createTable(all_results)
341
342
343
344 def predict_on(filename,filtered_reads):
345
346 print 'parsing filtered reads..'
347 all_filtered_reads = parse_filtered_reads(filtered_reads)
348 print 'found %d filtered reads' % len(all_filtered_reads)
349
350 allPredictions = cPickle.load(open(filename))
351
352 pos_correct_ctr = 0
353 pos_incorrect_ctr = 0
354
355 score_correct_ctr = 0
356 score_incorrect_ctr = 0
357
358 total_vmatch_instances_ctr = 0
359
360
361 for current_prediction in allPredictions:
362 id = current_prediction['id']
363 current_ground_truth = all_filtered_reads[id]
364
365 start_pos = current_prediction['start_pos']
366 chr = current_prediction['chr']
367 strand = current_prediction['strand']
368
369 #score = current_prediction['DPScores'].flatten().tolist()[0][0]
370
371 #pdb.set_trace()
372
373 predExons = current_prediction['predExons'] #:newExons, 'dna':dna, 'est':est
374 predExons = [e+start_pos for e in predExons]
375 if len(predExons) == 4:
376 predExons[1] -= 1
377 predExons[3] -= 1
378
379 cut_pos = current_ground_truth['true_cut']
380 p_start = current_ground_truth['p_start']
381 e_stop = current_ground_truth['exon_stop']
382 e_start = current_ground_truth['exon_start']
383 p_stop = current_ground_truth['p_stop']
384
385 true_cut = current_ground_truth['true_cut']
386
387 if p_start == predExons[0] and e_stop == predExons[1] and\
388 e_start == predExons[2] and p_stop == predExons[3]:
389 pos_correct_ctr += 1
390 else:
391 pos_incorrect_ctr += 1
392 #pdb.set_trace()
393
394 elif len(predExons) == 2:
395 predExons[1] -= 1
396
397 cut_pos = current_ground_truth['true_cut']
398 p_start = current_ground_truth['p_start']
399 p_stop = current_ground_truth['p_stop']
400
401 true_cut = current_ground_truth['true_cut']
402
403 if p_start == predExons[0] and p_stop == predExons[1]:
404 pos_correct_ctr += 1
405 else:
406 pos_incorrect_ctr += 1
407 #pdb.set_trace()
408
409 else:
410 pass
411 ## check whether the correct predictions score higher than the incorrect
412 ## ones
413 #cmp_res = compare_scores_and_labels(current_scores,current_labels)
414 #if cmp_res:
415 # score_correct_ctr += 1
416 #else:
417 # score_incorrect_ctr += 1
418
419 numPredictions = len(allPredictions)
420
421 # now that we have evaluated all instances put out all counters and sizes
422 print 'Total num. of examples: %d' % numPredictions
423 print 'Correct pos: %2.3f, incorrect pos: %2.3f' %\
424 (pos_correct_ctr/(1.0*numPredictions),pos_incorrect_ctr/(1.0*numPredictions))
425
426 #print 'Correct scores: %d, incorrect scores: %d' %\
427 #(score_correct_ctr,score_incorrect_ctr)
428
429 #pos_error = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
430 #score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
431
432
433 if __name__ == '__main__':
434 #dir = sys.argv[1]
435 #assert os.path.exists(dir), 'Error: Directory does not exist!'
436
437 #global data
438 #data_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/dataset_remapped_test_new'
439 #data = cPickle.load(open(data_fn))
440 #forall_experiments(perform_prediction,dir)
441 #forall_experiments(collect_prediction,dir)
442
443 predict_on(sys.argv[1],sys.argv[2])