+ added field to prediction dict
[qpalma.git] / scripts / Evaluation.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pydb
7 import pdb
8 import os
9 import os.path
10 import math
11
12 data = None
13
14
15 def createErrorVSCutPlot(results):
16 """
17 This function takes the results of the evaluation and creates a tex table.
18 """
19
20 fh = open('error_rates_table.tex','w+')
21 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
22 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
23 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
24
25 #for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
26 for pos,key in enumerate(['+++']):
27 res = results[key]
28 for i in range(37):
29 ctr = 0
30 try:
31 ctr = res[1][i]
32 except:
33 ctr = 0
34
35 lines.append( '%d\n' % ctr)
36
37 if pos % 2 == 1:
38 lines.append('\hline')
39
40 lines.append('\end{tabular}')
41
42 lines = [l+'\n' for l in lines]
43 for l in lines:
44 fh.write(l)
45 fh.close()
46
47
48 def createTable(results):
49 """
50 This function takes the results of the evaluation and creates a tex table.
51 """
52
53 fh = open('result_table.tex','w+')
54 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
55 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
56 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
57
58 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
59 res = [e*100 for e in results[key]]
60
61 lines.append( '%s & %s & %s & %2.2f & %2.2f \\%%\\\\' % ( key[0], key[1], key[2], res[0], res[1] ) )
62 if pos % 2 == 1:
63 lines.append('\hline')
64
65 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
66 res = [e*100 for e in results[key]]
67
68 lines.append( '%s & %s & %s & %2.2f & x \\%%\\\\' % ( key[0], key[1], key[2], res[2] ) )
69 if pos % 2 == 1:
70 lines.append('\hline')
71
72 lines.append('\end{tabular}')
73
74 lines = [l+'\n' for l in lines]
75 for l in lines:
76 fh.write(l)
77 fh.close()
78
79
80 def compare_scores_and_labels(scores,labels):
81 """
82 Iterate through all predictions. If we find a correct prediction check
83 whether this correct prediction scores higher than the incorrect
84 predictions for this example.
85 """
86
87 for currentPos,currentElem in enumerate(scores):
88 if labels[currentPos] == True:
89 for otherPos,otherElem in enumerate(scores):
90 if otherPos == currentPos:
91 continue
92
93 if labels[otherPos] == False and otherElem > currentElem:
94 return False
95
96 return True
97
98
99 def compare_exons(predExons,trueExons):
100 e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
101
102 if len(predExons) == 4:
103 e1_begin,e1_end = predExons[0],predExons[1]
104 e2_begin,e2_end = predExons[2],predExons[3]
105 else:
106 return False
107
108 e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
109 e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
110
111 e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
112 e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
113
114 if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
115 and e2_e_off == 0:
116 return True
117
118 return False
119
120
121 def evaluate_unmapped_example(current_prediction):
122 predExons = current_prediction['predExons']
123 trueExons = current_prediction['trueExons']
124
125 result = compare_exons(predExons,trueExons)
126 return result
127
128
129 def evaluate_example(current_prediction):
130 label = False
131 label = current_prediction['label']
132
133 pred_score = current_prediction['DPScores'].flatten().tolist()[0][0]
134
135 # if the read was mapped by vmatch at an incorrect position we only have to
136 # compare the score
137 if label == False:
138 return label,False,pred_score
139
140 predExons = current_prediction['predExons']
141 trueExons = current_prediction['trueExons']
142
143 predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
144 truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
145
146 #pdb.set_trace()
147
148 pos_comparison = (predPositions == truePositions)
149
150 #if label == True and pos_comparison == False:
151 # pdb.set_trace()
152
153 return label,pos_comparison,pred_score
154
155
156 def prediction_on(filename):
157
158 global data
159
160 OriginalEsts= data[2]
161 OriginalEsts[10000:] = OriginalEsts
162
163 SplitPositions = data[5]
164 SplitPositions[10000:] = SplitPositions
165
166
167 incorrect_gt_cuts = {}
168 incorrect_vmatch_cuts = {}
169
170 allPredictions = cPickle.load(open(filename))
171
172 exon1Begin = []
173 exon1End = []
174 exon2Begin = []
175 exon2End = []
176 allWrongExons = []
177 allDoubleScores = []
178
179 gt_correct_ctr = 0
180 gt_incorrect_ctr = 0
181
182 pos_correct_ctr = 0
183 pos_incorrect_ctr = 0
184
185 score_correct_ctr = 0
186 score_incorrect_ctr = 0
187
188 total_gt_examples = 0
189
190 total_vmatch_instances_ctr = 0
191 true_vmatch_instances_ctr = 0
192
193 for current_example_pred in allPredictions:
194 gt_example = current_example_pred[0]
195 gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
196 gt_correct = evaluate_unmapped_example(gt_example)
197
198 exampleIdx = gt_example['exampleIdx']
199 originalEst = OriginalEsts[exampleIdx]
200
201 cut_pos = gt_example['true_cut']
202 #cut_pos = SplitPositions[exampleIdx]
203 #ctr = 0
204 #for elem in originalEst:
205 # if elem == '[':
206 # ctr += 3
207 #
208 #cut_pos -= ctr
209
210 current_scores = []
211 current_labels = []
212 current_scores.append(gt_score)
213 current_labels.append(gt_correct)
214
215 if gt_correct:
216 gt_correct_ctr += 1
217 else:
218 gt_incorrect_ctr += 1
219
220 try:
221 incorrect_gt_cuts[cut_pos] += 1
222 except:
223 incorrect_gt_cuts[cut_pos] = 1
224
225 total_gt_examples += 1
226
227 #pdb.set_trace()
228
229 for elem_nr,current_pred in enumerate(current_example_pred[1:]):
230 current_label,comparison_result,current_score = evaluate_example(current_pred)
231
232 # if vmatch found the right read pos we check for right exons
233 # boundaries
234 if current_label:
235 if comparison_result:
236 pos_correct_ctr += 1
237 else:
238 pos_incorrect_ctr += 1
239
240 try:
241 incorrect_vmatch_cuts[cut_pos] += 1
242 except:
243 incorrect_vmatch_cuts[cut_pos] = 1
244
245 true_vmatch_instances_ctr += 1
246
247 current_scores.append(current_score)
248 current_labels.append(current_label)
249
250 total_vmatch_instances_ctr += 1
251
252 # check whether the correct predictions score higher than the incorrect
253 # ones
254 cmp_res = compare_scores_and_labels(current_scores,current_labels)
255 if cmp_res:
256 score_correct_ctr += 1
257 else:
258 score_incorrect_ctr += 1
259
260 # now that we have evaluated all instances put out all counters and sizes
261 print 'Total num. of examples: %d' % len(allPredictions)
262 print 'Number of correct ground truth examples: %d' % gt_correct_ctr
263 print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
264 print 'Correct pos: %d, incorrect pos: %d' %\
265 (pos_correct_ctr,pos_incorrect_ctr)
266 print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
267 print 'Correct scores: %d, incorrect scores: %d' %\
268 (score_correct_ctr,score_incorrect_ctr)
269
270 pos_error = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
271 score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
272 gt_error = 1.0 * gt_incorrect_ctr / total_gt_examples
273
274 #print pos_error,score_error,gt_error
275
276 return (pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts)
277
278
279 def collect_prediction(current_dir,run_name):
280 """
281 Given the toplevel directory this function takes care that for each distinct
282 experiment the training and test predictions are evaluated.
283
284 """
285 train_suffix = '_allPredictions_TRAIN'
286 test_suffix = '_allPredictions_TEST'
287
288 jp = os.path.join
289 b2s = ['-','+']
290
291 currentRun = cPickle.load(open(jp(current_dir,'run_object.pickle')))
292 QFlag = currentRun['enable_quality_scores']
293 SSFlag = currentRun['enable_splice_signals']
294 ILFlag = currentRun['enable_intron_length']
295 currentRunId = '%s%s%s' % (b2s[QFlag],b2s[SSFlag],b2s[ILFlag])
296
297 filename = jp(current_dir,run_name)+train_suffix
298 print 'Prediction on: %s' % filename
299 train_result = prediction_on(filename)
300
301 filename = jp(current_dir,run_name)+test_suffix
302 print 'Prediction on: %s' % filename
303 test_result = prediction_on(filename)
304
305 return train_result,test_result,currentRunId
306
307
308 def perform_prediction(current_dir,run_name):
309 """
310 This function takes care of starting the jobs needed for the prediction phase
311 of qpalma
312 """
313 cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s | qsub -l h_vmem=12.0G -cwd -j y -N \"%s.log\"'%(current_dir,run_name)
314 #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
315 #print cmd
316 os.system(cmd)
317
318
319 def forall_experiments(current_func,tl_dir):
320 """
321 Given the toplevel directoy this function calls for each subdir the
322 function given as first argument. Which are at the moment:
323
324 - perform_prediction, and
325 - collect_prediction.
326
327 """
328
329 dir_entries = os.listdir(tl_dir)
330 dir_entries = [os.path.join(tl_dir,de) for de in dir_entries]
331 run_dirs = [de for de in dir_entries if os.path.isdir(de)]
332
333 all_results = {}
334 all_error_rates = {}
335
336 for current_dir in run_dirs:
337 run_name = current_dir.split('/')[-1]
338
339 #current_func(current_dir,run_name)
340
341 train_result,test_result,currentRunId = current_func(current_dir,run_name)
342 all_results[currentRunId] = test_result
343 pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
344 all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
345
346 createErrorVSCutPlot(all_error_rates)
347 #createTable(all_results)
348
349
350 if __name__ == '__main__':
351 dir = sys.argv[1]
352 assert os.path.exists(dir), 'Error: Directory does not exist!'
353
354 global data
355 data_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/dataset_remapped_test_new'
356 data = cPickle.load(open(data_fn))
357
358 #forall_experiments(perform_prediction,dir)
359 forall_experiments(collect_prediction,dir)