+ added new format for QPalma alignment output
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pydb
7 import pdb
8 import os
9 import os.path
10 import math
11 import resource
12
13 from qpalma.DataProc import *
14 from qpalma.computeSpliceWeights import *
15 from qpalma.set_param_palma import *
16 from qpalma.computeSpliceAlignWithQuality import *
17 from qpalma.penalty_lookup_new import *
18 from qpalma.compute_donacc import *
19 from qpalma.TrainingParam import Param
20 from qpalma.Plif import Plf
21
22 from qpalma.tools.splicesites import getDonAccScores
23 from qpalma.Configuration import *
24
25 from compile_dataset import getSpliceScores, get_seq_and_scores
26
27 from numpy.matlib import mat,zeros,ones,inf
28 from numpy import inf,mean
29
30 from qpalma.parsers import PipelineReadParser
31
32
33 def unbracket_est(est):
34 new_est = ''
35 e = 0
36
37 while True:
38 if e >= len(est):
39 break
40
41 if est[e] == '[':
42 new_est += est[e+2]
43 e += 4
44 else:
45 new_est += est[e]
46 e += 1
47
48 return "".join(new_est).lower()
49
50
51 class PipelineHeuristic:
52 """
53 This class wraps the filter which decides whether an alignment found by
54 vmatch is spliced an should be then newly aligned using QPalma or not.
55 """
56
57 def __init__(self,run_fname,data_fname,param_fname):
58 """
59 We need a run object holding information about the nr. of support points
60 etc.
61 """
62
63 run = cPickle.load(open(run_fname))
64 self.run = run
65
66 self.data_fname = data_fname
67
68 self.param = cPickle.load(open(param_fname))
69
70 # Set the parameters such as limits penalties for the Plifs
71 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
72
73 self.h = h
74 self.d = d
75 self.a = a
76 self.mmatrix = mmatrix
77 self.qualityPlifs = qualityPlifs
78
79 # when we look for alternative alignments with introns this value is the
80 # mean of intron size
81 self.intron_size = 90
82
83 self.read_size = 36
84
85 self.original_reads = {}
86
87 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
88 line = line.strip()
89 id,seq,q1,q2,q3 = line.split()
90 id = int(id)
91 self.original_reads[id] = seq
92
93 lengthSP = run['numLengthSuppPoints']
94 donSP = run['numDonSuppPoints']
95 accSP = run['numAccSuppPoints']
96 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
97 numq = run['numQualSuppPoints']
98 totalQualSP = run['totalQualSuppPoints']
99
100 currentPhi = zeros((run['numFeatures'],1))
101 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
102 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
103 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
104 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
105
106 totalQualityPenalties = self.param[-totalQualSP:]
107 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
108 self.currentPhi = currentPhi
109
110 # we want to identify spliced reads
111 # so true pos are spliced reads that are predicted "spliced"
112 self.true_pos = 0
113
114 # as false positives we count all reads that are not spliced but predicted
115 # as "spliced"
116 self.false_pos = 0
117
118 self.true_neg = 0
119 self.false_neg = 0
120
121
122 # total time spend for get seq and scores
123 self.get_time = 0.0
124 self.calcAlignmentScoreTime = 0.0
125
126
127 def filter(self):
128 """
129 This method...
130 """
131 run = self.run
132
133 rrp = PipelineReadParser(self.data_fname)
134 all_remapped_reads = rrp.parse()
135
136
137 ctr = 0
138 unspliced_ctr = 0
139 spliced_ctr = 0
140
141 print 'Starting filtering...'
142
143 for readId,currentReadLocations in all_remapped_reads.items():
144 for location in currentReadLocations:
145
146 id = location['id']
147 chr = location['chr']
148 pos = location['pos']
149 strand = location['strand']
150 mismatch = location['mismatches']
151 length = location['length']
152 off = location['offset']
153 seq = location['seq']
154 prb = location['prb']
155 cal_prb = location['cal_prb']
156 chastity = location['chastity']
157
158 id = int(id)
159
160 if strand == '-':
161 continue
162
163 if ctr == 100:
164 break
165
166 #if pos > 10000000:
167 # continue
168
169 unb_seq = unbracket_est(seq)
170 effective_len = len(unb_seq)
171
172 genomicSeq_start = pos
173 genomicSeq_stop = pos+effective_len-1
174
175 start = cpu()
176 #print genomicSeq_start,genomicSeq_stop
177 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
178 stop = cpu()
179 self.get_time += stop-start
180
181 dna = currentDNASeq
182 exons = zeros((2,1))
183 exons[0,0] = 0
184 exons[1,0] = effective_len
185 est = unb_seq
186 original_est = seq
187 quality = prb
188
189 #pdb.set_trace()
190
191 currentVMatchAlignment = dna, exons, est, original_est, quality,\
192 currentAcc, currentDon
193 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
194
195 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
196
197 # found no alternatives
198 if alternativeAlignmentScores == []:
199 continue
200
201 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
202 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
203 #print 'all candidates %s' % str(alternativeAlignmentScores)
204
205 new_id = id - 1000000300000
206
207 unspliced = False
208 # unspliced
209 if new_id > 0:
210 unspliced = True
211
212 # Seems that according to our learned parameters VMatch found a good
213 # alignment of the current read
214 if maxAlternativeAlignmentScore < vMatchScore:
215 unspliced_ctr += 1
216
217 if unspliced:
218 self.true_neg += 1
219 else:
220 self.false_neg += 1
221
222 # We found an alternative alignment considering splice sites that scores
223 # higher than the VMatch alignment
224 else:
225 spliced_ctr += 1
226
227 if unspliced:
228 self.false_pos += 1
229 else:
230 self.true_pos += 1
231
232
233 ctr += 1
234
235 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
236 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
237 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
238
239
240 def findHighestScoringSpliceSites(self,currentAcc,currentDon):
241 max_don = -inf
242 don_pos = []
243 for idx,score in enumerate(currentDon):
244 if score > -inf and idx > 1 and idx < self.read_size:
245 don_pos.append(idx)
246
247 if len(don_pos) == 2:
248 break
249
250 max_acc = -inf
251 acc_pos = []
252 for idx,score in enumerate(currentAcc):
253 if score > -inf and idx >= self.intron_size:
254 acc_pos = idx
255 #acc_pos.append(idx)
256 break
257
258 return don_pos,acc_pos
259
260
261 def calcAlternativeAlignments(self,location):
262 """
263 Given an alignment proposed by Vmatch this function calculates possible
264 alternative alignments taking into account for example matched
265 donor/acceptor positions.
266 """
267
268 run = self.run
269
270 id = location['id']
271 chr = location['chr']
272 pos = location['pos']
273 strand = location['strand']
274 seq = location['seq']
275 #orig_seq = location['orig_seq']
276 prb = location['prb']
277 cal_prb = location['cal_prb']
278
279 orig_seq = self.original_reads[int(id)]
280
281 unb_seq = unbracket_est(seq)
282 effective_len = len(unb_seq)
283
284 genomicSeq_start = pos
285 genomicSeq_stop = pos+self.intron_size*2+self.read_size*2
286
287 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
288 dna = currentDNASeq
289
290 alt_don_pos,acc_pos = self.findHighestScoringSpliceSites(currentAcc,currentDon)
291
292 alternativeScores = []
293
294 for don_pos in alt_don_pos:
295 exons = zeros((2,2),dtype=numpy.int)
296 exons[0,0] = 0
297 exons[0,1] = don_pos
298 exons[1,0] = acc_pos+1
299 exons[1,1] = acc_pos+1+(self.read_size-don_pos)
300 est = unb_seq
301 original_est = seq
302 quality = prb
303
304 _dna = dna[:int(exons[1,1])]
305 _dna = _dna[:exons[1,0]] + orig_seq[don_pos:]
306
307 #pdb.set_trace()
308
309 _currentAcc = currentAcc[:int(exons[1,1])]
310
311 acc_mean = mean([e for e in _currentAcc if e != -inf])
312 factor = 8.5
313 _currentAcc = [acc_mean*factor]*len(_currentAcc)
314
315 _currentDon = currentDon[:int(exons[1,1])]
316
317 don_mean = mean([e for e in _currentAcc if e != -inf])
318 factor = 2.5
319 _currentDon = [don_mean*factor]*len(_currentDon)
320
321 #pdb.set_trace()
322
323 currentVMatchAlignment = _dna, exons, est, original_est, quality,\
324 _currentAcc, _currentDon
325
326 #alternativeScore = self.calcAlignmentScore(currentVMatchAlignment)
327 start = cpu()
328 alternativeScores.append(self.calcAlignmentScore(currentVMatchAlignment))
329 stop = cpu()
330 self.calcAlignmentScoreTime += stop-start
331
332
333 return alternativeScores
334
335
336 def calcAlignmentScore(self,alignment):
337 """
338 Given an alignment (dna,exons,est) and the current parameter for QPalma
339 this function calculates the dot product of the feature representation of
340 the alignment and the parameter vector i.e the alignment score.
341 """
342
343 run = self.run
344
345 h = self.h
346 d = self.d
347 a = self.a
348 mmatrix = self.mmatrix
349 qualityPlifs = self.qualityPlifs
350
351 # Lets start calculation
352 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
353
354 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
355 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
356 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
357 quality, qualityPlifs, run)
358
359 # Calculate the weights
360 trueWeightDon, trueWeightAcc, trueWeightIntron =\
361 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
362
363 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
364
365 # Calculate w'phi(x,y) the total score of the alignment
366 return (trueWeight.T * self.currentPhi)[0,0]
367
368
369 def cpu():
370 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
371 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
372
373
374 if __name__ == '__main__':
375 #run_fname = sys.argv[1]
376 #data_fname = sys.argv[2]
377 #param_filename = sys.argv[3]
378
379 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
380 jp = os.path.join
381
382 run_fname = jp(dir,'run_object.pickle')
383 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_flag'
384 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
385 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
386
387 param_fname = jp(dir,'param_500.pickle')
388
389 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname)
390
391 start = cpu()
392 ph1.filter()
393 stop = cpu()
394
395 print 'total time elapsed: %f' % (stop-start)
396 print 'time spend for get seq: %f' % ph1.get_time
397 print 'time spend for calcAlignmentScore: %f' % ph1.calcAlignmentScoreTime
398 #import cProfile
399 #cProfile.run('ph1.filter()')
400
401 #import hotshot
402 #p = hotshot.Profile('profile.log')
403 #p.runcall(ph1.filter)