f4050bcda0b2b58310d3b45a7e9b08599c97c741
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import resource
10
11 from qpalma.computeSpliceWeights import *
12 from qpalma.set_param_palma import *
13 from qpalma.computeSpliceAlignWithQuality import *
14
15 from numpy.matlib import mat,zeros,ones,inf
16 from numpy import inf,mean
17
18 from ParaParser import ParaParser,IN_VECTOR
19
20 from qpalma.Lookup import LookupTable
21 from qpalma.sequence_utils import SeqSpliceInfo,DataAccessWrapper,reverse_complement,unbracket_seq
22
23
24 def cpu():
25 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
26 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
27
28
29 class BracketWrapper:
30 fields = ['id', 'chr', 'pos', 'strand', 'seq', 'prb']
31
32 def __init__(self,filename):
33 self.parser = ParaParser("%lu%d%d%s%s%s",self.fields,len(self.fields),IN_VECTOR)
34 self.parser.parseFile(filename)
35
36 def __len__(self):
37 return self.parser.getSize(IN_VECTOR)
38
39 def __getitem__(self,key):
40 return self.parser.fetchEntry(key)
41
42 def __iter__(self):
43 self.counter = 0
44 self.size = self.parser.getSize(IN_VECTOR)
45 return self
46
47 def next(self):
48 if not self.counter < self.size:
49 raise StopIteration
50 return_val = self.parser.fetchEntry(self.counter)
51 self.counter += 1
52 return return_val
53
54
55 class PipelineHeuristic:
56 """
57 This class wraps the filter which decides whether an alignment found by
58 vmatch is spliced an should be then newly aligned using QPalma or not.
59 """
60
61 def __init__(self,run_fname,data_fname,param_fname,result_fname):
62 """
63 We need a run object holding information about the nr. of support points
64 etc.
65 """
66
67 run = cPickle.load(open(run_fname))
68 self.run = run
69
70 start = cpu()
71 self.all_remapped_reads = BracketWrapper(data_fname)
72 stop = cpu()
73
74 print 'parsed %d reads in %f sec' % (len(self.all_remapped_reads),stop-start)
75
76 #g_dir = '/fml/ag-raetsch/home/fabio/tmp/Lyrata/contigs'
77 #acc_dir = '/fml/ag-raetsch/home/fabio/tmp/Lyrata/splice_scores/acc'
78 #don_dir = '/fml/ag-raetsch/home/fabio/tmp/Lyrata/splice_scores/don'
79
80 #g_fmt = 'contig%d.dna.flat'
81 #s_fmt = 'contig_%d%s'
82
83 #self.chromo_range = range(0,1099)
84
85 g_dir = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
86 acc_dir = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/acc'
87 don_dir = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/don'
88
89 g_fmt = 'chr%d.dna.flat'
90 s_fmt = 'contig_%d%s'
91
92 self.chromo_range = range(1,6)
93 #self.chromo_range = range(1,2)
94
95 num_chromo = 2
96 accessWrapper = DataAccessWrapper(g_dir,acc_dir,don_dir,g_fmt,s_fmt)
97 self.seqInfo = SeqSpliceInfo(accessWrapper,range(1,num_chromo))
98
99 start = cpu()
100 self.lt1 = LookupTable(g_dir,acc_dir,don_dir,g_fmt,s_fmt,self.chromo_range)
101 stop = cpu()
102
103 print 'prefetched sequence and splice data in %f sec' % (stop-start)
104
105 self.result_spliced_fh = open('%s.spliced'%result_fname,'w+')
106 self.result_unspliced_fh = open('%s.unspliced'%result_fname,'w+')
107
108 start = cpu()
109
110 self.data_fname = data_fname
111
112 self.param = cPickle.load(open(param_fname))
113
114 # Set the parameters such as limits penalties for the Plifs
115 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
116
117 self.h = h
118 self.d = d
119 self.a = a
120 self.mmatrix = mmatrix
121 self.qualityPlifs = qualityPlifs
122
123 self.read_size = 38
124 self.prb_offset = 64
125 #self.prb_offset = 50
126
127 # parameters of the heuristics to decide whether the read is spliced
128 #self.splice_thresh = 0.005
129 self.splice_thresh = 0.01
130 self.max_intron_size = 2000
131 self.max_mismatch = 2
132 #self.splice_stop_thresh = 0.99
133 self.splice_stop_thresh = 0.8
134 self.spliced_bias = 0.0
135
136 param_lines = [\
137 ('%f','splice_thresh',self.splice_thresh),\
138 ('%d','max_intron_size',self.max_intron_size),\
139 ('%d','max_mismatch',self.max_mismatch),\
140 ('%f','splice_stop_thresh',self.splice_stop_thresh),\
141 ('%f','spliced_bias',self.spliced_bias)]
142
143 param_lines = [('# %s: '+form+'\n')%(name,val) for form,name,val in param_lines]
144 param_lines.append('# data: %s\n'%self.data_fname)
145 param_lines.append('# param: %s\n'%param_fname)
146
147 #pdb.set_trace()
148
149 for p_line in param_lines:
150 self.result_spliced_fh.write(p_line)
151 self.result_unspliced_fh.write(p_line)
152
153 #self.original_reads = {}
154
155 # we do not have this information
156 #for line in open(reads_pipeline_fn):
157 # line = line.strip()
158 # id,seq,q1,q2,q3 = line.split()
159 # id = int(id)
160 # self.original_reads[id] = seq
161
162 lengthSP = run['numLengthSuppPoints']
163 donSP = run['numDonSuppPoints']
164 accSP = run['numAccSuppPoints']
165 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
166 numq = run['numQualSuppPoints']
167 totalQualSP = run['totalQualSuppPoints']
168
169 currentPhi = zeros((run['numFeatures'],1))
170 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
171 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
172 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
173 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
174
175 totalQualityPenalties = self.param[-totalQualSP:]
176 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
177 self.currentPhi = currentPhi
178
179 # we want to identify spliced reads
180 # so true pos are spliced reads that are predicted "spliced"
181 self.true_pos = 0
182
183 # as false positives we count all reads that are not spliced but predicted
184 # as "spliced"
185 self.false_pos = 0
186
187 self.true_neg = 0
188 self.false_neg = 0
189
190 # total time spend for get seq and scores
191 self.get_time = 0.0
192 self.calcAlignmentScoreTime = 0.0
193 self.alternativeScoresTime = 0.0
194
195 self.count_time = 0.0
196 self.main_loop = 0.0
197 self.splice_site_time = 0.0
198 self.computeSpliceAlignWithQualityTime = 0.0
199 self.computeSpliceWeightsTime = 0.0
200 self.DotProdTime = 0.0
201 self.array_stuff = 0.0
202 stop = cpu()
203
204 self.init_time = stop-start
205
206 def filter(self):
207 """
208 This method...
209 """
210 run = self.run
211
212 ctr = 0
213 unspliced_ctr = 0
214 spliced_ctr = 0
215
216 print 'Starting filtering...'
217 _start = cpu()
218
219 #for readId,currentReadLocations in all_remapped_reads.items():
220 #for location in currentReadLocations[:1]:
221
222 for location,original_line in self.all_remapped_reads:
223
224 if ctr % 1000 == 0:
225 print ctr
226
227 id = location['id']
228 chr = location['chr']
229 pos = location['pos']
230 strand = location['strand']
231 seq = location['seq']
232 prb = location['prb']
233
234 id = int(id)
235
236 seq = seq.lower()
237
238 strand_map = {'D':'+', 'P':'-'}
239
240 strand = strand_map[strand]
241
242 if not chr in self.chromo_range:
243 continue
244
245 unb_seq = unbracket_seq(seq)
246
247 # forgot to do this
248 if strand == '-':
249 #pos = self.lt1.seqInfo.chromo_sizes[chr]-pos-self.read_size
250 unb_seq = reverse_complement(unb_seq)
251 seq = reverse_complement(seq)
252
253 effective_len = len(unb_seq)
254
255 genomicSeq_start = pos
256 #genomicSeq_stop = pos+effective_len-1
257 genomicSeq_stop = pos+effective_len
258
259 start = cpu()
260 #currentDNASeq_, currentAcc_, currentDon_ = self.seqInfo.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,False)
261 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop)
262
263 # just performed some tests to check LookupTable results
264 #assert currentDNASeq_ == currentDNASeq
265 #assert currentAcc_ == currentAcc_
266 #assert currentDon_ == currentDon_
267
268 stop = cpu()
269 self.get_time += stop-start
270
271 dna = currentDNASeq
272 exons = zeros((2,1))
273 exons[0,0] = 0
274 exons[1,0] = effective_len
275 est = unb_seq
276 original_est = seq
277 quality = map(lambda x:ord(x)-self.prb_offset,prb)
278
279 currentVMatchAlignment = dna, exons, est, original_est, quality,\
280 currentAcc, currentDon
281
282 #pdb.set_trace()
283
284 #alternativeAlignmentScores = self.calcAlternativeAlignments(location)
285
286 try:
287 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
288 except:
289 alternativeAlignmentScores = []
290
291 if alternativeAlignmentScores == []:
292 # no alignment necessary
293 maxAlternativeAlignmentScore = -inf
294 vMatchScore = 0.0
295 else:
296 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
297 # compute alignment for vmatch unspliced read
298 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
299 #print 'vMatchScore/alternativeScore: %f %f %s' % (vMatchScore,maxAlternativeAlignmentScore,strand)
300
301
302 start = cpu()
303
304 #print 'all candidates %s' % str(alternativeAlignmentScores)
305
306 new_id = id - 1000000300000
307
308 unspliced = False
309 # unspliced
310 if new_id > 0:
311 unspliced = True
312
313 # Seems that according to our learned parameters VMatch found a good
314 # alignment of the current read
315 if maxAlternativeAlignmentScore < vMatchScore:
316 unspliced_ctr += 1
317
318 self.result_unspliced_fh.write(original_line+'\n')
319
320 if unspliced:
321 self.true_neg += 1
322 else:
323 self.false_neg += 1
324
325 # We found an alternative alignment considering splice sites that scores
326 # higher than the VMatch alignment
327 else:
328 spliced_ctr += 1
329
330 self.result_spliced_fh.write(original_line+'\n')
331
332 if unspliced:
333 self.false_pos += 1
334 else:
335 self.true_pos += 1
336
337 ctr += 1
338 stop = cpu()
339 self.count_time = stop-start
340
341 _stop = cpu()
342 self.main_loop = _stop-_start
343
344 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
345 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
346 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
347
348
349 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
350
351 def signum(a):
352 if a>0:
353 return 1
354 elif a<0:
355 return -1
356 else:
357 return 0
358
359 proximal_acc = []
360 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
361 if currentAcc[idx]>= splice_thresh:
362 proximal_acc.append((idx,currentAcc[idx]))
363
364 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
365 proximal_acc=proximal_acc[-2:]
366
367 distal_acc = []
368 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
369 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
370 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
371
372 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
373 #distal_acc=distal_acc[-2:]
374
375 proximal_don = []
376 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
377 if currentDon[idx] >= splice_thresh:
378 proximal_don.append((idx, currentDon[idx]))
379
380 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
381 proximal_don=proximal_don[-2:]
382
383 distal_don = []
384 for idx in xrange(1, max_intron_size):
385 if currentDon[idx] >= splice_thresh and idx>read_size:
386 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
387
388 distal_don.sort(lambda x,y: y[0]-x[0])
389 #distal_don=distal_don[-2:]
390
391 return proximal_acc,proximal_don,distal_acc,distal_don
392
393
394 def calcAlternativeAlignments(self,location):
395 """
396 Given an alignment proposed by Vmatch this function calculates possible
397 alternative alignments taking into account for example matched
398 donor/acceptor positions.
399 """
400
401 run = self.run
402
403 id = location['id']
404 chr = location['chr']
405 pos = location['pos']
406 strand = location['strand']
407 original_est = location['seq']
408 quality = location['prb']
409 quality = map(lambda x:ord(x)-self.prb_offset,quality)
410
411 original_est = original_est.lower()
412 est = unbracket_seq(original_est)
413 effective_len = len(est)
414
415 genomicSeq_start = pos - self.max_intron_size
416 genomicSeq_stop = pos + self.max_intron_size + len(est)
417
418 strand_map = {'D':'+', 'P':'-'}
419 strand = strand_map[strand]
420
421 if strand == '-':
422 est = reverse_complement(est)
423 original_est = reverse_complement(original_est)
424
425 start = cpu()
426 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
427 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop)
428 stop = cpu()
429 self.get_time += stop-start
430 dna = currentDNASeq
431
432 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
433
434 alternativeScores = []
435
436 #pdb.set_trace()
437
438 # inlined
439 h = self.h
440 d = self.d
441 a = self.a
442 mmatrix = self.mmatrix
443 qualityPlifs = self.qualityPlifs
444 # inlined
445
446 # find an intron on the 3' end
447 _start = cpu()
448 for (don_pos,don_score) in proximal_don:
449 DonorScore = calculatePlif(d, [don_score])[0]
450
451 for (acc_pos,acc_score,acc_dna) in distal_acc:
452
453 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
454 AcceptorScore = calculatePlif(a, [acc_score])[0]
455
456 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
457
458 # construct a new "original_est"
459 original_est_cut=''
460
461 est_ptr=0
462 dna_ptr=self.max_intron_size
463 ptr=0
464 acc_dna_ptr=0
465 num_mismatch = 0
466
467 while ptr<len(original_est):
468 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
469
470 if original_est[ptr]=='[':
471 dnaletter=original_est[ptr+1]
472 estletter=original_est[ptr+2]
473 if dna_ptr < don_pos:
474 original_est_cut+=original_est[ptr:ptr+4]
475 num_mismatch += 1
476 else:
477 if acc_dna[acc_dna_ptr]==estletter:
478 original_est_cut += estletter # EST letter
479 else:
480 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
481 num_mismatch += 1
482 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
483 acc_dna_ptr+=1
484 ptr+=4
485 else:
486 dnaletter=original_est[ptr]
487 estletter=dnaletter
488
489 if dna_ptr < don_pos:
490 original_est_cut+=estletter # EST letter
491 else:
492 if acc_dna[acc_dna_ptr]==estletter:
493 original_est_cut += estletter # EST letter
494 else:
495 num_mismatch += 1
496 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
497 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
498 acc_dna_ptr+=1
499
500 ptr+=1
501
502 if estletter=='-':
503 dna_ptr+=1
504 elif dnaletter=='-':
505 est_ptr+=1
506 else:
507 dna_ptr+=1
508 est_ptr+=1
509 if num_mismatch>self.max_mismatch:
510 continue
511
512 assert(dna_ptr<=len(dna))
513 assert(est_ptr<=len(est))
514
515 #print original_est, original_est_cut
516
517 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
518 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
519
520 alternativeScores.append(score)
521
522 if acc_score>=self.splice_stop_thresh:
523 break
524
525 #pdb.set_trace()
526
527 _stop = cpu()
528 self.alternativeScoresTime += _stop-_start
529
530 # find an intron on the 5' end
531 _start = cpu()
532 for (acc_pos,acc_score) in proximal_acc:
533
534 AcceptorScore = calculatePlif(a, [acc_score])[0]
535
536 for (don_pos,don_score,don_dna) in distal_don:
537
538 DonorScore = calculatePlif(d, [don_score])[0]
539 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
540
541 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
542
543 # construct a new "original_est"
544 original_est_cut=''
545
546 est_ptr=0
547 dna_ptr=self.max_intron_size
548 ptr=0
549 num_mismatch = 0
550 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
551 while ptr<len(original_est):
552
553 if original_est[ptr]=='[':
554 dnaletter=original_est[ptr+1]
555 estletter=original_est[ptr+2]
556 if dna_ptr > acc_pos:
557 original_est_cut+=original_est[ptr:ptr+4]
558 num_mismatch += 1
559 else:
560 if don_dna[don_dna_ptr]==estletter:
561 original_est_cut += estletter # EST letter
562 else:
563 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
564 num_mismatch += 1
565 #print '['+don_dna[don_dna_ptr]+estletter+']'
566 don_dna_ptr+=1
567 ptr+=4
568 else:
569 dnaletter=original_est[ptr]
570 estletter=dnaletter
571
572 if dna_ptr > acc_pos:
573 original_est_cut+=estletter # EST letter
574 else:
575 if don_dna[don_dna_ptr]==estletter:
576 original_est_cut += estletter # EST letter
577 else:
578 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
579 num_mismatch += 1
580 #print '('+don_dna[don_dna_ptr]+estletter+')'
581 don_dna_ptr+=1
582
583 ptr+=1
584
585 if estletter=='-':
586 dna_ptr+=1
587 elif dnaletter=='-':
588 est_ptr+=1
589 else:
590 dna_ptr+=1
591 est_ptr+=1
592
593 if num_mismatch>self.max_mismatch:
594 continue
595 assert(dna_ptr<=len(dna))
596 assert(est_ptr<=len(est))
597
598 #print original_est, original_est_cut
599
600 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
601 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
602
603 alternativeScores.append(score)
604
605 if don_score>=self.splice_stop_thresh:
606 break
607
608 _stop = cpu()
609 self.alternativeScoresTime += _stop-_start
610
611 return alternativeScores
612
613
614 def calcAlignmentScore(self,alignment):
615 """
616 Given an alignment (dna,exons,est) and the current parameter for QPalma
617 this function calculates the dot product of the feature representation of
618 the alignment and the parameter vector i.e the alignment score.
619 """
620
621 start = cpu()
622 run = self.run
623
624 # Lets start calculation
625 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
626
627 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
628
629 stop = cpu()
630 self.calcAlignmentScoreTime += stop-start
631
632 return score