+ added information on cut positions of spliced reads
[qpalma.git] / scripts / Evaluation.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pydb
7 import pdb
8 import os
9 import os.path
10 import math
11
12 data = None
13
14 def createErrorVSCutPlot(results):
15 """
16 This function takes the results of the evaluation and creates a tex table.
17 """
18
19 fh = open('error_rates_table.tex','w+')
20 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
21 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
22 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
23
24 #for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
25 for pos,key in enumerate(['+++']):
26 res = results[key]
27 for i in range(45):
28 ctr = 0
29 try:
30 ctr = res[1][i]
31 except:
32 ctr = 0
33
34 lines.append( '%d\n' % ctr)
35
36 if pos % 2 == 1:
37 lines.append('\hline')
38
39 lines.append('\end{tabular}')
40
41 lines = [l+'\n' for l in lines]
42 for l in lines:
43 fh.write(l)
44 fh.close()
45
46
47 def createTable(results):
48 """
49 This function takes the results of the evaluation and creates a tex table.
50 """
51
52 fh = open('result_table.tex','w+')
53 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
54 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
55 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
56
57 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
58 res = [e*100 for e in results[key]]
59
60 lines.append( '%s & %s & %s & %2.2f & %2.2f \\%%\\\\' % ( key[0], key[1], key[2], res[0], res[1] ) )
61 if pos % 2 == 1:
62 lines.append('\hline')
63
64 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
65 res = [e*100 for e in results[key]]
66
67 lines.append( '%s & %s & %s & %2.2f & x \\%%\\\\' % ( key[0], key[1], key[2], res[2] ) )
68 if pos % 2 == 1:
69 lines.append('\hline')
70
71 lines.append('\end{tabular}')
72
73 lines = [l+'\n' for l in lines]
74 for l in lines:
75 fh.write(l)
76 fh.close()
77
78
79 def compare_scores_and_labels(scores,labels):
80 """
81 Iterate through all predictions. If we find a correct prediction check
82 whether this correct prediction scores higher than the incorrect
83 predictions for this example.
84 """
85
86 for currentPos,currentElem in enumerate(scores):
87 if labels[currentPos] == True:
88 for otherPos,otherElem in enumerate(scores):
89 if otherPos == currentPos:
90 continue
91
92 if labels[otherPos] == False and otherElem > currentElem:
93 return False
94
95 return True
96
97
98 def compare_exons(predExons,trueExons):
99 e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
100
101 if len(predExons) == 4:
102 e1_begin,e1_end = predExons[0],predExons[1]
103 e2_begin,e2_end = predExons[2],predExons[3]
104 else:
105 return False
106
107 e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
108 e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
109
110 e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
111 e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
112
113 if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
114 and e2_e_off == 0:
115 return True
116
117 return False
118
119
120 def evaluate_unmapped_example(current_prediction):
121 predExons = current_prediction['predExons']
122 trueExons = current_prediction['trueExons']
123
124 result = compare_exons(predExons,trueExons)
125 return result
126
127
128 def evaluate_example(current_prediction):
129 label = False
130 label = current_prediction['label']
131
132 pred_score = current_prediction['DPScores'].flatten().tolist()[0][0]
133
134 # if the read was mapped by vmatch at an incorrect position we only have to
135 # compare the score
136 if label == False:
137 return label,False,pred_score
138
139 predExons = current_prediction['predExons']
140 trueExons = current_prediction['trueExons']
141
142 predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
143 truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
144
145 #pdb.set_trace()
146
147 pos_comparison = (predPositions == truePositions)
148
149 #if label == True and pos_comparison == False:
150 # pdb.set_trace()
151
152 return label,pos_comparison,pred_score
153
154
155 def prediction_on(filename):
156
157 global data
158
159 OriginalEsts= data[2]
160 OriginalEsts[10000:] = OriginalEsts
161
162 SplitPositions = data[5]
163 SplitPositions[10000:] = SplitPositions
164
165
166 incorrect_gt_cuts = {}
167 incorrect_vmatch_cuts = {}
168
169 allPredictions = cPickle.load(open(filename))
170
171 exon1Begin = []
172 exon1End = []
173 exon2Begin = []
174 exon2End = []
175 allWrongExons = []
176 allDoubleScores = []
177
178 gt_correct_ctr = 0
179 gt_incorrect_ctr = 0
180
181 pos_correct_ctr = 0
182 pos_incorrect_ctr = 0
183
184 score_correct_ctr = 0
185 score_incorrect_ctr = 0
186
187 total_gt_examples = 0
188
189 total_vmatch_instances_ctr = 0
190 true_vmatch_instances_ctr = 0
191
192 for current_example_pred in allPredictions:
193 gt_example = current_example_pred[0]
194 gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
195 gt_correct = evaluate_unmapped_example(gt_example)
196
197 exampleIdx = gt_example['exampleIdx']
198 originalEst = OriginalEsts[exampleIdx]
199 cut_pos = SplitPositions[exampleIdx]
200
201 #pdb.set_trace()
202
203 ctr = 0
204 for elem in originalEst:
205 if elem == '[':
206 ctr += 3
207
208 cut_pos -= ctr
209
210 current_scores = []
211 current_labels = []
212 current_scores.append(gt_score)
213 current_labels.append(gt_correct)
214
215 if gt_correct:
216 gt_correct_ctr += 1
217 else:
218 gt_incorrect_ctr += 1
219
220 try:
221 incorrect_gt_cuts[cut_pos] += 1
222 except:
223 incorrect_gt_cuts[cut_pos] = 1
224
225 total_gt_examples += 1
226
227 #pdb.set_trace()
228
229 for elem_nr,current_pred in enumerate(current_example_pred[1:]):
230 current_label,comparison_result,current_score = evaluate_example(current_pred)
231
232 # if vmatch found the right read pos we check for right exons
233 # boundaries
234 if current_label:
235 if comparison_result:
236 pos_correct_ctr += 1
237 else:
238 pos_incorrect_ctr += 1
239
240 try:
241 incorrect_vmatch_cuts[cut_pos] += 1
242 except:
243 incorrect_vmatch_cuts[cut_pos] = 1
244
245 true_vmatch_instances_ctr += 1
246
247 current_scores.append(current_score)
248 current_labels.append(current_label)
249
250 total_vmatch_instances_ctr += 1
251
252 # check whether the correct predictions score higher than the incorrect
253 # ones
254 cmp_res = compare_scores_and_labels(current_scores,current_labels)
255 if cmp_res:
256 score_correct_ctr += 1
257 else:
258 score_incorrect_ctr += 1
259
260 # now that we have evaluated all instances put out all counters and sizes
261 print 'Total num. of examples: %d' % len(allPredictions)
262 print 'Number of correct ground truth examples: %d' % gt_correct_ctr
263 print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
264 print 'Correct pos: %d, incorrect pos: %d' %\
265 (pos_correct_ctr,pos_incorrect_ctr)
266 print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
267 print 'Correct scores: %d, incorrect scores: %d' %\
268 (score_correct_ctr,score_incorrect_ctr)
269
270 pos_error = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
271 score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
272 gt_error = 1.0 * gt_incorrect_ctr / total_gt_examples
273
274 #print pos_error,score_error,gt_error
275
276 return (pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts)
277
278
279 def collect_prediction(current_dir,run_name):
280 """
281 Given the toplevel directory this function takes care that for each distinct
282 experiment the training and test predictions are evaluated.
283
284 """
285 train_suffix = '_allPredictions_TRAIN'
286 test_suffix = '_allPredictions_TEST'
287
288 jp = os.path.join
289 b2s = ['-','+']
290
291 currentRun = cPickle.load(open(jp(current_dir,'run_object.pickle')))
292 QFlag = currentRun['enable_quality_scores']
293 SSFlag = currentRun['enable_splice_signals']
294 ILFlag = currentRun['enable_intron_length']
295 currentRunId = '%s%s%s' % (b2s[QFlag],b2s[SSFlag],b2s[ILFlag])
296
297 filename = jp(current_dir,run_name)+train_suffix
298 print 'Prediction on: %s' % filename
299 train_result = prediction_on(filename)
300
301 filename = jp(current_dir,run_name)+test_suffix
302 print 'Prediction on: %s' % filename
303 test_result = prediction_on(filename)
304
305 return train_result,test_result,currentRunId
306
307
308 def perform_prediction(current_dir,run_name):
309 """
310 This function takes care of starting the jobs needed for the prediction phase
311 of qpalma
312 """
313 cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s | qsub -l h_vmem=12.0G -cwd -j y -N \"%s.log\"'%(current_dir,run_name)
314 #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
315 #print cmd
316 os.system(cmd)
317
318
319 def forall_experiments(current_func,tl_dir):
320 """
321 Given the toplevel directoy this function calls for each subdir the
322 function given as first argument. Which are at the moment:
323
324 - perform_prediction, and
325 - collect_prediction.
326
327 """
328
329 dir_entries = os.listdir(tl_dir)
330 dir_entries = [os.path.join(tl_dir,de) for de in dir_entries]
331 run_dirs = [de for de in dir_entries if os.path.isdir(de)]
332
333 all_results = {}
334 all_error_rates = {}
335
336 for current_dir in run_dirs:
337 run_name = current_dir.split('/')[-1]
338
339 #current_func(current_dir,run_name)
340
341 train_result,test_result,currentRunId = current_func(current_dir,run_name)
342 all_results[currentRunId] = test_result
343 pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
344
345 all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
346
347 createErrorVSCutPlot(all_error_rates)
348 #createTable(all_results)
349
350
351 if __name__ == '__main__':
352 dir = sys.argv[1]
353 assert os.path.exists(dir), 'Error: Directory does not exist!'
354
355 global data
356 data_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/dataset_remapped_test_new'
357 data = cPickle.load(open(data_fn))
358
359 #forall_experiments(perform_prediction,dir)
360 forall_experiments(collect_prediction,dir)