25e7dd67c8d710b0140e7202eda63929148fc8ce
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pydb
7 import pdb
8 import os
9 import os.path
10 import math
11
12 from qpalma.DataProc import *
13 from qpalma.computeSpliceWeights import *
14 from qpalma.set_param_palma import *
15 from qpalma.computeSpliceAlignWithQuality import *
16 from qpalma.penalty_lookup_new import *
17 from qpalma.compute_donacc import *
18 from qpalma.TrainingParam import Param
19 from qpalma.Plif import Plf
20
21 from qpalma.tools.splicesites import getDonAccScores
22 from qpalma.Configuration import *
23
24 from compile_dataset import getSpliceScores, get_seq_and_scores
25
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy import inf
28
29 from qpalma.parsers import PipelineReadParser
30
31
32 def unbracket_est(est):
33 new_est = ''
34 e = 0
35
36 while True:
37 if e >= len(est):
38 break
39
40 if est[e] == '[':
41 new_est += est[e+2]
42 e += 4
43 else:
44 new_est += est[e]
45 e += 1
46
47 return "".join(new_est).lower()
48
49
50 class PipelineHeuristic:
51 """
52 This class wraps the filter which decides whether an alignment found by
53 vmatch is spliced an should be then newly aligned using QPalma or not.
54 """
55
56 def __init__(self,run_fname,data_fname,param_fname):
57 """
58 We need a run object holding information about the nr. of support points
59 etc.
60 """
61
62 run = cPickle.load(open(run_fname))
63 self.run = run
64
65 self.data_fname = data_fname
66
67 self.param = cPickle.load(open(param_fname))
68
69 # Set the parameters such as limits penalties for the Plifs
70 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
71
72 self.h = h
73 self.d = d
74 self.a = a
75 self.mmatrix = mmatrix
76 self.qualityPlifs = qualityPlifs
77
78 # when we look for alternative alignments with introns this value is the
79 # mean of intron size
80 self.intron_size = 90
81
82 self.read_size = 36
83
84 self.original_reads = {}
85
86 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
87 line = line.strip()
88 id,seq,q1,q2,q3 = line.split()
89 id = int(id)
90 self.original_reads[id] = seq
91
92 lengthSP = run['numLengthSuppPoints']
93 donSP = run['numDonSuppPoints']
94 accSP = run['numAccSuppPoints']
95 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
96 numq = run['numQualSuppPoints']
97 totalQualSP = run['totalQualSuppPoints']
98
99 currentPhi = zeros((run['numFeatures'],1))
100 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
101 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
102 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
103 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
104
105 totalQualityPenalties = self.param[-totalQualSP:]
106 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
107 self.currentPhi = currentPhi
108
109 # we want to identify spliced reads
110 # so true pos are spliced reads that are predicted "spliced"
111 self.true_pos = 0
112
113 # as false positives we count all reads that are not spliced but predicted
114 # as "spliced"
115 self.false_pos = 0
116
117 self.true_neg = 0
118 self.false_neg = 0
119
120
121 def filter(self):
122 """
123 This method...
124 """
125 run = self.run
126
127 rrp = PipelineReadParser(self.data_fname)
128 all_remapped_reads = rrp.parse()
129
130
131 ctr = 0
132 unspliced_ctr = 0
133 spliced_ctr = 0
134
135 print 'Starting filtering...'
136
137 for readId,currentReadLocations in all_remapped_reads.items():
138 for location in currentReadLocations:
139
140 id = location['id']
141 chr = location['chr']
142 pos = location['pos']
143 strand = location['strand']
144 mismatch = location['mismatches']
145 length = location['length']
146 off = location['offset']
147 seq = location['seq']
148 prb = location['prb']
149 cal_prb = location['cal_prb']
150 chastity = location['chastity']
151
152 id = int(id)
153
154 if strand == '-':
155 continue
156
157 if ctr == 100:
158 break
159
160 #if pos > 10000000:
161 # continue
162
163 unb_seq = unbracket_est(seq)
164 effective_len = len(unb_seq)
165
166 genomicSeq_start = pos
167 genomicSeq_stop = pos+effective_len-1
168
169 #print genomicSeq_start,genomicSeq_stop
170 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
171
172 dna = currentDNASeq
173 exons = zeros((2,1))
174 exons[0,0] = 0
175 exons[1,0] = effective_len
176 est = unb_seq
177 original_est = seq
178 quality = prb
179
180 #pdb.set_trace()
181
182 currentVMatchAlignment = dna, exons, est, original_est, quality,\
183 currentAcc, currentDon
184 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
185
186 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
187
188 # found no alternatives
189 if alternativeAlignmentScores == []:
190 continue
191
192 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
193 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
194 #print 'all candidates %s' % str(alternativeAlignmentScores)
195
196 new_id = id - 1000000300000
197
198 unspliced = False
199 # unspliced
200 if new_id > 0:
201 unspliced = True
202
203 # Seems that according to our learned parameters VMatch found a good
204 # alignment of the current read
205 if maxAlternativeAlignmentScore < vMatchScore:
206 unspliced_ctr += 1
207
208 if unspliced:
209 self.true_neg += 1
210 else:
211 self.false_neg += 1
212
213 # We found an alternative alignment considering splice sites that scores
214 # higher than the VMatch alignment
215 else:
216 spliced_ctr += 1
217
218 if unspliced:
219 self.false_pos += 1
220 else:
221 self.true_pos += 1
222
223
224 ctr += 1
225
226 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
227 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
228 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
229
230
231 def findHighestScoringSpliceSites(self,currentAcc,currentDon):
232 max_don = -inf
233 don_pos = []
234 for idx,score in enumerate(currentDon):
235 if score > -inf and idx > 1 and idx < self.read_size:
236 don_pos.append(idx)
237
238 #if len(don_pos) == 2:
239 #break
240
241 max_acc = -inf
242 acc_pos = []
243 for idx,score in enumerate(currentAcc):
244 if score > -inf and idx > self.intron_size:
245 #acc_pos = idx
246 acc_pos.append(idx)
247 #break
248
249 return don_pos,acc_pos
250
251
252 def calcAlternativeAlignments(self,location):
253 """
254 Given an alignment proposed by Vmatch this function calculates possible
255 alternative alignments taking into account for example matched
256 donor/acceptor positions.
257 """
258
259 run = self.run
260
261 id = location['id']
262 chr = location['chr']
263 pos = location['pos']
264 strand = location['strand']
265 seq = location['seq']
266 #orig_seq = location['orig_seq']
267 prb = location['prb']
268 cal_prb = location['cal_prb']
269
270 orig_seq = self.original_reads[int(id)]
271
272 unb_seq = unbracket_est(seq)
273 effective_len = len(unb_seq)
274
275 genomicSeq_start = pos
276 genomicSeq_stop = pos+self.intron_size*2+self.read_size*2
277
278 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
279 dna = currentDNASeq
280
281 alt_don_pos,alt_acc_pos = self.findHighestScoringSpliceSites(currentAcc,currentDon)
282
283 alternativeScores = []
284
285 for don_pos in alt_don_pos:
286 for acc_pos in alt_acc_pos:
287 exons = zeros((2,2),dtype=numpy.int)
288 exons[0,0] = 0
289 exons[0,1] = don_pos
290 exons[1,0] = acc_pos+1
291 exons[1,1] = acc_pos+1+(self.read_size-don_pos)
292 est = unb_seq
293 original_est = seq
294 quality = prb
295
296 _dna = dna[:int(exons[1,1])]
297 _dna = _dna[:exons[1,0]] + orig_seq[don_pos:]
298
299 #pdb.set_trace()
300
301 _currentAcc = currentAcc[:int(exons[1,1])]
302 _currentDon = currentDon[:int(exons[1,1])]
303
304 currentVMatchAlignment = _dna, exons, est, original_est, quality,\
305 _currentAcc, _currentDon
306
307 #alternativeScore = self.calcAlignmentScore(currentVMatchAlignment)
308 alternativeScores.append(self.calcAlignmentScore(currentVMatchAlignment))
309
310 return alternativeScores
311
312
313 def calcAlignmentScore(self,alignment):
314 """
315 Given an alignment (dna,exons,est) and the current parameter for QPalma
316 this function calculates the dot product of the feature representation of
317 the alignment and the parameter vector i.e the alignment score.
318 """
319
320 run = self.run
321
322 h = self.h
323 d = self.d
324 a = self.a
325 mmatrix = self.mmatrix
326 qualityPlifs = self.qualityPlifs
327
328 # Lets start calculation
329 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
330
331 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
332 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
333 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
334 quality, qualityPlifs, run)
335
336 # Calculate the weights
337 trueWeightDon, trueWeightAcc, trueWeightIntron =\
338 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
339
340 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
341
342 # Calculate w'phi(x,y) the total score of the alignment
343 return (trueWeight.T * self.currentPhi)[0,0]
344
345
346 if __name__ == '__main__':
347 #run_fname = sys.argv[1]
348 #data_fname = sys.argv[2]
349 #param_filename = sys.argv[3]
350
351 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
352 jp = os.path.join
353
354 run_fname = jp(dir,'run_object.pickle')
355 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_flag'
356 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
357 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
358
359 param_fname = jp(dir,'param_500.pickle')
360
361 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname)
362
363 ph1.filter()
364 #import cProfile
365 #cProfile.run(ph1.filter)
366
367 #import hotshot
368 #p = hotshot.Profile('profile.log')
369 #p.runcall(ph1.filter)