git-svn-id: http://svn.tuebingen.mpg.de/ag-raetsch/projects/QPalma@8634 e1793c9e...
[qpalma.git] / scripts / Evaluation.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pydb
7 import pdb
8 import os
9 import os.path
10 import math
11
12 data = None
13
14
15 def createErrorVSCutPlot(results):
16 """
17 This function takes the results of the evaluation and creates a tex table.
18 """
19
20 fh = open('error_rates_table.tex','w+')
21 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
22 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
23 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
24
25 #for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
26 for pos,key in enumerate(['+++']):
27 res = results[key]
28 for i in range(37):
29 ctr = 0
30 try:
31 ctr = res[1][i]
32 except:
33 ctr = 0
34
35 lines.append( '%d\n' % ctr)
36
37 if pos % 2 == 1:
38 lines.append('\hline')
39
40 lines.append('\end{tabular}')
41
42 lines = [l+'\n' for l in lines]
43 for l in lines:
44 fh.write(l)
45 fh.close()
46
47
48 def createTable(results):
49 """
50 This function takes the results of the evaluation and creates a tex table.
51 """
52
53 fh = open('result_table.tex','w+')
54 lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
55 'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
56 'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
57
58 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
59 res = [e*100 for e in results[key]]
60
61 lines.append( '%s & %s & %s & %2.2f & %2.2f \\%%\\\\' % ( key[0], key[1], key[2], res[0], res[1] ) )
62 if pos % 2 == 1:
63 lines.append('\hline')
64
65 for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
66 res = [e*100 for e in results[key]]
67
68 lines.append( '%s & %s & %s & %2.2f & x \\%%\\\\' % ( key[0], key[1], key[2], res[2] ) )
69 if pos % 2 == 1:
70 lines.append('\hline')
71
72 lines.append('\end{tabular}')
73
74 lines = [l+'\n' for l in lines]
75 for l in lines:
76 fh.write(l)
77 fh.close()
78
79
80 def compare_scores_and_labels(scores,labels):
81 """
82 Iterate through all predictions. If we find a correct prediction check
83 whether this correct prediction scores higher than the incorrect
84 predictions for this example.
85 """
86
87 for currentPos,currentElem in enumerate(scores):
88 if labels[currentPos] == True:
89 for otherPos,otherElem in enumerate(scores):
90 if otherPos == currentPos:
91 continue
92
93 if labels[otherPos] == False and otherElem > currentElem:
94 return False
95
96 return True
97
98
99 def compare_exons(predExons,trueExons):
100 e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
101
102 if len(predExons) == 4:
103 e1_begin,e1_end = predExons[0],predExons[1]
104 e2_begin,e2_end = predExons[2],predExons[3]
105 else:
106 return False
107
108 e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
109 e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
110
111 e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
112 e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
113
114 if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
115 and e2_e_off == 0:
116 return True
117
118 return False
119
120
121 def evaluate_unmapped_example(current_prediction):
122 predExons = current_prediction['predExons']
123 trueExons = current_prediction['trueExons']
124
125 result = compare_exons(predExons,trueExons)
126 return result
127
128
129 def evaluate_example(current_prediction):
130 label = False
131 label = current_prediction['label']
132
133 pred_score = current_prediction['DPScores'].flatten().tolist()[0][0]
134
135 # if the read was mapped by vmatch at an incorrect position we only have to
136 # compare the score
137 if label == False:
138 return label,False,pred_score
139
140 predExons = current_prediction['predExons']
141 trueExons = current_prediction['trueExons']
142
143 predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
144 truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
145
146 #pdb.set_trace()
147
148 pos_comparison = (predPositions == truePositions)
149
150 #if label == True and pos_comparison == False:
151 # pdb.set_trace()
152
153 return label,pos_comparison,pred_score
154
155
156 def prediction_on(filename):
157
158 incorrect_gt_cuts = {}
159 incorrect_vmatch_cuts = {}
160
161 allPredictions = cPickle.load(open(filename))
162
163 exon1Begin = []
164 exon1End = []
165 exon2Begin = []
166 exon2End = []
167 allWrongExons = []
168 allDoubleScores = []
169
170 gt_correct_ctr = 0
171 gt_incorrect_ctr = 0
172
173 pos_correct_ctr = 0
174 pos_incorrect_ctr = 0
175
176 score_correct_ctr = 0
177 score_incorrect_ctr = 0
178
179 total_gt_examples = 0
180
181 total_vmatch_instances_ctr = 0
182 true_vmatch_instances_ctr = 0
183
184 for current_example_pred in allPredictions:
185 gt_example = current_example_pred[0]
186 gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
187 gt_correct = evaluate_unmapped_example(gt_example)
188
189 exampleIdx = gt_example['exampleIdx']
190
191 cut_pos = gt_example['true_cut']
192
193 current_scores = []
194 current_labels = []
195 current_scores.append(gt_score)
196 current_labels.append(gt_correct)
197
198 if gt_correct:
199 gt_correct_ctr += 1
200 else:
201 gt_incorrect_ctr += 1
202
203 try:
204 incorrect_gt_cuts[cut_pos] += 1
205 except:
206 incorrect_gt_cuts[cut_pos] = 1
207
208 total_gt_examples += 1
209
210 #pdb.set_trace()
211
212 for elem_nr,current_pred in enumerate(current_example_pred[1:]):
213 current_label,comparison_result,current_score = evaluate_example(current_pred)
214
215 # if vmatch found the right read pos we check for right exons
216 # boundaries
217 if current_label:
218 if comparison_result:
219 pos_correct_ctr += 1
220 else:
221 pos_incorrect_ctr += 1
222
223 try:
224 incorrect_vmatch_cuts[cut_pos] += 1
225 except:
226 incorrect_vmatch_cuts[cut_pos] = 1
227
228 true_vmatch_instances_ctr += 1
229
230 current_scores.append(current_score)
231 current_labels.append(current_label)
232
233 total_vmatch_instances_ctr += 1
234
235 # check whether the correct predictions score higher than the incorrect
236 # ones
237 cmp_res = compare_scores_and_labels(current_scores,current_labels)
238 if cmp_res:
239 score_correct_ctr += 1
240 else:
241 score_incorrect_ctr += 1
242
243 # now that we have evaluated all instances put out all counters and sizes
244 print 'Total num. of examples: %d' % len(allPredictions)
245 print 'Number of correct ground truth examples: %d' % gt_correct_ctr
246 print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
247 print 'Correct pos: %d, incorrect pos: %d' %\
248 (pos_correct_ctr,pos_incorrect_ctr)
249 print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
250 print 'Correct scores: %d, incorrect scores: %d' %\
251 (score_correct_ctr,score_incorrect_ctr)
252
253 pos_error = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
254 score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
255 gt_error = 1.0 * gt_incorrect_ctr / total_gt_examples
256
257 #print pos_error,score_error,gt_error
258
259 return (pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts)
260
261
262 def collect_prediction(current_dir,run_name):
263 """
264 Given the toplevel directory this function takes care that for each distinct
265 experiment the training and test predictions are evaluated.
266
267 """
268 idx = 5
269
270 train_suffix = '_%d_allPredictions_TRAIN' % (idx)
271 test_suffix = '_%d_allPredictions_TEST' % (idx)
272
273 jp = os.path.join
274 b2s = ['-','+']
275
276 currentRun = cPickle.load(open(jp(current_dir,'run_object_%d.pickle'%(idx))))
277 QFlag = currentRun['enable_quality_scores']
278 SSFlag = currentRun['enable_splice_signals']
279 ILFlag = currentRun['enable_intron_length']
280 currentRunId = '%s%s%s' % (b2s[QFlag],b2s[SSFlag],b2s[ILFlag])
281
282 #filename = jp(current_dir,run_name)+train_suffix
283 #print 'Prediction on: %s' % filename
284 #train_result = prediction_on(filename)
285 train_result = []
286
287 filename = jp(current_dir,run_name)+test_suffix
288 print 'Prediction on: %s' % filename
289 test_result = prediction_on(filename)
290
291 return train_result,test_result,currentRunId
292
293
294 def perform_prediction(current_dir,run_name):
295 """
296 This function takes care of starting the jobs needed for the prediction phase
297 of qpalma
298 """
299 #for i in range(1,6):
300 for i in range(1,2):
301 cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s %d |\
302 qsub -l h_vmem=12.0G -cwd -j y -N \"%s_%d.log\"'%(current_dir,i,run_name,i)
303
304 #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
305 #print cmd
306 os.system(cmd)
307
308
309 def forall_experiments(current_func,tl_dir):
310 """
311 Given the toplevel directoy this function calls for each subdir the
312 function given as first argument. Which are at the moment:
313
314 - perform_prediction, and
315 - collect_prediction.
316
317 """
318
319 dir_entries = os.listdir(tl_dir)
320 dir_entries = [os.path.join(tl_dir,de) for de in dir_entries]
321 run_dirs = [de for de in dir_entries if os.path.isdir(de)]
322
323 all_results = {}
324 all_error_rates = {}
325
326 for current_dir in run_dirs:
327 run_name = current_dir.split('/')[-1]
328
329 current_func(current_dir,run_name)
330
331 #train_result,test_result,currentRunId = current_func(current_dir,run_name)
332 #all_results[currentRunId] = test_result
333 #pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
334 #all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
335
336 #createErrorVSCutPlot(all_error_rates)
337 #createTable(all_results)
338
339
340 if __name__ == '__main__':
341 dir = sys.argv[1]
342 assert os.path.exists(dir), 'Error: Directory does not exist!'
343
344 #global data
345 #data_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/dataset_remapped_test_new'
346 #data = cPickle.load(open(data_fn))
347
348 forall_experiments(perform_prediction,dir)
349 #forall_experiments(collect_prediction,dir)