git-svn-id: http://svn.tuebingen.mpg.de/ag-raetsch/projects/QPalma@8629 e1793c9e...
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10 import resource
11
12 from qpalma.DataProc import *
13 from qpalma.computeSpliceWeights import *
14 from qpalma.set_param_palma import *
15 from qpalma.computeSpliceAlignWithQuality import *
16 from qpalma.penalty_lookup_new import *
17 from qpalma.compute_donacc import *
18 from qpalma.TrainingParam import Param
19 from qpalma.Plif import Plf
20
21 from qpalma.tools.splicesites import getDonAccScores
22 from qpalma.Configuration import *
23
24 from compile_dataset import getSpliceScores, get_seq_and_scores
25
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy import inf,mean
28
29 from qpalma.parsers import PipelineReadParser
30
31
32 def unbracket_est(est):
33 new_est = ''
34 e = 0
35
36 while True:
37 if e >= len(est):
38 break
39
40 if est[e] == '[':
41 new_est += est[e+2]
42 e += 4
43 else:
44 new_est += est[e]
45 e += 1
46
47 return "".join(new_est).lower()
48
49
50 class PipelineHeuristic:
51 """
52 This class wraps the filter which decides whether an alignment found by
53 vmatch is spliced an should be then newly aligned using QPalma or not.
54 """
55
56 def __init__(self,run_fname,data_fname,param_fname):
57 """
58 We need a run object holding information about the nr. of support points
59 etc.
60 """
61
62 run = cPickle.load(open(run_fname))
63 self.run = run
64
65 start = cpu()
66
67 self.data_fname = data_fname
68
69 self.param = cPickle.load(open(param_fname))
70
71 # Set the parameters such as limits penalties for the Plifs
72 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
73
74 self.h = h
75 self.d = d
76 self.a = a
77 self.mmatrix = mmatrix
78 self.qualityPlifs = qualityPlifs
79
80 # when we look for alternative alignments with introns this value is the
81 # mean of intron size
82 self.intron_size = 90
83
84 self.read_size = 36
85
86 self.original_reads = {}
87
88 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
89 line = line.strip()
90 id,seq,q1,q2,q3 = line.split()
91 id = int(id)
92 self.original_reads[id] = seq
93
94 lengthSP = run['numLengthSuppPoints']
95 donSP = run['numDonSuppPoints']
96 accSP = run['numAccSuppPoints']
97 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
98 numq = run['numQualSuppPoints']
99 totalQualSP = run['totalQualSuppPoints']
100
101 currentPhi = zeros((run['numFeatures'],1))
102 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
103 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
104 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
105 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
106
107 totalQualityPenalties = self.param[-totalQualSP:]
108 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
109 self.currentPhi = currentPhi
110
111 # we want to identify spliced reads
112 # so true pos are spliced reads that are predicted "spliced"
113 self.true_pos = 0
114
115 # as false positives we count all reads that are not spliced but predicted
116 # as "spliced"
117 self.false_pos = 0
118
119 self.true_neg = 0
120 self.false_neg = 0
121
122 # total time spend for get seq and scores
123 self.get_time = 0.0
124 self.calcAlignmentScoreTime = 0.0
125 self.alternativeScoresTime = 0.0
126
127 self.count_time = 0.0
128 self.read_parsing = 0.0
129 self.main_loop = 0.0
130 self.splice_site_time = 0.0
131 self.computeSpliceAlignWithQualityTime = 0.0
132 self.computeSpliceWeightsTime = 0.0
133 self.DotProdTime = 0.0
134 self.array_stuff = 0.0
135 stop = cpu()
136
137 self.init_time = stop-start
138
139 def filter(self):
140 """
141 This method...
142 """
143 run = self.run
144
145 start = cpu()
146
147 rrp = PipelineReadParser(self.data_fname)
148 all_remapped_reads = rrp.parse()
149
150 stop = cpu()
151
152 self.read_parsing = stop-start
153
154 ctr = 0
155 unspliced_ctr = 0
156 spliced_ctr = 0
157
158 print 'Starting filtering...'
159 _start = cpu()
160
161 for readId,currentReadLocations in all_remapped_reads.items():
162 for location in currentReadLocations[:1]:
163
164 id = location['id']
165 chr = location['chr']
166 pos = location['pos']
167 strand = location['strand']
168 mismatch = location['mismatches']
169 length = location['length']
170 off = location['offset']
171 seq = location['seq']
172 prb = location['prb']
173 cal_prb = location['cal_prb']
174 chastity = location['chastity']
175
176 id = int(id)
177
178 if strand == '-':
179 continue
180
181 if ctr == 100:
182 break
183
184 #if pos > 10000000:
185 # continue
186
187 unb_seq = unbracket_est(seq)
188 effective_len = len(unb_seq)
189
190 genomicSeq_start = pos
191 genomicSeq_stop = pos+effective_len-1
192
193 start = cpu()
194 #print genomicSeq_start,genomicSeq_stop
195 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
196 stop = cpu()
197 self.get_time += stop-start
198
199 dna = currentDNASeq
200 exons = zeros((2,1))
201 exons[0,0] = 0
202 exons[1,0] = effective_len
203 est = unb_seq
204 original_est = seq
205 quality = prb
206
207 #pdb.set_trace()
208
209 currentVMatchAlignment = dna, exons, est, original_est, quality,\
210 currentAcc, currentDon
211 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
212
213 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
214
215 start = cpu()
216 # found no alternatives
217 if alternativeAlignmentScores == []:
218 continue
219
220 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
221 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
222 #print 'all candidates %s' % str(alternativeAlignmentScores)
223
224 new_id = id - 1000000300000
225
226 unspliced = False
227 # unspliced
228 if new_id > 0:
229 unspliced = True
230
231 # Seems that according to our learned parameters VMatch found a good
232 # alignment of the current read
233 if maxAlternativeAlignmentScore < vMatchScore:
234 unspliced_ctr += 1
235
236 if unspliced:
237 self.true_neg += 1
238 else:
239 self.false_neg += 1
240
241 # We found an alternative alignment considering splice sites that scores
242 # higher than the VMatch alignment
243 else:
244 spliced_ctr += 1
245
246 if unspliced:
247 self.false_pos += 1
248 else:
249 self.true_pos += 1
250
251 ctr += 1
252 stop = cpu()
253 self.count_time = stop-start
254
255 _stop = cpu()
256 self.main_loop = _stop-_start
257
258 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
259 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
260 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
261
262
263 def findHighestScoringSpliceSites(self,currentAcc,currentDon):
264
265 acc = []
266 for idx,score in enumerate(currentAcc):
267 if score > -inf:
268 acc.append((idx,score))
269 if idx>self.read_size:
270 break
271
272 acc.sort(lambda x,y: x[0]-y[0])
273
274 don = []
275 for idx,score in enumerate(currentDon):
276 if score > -inf:
277 don.append((idx,score))
278 if idx>self.read_size:
279 break
280
281 don.sort(lambda x,y: x[0]-y[0])
282
283 return don,acc
284
285
286 def calcAlternativeAlignments(self,location):
287 """
288 Given an alignment proposed by Vmatch this function calculates possible
289 alternative alignments taking into account for example matched
290 donor/acceptor positions.
291 """
292
293 run = self.run
294
295 id = location['id']
296 chr = location['chr']
297 pos = location['pos']
298 strand = location['strand']
299 seq = location['seq']
300 #orig_seq = location['orig_seq']
301 prb = location['prb']
302 cal_prb = location['cal_prb']
303
304 orig_seq = self.original_reads[int(id)]
305
306 unb_seq = unbracket_est(seq)
307 effective_len = len(unb_seq)
308
309 genomicSeq_start = pos
310 genomicSeq_stop = pos+self.intron_size*2+self.read_size*2
311
312 start = cpu()
313 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
314 stop = cpu()
315 self.get_time += stop-start
316
317 dna = currentDNASeq
318
319 start = cpu()
320 alt_don,alt_acc = self.findHighestScoringSpliceSites(currentAcc,currentDon)
321 stop = cpu()
322 self.splice_site_time = stop-start
323
324 alternativeScores = []
325
326 exons = zeros((2,2),dtype=numpy.int)
327 est = unb_seq
328 original_est = seq
329 quality = prb
330
331 # inlined
332 h = self.h
333 d = self.d
334 a = self.a
335 mmatrix = self.mmatrix
336 qualityPlifs = self.qualityPlifs
337 # inlined
338
339 IntronScore = calculatePlif(h, [self.intron_size])[0]
340 dummyAcceptorScore = calculatePlif(a, [0.25])[0]
341 dummyDonorScore = calculatePlif(d, [0.25])[0]
342 print IntronScore,dummyAcceptorScore,dummyDonorScore
343
344 _start = cpu()
345 for (don_pos,don_score) in alt_don:
346 start = cpu()
347
348 acc_pos = don_pos + self.intron_size
349
350 exons[0,0] = 0
351 exons[0,1] = don_pos
352 exons[1,0] = acc_pos+1
353 exons[1,1] = acc_pos+1+(self.read_size-don_pos)
354
355 _dna = dna[:int(exons[1,1])]
356 _dna = _dna[:exons[1,0]] + est[don_pos:]# only correct if there are no indels!!!
357
358 _currentAcc = currentAcc[:int(exons[1,1])]
359 _currentAcc = [0.25]*len(_currentAcc)
360
361 _currentDon = currentDon[:int(exons[1,1])]
362 #_currentDon = [0.25]*len(_currentDon)
363
364 currentVMatchAlignment = _dna, exons, est, original_est, quality,\
365 _currentAcc, _currentDon
366
367 stop = cpu()
368 self.array_stuff += stop - start
369
370 #alternativeScore = self.calcAlignmentScore(currentVMatchAlignment)
371 #alternativeScores.append(self.calcAlignmentScore(currentVMatchAlignment))
372
373 # Lets start calculation
374 dna, exons, est, original_est, quality, acc_supp, don_supp =\
375 currentVMatchAlignment
376
377 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
378 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
379 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
380 quality, qualityPlifs, run)
381
382 stop = cpu()
383 self.computeSpliceAlignWithQualityTime += stop-start
384 start = cpu()
385
386 # Calculate the weights
387 trueWeightDon, trueWeightAcc, trueWeightIntron =\
388 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
389
390 #for i in xrange(0,len(trueWeightDon)):
391 # trueWeightDon[i]=0.0
392 #for i in xrange(0,len(trueWeightAcc)):
393 # trueWeightAcc[i]=0.0
394 #for i in xrange(0,len(trueWeightIntron)):
395 # trueWeightIntron[i]=0.0
396
397 stop = cpu()
398 self.computeSpliceWeightsTime += stop-start
399
400 start = cpu()
401
402 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
403
404 # Calculate w'phi(x,y) the total score of the alignment
405 alternativeScores.append((trueWeight.T * self.currentPhi)[0,0])
406
407 # remove mismatching positions in the second exon
408 original_est_cut=''
409
410 est_ptr=0
411 dna_ptr=0
412 ptr=0
413 while ptr<len(original_est):
414
415 if original_est[ptr]=='[':
416 dnaletter=original_est[ptr+1]
417 estletter=original_est[ptr+2]
418 if est_ptr<=exons[0,1]:
419 original_est_cut+=original_est[ptr:ptr+4]
420 else:
421 original_est_cut+=estletter # EST letter
422 ptr+=4
423 else:
424 dnaletter=original_est[ptr]
425 estletter=dnaletter
426
427 original_est_cut+=estletter # EST letter
428 ptr+=1
429
430 if estletter=='-':
431 dna_ptr+=1
432 elif dnaletter=='-':
433 est_ptr+=1
434 else:
435 dna_ptr+=1
436 est_ptr+=1
437
438 assert(dna_ptr<=len(dna))
439 assert(est_ptr<=len(est))
440
441 print original_est, original_est_cut
442
443 # new score
444 DonorScore = calculatePlif(d, [don_score])[0]
445 #print 'don: %f,%f' % (DonorScore, don_score)
446
447 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
448 score += dummyAcceptorScore + IntronScore + DonorScore
449
450 print 'diff %f,%f,%f' % ((trueWeight.T * self.currentPhi)[0,0] - score,(trueWeight.T * self.currentPhi)[0,0], score)
451
452 stop = cpu()
453 self.DotProdTime += stop-start
454
455 _stop = cpu()
456 self.alternativeScoresTime += _stop-_start
457
458 return alternativeScores
459
460
461 def calcAlignmentScore(self,alignment):
462 """
463 Given an alignment (dna,exons,est) and the current parameter for QPalma
464 this function calculates the dot product of the feature representation of
465 the alignment and the parameter vector i.e the alignment score.
466 """
467
468 start = cpu()
469 run = self.run
470
471 # Lets start calculation
472 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
473
474 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
475
476 stop = cpu()
477 self.calcAlignmentScoreTime += stop-start
478
479 return score
480
481
482 def cpu():
483 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
484 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
485
486
487 if __name__ == '__main__':
488 #run_fname = sys.argv[1]
489 #data_fname = sys.argv[2]
490 #param_filename = sys.argv[3]
491
492 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
493 jp = os.path.join
494
495 run_fname = jp(dir,'run_object.pickle')
496 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_flag'
497
498 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
499 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
500
501 param_fname = jp(dir,'param_500.pickle')
502
503 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname)
504
505 start = cpu()
506 ph1.filter()
507 stop = cpu()
508
509 print 'total time elapsed: %f' % (stop-start)
510 print 'time spend for get seq: %f' % ph1.get_time
511 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
512 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
513 print 'time spend for count time: %f' % ph1.count_time
514 print 'time spend for init time: %f' % ph1.init_time
515 print 'time spend for read_parsing time: %f' % ph1.read_parsing
516 print 'time spend for main_loop time: %f' % ph1.main_loop
517 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
518
519 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
520 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
521 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
522 print 'time spend forarray_stuff time: %f'% ph1.array_stuff
523 #import cProfile
524 #cProfile.run('ph1.filter()')
525
526 #import hotshot
527 #p = hotshot.Profile('profile.log')
528 #p.runcall(ph1.filter)