ee17503172e27fd0dc55db01a9e88318a6b30e58
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10 import resource
11
12 from qpalma.computeSpliceWeights import *
13 from qpalma.set_param_palma import *
14 from qpalma.computeSpliceAlignWithQuality import *
15
16 from numpy.matlib import mat,zeros,ones,inf
17 from numpy import inf,mean
18
19 from qpalma.parsers import *
20
21 from ParaParser import *
22
23 from Lookup import Lookup
24
25 from qpalma.sequence_utils import reverse_complement,unbracket_seq
26
27
28
29 class BracketWrapper:
30 fields = ['id', 'chr', 'pos', 'strand', 'mismatches', 'length',\
31 'offset', 'seq', 'prb', 'cal_prb', 'chastity']
32
33 def __init__(self,filename):
34 self.parser = ParaParser("%lu%d%d%s%d%d%d%s%s%s%s",self.fields,len(self.fields),IN_VECTOR)
35 self.parser.parseFile(filename)
36
37 def __len__(self):
38 return self.parser.getSize(IN_VECTOR)
39
40 def __getitem__(self,key):
41 return self.parser.fetchEntry(key)
42
43 def __iter__(self):
44 self.counter = 0
45 self.size = self.parser.getSize(IN_VECTOR)
46 return self
47
48 def next(self):
49 if not self.counter < self.size:
50 raise StopIteration
51 return_val = self.parser.fetchEntry(self.counter)
52 self.counter += 1
53 return return_val
54
55
56 class PipelineHeuristic:
57 """
58 This class wraps the filter which decides whether an alignment found by
59 vmatch is spliced an should be then newly aligned using QPalma or not.
60 """
61
62 def __init__(self,run_fname,data_fname,param_fname,result_fname):
63 """
64 We need a run object holding information about the nr. of support points
65 etc.
66 """
67
68 run = cPickle.load(open(run_fname))
69 self.run = run
70
71 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
72
73
74 start = cpu()
75
76 # old version
77 #self.all_remapped_reads = parse_map_vm_heuristic(data_fname)
78 self.all_remapped_reads = BracketWrapper(data_fname)
79
80 stop = cpu()
81
82 print 'parsed %d reads in %f sec' % (len(self.all_remapped_reads),stop-start)
83
84
85 start = cpu()
86 self.lt1 = Lookup(dna_flat_files)
87 stop = cpu()
88 print 'prefetched sequence and splice data in %f sec' % (stop-start)
89
90 self.result_spliced_fh = open('%s.spliced'%result_fname,'w+')
91 self.result_unspliced_fh = open('%s.unspliced'%result_fname,'w+')
92
93 start = cpu()
94
95 self.data_fname = data_fname
96
97 self.param = cPickle.load(open(param_fname))
98
99 # Set the parameters such as limits penalties for the Plifs
100 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
101
102 self.h = h
103 self.d = d
104 self.a = a
105 self.mmatrix = mmatrix
106 self.qualityPlifs = qualityPlifs
107
108 #self.read_size = 36
109
110 # parameters of the heuristics to decide whether the read is spliced
111 #self.splice_thresh = 0.005
112 self.splice_thresh = 0.01
113 self.max_intron_size = 2000
114 self.max_mismatch = 2
115 #self.splice_stop_thresh = 0.99
116 self.splice_stop_thresh = 0.8
117 self.spliced_bias = 0.0
118
119 param_lines = [\
120 ('%f','splice_thresh',self.splice_thresh),\
121 ('%d','max_intron_size',self.max_intron_size),\
122 ('%d','max_mismatch',self.max_mismatch),\
123 ('%f','splice_stop_thresh',self.splice_stop_thresh),\
124 ('%f','spliced_bias',self.spliced_bias)]
125
126 param_lines = [('# %s: '+form+'\n')%(name,val) for form,name,val in param_lines]
127 param_lines.append('# data: %s\n'%self.data_fname)
128 param_lines.append('# param: %s\n'%param_fname)
129
130 #pdb.set_trace()
131
132 for p_line in param_lines:
133 self.result_spliced_fh.write(p_line)
134 self.result_unspliced_fh.write(p_line)
135
136 #self.original_reads = {}
137
138 # we do not have this information
139 #for line in open(reads_pipeline_fn):
140 # line = line.strip()
141 # id,seq,q1,q2,q3 = line.split()
142 # id = int(id)
143 # self.original_reads[id] = seq
144
145 lengthSP = run['numLengthSuppPoints']
146 donSP = run['numDonSuppPoints']
147 accSP = run['numAccSuppPoints']
148 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
149 numq = run['numQualSuppPoints']
150 totalQualSP = run['totalQualSuppPoints']
151
152 currentPhi = zeros((run['numFeatures'],1))
153 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
154 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
155 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
156 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
157
158 totalQualityPenalties = self.param[-totalQualSP:]
159 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
160 self.currentPhi = currentPhi
161
162 # we want to identify spliced reads
163 # so true pos are spliced reads that are predicted "spliced"
164 self.true_pos = 0
165
166 # as false positives we count all reads that are not spliced but predicted
167 # as "spliced"
168 self.false_pos = 0
169
170 self.true_neg = 0
171 self.false_neg = 0
172
173 # total time spend for get seq and scores
174 self.get_time = 0.0
175 self.calcAlignmentScoreTime = 0.0
176 self.alternativeScoresTime = 0.0
177
178 self.count_time = 0.0
179 self.main_loop = 0.0
180 self.splice_site_time = 0.0
181 self.computeSpliceAlignWithQualityTime = 0.0
182 self.computeSpliceWeightsTime = 0.0
183 self.DotProdTime = 0.0
184 self.array_stuff = 0.0
185 stop = cpu()
186
187 self.init_time = stop-start
188
189 def filter(self):
190 """
191 This method...
192 """
193 run = self.run
194
195
196 ctr = 0
197 unspliced_ctr = 0
198 spliced_ctr = 0
199
200 print 'Starting filtering...'
201 _start = cpu()
202
203 #for readId,currentReadLocations in all_remapped_reads.items():
204 #for location in currentReadLocations[:1]:
205
206 for location,original_line in self.all_remapped_reads:
207
208 if ctr % 1000 == 0:
209 print ctr
210
211
212 id = location['id']
213 chr = location['chr']
214 pos = location['pos']
215 strand = location['strand']
216 mismatch = location['mismatches']
217 length = location['length']
218 off = location['offset']
219 seq = location['seq']
220 prb = location['prb']
221 cal_prb = location['cal_prb']
222 chastity = location['chastity']
223
224 id = int(id)
225
226 seq = seq.lower()
227
228 strand_map = {'D':'+', 'P':'-'}
229
230 strand = strand_map[strand]
231
232
233 if not chr in range(1,6):
234 continue
235
236 unb_seq = unbracket_seq(seq)
237
238 # forgot to do this
239 if strand == '-':
240 unb_seq = reverse_complement(unb_seq)
241
242 effective_len = len(unb_seq)
243
244 genomicSeq_start = pos
245 genomicSeq_stop = pos+effective_len-1
246
247 start = cpu()
248 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
249 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
250
251 stop = cpu()
252 self.get_time += stop-start
253
254 dna = currentDNASeq
255 exons = zeros((2,1))
256 exons[0,0] = 0
257 exons[1,0] = effective_len
258 est = unb_seq
259 original_est = seq
260 quality = prb
261
262 #pdb.set_trace()
263
264 currentVMatchAlignment = dna, exons, est, original_est, quality,\
265 currentAcc, currentDon
266
267 try:
268 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
269 except:
270 alternativeAlignmentScores = []
271
272
273 if alternativeAlignmentScores == []:
274 # no alignment necessary
275 maxAlternativeAlignmentScore = -inf
276 vMatchScore = 0.0
277 else:
278 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
279 # compute alignment for vmatch unspliced read
280 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
281
282 start = cpu()
283
284 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
285 #print 'all candidates %s' % str(alternativeAlignmentScores)
286
287 new_id = id - 1000000300000
288
289 unspliced = False
290 # unspliced
291 if new_id > 0:
292 unspliced = True
293
294 # Seems that according to our learned parameters VMatch found a good
295 # alignment of the current read
296 if maxAlternativeAlignmentScore < vMatchScore:
297 unspliced_ctr += 1
298
299 self.result_unspliced_fh.write(original_line+'\n')
300
301 if unspliced:
302 self.true_neg += 1
303 else:
304 self.false_neg += 1
305
306 # We found an alternative alignment considering splice sites that scores
307 # higher than the VMatch alignment
308 else:
309 spliced_ctr += 1
310
311 self.result_spliced_fh.write(original_line+'\n')
312
313 if unspliced:
314 self.false_pos += 1
315 else:
316 self.true_pos += 1
317
318 ctr += 1
319 stop = cpu()
320 self.count_time = stop-start
321
322 _stop = cpu()
323 self.main_loop = _stop-_start
324
325 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
326 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
327 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
328
329
330 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
331
332 def signum(a):
333 if a>0:
334 return 1
335 elif a<0:
336 return -1
337 else:
338 return 0
339
340 proximal_acc = []
341 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
342 if currentAcc[idx]>= splice_thresh:
343 proximal_acc.append((idx,currentAcc[idx]))
344
345 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
346 proximal_acc=proximal_acc[-2:]
347
348 distal_acc = []
349 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
350 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
351 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
352
353 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
354 #distal_acc=distal_acc[-2:]
355
356
357 proximal_don = []
358 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
359 if currentDon[idx] >= splice_thresh:
360 proximal_don.append((idx, currentDon[idx]))
361
362 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
363 proximal_don=proximal_don[-2:]
364
365 distal_don = []
366 for idx in xrange(1, max_intron_size):
367 if currentDon[idx] >= splice_thresh and idx>read_size:
368 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
369
370 distal_don.sort(lambda x,y: y[0]-x[0])
371 #distal_don=distal_don[-2:]
372
373 return proximal_acc,proximal_don,distal_acc,distal_don
374
375 def calcAlternativeAlignments(self,location):
376 """
377 Given an alignment proposed by Vmatch this function calculates possible
378 alternative alignments taking into account for example matched
379 donor/acceptor positions.
380 """
381
382 run = self.run
383
384 id = location['id']
385 chr = location['chr']
386 pos = location['pos']
387 strand = location['strand']
388 original_est = location['seq']
389 quality = location['prb']
390 cal_prb = location['cal_prb']
391
392 original_est = original_est.lower()
393 est = unbracket_seq(original_est)
394 effective_len = len(est)
395
396 genomicSeq_start = pos - self.max_intron_size
397 genomicSeq_stop = pos + self.max_intron_size + len(est)
398
399 strand_map = {'D':'+', 'P':'-'}
400 strand = strand_map[strand]
401
402 start = cpu()
403 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
404 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
405 stop = cpu()
406 self.get_time += stop-start
407 dna = currentDNASeq
408
409 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
410
411 alternativeScores = []
412
413 # inlined
414 h = self.h
415 d = self.d
416 a = self.a
417 mmatrix = self.mmatrix
418 qualityPlifs = self.qualityPlifs
419 # inlined
420
421 # find an intron on the 3' end
422 _start = cpu()
423 for (don_pos,don_score) in proximal_don:
424 DonorScore = calculatePlif(d, [don_score])[0]
425
426 for (acc_pos,acc_score,acc_dna) in distal_acc:
427
428 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
429 AcceptorScore = calculatePlif(a, [acc_score])[0]
430
431 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
432
433 # construct a new "original_est"
434 original_est_cut=''
435
436 est_ptr=0
437 dna_ptr=self.max_intron_size
438 ptr=0
439 acc_dna_ptr=0
440 num_mismatch = 0
441
442 while ptr<len(original_est):
443 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
444
445 if original_est[ptr]=='[':
446 dnaletter=original_est[ptr+1]
447 estletter=original_est[ptr+2]
448 if dna_ptr < don_pos:
449 original_est_cut+=original_est[ptr:ptr+4]
450 num_mismatch += 1
451 else:
452 if acc_dna[acc_dna_ptr]==estletter:
453 original_est_cut += estletter # EST letter
454 else:
455 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
456 num_mismatch += 1
457 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
458 acc_dna_ptr+=1
459 ptr+=4
460 else:
461 dnaletter=original_est[ptr]
462 estletter=dnaletter
463
464 if dna_ptr < don_pos:
465 original_est_cut+=estletter # EST letter
466 else:
467 if acc_dna[acc_dna_ptr]==estletter:
468 original_est_cut += estletter # EST letter
469 else:
470 num_mismatch += 1
471 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
472 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
473 acc_dna_ptr+=1
474
475 ptr+=1
476
477 if estletter=='-':
478 dna_ptr+=1
479 elif dnaletter=='-':
480 est_ptr+=1
481 else:
482 dna_ptr+=1
483 est_ptr+=1
484 if num_mismatch>self.max_mismatch:
485 continue
486
487 assert(dna_ptr<=len(dna))
488 assert(est_ptr<=len(est))
489
490 #print original_est, original_est_cut
491
492 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
493 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
494
495 alternativeScores.append(score)
496
497 if acc_score>=self.splice_stop_thresh:
498 break
499
500 _stop = cpu()
501 self.alternativeScoresTime += _stop-_start
502
503 # find an intron on the 5' end
504 _start = cpu()
505 for (acc_pos,acc_score) in proximal_acc:
506
507 AcceptorScore = calculatePlif(a, [acc_score])[0]
508
509 for (don_pos,don_score,don_dna) in distal_don:
510
511 DonorScore = calculatePlif(d, [don_score])[0]
512 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
513
514 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
515
516 # construct a new "original_est"
517 original_est_cut=''
518
519 est_ptr=0
520 dna_ptr=self.max_intron_size
521 ptr=0
522 num_mismatch = 0
523 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
524 while ptr<len(original_est):
525
526 if original_est[ptr]=='[':
527 dnaletter=original_est[ptr+1]
528 estletter=original_est[ptr+2]
529 if dna_ptr > acc_pos:
530 original_est_cut+=original_est[ptr:ptr+4]
531 num_mismatch += 1
532 else:
533 if don_dna[don_dna_ptr]==estletter:
534 original_est_cut += estletter # EST letter
535 else:
536 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
537 num_mismatch += 1
538 #print '['+don_dna[don_dna_ptr]+estletter+']'
539 don_dna_ptr+=1
540 ptr+=4
541 else:
542 dnaletter=original_est[ptr]
543 estletter=dnaletter
544
545 if dna_ptr > acc_pos:
546 original_est_cut+=estletter # EST letter
547 else:
548 if don_dna[don_dna_ptr]==estletter:
549 original_est_cut += estletter # EST letter
550 else:
551 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
552 num_mismatch += 1
553 #print '('+don_dna[don_dna_ptr]+estletter+')'
554 don_dna_ptr+=1
555
556 ptr+=1
557
558 if estletter=='-':
559 dna_ptr+=1
560 elif dnaletter=='-':
561 est_ptr+=1
562 else:
563 dna_ptr+=1
564 est_ptr+=1
565
566 if num_mismatch>self.max_mismatch:
567 continue
568 assert(dna_ptr<=len(dna))
569 assert(est_ptr<=len(est))
570
571 #print original_est, original_est_cut
572
573 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
574 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
575
576 alternativeScores.append(score)
577
578 if don_score>=self.splice_stop_thresh:
579 break
580
581 _stop = cpu()
582 self.alternativeScoresTime += _stop-_start
583
584 return alternativeScores
585
586
587 def calcAlignmentScore(self,alignment):
588 """
589 Given an alignment (dna,exons,est) and the current parameter for QPalma
590 this function calculates the dot product of the feature representation of
591 the alignment and the parameter vector i.e the alignment score.
592 """
593
594 start = cpu()
595 run = self.run
596
597 # Lets start calculation
598 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
599
600 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
601
602 stop = cpu()
603 self.calcAlignmentScoreTime += stop-start
604
605 return score
606
607
608 def cpu():
609 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
610 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
611
612
613 if __name__ == '__main__':
614 if len(sys.argv) != 6:
615 print 'Usage: ./%s data param run spliced.results unspliced.results' % (sys.argv[0])
616 sys.exit(1)
617
618 data_fname = sys.argv[1]
619 param_fname = sys.argv[2]
620 run_fname = sys.argv[3]
621
622 result_spliced_fname = sys.argv[4]
623 result_unspliced_fname = sys.argv[5]
624
625 jp = os.path.join
626
627 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname,result_spliced_fname,result_unspliced_fname)
628
629 start = cpu()
630 ph1.filter()
631 stop = cpu()
632 print 'total time elapsed: %f' % (stop-start)
633
634 print 'time spend for get seq: %f' % ph1.get_time
635 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
636 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
637 print 'time spend for count time: %f' % ph1.count_time
638 print 'time spend for init time: %f' % ph1.init_time
639 print 'time spend for main_loop time: %f' % ph1.main_loop
640 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
641
642 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
643 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
644 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
645 print 'time spend forarray_stuff time: %f'% ph1.array_stuff