1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
4 import cPickle
5 import sys
6 import pydb
7 import pdb
8 import os
9 import os.path
10 import math
12 data = None
15 def createErrorVSCutPlot(results):
16 """
17 This function takes the results of the evaluation and creates a tex table.
18 """
20 fh = open('error_rates_table.tex','w+')
21 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
22 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
23 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
25 #for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
26 for pos,key in enumerate(['+++']):
27 res = results[key]
28 for i in range(37):
29 ctr = 0
30 try:
31 ctr = res[1][i]
32 except:
33 ctr = 0
35 lines.append( '%d\n' % ctr)
37 if pos % 2 == 1:
38 lines.append('\hline')
40 lines.append('\end{tabular}')
42 lines = [l+'\n' for l in lines]
43 for l in lines:
44 fh.write(l)
45 fh.close()
48 def createTable(results):
49 """
50 This function takes the results of the evaluation and creates a tex table.
51 """
53 fh = open('result_table.tex','w+')
54 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
55 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
56 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
58 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
59 res = [e*100 for e in results[key]]
61 lines.append( '%s & %s & %s & %2.2f & %2.2f \\%%\\\\' % ( key[0], key[1], key[2], res[0], res[1] ) )
62 if pos % 2 == 1:
63 lines.append('\hline')
65 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
66 res = [e*100 for e in results[key]]
68 lines.append( '%s & %s & %s & %2.2f & x \\%%\\\\' % ( key[0], key[1], key[2], res[2] ) )
69 if pos % 2 == 1:
70 lines.append('\hline')
72 lines.append('\end{tabular}')
74 lines = [l+'\n' for l in lines]
75 for l in lines:
76 fh.write(l)
77 fh.close()
80 def compare_scores_and_labels(scores,labels):
81 """
82 Iterate through all predictions. If we find a correct prediction check
83 whether this correct prediction scores higher than the incorrect
84 predictions for this example.
85 """
87 for currentPos,currentElem in enumerate(scores):
88 if labels[currentPos] == True:
89 for otherPos,otherElem in enumerate(scores):
90 if otherPos == currentPos:
91 continue
93 if labels[otherPos] == False and otherElem > currentElem:
94 return False
96 return True
99 def compare_exons(predExons,trueExons):
100 e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
102 if len(predExons) == 4:
103 e1_begin,e1_end = predExons[0],predExons[1]
104 e2_begin,e2_end = predExons[2],predExons[3]
105 else:
106 return False
108 e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
109 e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
111 e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
112 e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
114 if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
115 and e2_e_off == 0:
116 return True
118 return False
121 def evaluate_unmapped_example(current_prediction):
122 predExons = current_prediction['predExons']
123 trueExons = current_prediction['trueExons']
125 result = compare_exons(predExons,trueExons)
126 return result
129 def evaluate_example(current_prediction):
130 label = False
131 label = current_prediction['label']
133 pred_score = current_prediction['DPScores'].flatten().tolist()[0][0]
135 # if the read was mapped by vmatch at an incorrect position we only have to
136 # compare the score
137 if label == False:
138 return label,False,pred_score
140 predExons = current_prediction['predExons']
141 trueExons = current_prediction['trueExons']
143 predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
144 truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
146 #pdb.set_trace()
148 pos_comparison = (predPositions == truePositions)
150 #if label == True and pos_comparison == False:
151 # pdb.set_trace()
153 return label,pos_comparison,pred_score
156 def prediction_on(filename):
158 global data
160 OriginalEsts= data[2]
161 OriginalEsts[10000:] = OriginalEsts
163 SplitPositions = data[5]
164 SplitPositions[10000:] = SplitPositions
167 incorrect_gt_cuts = {}
168 incorrect_vmatch_cuts = {}
172 exon1Begin = []
173 exon1End = []
174 exon2Begin = []
175 exon2End = []
176 allWrongExons = []
177 allDoubleScores = []
179 gt_correct_ctr = 0
180 gt_incorrect_ctr = 0
182 pos_correct_ctr = 0
183 pos_incorrect_ctr = 0
185 score_correct_ctr = 0
186 score_incorrect_ctr = 0
188 total_gt_examples = 0
190 total_vmatch_instances_ctr = 0
191 true_vmatch_instances_ctr = 0
193 for current_example_pred in allPredictions:
194 gt_example = current_example_pred[0]
195 gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
196 gt_correct = evaluate_unmapped_example(gt_example)
198 exampleIdx = gt_example['exampleIdx']
199 originalEst = OriginalEsts[exampleIdx]
201 cut_pos = gt_example['true_cut']
202 #cut_pos = SplitPositions[exampleIdx]
203 #ctr = 0
204 #for elem in originalEst:
205 # if elem == '[':
206 # ctr += 3
207 #
208 #cut_pos -= ctr
210 current_scores = []
211 current_labels = []
212 current_scores.append(gt_score)
213 current_labels.append(gt_correct)
215 if gt_correct:
216 gt_correct_ctr += 1
217 else:
218 gt_incorrect_ctr += 1
220 try:
221 incorrect_gt_cuts[cut_pos] += 1
222 except:
223 incorrect_gt_cuts[cut_pos] = 1
225 total_gt_examples += 1
227 #pdb.set_trace()
229 for elem_nr,current_pred in enumerate(current_example_pred[1:]):
230 current_label,comparison_result,current_score = evaluate_example(current_pred)
232 # if vmatch found the right read pos we check for right exons
233 # boundaries
234 if current_label:
235 if comparison_result:
236 pos_correct_ctr += 1
237 else:
238 pos_incorrect_ctr += 1
240 try:
241 incorrect_vmatch_cuts[cut_pos] += 1
242 except:
243 incorrect_vmatch_cuts[cut_pos] = 1
245 true_vmatch_instances_ctr += 1
247 current_scores.append(current_score)
248 current_labels.append(current_label)
250 total_vmatch_instances_ctr += 1
252 # check whether the correct predictions score higher than the incorrect
253 # ones
254 cmp_res = compare_scores_and_labels(current_scores,current_labels)
255 if cmp_res:
256 score_correct_ctr += 1
257 else:
258 score_incorrect_ctr += 1
260 # now that we have evaluated all instances put out all counters and sizes
261 print 'Total num. of examples: %d' % len(allPredictions)
262 print 'Number of correct ground truth examples: %d' % gt_correct_ctr
263 print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
264 print 'Correct pos: %d, incorrect pos: %d' %\
265 (pos_correct_ctr,pos_incorrect_ctr)
266 print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
267 print 'Correct scores: %d, incorrect scores: %d' %\
268 (score_correct_ctr,score_incorrect_ctr)
270 pos_error = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
271 score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
272 gt_error = 1.0 * gt_incorrect_ctr / total_gt_examples
274 #print pos_error,score_error,gt_error
276 return (pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts)
279 def collect_prediction(current_dir,run_name):
280 """
281 Given the toplevel directory this function takes care that for each distinct
282 experiment the training and test predictions are evaluated.
284 """
285 train_suffix = '_allPredictions_TRAIN'
286 test_suffix = '_allPredictions_TEST'
288 jp = os.path.join
289 b2s = ['-','+']
292 QFlag = currentRun['enable_quality_scores']
293 SSFlag = currentRun['enable_splice_signals']
294 ILFlag = currentRun['enable_intron_length']
295 currentRunId = '%s%s%s' % (b2s[QFlag],b2s[SSFlag],b2s[ILFlag])
297 filename = jp(current_dir,run_name)+train_suffix
298 print 'Prediction on: %s' % filename
299 train_result = prediction_on(filename)
301 filename = jp(current_dir,run_name)+test_suffix
302 print 'Prediction on: %s' % filename
303 test_result = prediction_on(filename)
305 return train_result,test_result,currentRunId
308 def perform_prediction(current_dir,run_name):
309 """
310 This function takes care of starting the jobs needed for the prediction phase
311 of qpalma
312 """
313 for i in range(1,6):
314 cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s %d |\
315 qsub -l h_vmem=12.0G -cwd -j y -N \"%s_%d.log\"'%(current_dir,i,run_name,i)
317 #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
318 #print cmd
319 os.system(cmd)
322 def forall_experiments(current_func,tl_dir):
323 """
324 Given the toplevel directoy this function calls for each subdir the
325 function given as first argument. Which are at the moment:
327 - perform_prediction, and
328 - collect_prediction.
330 """
332 dir_entries = os.listdir(tl_dir)
333 dir_entries = [os.path.join(tl_dir,de) for de in dir_entries]
334 run_dirs = [de for de in dir_entries if os.path.isdir(de)]
336 all_results = {}
337 all_error_rates = {}
339 for current_dir in run_dirs:
340 run_name = current_dir.split('/')[-1]
342 current_func(current_dir,run_name)
344 #train_result,test_result,currentRunId = current_func(current_dir,run_name)
345 #all_results[currentRunId] = test_result
346 #pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
347 #all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
349 #createErrorVSCutPlot(all_error_rates)
350 #createTable(all_results)
353 if __name__ == '__main__':
354 dir = sys.argv[1]
355 assert os.path.exists(dir), 'Error: Directory does not exist!'
357 #global data
358 #data_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/dataset_remapped_test_new'