+ changed PipelinHeuristic to support new data access functions
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import resource
10
11 from qpalma.computeSpliceWeights import *
12 from qpalma.set_param_palma import *
13 from qpalma.computeSpliceAlignWithQuality import *
14
15 from numpy.matlib import mat,zeros,ones,inf
16 from numpy import inf,mean
17
18 from qpalma.parsers import *
19
20 from ParaParser import *
21
22 from qpalma.Lookup import LookupTable
23
24 from qpalma.sequence_utils import reverse_complement,unbracket_seq
25
26
27 class BracketWrapper:
28 #fields = ['id', 'chr', 'pos', 'strand', 'mismatches', 'length',\
29 #'offset', 'seq', 'prb', 'cal_prb', 'chastity']
30 fields = ['id', 'chr', 'pos', 'strand', 'seq', 'prb']
31
32 def __init__(self,filename):
33 self.parser = ParaParser("%lu%d%d%s%s%s",self.fields,len(self.fields),IN_VECTOR)
34 self.parser.parseFile(filename)
35
36 def __len__(self):
37 return self.parser.getSize(IN_VECTOR)
38
39 def __getitem__(self,key):
40 return self.parser.fetchEntry(key)
41
42 def __iter__(self):
43 self.counter = 0
44 self.size = self.parser.getSize(IN_VECTOR)
45 return self
46
47 def next(self):
48 if not self.counter < self.size:
49 raise StopIteration
50 return_val = self.parser.fetchEntry(self.counter)
51 self.counter += 1
52 return return_val
53
54
55 class PipelineHeuristic:
56 """
57 This class wraps the filter which decides whether an alignment found by
58 vmatch is spliced an should be then newly aligned using QPalma or not.
59 """
60
61 def __init__(self,run_fname,data_fname,param_fname,result_fname):
62 """
63 We need a run object holding information about the nr. of support points
64 etc.
65 """
66
67 run = cPickle.load(open(run_fname))
68 self.run = run
69
70 start = cpu()
71 # old version
72 #self.all_remapped_reads = parse_map_vm_heuristic(data_fname)
73 self.all_remapped_reads = BracketWrapper(data_fname)
74 stop = cpu()
75
76 print 'parsed %d reads in %f sec' % (len(self.all_remapped_reads),stop-start)
77
78 #g_dir = '/fml/ag-raetsch/home/fabio/tmp/Lyrata/contigs'
79 #acc_dir = '/fml/ag-raetsch/home/fabio/tmp/Lyrata/splice_scores/acc'
80 #don_dir = '/fml/ag-raetsch/home/fabio/tmp/Lyrata/splice_scores/don'
81
82 #g_fmt = 'contig%d.dna.flat'
83 #s_fmt = 'contig_%d%s'
84
85 #self.chromo_range = range(0,1099)
86
87 g_dir = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
88 acc_dir = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/acc'
89 don_dir = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/don'
90
91 g_fmt = 'chr%d.dna.flat'
92 s_fmt = 'contig_%d%s'
93
94 self.chromo_range = range(1,6)
95
96 start = cpu()
97 self.lt1 = LookupTable(g_dir,acc_dir,don_dir,g_fmt,s_fmt,self.chromo_range)
98 stop = cpu()
99
100 print 'prefetched sequence and splice data in %f sec' % (stop-start)
101
102 self.result_spliced_fh = open('%s.spliced'%result_fname,'w+')
103 self.result_unspliced_fh = open('%s.unspliced'%result_fname,'w+')
104
105 start = cpu()
106
107 self.data_fname = data_fname
108
109 self.param = cPickle.load(open(param_fname))
110
111 # Set the parameters such as limits penalties for the Plifs
112 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
113
114 self.h = h
115 self.d = d
116 self.a = a
117 self.mmatrix = mmatrix
118 self.qualityPlifs = qualityPlifs
119
120 self.read_size = 38
121
122 # parameters of the heuristics to decide whether the read is spliced
123 #self.splice_thresh = 0.005
124 self.splice_thresh = 0.01
125 self.max_intron_size = 2000
126 self.max_mismatch = 2
127 #self.splice_stop_thresh = 0.99
128 self.splice_stop_thresh = 0.8
129 self.spliced_bias = 0.0
130
131 param_lines = [\
132 ('%f','splice_thresh',self.splice_thresh),\
133 ('%d','max_intron_size',self.max_intron_size),\
134 ('%d','max_mismatch',self.max_mismatch),\
135 ('%f','splice_stop_thresh',self.splice_stop_thresh),\
136 ('%f','spliced_bias',self.spliced_bias)]
137
138 param_lines = [('# %s: '+form+'\n')%(name,val) for form,name,val in param_lines]
139 param_lines.append('# data: %s\n'%self.data_fname)
140 param_lines.append('# param: %s\n'%param_fname)
141
142 #pdb.set_trace()
143
144 for p_line in param_lines:
145 self.result_spliced_fh.write(p_line)
146 self.result_unspliced_fh.write(p_line)
147
148 #self.original_reads = {}
149
150 # we do not have this information
151 #for line in open(reads_pipeline_fn):
152 # line = line.strip()
153 # id,seq,q1,q2,q3 = line.split()
154 # id = int(id)
155 # self.original_reads[id] = seq
156
157 lengthSP = run['numLengthSuppPoints']
158 donSP = run['numDonSuppPoints']
159 accSP = run['numAccSuppPoints']
160 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
161 numq = run['numQualSuppPoints']
162 totalQualSP = run['totalQualSuppPoints']
163
164 currentPhi = zeros((run['numFeatures'],1))
165 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
166 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
167 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
168 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
169
170 totalQualityPenalties = self.param[-totalQualSP:]
171 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
172 self.currentPhi = currentPhi
173
174 # we want to identify spliced reads
175 # so true pos are spliced reads that are predicted "spliced"
176 self.true_pos = 0
177
178 # as false positives we count all reads that are not spliced but predicted
179 # as "spliced"
180 self.false_pos = 0
181
182 self.true_neg = 0
183 self.false_neg = 0
184
185 # total time spend for get seq and scores
186 self.get_time = 0.0
187 self.calcAlignmentScoreTime = 0.0
188 self.alternativeScoresTime = 0.0
189
190 self.count_time = 0.0
191 self.main_loop = 0.0
192 self.splice_site_time = 0.0
193 self.computeSpliceAlignWithQualityTime = 0.0
194 self.computeSpliceWeightsTime = 0.0
195 self.DotProdTime = 0.0
196 self.array_stuff = 0.0
197 stop = cpu()
198
199 self.init_time = stop-start
200
201 def filter(self):
202 """
203 This method...
204 """
205 run = self.run
206
207
208 ctr = 0
209 unspliced_ctr = 0
210 spliced_ctr = 0
211
212 print 'Starting filtering...'
213 _start = cpu()
214
215 #for readId,currentReadLocations in all_remapped_reads.items():
216 #for location in currentReadLocations[:1]:
217
218 for location,original_line in self.all_remapped_reads:
219
220 if ctr % 1000 == 0:
221 print ctr
222
223 id = location['id']
224 chr = location['chr']
225 pos = location['pos']
226 strand = location['strand']
227 #mismatch = location['mismatches']
228 #length = location['length']
229 #off = location['offset']
230 seq = location['seq']
231 prb = location['prb']
232 #cal_prb = location['cal_prb']
233 #chastity = location['chastity']
234
235 id = int(id)
236
237 seq = seq.lower()
238
239 strand_map = {'D':'+', 'P':'-'}
240
241 strand = strand_map[strand]
242
243 if not chr in self.chromo_range:
244 continue
245
246 unb_seq = unbracket_seq(seq)
247
248 # forgot to do this
249 if strand == '-':
250 pos = self.lt1.seqInfo.chromo_sizes[chr]-pos-self.read_size
251 unb_seq = reverse_complement(unb_seq)
252
253 effective_len = len(unb_seq)
254
255 genomicSeq_start = pos
256 genomicSeq_stop = pos+effective_len-1
257
258 start = cpu()
259 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
260 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
261
262 stop = cpu()
263 self.get_time += stop-start
264
265 dna = currentDNASeq
266 exons = zeros((2,1))
267 exons[0,0] = 0
268 exons[1,0] = effective_len
269 est = unb_seq
270 original_est = seq
271 quality = map(lambda x:ord(x)-50,prb)
272
273 #pdb.set_trace()
274
275 currentVMatchAlignment = dna, exons, est, original_est, quality,\
276 currentAcc, currentDon
277
278 try:
279 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
280 except:
281 alternativeAlignmentScores = []
282
283
284 if alternativeAlignmentScores == []:
285 # no alignment necessary
286 maxAlternativeAlignmentScore = -inf
287 vMatchScore = 0.0
288 else:
289 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
290 # compute alignment for vmatch unspliced read
291 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
292
293 start = cpu()
294
295 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
296 #print 'all candidates %s' % str(alternativeAlignmentScores)
297
298 new_id = id - 1000000300000
299
300 unspliced = False
301 # unspliced
302 if new_id > 0:
303 unspliced = True
304
305 # Seems that according to our learned parameters VMatch found a good
306 # alignment of the current read
307 if maxAlternativeAlignmentScore < vMatchScore:
308 unspliced_ctr += 1
309
310 self.result_unspliced_fh.write(original_line+'\n')
311
312 if unspliced:
313 self.true_neg += 1
314 else:
315 self.false_neg += 1
316
317 # We found an alternative alignment considering splice sites that scores
318 # higher than the VMatch alignment
319 else:
320 spliced_ctr += 1
321
322 self.result_spliced_fh.write(original_line+'\n')
323
324 if unspliced:
325 self.false_pos += 1
326 else:
327 self.true_pos += 1
328
329 ctr += 1
330 stop = cpu()
331 self.count_time = stop-start
332
333 _stop = cpu()
334 self.main_loop = _stop-_start
335
336 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
337 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
338 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
339
340
341 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
342
343 def signum(a):
344 if a>0:
345 return 1
346 elif a<0:
347 return -1
348 else:
349 return 0
350
351 proximal_acc = []
352 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
353 if currentAcc[idx]>= splice_thresh:
354 proximal_acc.append((idx,currentAcc[idx]))
355
356 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
357 proximal_acc=proximal_acc[-2:]
358
359 distal_acc = []
360 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
361 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
362 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
363
364 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
365 #distal_acc=distal_acc[-2:]
366
367
368 proximal_don = []
369 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
370 if currentDon[idx] >= splice_thresh:
371 proximal_don.append((idx, currentDon[idx]))
372
373 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
374 proximal_don=proximal_don[-2:]
375
376 distal_don = []
377 for idx in xrange(1, max_intron_size):
378 if currentDon[idx] >= splice_thresh and idx>read_size:
379 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
380
381 distal_don.sort(lambda x,y: y[0]-x[0])
382 #distal_don=distal_don[-2:]
383
384 return proximal_acc,proximal_don,distal_acc,distal_don
385
386 def calcAlternativeAlignments(self,location):
387 """
388 Given an alignment proposed by Vmatch this function calculates possible
389 alternative alignments taking into account for example matched
390 donor/acceptor positions.
391 """
392
393 run = self.run
394
395 id = location['id']
396 chr = location['chr']
397 pos = location['pos']
398 strand = location['strand']
399 original_est = location['seq']
400 quality = location['prb']
401 quality = map(lambda x:ord(x)-50,quality)
402 #cal_prb = location['cal_prb']
403
404 original_est = original_est.lower()
405 est = unbracket_seq(original_est)
406 effective_len = len(est)
407
408 genomicSeq_start = pos - self.max_intron_size
409 genomicSeq_stop = pos + self.max_intron_size + len(est)
410
411 strand_map = {'D':'+', 'P':'-'}
412 strand = strand_map[strand]
413
414 start = cpu()
415 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
416 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
417 stop = cpu()
418 self.get_time += stop-start
419 dna = currentDNASeq
420
421 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
422
423 alternativeScores = []
424
425 # inlined
426 h = self.h
427 d = self.d
428 a = self.a
429 mmatrix = self.mmatrix
430 qualityPlifs = self.qualityPlifs
431 # inlined
432
433 # find an intron on the 3' end
434 _start = cpu()
435 for (don_pos,don_score) in proximal_don:
436 DonorScore = calculatePlif(d, [don_score])[0]
437
438 for (acc_pos,acc_score,acc_dna) in distal_acc:
439
440 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
441 AcceptorScore = calculatePlif(a, [acc_score])[0]
442
443 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
444
445 # construct a new "original_est"
446 original_est_cut=''
447
448 est_ptr=0
449 dna_ptr=self.max_intron_size
450 ptr=0
451 acc_dna_ptr=0
452 num_mismatch = 0
453
454 while ptr<len(original_est):
455 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
456
457 if original_est[ptr]=='[':
458 dnaletter=original_est[ptr+1]
459 estletter=original_est[ptr+2]
460 if dna_ptr < don_pos:
461 original_est_cut+=original_est[ptr:ptr+4]
462 num_mismatch += 1
463 else:
464 if acc_dna[acc_dna_ptr]==estletter:
465 original_est_cut += estletter # EST letter
466 else:
467 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
468 num_mismatch += 1
469 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
470 acc_dna_ptr+=1
471 ptr+=4
472 else:
473 dnaletter=original_est[ptr]
474 estletter=dnaletter
475
476 if dna_ptr < don_pos:
477 original_est_cut+=estletter # EST letter
478 else:
479 if acc_dna[acc_dna_ptr]==estletter:
480 original_est_cut += estletter # EST letter
481 else:
482 num_mismatch += 1
483 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
484 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
485 acc_dna_ptr+=1
486
487 ptr+=1
488
489 if estletter=='-':
490 dna_ptr+=1
491 elif dnaletter=='-':
492 est_ptr+=1
493 else:
494 dna_ptr+=1
495 est_ptr+=1
496 if num_mismatch>self.max_mismatch:
497 continue
498
499 assert(dna_ptr<=len(dna))
500 assert(est_ptr<=len(est))
501
502 #print original_est, original_est_cut
503
504 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
505 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
506
507 alternativeScores.append(score)
508
509 if acc_score>=self.splice_stop_thresh:
510 break
511
512 _stop = cpu()
513 self.alternativeScoresTime += _stop-_start
514
515 # find an intron on the 5' end
516 _start = cpu()
517 for (acc_pos,acc_score) in proximal_acc:
518
519 AcceptorScore = calculatePlif(a, [acc_score])[0]
520
521 for (don_pos,don_score,don_dna) in distal_don:
522
523 DonorScore = calculatePlif(d, [don_score])[0]
524 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
525
526 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
527
528 # construct a new "original_est"
529 original_est_cut=''
530
531 est_ptr=0
532 dna_ptr=self.max_intron_size
533 ptr=0
534 num_mismatch = 0
535 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
536 while ptr<len(original_est):
537
538 if original_est[ptr]=='[':
539 dnaletter=original_est[ptr+1]
540 estletter=original_est[ptr+2]
541 if dna_ptr > acc_pos:
542 original_est_cut+=original_est[ptr:ptr+4]
543 num_mismatch += 1
544 else:
545 if don_dna[don_dna_ptr]==estletter:
546 original_est_cut += estletter # EST letter
547 else:
548 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
549 num_mismatch += 1
550 #print '['+don_dna[don_dna_ptr]+estletter+']'
551 don_dna_ptr+=1
552 ptr+=4
553 else:
554 dnaletter=original_est[ptr]
555 estletter=dnaletter
556
557 if dna_ptr > acc_pos:
558 original_est_cut+=estletter # EST letter
559 else:
560 if don_dna[don_dna_ptr]==estletter:
561 original_est_cut += estletter # EST letter
562 else:
563 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
564 num_mismatch += 1
565 #print '('+don_dna[don_dna_ptr]+estletter+')'
566 don_dna_ptr+=1
567
568 ptr+=1
569
570 if estletter=='-':
571 dna_ptr+=1
572 elif dnaletter=='-':
573 est_ptr+=1
574 else:
575 dna_ptr+=1
576 est_ptr+=1
577
578 if num_mismatch>self.max_mismatch:
579 continue
580 assert(dna_ptr<=len(dna))
581 assert(est_ptr<=len(est))
582
583 #print original_est, original_est_cut
584
585 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
586 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
587
588 alternativeScores.append(score)
589
590 if don_score>=self.splice_stop_thresh:
591 break
592
593 _stop = cpu()
594 self.alternativeScoresTime += _stop-_start
595
596 return alternativeScores
597
598
599 def calcAlignmentScore(self,alignment):
600 """
601 Given an alignment (dna,exons,est) and the current parameter for QPalma
602 this function calculates the dot product of the feature representation of
603 the alignment and the parameter vector i.e the alignment score.
604 """
605
606 start = cpu()
607 run = self.run
608
609 # Lets start calculation
610 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
611
612 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
613
614 stop = cpu()
615 self.calcAlignmentScoreTime += stop-start
616
617 return score
618
619
620 def cpu():
621 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
622 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
623
624
625 if __name__ == '__main__':
626 if len(sys.argv) != 6:
627 print 'Usage: ./%s data param run spliced.results unspliced.results' % (sys.argv[0])
628 sys.exit(1)
629
630 data_fname = sys.argv[1]
631 param_fname = sys.argv[2]
632 run_fname = sys.argv[3]
633
634 result_spliced_fname = sys.argv[4]
635 result_unspliced_fname = sys.argv[5]
636
637 jp = os.path.join
638
639 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname,result_spliced_fname,result_unspliced_fname)
640
641 start = cpu()
642 ph1.filter()
643 stop = cpu()
644 print 'total time elapsed: %f' % (stop-start)
645
646 print 'time spend for get seq: %f' % ph1.get_time
647 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
648 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
649 print 'time spend for count time: %f' % ph1.count_time
650 print 'time spend for init time: %f' % ph1.init_time
651 print 'time spend for main_loop time: %f' % ph1.main_loop
652 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
653
654 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
655 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
656 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
657 print 'time spend forarray_stuff time: %f'% ph1.array_stuff