git-svn-id: http://svn.tuebingen.mpg.de/ag-raetsch/projects/QPalma@8648 e1793c9e...
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10 import resource
11
12 from qpalma.DataProc import *
13 from qpalma.computeSpliceWeights import *
14 from qpalma.set_param_palma import *
15 from qpalma.computeSpliceAlignWithQuality import *
16 from qpalma.penalty_lookup_new import *
17 from qpalma.compute_donacc import *
18 from qpalma.TrainingParam import Param
19 from qpalma.Plif import Plf
20
21 from qpalma.tools.splicesites import getDonAccScores
22 from qpalma.Configuration import *
23
24 from compile_dataset import getSpliceScores, get_seq_and_scores
25
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy import inf,mean
28
29 from qpalma.parsers import PipelineReadParser
30
31
32 def unbracket_est(est):
33 new_est = ''
34 e = 0
35
36 while True:
37 if e >= len(est):
38 break
39
40 if est[e] == '[':
41 new_est += est[e+2]
42 e += 4
43 else:
44 new_est += est[e]
45 e += 1
46
47 return "".join(new_est).lower()
48
49
50 class PipelineHeuristic:
51 """
52 This class wraps the filter which decides whether an alignment found by
53 vmatch is spliced an should be then newly aligned using QPalma or not.
54 """
55
56 def __init__(self,run_fname,data_fname,param_fname):
57 """
58 We need a run object holding information about the nr. of support points
59 etc.
60 """
61
62 run = cPickle.load(open(run_fname))
63 self.run = run
64
65 start = cpu()
66
67 self.data_fname = data_fname
68
69 self.param = cPickle.load(open(param_fname))
70
71 # Set the parameters such as limits penalties for the Plifs
72 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
73
74 self.h = h
75 self.d = d
76 self.a = a
77 self.mmatrix = mmatrix
78 self.qualityPlifs = qualityPlifs
79
80 # when we look for alternative alignments with introns this value is the
81 # mean of intron size
82 self.intron_size = 150
83
84 self.read_size = 36
85
86 self.original_reads = {}
87
88 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
89 line = line.strip()
90 id,seq,q1,q2,q3 = line.split()
91 id = int(id)
92 self.original_reads[id] = seq
93
94 lengthSP = run['numLengthSuppPoints']
95 donSP = run['numDonSuppPoints']
96 accSP = run['numAccSuppPoints']
97 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
98 numq = run['numQualSuppPoints']
99 totalQualSP = run['totalQualSuppPoints']
100
101 currentPhi = zeros((run['numFeatures'],1))
102 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
103 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
104 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
105 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
106
107 totalQualityPenalties = self.param[-totalQualSP:]
108 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
109 self.currentPhi = currentPhi
110
111 # we want to identify spliced reads
112 # so true pos are spliced reads that are predicted "spliced"
113 self.true_pos = 0
114
115 # as false positives we count all reads that are not spliced but predicted
116 # as "spliced"
117 self.false_pos = 0
118
119 self.true_neg = 0
120 self.false_neg = 0
121
122 # total time spend for get seq and scores
123 self.get_time = 0.0
124 self.calcAlignmentScoreTime = 0.0
125 self.alternativeScoresTime = 0.0
126
127 self.count_time = 0.0
128 self.read_parsing = 0.0
129 self.main_loop = 0.0
130 self.splice_site_time = 0.0
131 self.computeSpliceAlignWithQualityTime = 0.0
132 self.computeSpliceWeightsTime = 0.0
133 self.DotProdTime = 0.0
134 self.array_stuff = 0.0
135 stop = cpu()
136
137 self.init_time = stop-start
138
139 def filter(self):
140 """
141 This method...
142 """
143 run = self.run
144
145 start = cpu()
146
147 rrp = PipelineReadParser(self.data_fname)
148 all_remapped_reads = rrp.parse()
149
150 stop = cpu()
151
152 self.read_parsing = stop-start
153
154 ctr = 0
155 unspliced_ctr = 0
156 spliced_ctr = 0
157
158 print 'Starting filtering...'
159 _start = cpu()
160
161 for readId,currentReadLocations in all_remapped_reads.items():
162 for location in currentReadLocations[:1]:
163
164 id = location['id']
165 chr = location['chr']
166 pos = location['pos']
167 strand = location['strand']
168 mismatch = location['mismatches']
169 length = location['length']
170 off = location['offset']
171 seq = location['seq']
172 prb = location['prb']
173 cal_prb = location['cal_prb']
174 chastity = location['chastity']
175
176 id = int(id)
177
178 if strand == '-':
179 continue
180
181 if ctr == 100:
182 break
183
184 #if pos > 10000000:
185 # continue
186
187 unb_seq = unbracket_est(seq)
188 effective_len = len(unb_seq)
189
190 genomicSeq_start = pos
191 genomicSeq_stop = pos+effective_len-1
192
193 start = cpu()
194 #print genomicSeq_start,genomicSeq_stop
195 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
196 stop = cpu()
197 self.get_time += stop-start
198
199 dna = currentDNASeq
200 exons = zeros((2,1))
201 exons[0,0] = 0
202 exons[1,0] = effective_len
203 est = unb_seq
204 original_est = seq
205 quality = prb
206
207 #pdb.set_trace()
208
209 currentVMatchAlignment = dna, exons, est, original_est, quality,\
210 currentAcc, currentDon
211 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
212
213 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
214
215 start = cpu()
216 # found no alternatives
217 if alternativeAlignmentScores == []:
218 continue
219
220 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
221 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
222 #print 'all candidates %s' % str(alternativeAlignmentScores)
223
224 new_id = id - 1000000300000
225
226 unspliced = False
227 # unspliced
228 if new_id > 0:
229 unspliced = True
230
231 # Seems that according to our learned parameters VMatch found a good
232 # alignment of the current read
233 if maxAlternativeAlignmentScore < vMatchScore:
234 unspliced_ctr += 1
235
236 if unspliced:
237 self.true_neg += 1
238 else:
239 self.false_neg += 1
240
241 # We found an alternative alignment considering splice sites that scores
242 # higher than the VMatch alignment
243 else:
244 spliced_ctr += 1
245
246 if unspliced:
247 self.false_pos += 1
248 else:
249 self.true_pos += 1
250
251 ctr += 1
252 stop = cpu()
253 self.count_time = stop-start
254
255 _stop = cpu()
256 self.main_loop = _stop-_start
257
258 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
259 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
260 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
261
262
263 def findHighestScoringSpliceSites(self,currentAcc,currentDon):
264
265 acc = []
266 for idx,score in enumerate(currentAcc):
267 if score > -inf:
268 acc.append((idx,score))
269 if idx>self.read_size:
270 break
271
272 acc.sort(lambda x,y: x[0]-y[0])
273 acc=acc[-2:]
274
275 don = []
276 for idx,score in enumerate(currentDon):
277 if score > -inf:
278 don.append((idx,score))
279 if idx>self.read_size:
280 break
281
282 don.sort(lambda x,y: x[0]-y[0])
283 don=don[-2:]
284
285 return don,acc
286
287
288 def calcAlternativeAlignments(self,location):
289 """
290 Given an alignment proposed by Vmatch this function calculates possible
291 alternative alignments taking into account for example matched
292 donor/acceptor positions.
293 """
294
295 run = self.run
296
297 id = location['id']
298 chr = location['chr']
299 pos = location['pos']
300 strand = location['strand']
301 original_est = location['seq']
302 quality = location['prb']
303 cal_prb = location['cal_prb']
304
305 est = unbracket_est(original_est)
306 effective_len = len(est)
307
308 genomicSeq_start = pos
309 genomicSeq_stop = pos+1000#self.read_size+1#+self.intron_size*2+self.read_size*2
310
311 start = cpu()
312 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
313 stop = cpu()
314 self.get_time += stop-start
315 dna = currentDNASeq
316
317 alt_don,alt_acc = self.findHighestScoringSpliceSites(currentAcc,currentDon)
318
319 sort_acc=[ elem for elem in enumerate(currentAcc) ]
320 sort_acc.sort(lambda x,y: x[1]-y[1])
321
322 (max_acc_pos,max_acc_score)=sort_acc[-1]
323
324 #print alt_don
325 #print alt_acc
326
327 alternativeScores = []
328
329 # inlined
330 h = self.h
331 d = self.d
332 a = self.a
333 mmatrix = self.mmatrix
334 qualityPlifs = self.qualityPlifs
335 # inlined
336
337 # compute dummy scores
338 IntronScore = calculatePlif(h, [math.fabs(max_acc_pos-30)])[0]
339 #print IntronScore
340 dummyAcceptorScore = calculatePlif(a, [max_acc_score])[0]
341 dummyDonorScore = calculatePlif(d, [0.25])[0]
342
343 _start = cpu()
344 for (don_pos,don_score) in alt_don:
345 # remove mismatching positions in the second exon
346 original_est_cut=''
347
348 est_ptr=0
349 dna_ptr=0
350 ptr=0
351 while ptr<len(original_est):
352
353 if original_est[ptr]=='[':
354 dnaletter=original_est[ptr+1]
355 estletter=original_est[ptr+2]
356 if dna_ptr < don_pos:
357 original_est_cut+=original_est[ptr:ptr+4]
358 else:
359 original_est_cut+=estletter # EST letter
360 ptr+=4
361 else:
362 dnaletter=original_est[ptr]
363 estletter=dnaletter
364
365 original_est_cut+=estletter # EST letter
366 ptr+=1
367
368 if estletter=='-':
369 dna_ptr+=1
370 elif dnaletter=='-':
371 est_ptr+=1
372 else:
373 dna_ptr+=1
374 est_ptr+=1
375
376 assert(dna_ptr<=len(dna))
377 assert(est_ptr<=len(est))
378
379 #print "Donor"
380 DonorScore = calculatePlif(d, [don_score])[0]
381 #print DonorScore,don_score,don_pos
382
383 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
384 score += dummyAcceptorScore + IntronScore + DonorScore
385
386 #print 'diff %f,%f,%f' % ((trueWeight.T * self.currentPhi)[0,0] - score,(trueWeight.T * self.currentPhi)[0,0], score)
387 alternativeScores.append(score)
388
389 _stop = cpu()
390 self.alternativeScoresTime += _stop-_start
391
392 _start = cpu()
393 for (acc_pos,acc_score) in alt_acc:
394 # remove mismatching positions in the second exon
395 original_est_cut=''
396
397 est_ptr=0
398 dna_ptr=0
399 ptr=0
400 while ptr<len(original_est):
401
402 if original_est[ptr]=='[':
403 dnaletter=original_est[ptr+1]
404 estletter=original_est[ptr+2]
405 if est_ptr>=acc_pos:
406 original_est_cut+=original_est[ptr:ptr+4]
407 else:
408 original_est_cut+=estletter # EST letter
409 ptr+=4
410 else:
411 dnaletter=original_est[ptr]
412 estletter=dnaletter
413
414 original_est_cut+=estletter # EST letter
415 ptr+=1
416
417 if estletter=='-':
418 dna_ptr+=1
419 elif dnaletter=='-':
420 est_ptr+=1
421 else:
422 dna_ptr+=1
423 est_ptr+=1
424
425 assert(dna_ptr<=len(dna))
426 assert(est_ptr<=len(est))
427
428 #print original_est,original_est_cut
429
430 AcceptorScore = calculatePlif(d, [acc_score])[0]
431 #print "Acceptor"
432 #print AcceptorScore,acc_score,acc_pos
433
434 #if acc_score<0.1:
435 # print currentAcc[0:50]
436 # print currentDon[0:50]
437
438 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
439 score += AcceptorScore + IntronScore + dummyDonorScore
440
441 #print 'diff %f,%f,%f' % ((trueWeight.T * self.currentPhi)[0,0] - score,(trueWeight.T * self.currentPhi)[0,0], score)
442 alternativeScores.append(score)
443
444 _stop = cpu()
445 self.alternativeScoresTime += _stop-_start
446
447 return alternativeScores
448
449
450 def calcAlignmentScore(self,alignment):
451 """
452 Given an alignment (dna,exons,est) and the current parameter for QPalma
453 this function calculates the dot product of the feature representation of
454 the alignment and the parameter vector i.e the alignment score.
455 """
456
457 start = cpu()
458 run = self.run
459
460 # Lets start calculation
461 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
462
463 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
464
465 stop = cpu()
466 self.calcAlignmentScoreTime += stop-start
467
468 return score
469
470
471 def cpu():
472 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
473 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
474
475
476 if __name__ == '__main__':
477 #run_fname = sys.argv[1]
478 #data_fname = sys.argv[2]
479 #param_filename = sys.argv[3]
480
481 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
482 jp = os.path.join
483
484 run_fname = jp(dir,'run_object.pickle')
485 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_flag'
486
487 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
488 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
489
490 param_fname = jp(dir,'param_500.pickle')
491
492 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname)
493
494 start = cpu()
495 ph1.filter()
496 stop = cpu()
497
498 print 'total time elapsed: %f' % (stop-start)
499 print 'time spend for get seq: %f' % ph1.get_time
500 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
501 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
502 print 'time spend for count time: %f' % ph1.count_time
503 print 'time spend for init time: %f' % ph1.init_time
504 print 'time spend for read_parsing time: %f' % ph1.read_parsing
505 print 'time spend for main_loop time: %f' % ph1.main_loop
506 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
507
508 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
509 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
510 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
511 print 'time spend forarray_stuff time: %f'% ph1.array_stuff
512 #import cProfile
513 #cProfile.run('ph1.filter()')
514
515 #import hotshot
516 #p = hotshot.Profile('profile.log')
517 #p.runcall(ph1.filter)