git-svn-id: http://svn.tuebingen.mpg.de/ag-raetsch/projects/QPalma@8651 e1793c9e...
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10 import resource
11
12 from qpalma.DataProc import *
13 from qpalma.computeSpliceWeights import *
14 from qpalma.set_param_palma import *
15 from qpalma.computeSpliceAlignWithQuality import *
16 from qpalma.penalty_lookup_new import *
17 from qpalma.compute_donacc import *
18 from qpalma.TrainingParam import Param
19 from qpalma.Plif import Plf
20
21 from qpalma.tools.splicesites import getDonAccScores
22 from qpalma.Configuration import *
23
24 from compile_dataset import getSpliceScores, get_seq_and_scores
25
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy import inf,mean
28
29 from qpalma.parsers import PipelineReadParser
30
31
32 def unbracket_est(est):
33 new_est = ''
34 e = 0
35
36 while True:
37 if e >= len(est):
38 break
39
40 if est[e] == '[':
41 new_est += est[e+2]
42 e += 4
43 else:
44 new_est += est[e]
45 e += 1
46
47 return "".join(new_est).lower()
48
49
50 class PipelineHeuristic:
51 """
52 This class wraps the filter which decides whether an alignment found by
53 vmatch is spliced an should be then newly aligned using QPalma or not.
54 """
55
56 def __init__(self,run_fname,data_fname,param_fname):
57 """
58 We need a run object holding information about the nr. of support points
59 etc.
60 """
61
62 run = cPickle.load(open(run_fname))
63 self.run = run
64
65 start = cpu()
66
67 self.data_fname = data_fname
68
69 self.param = cPickle.load(open(param_fname))
70
71 # Set the parameters such as limits penalties for the Plifs
72 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
73
74 self.h = h
75 self.d = d
76 self.a = a
77 self.mmatrix = mmatrix
78 self.qualityPlifs = qualityPlifs
79
80 # when we look for alternative alignments with introns this value is the
81 # mean of intron size
82 self.intron_size = 250
83
84 self.read_size = 36
85
86 self.original_reads = {}
87
88 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
89 line = line.strip()
90 id,seq,q1,q2,q3 = line.split()
91 id = int(id)
92 self.original_reads[id] = seq
93
94 lengthSP = run['numLengthSuppPoints']
95 donSP = run['numDonSuppPoints']
96 accSP = run['numAccSuppPoints']
97 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
98 numq = run['numQualSuppPoints']
99 totalQualSP = run['totalQualSuppPoints']
100
101 currentPhi = zeros((run['numFeatures'],1))
102 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
103 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
104 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
105 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
106
107 totalQualityPenalties = self.param[-totalQualSP:]
108 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
109 self.currentPhi = currentPhi
110
111 # we want to identify spliced reads
112 # so true pos are spliced reads that are predicted "spliced"
113 self.true_pos = 0
114
115 # as false positives we count all reads that are not spliced but predicted
116 # as "spliced"
117 self.false_pos = 0
118
119 self.true_neg = 0
120 self.false_neg = 0
121
122 # total time spend for get seq and scores
123 self.get_time = 0.0
124 self.calcAlignmentScoreTime = 0.0
125 self.alternativeScoresTime = 0.0
126
127 self.count_time = 0.0
128 self.read_parsing = 0.0
129 self.main_loop = 0.0
130 self.splice_site_time = 0.0
131 self.computeSpliceAlignWithQualityTime = 0.0
132 self.computeSpliceWeightsTime = 0.0
133 self.DotProdTime = 0.0
134 self.array_stuff = 0.0
135 stop = cpu()
136
137 self.init_time = stop-start
138
139 def filter(self):
140 """
141 This method...
142 """
143 run = self.run
144
145 start = cpu()
146
147 rrp = PipelineReadParser(self.data_fname)
148 all_remapped_reads = rrp.parse()
149
150 stop = cpu()
151
152 self.read_parsing = stop-start
153
154 ctr = 0
155 unspliced_ctr = 0
156 spliced_ctr = 0
157
158 print 'Starting filtering...'
159 _start = cpu()
160
161 for readId,currentReadLocations in all_remapped_reads.items():
162 for location in currentReadLocations[:1]:
163
164 id = location['id']
165 chr = location['chr']
166 pos = location['pos']
167 strand = location['strand']
168 mismatch = location['mismatches']
169 length = location['length']
170 off = location['offset']
171 seq = location['seq']
172 prb = location['prb']
173 cal_prb = location['cal_prb']
174 chastity = location['chastity']
175
176 id = int(id)
177
178 if strand == '-':
179 continue
180
181 if ctr == 1000:
182 break
183
184 #if pos > 10000000:
185 # continue
186
187 unb_seq = unbracket_est(seq)
188 effective_len = len(unb_seq)
189
190 genomicSeq_start = pos
191 genomicSeq_stop = pos+effective_len-1
192
193 start = cpu()
194 #print genomicSeq_start,genomicSeq_stop
195 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
196 stop = cpu()
197 self.get_time += stop-start
198
199 dna = currentDNASeq
200 exons = zeros((2,1))
201 exons[0,0] = 0
202 exons[1,0] = effective_len
203 est = unb_seq
204 original_est = seq
205 quality = prb
206
207 #pdb.set_trace()
208
209 currentVMatchAlignment = dna, exons, est, original_est, quality,\
210 currentAcc, currentDon
211 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
212
213 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
214
215 start = cpu()
216 # found no alternatives
217 if alternativeAlignmentScores == []:
218 continue
219
220 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
221 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
222 #print 'all candidates %s' % str(alternativeAlignmentScores)
223
224 new_id = id - 1000000300000
225
226 unspliced = False
227 # unspliced
228 if new_id > 0:
229 unspliced = True
230
231 # Seems that according to our learned parameters VMatch found a good
232 # alignment of the current read
233 if maxAlternativeAlignmentScore < vMatchScore:
234 unspliced_ctr += 1
235
236 if unspliced:
237 self.true_neg += 1
238 else:
239 self.false_neg += 1
240
241 # We found an alternative alignment considering splice sites that scores
242 # higher than the VMatch alignment
243 else:
244 spliced_ctr += 1
245
246 if unspliced:
247 self.false_pos += 1
248 else:
249 self.true_pos += 1
250
251 ctr += 1
252 stop = cpu()
253 self.count_time = stop-start
254
255 _stop = cpu()
256 self.main_loop = _stop-_start
257
258 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
259 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
260 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
261
262
263 def findHighestScoringSpliceSites(self,currentAcc,currentDon):
264
265 acc = []
266 for idx,score in enumerate(currentAcc):
267 if score > -inf:
268 acc.append((idx,score))
269 if idx>self.read_size:
270 break
271
272 acc.sort(lambda x,y: x[0]-y[0])
273 acc=acc[-2:]
274
275 don = []
276 for idx,score in enumerate(currentDon):
277 if score > -inf:
278 don.append((idx,score))
279 if idx>self.read_size:
280 break
281
282 don.sort(lambda x,y: x[0]-y[0])
283 don=don[-2:]
284
285 return don,acc
286
287 def calcAlternativeAlignments(self,location):
288 """
289 Given an alignment proposed by Vmatch this function calculates possible
290 alternative alignments taking into account for example matched
291 donor/acceptor positions.
292 """
293 def signum(a):
294 if a>0:
295 return 1
296 elif a<0:
297 return -1
298 else:
299 return 0
300
301 run = self.run
302
303 id = location['id']
304 chr = location['chr']
305 pos = location['pos']
306 strand = location['strand']
307 original_est = location['seq']
308 quality = location['prb']
309 cal_prb = location['cal_prb']
310
311 est = unbracket_est(original_est)
312 effective_len = len(est)
313
314 genomicSeq_start = pos
315 genomicSeq_stop = pos+1000#self.read_size+1#+self.intron_size*2+self.read_size*2
316
317 start = cpu()
318 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
319 stop = cpu()
320 self.get_time += stop-start
321 dna = currentDNASeq
322
323 alt_don,alt_acc = self.findHighestScoringSpliceSites(currentAcc,currentDon)
324
325 sort_acc=[ elem for elem in enumerate(currentAcc) ]
326 sort_acc.sort(lambda x,y: signum(x[1]-y[1]) )
327
328 (max_acc_pos,max_acc_score)=sort_acc[-1]
329
330 #print alt_don
331 #print alt_acc
332
333 alternativeScores = []
334
335 # inlined
336 h = self.h
337 d = self.d
338 a = self.a
339 mmatrix = self.mmatrix
340 qualityPlifs = self.qualityPlifs
341 # inlined
342
343 # compute dummy scores
344 #IntronScore = calculatePlif(h, [math.fabs(max_acc_pos-30)])[0]
345 #dummyAcceptorScore = calculatePlif(a, [max_acc_score])[0]
346 IntronScore = calculatePlif(h, [self.intron_size])[0] - 0.5
347 dummyAcceptorScore = calculatePlif(a, [0.25])[0]
348 dummyDonorScore = calculatePlif(d, [0.25])[0]
349
350 _start = cpu()
351 for (don_pos,don_score) in alt_don:
352 # remove mismatching positions in the second exon
353 original_est_cut=''
354
355 est_ptr=0
356 dna_ptr=0
357 ptr=0
358 while ptr<len(original_est):
359
360 if original_est[ptr]=='[':
361 dnaletter=original_est[ptr+1]
362 estletter=original_est[ptr+2]
363 if dna_ptr < don_pos:
364 original_est_cut+=original_est[ptr:ptr+4]
365 else:
366 #original_est_cut+=estletter # EST letter
367 original_est_cut+=dnaletter # DNA letter
368 ptr+=4
369 else:
370 dnaletter=original_est[ptr]
371 estletter=dnaletter
372
373 original_est_cut+=estletter # EST letter
374 ptr+=1
375
376 if estletter=='-':
377 dna_ptr+=1
378 elif dnaletter=='-':
379 est_ptr+=1
380 else:
381 dna_ptr+=1
382 est_ptr+=1
383
384 assert(dna_ptr<=len(dna))
385 assert(est_ptr<=len(est))
386
387 #print "Donor"
388 DonorScore = calculatePlif(d, [don_score])[0]
389 #print DonorScore,don_score,don_pos
390
391 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
392 score += dummyAcceptorScore + IntronScore + DonorScore
393
394 #print 'diff %f,%f,%f' % ((trueWeight.T * self.currentPhi)[0,0] - score,(trueWeight.T * self.currentPhi)[0,0], score)
395 alternativeScores.append(score)
396
397 _stop = cpu()
398 self.alternativeScoresTime += _stop-_start
399
400 _start = cpu()
401 for (acc_pos,acc_score) in alt_acc:
402 # remove mismatching positions in the second exon
403 original_est_cut=''
404
405 est_ptr=0
406 dna_ptr=0
407 ptr=0
408 while ptr<len(original_est):
409
410 if original_est[ptr]=='[':
411 dnaletter=original_est[ptr+1]
412 estletter=original_est[ptr+2]
413 if est_ptr>=acc_pos:
414 original_est_cut+=original_est[ptr:ptr+4]
415 else:
416 original_est_cut+=estletter # EST letter
417 ptr+=4
418 else:
419 dnaletter=original_est[ptr]
420 estletter=dnaletter
421
422 original_est_cut+=estletter # EST letter
423 ptr+=1
424
425 if estletter=='-':
426 dna_ptr+=1
427 elif dnaletter=='-':
428 est_ptr+=1
429 else:
430 dna_ptr+=1
431 est_ptr+=1
432
433 assert(dna_ptr<=len(dna))
434 assert(est_ptr<=len(est))
435
436 #print original_est,original_est_cut
437
438 AcceptorScore = calculatePlif(d, [acc_score])[0]
439 #print "Acceptor"
440 #print AcceptorScore,acc_score,acc_pos
441
442 #if acc_score<0.1:
443 # print currentAcc[0:50]
444 # print currentDon[0:50]
445
446 score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
447 score += AcceptorScore + IntronScore + dummyDonorScore
448
449 #print 'diff %f,%f,%f' % ((trueWeight.T * self.currentPhi)[0,0] - score,(trueWeight.T * self.currentPhi)[0,0], score)
450 alternativeScores.append(score)
451
452 _stop = cpu()
453 self.alternativeScoresTime += _stop-_start
454
455 return alternativeScores
456
457
458 def calcAlignmentScore(self,alignment):
459 """
460 Given an alignment (dna,exons,est) and the current parameter for QPalma
461 this function calculates the dot product of the feature representation of
462 the alignment and the parameter vector i.e the alignment score.
463 """
464
465 start = cpu()
466 run = self.run
467
468 # Lets start calculation
469 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
470
471 score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
472
473 stop = cpu()
474 self.calcAlignmentScoreTime += stop-start
475
476 return score
477
478
479 def cpu():
480 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
481 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
482
483
484 if __name__ == '__main__':
485 #run_fname = sys.argv[1]
486 #data_fname = sys.argv[2]
487 #param_filename = sys.argv[3]
488
489 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
490 jp = os.path.join
491
492 run_fname = jp(dir,'run_object.pickle')
493 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_1k'
494
495 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
496 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
497
498 param_fname = jp(dir,'param_500.pickle')
499
500 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname)
501
502 start = cpu()
503 ph1.filter()
504 stop = cpu()
505
506 print 'total time elapsed: %f' % (stop-start)
507 print 'time spend for get seq: %f' % ph1.get_time
508 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
509 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
510 print 'time spend for count time: %f' % ph1.count_time
511 print 'time spend for init time: %f' % ph1.init_time
512 print 'time spend for read_parsing time: %f' % ph1.read_parsing
513 print 'time spend for main_loop time: %f' % ph1.main_loop
514 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
515
516 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
517 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
518 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
519 print 'time spend forarray_stuff time: %f'% ph1.array_stuff
520 #import cProfile
521 #cProfile.run('ph1.filter()')
522
523 #import hotshot
524 #p = hotshot.Profile('profile.log')
525 #p.runcall(ph1.filter)