+ added some testcases
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import resource
10
11 from qpalma.computeSpliceWeights import *
12 from qpalma.set_param_palma import *
13 from qpalma.computeSpliceAlignWithQuality import *
14
15 from numpy.matlib import mat,zeros,ones,inf
16 from numpy import inf,mean
17
18 from qpalma.parsers import *
19
20 from ParaParser import *
21
22 from qpalma.Lookup import LookupTable
23
24 from qpalma.sequence_utils import reverse_complement,unbracket_seq
25
26
27 class BracketWrapper:
28 #fields = ['id', 'chr', 'pos', 'strand', 'mismatches', 'length',\
29 #'offset', 'seq', 'prb', 'cal_prb', 'chastity']
30 fields = ['id', 'chr', 'pos', 'strand', 'seq', 'prb']
31
32 def __init__(self,filename):
33 self.parser = ParaParser("%lu%d%d%s%s%s",self.fields,len(self.fields),IN_VECTOR)
34 self.parser.parseFile(filename)
35
36 def __len__(self):
37 return self.parser.getSize(IN_VECTOR)
38
39 def __getitem__(self,key):
40 return self.parser.fetchEntry(key)
41
42 def __iter__(self):
43 self.counter = 0
44 self.size = self.parser.getSize(IN_VECTOR)
45 return self
46
47 def next(self):
48 if not self.counter < self.size:
49 raise StopIteration
50 return_val = self.parser.fetchEntry(self.counter)
51 self.counter += 1
52 return return_val
53
54
55 class PipelineHeuristic:
56 """
57 This class wraps the filter which decides whether an alignment found by
58 vmatch is spliced an should be then newly aligned using QPalma or not.
59 """
60
61 def __init__(self,run_fname,data_fname,param_fname,result_fname):
62 """
63 We need a run object holding information about the nr. of support points
64 etc.
65 """
66
67 run = cPickle.load(open(run_fname))
68 self.run = run
69
70 start = cpu()
71 # old version
72 #self.all_remapped_reads = parse_map_vm_heuristic(data_fname)
73 self.all_remapped_reads = BracketWrapper(data_fname)
74 stop = cpu()
75
76 print 'parsed %d reads in %f sec' % (len(self.all_remapped_reads),stop-start)
77
78 #g_dir = '/fml/ag-raetsch/home/fabio/tmp/Lyrata/contigs'
79 #acc_dir = '/fml/ag-raetsch/home/fabio/tmp/Lyrata/splice_scores/acc'
80 #don_dir = '/fml/ag-raetsch/home/fabio/tmp/Lyrata/splice_scores/don'
81
82 #g_fmt = 'contig%d.dna.flat'
83 #s_fmt = 'contig_%d%s'
84
85 #self.chromo_range = range(0,1099)
86
87 g_dir = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
88 acc_dir = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/acc'
89 don_dir = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/don'
90
91 g_fmt = 'chr%d.dna.flat'
92 s_fmt = 'contig_%d%s'
93
94 self.chromo_range = range(1,6)
95
96 start = cpu()
97 self.lt1 = LookupTable(g_dir,acc_dir,don_dir,g_fmt,s_fmt,self.chromo_range)
98 stop = cpu()
99
100 print 'prefetched sequence and splice data in %f sec' % (stop-start)
101
102 self.result_spliced_fh = open('%s.spliced'%result_fname,'w+')
103 self.result_unspliced_fh = open('%s.unspliced'%result_fname,'w+')
104
105 start = cpu()
106
107 self.data_fname = data_fname
108
109 self.param = cPickle.load(open(param_fname))
110
111 # Set the parameters such as limits penalties for the Plifs
112 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
113
114 self.h = h
115 self.d = d
116 self.a = a
117 self.mmatrix = mmatrix
118 self.qualityPlifs = qualityPlifs
119
120 self.read_size = 38
121
122 # parameters of the heuristics to decide whether the read is spliced
123 #self.splice_thresh = 0.005
124 self.splice_thresh = 0.01
125 self.max_intron_size = 2000
126 self.max_mismatch = 2
127 #self.splice_stop_thresh = 0.99
128 self.splice_stop_thresh = 0.8
129 self.spliced_bias = 0.0
130
131 param_lines = [\
132 ('%f','splice_thresh',self.splice_thresh),\
133 ('%d','max_intron_size',self.max_intron_size),\
134 ('%d','max_mismatch',self.max_mismatch),\
135 ('%f','splice_stop_thresh',self.splice_stop_thresh),\
136 ('%f','spliced_bias',self.spliced_bias)]
137
138 param_lines = [('# %s: '+form+'\n')%(name,val) for form,name,val in param_lines]
139 param_lines.append('# data: %s\n'%self.data_fname)
140 param_lines.append('# param: %s\n'%param_fname)
141
142 #pdb.set_trace()
143
144 for p_line in param_lines:
145 self.result_spliced_fh.write(p_line)
146 self.result_unspliced_fh.write(p_line)
147
148 #self.original_reads = {}
149
150 # we do not have this information
151 #for line in open(reads_pipeline_fn):
152 # line = line.strip()
153 # id,seq,q1,q2,q3 = line.split()
154 # id = int(id)
155 # self.original_reads[id] = seq
156
157 lengthSP = run['numLengthSuppPoints']
158 donSP = run['numDonSuppPoints']
159 accSP = run['numAccSuppPoints']
160 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
161 numq = run['numQualSuppPoints']
162 totalQualSP = run['totalQualSuppPoints']
163
164 currentPhi = zeros((run['numFeatures'],1))
165 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
166 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
167 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
168 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
169
170 totalQualityPenalties = self.param[-totalQualSP:]
171 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
172 self.currentPhi = currentPhi
173
174 # we want to identify spliced reads
175 # so true pos are spliced reads that are predicted "spliced"
176 self.true_pos = 0
177
178 # as false positives we count all reads that are not spliced but predicted
179 # as "spliced"
180 self.false_pos = 0
181
182 self.true_neg = 0
183 self.false_neg = 0
184
185 # total time spend for get seq and scores
186 self.get_time = 0.0
187 self.calcAlignmentScoreTime = 0.0
188 self.alternativeScoresTime = 0.0
189
190 self.count_time = 0.0
191 self.main_loop = 0.0
192 self.splice_site_time = 0.0
193 self.computeSpliceAlignWithQualityTime = 0.0
194 self.computeSpliceWeightsTime = 0.0
195 self.DotProdTime = 0.0
196 self.array_stuff = 0.0
197 stop = cpu()
198
199 self.init_time = stop-start
200
201 def filter(self):
202 """
203 This method...
204 """
205 run = self.run
206
207
208 ctr = 0
209 unspliced_ctr = 0
210 spliced_ctr = 0
211
212 print 'Starting filtering...'
213 _start = cpu()
214
215 #for readId,currentReadLocations in all_remapped_reads.items():
216 #for location in currentReadLocations[:1]:
217
218 for location,original_line in self.all_remapped_reads:
219
220 if ctr % 1000 == 0:
221 print ctr
222
223 id = location['id']
224 chr = location['chr']
225 pos = location['pos']
226 strand = location['strand']
227 #mismatch = location['mismatches']
228 #length = location['length']
229 #off = location['offset']
230 seq = location['seq']
231 prb = location['prb']
232 #cal_prb = location['cal_prb']
233 #chastity = location['chastity']
234
235 id = int(id)
236
237 seq = seq.lower()
238
239 strand_map = {'D':'+', 'P':'-'}
240
241 strand = strand_map[strand]
242
243 if not chr in self.chromo_range:
244 continue
245
246 unb_seq = unbracket_seq(seq)
247
248 # forgot to do this
249 if strand == '-':
250 #pos = self.lt1.seqInfo.chromo_sizes[chr]-pos-self.read_size
251 unb_seq = reverse_complement(unb_seq)
252
253 effective_len = len(unb_seq)
254
255 genomicSeq_start = pos
256 #genomicSeq_stop = pos+effective_len-1
257 genomicSeq_stop = pos+effective_len
258
259 start = cpu()
260 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
261 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
262
263 stop = cpu()
264 self.get_time += stop-start
265
266 dna = currentDNASeq
267 exons = zeros((2,1))
268 exons[0,0] = 0
269 exons[1,0] = effective_len
270 est = unb_seq
271 original_est = seq
272 quality = map(lambda x:ord(x)-64,prb)
273
274 #pdb.set_trace()
275
276 currentVMatchAlignment = dna, exons, est, original_est, quality,\
277 currentAcc, currentDon
278
279 try:
280 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
281 except:
282 alternativeAlignmentScores = []
283
284
285 if alternativeAlignmentScores == []:
286 # no alignment necessary
287 maxAlternativeAlignmentScore = -inf
288 vMatchScore = 0.0
289 else:
290 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
291 # compute alignment for vmatch unspliced read
292 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
293
294 start = cpu()
295
296 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
297 #print 'all candidates %s' % str(alternativeAlignmentScores)
298
299 new_id = id - 1000000300000
300
301 unspliced = False
302 # unspliced
303 if new_id > 0:
304 unspliced = True
305
306 # Seems that according to our learned parameters VMatch found a good
307 # alignment of the current read
308 if maxAlternativeAlignmentScore < vMatchScore:
309 unspliced_ctr += 1
310
311 self.result_unspliced_fh.write(original_line+'\n')
312
313 if unspliced:
314 self.true_neg += 1
315 else:
316 self.false_neg += 1
317
318 # We found an alternative alignment considering splice sites that scores
319 # higher than the VMatch alignment
320 else:
321 spliced_ctr += 1
322
323 self.result_spliced_fh.write(original_line+'\n')
324
325 if unspliced:
326 self.false_pos += 1
327 else:
328 self.true_pos += 1
329
330 ctr += 1
331 stop = cpu()
332 self.count_time = stop-start
333
334 _stop = cpu()
335 self.main_loop = _stop-_start
336
337 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
338 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
339 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
340
341
342 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
343
344 def signum(a):
345 if a>0:
346 return 1
347 elif a<0:
348 return -1
349 else:
350 return 0
351
352 proximal_acc = []
353 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
354 if currentAcc[idx]>= splice_thresh:
355 proximal_acc.append((idx,currentAcc[idx]))
356
357 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
358 proximal_acc=proximal_acc[-2:]
359
360 distal_acc = []
361 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
362 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
363 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
364
365 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
366 #distal_acc=distal_acc[-2:]
367
368
369 proximal_don = []
370 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
371 if currentDon[idx] >= splice_thresh:
372 proximal_don.append((idx, currentDon[idx]))
373
374 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
375 proximal_don=proximal_don[-2:]
376
377 distal_don = []
378 for idx in xrange(1, max_intron_size):
379 if currentDon[idx] >= splice_thresh and idx>read_size:
380 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
381
382 distal_don.sort(lambda x,y: y[0]-x[0])
383 #distal_don=distal_don[-2:]
384
385 return proximal_acc,proximal_don,distal_acc,distal_don
386
387 def calcAlternativeAlignments(self,location):
388 """
389 Given an alignment proposed by Vmatch this function calculates possible
390 alternative alignments taking into account for example matched
391 donor/acceptor positions.
392 """
393
394 run = self.run
395
396 id = location['id']
397 chr = location['chr']
398 pos = location['pos']
399 strand = location['strand']
400 original_est = location['seq']
401 quality = location['prb']
402 quality = map(lambda x:ord(x)-64,quality)
403 #cal_prb = location['cal_prb']
404
405 original_est = original_est.lower()
406 est = unbracket_seq(original_est)
407 effective_len = len(est)
408
409 genomicSeq_start = pos - self.max_intron_size
410 genomicSeq_stop = pos + self.max_intron_size + len(est)
411
412 strand_map = {'D':'+', 'P':'-'}
413 strand = strand_map[strand]
414
415 start = cpu()
416 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
417 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
418 stop = cpu()
419 self.get_time += stop-start
420 dna = currentDNASeq
421
422 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
423
424 alternativeScores = []
425
426 # inlined
427 h = self.h
428 d = self.d
429 a = self.a
430 mmatrix = self.mmatrix
431 qualityPlifs = self.qualityPlifs
432 # inlined
433
434 # find an intron on the 3' end
435 _start = cpu()
436 for (don_pos,don_score) in proximal_don:
437 DonorScore = calculatePlif(d, [don_score])[0]
438
439 for (acc_pos,acc_score,acc_dna) in distal_acc:
440
441 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
442 AcceptorScore = calculatePlif(a, [acc_score])[0]
443
444 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
445
446 # construct a new "original_est"
447 original_est_cut=''
448
449 est_ptr=0
450 dna_ptr=self.max_intron_size
451 ptr=0
452 acc_dna_ptr=0
453 num_mismatch = 0
454
455 while ptr<len(original_est):
456 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
457
458 if original_est[ptr]=='[':
459 dnaletter=original_est[ptr+1]
460 estletter=original_est[ptr+2]
461 if dna_ptr < don_pos:
462 original_est_cut+=original_est[ptr:ptr+4]
463 num_mismatch += 1
464 else:
465 if acc_dna[acc_dna_ptr]==estletter:
466 original_est_cut += estletter # EST letter
467 else:
468 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
469 num_mismatch += 1
470 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
471 acc_dna_ptr+=1
472 ptr+=4
473 else:
474 dnaletter=original_est[ptr]
475 estletter=dnaletter
476
477 if dna_ptr < don_pos:
478 original_est_cut+=estletter # EST letter
479 else:
480 if acc_dna[acc_dna_ptr]==estletter:
481 original_est_cut += estletter # EST letter
482 else:
483 num_mismatch += 1
484 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
485 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
486 acc_dna_ptr+=1
487
488 ptr+=1
489
490 if estletter=='-':
491 dna_ptr+=1
492 elif dnaletter=='-':
493 est_ptr+=1
494 else:
495 dna_ptr+=1
496 est_ptr+=1
497 if num_mismatch>self.max_mismatch:
498 continue
499
500 assert(dna_ptr<=len(dna))
501 assert(est_ptr<=len(est))
502
503 #print original_est, original_est_cut
504
505 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
506 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
507
508 alternativeScores.append(score)
509
510 if acc_score>=self.splice_stop_thresh:
511 break
512
513 _stop = cpu()
514 self.alternativeScoresTime += _stop-_start
515
516 # find an intron on the 5' end
517 _start = cpu()
518 for (acc_pos,acc_score) in proximal_acc:
519
520 AcceptorScore = calculatePlif(a, [acc_score])[0]
521
522 for (don_pos,don_score,don_dna) in distal_don:
523
524 DonorScore = calculatePlif(d, [don_score])[0]
525 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
526
527 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
528
529 # construct a new "original_est"
530 original_est_cut=''
531
532 est_ptr=0
533 dna_ptr=self.max_intron_size
534 ptr=0
535 num_mismatch = 0
536 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
537 while ptr<len(original_est):
538
539 if original_est[ptr]=='[':
540 dnaletter=original_est[ptr+1]
541 estletter=original_est[ptr+2]
542 if dna_ptr > acc_pos:
543 original_est_cut+=original_est[ptr:ptr+4]
544 num_mismatch += 1
545 else:
546 if don_dna[don_dna_ptr]==estletter:
547 original_est_cut += estletter # EST letter
548 else:
549 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
550 num_mismatch += 1
551 #print '['+don_dna[don_dna_ptr]+estletter+']'
552 don_dna_ptr+=1
553 ptr+=4
554 else:
555 dnaletter=original_est[ptr]
556 estletter=dnaletter
557
558 if dna_ptr > acc_pos:
559 original_est_cut+=estletter # EST letter
560 else:
561 if don_dna[don_dna_ptr]==estletter:
562 original_est_cut += estletter # EST letter
563 else:
564 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
565 num_mismatch += 1
566 #print '('+don_dna[don_dna_ptr]+estletter+')'
567 don_dna_ptr+=1
568
569 ptr+=1
570
571 if estletter=='-':
572 dna_ptr+=1
573 elif dnaletter=='-':
574 est_ptr+=1
575 else:
576 dna_ptr+=1
577 est_ptr+=1
578
579 if num_mismatch>self.max_mismatch:
580 continue
581 assert(dna_ptr<=len(dna))
582 assert(est_ptr<=len(est))
583
584 #print original_est, original_est_cut
585
586 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
587 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
588
589 alternativeScores.append(score)
590
591 if don_score>=self.splice_stop_thresh:
592 break
593
594 _stop = cpu()
595 self.alternativeScoresTime += _stop-_start
596
597 return alternativeScores
598
599
600 def calcAlignmentScore(self,alignment):
601 """
602 Given an alignment (dna,exons,est) and the current parameter for QPalma
603 this function calculates the dot product of the feature representation of
604 the alignment and the parameter vector i.e the alignment score.
605 """
606
607 start = cpu()
608 run = self.run
609
610 # Lets start calculation
611 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
612
613 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
614
615 stop = cpu()
616 self.calcAlignmentScoreTime += stop-start
617
618 return score
619
620
621 def cpu():
622 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
623 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
624
625
626 if __name__ == '__main__':
627 if len(sys.argv) != 6:
628 print 'Usage: ./%s data param run spliced.results unspliced.results' % (sys.argv[0])
629 sys.exit(1)
630
631 data_fname = sys.argv[1]
632 param_fname = sys.argv[2]
633 run_fname = sys.argv[3]
634
635 result_spliced_fname = sys.argv[4]
636 result_unspliced_fname = sys.argv[5]
637
638 jp = os.path.join
639
640 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname,result_spliced_fname,result_unspliced_fname)
641
642 start = cpu()
643 ph1.filter()
644 stop = cpu()
645 print 'total time elapsed: %f' % (stop-start)
646
647 print 'time spend for get seq: %f' % ph1.get_time
648 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
649 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
650 print 'time spend for count time: %f' % ph1.count_time
651 print 'time spend for init time: %f' % ph1.init_time
652 print 'time spend for main_loop time: %f' % ph1.main_loop
653 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
654
655 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
656 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
657 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
658 print 'time spend forarray_stuff time: %f'% ph1.array_stuff