+ added framework code for training modus
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import resource
10
11 from qpalma.computeSpliceWeights import *
12 from qpalma.set_param_palma import *
13 from qpalma.computeSpliceAlignWithQuality import *
14
15 from numpy.matlib import mat,zeros,ones,inf
16 from numpy import inf,mean
17
18 from ParaParser import ParaParser,IN_VECTOR
19
20 from qpalma.Lookup import LookupTable
21 from qpalma.sequence_utils import SeqSpliceInfo,DataAccessWrapper,reverse_complement,unbracket_seq
22
23
24 def cpu():
25 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
26 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
27
28
29 class BracketWrapper:
30 fields = ['id', 'chr', 'pos', 'strand', 'seq', 'prb']
31
32 def __init__(self,filename):
33 self.parser = ParaParser("%lu%d%d%s%s%s",self.fields,len(self.fields),IN_VECTOR)
34 self.parser.parseFile(filename)
35
36 def __len__(self):
37 return self.parser.getSize(IN_VECTOR)
38
39 def __getitem__(self,key):
40 return self.parser.fetchEntry(key)
41
42 def __iter__(self):
43 self.counter = 0
44 self.size = self.parser.getSize(IN_VECTOR)
45 return self
46
47 def next(self):
48 if not self.counter < self.size:
49 raise StopIteration
50 return_val = self.parser.fetchEntry(self.counter)
51 self.counter += 1
52 return return_val
53
54
55 class PipelineHeuristic:
56 """
57 This class wraps the filter which decides whether an alignment found by
58 vmatch is spliced an should be then newly aligned using QPalma or not.
59 """
60
61 def __init__(self,run_fname,data_fname,param_fname,result_fname,settings):
62 """
63 We need a run object holding information about the nr. of support points
64 etc.
65 """
66
67 run = cPickle.load(open(run_fname))
68 self.run = run
69
70 start = cpu()
71 self.all_remapped_reads = BracketWrapper(data_fname)
72 stop = cpu()
73
74 print 'parsed %d reads in %f sec' % (len(self.all_remapped_reads),stop-start)
75
76 self.chromo_range = settings['allowed_fragments']
77
78 accessWrapper = DataAccessWrapper(settings)
79 seqInfo = SeqSpliceInfo(accessWrapper,settings['allowed_fragments'])
80
81 start = cpu()
82 self.lt1 = LookupTable(g_dir,acc_dir,don_dir,g_fmt,s_fmt,self.chromo_range)
83 stop = cpu()
84
85 print 'prefetched sequence and splice data in %f sec' % (stop-start)
86
87 self.result_spliced_fh = open('%s.spliced'%result_fname,'w+')
88 self.result_unspliced_fh = open('%s.unspliced'%result_fname,'w+')
89
90 start = cpu()
91
92 self.data_fname = data_fname
93
94 self.param = cPickle.load(open(param_fname))
95
96 # Set the parameters such as limits penalties for the Plifs
97 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
98
99 self.h = h
100 self.d = d
101 self.a = a
102 self.mmatrix = mmatrix
103 self.qualityPlifs = qualityPlifs
104
105 self.read_size = 38
106 self.prb_offset = 64
107 #self.prb_offset = 50
108
109 # parameters of the heuristics to decide whether the read is spliced
110 #self.splice_thresh = 0.005
111 self.splice_thresh = 0.01
112 self.max_intron_size = 2000
113 self.max_mismatch = 2
114 #self.splice_stop_thresh = 0.99
115 self.splice_stop_thresh = 0.8
116 self.spliced_bias = 0.0
117
118 param_lines = [\
119 ('%f','splice_thresh',self.splice_thresh),\
120 ('%d','max_intron_size',self.max_intron_size),\
121 ('%d','max_mismatch',self.max_mismatch),\
122 ('%f','splice_stop_thresh',self.splice_stop_thresh),\
123 ('%f','spliced_bias',self.spliced_bias)]
124
125 param_lines = [('# %s: '+form+'\n')%(name,val) for form,name,val in param_lines]
126 param_lines.append('# data: %s\n'%self.data_fname)
127 param_lines.append('# param: %s\n'%param_fname)
128
129 #pdb.set_trace()
130
131 for p_line in param_lines:
132 self.result_spliced_fh.write(p_line)
133 self.result_unspliced_fh.write(p_line)
134
135 #self.original_reads = {}
136
137 # we do not have this information
138 #for line in open(reads_pipeline_fn):
139 # line = line.strip()
140 # id,seq,q1,q2,q3 = line.split()
141 # id = int(id)
142 # self.original_reads[id] = seq
143
144 lengthSP = run['numLengthSuppPoints']
145 donSP = run['numDonSuppPoints']
146 accSP = run['numAccSuppPoints']
147 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
148 numq = run['numQualSuppPoints']
149 totalQualSP = run['totalQualSuppPoints']
150
151 currentPhi = zeros((run['numFeatures'],1))
152 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
153 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
154 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
155 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
156
157 totalQualityPenalties = self.param[-totalQualSP:]
158 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
159 self.currentPhi = currentPhi
160
161 # we want to identify spliced reads
162 # so true pos are spliced reads that are predicted "spliced"
163 self.true_pos = 0
164
165 # as false positives we count all reads that are not spliced but predicted
166 # as "spliced"
167 self.false_pos = 0
168
169 self.true_neg = 0
170 self.false_neg = 0
171
172 # total time spend for get seq and scores
173 self.get_time = 0.0
174 self.calcAlignmentScoreTime = 0.0
175 self.alternativeScoresTime = 0.0
176
177 self.count_time = 0.0
178 self.main_loop = 0.0
179 self.splice_site_time = 0.0
180 self.computeSpliceAlignWithQualityTime = 0.0
181 self.computeSpliceWeightsTime = 0.0
182 self.DotProdTime = 0.0
183 self.array_stuff = 0.0
184 stop = cpu()
185
186 self.init_time = stop-start
187
188 def filter(self):
189 """
190 This method...
191 """
192 run = self.run
193
194 ctr = 0
195 unspliced_ctr = 0
196 spliced_ctr = 0
197
198 print 'Starting filtering...'
199 _start = cpu()
200
201 #for readId,currentReadLocations in all_remapped_reads.items():
202 #for location in currentReadLocations[:1]:
203
204 for location,original_line in self.all_remapped_reads:
205
206 if ctr % 1000 == 0:
207 print ctr
208
209 id = location['id']
210 chr = location['chr']
211 pos = location['pos']
212 strand = location['strand']
213 seq = location['seq']
214 prb = location['prb']
215
216 id = int(id)
217
218 seq = seq.lower()
219
220 strand_map = {'D':'+', 'P':'-'}
221
222 strand = strand_map[strand]
223
224 if not chr in self.chromo_range:
225 continue
226
227 unb_seq = unbracket_seq(seq)
228
229 # forgot to do this
230 if strand == '-':
231 #pos = self.lt1.seqInfo.chromo_sizes[chr]-pos-self.read_size
232 unb_seq = reverse_complement(unb_seq)
233 seq = reverse_complement(seq)
234
235 effective_len = len(unb_seq)
236
237 genomicSeq_start = pos
238 #genomicSeq_stop = pos+effective_len-1
239 genomicSeq_stop = pos+effective_len
240
241 start = cpu()
242 #currentDNASeq_, currentAcc_, currentDon_ = self.seqInfo.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,False)
243 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop)
244
245 # just performed some tests to check LookupTable results
246 #assert currentDNASeq_ == currentDNASeq
247 #assert currentAcc_ == currentAcc_
248 #assert currentDon_ == currentDon_
249
250 stop = cpu()
251 self.get_time += stop-start
252
253 dna = currentDNASeq
254 exons = zeros((2,1))
255 exons[0,0] = 0
256 exons[1,0] = effective_len
257 est = unb_seq
258 original_est = seq
259 quality = map(lambda x:ord(x)-self.prb_offset,prb)
260
261 currentVMatchAlignment = dna, exons, est, original_est, quality,\
262 currentAcc, currentDon
263
264 #pdb.set_trace()
265
266 #alternativeAlignmentScores = self.calcAlternativeAlignments(location)
267
268 try:
269 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
270 except:
271 alternativeAlignmentScores = []
272
273 if alternativeAlignmentScores == []:
274 # no alignment necessary
275 maxAlternativeAlignmentScore = -inf
276 vMatchScore = 0.0
277 else:
278 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
279 # compute alignment for vmatch unspliced read
280 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
281 #print 'vMatchScore/alternativeScore: %f %f %s' % (vMatchScore,maxAlternativeAlignmentScore,strand)
282
283
284 start = cpu()
285
286 #print 'all candidates %s' % str(alternativeAlignmentScores)
287
288 new_id = id - 1000000300000
289
290 unspliced = False
291 # unspliced
292 if new_id > 0:
293 unspliced = True
294
295 # Seems that according to our learned parameters VMatch found a good
296 # alignment of the current read
297 if maxAlternativeAlignmentScore < vMatchScore:
298 unspliced_ctr += 1
299
300 self.result_unspliced_fh.write(original_line+'\n')
301
302 if unspliced:
303 self.true_neg += 1
304 else:
305 self.false_neg += 1
306
307 # We found an alternative alignment considering splice sites that scores
308 # higher than the VMatch alignment
309 else:
310 spliced_ctr += 1
311
312 self.result_spliced_fh.write(original_line+'\n')
313
314 if unspliced:
315 self.false_pos += 1
316 else:
317 self.true_pos += 1
318
319 ctr += 1
320 stop = cpu()
321 self.count_time = stop-start
322
323 _stop = cpu()
324 self.main_loop = _stop-_start
325
326 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
327 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
328 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
329
330
331 def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
332
333 def signum(a):
334 if a>0:
335 return 1
336 elif a<0:
337 return -1
338 else:
339 return 0
340
341 proximal_acc = []
342 for idx in xrange(max_intron_size, max_intron_size+read_size/2):
343 if currentAcc[idx]>= splice_thresh:
344 proximal_acc.append((idx,currentAcc[idx]))
345
346 proximal_acc.sort(lambda x,y: signum(x[1]-y[1]))
347 proximal_acc=proximal_acc[-2:]
348
349 distal_acc = []
350 for idx in xrange(max_intron_size+read_size, len(currentAcc)):
351 if currentAcc[idx]>= splice_thresh and idx+read_size<len(currentAcc):
352 distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
353
354 #distal_acc.sort(lambda x,y: signum(x[1]-y[1]))
355 #distal_acc=distal_acc[-2:]
356
357 proximal_don = []
358 for idx in xrange(max_intron_size+read_size/2, max_intron_size+read_size):
359 if currentDon[idx] >= splice_thresh:
360 proximal_don.append((idx, currentDon[idx]))
361
362 proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
363 proximal_don=proximal_don[-2:]
364
365 distal_don = []
366 for idx in xrange(1, max_intron_size):
367 if currentDon[idx] >= splice_thresh and idx>read_size:
368 distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
369
370 distal_don.sort(lambda x,y: y[0]-x[0])
371 #distal_don=distal_don[-2:]
372
373 return proximal_acc,proximal_don,distal_acc,distal_don
374
375
376 def calcAlternativeAlignments(self,location):
377 """
378 Given an alignment proposed by Vmatch this function calculates possible
379 alternative alignments taking into account for example matched
380 donor/acceptor positions.
381 """
382
383 run = self.run
384
385 id = location['id']
386 chr = location['chr']
387 pos = location['pos']
388 strand = location['strand']
389 original_est = location['seq']
390 quality = location['prb']
391 quality = map(lambda x:ord(x)-self.prb_offset,quality)
392
393 original_est = original_est.lower()
394 est = unbracket_seq(original_est)
395 effective_len = len(est)
396
397 genomicSeq_start = pos - self.max_intron_size
398 genomicSeq_stop = pos + self.max_intron_size + len(est)
399
400 strand_map = {'D':'+', 'P':'-'}
401 strand = strand_map[strand]
402
403 if strand == '-':
404 est = reverse_complement(est)
405 original_est = reverse_complement(original_est)
406
407 start = cpu()
408 #currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop, run['dna_flat_files'])
409 currentDNASeq, currentAcc, currentDon = self.lt1.get_seq_and_scores(chr, strand, genomicSeq_start, genomicSeq_stop)
410 stop = cpu()
411 self.get_time += stop-start
412 dna = currentDNASeq
413
414 proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc, currentDon, dna, self.max_intron_size, len(est), self.splice_thresh)
415
416 alternativeScores = []
417
418 #pdb.set_trace()
419
420 # inlined
421 h = self.h
422 d = self.d
423 a = self.a
424 mmatrix = self.mmatrix
425 qualityPlifs = self.qualityPlifs
426 # inlined
427
428 # find an intron on the 3' end
429 _start = cpu()
430 for (don_pos,don_score) in proximal_don:
431 DonorScore = calculatePlif(d, [don_score])[0]
432
433 for (acc_pos,acc_score,acc_dna) in distal_acc:
434
435 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
436 AcceptorScore = calculatePlif(a, [acc_score])[0]
437
438 #print 'don splice: ', (don_pos,don_score), (acc_pos,acc_score,acc_dna), (DonorScore,IntronScore,AcceptorScore)
439
440 # construct a new "original_est"
441 original_est_cut=''
442
443 est_ptr=0
444 dna_ptr=self.max_intron_size
445 ptr=0
446 acc_dna_ptr=0
447 num_mismatch = 0
448
449 while ptr<len(original_est):
450 #print acc_dna_ptr,len(acc_dna),acc_pos,don_pos
451
452 if original_est[ptr]=='[':
453 dnaletter=original_est[ptr+1]
454 estletter=original_est[ptr+2]
455 if dna_ptr < don_pos:
456 original_est_cut+=original_est[ptr:ptr+4]
457 num_mismatch += 1
458 else:
459 if acc_dna[acc_dna_ptr]==estletter:
460 original_est_cut += estletter # EST letter
461 else:
462 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
463 num_mismatch += 1
464 #print '['+acc_dna[acc_dna_ptr]+estletter+']'
465 acc_dna_ptr+=1
466 ptr+=4
467 else:
468 dnaletter=original_est[ptr]
469 estletter=dnaletter
470
471 if dna_ptr < don_pos:
472 original_est_cut+=estletter # EST letter
473 else:
474 if acc_dna[acc_dna_ptr]==estletter:
475 original_est_cut += estletter # EST letter
476 else:
477 num_mismatch += 1
478 original_est_cut += '['+acc_dna[acc_dna_ptr]+estletter+']' # EST letter
479 #print '('+acc_dna[acc_dna_ptr]+estletter+')'
480 acc_dna_ptr+=1
481
482 ptr+=1
483
484 if estletter=='-':
485 dna_ptr+=1
486 elif dnaletter=='-':
487 est_ptr+=1
488 else:
489 dna_ptr+=1
490 est_ptr+=1
491 if num_mismatch>self.max_mismatch:
492 continue
493
494 assert(dna_ptr<=len(dna))
495 assert(est_ptr<=len(est))
496
497 #print original_est, original_est_cut
498
499 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
500 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
501
502 alternativeScores.append(score)
503
504 if acc_score>=self.splice_stop_thresh:
505 break
506
507 #pdb.set_trace()
508
509 _stop = cpu()
510 self.alternativeScoresTime += _stop-_start
511
512 # find an intron on the 5' end
513 _start = cpu()
514 for (acc_pos,acc_score) in proximal_acc:
515
516 AcceptorScore = calculatePlif(a, [acc_score])[0]
517
518 for (don_pos,don_score,don_dna) in distal_don:
519
520 DonorScore = calculatePlif(d, [don_score])[0]
521 IntronScore = calculatePlif(h, [acc_pos-don_pos])[0]
522
523 #print 'acc splice: ', (don_pos,don_score,don_dna), (acc_pos,acc_score), (DonorScore,IntronScore,AcceptorScore)
524
525 # construct a new "original_est"
526 original_est_cut=''
527
528 est_ptr=0
529 dna_ptr=self.max_intron_size
530 ptr=0
531 num_mismatch = 0
532 don_dna_ptr=len(don_dna)-(acc_pos-self.max_intron_size)-1
533 while ptr<len(original_est):
534
535 if original_est[ptr]=='[':
536 dnaletter=original_est[ptr+1]
537 estletter=original_est[ptr+2]
538 if dna_ptr > acc_pos:
539 original_est_cut+=original_est[ptr:ptr+4]
540 num_mismatch += 1
541 else:
542 if don_dna[don_dna_ptr]==estletter:
543 original_est_cut += estletter # EST letter
544 else:
545 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
546 num_mismatch += 1
547 #print '['+don_dna[don_dna_ptr]+estletter+']'
548 don_dna_ptr+=1
549 ptr+=4
550 else:
551 dnaletter=original_est[ptr]
552 estletter=dnaletter
553
554 if dna_ptr > acc_pos:
555 original_est_cut+=estletter # EST letter
556 else:
557 if don_dna[don_dna_ptr]==estletter:
558 original_est_cut += estletter # EST letter
559 else:
560 original_est_cut += '['+don_dna[don_dna_ptr]+estletter+']' # EST letter
561 num_mismatch += 1
562 #print '('+don_dna[don_dna_ptr]+estletter+')'
563 don_dna_ptr+=1
564
565 ptr+=1
566
567 if estletter=='-':
568 dna_ptr+=1
569 elif dnaletter=='-':
570 est_ptr+=1
571 else:
572 dna_ptr+=1
573 est_ptr+=1
574
575 if num_mismatch>self.max_mismatch:
576 continue
577 assert(dna_ptr<=len(dna))
578 assert(est_ptr<=len(est))
579
580 #print original_est, original_est_cut
581
582 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
583 score += AcceptorScore + IntronScore + DonorScore + self.spliced_bias
584
585 alternativeScores.append(score)
586
587 if don_score>=self.splice_stop_thresh:
588 break
589
590 _stop = cpu()
591 self.alternativeScoresTime += _stop-_start
592
593 return alternativeScores
594
595
596 def calcAlignmentScore(self,alignment):
597 """
598 Given an alignment (dna,exons,est) and the current parameter for QPalma
599 this function calculates the dot product of the feature representation of
600 the alignment and the parameter vector i.e the alignment score.
601 """
602
603 start = cpu()
604 run = self.run
605
606 # Lets start calculation
607 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
608
609 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
610
611 stop = cpu()
612 self.calcAlignmentScoreTime += stop-start
613
614 return score