+ modified Lookup table to work with negative strands
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10 import resource
11
12 from qpalma.computeSpliceWeights import *
13 from qpalma.set_param_palma import *
14 from qpalma.computeSpliceAlignWithQuality import *
15
16 from numpy.matlib import mat,zeros,ones,inf
17 from numpy import inf,mean
18
19 from qpalma.parsers import *
20
21 from ParaParser import *
22
23 from Lookup import Lookup
24
25 from qpalma.sequence_utils import reverse_complement,unbracket_seq
26
27
28
29 class BracketWrapper:
30 fields = ['id', 'chr', 'pos', 'strand', 'mismatches', 'length',\
31 'offset', 'seq', 'prb', 'cal_prb', 'chastity']
32
33 def __init__(self,filename):
34 self.parser = ParaParser("%lu%d%d%s%d%d%d%s%s%s%s",self.fields,len(self.fields),IN_VECTOR)
35 self.parser.parseFile(filename)
36
37 def __len__(self):
38 return self.parser.getSize(IN_VECTOR)
39
40 def __getitem__(self,key):
41 return self.parser.fetchEntry(key)
42
43 def __iter__(self):
44 self.counter = 0
45 self.size = self.parser.getSize(IN_VECTOR)
46 return self
47
48 def next(self):
49 if not self.counter < self.size:
50 raise StopIteration
51 return_val = self.parser.fetchEntry(self.counter)
52 self.counter += 1
53 return return_val
54
55
56 class PipelineHeuristic:
57 """
58 This class wraps the filter which decides whether an alignment found by
59 vmatch is spliced an should be then newly aligned using QPalma or not.
60 """
61
62 def __init__(self,run_fname,data_fname,param_fname,result_fname):
63 """
64 We need a run object holding information about the nr. of support points
65 etc.
66 """
67
68 run = cPickle.load(open(run_fname))
69 self.run = run
70
71 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
72
73
74 start = cpu()
75
76 # old version
77 #self.all_remapped_reads = parse_map_vm_heuristic(data_fname)
78 self.all_remapped_reads = BracketWrapper(data_fname)
79
80 stop = cpu()
81
82 print 'parsed %d reads in %f sec' % (len(self.all_remapped_reads),stop-start)
83
84
85 start = cpu()
86 self.lt1 = Lookup(dna_flat_files)
87 stop = cpu()
88 print 'prefetched sequence and splice data in %f sec' % (stop-start)
89
90 self.result_spliced_fh = open('%s.spliced'%result_fname,'w+')
91 self.result_unspliced_fh = open('%s.unspliced'%result_fname,'w+')
92
93 start = cpu()
94
95 self.data_fname = data_fname
96
97 self.param = cPickle.load(open(param_fname))
98
99 # Set the parameters such as limits penalties for the Plifs
100 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
101
102 self.h = h
103 self.d = d
104 self.a = a
105 self.mmatrix = mmatrix
106 self.qualityPlifs = qualityPlifs
107
108 #self.read_size = 36
109
110 # parameters of the heuristics to decide whether the read is spliced
111 #self.splice_thresh = 0.005
112 self.splice_thresh = 0.01
113 self.max_intron_size = 2000
114 self.max_mismatch = 2
115 #self.splice_stop_thresh = 0.99
116 self.splice_stop_thresh = 0.8
117 self.spliced_bias = 0.0
118
119 param_lines = [\
120 ('%f','splice_thresh',self.splice_thresh),\
121 ('%d','max_intron_size',self.max_intron_size),\
122 ('%d','max_mismatch',self.max_mismatch),\
123 ('%f','splice_stop_thresh',self.splice_stop_thresh),\
124 ('%f','spliced_bias',self.spliced_bias)]
125
126 param_lines = [('# %s: '+form+'\n')%(name,val) for form,name,val in param_lines]
127 param_lines.append('# data: %s\n'%self.data_fname)
128 param_lines.append('# param: %s\n'%param_fname)
129
130 #pdb.set_trace()
131
132 for p_line in param_lines:
133 self.result_spliced_fh.write(p_line)
134 self.result_unspliced_fh.write(p_line)
135
136 #self.original_reads = {}
137
138 # we do not have this information
139 #for line in open(reads_pipeline_fn):
140 # line = line.strip()
141 # id,seq,q1,q2,q3 = line.split()
142 # id = int(id)
143 # self.original_reads[id] = seq
144
145 lengthSP = run['numLengthSuppPoints']
146 donSP = run['numDonSuppPoints']
147 accSP = run['numAccSuppPoints']
148 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
149 numq = run['numQualSuppPoints']
150 totalQualSP = run['totalQualSuppPoints']
151
152 currentPhi = zeros((run['numFeatures'],1))
153 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
154 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
155 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
156 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
157
158 totalQualityPenalties = self.param[-totalQualSP:]
159 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
160 self.currentPhi = currentPhi
161
162 # we want to identify spliced reads
163 # so true pos are spliced reads that are predicted "spliced"
164 self.true_pos = 0
165
166 # as false positives we count all reads that are not spliced but predicted
167 # as "spliced"
168 self.false_pos = 0
169
170 self.true_neg = 0
171 self.false_neg = 0
172
173 # total time spend for get seq and scores
174 self.get_time = 0.0
175 self.calcAlignmentScoreTime = 0.0
176 self.alternativeScoresTime = 0.0
177
178 self.count_time = 0.0
179 self.main_loop = 0.0
180 self.splice_site_time = 0.0
181 self.computeSpliceAlignWithQualityTime = 0.0
182 self.computeSpliceWeightsTime = 0.0
183 self.DotProdTime = 0.0
184 self.array_stuff = 0.0
185 stop = cpu()
186
187 self.init_time = stop-start
188
189 def filter(self):
190 """
191 This method...
192 """
193 run = self.run
194
195
196 ctr = 0
197 unspliced_ctr = 0
198 spliced_ctr = 0
199
200 print 'Starting filtering...'
201 _start = cpu()
202
203 #for readId,currentReadLocations in all_remapped_reads.items():
204 #for location in currentReadLocations[:1]:
205
206 for location,original_line in self.all_remapped_reads:
207
208 if ctr % 1000 == 0:
209 print ctr
210
211 id = location['id']
212 chr = location['chr']
213 pos = location['pos']
214 strand = location['strand']
215 mismatch = location['mismatches']
216 length = location['length']
217 off = location['offset']
218 seq = location['seq']
219 prb = location['prb']
220 cal_prb = location['cal_prb']
221 chastity = location['chastity']
222
223 id = int(id)
224
225 seq = seq.lower()
226
227 strand_map = {'D':'+', 'P':'-'}
228
229 strand = strand_map[strand]
230
231 if not chr in range(1,6):
232 continue
233
234 unb_seq = unbracket_seq(seq)
235
236 # forgot to do this
237 if strand == '-':
238 pos = self.lt1.seqInfo.chromo_sizes[chromo+7]-pos-self.read_size
239 unb_seq = reverse_complement(unb_seq)
240
241 effective_len = len(unb_seq)
242
243 genomicSeq_start = pos
244 genomicSeq_stop = pos+effective_len-1
245
246 start = cpu()
247 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
248 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
249
250 stop = cpu()
251 self.get_time += stop-start
252
253 dna = currentDNASeq
254 exons = zeros((2,1))
255 exons[0,0] = 0
256 exons[1,0] = effective_len
257 est = unb_seq
258 original_est = seq
259 quality = prb
260
261 #pdb.set_trace()
262
263 currentVMatchAlignment = dna, exons, est, original_est, quality,\
264 currentAcc, currentDon
265
266 try:
267 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
268 except:
269 alternativeAlignmentScores = []
270
271
272 if alternativeAlignmentScores == []:
273 # no alignment necessary
274 maxAlternativeAlignmentScore = -inf
275 vMatchScore = 0.0
276 else:
277 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
278 # compute alignment for vmatch unspliced read
279 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
280
281 start = cpu()
282
283 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
284 #print 'all candidates %s' % str(alternativeAlignmentScores)
285
286 new_id = id - 1000000300000
287
288 unspliced = False
289 # unspliced
290 if new_id > 0:
291 unspliced = True
292
293 # Seems that according to our learned parameters VMatch found a good
294 # alignment of the current read
295 if maxAlternativeAlignmentScore < vMatchScore:
296 unspliced_ctr += 1
297
298 self.result_unspliced_fh.write(original_line+'\n')
299
300 if unspliced:
301 self.true_neg += 1
302 else:
303 self.false_neg += 1
304
305 # We found an alternative alignment considering splice sites that scores
306 # higher than the VMatch alignment
307 else:
308 spliced_ctr += 1
309
310 self.result_spliced_fh.write(original_line+'\n')
311
312 if unspliced:
313 self.false_pos += 1
314 else:
315 self.true_pos += 1
316
317 ctr += 1
318 stop = cpu()
319 self.count_time = stop-start
320
321 _stop = cpu()
322 self.main_loop = _stop-_start
323
324 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
325 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
326 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
327
328
329 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
330
331 def signum(a):
332 if a>0:
333 return 1
334 elif a<0:
335 return -1
336 else:
337 return 0
338
339 proximal_acc = []
340 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
341 if currentAcc[idx]>= splice_thresh:
342 proximal_acc.append((idx,currentAcc[idx]))
343
344 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
345 proximal_acc=proximal_acc[-2:]
346
347 distal_acc = []
348 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
349 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
350 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
351
352 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
353 #distal_acc=distal_acc[-2:]
354
355
356 proximal_don = []
357 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
358 if currentDon[idx] >= splice_thresh:
359 proximal_don.append((idx, currentDon[idx]))
360
361 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
362 proximal_don=proximal_don[-2:]
363
364 distal_don = []
365 for idx in xrange(1, max_intron_size):
366 if currentDon[idx] >= splice_thresh and idx>read_size:
367 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
368
369 distal_don.sort(lambda x,y: y[0]-x[0])
370 #distal_don=distal_don[-2:]
371
372 return proximal_acc,proximal_don,distal_acc,distal_don
373
374 def calcAlternativeAlignments(self,location):
375 """
376 Given an alignment proposed by Vmatch this function calculates possible
377 alternative alignments taking into account for example matched
378 donor/acceptor positions.
379 """
380
381 run = self.run
382
383 id = location['id']
384 chr = location['chr']
385 pos = location['pos']
386 strand = location['strand']
387 original_est = location['seq']
388 quality = location['prb']
389 cal_prb = location['cal_prb']
390
391 original_est = original_est.lower()
392 est = unbracket_seq(original_est)
393 effective_len = len(est)
394
395 genomicSeq_start = pos - self.max_intron_size
396 genomicSeq_stop = pos + self.max_intron_size + len(est)
397
398 strand_map = {'D':'+', 'P':'-'}
399 strand = strand_map[strand]
400
401 start = cpu()
402 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
403 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
404 stop = cpu()
405 self.get_time += stop-start
406 dna = currentDNASeq
407
408 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
409
410 alternativeScores = []
411
412 # inlined
413 h = self.h
414 d = self.d
415 a = self.a
416 mmatrix = self.mmatrix
417 qualityPlifs = self.qualityPlifs
418 # inlined
419
420 # find an intron on the 3' end
421 _start = cpu()
422 for (don_pos,don_score) in proximal_don:
423 DonorScore = calculatePlif(d, [don_score])[0]
424
425 for (acc_pos,acc_score,acc_dna) in distal_acc:
426
427 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
428 AcceptorScore = calculatePlif(a, [acc_score])[0]
429
430 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
431
432 # construct a new "original_est"
433 original_est_cut=''
434
435 est_ptr=0
436 dna_ptr=self.max_intron_size
437 ptr=0
438 acc_dna_ptr=0
439 num_mismatch = 0
440
441 while ptr<len(original_est):
442 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
443
444 if original_est[ptr]=='[':
445 dnaletter=original_est[ptr+1]
446 estletter=original_est[ptr+2]
447 if dna_ptr < don_pos:
448 original_est_cut+=original_est[ptr:ptr+4]
449 num_mismatch += 1
450 else:
451 if acc_dna[acc_dna_ptr]==estletter:
452 original_est_cut += estletter # EST letter
453 else:
454 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
455 num_mismatch += 1
456 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
457 acc_dna_ptr+=1
458 ptr+=4
459 else:
460 dnaletter=original_est[ptr]
461 estletter=dnaletter
462
463 if dna_ptr < don_pos:
464 original_est_cut+=estletter # EST letter
465 else:
466 if acc_dna[acc_dna_ptr]==estletter:
467 original_est_cut += estletter # EST letter
468 else:
469 num_mismatch += 1
470 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
471 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
472 acc_dna_ptr+=1
473
474 ptr+=1
475
476 if estletter=='-':
477 dna_ptr+=1
478 elif dnaletter=='-':
479 est_ptr+=1
480 else:
481 dna_ptr+=1
482 est_ptr+=1
483 if num_mismatch>self.max_mismatch:
484 continue
485
486 assert(dna_ptr<=len(dna))
487 assert(est_ptr<=len(est))
488
489 #print original_est, original_est_cut
490
491 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
492 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
493
494 alternativeScores.append(score)
495
496 if acc_score>=self.splice_stop_thresh:
497 break
498
499 _stop = cpu()
500 self.alternativeScoresTime += _stop-_start
501
502 # find an intron on the 5' end
503 _start = cpu()
504 for (acc_pos,acc_score) in proximal_acc:
505
506 AcceptorScore = calculatePlif(a, [acc_score])[0]
507
508 for (don_pos,don_score,don_dna) in distal_don:
509
510 DonorScore = calculatePlif(d, [don_score])[0]
511 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
512
513 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
514
515 # construct a new "original_est"
516 original_est_cut=''
517
518 est_ptr=0
519 dna_ptr=self.max_intron_size
520 ptr=0
521 num_mismatch = 0
522 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
523 while ptr<len(original_est):
524
525 if original_est[ptr]=='[':
526 dnaletter=original_est[ptr+1]
527 estletter=original_est[ptr+2]
528 if dna_ptr > acc_pos:
529 original_est_cut+=original_est[ptr:ptr+4]
530 num_mismatch += 1
531 else:
532 if don_dna[don_dna_ptr]==estletter:
533 original_est_cut += estletter # EST letter
534 else:
535 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
536 num_mismatch += 1
537 #print '['+don_dna[don_dna_ptr]+estletter+']'
538 don_dna_ptr+=1
539 ptr+=4
540 else:
541 dnaletter=original_est[ptr]
542 estletter=dnaletter
543
544 if dna_ptr > acc_pos:
545 original_est_cut+=estletter # EST letter
546 else:
547 if don_dna[don_dna_ptr]==estletter:
548 original_est_cut += estletter # EST letter
549 else:
550 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
551 num_mismatch += 1
552 #print '('+don_dna[don_dna_ptr]+estletter+')'
553 don_dna_ptr+=1
554
555 ptr+=1
556
557 if estletter=='-':
558 dna_ptr+=1
559 elif dnaletter=='-':
560 est_ptr+=1
561 else:
562 dna_ptr+=1
563 est_ptr+=1
564
565 if num_mismatch>self.max_mismatch:
566 continue
567 assert(dna_ptr<=len(dna))
568 assert(est_ptr<=len(est))
569
570 #print original_est, original_est_cut
571
572 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
573 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
574
575 alternativeScores.append(score)
576
577 if don_score>=self.splice_stop_thresh:
578 break
579
580 _stop = cpu()
581 self.alternativeScoresTime += _stop-_start
582
583 return alternativeScores
584
585
586 def calcAlignmentScore(self,alignment):
587 """
588 Given an alignment (dna,exons,est) and the current parameter for QPalma
589 this function calculates the dot product of the feature representation of
590 the alignment and the parameter vector i.e the alignment score.
591 """
592
593 start = cpu()
594 run = self.run
595
596 # Lets start calculation
597 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
598
599 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
600
601 stop = cpu()
602 self.calcAlignmentScoreTime += stop-start
603
604 return score
605
606
607 def cpu():
608 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
609 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
610
611
612 if __name__ == '__main__':
613 if len(sys.argv) != 6:
614 print 'Usage: ./%s data param run spliced.results unspliced.results' % (sys.argv[0])
615 sys.exit(1)
616
617 data_fname = sys.argv[1]
618 param_fname = sys.argv[2]
619 run_fname = sys.argv[3]
620
621 result_spliced_fname = sys.argv[4]
622 result_unspliced_fname = sys.argv[5]
623
624 jp = os.path.join
625
626 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname,result_spliced_fname,result_unspliced_fname)
627
628 start = cpu()
629 ph1.filter()
630 stop = cpu()
631 print 'total time elapsed: %f' % (stop-start)
632
633 print 'time spend for get seq: %f' % ph1.get_time
634 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
635 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
636 print 'time spend for count time: %f' % ph1.count_time
637 print 'time spend for init time: %f' % ph1.init_time
638 print 'time spend for main_loop time: %f' % ph1.main_loop
639 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
640
641 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
642 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
643 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
644 print 'time spend forarray_stuff time: %f'% ph1.array_stuff