git-svn-id: http://svn.tuebingen.mpg.de/ag-raetsch/projects/QPalma@8619 e1793c9e...
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10 import resource
11
12 from qpalma.DataProc import *
13 from qpalma.computeSpliceWeights import *
14 from qpalma.set_param_palma import *
15 from qpalma.computeSpliceAlignWithQuality import *
16 from qpalma.penalty_lookup_new import *
17 from qpalma.compute_donacc import *
18 from qpalma.TrainingParam import Param
19 from qpalma.Plif import Plf
20
21 from qpalma.tools.splicesites import getDonAccScores
22 from qpalma.Configuration import *
23
24 from compile_dataset import getSpliceScores, get_seq_and_scores
25
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy import inf,mean
28
29 from qpalma.parsers import PipelineReadParser
30
31
32 def unbracket_est(est):
33 new_est = ''
34 e = 0
35
36 while True:
37 if e >= len(est):
38 break
39
40 if est[e] == '[':
41 new_est += est[e+2]
42 e += 4
43 else:
44 new_est += est[e]
45 e += 1
46
47 return "".join(new_est).lower()
48
49
50 class PipelineHeuristic:
51 """
52 This class wraps the filter which decides whether an alignment found by
53 vmatch is spliced an should be then newly aligned using QPalma or not.
54 """
55
56 def __init__(self,run_fname,data_fname,param_fname):
57 """
58 We need a run object holding information about the nr. of support points
59 etc.
60 """
61
62 run = cPickle.load(open(run_fname))
63 self.run = run
64
65 start = cpu()
66
67 self.data_fname = data_fname
68
69 self.param = cPickle.load(open(param_fname))
70
71 # Set the parameters such as limits penalties for the Plifs
72 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
73
74 self.h = h
75 self.d = d
76 self.a = a
77 self.mmatrix = mmatrix
78 self.qualityPlifs = qualityPlifs
79
80 # when we look for alternative alignments with introns this value is the
81 # mean of intron size
82 self.intron_size = 90
83
84 self.read_size = 36
85
86 self.original_reads = {}
87
88 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
89 line = line.strip()
90 id,seq,q1,q2,q3 = line.split()
91 id = int(id)
92 self.original_reads[id] = seq
93
94 lengthSP = run['numLengthSuppPoints']
95 donSP = run['numDonSuppPoints']
96 accSP = run['numAccSuppPoints']
97 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
98 numq = run['numQualSuppPoints']
99 totalQualSP = run['totalQualSuppPoints']
100
101 currentPhi = zeros((run['numFeatures'],1))
102 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
103 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
104 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
105 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
106
107 totalQualityPenalties = self.param[-totalQualSP:]
108 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
109 self.currentPhi = currentPhi
110
111 # we want to identify spliced reads
112 # so true pos are spliced reads that are predicted "spliced"
113 self.true_pos = 0
114
115 # as false positives we count all reads that are not spliced but predicted
116 # as "spliced"
117 self.false_pos = 0
118
119 self.true_neg = 0
120 self.false_neg = 0
121
122 # total time spend for get seq and scores
123 self.get_time = 0.0
124 self.calcAlignmentScoreTime = 0.0
125 self.alternativeScoresTime = 0.0
126
127 self.count_time = 0.0
128 self.read_parsing = 0.0
129 self.main_loop = 0.0
130 self.splice_site_time = 0.0
131 self.computeSpliceAlignWithQualityTime = 0.0
132 self.computeSpliceWeightsTime = 0.0
133 self.DotProdTime = 0.0
134 self.array_stuff = 0.0
135 stop = cpu()
136
137 self.init_time = stop-start
138
139 def filter(self):
140 """
141 This method...
142 """
143 run = self.run
144
145 start = cpu()
146
147 rrp = PipelineReadParser(self.data_fname)
148 all_remapped_reads = rrp.parse()
149
150 stop = cpu()
151
152 self.read_parsing = stop-start
153
154 ctr = 0
155 unspliced_ctr = 0
156 spliced_ctr = 0
157
158 print 'Starting filtering...'
159 _start = cpu()
160
161 for readId,currentReadLocations in all_remapped_reads.items():
162 for location in currentReadLocations[:1]:
163
164 id = location['id']
165 chr = location['chr']
166 pos = location['pos']
167 strand = location['strand']
168 mismatch = location['mismatches']
169 length = location['length']
170 off = location['offset']
171 seq = location['seq']
172 prb = location['prb']
173 cal_prb = location['cal_prb']
174 chastity = location['chastity']
175
176 id = int(id)
177
178 if strand == '-':
179 continue
180
181 if ctr == 100:
182 break
183
184 #if pos > 10000000:
185 # continue
186
187 unb_seq = unbracket_est(seq)
188 effective_len = len(unb_seq)
189
190 genomicSeq_start = pos
191 genomicSeq_stop = pos+effective_len-1
192
193 start = cpu()
194 #print genomicSeq_start,genomicSeq_stop
195 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
196 stop = cpu()
197 self.get_time += stop-start
198
199 dna = currentDNASeq
200 exons = zeros((2,1))
201 exons[0,0] = 0
202 exons[1,0] = effective_len
203 est = unb_seq
204 original_est = seq
205 quality = prb
206
207 #pdb.set_trace()
208
209 currentVMatchAlignment = dna, exons, est, original_est, quality,\
210 currentAcc, currentDon
211 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
212
213 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
214
215 start = cpu()
216 # found no alternatives
217 if alternativeAlignmentScores == []:
218 continue
219
220 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
221 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
222 #print 'all candidates %s' % str(alternativeAlignmentScores)
223
224 new_id = id - 1000000300000
225
226 unspliced = False
227 # unspliced
228 if new_id > 0:
229 unspliced = True
230
231 # Seems that according to our learned parameters VMatch found a good
232 # alignment of the current read
233 if maxAlternativeAlignmentScore < vMatchScore:
234 unspliced_ctr += 1
235
236 if unspliced:
237 self.true_neg += 1
238 else:
239 self.false_neg += 1
240
241 # We found an alternative alignment considering splice sites that scores
242 # higher than the VMatch alignment
243 else:
244 spliced_ctr += 1
245
246 if unspliced:
247 self.false_pos += 1
248 else:
249 self.true_pos += 1
250
251 ctr += 1
252 stop = cpu()
253 self.count_time = stop-start
254
255 _stop = cpu()
256 self.main_loop = _stop-_start
257
258 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
259 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
260 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
261
262
263 def findHighestScoringSpliceSites(self,currentAcc,currentDon):
264 max_don = -inf
265 don_pos = []
266 for idx,score in enumerate(currentDon):
267 if score > -inf and idx > 1 and idx < self.read_size:
268 don_pos.append(idx)
269
270 if len(don_pos) == 2:
271 break
272
273 max_acc = -inf
274 acc_pos = []
275 for idx,score in enumerate(currentAcc):
276 if score > -inf and idx >= self.intron_size:
277 acc_pos = idx
278 #acc_pos.append(idx)
279 break
280
281 return don_pos,acc_pos
282
283
284 def calcAlternativeAlignments(self,location):
285 """
286 Given an alignment proposed by Vmatch this function calculates possible
287 alternative alignments taking into account for example matched
288 donor/acceptor positions.
289 """
290
291 run = self.run
292
293 id = location['id']
294 chr = location['chr']
295 pos = location['pos']
296 strand = location['strand']
297 seq = location['seq']
298 #orig_seq = location['orig_seq']
299 prb = location['prb']
300 cal_prb = location['cal_prb']
301
302 orig_seq = self.original_reads[int(id)]
303
304 unb_seq = unbracket_est(seq)
305 effective_len = len(unb_seq)
306
307 genomicSeq_start = pos
308 genomicSeq_stop = pos+self.intron_size*2+self.read_size*2
309
310 start = cpu()
311 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
312 stop = cpu()
313 self.get_time += stop-start
314
315 dna = currentDNASeq
316
317 start = cpu()
318 alt_don_pos,acc_pos = self.findHighestScoringSpliceSites(currentAcc,currentDon)
319 stop = cpu()
320 self.splice_site_time = stop-start
321
322 alternativeScores = []
323
324 exons = zeros((2,2),dtype=numpy.int)
325 est = unb_seq
326 original_est = seq
327 quality = prb
328
329 # inlined
330 h = self.h
331 d = self.d
332 a = self.a
333 mmatrix = self.mmatrix
334 qualityPlifs = self.qualityPlifs
335 # inlined
336
337 IntronScore = calculatePlif(h, [90])[0]
338 dummyAcceptorScore = calculatePlif(a, [0.8])[0]
339 dummyDonorScore = calculatePlif(d, [0.8])[0]
340 print IntronScore,dummyAcceptorScore,dummyDonorScore
341
342 _start = cpu()
343 for don_pos in alt_don_pos:
344 start = cpu()
345
346 exons[0,0] = 0
347 exons[0,1] = don_pos
348 exons[1,0] = acc_pos+1
349 exons[1,1] = acc_pos+1+(self.read_size-don_pos)
350
351 _dna = dna[:int(exons[1,1])]
352 _dna = _dna[:exons[1,0]] + orig_seq[don_pos:]
353
354 _currentAcc = currentAcc[:int(exons[1,1])]
355 _currentAcc = [0.25]*len(_currentAcc)
356
357 _currentDon = currentDon[:int(exons[1,1])]
358 _currentDon = [0.25]*len(_currentDon)
359
360 currentVMatchAlignment = _dna, exons, est, original_est, quality,\
361 _currentAcc, _currentDon
362
363 stop = cpu()
364 self.array_stuff += stop - start
365
366 #alternativeScore = self.calcAlignmentScore(currentVMatchAlignment)
367 #alternativeScores.append(self.calcAlignmentScore(currentVMatchAlignment))
368
369 # Lets start calculation
370 dna, exons, est, original_est, quality, acc_supp, don_supp =\
371 currentVMatchAlignment
372
373 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
374 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
375 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
376 quality, qualityPlifs, run)
377
378 stop = cpu()
379 self.computeSpliceAlignWithQualityTime += stop-start
380 start = cpu()
381
382 # Calculate the weights
383 trueWeightDon, trueWeightAcc, trueWeightIntron =\
384 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
385
386 stop = cpu()
387 self.computeSpliceWeightsTime += stop-start
388
389 start = cpu()
390
391 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
392
393 # Calculate w'phi(x,y) the total score of the alignment
394 alternativeScores.append((trueWeight.T * self.currentPhi)[0,0])
395
396 # remove mismatching positions in the second exon
397 original_est_cut=''
398
399 est_ptr=0
400 dna_ptr=0
401 ptr=0
402 while ptr<len(original_est):
403
404 if original_est[ptr]=='[':
405 dnaletter=original_est[ptr+1]
406 estletter=original_est[ptr+2]
407 if est_ptr<=exons[0,1]:
408 original_est_cut+=original_est[ptr:ptr+4]
409 else:
410 original_est_cut+=estletter # EST letter
411 ptr+=4
412 else:
413 dnaletter=original_est[ptr]
414 estletter=dnaletter
415
416 original_est_cut+=estletter # EST letter
417 ptr+=1
418
419 if estletter=='-':
420 dna_ptr+=1
421 elif dnaletter=='-':
422 est_ptr+=1
423 else:
424 dna_ptr+=1
425 est_ptr+=1
426
427 assert(dna_ptr<=len(dna))
428 assert(est_ptr<=len(est))
429
430 # new score
431 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
432 score += dummyAcceptorScore + IntronScore
433
434 print 'diff %f,%f,%f' % ((trueWeight.T * self.currentPhi)[0,0] - score,(trueWeight.T * self.currentPhi)[0,0], score)
435
436 stop = cpu()
437 self.DotProdTime += stop-start
438
439 _stop = cpu()
440 self.alternativeScoresTime += _stop-_start
441
442 return alternativeScores
443
444
445 def calcAlignmentScore(self,alignment):
446 """
447 Given an alignment (dna,exons,est) and the current parameter for QPalma
448 this function calculates the dot product of the feature representation of
449 the alignment and the parameter vector i.e the alignment score.
450 """
451
452 start = cpu()
453 run = self.run
454
455 # Lets start calculation
456 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
457
458 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
459
460 stop = cpu()
461 self.calcAlignmentScoreTime += stop-start
462
463 return score
464
465
466 def cpu():
467 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
468 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
469
470
471 if __name__ == '__main__':
472 #run_fname = sys.argv[1]
473 #data_fname = sys.argv[2]
474 #param_filename = sys.argv[3]
475
476 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
477 jp = os.path.join
478
479 run_fname = jp(dir,'run_object.pickle')
480 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_flag'
481
482 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
483 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
484
485 param_fname = jp(dir,'param_500.pickle')
486
487 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname)
488
489 start = cpu()
490 ph1.filter()
491 stop = cpu()
492
493 print 'total time elapsed: %f' % (stop-start)
494 print 'time spend for get seq: %f' % ph1.get_time
495 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
496 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
497 print 'time spend for count time: %f' % ph1.count_time
498 print 'time spend for init time: %f' % ph1.init_time
499 print 'time spend for read_parsing time: %f' % ph1.read_parsing
500 print 'time spend for main_loop time: %f' % ph1.main_loop
501 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
502
503 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
504 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
505 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
506 print 'time spend forarray_stuff time: %f'% ph1.array_stuff
507 #import cProfile
508 #cProfile.run('ph1.filter()')
509
510 #import hotshot
511 #p = hotshot.Profile('profile.log')
512 #p.runcall(ph1.filter)