ca04bd3bfa2fd9b4ea0506ecfa5f05c60c82f2bd
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pydb
7 import pdb
8 import os
9 import os.path
10 import math
11 import resource
12
13 from qpalma.DataProc import *
14 from qpalma.computeSpliceWeights import *
15 from qpalma.set_param_palma import *
16 from qpalma.computeSpliceAlignWithQuality import *
17 from qpalma.penalty_lookup_new import *
18 from qpalma.compute_donacc import *
19 from qpalma.TrainingParam import Param
20 from qpalma.Plif import Plf
21
22 from qpalma.tools.splicesites import getDonAccScores
23 from qpalma.Configuration import *
24
25 from compile_dataset import getSpliceScores, get_seq_and_scores
26
27 from numpy.matlib import mat,zeros,ones,inf
28 from numpy import inf,mean
29
30 from qpalma.parsers import PipelineReadParser
31
32
33 def unbracket_est(est):
34 new_est = ''
35 e = 0
36
37 while True:
38 if e >= len(est):
39 break
40
41 if est[e] == '[':
42 new_est += est[e+2]
43 e += 4
44 else:
45 new_est += est[e]
46 e += 1
47
48 return "".join(new_est).lower()
49
50
51 class PipelineHeuristic:
52 """
53 This class wraps the filter which decides whether an alignment found by
54 vmatch is spliced an should be then newly aligned using QPalma or not.
55 """
56
57 def __init__(self,run_fname,data_fname,param_fname):
58 """
59 We need a run object holding information about the nr. of support points
60 etc.
61 """
62
63 run = cPickle.load(open(run_fname))
64 self.run = run
65
66 self.data_fname = data_fname
67
68 self.param = cPickle.load(open(param_fname))
69
70 # Set the parameters such as limits penalties for the Plifs
71 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
72
73 self.h = h
74 self.d = d
75 self.a = a
76 self.mmatrix = mmatrix
77 self.qualityPlifs = qualityPlifs
78
79 # when we look for alternative alignments with introns this value is the
80 # mean of intron size
81 self.intron_size = 90
82
83 self.read_size = 36
84
85 self.original_reads = {}
86
87 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
88 line = line.strip()
89 id,seq,q1,q2,q3 = line.split()
90 id = int(id)
91 self.original_reads[id] = seq
92
93 lengthSP = run['numLengthSuppPoints']
94 donSP = run['numDonSuppPoints']
95 accSP = run['numAccSuppPoints']
96 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
97 numq = run['numQualSuppPoints']
98 totalQualSP = run['totalQualSuppPoints']
99
100 currentPhi = zeros((run['numFeatures'],1))
101 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
102 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
103 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
104 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
105
106 totalQualityPenalties = self.param[-totalQualSP:]
107 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
108 self.currentPhi = currentPhi
109
110 # we want to identify spliced reads
111 # so true pos are spliced reads that are predicted "spliced"
112 self.true_pos = 0
113
114 # as false positives we count all reads that are not spliced but predicted
115 # as "spliced"
116 self.false_pos = 0
117
118 self.true_neg = 0
119 self.false_neg = 0
120
121 # total time spend for get seq and scores
122 self.get_time = 0.0
123 self.calcAlignmentScoreTime = 0.0
124
125
126 def filter(self):
127 """
128 This method...
129 """
130 run = self.run
131
132 rrp = PipelineReadParser(self.data_fname)
133 all_remapped_reads = rrp.parse()
134
135
136 ctr = 0
137 unspliced_ctr = 0
138 spliced_ctr = 0
139
140 print 'Starting filtering...'
141
142 for readId,currentReadLocations in all_remapped_reads.items():
143 for location in currentReadLocations:
144
145 id = location['id']
146 chr = location['chr']
147 pos = location['pos']
148 strand = location['strand']
149 mismatch = location['mismatches']
150 length = location['length']
151 off = location['offset']
152 seq = location['seq']
153 prb = location['prb']
154 cal_prb = location['cal_prb']
155 chastity = location['chastity']
156
157 id = int(id)
158
159 if strand == '-':
160 continue
161
162 if ctr == 100:
163 break
164
165 #if pos > 10000000:
166 # continue
167
168 unb_seq = unbracket_est(seq)
169 effective_len = len(unb_seq)
170
171 genomicSeq_start = pos
172 genomicSeq_stop = pos+effective_len-1
173
174 start = cpu()
175 #print genomicSeq_start,genomicSeq_stop
176 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
177 stop = cpu()
178 self.get_time += stop-start
179
180 dna = currentDNASeq
181 exons = zeros((2,1))
182 exons[0,0] = 0
183 exons[1,0] = effective_len
184 est = unb_seq
185 original_est = seq
186 quality = prb
187
188 #pdb.set_trace()
189
190 currentVMatchAlignment = dna, exons, est, original_est, quality,\
191 currentAcc, currentDon
192 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
193
194 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
195
196 # found no alternatives
197 if alternativeAlignmentScores == []:
198 continue
199
200 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
201 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
202 #print 'all candidates %s' % str(alternativeAlignmentScores)
203
204 new_id = id - 1000000300000
205
206 unspliced = False
207 # unspliced
208 if new_id > 0:
209 unspliced = True
210
211 # Seems that according to our learned parameters VMatch found a good
212 # alignment of the current read
213 if maxAlternativeAlignmentScore < vMatchScore:
214 unspliced_ctr += 1
215
216 if unspliced:
217 self.true_neg += 1
218 else:
219 self.false_neg += 1
220
221 # We found an alternative alignment considering splice sites that scores
222 # higher than the VMatch alignment
223 else:
224 spliced_ctr += 1
225
226 if unspliced:
227 self.false_pos += 1
228 else:
229 self.true_pos += 1
230
231
232 ctr += 1
233
234 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
235 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
236 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
237
238
239 def findHighestScoringSpliceSites(self,currentAcc,currentDon):
240 max_don = -inf
241 don_pos = []
242 for idx,score in enumerate(currentDon):
243 if score > -inf and idx > 1 and idx < self.read_size:
244 don_pos.append(idx)
245
246 if len(don_pos) == 2:
247 break
248
249 max_acc = -inf
250 acc_pos = []
251 for idx,score in enumerate(currentAcc):
252 if score > -inf and idx >= self.intron_size:
253 acc_pos = idx
254 #acc_pos.append(idx)
255 break
256
257 return don_pos,acc_pos
258
259
260 def calcAlternativeAlignments(self,location):
261 """
262 Given an alignment proposed by Vmatch this function calculates possible
263 alternative alignments taking into account for example matched
264 donor/acceptor positions.
265 """
266
267 run = self.run
268
269 id = location['id']
270 chr = location['chr']
271 pos = location['pos']
272 strand = location['strand']
273 seq = location['seq']
274 #orig_seq = location['orig_seq']
275 prb = location['prb']
276 cal_prb = location['cal_prb']
277
278 orig_seq = self.original_reads[int(id)]
279
280 unb_seq = unbracket_est(seq)
281 effective_len = len(unb_seq)
282
283 genomicSeq_start = pos
284 genomicSeq_stop = pos+self.intron_size*2+self.read_size*2
285
286 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
287 dna = currentDNASeq
288
289 alt_don_pos,acc_pos = self.findHighestScoringSpliceSites(currentAcc,currentDon)
290
291 alternativeScores = []
292
293 for don_pos in alt_don_pos:
294 exons = zeros((2,2),dtype=numpy.int)
295 exons[0,0] = 0
296 exons[0,1] = don_pos
297 exons[1,0] = acc_pos+1
298 exons[1,1] = acc_pos+1+(self.read_size-don_pos)
299 est = unb_seq
300 original_est = seq
301 quality = prb
302
303 _dna = dna[:int(exons[1,1])]
304 _dna = _dna[:exons[1,0]] + orig_seq[don_pos:]
305
306 #pdb.set_trace()
307
308 _currentAcc = currentAcc[:int(exons[1,1])]
309
310 #acc_mean = mean([e for e in _currentAcc if e != -inf])
311 #factor = 8.5
312 _currentAcc = [0.25]*len(_currentAcc)
313
314 _currentDon = currentDon[:int(exons[1,1])]
315
316 #don_mean = mean([e for e in _currentAcc if e != -inf])
317 #factor = 2.5
318 _currentDon = [0.25]*len(_currentDon)
319
320 #pdb.set_trace()
321
322 currentVMatchAlignment = _dna, exons, est, original_est, quality,\
323 _currentAcc, _currentDon
324
325 #alternativeScore = self.calcAlignmentScore(currentVMatchAlignment)
326 start = cpu()
327 alternativeScores.append(self.calcAlignmentScore(currentVMatchAlignment))
328 stop = cpu()
329 self.calcAlignmentScoreTime += stop-start
330
331
332 return alternativeScores
333
334
335 def calcAlignmentScore(self,alignment):
336 """
337 Given an alignment (dna,exons,est) and the current parameter for QPalma
338 this function calculates the dot product of the feature representation of
339 the alignment and the parameter vector i.e the alignment score.
340 """
341
342 run = self.run
343
344 h = self.h
345 d = self.d
346 a = self.a
347 mmatrix = self.mmatrix
348 qualityPlifs = self.qualityPlifs
349
350 # Lets start calculation
351 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
352
353 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
354 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
355 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
356 quality, qualityPlifs, run)
357
358 # Calculate the weights
359 trueWeightDon, trueWeightAcc, trueWeightIntron =\
360 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
361
362 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
363
364 # Calculate w'phi(x,y) the total score of the alignment
365 return (trueWeight.T * self.currentPhi)[0,0]
366
367
368 def cpu():
369 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
370 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
371
372
373 if __name__ == '__main__':
374 #run_fname = sys.argv[1]
375 #data_fname = sys.argv[2]
376 #param_filename = sys.argv[3]
377
378 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
379 jp = os.path.join
380
381 run_fname = jp(dir,'run_object.pickle')
382 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_flag'
383
384 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
385 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
386
387 param_fname = jp(dir,'param_500.pickle')
388
389 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname)
390
391 start = cpu()
392 ph1.filter()
393 stop = cpu()
394
395 print 'total time elapsed: %f' % (stop-start)
396 print 'time spend for get seq: %f' % ph1.get_time
397 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
398
399 #import cProfile
400 #cProfile.run('ph1.filter()')
401
402 #import hotshot
403 #p = hotshot.Profile('profile.log')
404 #p.runcall(ph1.filter)