+ using new parsers now
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10 import resource
11
12 from qpalma.DataProc import *
13 from qpalma.computeSpliceWeights import *
14 from qpalma.set_param_palma import *
15 from qpalma.computeSpliceAlignWithQuality import *
16 from qpalma.penalty_lookup_new import *
17 from qpalma.compute_donacc import *
18 from qpalma.TrainingParam import Param
19 from qpalma.Plif import Plf
20
21 #from qpalma.tools.splicesites import getDonAccScores
22 from qpalma.Configuration import *
23
24 #from compile_dataset import getSpliceScores, get_seq_and_scores
25
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy import inf,mean
28
29 from qpalma.parsers import *
30
31 from Lookup import *
32
33
34 def unbracket_est(est):
35 new_est = ''
36 e = 0
37
38 while True:
39 if e >= len(est):
40 break
41
42 if est[e] == '[':
43 new_est += est[e+2]
44 e += 4
45 else:
46 new_est += est[e]
47 e += 1
48
49 return "".join(new_est).lower()
50
51
52 class PipelineHeuristic:
53 """
54 This class wraps the filter which decides whether an alignment found by
55 vmatch is spliced an should be then newly aligned using QPalma or not.
56 """
57
58 def __init__(self,run_fname,data_fname,param_fname,result_spliced_fname,result_unspliced_fname):
59 """
60 We need a run object holding information about the nr. of support points
61 etc.
62 """
63
64 run = cPickle.load(open(run_fname))
65 self.run = run
66
67 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
68
69 start = cpu()
70 self.all_remapped_reads = parse(data_fname)
71 stop = cpu()
72
73 print 'parsed %d reads in %f sec' % (len(self.all_remapped_reads),stop-start)
74
75
76 start = cpu()
77 self.lt1 = Lookup(dna_flat_files)
78 stop = cpu()
79 print 'prefetched sequence and splice data in %f sec' % (stop-start)
80
81 self.result_spliced_fh = open(result_spliced_fname,'w+')
82 self.result_unspliced_fh = open(result_unspliced_fname,'w+')
83
84 start = cpu()
85
86 self.data_fname = data_fname
87
88 self.param = cPickle.load(open(param_fname))
89
90 # Set the parameters such as limits penalties for the Plifs
91 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
92
93 self.h = h
94 self.d = d
95 self.a = a
96 self.mmatrix = mmatrix
97 self.qualityPlifs = qualityPlifs
98
99 self.read_size = 36
100
101 # parameters of the heuristics to decide whether the read is spliced
102 #self.splice_thresh = 0.005
103 self.splice_thresh = 0.01
104 self.max_intron_size = 2000
105 self.max_mismatch = 2
106 #self.splice_stop_thresh = 0.99
107 self.splice_stop_thresh = 0.8
108 self.spliced_bias = 0.0
109
110 param_lines = [\
111 ('%f','splice_thresh',self.splice_thresh),\
112 ('%d','max_intron_size',self.max_intron_size),\
113 ('%d','max_mismatch',self.max_mismatch),\
114 ('%f','splice_stop_thresh',self.splice_stop_thresh),\
115 ('%f','spliced_bias',self.spliced_bias)]
116
117 param_lines = [('# %s: '+form+'\n')%(name,val) for form,name,val in param_lines]
118 param_lines.append('# data: %s\n'%self.data_fname)
119 param_lines.append('# param: %s\n'%param_fname)
120
121 #pdb.set_trace()
122
123 for p_line in param_lines:
124 self.result_spliced_fh.write(p_line)
125 self.result_unspliced_fh.write(p_line)
126
127 self.original_reads = {}
128
129 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
130 line = line.strip()
131 id,seq,q1,q2,q3 = line.split()
132 id = int(id)
133 self.original_reads[id] = seq
134
135 lengthSP = run['numLengthSuppPoints']
136 donSP = run['numDonSuppPoints']
137 accSP = run['numAccSuppPoints']
138 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
139 numq = run['numQualSuppPoints']
140 totalQualSP = run['totalQualSuppPoints']
141
142 currentPhi = zeros((run['numFeatures'],1))
143 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
144 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
145 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
146 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
147
148 totalQualityPenalties = self.param[-totalQualSP:]
149 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
150 self.currentPhi = currentPhi
151
152 # we want to identify spliced reads
153 # so true pos are spliced reads that are predicted "spliced"
154 self.true_pos = 0
155
156 # as false positives we count all reads that are not spliced but predicted
157 # as "spliced"
158 self.false_pos = 0
159
160 self.true_neg = 0
161 self.false_neg = 0
162
163 # total time spend for get seq and scores
164 self.get_time = 0.0
165 self.calcAlignmentScoreTime = 0.0
166 self.alternativeScoresTime = 0.0
167
168 self.count_time = 0.0
169 self.main_loop = 0.0
170 self.splice_site_time = 0.0
171 self.computeSpliceAlignWithQualityTime = 0.0
172 self.computeSpliceWeightsTime = 0.0
173 self.DotProdTime = 0.0
174 self.array_stuff = 0.0
175 stop = cpu()
176
177 self.init_time = stop-start
178
179 def filter(self):
180 """
181 This method...
182 """
183 run = self.run
184
185
186 ctr = 0
187 unspliced_ctr = 0
188 spliced_ctr = 0
189
190 print 'Starting filtering...'
191 _start = cpu()
192
193 #for readId,currentReadLocations in all_remapped_reads.items():
194 #for location in currentReadLocations[:1]:
195
196 for location,original_line in self.all_remapped_reads:
197
198 if ctr % 1000 == 0:
199 print ctr
200
201
202 id = location['id']
203 chr = location['chr']
204 pos = location['pos']
205 strand = location['strand']
206 mismatch = location['mismatches']
207 length = location['length']
208 off = location['offset']
209 seq = location['seq']
210 prb = location['prb']
211 cal_prb = location['cal_prb']
212 chastity = location['chastity']
213
214 id = int(id)
215
216 seq = seq.lower()
217
218 strand_map = {'D':'+', 'P':'-'}
219
220 strand = strand_map[strand]
221
222 if strand == '-':
223 continue
224
225 unb_seq = unbracket_est(seq)
226 effective_len = len(unb_seq)
227
228 genomicSeq_start = pos
229 genomicSeq_stop = pos+effective_len-1
230
231 start = cpu()
232 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
233 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
234 stop = cpu()
235 self.get_time += stop-start
236
237 dna = currentDNASeq
238 exons = zeros((2,1))
239 exons[0,0] = 0
240 exons[1,0] = effective_len
241 est = unb_seq
242 original_est = seq
243 quality = prb
244
245 #pdb.set_trace()
246
247 currentVMatchAlignment = dna, exons, est, original_est, quality,\
248 currentAcc, currentDon
249
250 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
251
252 if alternativeAlignmentScores == []:
253 # no alignment necessary
254 maxAlternativeAlignmentScore = -inf
255 vMatchScore = 0.0
256 else:
257 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
258 # compute alignment for vmatch unspliced read
259 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
260
261 start = cpu()
262
263 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
264 #print 'all candidates %s' % str(alternativeAlignmentScores)
265
266 new_id = id - 1000000300000
267
268 unspliced = False
269 # unspliced
270 if new_id > 0:
271 unspliced = True
272
273 # Seems that according to our learned parameters VMatch found a good
274 # alignment of the current read
275 if maxAlternativeAlignmentScore < vMatchScore:
276 unspliced_ctr += 1
277
278 self.result_unspliced_fh.write(original_line+'\n')
279
280 if unspliced:
281 self.true_neg += 1
282 else:
283 self.false_neg += 1
284
285 # We found an alternative alignment considering splice sites that scores
286 # higher than the VMatch alignment
287 else:
288 spliced_ctr += 1
289
290 self.result_spliced_fh.write(original_line+'\n')
291
292 if unspliced:
293 self.false_pos += 1
294 else:
295 self.true_pos += 1
296
297 ctr += 1
298 stop = cpu()
299 self.count_time = stop-start
300
301 _stop = cpu()
302 self.main_loop = _stop-_start
303
304 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
305 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
306 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
307
308
309 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
310
311 def signum(a):
312 if a>0:
313 return 1
314 elif a<0:
315 return -1
316 else:
317 return 0
318
319 proximal_acc = []
320 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
321 if currentAcc[idx]>= splice_thresh:
322 proximal_acc.append((idx,currentAcc[idx]))
323
324 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
325 proximal_acc=proximal_acc[-2:]
326
327 distal_acc = []
328 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
329 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
330 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
331
332 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
333 #distal_acc=distal_acc[-2:]
334
335
336 proximal_don = []
337 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
338 if currentDon[idx] >= splice_thresh:
339 proximal_don.append((idx, currentDon[idx]))
340
341 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
342 proximal_don=proximal_don[-2:]
343
344 distal_don = []
345 for idx in xrange(1, max_intron_size):
346 if currentDon[idx] >= splice_thresh and idx>read_size:
347 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
348
349 distal_don.sort(lambda x,y: y[0]-x[0])
350 #distal_don=distal_don[-2:]
351
352 return proximal_acc,proximal_don,distal_acc,distal_don
353
354 def calcAlternativeAlignments(self,location):
355 """
356 Given an alignment proposed by Vmatch this function calculates possible
357 alternative alignments taking into account for example matched
358 donor/acceptor positions.
359 """
360
361 run = self.run
362
363 id = location['id']
364 chr = location['chr']
365 pos = location['pos']
366 strand = location['strand']
367 original_est = location['seq']
368 quality = location['prb']
369 cal_prb = location['cal_prb']
370
371 original_est = original_est.lower()
372 est = unbracket_est(original_est)
373 effective_len = len(est)
374
375 genomicSeq_start = pos - self.max_intron_size
376 genomicSeq_stop = pos + self.max_intron_size + len(est)
377
378 strand_map = {'D':'+', 'P':'-'}
379 strand = strand_map[strand]
380
381 start = cpu()
382 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
383 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
384 stop = cpu()
385 self.get_time += stop-start
386 dna = currentDNASeq
387
388 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
389
390 alternativeScores = []
391
392 # inlined
393 h = self.h
394 d = self.d
395 a = self.a
396 mmatrix = self.mmatrix
397 qualityPlifs = self.qualityPlifs
398 # inlined
399
400 # find an intron on the 3' end
401 _start = cpu()
402 for (don_pos,don_score) in proximal_don:
403 DonorScore = calculatePlif(d, [don_score])[0]
404
405 for (acc_pos,acc_score,acc_dna) in distal_acc:
406
407 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
408 AcceptorScore = calculatePlif(a, [acc_score])[0]
409
410 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
411
412 # construct a new "original_est"
413 original_est_cut=''
414
415 est_ptr=0
416 dna_ptr=self.max_intron_size
417 ptr=0
418 acc_dna_ptr=0
419 num_mismatch = 0
420
421 while ptr<len(original_est):
422 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
423
424 if original_est[ptr]=='[':
425 dnaletter=original_est[ptr+1]
426 estletter=original_est[ptr+2]
427 if dna_ptr < don_pos:
428 original_est_cut+=original_est[ptr:ptr+4]
429 num_mismatch += 1
430 else:
431 if acc_dna[acc_dna_ptr]==estletter:
432 original_est_cut += estletter # EST letter
433 else:
434 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
435 num_mismatch += 1
436 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
437 acc_dna_ptr+=1
438 ptr+=4
439 else:
440 dnaletter=original_est[ptr]
441 estletter=dnaletter
442
443 if dna_ptr < don_pos:
444 original_est_cut+=estletter # EST letter
445 else:
446 if acc_dna[acc_dna_ptr]==estletter:
447 original_est_cut += estletter # EST letter
448 else:
449 num_mismatch += 1
450 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
451 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
452 acc_dna_ptr+=1
453
454 ptr+=1
455
456 if estletter=='-':
457 dna_ptr+=1
458 elif dnaletter=='-':
459 est_ptr+=1
460 else:
461 dna_ptr+=1
462 est_ptr+=1
463 if num_mismatch>self.max_mismatch:
464 continue
465
466 assert(dna_ptr<=len(dna))
467 assert(est_ptr<=len(est))
468
469 #print original_est, original_est_cut
470
471 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
472 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
473
474 alternativeScores.append(score)
475
476 if acc_score>=self.splice_stop_thresh:
477 break
478
479 _stop = cpu()
480 self.alternativeScoresTime += _stop-_start
481
482 # find an intron on the 5' end
483 _start = cpu()
484 for (acc_pos,acc_score) in proximal_acc:
485
486 AcceptorScore = calculatePlif(a, [acc_score])[0]
487
488 for (don_pos,don_score,don_dna) in distal_don:
489
490 DonorScore = calculatePlif(d, [don_score])[0]
491 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
492
493 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
494
495 # construct a new "original_est"
496 original_est_cut=''
497
498 est_ptr=0
499 dna_ptr=self.max_intron_size
500 ptr=0
501 num_mismatch = 0
502 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
503 while ptr<len(original_est):
504
505 if original_est[ptr]=='[':
506 dnaletter=original_est[ptr+1]
507 estletter=original_est[ptr+2]
508 if dna_ptr > acc_pos:
509 original_est_cut+=original_est[ptr:ptr+4]
510 num_mismatch += 1
511 else:
512 if don_dna[don_dna_ptr]==estletter:
513 original_est_cut += estletter # EST letter
514 else:
515 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
516 num_mismatch += 1
517 #print '['+don_dna[don_dna_ptr]+estletter+']'
518 don_dna_ptr+=1
519 ptr+=4
520 else:
521 dnaletter=original_est[ptr]
522 estletter=dnaletter
523
524 if dna_ptr > acc_pos:
525 original_est_cut+=estletter # EST letter
526 else:
527 if don_dna[don_dna_ptr]==estletter:
528 original_est_cut += estletter # EST letter
529 else:
530 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
531 num_mismatch += 1
532 #print '('+don_dna[don_dna_ptr]+estletter+')'
533 don_dna_ptr+=1
534
535 ptr+=1
536
537 if estletter=='-':
538 dna_ptr+=1
539 elif dnaletter=='-':
540 est_ptr+=1
541 else:
542 dna_ptr+=1
543 est_ptr+=1
544
545 if num_mismatch>self.max_mismatch:
546 continue
547 assert(dna_ptr<=len(dna))
548 assert(est_ptr<=len(est))
549
550 #print original_est, original_est_cut
551
552 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
553 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
554
555 alternativeScores.append(score)
556
557 if don_score>=self.splice_stop_thresh:
558 break
559
560 _stop = cpu()
561 self.alternativeScoresTime += _stop-_start
562
563 return alternativeScores
564
565
566 def calcAlignmentScore(self,alignment):
567 """
568 Given an alignment (dna,exons,est) and the current parameter for QPalma
569 this function calculates the dot product of the feature representation of
570 the alignment and the parameter vector i.e the alignment score.
571 """
572
573 start = cpu()
574 run = self.run
575
576 # Lets start calculation
577 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
578
579 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
580
581 stop = cpu()
582 self.calcAlignmentScoreTime += stop-start
583
584 return score
585
586
587 def cpu():
588 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
589 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
590
591
592 if __name__ == '__main__':
593 if len(sys.argv) != 6:
594 print 'Usage: ./%s data param run spliced.results unspliced.results' % (sys.argv[0])
595
596 data_fname = sys.argv[1]
597 param_fname = sys.argv[2]
598 run_fname = sys.argv[3]
599
600 result_spliced_fname = sys.argv[4]
601 result_unspliced_fname = sys.argv[5]
602
603 jp = os.path.join
604
605 #dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
606 #param_fname = jp(dir,'param_500.pickle')
607 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_20k'
608 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
609 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
610 #result_spliced_fname = 'splicedReads.heuristic'
611 #result_unspliced_fname = 'unsplicedReads.heuristic'
612
613 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname,result_spliced_fname,result_unspliced_fname)
614
615 start = cpu()
616 ph1.filter()
617 stop = cpu()
618 print 'total time elapsed: %f' % (stop-start)
619
620 print 'time spend for get seq: %f' % ph1.get_time
621 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
622 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
623 print 'time spend for count time: %f' % ph1.count_time
624 print 'time spend for init time: %f' % ph1.init_time
625 print 'time spend for main_loop time: %f' % ph1.main_loop
626 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
627
628 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
629 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
630 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
631 print 'time spend forarray_stuff time: %f'% ph1.array_stuff