minor changes in the dataset compilation and the evaluation scripts
[qpalma.git] / scripts / Evaluation.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pydb
7 import pdb
8 import os
9 import os.path
10 import math
11
12
13 def createTable(results):
14 """
15 This function takes the results of the evaluation and creates a tex table.
16 """
17
18 fh = open('result_table.tex','w+')
19 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
20 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
21 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\',\
22 '\hline',\
23
24 '- & - & - & %2.3f & %2.3f \\%%\\\\' % ( results['---'][0], results['---'][1] ),\
25 '+ & - & - & %2.3f & %2.3f \\%%\\\\' % ( results['+--'][0], results['+--'][1] ),\
26
27 '\hline',\
28
29 '- & + & - & %2.3f & %2.3f \\%%\\\\' % ( results['-+-'][0], results['-+-'][1] ),\
30 '+ & + & - & %2.3f & %2.3f \\%%\\\\' % ( results['++-'][0], results['++-'][1] ),\
31
32 '\hline\hline',\
33
34 '- & - & + & %2.3f & %2.3f \\%%\\\\' % ( results['--+'][0], results['--+'][1] ),\
35 '+ & - & + & %2.3f & %2.3f \\%%\\\\' % ( results['+-+'][0], results['+-+'][1] ),\
36
37 '\hline',\
38
39 '- & + & + & %2.3f & %2.3f \\%%\\\\' % ( results['-++'][0], results['-++'][1] ),\
40 '+ & + & + & %2.3f & %2.3f \\%%\\\\' % ( results['+++'][0], results['+++'][1] ),\
41
42 '\hline',\
43 '\end{tabular}']
44
45 lines = [l+'\n' for l in lines]
46 for l in lines:
47 fh.write(l)
48 fh.close()
49
50
51 def compare_scores_and_labels(scores,labels):
52 """
53 Iterate through all predictions. If we find a correct prediction check
54 whether this correct prediction scores higher than the incorrect
55 predictions for this example.
56 """
57
58 for currentPos,currentElem in enumerate(scores):
59 if labels[currentPos] == True:
60 for otherPos,otherElem in enumerate(scores):
61 if otherPos == currentPos:
62 continue
63
64 if labels[otherPos] == False and otherElem > currentElem:
65 return False
66
67 return True
68
69
70 def compare_exons(predExons,trueExons):
71 e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
72
73 if len(predExons) == 4:
74 e1_begin,e1_end = predExons[0],predExons[1]
75 e2_begin,e2_end = predExons[2],predExons[3]
76 else:
77 return False
78
79 e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
80 e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
81
82 e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
83 e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
84
85 if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
86 and e2_e_off == 0:
87 return True
88
89 return False
90
91
92 def evaluate_unmapped_example(current_prediction):
93 predExons = current_prediction['predExons']
94 trueExons = current_prediction['trueExons']
95
96 result = compare_exons(predExons,trueExons)
97 return result
98
99
100 def evaluate_example(current_prediction):
101 label = False
102 label = current_prediction['label']
103
104 pred_score = current_prediction['DPScores'].flatten().tolist()[0][0]
105
106 # if the read was mapped by vmatch at an incorrect position we only have to
107 # compare the score
108 if label == False:
109 return label,False,pred_score
110
111 predExons = current_prediction['predExons']
112 trueExons = current_prediction['trueExons']
113
114 predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
115 truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
116
117 #pdb.set_trace()
118
119 pos_comparison = (predPositions == truePositions)
120
121 #if label == True and pos_comparison == False:
122 # pdb.set_trace()
123
124 return label,pos_comparison,pred_score
125
126
127 def prediction_on(filename):
128 allPredictions = cPickle.load(open(filename))
129
130 exon1Begin = []
131 exon1End = []
132 exon2Begin = []
133 exon2End = []
134 allWrongExons = []
135 allDoubleScores = []
136
137 gt_correct_ctr = 0
138 pos_correct_ctr = 0
139 pos_incorrect_ctr = 0
140 score_correct_ctr = 0
141 score_incorrect_ctr = 0
142
143 total_vmatch_instances_ctr = 0
144 true_vmatch_instances_ctr = 0
145
146 for current_example_pred in allPredictions:
147 gt_example = current_example_pred[0]
148 gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
149 gt_correct = evaluate_unmapped_example(gt_example)
150
151 current_scores = []
152 current_labels = []
153 current_scores.append(gt_score)
154 current_labels.append(gt_correct)
155
156 if gt_correct:
157 gt_correct_ctr += 1
158
159 for elem_nr,current_pred in enumerate(current_example_pred[1:]):
160 current_label,comparison_result,current_score = evaluate_example(current_pred)
161
162 # if vmatch found the right read pos we check for right exons
163 # boundaries
164 if current_label:
165 if comparison_result:
166 pos_correct_ctr += 1
167 else:
168 pos_incorrect_ctr += 1
169
170 true_vmatch_instances_ctr += 1
171
172 current_scores.append(current_score)
173 current_labels.append(current_label)
174
175 total_vmatch_instances_ctr += 1
176
177 # check whether the correct predictions score higher than the incorrect
178 # ones
179 cmp_res = compare_scores_and_labels(current_scores,current_labels)
180 if cmp_res:
181 score_correct_ctr += 1
182 else:
183 score_incorrect_ctr += 1
184
185 # now that we have evaluated all instances put out all counters and sizes
186 print 'Total num. of examples: %d' % len(allPredictions)
187 print 'Number of correct ground truth examples: %d' % gt_correct_ctr
188 print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
189 print 'Correct pos: %d, incorrect pos: %d' %\
190 (pos_correct_ctr,pos_incorrect_ctr)
191 print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
192 print 'Correct scores: %d, incorrect scores: %d' %\
193 (score_correct_ctr,score_incorrect_ctr)
194
195 pos_error = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
196 score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
197
198 print pos_error,score_error
199
200 return (pos_error,score_error)
201
202
203 def collect_prediction(current_dir,run_name):
204 """
205 Given the toplevel directory this function takes care that for each distinct
206 experiment the training and test predictions are evaluated.
207
208 """
209 train_suffix = '_allPredictions_TRAIN'
210 test_suffix = '_allPredictions_TEST'
211
212 jp = os.path.join
213 b2s = ['-','+']
214
215 currentRun = cPickle.load(open(jp(current_dir,'run_object.pickle')))
216 QFlag = currentRun['enable_quality_scores']
217 SSFlag = currentRun['enable_splice_signals']
218 ILFlag = currentRun['enable_intron_length']
219 currentRunId = '%s%s%s' % (b2s[QFlag],b2s[SSFlag],b2s[ILFlag])
220
221 filename = jp(current_dir,run_name)+train_suffix
222 print 'Prediction on: %s' % filename
223 train_result = prediction_on(filename)
224
225 filename = jp(current_dir,run_name)+test_suffix
226 print 'Prediction on: %s' % filename
227 test_result = prediction_on(filename)
228
229 return train_result,test_result,currentRunId
230
231
232 def perform_prediction(current_dir,run_name):
233 """
234 This function takes care of starting the jobs needed for the prediction phase
235 of qpalma
236 """
237 cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s | qsub -l h_vmem=12.0G -cwd -j y -N \"%s.log\"'%(current_dir,run_name)
238 #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
239 #print cmd
240 os.system(cmd)
241
242
243 def forall_experiments(current_func,tl_dir):
244 """
245 Given the toplevel directoy this function calls for each subdir the
246 function given as first argument. Which are at the moment:
247
248 - perform_prediction, and
249 - collect_prediction.
250
251 """
252
253 dir_entries = os.listdir(tl_dir)
254 dir_entries = [os.path.join(tl_dir,de) for de in dir_entries]
255 run_dirs = [de for de in dir_entries if os.path.isdir(de)]
256
257 all_results = {}
258
259 for current_dir in run_dirs:
260 run_name = current_dir.split('/')[-1]
261
262 #current_func(current_dir,run_name)
263
264 train_result,test_result,currentRunId = current_func(current_dir,run_name)
265 all_results[currentRunId] = test_result
266
267 createTable(all_results)
268
269
270 if __name__ == '__main__':
271 dir = sys.argv[1]
272 assert os.path.exists(dir), 'Error: Directory does not exist!'
273
274 #forall_experiments(perform_prediction,dir)
275 forall_experiments(collect_prediction,dir)