+ made original reads file location a parameter for the command line
[qpalma.git] / scripts / Evaluation.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10
11 from qpalma.parsers import *
12
13
14 data = None
15
16
17 def createErrorVSCutPlot(results):
18 """
19 This function takes the results of the evaluation and creates a tex table.
20 """
21
22 fh = open('error_rates_table.tex','w+')
23 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
24 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
25 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
26
27 #for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
28 for pos,key in enumerate(['+++']):
29 res = results[key]
30 for i in range(37):
31 ctr = 0
32 try:
33 ctr = res[1][i]
34 except:
35 ctr = 0
36
37 lines.append( '%d\n' % ctr)
38
39 if pos % 2 == 1:
40 lines.append('\hline')
41
42 lines.append('\end{tabular}')
43
44 lines = [l+'\n' for l in lines]
45 for l in lines:
46 fh.write(l)
47 fh.close()
48
49
50 def createTable(results):
51 """
52 This function takes the results of the evaluation and creates a tex table.
53 """
54
55 fh = open('result_table.tex','w+')
56 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
57 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
58 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
59
60 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
61 res = [e*100 for e in results[key]]
62
63 lines.append( '%s & %s & %s & %2.2f & %2.2f \\%%\\\\' % ( key[0], key[1], key[2], res[0], res[1] ) )
64 if pos % 2 == 1:
65 lines.append('\hline')
66
67 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
68 res = [e*100 for e in results[key]]
69
70 lines.append( '%s & %s & %s & %2.2f & x \\%%\\\\' % ( key[0], key[1], key[2], res[2] ) )
71 if pos % 2 == 1:
72 lines.append('\hline')
73
74 lines.append('\end{tabular}')
75
76 lines = [l+'\n' for l in lines]
77 for l in lines:
78 fh.write(l)
79 fh.close()
80
81
82 def compare_scores_and_labels(scores,labels):
83 """
84 Iterate through all predictions. If we find a correct prediction check
85 whether this correct prediction scores higher than the incorrect
86 predictions for this example.
87 """
88
89 for currentPos,currentElem in enumerate(scores):
90 if labels[currentPos] == True:
91 for otherPos,otherElem in enumerate(scores):
92 if otherPos == currentPos:
93 continue
94
95 if labels[otherPos] == False and otherElem > currentElem:
96 return False
97
98 return True
99
100
101 def compare_exons(predExons,trueExons):
102 e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
103
104 if len(predExons) == 4:
105 e1_begin,e1_end = predExons[0],predExons[1]
106 e2_begin,e2_end = predExons[2],predExons[3]
107 else:
108 return False
109
110 e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
111 e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
112
113 e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
114 e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
115
116 if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
117 and e2_e_off == 0:
118 return True
119
120 return False
121
122
123 def evaluate_unmapped_example(current_prediction):
124 predExons = current_prediction['predExons']
125 trueExons = current_prediction['trueExons']
126
127 result = compare_exons(predExons,trueExons)
128 return result
129
130
131 def evaluate_example(current_prediction):
132 label = False
133 label = current_prediction['label']
134
135 pred_score = current_prediction['DPScores'].flatten().tolist()[0][0]
136
137 # if the read was mapped by vmatch at an incorrect position we only have to
138 # compare the score
139 if label == False:
140 return label,False,pred_score
141
142 predExons = current_prediction['predExons']
143 trueExons = current_prediction['trueExons']
144
145 predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
146 truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
147
148 #pdb.set_trace()
149
150 pos_comparison = (predPositions == truePositions)
151
152 #if label == True and pos_comparison == False:
153 # pdb.set_trace()
154
155 return label,pos_comparison,pred_score
156
157
158 def prediction_on(filename):
159
160 incorrect_gt_cuts = {}
161 incorrect_vmatch_cuts = {}
162
163 allPredictions = cPickle.load(open(filename))
164
165 exon1Begin = []
166 exon1End = []
167 exon2Begin = []
168 exon2End = []
169 allWrongExons = []
170 allDoubleScores = []
171
172 gt_correct_ctr = 0
173 gt_incorrect_ctr = 0
174
175 pos_correct_ctr = 0
176 pos_incorrect_ctr = 0
177
178 score_correct_ctr = 0
179 score_incorrect_ctr = 0
180
181 total_gt_examples = 0
182
183 total_vmatch_instances_ctr = 0
184 true_vmatch_instances_ctr = 0
185
186 for current_example_pred in allPredictions:
187 gt_example = current_example_pred[0]
188 gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
189 gt_correct = evaluate_unmapped_example(gt_example)
190
191 exampleIdx = gt_example['exampleIdx']
192
193 cut_pos = gt_example['true_cut']
194
195 current_scores = []
196 current_labels = []
197 current_scores.append(gt_score)
198 current_labels.append(gt_correct)
199
200 if gt_correct:
201 gt_correct_ctr += 1
202 else:
203 gt_incorrect_ctr += 1
204
205 try:
206 incorrect_gt_cuts[cut_pos] += 1
207 except:
208 incorrect_gt_cuts[cut_pos] = 1
209
210 total_gt_examples += 1
211
212 #pdb.set_trace()
213
214 for elem_nr,current_pred in enumerate(current_example_pred[1:]):
215 current_label,comparison_result,current_score = evaluate_example(current_pred)
216
217 # if vmatch found the right read pos we check for right exons
218 # boundaries
219 if current_label:
220 if comparison_result:
221 pos_correct_ctr += 1
222 else:
223 pos_incorrect_ctr += 1
224
225 try:
226 incorrect_vmatch_cuts[cut_pos] += 1
227 except:
228 incorrect_vmatch_cuts[cut_pos] = 1
229
230 true_vmatch_instances_ctr += 1
231
232 current_scores.append(current_score)
233 current_labels.append(current_label)
234
235 total_vmatch_instances_ctr += 1
236
237 # check whether the correct predictions score higher than the incorrect
238 # ones
239 cmp_res = compare_scores_and_labels(current_scores,current_labels)
240 if cmp_res:
241 score_correct_ctr += 1
242 else:
243 score_incorrect_ctr += 1
244
245 # now that we have evaluated all instances put out all counters and sizes
246 print 'Total num. of examples: %d' % len(allPredictions)
247 print 'Number of correct ground truth examples: %d' % gt_correct_ctr
248 print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
249 print 'Correct pos: %d, incorrect pos: %d' %\
250 (pos_correct_ctr,pos_incorrect_ctr)
251 print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
252 print 'Correct scores: %d, incorrect scores: %d' %\
253 (score_correct_ctr,score_incorrect_ctr)
254
255 pos_error = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
256 score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
257 gt_error = 1.0 * gt_incorrect_ctr / total_gt_examples
258
259 #print pos_error,score_error,gt_error
260
261 return (pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts)
262
263
264 def collect_prediction(current_dir,run_name):
265 """
266 Given the toplevel directory this function takes care that for each distinct
267 experiment the training and test predictions are evaluated.
268
269 """
270 idx = 5
271
272 train_suffix = '_%d_allPredictions_TRAIN' % (idx)
273 test_suffix = '_%d_allPredictions_TEST' % (idx)
274
275 jp = os.path.join
276 b2s = ['-','+']
277
278 currentRun = cPickle.load(open(jp(current_dir,'run_object_%d.pickle'%(idx))))
279 QFlag = currentRun['enable_quality_scores']
280 SSFlag = currentRun['enable_splice_signals']
281 ILFlag = currentRun['enable_intron_length']
282 currentRunId = '%s%s%s' % (b2s[QFlag],b2s[SSFlag],b2s[ILFlag])
283
284 #filename = jp(current_dir,run_name)+train_suffix
285 #print 'Prediction on: %s' % filename
286 #train_result = prediction_on(filename)
287 train_result = []
288
289 filename = jp(current_dir,run_name)+test_suffix
290 print 'Prediction on: %s' % filename
291 test_result = prediction_on(filename)
292
293 return train_result,test_result,currentRunId
294
295
296 def perform_prediction(current_dir,run_name):
297 """
298 This function takes care of starting the jobs needed for the prediction phase
299 of qpalma
300 """
301 #for i in range(1,6):
302 for i in range(1,2):
303 cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s %d |\
304 qsub -l h_vmem=12.0G -cwd -j y -N \"%s_%d.log\"'%(current_dir,i,run_name,i)
305
306 #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
307 #print cmd
308 os.system(cmd)
309
310
311
312 def forall_experiments(current_func,tl_dir):
313 """
314 Given the toplevel directoy this function calls for each subdir the
315 function given as first argument. Which are at the moment:
316
317 - perform_prediction, and
318 - collect_prediction.
319
320 """
321
322 dir_entries = os.listdir(tl_dir)
323 dir_entries = [os.path.join(tl_dir,de) for de in dir_entries]
324 run_dirs = [de for de in dir_entries if os.path.isdir(de)]
325
326 all_results = {}
327 all_error_rates = {}
328
329 for current_dir in run_dirs:
330 run_name = current_dir.split('/')[-1]
331
332 current_func(current_dir,run_name)
333
334 #train_result,test_result,currentRunId = current_func(current_dir,run_name)
335 #all_results[currentRunId] = test_result
336 #pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
337 #all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
338
339 #createErrorVSCutPlot(all_error_rates)
340 #createTable(all_results)
341
342
343
344 def predict_on(filename,filtered_reads):
345
346 coverage_map = {}
347
348 for line in open('/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/coverage_results/ALL_COVERAGES'):
349 id,coverage_nr = line.strip().split()
350 coverage_map[int(id)] = int(coverage_nr)
351
352 print 'parsing filtered reads..'
353 all_filtered_reads = parse_filtered_reads(filtered_reads)
354 print 'found %d filtered reads' % len(all_filtered_reads)
355
356 allPredictions = cPickle.load(open(filename))
357
358 spliced_ctr = 0
359 unspliced_ctr = 0
360
361 pos_correct_ctr = 0
362 pos_incorrect_ctr = 0
363
364 correct_spliced_ctr = 0
365 correct_unspliced_ctr = 0
366
367 incorrect_spliced_ctr = 0
368 incorrect_unspliced_ctr = 0
369
370 correct_covered_splice_ctr = 0
371 incorrect_covered_splice_ctr = 0
372
373 total_vmatch_instances_ctr = 0
374
375 unspliced_spliced_reads_ctr = 0
376 wrong_spliced_reads_ctr = 0
377
378 wrong_aligned_unspliced_reads_ctr = 0
379 wrong_unspliced_reads_ctr = 0
380
381 cut_pos_ctr = {}
382
383 total_ctr = 0
384 skipped_ctr = 0
385
386 is_spliced = False
387 min_coverage = 3
388
389 allUniqPredictions = {}
390
391 print 'Got %d predictions' % len(allPredictions)
392
393 for new_prediction in allPredictions:
394 id = new_prediction['id']
395 id = int(id)
396
397 if allUniqPredictions.has_key(id):
398 current_prediction = allUniqPredictions[id]
399
400 current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
401 new_score = new_prediction['DPScores'].flatten().tolist()[0][0]
402
403 if current_a_score < new_score :
404 allUniqPredictions[id] = new_prediction
405
406 else:
407 allUniqPredictions[id] = new_prediction
408
409 print 'Got %d uniq predictions' % len(allUniqPredictions)
410
411 #for current_prediction in allPredictions:
412 for _id,current_prediction in allUniqPredictions.items():
413 id = current_prediction['id']
414 id = int(id)
415
416 if id < 1000000010006:
417 continue
418
419 if not id >= 1000000300000:
420 is_spliced = True
421 else:
422 is_spliced = False
423
424 is_covered = False
425
426 if is_spliced:
427 try:
428 current_coverage_nr = coverage_map[id]
429 is_covered = True
430 except:
431 is_covered = False
432
433
434 if is_spliced:
435 spliced_ctr += 1
436 else:
437 unspliced_ctr += 1
438
439 try:
440 current_ground_truth = all_filtered_reads[id]
441 except:
442 skipped_ctr += 1
443 continue
444
445 start_pos = current_prediction['start_pos']
446 chr = current_prediction['chr']
447 strand = current_prediction['strand']
448
449 #score = current_prediction['DPScores'].flatten().tolist()[0][0]
450 #pdb.set_trace()
451
452 predExons = current_prediction['predExons'] #:newExons, 'dna':dna, 'est':est
453 predExons = [e+start_pos for e in predExons]
454
455 spliced_flag = False
456
457 if len(predExons) == 4:
458 spliced_flag = True
459 predExons[1] -= 1
460 predExons[3] -= 1
461
462 cut_pos = current_ground_truth['true_cut']
463 p_start = current_ground_truth['p_start']
464 e_stop = current_ground_truth['exon_stop']
465 e_start = current_ground_truth['exon_start']
466 p_stop = current_ground_truth['p_stop']
467
468 true_cut = current_ground_truth['true_cut']
469
470 if p_start == predExons[0] and e_stop == predExons[1] and\
471 e_start == predExons[2] and p_stop == predExons[3]:
472 pos_correct = True
473 else:
474 pos_correct = False
475
476 elif len(predExons) == 2:
477 spliced_flag = False
478 predExons[1] -= 1
479
480 cut_pos = current_ground_truth['true_cut']
481 p_start = current_ground_truth['p_start']
482 p_stop = current_ground_truth['p_stop']
483
484 true_cut = current_ground_truth['true_cut']
485
486 #pdb.set_trace()
487
488 if math.fabs(p_start - predExons[0]) <= 0:# and math.fabs(p_stop - predExons[1]) <= 2:
489 pos_correct = True
490 else:
491 pos_correct = False
492
493 else:
494 pos_correct = False
495
496 if is_spliced and not spliced_flag:
497 unspliced_spliced_reads_ctr += 1
498
499 if is_spliced and not pos_correct and len(predExons) == 4 and predExons[1]!=-1:
500 wrong_spliced_reads_ctr += 1
501
502 if not is_spliced and spliced_flag:
503 wrong_unspliced_reads_ctr += 1
504
505 if not is_spliced and not pos_correct:
506 wrong_aligned_unspliced_reads_ctr += 1
507
508
509 if pos_correct:
510 pos_correct_ctr += 1
511
512 if is_spliced:
513 correct_spliced_ctr += 1
514 if is_covered and current_coverage_nr >= min_coverage:
515 correct_covered_splice_ctr += 1
516
517 if not is_spliced:
518 correct_unspliced_ctr += 1
519
520 else:
521 pos_incorrect_ctr += 1
522
523 if is_spliced:
524 incorrect_spliced_ctr += 1
525 if is_covered and current_coverage_nr >= min_coverage:
526 incorrect_covered_splice_ctr += 1
527
528 if not is_spliced:
529 incorrect_unspliced_ctr += 1
530
531 if spliced_flag:
532 if not is_covered:
533 current_coverage_nr=0
534 if pos_correct:
535 print "%s\tcorrect\t%i" %( current_prediction['id'], current_coverage_nr)
536 else:
537 print "%s\twrong\t%i" %( current_prediction['id'], current_coverage_nr)
538
539 total_ctr += 1
540
541
542 numPredictions = len(allUniqPredictions)
543
544 # now that we have evaluated all instances put out all counters and sizes
545 print 'Total num. of examples: %d' % numPredictions
546
547 print "spliced/unspliced: %d,%d " % (spliced_ctr, unspliced_ctr )
548 print "Correct/incorrect spliced: %d,%d " % (correct_spliced_ctr, incorrect_spliced_ctr )
549 print "Correct/incorrect unspliced: %d,%d " % (correct_unspliced_ctr , incorrect_unspliced_ctr )
550 print "Correct/incorrect covered spliced read: %d,%d " %\
551 (correct_covered_splice_ctr,incorrect_covered_splice_ctr)
552
553 print "pos_correct: %d,%d" % (pos_correct_ctr , pos_incorrect_ctr )
554
555 print 'unspliced_spliced reads: %d' % unspliced_spliced_reads_ctr
556 print 'spliced reads at wrong_place: %d' % wrong_spliced_reads_ctr
557
558 print 'spliced_unspliced reads: %d' % wrong_unspliced_reads_ctr
559 print 'wrong aligned at wrong_pos: %d' % wrong_aligned_unspliced_reads_ctr
560
561 print 'total_ctr: %d' % total_ctr
562
563 print "skipped: %d " % skipped_ctr
564 print 'min. coverage: %d' % min_coverage
565
566 if __name__ == '__main__':
567 #dir = sys.argv[1]
568 #assert os.path.exists(dir), 'Error: Directory does not exist!'
569
570 #global data
571 #data_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/dataset_remapped_test_new'
572 #data = cPickle.load(open(data_fn))
573 #forall_experiments(perform_prediction,dir)
574 #forall_experiments(collect_prediction,dir)
575
576 predict_on(sys.argv[1],sys.argv[2])