+ added script which draws the coverage bar plots presented in the QPalma paper
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10 import resource
11
12 from qpalma.DataProc import *
13 from qpalma.computeSpliceWeights import *
14 from qpalma.set_param_palma import *
15 from qpalma.computeSpliceAlignWithQuality import *
16 from qpalma.penalty_lookup_new import *
17 from qpalma.compute_donacc import *
18 from qpalma.TrainingParam import Param
19 from qpalma.Plif import Plf
20
21 #from qpalma.tools.splicesites import getDonAccScores
22 from qpalma.Configuration import *
23
24 #from compile_dataset import getSpliceScores, get_seq_and_scores
25
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy import inf,mean
28
29 from qpalma.parsers import *
30
31 from Lookup import *
32
33
34 def unbracket_est(est):
35 new_est = ''
36 e = 0
37
38 while True:
39 if e >= len(est):
40 break
41
42 if est[e] == '[':
43 new_est += est[e+2]
44 e += 4
45 else:
46 new_est += est[e]
47 e += 1
48
49 return "".join(new_est).lower()
50
51
52 class PipelineHeuristic:
53 """
54 This class wraps the filter which decides whether an alignment found by
55 vmatch is spliced an should be then newly aligned using QPalma or not.
56 """
57
58 def __init__(self,run_fname,data_fname,reads_pipeline_fn,param_fname,result_spliced_fname,result_unspliced_fname):
59 """
60 We need a run object holding information about the nr. of support points
61 etc.
62 """
63
64 run = cPickle.load(open(run_fname))
65 self.run = run
66
67 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
68
69 start = cpu()
70 self.all_remapped_reads = parse_map_vm_heuristic(data_fname)
71 stop = cpu()
72
73 print 'parsed %d reads in %f sec' % (len(self.all_remapped_reads),stop-start)
74
75
76 start = cpu()
77 self.lt1 = Lookup(dna_flat_files)
78 stop = cpu()
79 print 'prefetched sequence and splice data in %f sec' % (stop-start)
80
81 self.result_spliced_fh = open(result_spliced_fname,'w+')
82 self.result_unspliced_fh = open(result_unspliced_fname,'w+')
83
84 start = cpu()
85
86 self.data_fname = data_fname
87
88 self.param = cPickle.load(open(param_fname))
89
90 # Set the parameters such as limits penalties for the Plifs
91 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
92
93 self.h = h
94 self.d = d
95 self.a = a
96 self.mmatrix = mmatrix
97 self.qualityPlifs = qualityPlifs
98
99 self.read_size = 36
100
101 # parameters of the heuristics to decide whether the read is spliced
102 #self.splice_thresh = 0.005
103 self.splice_thresh = 0.01
104 self.max_intron_size = 2000
105 self.max_mismatch = 2
106 #self.splice_stop_thresh = 0.99
107 self.splice_stop_thresh = 0.8
108 self.spliced_bias = 0.0
109
110 param_lines = [\
111 ('%f','splice_thresh',self.splice_thresh),\
112 ('%d','max_intron_size',self.max_intron_size),\
113 ('%d','max_mismatch',self.max_mismatch),\
114 ('%f','splice_stop_thresh',self.splice_stop_thresh),\
115 ('%f','spliced_bias',self.spliced_bias)]
116
117 param_lines = [('# %s: '+form+'\n')%(name,val) for form,name,val in param_lines]
118 param_lines.append('# data: %s\n'%self.data_fname)
119 param_lines.append('# param: %s\n'%param_fname)
120
121 #pdb.set_trace()
122
123 for p_line in param_lines:
124 self.result_spliced_fh.write(p_line)
125 self.result_unspliced_fh.write(p_line)
126
127 self.original_reads = {}
128
129 for line in open(reads_pipeline_fn):
130 line = line.strip()
131 id,seq,q1,q2,q3 = line.split()
132 id = int(id)
133 self.original_reads[id] = seq
134
135 lengthSP = run['numLengthSuppPoints']
136 donSP = run['numDonSuppPoints']
137 accSP = run['numAccSuppPoints']
138 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
139 numq = run['numQualSuppPoints']
140 totalQualSP = run['totalQualSuppPoints']
141
142 currentPhi = zeros((run['numFeatures'],1))
143 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
144 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
145 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
146 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
147
148 totalQualityPenalties = self.param[-totalQualSP:]
149 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
150 self.currentPhi = currentPhi
151
152 # we want to identify spliced reads
153 # so true pos are spliced reads that are predicted "spliced"
154 self.true_pos = 0
155
156 # as false positives we count all reads that are not spliced but predicted
157 # as "spliced"
158 self.false_pos = 0
159
160 self.true_neg = 0
161 self.false_neg = 0
162
163 # total time spend for get seq and scores
164 self.get_time = 0.0
165 self.calcAlignmentScoreTime = 0.0
166 self.alternativeScoresTime = 0.0
167
168 self.count_time = 0.0
169 self.main_loop = 0.0
170 self.splice_site_time = 0.0
171 self.computeSpliceAlignWithQualityTime = 0.0
172 self.computeSpliceWeightsTime = 0.0
173 self.DotProdTime = 0.0
174 self.array_stuff = 0.0
175 stop = cpu()
176
177 self.init_time = stop-start
178
179 def filter(self):
180 """
181 This method...
182 """
183 run = self.run
184
185
186 ctr = 0
187 unspliced_ctr = 0
188 spliced_ctr = 0
189
190 print 'Starting filtering...'
191 _start = cpu()
192
193 #for readId,currentReadLocations in all_remapped_reads.items():
194 #for location in currentReadLocations[:1]:
195
196 for location,original_line in self.all_remapped_reads:
197
198 if ctr % 1000 == 0:
199 print ctr
200
201
202 id = location['id']
203 chr = location['chr']
204 pos = location['pos']
205 strand = location['strand']
206 mismatch = location['mismatches']
207 length = location['length']
208 off = location['offset']
209 seq = location['seq']
210 prb = location['prb']
211 cal_prb = location['cal_prb']
212 chastity = location['chastity']
213
214 id = int(id)
215
216 seq = seq.lower()
217
218 strand_map = {'D':'+', 'P':'-'}
219
220 strand = strand_map[strand]
221
222 if strand == '-':
223 continue
224
225 unb_seq = unbracket_est(seq)
226 effective_len = len(unb_seq)
227
228 genomicSeq_start = pos
229 genomicSeq_stop = pos+effective_len-1
230
231 start = cpu()
232 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
233 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
234 stop = cpu()
235 self.get_time += stop-start
236
237 dna = currentDNASeq
238 exons = zeros((2,1))
239 exons[0,0] = 0
240 exons[1,0] = effective_len
241 est = unb_seq
242 original_est = seq
243 quality = prb
244
245 #pdb.set_trace()
246
247 currentVMatchAlignment = dna, exons, est, original_est, quality,\
248 currentAcc, currentDon
249
250 try:
251 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
252 except:
253 alternativeAlignmentScores = []
254
255
256 if alternativeAlignmentScores == []:
257 # no alignment necessary
258 maxAlternativeAlignmentScore = -inf
259 vMatchScore = 0.0
260 else:
261 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
262 # compute alignment for vmatch unspliced read
263 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
264
265 start = cpu()
266
267 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
268 #print 'all candidates %s' % str(alternativeAlignmentScores)
269
270 new_id = id - 1000000300000
271
272 unspliced = False
273 # unspliced
274 if new_id > 0:
275 unspliced = True
276
277 # Seems that according to our learned parameters VMatch found a good
278 # alignment of the current read
279 if maxAlternativeAlignmentScore < vMatchScore:
280 unspliced_ctr += 1
281
282 self.result_unspliced_fh.write(original_line+'\n')
283
284 if unspliced:
285 self.true_neg += 1
286 else:
287 self.false_neg += 1
288
289 # We found an alternative alignment considering splice sites that scores
290 # higher than the VMatch alignment
291 else:
292 spliced_ctr += 1
293
294 self.result_spliced_fh.write(original_line+'\n')
295
296 if unspliced:
297 self.false_pos += 1
298 else:
299 self.true_pos += 1
300
301 ctr += 1
302 stop = cpu()
303 self.count_time = stop-start
304
305 _stop = cpu()
306 self.main_loop = _stop-_start
307
308 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
309 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
310 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
311
312
313 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
314
315 def signum(a):
316 if a>0:
317 return 1
318 elif a<0:
319 return -1
320 else:
321 return 0
322
323 proximal_acc = []
324 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
325 if currentAcc[idx]>= splice_thresh:
326 proximal_acc.append((idx,currentAcc[idx]))
327
328 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
329 proximal_acc=proximal_acc[-2:]
330
331 distal_acc = []
332 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
333 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
334 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
335
336 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
337 #distal_acc=distal_acc[-2:]
338
339
340 proximal_don = []
341 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
342 if currentDon[idx] >= splice_thresh:
343 proximal_don.append((idx, currentDon[idx]))
344
345 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
346 proximal_don=proximal_don[-2:]
347
348 distal_don = []
349 for idx in xrange(1, max_intron_size):
350 if currentDon[idx] >= splice_thresh and idx>read_size:
351 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
352
353 distal_don.sort(lambda x,y: y[0]-x[0])
354 #distal_don=distal_don[-2:]
355
356 return proximal_acc,proximal_don,distal_acc,distal_don
357
358 def calcAlternativeAlignments(self,location):
359 """
360 Given an alignment proposed by Vmatch this function calculates possible
361 alternative alignments taking into account for example matched
362 donor/acceptor positions.
363 """
364
365 run = self.run
366
367 id = location['id']
368 chr = location['chr']
369 pos = location['pos']
370 strand = location['strand']
371 original_est = location['seq']
372 quality = location['prb']
373 cal_prb = location['cal_prb']
374
375 original_est = original_est.lower()
376 est = unbracket_est(original_est)
377 effective_len = len(est)
378
379 genomicSeq_start = pos - self.max_intron_size
380 genomicSeq_stop = pos + self.max_intron_size + len(est)
381
382 strand_map = {'D':'+', 'P':'-'}
383 strand = strand_map[strand]
384
385 start = cpu()
386 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
387 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
388 stop = cpu()
389 self.get_time += stop-start
390 dna = currentDNASeq
391
392 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
393
394 alternativeScores = []
395
396 # inlined
397 h = self.h
398 d = self.d
399 a = self.a
400 mmatrix = self.mmatrix
401 qualityPlifs = self.qualityPlifs
402 # inlined
403
404 # find an intron on the 3' end
405 _start = cpu()
406 for (don_pos,don_score) in proximal_don:
407 DonorScore = calculatePlif(d, [don_score])[0]
408
409 for (acc_pos,acc_score,acc_dna) in distal_acc:
410
411 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
412 AcceptorScore = calculatePlif(a, [acc_score])[0]
413
414 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
415
416 # construct a new "original_est"
417 original_est_cut=''
418
419 est_ptr=0
420 dna_ptr=self.max_intron_size
421 ptr=0
422 acc_dna_ptr=0
423 num_mismatch = 0
424
425 while ptr<len(original_est):
426 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
427
428 if original_est[ptr]=='[':
429 dnaletter=original_est[ptr+1]
430 estletter=original_est[ptr+2]
431 if dna_ptr < don_pos:
432 original_est_cut+=original_est[ptr:ptr+4]
433 num_mismatch += 1
434 else:
435 if acc_dna[acc_dna_ptr]==estletter:
436 original_est_cut += estletter # EST letter
437 else:
438 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
439 num_mismatch += 1
440 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
441 acc_dna_ptr+=1
442 ptr+=4
443 else:
444 dnaletter=original_est[ptr]
445 estletter=dnaletter
446
447 if dna_ptr < don_pos:
448 original_est_cut+=estletter # EST letter
449 else:
450 if acc_dna[acc_dna_ptr]==estletter:
451 original_est_cut += estletter # EST letter
452 else:
453 num_mismatch += 1
454 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
455 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
456 acc_dna_ptr+=1
457
458 ptr+=1
459
460 if estletter=='-':
461 dna_ptr+=1
462 elif dnaletter=='-':
463 est_ptr+=1
464 else:
465 dna_ptr+=1
466 est_ptr+=1
467 if num_mismatch>self.max_mismatch:
468 continue
469
470 assert(dna_ptr<=len(dna))
471 assert(est_ptr<=len(est))
472
473 #print original_est, original_est_cut
474
475 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
476 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
477
478 alternativeScores.append(score)
479
480 if acc_score>=self.splice_stop_thresh:
481 break
482
483 _stop = cpu()
484 self.alternativeScoresTime += _stop-_start
485
486 # find an intron on the 5' end
487 _start = cpu()
488 for (acc_pos,acc_score) in proximal_acc:
489
490 AcceptorScore = calculatePlif(a, [acc_score])[0]
491
492 for (don_pos,don_score,don_dna) in distal_don:
493
494 DonorScore = calculatePlif(d, [don_score])[0]
495 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
496
497 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
498
499 # construct a new "original_est"
500 original_est_cut=''
501
502 est_ptr=0
503 dna_ptr=self.max_intron_size
504 ptr=0
505 num_mismatch = 0
506 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
507 while ptr<len(original_est):
508
509 if original_est[ptr]=='[':
510 dnaletter=original_est[ptr+1]
511 estletter=original_est[ptr+2]
512 if dna_ptr > acc_pos:
513 original_est_cut+=original_est[ptr:ptr+4]
514 num_mismatch += 1
515 else:
516 if don_dna[don_dna_ptr]==estletter:
517 original_est_cut += estletter # EST letter
518 else:
519 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
520 num_mismatch += 1
521 #print '['+don_dna[don_dna_ptr]+estletter+']'
522 don_dna_ptr+=1
523 ptr+=4
524 else:
525 dnaletter=original_est[ptr]
526 estletter=dnaletter
527
528 if dna_ptr > acc_pos:
529 original_est_cut+=estletter # EST letter
530 else:
531 if don_dna[don_dna_ptr]==estletter:
532 original_est_cut += estletter # EST letter
533 else:
534 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
535 num_mismatch += 1
536 #print '('+don_dna[don_dna_ptr]+estletter+')'
537 don_dna_ptr+=1
538
539 ptr+=1
540
541 if estletter=='-':
542 dna_ptr+=1
543 elif dnaletter=='-':
544 est_ptr+=1
545 else:
546 dna_ptr+=1
547 est_ptr+=1
548
549 if num_mismatch>self.max_mismatch:
550 continue
551 assert(dna_ptr<=len(dna))
552 assert(est_ptr<=len(est))
553
554 #print original_est, original_est_cut
555
556 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
557 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
558
559 alternativeScores.append(score)
560
561 if don_score>=self.splice_stop_thresh:
562 break
563
564 _stop = cpu()
565 self.alternativeScoresTime += _stop-_start
566
567 return alternativeScores
568
569
570 def calcAlignmentScore(self,alignment):
571 """
572 Given an alignment (dna,exons,est) and the current parameter for QPalma
573 this function calculates the dot product of the feature representation of
574 the alignment and the parameter vector i.e the alignment score.
575 """
576
577 start = cpu()
578 run = self.run
579
580 # Lets start calculation
581 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
582
583 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
584
585 stop = cpu()
586 self.calcAlignmentScoreTime += stop-start
587
588 return score
589
590
591 def cpu():
592 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
593 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
594
595
596 if __name__ == '__main__':
597 if len(sys.argv) != 7:
598 print 'Usage: ./%s data reads param run spliced.results unspliced.results' % (sys.argv[0])
599
600 data_fname = sys.argv[1]
601 reads_pipeline_fn = sys.argv[2]
602 param_fname = sys.argv[3]
603 run_fname = sys.argv[4]
604
605 result_spliced_fname = sys.argv[5]
606 result_unspliced_fname = sys.argv[6]
607
608 jp = os.path.join
609
610 #dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
611 #param_fname = jp(dir,'param_500.pickle')
612 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_20k'
613 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
614 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
615 #result_spliced_fname = 'splicedReads.heuristic'
616 #result_unspliced_fname = 'unsplicedReads.heuristic'
617
618 ph1 = PipelineHeuristic(run_fname,data_fname,reads_pipeline_fn,param_fname,result_spliced_fname,result_unspliced_fname)
619
620 start = cpu()
621 ph1.filter()
622 stop = cpu()
623 print 'total time elapsed: %f' % (stop-start)
624
625 print 'time spend for get seq: %f' % ph1.get_time
626 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
627 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
628 print 'time spend for count time: %f' % ph1.count_time
629 print 'time spend for init time: %f' % ph1.init_time
630 print 'time spend for main_loop time: %f' % ph1.main_loop
631 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
632
633 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
634 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
635 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
636 print 'time spend forarray_stuff time: %f'% ph1.array_stuff