1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
11 from qpalma.parsers import *
14 data = None
17 def createErrorVSCutPlot(results):
18 """
19 This function takes the results of the evaluation and creates a tex table.
20 """
22 fh = open('error_rates_table.tex','w+')
23 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
24 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
25 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
27 #for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
28 for pos,key in enumerate(['+++']):
29 res = results[key]
30 for i in range(37):
31 ctr = 0
32 try:
33 ctr = res[1][i]
34 except:
35 ctr = 0
37 lines.append( '%d\n' % ctr)
39 if pos % 2 == 1:
40 lines.append('\hline')
42 lines.append('\end{tabular}')
44 lines = [l+'\n' for l in lines]
45 for l in lines:
46 fh.write(l)
47 fh.close()
50 def createTable(results):
51 """
52 This function takes the results of the evaluation and creates a tex table.
53 """
55 fh = open('result_table.tex','w+')
56 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
57 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
58 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
60 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
61 res = [e*100 for e in results[key]]
63 lines.append( '%s & %s & %s & %2.2f & %2.2f \\%%\\\\' % ( key[0], key[1], key[2], res[0], res[1] ) )
64 if pos % 2 == 1:
65 lines.append('\hline')
67 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
68 res = [e*100 for e in results[key]]
70 lines.append( '%s & %s & %s & %2.2f & x \\%%\\\\' % ( key[0], key[1], key[2], res[2] ) )
71 if pos % 2 == 1:
72 lines.append('\hline')
74 lines.append('\end{tabular}')
76 lines = [l+'\n' for l in lines]
77 for l in lines:
78 fh.write(l)
79 fh.close()
82 def compare_scores_and_labels(scores,labels):
83 """
84 Iterate through all predictions. If we find a correct prediction check
85 whether this correct prediction scores higher than the incorrect
86 predictions for this example.
87 """
89 for currentPos,currentElem in enumerate(scores):
90 if labels[currentPos] == True:
91 for otherPos,otherElem in enumerate(scores):
92 if otherPos == currentPos:
93 continue
95 if labels[otherPos] == False and otherElem > currentElem:
96 return False
98 return True
101 def compare_exons(predExons,trueExons):
102 e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
104 if len(predExons) == 4:
105 e1_begin,e1_end = predExons[0],predExons[1]
106 e2_begin,e2_end = predExons[2],predExons[3]
107 else:
108 return False
110 e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
111 e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
113 e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
114 e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
116 if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
117 and e2_e_off == 0:
118 return True
120 return False
123 def evaluate_unmapped_example(current_prediction):
124 predExons = current_prediction['predExons']
125 trueExons = current_prediction['trueExons']
127 result = compare_exons(predExons,trueExons)
128 return result
131 def evaluate_example(current_prediction):
132 label = False
133 label = current_prediction['label']
135 pred_score = current_prediction['DPScores'].flatten().tolist()[0][0]
137 # if the read was mapped by vmatch at an incorrect position we only have to
138 # compare the score
139 if label == False:
140 return label,False,pred_score
142 predExons = current_prediction['predExons']
143 trueExons = current_prediction['trueExons']
145 predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
146 truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
148 #pdb.set_trace()
150 pos_comparison = (predPositions == truePositions)
152 #if label == True and pos_comparison == False:
153 # pdb.set_trace()
155 return label,pos_comparison,pred_score
158 def prediction_on(filename):
160 incorrect_gt_cuts = {}
161 incorrect_vmatch_cuts = {}
165 exon1Begin = []
166 exon1End = []
167 exon2Begin = []
168 exon2End = []
169 allWrongExons = []
170 allDoubleScores = []
172 gt_correct_ctr = 0
173 gt_incorrect_ctr = 0
175 pos_correct_ctr = 0
176 pos_incorrect_ctr = 0
178 score_correct_ctr = 0
179 score_incorrect_ctr = 0
181 total_gt_examples = 0
183 total_vmatch_instances_ctr = 0
184 true_vmatch_instances_ctr = 0
186 for current_example_pred in allPredictions:
187 gt_example = current_example_pred[0]
188 gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
189 gt_correct = evaluate_unmapped_example(gt_example)
191 exampleIdx = gt_example['exampleIdx']
193 cut_pos = gt_example['true_cut']
195 current_scores = []
196 current_labels = []
197 current_scores.append(gt_score)
198 current_labels.append(gt_correct)
200 if gt_correct:
201 gt_correct_ctr += 1
202 else:
203 gt_incorrect_ctr += 1
205 try:
206 incorrect_gt_cuts[cut_pos] += 1
207 except:
208 incorrect_gt_cuts[cut_pos] = 1
210 total_gt_examples += 1
212 #pdb.set_trace()
214 for elem_nr,current_pred in enumerate(current_example_pred[1:]):
215 current_label,comparison_result,current_score = evaluate_example(current_pred)
217 # if vmatch found the right read pos we check for right exons
218 # boundaries
219 if current_label:
220 if comparison_result:
221 pos_correct_ctr += 1
222 else:
223 pos_incorrect_ctr += 1
225 try:
226 incorrect_vmatch_cuts[cut_pos] += 1
227 except:
228 incorrect_vmatch_cuts[cut_pos] = 1
230 true_vmatch_instances_ctr += 1
232 current_scores.append(current_score)
233 current_labels.append(current_label)
235 total_vmatch_instances_ctr += 1
237 # check whether the correct predictions score higher than the incorrect
238 # ones
239 cmp_res = compare_scores_and_labels(current_scores,current_labels)
240 if cmp_res:
241 score_correct_ctr += 1
242 else:
243 score_incorrect_ctr += 1
245 # now that we have evaluated all instances put out all counters and sizes
246 print 'Total num. of examples: %d' % len(allPredictions)
247 print 'Number of correct ground truth examples: %d' % gt_correct_ctr
248 print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
249 print 'Correct pos: %d, incorrect pos: %d' %\
250 (pos_correct_ctr,pos_incorrect_ctr)
251 print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
252 print 'Correct scores: %d, incorrect scores: %d' %\
253 (score_correct_ctr,score_incorrect_ctr)
255 pos_error = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
256 score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
257 gt_error = 1.0 * gt_incorrect_ctr / total_gt_examples
259 #print pos_error,score_error,gt_error
261 return (pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts)
264 def collect_prediction(current_dir,run_name):
265 """
266 Given the toplevel directory this function takes care that for each distinct
267 experiment the training and test predictions are evaluated.
269 """
270 idx = 5
272 train_suffix = '_%d_allPredictions_TRAIN' % (idx)
273 test_suffix = '_%d_allPredictions_TEST' % (idx)
275 jp = os.path.join
276 b2s = ['-','+']
279 QFlag = currentRun['enable_quality_scores']
280 SSFlag = currentRun['enable_splice_signals']
281 ILFlag = currentRun['enable_intron_length']
282 currentRunId = '%s%s%s' % (b2s[QFlag],b2s[SSFlag],b2s[ILFlag])
284 #filename = jp(current_dir,run_name)+train_suffix
285 #print 'Prediction on: %s' % filename
286 #train_result = prediction_on(filename)
287 train_result = []
289 filename = jp(current_dir,run_name)+test_suffix
290 print 'Prediction on: %s' % filename
291 test_result = prediction_on(filename)
293 return train_result,test_result,currentRunId
296 def perform_prediction(current_dir,run_name):
297 """
298 This function takes care of starting the jobs needed for the prediction phase
299 of qpalma
300 """
301 #for i in range(1,6):
302 for i in range(1,2):
303 cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s %d |\
304 qsub -l h_vmem=12.0G -cwd -j y -N \"%s_%d.log\"'%(current_dir,i,run_name,i)
306 #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
307 #print cmd
308 os.system(cmd)
312 def forall_experiments(current_func,tl_dir):
313 """
314 Given the toplevel directoy this function calls for each subdir the
315 function given as first argument. Which are at the moment:
317 - perform_prediction, and
318 - collect_prediction.
320 """
322 dir_entries = os.listdir(tl_dir)
323 dir_entries = [os.path.join(tl_dir,de) for de in dir_entries]
324 run_dirs = [de for de in dir_entries if os.path.isdir(de)]
326 all_results = {}
327 all_error_rates = {}
329 for current_dir in run_dirs:
330 run_name = current_dir.split('/')[-1]
332 current_func(current_dir,run_name)
334 #train_result,test_result,currentRunId = current_func(current_dir,run_name)
335 #all_results[currentRunId] = test_result
336 #pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
337 #all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
339 #createErrorVSCutPlot(all_error_rates)
340 #createTable(all_results)
346 coverage_map = {}
348 for line in open('/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/coverage_results/ALL_COVERAGES'):
349 id,coverage_nr = line.strip().split()
350 coverage_map[int(id)] = int(coverage_nr)
358 spliced_ctr = 0
359 unspliced_ctr = 0
361 pos_correct_ctr = 0
362 pos_incorrect_ctr = 0
364 correct_spliced_ctr = 0
365 correct_unspliced_ctr = 0
367 incorrect_spliced_ctr = 0
368 incorrect_unspliced_ctr = 0
370 correct_covered_splice_ctr = 0
371 incorrect_covered_splice_ctr = 0
373 total_vmatch_instances_ctr = 0
381 cut_pos_ctr = {}
383 total_ctr = 0
384 skipped_ctr = 0
386 is_spliced = False
387 min_coverage = 3
389 allUniqPredictions = {}
391 print 'Got %d predictions' % len(allPredictions)
393 for new_prediction in allPredictions:
394 id = new_prediction['id']
395 id = int(id)
397 if allUniqPredictions.has_key(id):
398 current_prediction = allUniqPredictions[id]
400 current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
401 new_score = new_prediction['DPScores'].flatten().tolist()[0][0]
403 if current_a_score < new_score :
404 allUniqPredictions[id] = new_prediction
406 else:
407 allUniqPredictions[id] = new_prediction
409 print 'Got %d uniq predictions' % len(allUniqPredictions)
411 #for current_prediction in allPredictions:
412 for _id,current_prediction in allUniqPredictions.items():
413 id = current_prediction['id']
414 id = int(id)
416 if id < 1000000010006:
417 continue
419 if not id >= 1000000300000:
420 is_spliced = True
421 else:
422 is_spliced = False
424 is_covered = False
426 if is_spliced:
427 try:
428 current_coverage_nr = coverage_map[id]
429 is_covered = True
430 except:
431 is_covered = False
434 if is_spliced:
435 spliced_ctr += 1
436 else:
437 unspliced_ctr += 1
439 try:
441 except:
442 skipped_ctr += 1
443 continue
445 start_pos = current_prediction['start_pos']
446 chr = current_prediction['chr']
447 strand = current_prediction['strand']
449 #score = current_prediction['DPScores'].flatten().tolist()[0][0]
450 #pdb.set_trace()
452 predExons = current_prediction['predExons'] #:newExons, 'dna':dna, 'est':est
453 predExons = [e+start_pos for e in predExons]
455 spliced_flag = False
457 if len(predExons) == 4:
458 spliced_flag = True
459 predExons[1] -= 1
460 predExons[3] -= 1
462 cut_pos = current_ground_truth['true_cut']
463 p_start = current_ground_truth['p_start']
464 e_stop = current_ground_truth['exon_stop']
465 e_start = current_ground_truth['exon_start']
466 p_stop = current_ground_truth['p_stop']
468 true_cut = current_ground_truth['true_cut']
470 if p_start == predExons[0] and e_stop == predExons[1] and\
471 e_start == predExons[2] and p_stop == predExons[3]:
472 pos_correct = True
473 else:
474 pos_correct = False
476 elif len(predExons) == 2:
477 spliced_flag = False
478 predExons[1] -= 1
480 cut_pos = current_ground_truth['true_cut']
481 p_start = current_ground_truth['p_start']
482 p_stop = current_ground_truth['p_stop']
484 true_cut = current_ground_truth['true_cut']
486 #pdb.set_trace()
488 if math.fabs(p_start - predExons[0]) <= 0:# and math.fabs(p_stop - predExons[1]) <= 2:
489 pos_correct = True
490 else:
491 pos_correct = False
493 else:
494 pos_correct = False
496 if is_spliced and not spliced_flag:
499 if is_spliced and not pos_correct and len(predExons) == 4 and predExons[1]!=-1:
502 if not is_spliced and spliced_flag:
505 if not is_spliced and not pos_correct:
509 if pos_correct:
510 pos_correct_ctr += 1
512 if is_spliced:
513 correct_spliced_ctr += 1
514 if is_covered and current_coverage_nr >= min_coverage:
515 correct_covered_splice_ctr += 1
517 if not is_spliced:
518 correct_unspliced_ctr += 1
520 else:
521 pos_incorrect_ctr += 1
523 if is_spliced:
524 incorrect_spliced_ctr += 1
525 if is_covered and current_coverage_nr >= min_coverage:
526 incorrect_covered_splice_ctr += 1
528 if not is_spliced:
529 incorrect_unspliced_ctr += 1
531 if spliced_flag:
532 if not is_covered:
533 current_coverage_nr=0
534 if pos_correct:
535 print "%s\tcorrect\t%i" %( current_prediction['id'], current_coverage_nr)
536 else:
537 print "%s\twrong\t%i" %( current_prediction['id'], current_coverage_nr)
539 total_ctr += 1
542 numPredictions = len(allUniqPredictions)
544 # now that we have evaluated all instances put out all counters and sizes
545 print 'Total num. of examples: %d' % numPredictions
547 print "spliced/unspliced: %d,%d " % (spliced_ctr, unspliced_ctr )
548 print "Correct/incorrect spliced: %d,%d " % (correct_spliced_ctr, incorrect_spliced_ctr )
549 print "Correct/incorrect unspliced: %d,%d " % (correct_unspliced_ctr , incorrect_unspliced_ctr )
550 print "Correct/incorrect covered spliced read: %d,%d " %\
551 (correct_covered_splice_ctr,incorrect_covered_splice_ctr)
553 print "pos_correct: %d,%d" % (pos_correct_ctr , pos_incorrect_ctr )