36869916e877d70a0bb877d0a7944900c54a34bc
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10 import resource
11
12 from qpalma.DataProc import *
13 from qpalma.computeSpliceWeights import *
14 from qpalma.set_param_palma import *
15 from qpalma.computeSpliceAlignWithQuality import *
16 from qpalma.penalty_lookup_new import *
17 from qpalma.compute_donacc import *
18 from qpalma.TrainingParam import Param
19 from qpalma.Plif import Plf
20
21 #from qpalma.tools.splicesites import getDonAccScores
22 from qpalma.Configuration import *
23
24 from compile_dataset import getSpliceScores, get_seq_and_scores
25
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy import inf,mean
28
29 from qpalma.parsers import PipelineReadParser
30
31
32 def unbracket_est(est):
33 new_est = ''
34 e = 0
35
36 while True:
37 if e >= len(est):
38 break
39
40 if est[e] == '[':
41 new_est += est[e+2]
42 e += 4
43 else:
44 new_est += est[e]
45 e += 1
46
47 return "".join(new_est).lower()
48
49
50 class PipelineHeuristic:
51 """
52 This class wraps the filter which decides whether an alignment found by
53 vmatch is spliced an should be then newly aligned using QPalma or not.
54 """
55
56 def __init__(self,run_fname,data_fname,param_fname,result_spliced_fname,result_unspliced_fname):
57 """
58 We need a run object holding information about the nr. of support points
59 etc.
60 """
61
62 run = cPickle.load(open(run_fname))
63 self.run = run
64
65 self.result_spliced_fh = open(result_spliced_fname,'w+')
66 self.result_unspliced_fh = open(result_unspliced_fname,'w+')
67
68 start = cpu()
69
70 self.data_fname = data_fname
71
72 self.param = cPickle.load(open(param_fname))
73
74 # Set the parameters such as limits penalties for the Plifs
75 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
76
77 self.h = h
78 self.d = d
79 self.a = a
80 self.mmatrix = mmatrix
81 self.qualityPlifs = qualityPlifs
82
83 self.read_size = 36
84
85 # parameters of the heuristics to decide whether the read is spliced
86 self.splice_thresh = 0.005
87 self.max_intron_size = 2000
88 self.max_mismatch = 2
89 self.splice_stop_thresh = 0.99
90 self.spliced_bias = 0.0
91
92 self.original_reads = {}
93
94 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
95 line = line.strip()
96 id,seq,q1,q2,q3 = line.split()
97 id = int(id)
98 self.original_reads[id] = seq
99
100 lengthSP = run['numLengthSuppPoints']
101 donSP = run['numDonSuppPoints']
102 accSP = run['numAccSuppPoints']
103 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
104 numq = run['numQualSuppPoints']
105 totalQualSP = run['totalQualSuppPoints']
106
107 currentPhi = zeros((run['numFeatures'],1))
108 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
109 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
110 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
111 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
112
113 totalQualityPenalties = self.param[-totalQualSP:]
114 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
115 self.currentPhi = currentPhi
116
117 # we want to identify spliced reads
118 # so true pos are spliced reads that are predicted "spliced"
119 self.true_pos = 0
120
121 # as false positives we count all reads that are not spliced but predicted
122 # as "spliced"
123 self.false_pos = 0
124
125 self.true_neg = 0
126 self.false_neg = 0
127
128 # total time spend for get seq and scores
129 self.get_time = 0.0
130 self.calcAlignmentScoreTime = 0.0
131 self.alternativeScoresTime = 0.0
132
133 self.count_time = 0.0
134 self.read_parsing = 0.0
135 self.main_loop = 0.0
136 self.splice_site_time = 0.0
137 self.computeSpliceAlignWithQualityTime = 0.0
138 self.computeSpliceWeightsTime = 0.0
139 self.DotProdTime = 0.0
140 self.array_stuff = 0.0
141 stop = cpu()
142
143 self.init_time = stop-start
144
145 def filter(self):
146 """
147 This method...
148 """
149 run = self.run
150
151 start = cpu()
152
153 rrp = PipelineReadParser(self.data_fname)
154 all_remapped_reads = rrp.parse()
155
156 stop = cpu()
157
158 self.read_parsing = stop-start
159
160 ctr = 0
161 unspliced_ctr = 0
162 spliced_ctr = 0
163
164 print 'Starting filtering...'
165 _start = cpu()
166
167 for readId,currentReadLocations in all_remapped_reads.items():
168 for location in currentReadLocations[:1]:
169
170 id = location['id']
171 chr = location['chr']
172 pos = location['pos']
173 strand = location['strand']
174 mismatch = location['mismatches']
175 length = location['length']
176 off = location['offset']
177 seq = location['seq']
178 prb = location['prb']
179 cal_prb = location['cal_prb']
180 chastity = location['chastity']
181
182 id = int(id)
183
184 strand_map = {'+':'D', '-':'P'}
185
186
187 original_line = '%d\t%d\t%d\t%s\t%d\t%d\t%d\t%s\t%s\t%s\t%s\n' %\
188 (id,chr,pos,strand_map[strand],int(mismatch),int(length),int(off),seq.upper(),prb,cal_prb,chastity)
189
190
191 if strand == '-':
192 continue
193
194 if ctr == 1000:
195 break
196
197 #if pos > 10000000:
198 # continue
199
200 unb_seq = unbracket_est(seq)
201 effective_len = len(unb_seq)
202
203 genomicSeq_start = pos
204 genomicSeq_stop = pos+effective_len-1
205
206 start = cpu()
207 #print genomicSeq_start,genomicSeq_stop
208 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
209 stop = cpu()
210 self.get_time += stop-start
211
212 dna = currentDNASeq
213 exons = zeros((2,1))
214 exons[0,0] = 0
215 exons[1,0] = effective_len
216 est = unb_seq
217 original_est = seq
218 quality = prb
219
220 #pdb.set_trace()
221
222 currentVMatchAlignment = dna, exons, est, original_est, quality,\
223 currentAcc, currentDon
224
225 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
226
227 if alternativeAlignmentScores == []:
228 # no alignment necessary
229 maxAlternativeAlignmentScore = -inf
230 vMatchScore = 0.0
231 else:
232 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
233 # compute alignment for vmatch unspliced read
234 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
235
236 start = cpu()
237
238 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
239 #print 'all candidates %s' % str(alternativeAlignmentScores)
240
241 new_id = id - 1000000300000
242
243 unspliced = False
244 # unspliced
245 if new_id > 0:
246 unspliced = True
247
248 # Seems that according to our learned parameters VMatch found a good
249 # alignment of the current read
250 if maxAlternativeAlignmentScore < vMatchScore:
251 unspliced_ctr += 1
252
253 self.result_unspliced_fh.write(original_line)
254
255 if unspliced:
256 self.true_neg += 1
257 else:
258 self.false_neg += 1
259
260 # We found an alternative alignment considering splice sites that scores
261 # higher than the VMatch alignment
262 else:
263 spliced_ctr += 1
264
265 self.result_spliced_fh.write(original_line)
266
267 if unspliced:
268 self.false_pos += 1
269 else:
270 self.true_pos += 1
271
272 ctr += 1
273 stop = cpu()
274 self.count_time = stop-start
275
276 _stop = cpu()
277 self.main_loop = _stop-_start
278
279 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
280 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
281 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
282
283
284 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
285
286 def signum(a):
287 if a>0:
288 return 1
289 elif a<0:
290 return -1
291 else:
292 return 0
293
294 proximal_acc = []
295 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
296 if currentAcc[idx]>= splice_thresh:
297 proximal_acc.append((idx,currentAcc[idx]))
298
299 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
300 proximal_acc=proximal_acc[-2:]
301
302 distal_acc = []
303 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
304 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
305 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
306
307 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
308 #distal_acc=distal_acc[-2:]
309
310
311 proximal_don = []
312 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
313 if currentDon[idx] >= splice_thresh:
314 proximal_don.append((idx, currentDon[idx]))
315
316 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
317 proximal_don=proximal_don[-2:]
318
319 distal_don = []
320 for idx in xrange(1, max_intron_size):
321 if currentDon[idx] >= splice_thresh and idx>read_size:
322 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
323
324 distal_don.sort(lambda x,y: y[0]-x[0])
325 #distal_don=distal_don[-2:]
326
327 return proximal_acc,proximal_don,distal_acc,distal_don
328
329 def calcAlternativeAlignments(self,location):
330 """
331 Given an alignment proposed by Vmatch this function calculates possible
332 alternative alignments taking into account for example matched
333 donor/acceptor positions.
334 """
335
336 run = self.run
337
338 id = location['id']
339 chr = location['chr']
340 pos = location['pos']
341 strand = location['strand']
342 original_est = location['seq']
343 quality = location['prb']
344 cal_prb = location['cal_prb']
345
346 est = unbracket_est(original_est)
347 effective_len = len(est)
348
349 genomicSeq_start = pos - self.max_intron_size
350 genomicSeq_stop = pos + self.max_intron_size + len(est)
351
352 start = cpu()
353 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
354 stop = cpu()
355 self.get_time += stop-start
356 dna = currentDNASeq
357
358 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
359
360 alternativeScores = []
361
362 # inlined
363 h = self.h
364 d = self.d
365 a = self.a
366 mmatrix = self.mmatrix
367 qualityPlifs = self.qualityPlifs
368 # inlined
369
370 # find an intron on the 3' end
371 _start = cpu()
372 for (don_pos,don_score) in proximal_don:
373 DonorScore = calculatePlif(d, [don_score])[0]
374
375 for (acc_pos,acc_score,acc_dna) in distal_acc:
376
377 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
378 AcceptorScore = calculatePlif(a, [acc_score])[0]
379
380 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
381
382 # construct a new "original_est"
383 original_est_cut=''
384
385 est_ptr=0
386 dna_ptr=self.max_intron_size
387 ptr=0
388 acc_dna_ptr=0
389 num_mismatch = 0
390
391 while ptr<len(original_est):
392 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
393
394 if original_est[ptr]=='[':
395 dnaletter=original_est[ptr+1]
396 estletter=original_est[ptr+2]
397 if dna_ptr < don_pos:
398 original_est_cut+=original_est[ptr:ptr+4]
399 num_mismatch += 1
400 else:
401 if acc_dna[acc_dna_ptr]==estletter:
402 original_est_cut += estletter # EST letter
403 else:
404 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
405 num_mismatch += 1
406 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
407 acc_dna_ptr+=1
408 ptr+=4
409 else:
410 dnaletter=original_est[ptr]
411 estletter=dnaletter
412
413 if dna_ptr < don_pos:
414 original_est_cut+=estletter # EST letter
415 else:
416 if acc_dna[acc_dna_ptr]==estletter:
417 original_est_cut += estletter # EST letter
418 else:
419 num_mismatch += 1
420 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
421 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
422 acc_dna_ptr+=1
423
424 ptr+=1
425
426 if estletter=='-':
427 dna_ptr+=1
428 elif dnaletter=='-':
429 est_ptr+=1
430 else:
431 dna_ptr+=1
432 est_ptr+=1
433 if num_mismatch>self.max_mismatch:
434 continue
435
436 assert(dna_ptr<=len(dna))
437 assert(est_ptr<=len(est))
438
439 #print original_est, original_est_cut
440
441 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
442 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
443
444 alternativeScores.append(score)
445
446 if acc_score>=self.splice_stop_thresh:
447 break
448
449 _stop = cpu()
450 self.alternativeScoresTime += _stop-_start
451
452 # find an intron on the 5' end
453 _start = cpu()
454 for (acc_pos,acc_score) in proximal_acc:
455
456 AcceptorScore = calculatePlif(a, [acc_score])[0]
457
458 for (don_pos,don_score,don_dna) in distal_don:
459
460 DonorScore = calculatePlif(d, [don_score])[0]
461 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
462
463 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
464
465 # construct a new "original_est"
466 original_est_cut=''
467
468 est_ptr=0
469 dna_ptr=self.max_intron_size
470 ptr=0
471 num_mismatch = 0
472 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
473 while ptr<len(original_est):
474
475 if original_est[ptr]=='[':
476 dnaletter=original_est[ptr+1]
477 estletter=original_est[ptr+2]
478 if dna_ptr > acc_pos:
479 original_est_cut+=original_est[ptr:ptr+4]
480 num_mismatch += 1
481 else:
482 if don_dna[don_dna_ptr]==estletter:
483 original_est_cut += estletter # EST letter
484 else:
485 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
486 num_mismatch += 1
487 #print '['+don_dna[don_dna_ptr]+estletter+']'
488 don_dna_ptr+=1
489 ptr+=4
490 else:
491 dnaletter=original_est[ptr]
492 estletter=dnaletter
493
494 if dna_ptr > acc_pos:
495 original_est_cut+=estletter # EST letter
496 else:
497 if don_dna[don_dna_ptr]==estletter:
498 original_est_cut += estletter # EST letter
499 else:
500 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
501 num_mismatch += 1
502 #print '('+don_dna[don_dna_ptr]+estletter+')'
503 don_dna_ptr+=1
504
505 ptr+=1
506
507 if estletter=='-':
508 dna_ptr+=1
509 elif dnaletter=='-':
510 est_ptr+=1
511 else:
512 dna_ptr+=1
513 est_ptr+=1
514
515 if num_mismatch>self.max_mismatch:
516 continue
517 assert(dna_ptr<=len(dna))
518 assert(est_ptr<=len(est))
519
520 #print original_est, original_est_cut
521
522 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
523 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
524
525 alternativeScores.append(score)
526
527 if don_score>=self.splice_stop_thresh:
528 break
529
530 _stop = cpu()
531 self.alternativeScoresTime += _stop-_start
532
533 return alternativeScores
534
535
536 def calcAlignmentScore(self,alignment):
537 """
538 Given an alignment (dna,exons,est) and the current parameter for QPalma
539 this function calculates the dot product of the feature representation of
540 the alignment and the parameter vector i.e the alignment score.
541 """
542
543 start = cpu()
544 run = self.run
545
546 # Lets start calculation
547 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
548
549 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
550
551 stop = cpu()
552 self.calcAlignmentScoreTime += stop-start
553
554 return score
555
556
557 def cpu():
558 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
559 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
560
561
562 if __name__ == '__main__':
563 #run_fname = sys.argv[1]
564 #data_fname = sys.argv[2]
565 #param_filename = sys.argv[3]
566
567 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
568 jp = os.path.join
569
570 run_fname = jp(dir,'run_object.pickle')
571 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_1k'
572
573 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
574 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
575
576 param_fname = jp(dir,'param_500.pickle')
577
578 result_spliced_fname = 'splicedReads.heuristic'
579 result_unspliced_fname = 'unsplicedReads.heuristic'
580
581 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname,result_spliced_fname,result_unspliced_fname)
582
583 start = cpu()
584 ph1.filter()
585 stop = cpu()
586
587 print 'total time elapsed: %f' % (stop-start)
588 print 'time spend for get seq: %f' % ph1.get_time
589 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
590 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
591 print 'time spend for count time: %f' % ph1.count_time
592 print 'time spend for init time: %f' % ph1.init_time
593 print 'time spend for read_parsing time: %f' % ph1.read_parsing
594 print 'time spend for main_loop time: %f' % ph1.main_loop
595 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
596
597 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
598 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
599 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
600 print 'time spend forarray_stuff time: %f'% ph1.array_stuff
601 #import cProfile
602 #cProfile.run('ph1.filter()')
603
604 #import hotshot
605 #p = hotshot.Profile('profile.log')
606 #p.runcall(ph1.filter)