first version of faster heuristic
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pydb
7 import pdb
8 import os
9 import os.path
10 import math
11 import resource
12
13 from qpalma.DataProc import *
14 from qpalma.computeSpliceWeights import *
15 from qpalma.set_param_palma import *
16 from qpalma.computeSpliceAlignWithQuality import *
17 from qpalma.penalty_lookup_new import *
18 from qpalma.compute_donacc import *
19 from qpalma.TrainingParam import Param
20 from qpalma.Plif import Plf
21
22 from qpalma.tools.splicesites import getDonAccScores
23 from qpalma.Configuration import *
24
25 from compile_dataset import getSpliceScores, get_seq_and_scores
26
27 from numpy.matlib import mat,zeros,ones,inf
28 from numpy import inf,mean
29
30 from qpalma.parsers import PipelineReadParser
31
32
33 def unbracket_est(est):
34 new_est = ''
35 e = 0
36
37 while True:
38 if e >= len(est):
39 break
40
41 if est[e] == '[':
42 new_est += est[e+2]
43 e += 4
44 else:
45 new_est += est[e]
46 e += 1
47
48 return "".join(new_est).lower()
49
50
51 class PipelineHeuristic:
52 """
53 This class wraps the filter which decides whether an alignment found by
54 vmatch is spliced an should be then newly aligned using QPalma or not.
55 """
56
57 def __init__(self,run_fname,data_fname,param_fname):
58 """
59 We need a run object holding information about the nr. of support points
60 etc.
61 """
62
63 run = cPickle.load(open(run_fname))
64 self.run = run
65
66 start = cpu()
67
68 self.data_fname = data_fname
69
70 self.param = cPickle.load(open(param_fname))
71
72 # Set the parameters such as limits penalties for the Plifs
73 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
74
75 self.h = h
76 self.d = d
77 self.a = a
78 self.mmatrix = mmatrix
79 self.qualityPlifs = qualityPlifs
80
81 # when we look for alternative alignments with introns this value is the
82 # mean of intron size
83 self.intron_size = 90
84
85 self.read_size = 36
86
87 self.original_reads = {}
88
89 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
90 line = line.strip()
91 id,seq,q1,q2,q3 = line.split()
92 id = int(id)
93 self.original_reads[id] = seq
94
95 lengthSP = run['numLengthSuppPoints']
96 donSP = run['numDonSuppPoints']
97 accSP = run['numAccSuppPoints']
98 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
99 numq = run['numQualSuppPoints']
100 totalQualSP = run['totalQualSuppPoints']
101
102 currentPhi = zeros((run['numFeatures'],1))
103 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
104 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
105 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
106 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
107
108 totalQualityPenalties = self.param[-totalQualSP:]
109 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
110 self.currentPhi = currentPhi
111
112 # we want to identify spliced reads
113 # so true pos are spliced reads that are predicted "spliced"
114 self.true_pos = 0
115
116 # as false positives we count all reads that are not spliced but predicted
117 # as "spliced"
118 self.false_pos = 0
119
120 self.true_neg = 0
121 self.false_neg = 0
122
123 # total time spend for get seq and scores
124 self.get_time = 0.0
125 self.calcAlignmentScoreTime = 0.0
126 self.alternativeScoresTime = 0.0
127
128 self.count_time = 0.0
129 self.read_parsing = 0.0
130 self.main_loop = 0.0
131 self.splice_site_time = 0.0
132 self.computeSpliceAlignWithQualityTime = 0.0
133 self.computeSpliceWeightsTime = 0.0
134 self.DotProdTime = 0.0
135 self.array_stuff = 0.0
136 stop = cpu()
137
138 self.init_time = stop-start
139
140 def filter(self):
141 """
142 This method...
143 """
144 run = self.run
145
146 start = cpu()
147
148 rrp = PipelineReadParser(self.data_fname)
149 all_remapped_reads = rrp.parse()
150
151 stop = cpu()
152
153 self.read_parsing = stop-start
154
155 ctr = 0
156 unspliced_ctr = 0
157 spliced_ctr = 0
158
159 print 'Starting filtering...'
160 _start = cpu()
161
162 for readId,currentReadLocations in all_remapped_reads.items():
163 for location in currentReadLocations[:1]:
164
165 id = location['id']
166 chr = location['chr']
167 pos = location['pos']
168 strand = location['strand']
169 mismatch = location['mismatches']
170 length = location['length']
171 off = location['offset']
172 seq = location['seq']
173 prb = location['prb']
174 cal_prb = location['cal_prb']
175 chastity = location['chastity']
176
177 id = int(id)
178
179 if strand == '-':
180 continue
181
182 if ctr == 100:
183 break
184
185 #if pos > 10000000:
186 # continue
187
188 unb_seq = unbracket_est(seq)
189 effective_len = len(unb_seq)
190
191 genomicSeq_start = pos
192 genomicSeq_stop = pos+effective_len-1
193
194 start = cpu()
195 #print genomicSeq_start,genomicSeq_stop
196 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
197 stop = cpu()
198 self.get_time += stop-start
199
200 dna = currentDNASeq
201 exons = zeros((2,1))
202 exons[0,0] = 0
203 exons[1,0] = effective_len
204 est = unb_seq
205 original_est = seq
206 quality = prb
207
208 #pdb.set_trace()
209
210 currentVMatchAlignment = dna, exons, est, original_est, quality,\
211 currentAcc, currentDon
212 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
213
214 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
215
216 start = cpu()
217 # found no alternatives
218 if alternativeAlignmentScores == []:
219 continue
220
221 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
222 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
223 #print 'all candidates %s' % str(alternativeAlignmentScores)
224
225 new_id = id - 1000000300000
226
227 unspliced = False
228 # unspliced
229 if new_id > 0:
230 unspliced = True
231
232 # Seems that according to our learned parameters VMatch found a good
233 # alignment of the current read
234 if maxAlternativeAlignmentScore < vMatchScore:
235 unspliced_ctr += 1
236
237 if unspliced:
238 self.true_neg += 1
239 else:
240 self.false_neg += 1
241
242 # We found an alternative alignment considering splice sites that scores
243 # higher than the VMatch alignment
244 else:
245 spliced_ctr += 1
246
247 if unspliced:
248 self.false_pos += 1
249 else:
250 self.true_pos += 1
251
252 ctr += 1
253 stop = cpu()
254 self.count_time = stop-start
255
256 _stop = cpu()
257 self.main_loop = _stop-_start
258
259 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
260 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
261 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
262
263
264 def findHighestScoringSpliceSites(self,currentAcc,currentDon):
265 max_don = -inf
266 don_pos = []
267 for idx,score in enumerate(currentDon):
268 if score > -inf and idx > 1 and idx < self.read_size:
269 don_pos.append(idx)
270
271 if len(don_pos) == 2:
272 break
273
274 max_acc = -inf
275 acc_pos = []
276 for idx,score in enumerate(currentAcc):
277 if score > -inf and idx >= self.intron_size:
278 acc_pos = idx
279 #acc_pos.append(idx)
280 break
281
282 return don_pos,acc_pos
283
284
285 def calcAlternativeAlignments(self,location):
286 """
287 Given an alignment proposed by Vmatch this function calculates possible
288 alternative alignments taking into account for example matched
289 donor/acceptor positions.
290 """
291
292 run = self.run
293
294 id = location['id']
295 chr = location['chr']
296 pos = location['pos']
297 strand = location['strand']
298 seq = location['seq']
299 #orig_seq = location['orig_seq']
300 prb = location['prb']
301 cal_prb = location['cal_prb']
302
303 orig_seq = self.original_reads[int(id)]
304
305 unb_seq = unbracket_est(seq)
306 effective_len = len(unb_seq)
307
308 genomicSeq_start = pos
309 genomicSeq_stop = pos+self.intron_size*2+self.read_size*2
310
311 start = cpu()
312 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
313 stop = cpu()
314 self.get_time += stop-start
315
316 dna = currentDNASeq
317
318 start = cpu()
319 alt_don_pos,acc_pos = self.findHighestScoringSpliceSites(currentAcc,currentDon)
320 stop = cpu()
321 self.splice_site_time = stop-start
322
323 alternativeScores = []
324
325 exons = zeros((2,2),dtype=numpy.int)
326 est = unb_seq
327 original_est = seq
328 quality = prb
329
330 # inlined
331 h = self.h
332 d = self.d
333 a = self.a
334 mmatrix = self.mmatrix
335 qualityPlifs = self.qualityPlifs
336 # inlined
337
338 _start = cpu()
339 for don_pos in alt_don_pos:
340 start = cpu()
341
342 exons[0,0] = 0
343 exons[0,1] = don_pos
344 exons[1,0] = acc_pos+1
345 exons[1,1] = acc_pos+1+(self.read_size-don_pos)
346
347 _dna = dna[:int(exons[1,1])]
348 _dna = _dna[:exons[1,0]] + orig_seq[don_pos:]
349
350 _currentAcc = currentAcc[:int(exons[1,1])]
351 _currentAcc = [0.25]*len(_currentAcc)
352
353 _currentDon = currentDon[:int(exons[1,1])]
354 _currentDon = [0.25]*len(_currentDon)
355
356 currentVMatchAlignment = _dna, exons, est, original_est, quality,\
357 _currentAcc, _currentDon
358
359 stop = cpu()
360 self.array_stuff += stop - start
361
362 #alternativeScore = self.calcAlignmentScore(currentVMatchAlignment)
363 #alternativeScores.append(self.calcAlignmentScore(currentVMatchAlignment))
364
365 # Lets start calculation
366 dna, exons, est, original_est, quality, acc_supp, don_supp =\
367 currentVMatchAlignment
368
369 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
370 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
371 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
372 quality, qualityPlifs, run)
373
374 score = computeSpliceAlignScoreWithQuality(dna, exons, est, original_est,\
375 quality, qualityPlifs, run, self.currentPhi)
376
377 stop = cpu()
378 self.computeSpliceAlignWithQualityTime += stop-start
379 start = cpu()
380
381 # Calculate the weights
382 trueWeightDon, trueWeightAcc, trueWeightIntron =\
383 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
384
385 stop = cpu()
386 self.computeSpliceWeightsTime += stop-start
387
388 start = cpu()
389
390 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
391
392 # Calculate w'phi(x,y) the total score of the alignment
393 alternativeScores.append((trueWeight.T * self.currentPhi)[0,0])
394
395 if score!=0.0:
396 print 'diff %f' % (trueWeight.T * self.currentPhi)[0,0] - score
397
398 stop = cpu()
399 self.DotProdTime += stop-start
400
401 _stop = cpu()
402 self.alternativeScoresTime += _stop-_start
403
404 return alternativeScores
405
406
407 def calcAlignmentScore(self,alignment):
408 """
409 Given an alignment (dna,exons,est) and the current parameter for QPalma
410 this function calculates the dot product of the feature representation of
411 the alignment and the parameter vector i.e the alignment score.
412 """
413
414 run = self.run
415
416 h = self.h
417 d = self.d
418 a = self.a
419 mmatrix = self.mmatrix
420 qualityPlifs = self.qualityPlifs
421
422 # Lets start calculation
423 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
424
425 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
426 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
427 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
428 quality, qualityPlifs, run)
429
430 # Calculate the weights
431 trueWeightDon, trueWeightAcc, trueWeightIntron =\
432 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
433
434 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
435
436 # Calculate w'phi(x,y) the total score of the alignment
437 return (trueWeight.T * self.currentPhi)[0,0]
438
439
440 def cpu():
441 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
442 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
443
444
445 if __name__ == '__main__':
446 #run_fname = sys.argv[1]
447 #data_fname = sys.argv[2]
448 #param_filename = sys.argv[3]
449
450 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
451 jp = os.path.join
452
453 run_fname = jp(dir,'run_object.pickle')
454 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_flag'
455
456 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
457 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
458
459 param_fname = jp(dir,'param_500.pickle')
460
461 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname)
462
463 start = cpu()
464 ph1.filter()
465 stop = cpu()
466
467 print 'total time elapsed: %f' % (stop-start)
468 print 'time spend for get seq: %f' % ph1.get_time
469 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
470 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
471 print 'time spend for count time: %f' % ph1.count_time
472 print 'time spend for init time: %f' % ph1.init_time
473 print 'time spend for read_parsing time: %f' % ph1.read_parsing
474 print 'time spend for main_loop time: %f' % ph1.main_loop
475 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
476
477 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
478 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
479 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
480 print 'time spend forarray_stuff time: %f'% ph1.array_stuff
481 #import cProfile
482 #cProfile.run('ph1.filter()')
483
484 #import hotshot
485 #p = hotshot.Profile('profile.log')
486 #p.runcall(ph1.filter)