10b8fbebfcd57e61da88e691292526a66f0c2113
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pydb
7 import pdb
8 import os
9 import os.path
10 import math
11 import resource
12
13 from qpalma.DataProc import *
14 from qpalma.computeSpliceWeights import *
15 from qpalma.set_param_palma import *
16 from qpalma.computeSpliceAlignWithQuality import *
17 from qpalma.penalty_lookup_new import *
18 from qpalma.compute_donacc import *
19 from qpalma.TrainingParam import Param
20 from qpalma.Plif import Plf
21
22 from qpalma.tools.splicesites import getDonAccScores
23 from qpalma.Configuration import *
24
25 from compile_dataset import getSpliceScores, get_seq_and_scores
26
27 from numpy.matlib import mat,zeros,ones,inf
28 from numpy import inf,mean
29
30 from qpalma.parsers import PipelineReadParser
31
32
33 def unbracket_est(est):
34 new_est = ''
35 e = 0
36
37 while True:
38 if e >= len(est):
39 break
40
41 if est[e] == '[':
42 new_est += est[e+2]
43 e += 4
44 else:
45 new_est += est[e]
46 e += 1
47
48 return "".join(new_est).lower()
49
50
51 class PipelineHeuristic:
52 """
53 This class wraps the filter which decides whether an alignment found by
54 vmatch is spliced an should be then newly aligned using QPalma or not.
55 """
56
57 def __init__(self,run_fname,data_fname,param_fname):
58 """
59 We need a run object holding information about the nr. of support points
60 etc.
61 """
62
63 run = cPickle.load(open(run_fname))
64 self.run = run
65
66 start = cpu()
67
68 self.data_fname = data_fname
69
70 self.param = cPickle.load(open(param_fname))
71
72 # Set the parameters such as limits penalties for the Plifs
73 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
74
75 self.h = h
76 self.d = d
77 self.a = a
78 self.mmatrix = mmatrix
79 self.qualityPlifs = qualityPlifs
80
81 # when we look for alternative alignments with introns this value is the
82 # mean of intron size
83 self.intron_size = 90
84
85 self.read_size = 36
86
87 self.original_reads = {}
88
89 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
90 line = line.strip()
91 id,seq,q1,q2,q3 = line.split()
92 id = int(id)
93 self.original_reads[id] = seq
94
95 lengthSP = run['numLengthSuppPoints']
96 donSP = run['numDonSuppPoints']
97 accSP = run['numAccSuppPoints']
98 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
99 numq = run['numQualSuppPoints']
100 totalQualSP = run['totalQualSuppPoints']
101
102 currentPhi = zeros((run['numFeatures'],1))
103 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
104 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
105 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
106 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
107
108 totalQualityPenalties = self.param[-totalQualSP:]
109 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
110 self.currentPhi = currentPhi
111
112 # we want to identify spliced reads
113 # so true pos are spliced reads that are predicted "spliced"
114 self.true_pos = 0
115
116 # as false positives we count all reads that are not spliced but predicted
117 # as "spliced"
118 self.false_pos = 0
119
120 self.true_neg = 0
121 self.false_neg = 0
122
123 # total time spend for get seq and scores
124 self.get_time = 0.0
125 self.calcAlignmentScoreTime = 0.0
126 self.alternativeScoresTime = 0.0
127
128 self.count_time = 0.0
129 self.read_parsing = 0.0
130 self.main_loop = 0.0
131 self.splice_site_time = 0.0
132 self.computeSpliceAlignWithQualityTime = 0.0
133 self.computeSpliceWeightsTime = 0.0
134 self.DotProdTime = 0.0
135 self.array_stuff = 0.0
136 stop = cpu()
137
138 self.init_time = stop-start
139
140 def filter(self):
141 """
142 This method...
143 """
144 run = self.run
145
146 start = cpu()
147
148 rrp = PipelineReadParser(self.data_fname)
149 all_remapped_reads = rrp.parse()
150
151 stop = cpu()
152
153 self.read_parsing = stop-start
154
155 ctr = 0
156 unspliced_ctr = 0
157 spliced_ctr = 0
158
159 print 'Starting filtering...'
160 _start = cpu()
161
162 for readId,currentReadLocations in all_remapped_reads.items():
163 for location in currentReadLocations[:1]:
164
165 id = location['id']
166 chr = location['chr']
167 pos = location['pos']
168 strand = location['strand']
169 mismatch = location['mismatches']
170 length = location['length']
171 off = location['offset']
172 seq = location['seq']
173 prb = location['prb']
174 cal_prb = location['cal_prb']
175 chastity = location['chastity']
176
177 id = int(id)
178
179 if strand == '-':
180 continue
181
182 if ctr == 100:
183 break
184
185 #if pos > 10000000:
186 # continue
187
188 unb_seq = unbracket_est(seq)
189 effective_len = len(unb_seq)
190
191 genomicSeq_start = pos
192 genomicSeq_stop = pos+effective_len-1
193
194 start = cpu()
195 #print genomicSeq_start,genomicSeq_stop
196 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
197 stop = cpu()
198 self.get_time += stop-start
199
200 dna = currentDNASeq
201 exons = zeros((2,1))
202 exons[0,0] = 0
203 exons[1,0] = effective_len
204 est = unb_seq
205 original_est = seq
206 quality = prb
207
208 #pdb.set_trace()
209
210 currentVMatchAlignment = dna, exons, est, original_est, quality,\
211 currentAcc, currentDon
212 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
213
214 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
215
216 start = cpu()
217 # found no alternatives
218 if alternativeAlignmentScores == []:
219 continue
220
221 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
222 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
223 #print 'all candidates %s' % str(alternativeAlignmentScores)
224
225 new_id = id - 1000000300000
226
227 unspliced = False
228 # unspliced
229 if new_id > 0:
230 unspliced = True
231
232 # Seems that according to our learned parameters VMatch found a good
233 # alignment of the current read
234 if maxAlternativeAlignmentScore < vMatchScore:
235 unspliced_ctr += 1
236
237 if unspliced:
238 self.true_neg += 1
239 else:
240 self.false_neg += 1
241
242 # We found an alternative alignment considering splice sites that scores
243 # higher than the VMatch alignment
244 else:
245 spliced_ctr += 1
246
247 if unspliced:
248 self.false_pos += 1
249 else:
250 self.true_pos += 1
251
252 ctr += 1
253 stop = cpu()
254 self.count_time = stop-start
255
256 _stop = cpu()
257 self.main_loop = _stop-_start
258
259 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
260 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
261 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
262
263
264 def findHighestScoringSpliceSites(self,currentAcc,currentDon):
265 max_don = -inf
266 don_pos = []
267 for idx,score in enumerate(currentDon):
268 if score > -inf and idx > 1 and idx < self.read_size:
269 don_pos.append(idx)
270
271 if len(don_pos) == 2:
272 break
273
274 max_acc = -inf
275 acc_pos = []
276 for idx,score in enumerate(currentAcc):
277 if score > -inf and idx >= self.intron_size:
278 acc_pos = idx
279 #acc_pos.append(idx)
280 break
281
282 return don_pos,acc_pos
283
284
285 def calcAlternativeAlignments(self,location):
286 """
287 Given an alignment proposed by Vmatch this function calculates possible
288 alternative alignments taking into account for example matched
289 donor/acceptor positions.
290 """
291
292 run = self.run
293
294 id = location['id']
295 chr = location['chr']
296 pos = location['pos']
297 strand = location['strand']
298 seq = location['seq']
299 #orig_seq = location['orig_seq']
300 prb = location['prb']
301 cal_prb = location['cal_prb']
302
303 orig_seq = self.original_reads[int(id)]
304
305 unb_seq = unbracket_est(seq)
306 effective_len = len(unb_seq)
307
308 genomicSeq_start = pos
309 genomicSeq_stop = pos+self.intron_size*2+self.read_size*2
310
311 start = cpu()
312 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
313 stop = cpu()
314 self.get_time += stop-start
315
316 dna = currentDNASeq
317
318 start = cpu()
319 alt_don_pos,acc_pos = self.findHighestScoringSpliceSites(currentAcc,currentDon)
320 stop = cpu()
321 self.splice_site_time = stop-start
322
323 alternativeScores = []
324
325 exons = zeros((2,2),dtype=numpy.int)
326 est = unb_seq
327 original_est = seq
328 quality = prb
329
330 # inlined
331 h = self.h
332 d = self.d
333 a = self.a
334 mmatrix = self.mmatrix
335 qualityPlifs = self.qualityPlifs
336 # inlined
337
338 _start = cpu()
339 for don_pos in alt_don_pos:
340 start = cpu()
341
342 exons[0,0] = 0
343 exons[0,1] = don_pos
344 exons[1,0] = acc_pos+1
345 exons[1,1] = acc_pos+1+(self.read_size-don_pos)
346
347 _dna = dna[:int(exons[1,1])]
348 _dna = _dna[:exons[1,0]] + orig_seq[don_pos:]
349
350 _currentAcc = currentAcc[:int(exons[1,1])]
351 _currentAcc = [0.25]*len(_currentAcc)
352
353 _currentDon = currentDon[:int(exons[1,1])]
354 _currentDon = [0.25]*len(_currentDon)
355
356 currentVMatchAlignment = _dna, exons, est, original_est, quality,\
357 _currentAcc, _currentDon
358
359 stop = cpu()
360 self.array_stuff += stop - start
361
362 #alternativeScore = self.calcAlignmentScore(currentVMatchAlignment)
363 #alternativeScores.append(self.calcAlignmentScore(currentVMatchAlignment))
364
365 # Lets start calculation
366 dna, exons, est, original_est, quality, acc_supp, don_supp =\
367 currentVMatchAlignment
368
369 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
370 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
371 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
372 quality, qualityPlifs, run)
373
374 stop = cpu()
375 self.computeSpliceAlignWithQualityTime += stop-start
376 start = cpu()
377
378 # Calculate the weights
379 trueWeightDon, trueWeightAcc, trueWeightIntron =\
380 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
381
382 stop = cpu()
383 self.computeSpliceWeightsTime += stop-start
384
385 start = cpu()
386
387 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
388
389 # Calculate w'phi(x,y) the total score of the alignment
390 alternativeScores.append((trueWeight.T * self.currentPhi)[0,0])
391
392 stop = cpu()
393 self.DotProdTime += stop-start
394
395 _stop = cpu()
396 self.alternativeScoresTime += _stop-_start
397
398 return alternativeScores
399
400
401 def calcAlignmentScore(self,alignment):
402 """
403 Given an alignment (dna,exons,est) and the current parameter for QPalma
404 this function calculates the dot product of the feature representation of
405 the alignment and the parameter vector i.e the alignment score.
406 """
407
408 run = self.run
409
410 h = self.h
411 d = self.d
412 a = self.a
413 mmatrix = self.mmatrix
414 qualityPlifs = self.qualityPlifs
415
416 # Lets start calculation
417 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
418
419 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
420 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
421 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
422 quality, qualityPlifs, run)
423
424 # Calculate the weights
425 trueWeightDon, trueWeightAcc, trueWeightIntron =\
426 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
427
428 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
429
430 # Calculate w'phi(x,y) the total score of the alignment
431 return (trueWeight.T * self.currentPhi)[0,0]
432
433
434 def cpu():
435 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
436 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
437
438
439 if __name__ == '__main__':
440 #run_fname = sys.argv[1]
441 #data_fname = sys.argv[2]
442 #param_filename = sys.argv[3]
443
444 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
445 jp = os.path.join
446
447 run_fname = jp(dir,'run_object.pickle')
448 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_flag'
449
450 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
451 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
452
453 param_fname = jp(dir,'param_500.pickle')
454
455 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname)
456
457 start = cpu()
458 ph1.filter()
459 stop = cpu()
460
461 print 'total time elapsed: %f' % (stop-start)
462 print 'time spend for get seq: %f' % ph1.get_time
463 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
464 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
465 print 'time spend for count time: %f' % ph1.count_time
466 print 'time spend for init time: %f' % ph1.init_time
467 print 'time spend for read_parsing time: %f' % ph1.read_parsing
468 print 'time spend for main_loop time: %f' % ph1.main_loop
469 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
470
471 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
472 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
473 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
474 print 'time spend forarray_stuff time: %f'% ph1.array_stuff
475 #import cProfile
476 #cProfile.run('ph1.filter()')
477
478 #import hotshot
479 #p = hotshot.Profile('profile.log')
480 #p.runcall(ph1.filter)