847652a4b09e4f4f0c792c065281d64cf8d3ff5e
[qpalma.git] / qpalma / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 2 of the License, or
7 # (at your option) any later version.
8 #
9 # Written (W) 2008 Gunnar R├Ątsch, Fabio De Bona
10 # Copyright (C) 2008 Max-Planck-Society
11
12 import cPickle
13 import sys
14 import pdb
15 import os
16 import os.path
17 import resource
18
19 from qpalma.computeSpliceWeights import *
20 from qpalma.set_param_palma import *
21 from qpalma.computeSpliceAlignWithQuality import *
22
23 from numpy.matlib import mat,zeros,ones,inf
24 from numpy import inf,mean
25
26 from ParaParser import ParaParser,IN_VECTOR
27
28 from qpalma.Lookup import LookupTable
29 from qpalma.sequence_utils import SeqSpliceInfo,DataAccessWrapper,reverse_complement,unbracket_seq
30
31
32 def cpu():
33 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
34 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
35
36
37 class BracketWrapper:
38 fields = ['id', 'chr', 'pos', 'strand', 'seq', 'prb']
39
40 def __init__(self,filename):
41 self.parser = ParaParser("%lu%d%d%s%s%s",self.fields,len(self.fields),IN_VECTOR)
42 self.parser.parseFile(filename)
43
44 def __len__(self):
45 return self.parser.getSize(IN_VECTOR)
46
47 def __getitem__(self,key):
48 return self.parser.fetchEntry(key)
49
50 def __iter__(self):
51 self.counter = 0
52 self.size = self.parser.getSize(IN_VECTOR)
53 return self
54
55 def next(self):
56 if not self.counter < self.size:
57 raise StopIteration
58 return_val = self.parser.fetchEntry(self.counter)
59 self.counter += 1
60 return return_val
61
62
63 class PipelineHeuristic:
64 """
65 This class wraps the filter which decides whether an alignment found by
66 vmatch is spliced an should be then newly aligned using QPalma or not.
67 """
68
69 def __init__(self,run_fname,data_fname,param_fname,result_fname,settings):
70 """
71 We need a run object holding information about the nr. of support points
72 etc.
73 """
74
75 run = cPickle.load(open(run_fname))
76 self.run = run
77
78 start = cpu()
79 self.all_remapped_reads = BracketWrapper(data_fname)
80 stop = cpu()
81
82 print 'parsed %d reads in %f sec' % (len(self.all_remapped_reads),stop-start)
83
84 self.chromo_range = settings['allowed_fragments']
85
86 accessWrapper = DataAccessWrapper(settings)
87 seqInfo = SeqSpliceInfo(accessWrapper,settings['allowed_fragments'])
88
89 start = cpu()
90 self.lt1 = LookupTable(settings)
91 stop = cpu()
92
93 print 'prefetched sequence and splice data in %f sec' % (stop-start)
94
95 self.result_spliced_fh = open('%s.spliced'%result_fname,'w+')
96 self.result_unspliced_fh = open('%s.unspliced'%result_fname,'w+')
97
98 start = cpu()
99
100 self.data_fname = data_fname
101
102 self.param = cPickle.load(open(param_fname))
103
104 # Set the parameters such as limits penalties for the Plifs
105 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
106
107 self.h = h
108 self.d = d
109 self.a = a
110 self.mmatrix = mmatrix
111 self.qualityPlifs = qualityPlifs
112
113 self.prb_offset = settings['prb_offset']
114
115 # parameters of the heuristics to decide whether the read is spliced
116 #self.splice_thresh = 0.005
117 self.splice_thresh = 0.01
118 self.max_intron_size = 2000
119 self.max_mismatch = 2
120 #self.splice_stop_thresh = 0.99
121 self.splice_stop_thresh = 0.8
122 self.spliced_bias = 0.0
123
124 param_lines = [\
125 ('%f','splice_thresh',self.splice_thresh),\
126 ('%d','max_intron_size',self.max_intron_size),\
127 ('%d','max_mismatch',self.max_mismatch),\
128 ('%f','splice_stop_thresh',self.splice_stop_thresh),\
129 ('%f','spliced_bias',self.spliced_bias)]
130
131 param_lines = [('# %s: '+form+'\n')%(name,val) for form,name,val in param_lines]
132 param_lines.append('# data: %s\n'%self.data_fname)
133 param_lines.append('# param: %s\n'%param_fname)
134
135 #pdb.set_trace()
136
137 for p_line in param_lines:
138 self.result_spliced_fh.write(p_line)
139 self.result_unspliced_fh.write(p_line)
140
141 #self.original_reads = {}
142
143 # we do not have this information
144 #for line in open(reads_pipeline_fn):
145 # line = line.strip()
146 # id,seq,q1,q2,q3 = line.split()
147 # id = int(id)
148 # self.original_reads[id] = seq
149
150 lengthSP = run['numLengthSuppPoints']
151 donSP = run['numDonSuppPoints']
152 accSP = run['numAccSuppPoints']
153 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
154 numq = run['numQualSuppPoints']
155 totalQualSP = run['totalQualSuppPoints']
156
157 currentPhi = zeros((run['numFeatures'],1))
158 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
159 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
160 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
161 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
162
163 totalQualityPenalties = self.param[-totalQualSP:]
164 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
165 self.currentPhi = currentPhi
166
167 # we want to identify spliced reads
168 # so true pos are spliced reads that are predicted "spliced"
169 self.true_pos = 0
170
171 # as false positives we count all reads that are not spliced but predicted
172 # as "spliced"
173 self.false_pos = 0
174
175 self.true_neg = 0
176 self.false_neg = 0
177
178 # total time spend for get seq and scores
179 self.get_time = 0.0
180 self.calcAlignmentScoreTime = 0.0
181 self.alternativeScoresTime = 0.0
182
183 self.count_time = 0.0
184 self.main_loop = 0.0
185 self.splice_site_time = 0.0
186 self.computeSpliceAlignWithQualityTime = 0.0
187 self.computeSpliceWeightsTime = 0.0
188 self.DotProdTime = 0.0
189 self.array_stuff = 0.0
190 stop = cpu()
191
192 self.init_time = stop-start
193
194 def filter(self):
195 """
196 This method...
197 """
198 run = self.run
199
200 ctr = 0
201 unspliced_ctr = 0
202 spliced_ctr = 0
203
204 print 'Starting filtering...'
205 _start = cpu()
206
207 for location,original_line in self.all_remapped_reads:
208
209 if ctr % 1000 == 0:
210 print ctr
211
212 id = location['id']
213 chr = location['chr']
214 pos = location['pos']
215 strand = location['strand']
216 seq = location['seq']
217 prb = location['prb']
218
219 id = int(id)
220
221 seq = seq.lower()
222
223 strand_map = {'D':'+', 'P':'-'}
224
225 strand = strand_map[strand]
226
227 if not chr in self.chromo_range:
228 continue
229
230 unb_seq = unbracket_seq(seq)
231
232 # forgot to do this
233 if strand == '-':
234 unb_seq = reverse_complement(unb_seq)
235 seq = reverse_complement(seq)
236
237 effective_len = len(unb_seq)
238
239 genomicSeq_start = pos
240 #genomicSeq_stop = pos+effective_len-1
241 genomicSeq_stop = pos+effective_len
242
243 start = cpu()
244 #currentDNASeq_, currentAcc_, currentDon_ = self.seqInfo.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,False)
245 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop)
246
247 # just performed some tests to check LookupTable results
248 #assert currentDNASeq_ == currentDNASeq
249 #assert currentAcc_ == currentAcc_
250 #assert currentDon_ == currentDon_
251
252 stop = cpu()
253 self.get_time += stop-start
254
255 dna = currentDNASeq
256 exons = zeros((2,1))
257 exons[0,0] = 0
258 exons[1,0] = effective_len
259 est = unb_seq
260 original_est = seq
261 quality = map(lambda x:ord(x)-self.prb_offset,prb)
262
263 currentVMatchAlignment = dna, exons, est, original_est, quality,\
264 currentAcc, currentDon
265
266 #pdb.set_trace()
267
268 #alternativeAlignmentScores = self.calcAlternativeAlignments(location)
269
270 try:
271 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
272 except:
273 alternativeAlignmentScores = []
274
275 if alternativeAlignmentScores == []:
276 # no alignment necessary
277 maxAlternativeAlignmentScore = -inf
278 vMatchScore = 0.0
279 else:
280 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
281 # compute alignment for vmatch unspliced read
282 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
283 #print 'vMatchScore/alternativeScore: %f %f %s' % (vMatchScore,maxAlternativeAlignmentScore,strand)
284
285
286 start = cpu()
287
288 #print 'all candidates %s' % str(alternativeAlignmentScores)
289
290 new_id = id - 1000000300000
291
292 unspliced = False
293 # unspliced
294 if new_id > 0:
295 unspliced = True
296
297 # Seems that according to our learned parameters VMatch found a good
298 # alignment of the current read
299 if maxAlternativeAlignmentScore < vMatchScore:
300 unspliced_ctr += 1
301
302 self.result_unspliced_fh.write(original_line+'\n')
303
304 if unspliced:
305 self.true_neg += 1
306 else:
307 self.false_neg += 1
308
309 # We found an alternative alignment considering splice sites that scores
310 # higher than the VMatch alignment
311 else:
312 spliced_ctr += 1
313
314 self.result_spliced_fh.write(original_line+'\n')
315
316 if unspliced:
317 self.false_pos += 1
318 else:
319 self.true_pos += 1
320
321 ctr += 1
322 stop = cpu()
323 self.count_time = stop-start
324
325 _stop = cpu()
326 self.main_loop = _stop-_start
327
328 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
329 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
330 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
331
332
333 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
334
335 def signum(a):
336 if a>0:
337 return 1
338 elif a<0:
339 return -1
340 else:
341 return 0
342
343 proximal_acc = []
344 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
345 if currentAcc[idx]>= splice_thresh:
346 proximal_acc.append((idx,currentAcc[idx]))
347
348 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
349 proximal_acc=proximal_acc[-2:]
350
351 distal_acc = []
352 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
353 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
354 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
355
356 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
357 #distal_acc=distal_acc[-2:]
358
359 proximal_don = []
360 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
361 if currentDon[idx] >= splice_thresh:
362 proximal_don.append((idx, currentDon[idx]))
363
364 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
365 proximal_don=proximal_don[-2:]
366
367 distal_don = []
368 for idx in xrange(1, max_intron_size):
369 if currentDon[idx] >= splice_thresh and idx>read_size:
370 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
371
372 distal_don.sort(lambda x,y: y[0]-x[0])
373 #distal_don=distal_don[-2:]
374
375 return proximal_acc,proximal_don,distal_acc,distal_don
376
377
378 def calcAlternativeAlignments(self,location):
379 """
380 Given an alignment proposed by Vmatch this function calculates possible
381 alternative alignments taking into account for example matched
382 donor/acceptor positions.
383 """
384
385 run = self.run
386
387 id = location['id']
388 chr = location['chr']
389 pos = location['pos']
390 strand = location['strand']
391 original_est = location['seq']
392 quality = location['prb']
393 quality = map(lambda x:ord(x)-self.prb_offset,quality)
394
395 original_est = original_est.lower()
396 est = unbracket_seq(original_est)
397 effective_len = len(est)
398
399 genomicSeq_start = pos - self.max_intron_size
400 genomicSeq_stop = pos + self.max_intron_size + len(est)
401
402 strand_map = {'D':'+', 'P':'-'}
403 strand = strand_map[strand]
404
405 if strand == '-':
406 est = reverse_complement(est)
407 original_est = reverse_complement(original_est)
408
409 start = cpu()
410 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
411 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop)
412 stop = cpu()
413 self.get_time += stop-start
414 dna = currentDNASeq
415
416 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
417
418 alternativeScores = []
419
420 #pdb.set_trace()
421
422 # inlined
423 h = self.h
424 d = self.d
425 a = self.a
426 mmatrix = self.mmatrix
427 qualityPlifs = self.qualityPlifs
428 # inlined
429
430 # find an intron on the 3' end
431 _start = cpu()
432 for (don_pos,don_score) in proximal_don:
433 DonorScore = calculatePlif(d, [don_score])[0]
434
435 for (acc_pos,acc_score,acc_dna) in distal_acc:
436
437 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
438 AcceptorScore = calculatePlif(a, [acc_score])[0]
439
440 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
441
442 # construct a new "original_est"
443 original_est_cut=''
444
445 est_ptr=0
446 dna_ptr=self.max_intron_size
447 ptr=0
448 acc_dna_ptr=0
449 num_mismatch = 0
450
451 while ptr<len(original_est):
452 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
453
454 if original_est[ptr]=='[':
455 dnaletter=original_est[ptr+1]
456 estletter=original_est[ptr+2]
457 if dna_ptr < don_pos:
458 original_est_cut+=original_est[ptr:ptr+4]
459 num_mismatch += 1
460 else:
461 if acc_dna[acc_dna_ptr]==estletter:
462 original_est_cut += estletter # EST letter
463 else:
464 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
465 num_mismatch += 1
466 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
467 acc_dna_ptr+=1
468 ptr+=4
469 else:
470 dnaletter=original_est[ptr]
471 estletter=dnaletter
472
473 if dna_ptr < don_pos:
474 original_est_cut+=estletter # EST letter
475 else:
476 if acc_dna[acc_dna_ptr]==estletter:
477 original_est_cut += estletter # EST letter
478 else:
479 num_mismatch += 1
480 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
481 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
482 acc_dna_ptr+=1
483
484 ptr+=1
485
486 if estletter=='-':
487 dna_ptr+=1
488 elif dnaletter=='-':
489 est_ptr+=1
490 else:
491 dna_ptr+=1
492 est_ptr+=1
493 if num_mismatch>self.max_mismatch:
494 continue
495
496 assert(dna_ptr<=len(dna))
497 assert(est_ptr<=len(est))
498
499 #print original_est, original_est_cut
500
501 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
502 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
503
504 alternativeScores.append(score)
505
506 if acc_score>=self.splice_stop_thresh:
507 break
508
509 #pdb.set_trace()
510
511 _stop = cpu()
512 self.alternativeScoresTime += _stop-_start
513
514 # find an intron on the 5' end
515 _start = cpu()
516 for (acc_pos,acc_score) in proximal_acc:
517
518 AcceptorScore = calculatePlif(a, [acc_score])[0]
519
520 for (don_pos,don_score,don_dna) in distal_don:
521
522 DonorScore = calculatePlif(d, [don_score])[0]
523 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
524
525 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
526
527 # construct a new "original_est"
528 original_est_cut=''
529
530 est_ptr=0
531 dna_ptr=self.max_intron_size
532 ptr=0
533 num_mismatch = 0
534 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
535 while ptr<len(original_est):
536
537 if original_est[ptr]=='[':
538 dnaletter=original_est[ptr+1]
539 estletter=original_est[ptr+2]
540 if dna_ptr > acc_pos:
541 original_est_cut+=original_est[ptr:ptr+4]
542 num_mismatch += 1
543 else:
544 if don_dna[don_dna_ptr]==estletter:
545 original_est_cut += estletter # EST letter
546 else:
547 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
548 num_mismatch += 1
549 #print '['+don_dna[don_dna_ptr]+estletter+']'
550 don_dna_ptr+=1
551 ptr+=4
552 else:
553 dnaletter=original_est[ptr]
554 estletter=dnaletter
555
556 if dna_ptr > acc_pos:
557 original_est_cut+=estletter # EST letter
558 else:
559 if don_dna[don_dna_ptr]==estletter:
560 original_est_cut += estletter # EST letter
561 else:
562 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
563 num_mismatch += 1
564 #print '('+don_dna[don_dna_ptr]+estletter+')'
565 don_dna_ptr+=1
566
567 ptr+=1
568
569 if estletter=='-':
570 dna_ptr+=1
571 elif dnaletter=='-':
572 est_ptr+=1
573 else:
574 dna_ptr+=1
575 est_ptr+=1
576
577 if num_mismatch>self.max_mismatch:
578 continue
579 assert(dna_ptr<=len(dna))
580 assert(est_ptr<=len(est))
581
582 #print original_est, original_est_cut
583
584 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
585 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
586
587 alternativeScores.append(score)
588
589 if don_score>=self.splice_stop_thresh:
590 break
591
592 _stop = cpu()
593 self.alternativeScoresTime += _stop-_start
594
595 return alternativeScores
596
597
598 def calcAlignmentScore(self,alignment):
599 """
600 Given an alignment (dna,exons,est) and the current parameter for QPalma
601 this function calculates the dot product of the feature representation of
602 the alignment and the parameter vector i.e the alignment score.
603 """
604
605 start = cpu()
606 run = self.run
607
608 # Lets start calculation
609 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
610
611 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
612
613 stop = cpu()
614 self.calcAlignmentScoreTime += stop-start
615
616 return score