optimized heuristic
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10 import resource
11
12 from qpalma.DataProc import *
13 from qpalma.computeSpliceWeights import *
14 from qpalma.set_param_palma import *
15 from qpalma.computeSpliceAlignWithQuality import *
16 from qpalma.penalty_lookup_new import *
17 from qpalma.compute_donacc import *
18 from qpalma.TrainingParam import Param
19 from qpalma.Plif import Plf
20
21 #from qpalma.tools.splicesites import getDonAccScores
22 from qpalma.Configuration import *
23
24 from compile_dataset import getSpliceScores, get_seq_and_scores
25
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy import inf,mean
28
29 from qpalma.parsers import PipelineReadParser
30
31
32 def unbracket_est(est):
33 new_est = ''
34 e = 0
35
36 while True:
37 if e >= len(est):
38 break
39
40 if est[e] == '[':
41 new_est += est[e+2]
42 e += 4
43 else:
44 new_est += est[e]
45 e += 1
46
47 return "".join(new_est).lower()
48
49
50 class PipelineHeuristic:
51 """
52 This class wraps the filter which decides whether an alignment found by
53 vmatch is spliced an should be then newly aligned using QPalma or not.
54 """
55
56 def __init__(self,run_fname,data_fname,param_fname,result_spliced_fname,result_unspliced_fname):
57 """
58 We need a run object holding information about the nr. of support points
59 etc.
60 """
61
62 run = cPickle.load(open(run_fname))
63 self.run = run
64
65 self.result_spliced_fh = open(result_spliced_fname,'w+')
66 self.result_unspliced_fh = open(result_unspliced_fname,'w+')
67
68 start = cpu()
69
70 self.data_fname = data_fname
71
72 self.param = cPickle.load(open(param_fname))
73
74 # Set the parameters such as limits penalties for the Plifs
75 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
76
77 self.h = h
78 self.d = d
79 self.a = a
80 self.mmatrix = mmatrix
81 self.qualityPlifs = qualityPlifs
82
83 self.read_size = 36
84
85 # parameters of the heuristics to decide whether the read is spliced
86 self.splice_thresh = 0.005
87 self.max_intron_size = 2000
88 self.max_mismatch = 2
89 self.splice_stop_thresh = 0.99
90 self.spliced_bias = 0.0
91
92 self.original_reads = {}
93
94 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
95 line = line.strip()
96 id,seq,q1,q2,q3 = line.split()
97 id = int(id)
98 self.original_reads[id] = seq
99
100 lengthSP = run['numLengthSuppPoints']
101 donSP = run['numDonSuppPoints']
102 accSP = run['numAccSuppPoints']
103 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
104 numq = run['numQualSuppPoints']
105 totalQualSP = run['totalQualSuppPoints']
106
107 currentPhi = zeros((run['numFeatures'],1))
108 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
109 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
110 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
111 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
112
113 totalQualityPenalties = self.param[-totalQualSP:]
114 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
115 self.currentPhi = currentPhi
116
117 # we want to identify spliced reads
118 # so true pos are spliced reads that are predicted "spliced"
119 self.true_pos = 0
120
121 # as false positives we count all reads that are not spliced but predicted
122 # as "spliced"
123 self.false_pos = 0
124
125 self.true_neg = 0
126 self.false_neg = 0
127
128 # total time spend for get seq and scores
129 self.get_time = 0.0
130 self.calcAlignmentScoreTime = 0.0
131 self.alternativeScoresTime = 0.0
132
133 self.count_time = 0.0
134 self.read_parsing = 0.0
135 self.main_loop = 0.0
136 self.splice_site_time = 0.0
137 self.computeSpliceAlignWithQualityTime = 0.0
138 self.computeSpliceWeightsTime = 0.0
139 self.DotProdTime = 0.0
140 self.array_stuff = 0.0
141 stop = cpu()
142
143 self.init_time = stop-start
144
145 def filter(self):
146 """
147 This method...
148 """
149 run = self.run
150
151 start = cpu()
152
153 rrp = PipelineReadParser(self.data_fname)
154 all_remapped_reads = rrp.parse()
155
156 stop = cpu()
157
158 self.read_parsing = stop-start
159
160 ctr = 0
161 unspliced_ctr = 0
162 spliced_ctr = 0
163
164 print 'Starting filtering...'
165 _start = cpu()
166
167 for readId,currentReadLocations in all_remapped_reads.items():
168 for location in currentReadLocations[:1]:
169
170 id = location['id']
171 chr = location['chr']
172 pos = location['pos']
173 strand = location['strand']
174 mismatch = location['mismatches']
175 length = location['length']
176 off = location['offset']
177 seq = location['seq']
178 prb = location['prb']
179 cal_prb = location['cal_prb']
180 chastity = location['chastity']
181
182 id = int(id)
183
184 strand_map = {'+':'D', '-':'P'}
185
186
187 original_line = '%d\t%d\t%d\t%s\t%d\t%d\t%d\t%s\t%s\t%s\t%s\n' %\
188 (id,chr,pos,strand_map[strand],int(mismatch),int(length),int(off),seq.upper(),prb,cal_prb,chastity)
189
190
191 if strand == '-':
192 continue
193
194 if ctr == 1000:
195 break
196
197 #if pos > 10000000:
198 # continue
199
200 unb_seq = unbracket_est(seq)
201 effective_len = len(unb_seq)
202
203 genomicSeq_start = pos
204 genomicSeq_stop = pos+effective_len-1
205
206 start = cpu()
207 #print genomicSeq_start,genomicSeq_stop
208 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
209 stop = cpu()
210 self.get_time += stop-start
211
212 dna = currentDNASeq
213 exons = zeros((2,1))
214 exons[0,0] = 0
215 exons[1,0] = effective_len
216 est = unb_seq
217 original_est = seq
218 quality = prb
219
220 #pdb.set_trace()
221
222 currentVMatchAlignment = dna, exons, est, original_est, quality,\
223 currentAcc, currentDon
224 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
225
226 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
227
228 start = cpu()
229 # found no alternatives
230 #if alternativeAlignmentScores == []:
231 # continue
232
233 if alternativeAlignmentScores == []:
234 maxAlternativeAlignmentScore = -inf
235 else:
236 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
237 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
238 #print 'all candidates %s' % str(alternativeAlignmentScores)
239
240 new_id = id - 1000000300000
241
242 unspliced = False
243 # unspliced
244 if new_id > 0:
245 unspliced = True
246
247 # Seems that according to our learned parameters VMatch found a good
248 # alignment of the current read
249 if maxAlternativeAlignmentScore < vMatchScore:
250 unspliced_ctr += 1
251
252 self.result_unspliced_fh.write(original_line)
253
254 if unspliced:
255 self.true_neg += 1
256 else:
257 self.false_neg += 1
258
259 # We found an alternative alignment considering splice sites that scores
260 # higher than the VMatch alignment
261 else:
262 spliced_ctr += 1
263
264 self.result_spliced_fh.write(original_line)
265
266 if unspliced:
267 self.false_pos += 1
268 else:
269 self.true_pos += 1
270
271 ctr += 1
272 stop = cpu()
273 self.count_time = stop-start
274
275 _stop = cpu()
276 self.main_loop = _stop-_start
277
278 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
279 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
280 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
281
282
283 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
284
285 def signum(a):
286 if a>0:
287 return 1
288 elif a<0:
289 return -1
290 else:
291 return 0
292
293 proximal_acc = []
294 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
295 if currentAcc[idx]>= splice_thresh:
296 proximal_acc.append((idx,currentAcc[idx]))
297
298 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
299 proximal_acc=proximal_acc[-2:]
300
301 distal_acc = []
302 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
303 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
304 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
305
306 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
307 #distal_acc=distal_acc[-2:]
308
309
310 proximal_don = []
311 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
312 if currentDon[idx] >= splice_thresh:
313 proximal_don.append((idx, currentDon[idx]))
314
315 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
316 proximal_don=proximal_don[-2:]
317
318 distal_don = []
319 for idx in xrange(1, max_intron_size):
320 if currentDon[idx] >= splice_thresh and idx>read_size:
321 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
322
323 distal_don.sort(lambda x,y: y[0]-x[0])
324 #distal_don=distal_don[-2:]
325
326 return proximal_acc,proximal_don,distal_acc,distal_don
327
328 def calcAlternativeAlignments(self,location):
329 """
330 Given an alignment proposed by Vmatch this function calculates possible
331 alternative alignments taking into account for example matched
332 donor/acceptor positions.
333 """
334
335 run = self.run
336
337 id = location['id']
338 chr = location['chr']
339 pos = location['pos']
340 strand = location['strand']
341 original_est = location['seq']
342 quality = location['prb']
343 cal_prb = location['cal_prb']
344
345 est = unbracket_est(original_est)
346 effective_len = len(est)
347
348 genomicSeq_start = pos - self.max_intron_size
349 genomicSeq_stop = pos + self.max_intron_size + len(est)
350
351 start = cpu()
352 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
353 stop = cpu()
354 self.get_time += stop-start
355 dna = currentDNASeq
356
357 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
358
359 alternativeScores = []
360
361 # inlined
362 h = self.h
363 d = self.d
364 a = self.a
365 mmatrix = self.mmatrix
366 qualityPlifs = self.qualityPlifs
367 # inlined
368
369 # find an intron on the 3' end
370 _start = cpu()
371 for (don_pos,don_score) in proximal_don:
372 DonorScore = calculatePlif(d, [don_score])[0]
373
374 for (acc_pos,acc_score,acc_dna) in distal_acc:
375
376 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
377 AcceptorScore = calculatePlif(a, [acc_score])[0]
378
379 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
380
381 # construct a new "original_est"
382 original_est_cut=''
383
384 est_ptr=0
385 dna_ptr=self.max_intron_size
386 ptr=0
387 acc_dna_ptr=0
388 num_mismatch = 0
389
390 while ptr<len(original_est):
391 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
392
393 if original_est[ptr]=='[':
394 dnaletter=original_est[ptr+1]
395 estletter=original_est[ptr+2]
396 if dna_ptr < don_pos:
397 original_est_cut+=original_est[ptr:ptr+4]
398 num_mismatch += 1
399 else:
400 if acc_dna[acc_dna_ptr]==estletter:
401 original_est_cut += estletter # EST letter
402 else:
403 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
404 num_mismatch += 1
405 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
406 acc_dna_ptr+=1
407 ptr+=4
408 else:
409 dnaletter=original_est[ptr]
410 estletter=dnaletter
411
412 if dna_ptr < don_pos:
413 original_est_cut+=estletter # EST letter
414 else:
415 if acc_dna[acc_dna_ptr]==estletter:
416 original_est_cut += estletter # EST letter
417 else:
418 num_mismatch += 1
419 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
420 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
421 acc_dna_ptr+=1
422
423 ptr+=1
424
425 if estletter=='-':
426 dna_ptr+=1
427 elif dnaletter=='-':
428 est_ptr+=1
429 else:
430 dna_ptr+=1
431 est_ptr+=1
432 if num_mismatch>self.max_mismatch:
433 continue
434
435 assert(dna_ptr<=len(dna))
436 assert(est_ptr<=len(est))
437
438 #print original_est, original_est_cut
439
440 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
441 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
442
443 alternativeScores.append(score)
444
445 if acc_score>=self.splice_stop_thresh:
446 break
447
448 _stop = cpu()
449 self.alternativeScoresTime += _stop-_start
450
451 # find an intron on the 5' end
452 _start = cpu()
453 for (acc_pos,acc_score) in proximal_acc:
454
455 AcceptorScore = calculatePlif(a, [acc_score])[0]
456
457 for (don_pos,don_score,don_dna) in distal_don:
458
459 DonorScore = calculatePlif(d, [don_score])[0]
460 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
461
462 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
463
464 # construct a new "original_est"
465 original_est_cut=''
466
467 est_ptr=0
468 dna_ptr=self.max_intron_size
469 ptr=0
470 num_mismatch = 0
471 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
472 while ptr<len(original_est):
473
474 if original_est[ptr]=='[':
475 dnaletter=original_est[ptr+1]
476 estletter=original_est[ptr+2]
477 if dna_ptr > acc_pos:
478 original_est_cut+=original_est[ptr:ptr+4]
479 num_mismatch += 1
480 else:
481 if don_dna[don_dna_ptr]==estletter:
482 original_est_cut += estletter # EST letter
483 else:
484 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
485 num_mismatch += 1
486 #print '['+don_dna[don_dna_ptr]+estletter+']'
487 don_dna_ptr+=1
488 ptr+=4
489 else:
490 dnaletter=original_est[ptr]
491 estletter=dnaletter
492
493 if dna_ptr > acc_pos:
494 original_est_cut+=estletter # EST letter
495 else:
496 if don_dna[don_dna_ptr]==estletter:
497 original_est_cut += estletter # EST letter
498 else:
499 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
500 num_mismatch += 1
501 #print '('+don_dna[don_dna_ptr]+estletter+')'
502 don_dna_ptr+=1
503
504 ptr+=1
505
506 if estletter=='-':
507 dna_ptr+=1
508 elif dnaletter=='-':
509 est_ptr+=1
510 else:
511 dna_ptr+=1
512 est_ptr+=1
513
514 if num_mismatch>self.max_mismatch:
515 continue
516 assert(dna_ptr<=len(dna))
517 assert(est_ptr<=len(est))
518
519 #print original_est, original_est_cut
520
521 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
522 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
523
524 alternativeScores.append(score)
525
526 if don_score>=self.splice_stop_thresh:
527 break
528
529 _stop = cpu()
530 self.alternativeScoresTime += _stop-_start
531
532 return alternativeScores
533
534
535 def calcAlignmentScore(self,alignment):
536 """
537 Given an alignment (dna,exons,est) and the current parameter for QPalma
538 this function calculates the dot product of the feature representation of
539 the alignment and the parameter vector i.e the alignment score.
540 """
541
542 start = cpu()
543 run = self.run
544
545 # Lets start calculation
546 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
547
548 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
549
550 stop = cpu()
551 self.calcAlignmentScoreTime += stop-start
552
553 return score
554
555
556 def cpu():
557 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
558 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
559
560
561 if __name__ == '__main__':
562 #run_fname = sys.argv[1]
563 #data_fname = sys.argv[2]
564 #param_filename = sys.argv[3]
565
566 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
567 jp = os.path.join
568
569 run_fname = jp(dir,'run_object.pickle')
570 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_1k'
571
572 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
573 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
574
575 param_fname = jp(dir,'param_500.pickle')
576
577 result_spliced_fname = 'splicedReads.heuristic'
578 result_unspliced_fname = 'unsplicedReads.heuristic'
579
580 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname,result_spliced_fname,result_unspliced_fname)
581
582 start = cpu()
583 ph1.filter()
584 stop = cpu()
585
586 print 'total time elapsed: %f' % (stop-start)
587 print 'time spend for get seq: %f' % ph1.get_time
588 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
589 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
590 print 'time spend for count time: %f' % ph1.count_time
591 print 'time spend for init time: %f' % ph1.init_time
592 print 'time spend for read_parsing time: %f' % ph1.read_parsing
593 print 'time spend for main_loop time: %f' % ph1.main_loop
594 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
595
596 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
597 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
598 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
599 print 'time spend forarray_stuff time: %f'% ph1.array_stuff
600 #import cProfile
601 #cProfile.run('ph1.filter()')
602
603 #import hotshot
604 #p = hotshot.Profile('profile.log')
605 #p.runcall(ph1.filter)