+ fixed some bugs in the negative strand lookup table
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10 import resource
11
12 from qpalma.DataProc import *
13 from qpalma.computeSpliceWeights import *
14 from qpalma.set_param_palma import *
15 from qpalma.computeSpliceAlignWithQuality import *
16 from qpalma.penalty_lookup_new import *
17 from qpalma.compute_donacc import *
18 from qpalma.TrainingParam import Param
19 from qpalma.Plif import Plf
20
21 #from qpalma.tools.splicesites import getDonAccScores
22 from qpalma.Configuration import *
23
24 #from compile_dataset import getSpliceScores, get_seq_and_scores
25
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy import inf,mean
28
29 from qpalma.parsers import *
30
31 from Lookup import *
32
33 from qpalma.sequence_utils import *
34
35
36 class PipelineHeuristic:
37 """
38 This class wraps the filter which decides whether an alignment found by
39 vmatch is spliced an should be then newly aligned using QPalma or not.
40 """
41
42 def __init__(self,run_fname,data_fname,reads_pipeline_fn,param_fname,result_spliced_fname,result_unspliced_fname):
43 """
44 We need a run object holding information about the nr. of support points
45 etc.
46 """
47
48 run = cPickle.load(open(run_fname))
49 self.run = run
50
51 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
52
53 start = cpu()
54 self.all_remapped_reads = parse_map_vm_heuristic(data_fname)
55 stop = cpu()
56
57 print 'parsed %d reads in %f sec' % (len(self.all_remapped_reads),stop-start)
58
59
60 start = cpu()
61 self.lt1 = Lookup(dna_flat_files)
62 stop = cpu()
63 print 'prefetched sequence and splice data in %f sec' % (stop-start)
64
65 self.result_spliced_fh = open(result_spliced_fname,'w+')
66 self.result_unspliced_fh = open(result_unspliced_fname,'w+')
67
68 start = cpu()
69
70 self.data_fname = data_fname
71
72 self.param = cPickle.load(open(param_fname))
73
74 # Set the parameters such as limits penalties for the Plifs
75 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
76
77 self.h = h
78 self.d = d
79 self.a = a
80 self.mmatrix = mmatrix
81 self.qualityPlifs = qualityPlifs
82
83 self.read_size = 36
84
85 # parameters of the heuristics to decide whether the read is spliced
86 #self.splice_thresh = 0.005
87 self.splice_thresh = 0.01
88 self.max_intron_size = 2000
89 self.max_mismatch = 2
90 #self.splice_stop_thresh = 0.99
91 self.splice_stop_thresh = 0.8
92 self.spliced_bias = 0.0
93
94 param_lines = [\
95 ('%f','splice_thresh',self.splice_thresh),\
96 ('%d','max_intron_size',self.max_intron_size),\
97 ('%d','max_mismatch',self.max_mismatch),\
98 ('%f','splice_stop_thresh',self.splice_stop_thresh),\
99 ('%f','spliced_bias',self.spliced_bias)]
100
101 param_lines = [('# %s: '+form+'\n')%(name,val) for form,name,val in param_lines]
102 param_lines.append('# data: %s\n'%self.data_fname)
103 param_lines.append('# param: %s\n'%param_fname)
104
105 #pdb.set_trace()
106
107 for p_line in param_lines:
108 self.result_spliced_fh.write(p_line)
109 self.result_unspliced_fh.write(p_line)
110
111 self.original_reads = {}
112
113 for line in open(reads_pipeline_fn):
114 line = line.strip()
115 id,seq,q1,q2,q3 = line.split()
116 id = int(id)
117 self.original_reads[id] = seq
118
119 lengthSP = run['numLengthSuppPoints']
120 donSP = run['numDonSuppPoints']
121 accSP = run['numAccSuppPoints']
122 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
123 numq = run['numQualSuppPoints']
124 totalQualSP = run['totalQualSuppPoints']
125
126 currentPhi = zeros((run['numFeatures'],1))
127 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
128 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
129 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
130 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
131
132 totalQualityPenalties = self.param[-totalQualSP:]
133 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
134 self.currentPhi = currentPhi
135
136 # we want to identify spliced reads
137 # so true pos are spliced reads that are predicted "spliced"
138 self.true_pos = 0
139
140 # as false positives we count all reads that are not spliced but predicted
141 # as "spliced"
142 self.false_pos = 0
143
144 self.true_neg = 0
145 self.false_neg = 0
146
147 # total time spend for get seq and scores
148 self.get_time = 0.0
149 self.calcAlignmentScoreTime = 0.0
150 self.alternativeScoresTime = 0.0
151
152 self.count_time = 0.0
153 self.main_loop = 0.0
154 self.splice_site_time = 0.0
155 self.computeSpliceAlignWithQualityTime = 0.0
156 self.computeSpliceWeightsTime = 0.0
157 self.DotProdTime = 0.0
158 self.array_stuff = 0.0
159 stop = cpu()
160
161 self.init_time = stop-start
162
163 def filter(self):
164 """
165 This method...
166 """
167 run = self.run
168
169
170 ctr = 0
171 unspliced_ctr = 0
172 spliced_ctr = 0
173
174 print 'Starting filtering...'
175 _start = cpu()
176
177 #for readId,currentReadLocations in all_remapped_reads.items():
178 #for location in currentReadLocations[:1]:
179
180 for location,original_line in self.all_remapped_reads:
181
182 if ctr % 1000 == 0:
183 print ctr
184
185
186 id = location['id']
187 chr = location['chr']
188 pos = location['pos']
189 strand = location['strand']
190 mismatch = location['mismatches']
191 length = location['length']
192 off = location['offset']
193 seq = location['seq']
194 prb = location['prb']
195 cal_prb = location['cal_prb']
196 chastity = location['chastity']
197
198 id = int(id)
199
200 seq = seq.lower()
201
202 strand_map = {'D':'+', 'P':'-'}
203
204 strand = strand_map[strand]
205
206 if strand == '-':
207 continue
208
209 unb_seq = unbracket_est(seq)
210 effective_len = len(unb_seq)
211
212 genomicSeq_start = pos
213 genomicSeq_stop = pos+effective_len-1
214
215 start = cpu()
216 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
217 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
218 stop = cpu()
219 self.get_time += stop-start
220
221 dna = currentDNASeq
222 exons = zeros((2,1))
223 exons[0,0] = 0
224 exons[1,0] = effective_len
225 est = unb_seq
226 original_est = seq
227 quality = prb
228
229 #pdb.set_trace()
230
231 currentVMatchAlignment = dna, exons, est, original_est, quality,\
232 currentAcc, currentDon
233
234 try:
235 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
236 except:
237 alternativeAlignmentScores = []
238
239
240 if alternativeAlignmentScores == []:
241 # no alignment necessary
242 maxAlternativeAlignmentScore = -inf
243 vMatchScore = 0.0
244 else:
245 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
246 # compute alignment for vmatch unspliced read
247 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
248
249 start = cpu()
250
251 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
252 #print 'all candidates %s' % str(alternativeAlignmentScores)
253
254 new_id = id - 1000000300000
255
256 unspliced = False
257 # unspliced
258 if new_id > 0:
259 unspliced = True
260
261 # Seems that according to our learned parameters VMatch found a good
262 # alignment of the current read
263 if maxAlternativeAlignmentScore < vMatchScore:
264 unspliced_ctr += 1
265
266 self.result_unspliced_fh.write(original_line+'\n')
267
268 if unspliced:
269 self.true_neg += 1
270 else:
271 self.false_neg += 1
272
273 # We found an alternative alignment considering splice sites that scores
274 # higher than the VMatch alignment
275 else:
276 spliced_ctr += 1
277
278 self.result_spliced_fh.write(original_line+'\n')
279
280 if unspliced:
281 self.false_pos += 1
282 else:
283 self.true_pos += 1
284
285 ctr += 1
286 stop = cpu()
287 self.count_time = stop-start
288
289 _stop = cpu()
290 self.main_loop = _stop-_start
291
292 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
293 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
294 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
295
296
297 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
298
299 def signum(a):
300 if a>0:
301 return 1
302 elif a<0:
303 return -1
304 else:
305 return 0
306
307 proximal_acc = []
308 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
309 if currentAcc[idx]>= splice_thresh:
310 proximal_acc.append((idx,currentAcc[idx]))
311
312 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
313 proximal_acc=proximal_acc[-2:]
314
315 distal_acc = []
316 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
317 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
318 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
319
320 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
321 #distal_acc=distal_acc[-2:]
322
323
324 proximal_don = []
325 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
326 if currentDon[idx] >= splice_thresh:
327 proximal_don.append((idx, currentDon[idx]))
328
329 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
330 proximal_don=proximal_don[-2:]
331
332 distal_don = []
333 for idx in xrange(1, max_intron_size):
334 if currentDon[idx] >= splice_thresh and idx>read_size:
335 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
336
337 distal_don.sort(lambda x,y: y[0]-x[0])
338 #distal_don=distal_don[-2:]
339
340 return proximal_acc,proximal_don,distal_acc,distal_don
341
342 def calcAlternativeAlignments(self,location):
343 """
344 Given an alignment proposed by Vmatch this function calculates possible
345 alternative alignments taking into account for example matched
346 donor/acceptor positions.
347 """
348
349 run = self.run
350
351 id = location['id']
352 chr = location['chr']
353 pos = location['pos']
354 strand = location['strand']
355 original_est = location['seq']
356 quality = location['prb']
357 cal_prb = location['cal_prb']
358
359 original_est = original_est.lower()
360 est = unbracket_est(original_est)
361 effective_len = len(est)
362
363 genomicSeq_start = pos - self.max_intron_size
364 genomicSeq_stop = pos + self.max_intron_size + len(est)
365
366 strand_map = {'D':'+', 'P':'-'}
367 strand = strand_map[strand]
368
369 start = cpu()
370 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
371 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
372 stop = cpu()
373 self.get_time += stop-start
374 dna = currentDNASeq
375
376 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
377
378 alternativeScores = []
379
380 # inlined
381 h = self.h
382 d = self.d
383 a = self.a
384 mmatrix = self.mmatrix
385 qualityPlifs = self.qualityPlifs
386 # inlined
387
388 # find an intron on the 3' end
389 _start = cpu()
390 for (don_pos,don_score) in proximal_don:
391 DonorScore = calculatePlif(d, [don_score])[0]
392
393 for (acc_pos,acc_score,acc_dna) in distal_acc:
394
395 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
396 AcceptorScore = calculatePlif(a, [acc_score])[0]
397
398 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
399
400 # construct a new "original_est"
401 original_est_cut=''
402
403 est_ptr=0
404 dna_ptr=self.max_intron_size
405 ptr=0
406 acc_dna_ptr=0
407 num_mismatch = 0
408
409 while ptr<len(original_est):
410 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
411
412 if original_est[ptr]=='[':
413 dnaletter=original_est[ptr+1]
414 estletter=original_est[ptr+2]
415 if dna_ptr < don_pos:
416 original_est_cut+=original_est[ptr:ptr+4]
417 num_mismatch += 1
418 else:
419 if acc_dna[acc_dna_ptr]==estletter:
420 original_est_cut += estletter # EST letter
421 else:
422 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
423 num_mismatch += 1
424 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
425 acc_dna_ptr+=1
426 ptr+=4
427 else:
428 dnaletter=original_est[ptr]
429 estletter=dnaletter
430
431 if dna_ptr < don_pos:
432 original_est_cut+=estletter # EST letter
433 else:
434 if acc_dna[acc_dna_ptr]==estletter:
435 original_est_cut += estletter # EST letter
436 else:
437 num_mismatch += 1
438 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
439 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
440 acc_dna_ptr+=1
441
442 ptr+=1
443
444 if estletter=='-':
445 dna_ptr+=1
446 elif dnaletter=='-':
447 est_ptr+=1
448 else:
449 dna_ptr+=1
450 est_ptr+=1
451 if num_mismatch>self.max_mismatch:
452 continue
453
454 assert(dna_ptr<=len(dna))
455 assert(est_ptr<=len(est))
456
457 #print original_est, original_est_cut
458
459 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
460 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
461
462 alternativeScores.append(score)
463
464 if acc_score>=self.splice_stop_thresh:
465 break
466
467 _stop = cpu()
468 self.alternativeScoresTime += _stop-_start
469
470 # find an intron on the 5' end
471 _start = cpu()
472 for (acc_pos,acc_score) in proximal_acc:
473
474 AcceptorScore = calculatePlif(a, [acc_score])[0]
475
476 for (don_pos,don_score,don_dna) in distal_don:
477
478 DonorScore = calculatePlif(d, [don_score])[0]
479 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
480
481 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
482
483 # construct a new "original_est"
484 original_est_cut=''
485
486 est_ptr=0
487 dna_ptr=self.max_intron_size
488 ptr=0
489 num_mismatch = 0
490 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
491 while ptr<len(original_est):
492
493 if original_est[ptr]=='[':
494 dnaletter=original_est[ptr+1]
495 estletter=original_est[ptr+2]
496 if dna_ptr > acc_pos:
497 original_est_cut+=original_est[ptr:ptr+4]
498 num_mismatch += 1
499 else:
500 if don_dna[don_dna_ptr]==estletter:
501 original_est_cut += estletter # EST letter
502 else:
503 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
504 num_mismatch += 1
505 #print '['+don_dna[don_dna_ptr]+estletter+']'
506 don_dna_ptr+=1
507 ptr+=4
508 else:
509 dnaletter=original_est[ptr]
510 estletter=dnaletter
511
512 if dna_ptr > acc_pos:
513 original_est_cut+=estletter # EST letter
514 else:
515 if don_dna[don_dna_ptr]==estletter:
516 original_est_cut += estletter # EST letter
517 else:
518 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
519 num_mismatch += 1
520 #print '('+don_dna[don_dna_ptr]+estletter+')'
521 don_dna_ptr+=1
522
523 ptr+=1
524
525 if estletter=='-':
526 dna_ptr+=1
527 elif dnaletter=='-':
528 est_ptr+=1
529 else:
530 dna_ptr+=1
531 est_ptr+=1
532
533 if num_mismatch>self.max_mismatch:
534 continue
535 assert(dna_ptr<=len(dna))
536 assert(est_ptr<=len(est))
537
538 #print original_est, original_est_cut
539
540 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
541 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
542
543 alternativeScores.append(score)
544
545 if don_score>=self.splice_stop_thresh:
546 break
547
548 _stop = cpu()
549 self.alternativeScoresTime += _stop-_start
550
551 return alternativeScores
552
553
554 def calcAlignmentScore(self,alignment):
555 """
556 Given an alignment (dna,exons,est) and the current parameter for QPalma
557 this function calculates the dot product of the feature representation of
558 the alignment and the parameter vector i.e the alignment score.
559 """
560
561 start = cpu()
562 run = self.run
563
564 # Lets start calculation
565 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
566
567 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
568
569 stop = cpu()
570 self.calcAlignmentScoreTime += stop-start
571
572 return score
573
574
575 def cpu():
576 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
577 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
578
579
580 if __name__ == '__main__':
581 if len(sys.argv) != 7:
582 print 'Usage: ./%s data reads param run spliced.results unspliced.results' % (sys.argv[0])
583
584 data_fname = sys.argv[1]
585 reads_pipeline_fn = sys.argv[2]
586 param_fname = sys.argv[3]
587 run_fname = sys.argv[4]
588
589 result_spliced_fname = sys.argv[5]
590 result_unspliced_fname = sys.argv[6]
591
592 jp = os.path.join
593
594 #dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
595 #param_fname = jp(dir,'param_500.pickle')
596 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_20k'
597 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
598 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
599 #result_spliced_fname = 'splicedReads.heuristic'
600 #result_unspliced_fname = 'unsplicedReads.heuristic'
601
602 ph1 = PipelineHeuristic(run_fname,data_fname,reads_pipeline_fn,param_fname,result_spliced_fname,result_unspliced_fname)
603
604 start = cpu()
605 ph1.filter()
606 stop = cpu()
607 print 'total time elapsed: %f' % (stop-start)
608
609 print 'time spend for get seq: %f' % ph1.get_time
610 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
611 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
612 print 'time spend for count time: %f' % ph1.count_time
613 print 'time spend for init time: %f' % ph1.init_time
614 print 'time spend for main_loop time: %f' % ph1.main_loop
615 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
616
617 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
618 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
619 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
620 print 'time spend forarray_stuff time: %f'% ph1.array_stuff