+ restructured test cases
[qpalma.git] / qpalma / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 2 of the License, or
7 # (at your option) any later version.
8 #
9 # Written (W) 2008 Gunnar R├Ątsch, Fabio De Bona
10 # Copyright (C) 2008 Max-Planck-Society
11
12 import cPickle
13 import sys
14 import pdb
15 import os
16 import os.path
17 import resource
18
19 from qpalma.computeSpliceWeights import *
20 from qpalma.set_param_palma import *
21 from qpalma.computeSpliceAlignWithQuality import *
22
23 from numpy.matlib import mat,zeros,ones,inf
24 from numpy import inf,mean
25
26 from ParaParser import ParaParser,IN_VECTOR
27
28 from qpalma.Lookup import LookupTable
29 from qpalma.sequence_utils import SeqSpliceInfo,DataAccessWrapper,reverse_complement,unbracket_seq
30
31
32 def cpu():
33 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
34 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
35
36
37 class BracketWrapper:
38 fields = ['id', 'chr', 'pos', 'strand', 'seq', 'prb']
39
40 def __init__(self,filename):
41 self.parser = ParaParser("%lu%d%d%s%s%s",self.fields,len(self.fields),IN_VECTOR)
42 self.parser.parseFile(filename)
43
44 def __len__(self):
45 return self.parser.getSize(IN_VECTOR)
46
47 def __getitem__(self,key):
48 return self.parser.fetchEntry(key)
49
50 def __iter__(self):
51 self.counter = 0
52 self.size = self.parser.getSize(IN_VECTOR)
53 return self
54
55 def next(self):
56 if not self.counter < self.size:
57 raise StopIteration
58 return_val = self.parser.fetchEntry(self.counter)
59 self.counter += 1
60 return return_val
61
62
63 class PipelineHeuristic:
64 """
65 This class wraps the filter which decides whether an alignment found by
66 vmatch is spliced an should be then newly aligned using QPalma or not.
67 """
68
69 def __init__(self,data_fname,param_fname,result_fname,settings):
70 """
71 We need a run object holding information about the nr. of support points
72 etc.
73 """
74
75 start = cpu()
76 self.all_remapped_reads = BracketWrapper(data_fname)
77 stop = cpu()
78
79 print 'parsed %d reads in %f sec' % (len(self.all_remapped_reads),stop-start)
80
81 self.chromo_range = settings['allowed_fragments']
82
83 accessWrapper = DataAccessWrapper(settings)
84 seqInfo = SeqSpliceInfo(accessWrapper,settings['allowed_fragments'])
85
86 start = cpu()
87 self.lt1 = LookupTable(settings)
88 stop = cpu()
89
90 print 'prefetched sequence and splice data in %f sec' % (stop-start)
91
92 self.result_spliced_fh = open('%s.spliced'%result_fname,'w+')
93 self.result_unspliced_fh = open('%s.unspliced'%result_fname,'w+')
94
95 start = cpu()
96
97 self.data_fname = data_fname
98
99 self.param = cPickle.load(open(param_fname))
100
101 # Set the parameters such as limits penalties for the Plifs
102 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,settings)
103
104 self.h = h
105 self.d = d
106 self.a = a
107 self.mmatrix = mmatrix
108 self.qualityPlifs = qualityPlifs
109
110 self.prb_offset = settings['prb_offset']
111
112 # parameters of the heuristics to decide whether the read is spliced
113 #self.splice_thresh = 0.005
114 self.splice_thresh = 0.01
115 self.max_intron_size = 2000
116 self.max_mismatch = 2
117 #self.splice_stop_thresh = 0.99
118 self.splice_stop_thresh = 0.8
119 self.spliced_bias = 0.0
120
121 param_lines = [\
122 ('%f','splice_thresh',self.splice_thresh),\
123 ('%d','max_intron_size',self.max_intron_size),\
124 ('%d','max_mismatch',self.max_mismatch),\
125 ('%f','splice_stop_thresh',self.splice_stop_thresh),\
126 ('%f','spliced_bias',self.spliced_bias)]
127
128 param_lines = [('# %s: '+form+'\n')%(name,val) for form,name,val in param_lines]
129 param_lines.append('# data: %s\n'%self.data_fname)
130 param_lines.append('# param: %s\n'%param_fname)
131
132 #pdb.set_trace()
133
134 for p_line in param_lines:
135 self.result_spliced_fh.write(p_line)
136 self.result_unspliced_fh.write(p_line)
137
138 #self.original_reads = {}
139
140 # we do not have this information
141 #for line in open(reads_pipeline_fn):
142 # line = line.strip()
143 # id,seq,q1,q2,q3 = line.split()
144 # id = int(id)
145 # self.original_reads[id] = seq
146
147 lengthSP = settings['numLengthSuppPoints']
148 donSP = settings['numDonSuppPoints']
149 accSP = settings['numAccSuppPoints']
150 mmatrixSP = settings['matchmatrixRows']*settings['matchmatrixCols']
151 numq = settings['numQualSuppPoints']
152 totalQualSP = settings['totalQualSuppPoints']
153
154 currentPhi = zeros((settings['numFeatures'],1))
155 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
156 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
157 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
158 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
159
160 totalQualityPenalties = self.param[-totalQualSP:]
161 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
162 self.currentPhi = currentPhi
163
164 # we want to identify spliced reads
165 # so true pos are spliced reads that are predicted "spliced"
166 self.true_pos = 0
167
168 # as false positives we count all reads that are not spliced but predicted
169 # as "spliced"
170 self.false_pos = 0
171
172 self.true_neg = 0
173 self.false_neg = 0
174
175 # total time spend for get seq and scores
176 self.get_time = 0.0
177 self.calcAlignmentScoreTime = 0.0
178 self.alternativeScoresTime = 0.0
179
180 self.count_time = 0.0
181 self.main_loop = 0.0
182 self.splice_site_time = 0.0
183 self.computeSpliceAlignWithQualityTime = 0.0
184 self.computeSpliceWeightsTime = 0.0
185 self.DotProdTime = 0.0
186 self.array_stuff = 0.0
187 stop = cpu()
188
189 self.init_time = stop-start
190
191 def filter(self):
192 """
193 This method...
194 """
195
196 ctr = 0
197 unspliced_ctr = 0
198 spliced_ctr = 0
199
200 print 'Starting filtering...'
201 _start = cpu()
202
203 for location,original_line in self.all_remapped_reads:
204
205 if ctr % 1000 == 0:
206 print ctr
207
208 id = location['id']
209 chr = location['chr']
210 pos = location['pos']
211 strand = location['strand']
212 seq = location['seq']
213 prb = location['prb']
214
215 id = int(id)
216
217 seq = seq.lower()
218
219 strand_map = {'D':'+', 'P':'-'}
220
221 strand = strand_map[strand]
222
223 if not chr in self.chromo_range:
224 continue
225
226 unb_seq = unbracket_seq(seq)
227
228 # forgot to do this
229 if strand == '-':
230 unb_seq = reverse_complement(unb_seq)
231 seq = reverse_complement(seq)
232
233 effective_len = len(unb_seq)
234
235 genomicSeq_start = pos
236 #genomicSeq_stop = pos+effective_len-1
237 genomicSeq_stop = pos+effective_len
238
239 start = cpu()
240 #currentDNASeq_, currentAcc_, currentDon_ = self.seqInfo.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,False)
241 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop)
242
243 # just performed some tests to check LookupTable results
244 #assert currentDNASeq_ == currentDNASeq
245 #assert currentAcc_ == currentAcc_
246 #assert currentDon_ == currentDon_
247
248 stop = cpu()
249 self.get_time += stop-start
250
251 dna = currentDNASeq
252 exons = zeros((2,1))
253 exons[0,0] = 0
254 exons[1,0] = effective_len
255 est = unb_seq
256 original_est = seq
257 quality = map(lambda x:ord(x)-self.prb_offset,prb)
258
259 currentVMatchAlignment = dna, exons, est, original_est, quality,\
260 currentAcc, currentDon
261
262 #pdb.set_trace()
263
264 #alternativeAlignmentScores = self.calcAlternativeAlignments(location)
265
266 try:
267 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
268 except:
269 alternativeAlignmentScores = []
270
271 if alternativeAlignmentScores == []:
272 # no alignment necessary
273 maxAlternativeAlignmentScore = -inf
274 vMatchScore = 0.0
275 else:
276 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
277 # compute alignment for vmatch unspliced read
278 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
279 #print 'vMatchScore/alternativeScore: %f %f %s' % (vMatchScore,maxAlternativeAlignmentScore,strand)
280
281
282 start = cpu()
283
284 #print 'all candidates %s' % str(alternativeAlignmentScores)
285
286 new_id = id - 1000000300000
287
288 unspliced = False
289 # unspliced
290 if new_id > 0:
291 unspliced = True
292
293 # Seems that according to our learned parameters VMatch found a good
294 # alignment of the current read
295 if maxAlternativeAlignmentScore < vMatchScore:
296 unspliced_ctr += 1
297
298 self.result_unspliced_fh.write(original_line+'\n')
299
300 if unspliced:
301 self.true_neg += 1
302 else:
303 self.false_neg += 1
304
305 # We found an alternative alignment considering splice sites that scores
306 # higher than the VMatch alignment
307 else:
308 spliced_ctr += 1
309
310 self.result_spliced_fh.write(original_line+'\n')
311
312 if unspliced:
313 self.false_pos += 1
314 else:
315 self.true_pos += 1
316
317 ctr += 1
318 stop = cpu()
319 self.count_time = stop-start
320
321 _stop = cpu()
322 self.main_loop = _stop-_start
323
324 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
325 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
326 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
327
328
329 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
330
331 def signum(a):
332 if a>0:
333 return 1
334 elif a<0:
335 return -1
336 else:
337 return 0
338
339 proximal_acc = []
340 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
341 if currentAcc[idx]>= splice_thresh:
342 proximal_acc.append((idx,currentAcc[idx]))
343
344 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
345 proximal_acc=proximal_acc[-2:]
346
347 distal_acc = []
348 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
349 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
350 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
351
352 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
353 #distal_acc=distal_acc[-2:]
354
355 proximal_don = []
356 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
357 if currentDon[idx] >= splice_thresh:
358 proximal_don.append((idx, currentDon[idx]))
359
360 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
361 proximal_don=proximal_don[-2:]
362
363 distal_don = []
364 for idx in xrange(1, max_intron_size):
365 if currentDon[idx] >= splice_thresh and idx>read_size:
366 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
367
368 distal_don.sort(lambda x,y: y[0]-x[0])
369 #distal_don=distal_don[-2:]
370
371 return proximal_acc,proximal_don,distal_acc,distal_don
372
373
374 def calcAlternativeAlignments(self,location):
375 """
376 Given an alignment proposed by Vmatch this function calculates possible
377 alternative alignments taking into account for example matched
378 donor/acceptor positions.
379 """
380
381 id = location['id']
382 chr = location['chr']
383 pos = location['pos']
384 strand = location['strand']
385 original_est = location['seq']
386 quality = location['prb']
387 quality = map(lambda x:ord(x)-self.prb_offset,quality)
388
389 original_est = original_est.lower()
390 est = unbracket_seq(original_est)
391 effective_len = len(est)
392
393 genomicSeq_start = pos - self.max_intron_size
394 genomicSeq_stop = pos + self.max_intron_size + len(est)
395
396 strand_map = {'D':'+', 'P':'-'}
397 strand = strand_map[strand]
398
399 if strand == '-':
400 est = reverse_complement(est)
401 original_est = reverse_complement(original_est)
402
403 start = cpu()
404 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
405 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop)
406 stop = cpu()
407 self.get_time += stop-start
408 dna = currentDNASeq
409
410 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
411
412 alternativeScores = []
413
414 #pdb.set_trace()
415
416 # inlined
417 h = self.h
418 d = self.d
419 a = self.a
420 mmatrix = self.mmatrix
421 qualityPlifs = self.qualityPlifs
422 # inlined
423
424 # find an intron on the 3' end
425 _start = cpu()
426 for (don_pos,don_score) in proximal_don:
427 DonorScore = calculatePlif(d, [don_score])[0]
428
429 for (acc_pos,acc_score,acc_dna) in distal_acc:
430
431 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
432 AcceptorScore = calculatePlif(a, [acc_score])[0]
433
434 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
435
436 # construct a new "original_est"
437 original_est_cut=''
438
439 est_ptr=0
440 dna_ptr=self.max_intron_size
441 ptr=0
442 acc_dna_ptr=0
443 num_mismatch = 0
444
445 while ptr<len(original_est):
446 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
447
448 if original_est[ptr]=='[':
449 dnaletter=original_est[ptr+1]
450 estletter=original_est[ptr+2]
451 if dna_ptr < don_pos:
452 original_est_cut+=original_est[ptr:ptr+4]
453 num_mismatch += 1
454 else:
455 if acc_dna[acc_dna_ptr]==estletter:
456 original_est_cut += estletter # EST letter
457 else:
458 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
459 num_mismatch += 1
460 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
461 acc_dna_ptr+=1
462 ptr+=4
463 else:
464 dnaletter=original_est[ptr]
465 estletter=dnaletter
466
467 if dna_ptr < don_pos:
468 original_est_cut+=estletter # EST letter
469 else:
470 if acc_dna[acc_dna_ptr]==estletter:
471 original_est_cut += estletter # EST letter
472 else:
473 num_mismatch += 1
474 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
475 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
476 acc_dna_ptr+=1
477
478 ptr+=1
479
480 if estletter=='-':
481 dna_ptr+=1
482 elif dnaletter=='-':
483 est_ptr+=1
484 else:
485 dna_ptr+=1
486 est_ptr+=1
487 if num_mismatch>self.max_mismatch:
488 continue
489
490 assert(dna_ptr<=len(dna))
491 assert(est_ptr<=len(est))
492
493 #print original_est, original_est_cut
494
495 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, settings, self.currentPhi)
496 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
497
498 alternativeScores.append(score)
499
500 if acc_score>=self.splice_stop_thresh:
501 break
502
503 #pdb.set_trace()
504
505 _stop = cpu()
506 self.alternativeScoresTime += _stop-_start
507
508 # find an intron on the 5' end
509 _start = cpu()
510 for (acc_pos,acc_score) in proximal_acc:
511
512 AcceptorScore = calculatePlif(a, [acc_score])[0]
513
514 for (don_pos,don_score,don_dna) in distal_don:
515
516 DonorScore = calculatePlif(d, [don_score])[0]
517 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
518
519 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
520
521 # construct a new "original_est"
522 original_est_cut=''
523
524 est_ptr=0
525 dna_ptr=self.max_intron_size
526 ptr=0
527 num_mismatch = 0
528 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
529 while ptr<len(original_est):
530
531 if original_est[ptr]=='[':
532 dnaletter=original_est[ptr+1]
533 estletter=original_est[ptr+2]
534 if dna_ptr > acc_pos:
535 original_est_cut+=original_est[ptr:ptr+4]
536 num_mismatch += 1
537 else:
538 if don_dna[don_dna_ptr]==estletter:
539 original_est_cut += estletter # EST letter
540 else:
541 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
542 num_mismatch += 1
543 #print '['+don_dna[don_dna_ptr]+estletter+']'
544 don_dna_ptr+=1
545 ptr+=4
546 else:
547 dnaletter=original_est[ptr]
548 estletter=dnaletter
549
550 if dna_ptr > acc_pos:
551 original_est_cut+=estletter # EST letter
552 else:
553 if don_dna[don_dna_ptr]==estletter:
554 original_est_cut += estletter # EST letter
555 else:
556 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
557 num_mismatch += 1
558 #print '('+don_dna[don_dna_ptr]+estletter+')'
559 don_dna_ptr+=1
560
561 ptr+=1
562
563 if estletter=='-':
564 dna_ptr+=1
565 elif dnaletter=='-':
566 est_ptr+=1
567 else:
568 dna_ptr+=1
569 est_ptr+=1
570
571 if num_mismatch>self.max_mismatch:
572 continue
573 assert(dna_ptr<=len(dna))
574 assert(est_ptr<=len(est))
575
576 #print original_est, original_est_cut
577
578 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, settings, self.currentPhi)
579 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
580
581 alternativeScores.append(score)
582
583 if don_score>=self.splice_stop_thresh:
584 break
585
586 _stop = cpu()
587 self.alternativeScoresTime += _stop-_start
588
589 return alternativeScores
590
591
592 def calcAlignmentScore(self,alignment):
593 """
594 Given an alignment (dna,exons,est) and the current parameter for QPalma
595 this function calculates the dot product of the feature representation of
596 the alignment and the parameter vector i.e the alignment score.
597 """
598
599 start = cpu()
600 # Lets start calculation
601 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
602
603 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, settings, self.currentPhi)
604
605 stop = cpu()
606 self.calcAlignmentScoreTime += stop-start
607
608 return score