git-svn-id: http://svn.tuebingen.mpg.de/ag-raetsch/projects/QPalma@8590 e1793c9e...
[qpalma.git] / scripts / PipelineHeuristic.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pydb
7 import pdb
8 import os
9 import os.path
10 import math
11 import resource
12
13 from qpalma.DataProc import *
14 from qpalma.computeSpliceWeights import *
15 from qpalma.set_param_palma import *
16 from qpalma.computeSpliceAlignWithQuality import *
17 from qpalma.penalty_lookup_new import *
18 from qpalma.compute_donacc import *
19 from qpalma.TrainingParam import Param
20 from qpalma.Plif import Plf
21
22 from qpalma.tools.splicesites import getDonAccScores
23 from qpalma.Configuration import *
24
25 from compile_dataset import getSpliceScores, get_seq_and_scores
26
27 from numpy.matlib import mat,zeros,ones,inf
28 from numpy import inf,mean
29
30 from qpalma.parsers import PipelineReadParser
31
32
33 def unbracket_est(est):
34 new_est = ''
35 e = 0
36
37 while True:
38 if e >= len(est):
39 break
40
41 if est[e] == '[':
42 new_est += est[e+2]
43 e += 4
44 else:
45 new_est += est[e]
46 e += 1
47
48 return "".join(new_est).lower()
49
50
51 class PipelineHeuristic:
52 """
53 This class wraps the filter which decides whether an alignment found by
54 vmatch is spliced an should be then newly aligned using QPalma or not.
55 """
56
57 def __init__(self,run_fname,data_fname,param_fname):
58 """
59 We need a run object holding information about the nr. of support points
60 etc.
61 """
62
63 run = cPickle.load(open(run_fname))
64 self.run = run
65
66 start = cpu()
67
68 self.data_fname = data_fname
69
70 self.param = cPickle.load(open(param_fname))
71
72 # Set the parameters such as limits penalties for the Plifs
73 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
74
75 self.h = h
76 self.d = d
77 self.a = a
78 self.mmatrix = mmatrix
79 self.qualityPlifs = qualityPlifs
80
81 # when we look for alternative alignments with introns this value is the
82 # mean of intron size
83 self.intron_size = 90
84
85 self.read_size = 36
86
87 self.original_reads = {}
88
89 for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
90 line = line.strip()
91 id,seq,q1,q2,q3 = line.split()
92 id = int(id)
93 self.original_reads[id] = seq
94
95 lengthSP = run['numLengthSuppPoints']
96 donSP = run['numDonSuppPoints']
97 accSP = run['numAccSuppPoints']
98 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
99 numq = run['numQualSuppPoints']
100 totalQualSP = run['totalQualSuppPoints']
101
102 currentPhi = zeros((run['numFeatures'],1))
103 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
104 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
105 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
106 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
107
108 totalQualityPenalties = self.param[-totalQualSP:]
109 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
110 self.currentPhi = currentPhi
111
112 # we want to identify spliced reads
113 # so true pos are spliced reads that are predicted "spliced"
114 self.true_pos = 0
115
116 # as false positives we count all reads that are not spliced but predicted
117 # as "spliced"
118 self.false_pos = 0
119
120 self.true_neg = 0
121 self.false_neg = 0
122
123 # total time spend for get seq and scores
124 self.get_time = 0.0
125 self.calcAlignmentScoreTime = 0.0
126 self.alternativeScoresTime = 0.0
127
128 self.count_time = 0.0
129 self.read_parsing = 0.0
130 self.main_loop = 0.0
131 self.splice_site_time = 0.0
132 self.computeSpliceAlignWithQualityTime = 0.0
133 self.computeSpliceWeightsTime = 0.0
134 self.DotProdTime = 0.0
135 self.array_stuff = 0.0
136 stop = cpu()
137
138 self.init_time = stop-start
139
140 def filter(self):
141 """
142 This method...
143 """
144 run = self.run
145
146 start = cpu()
147
148 rrp = PipelineReadParser(self.data_fname)
149 all_remapped_reads = rrp.parse()
150
151 stop = cpu()
152
153 self.read_parsing = stop-start
154
155 ctr = 0
156 unspliced_ctr = 0
157 spliced_ctr = 0
158
159 print 'Starting filtering...'
160 _start = cpu()
161
162 for readId,currentReadLocations in all_remapped_reads.items():
163 for location in currentReadLocations[:1]:
164
165 id = location['id']
166 chr = location['chr']
167 pos = location['pos']
168 strand = location['strand']
169 mismatch = location['mismatches']
170 length = location['length']
171 off = location['offset']
172 seq = location['seq']
173 prb = location['prb']
174 cal_prb = location['cal_prb']
175 chastity = location['chastity']
176
177 id = int(id)
178
179 if strand == '-':
180 continue
181
182 if ctr == 100:
183 break
184
185 #if pos > 10000000:
186 # continue
187
188 unb_seq = unbracket_est(seq)
189 effective_len = len(unb_seq)
190
191 genomicSeq_start = pos
192 genomicSeq_stop = pos+effective_len-1
193
194 start = cpu()
195 #print genomicSeq_start,genomicSeq_stop
196 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
197 stop = cpu()
198 self.get_time += stop-start
199
200 dna = currentDNASeq
201 exons = zeros((2,1))
202 exons[0,0] = 0
203 exons[1,0] = effective_len
204 est = unb_seq
205 original_est = seq
206 quality = prb
207
208 #pdb.set_trace()
209
210 currentVMatchAlignment = dna, exons, est, original_est, quality,\
211 currentAcc, currentDon
212 vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
213
214 alternativeAlignmentScores = self.calcAlternativeAlignments(location)
215
216 start = cpu()
217 # found no alternatives
218 if alternativeAlignmentScores == []:
219 continue
220
221 maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
222 #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
223 #print 'all candidates %s' % str(alternativeAlignmentScores)
224
225 new_id = id - 1000000300000
226
227 unspliced = False
228 # unspliced
229 if new_id > 0:
230 unspliced = True
231
232 # Seems that according to our learned parameters VMatch found a good
233 # alignment of the current read
234 if maxAlternativeAlignmentScore < vMatchScore:
235 unspliced_ctr += 1
236
237 if unspliced:
238 self.true_neg += 1
239 else:
240 self.false_neg += 1
241
242 # We found an alternative alignment considering splice sites that scores
243 # higher than the VMatch alignment
244 else:
245 spliced_ctr += 1
246
247 if unspliced:
248 self.false_pos += 1
249 else:
250 self.true_pos += 1
251
252 ctr += 1
253 stop = cpu()
254 self.count_time = stop-start
255
256 _stop = cpu()
257 self.main_loop = _stop-_start
258
259 print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
260 print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
261 print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
262
263
264 def findHighestScoringSpliceSites(self,currentAcc,currentDon):
265 max_don = -inf
266 don_pos = []
267 for idx,score in enumerate(currentDon):
268 if score > -inf and idx > 1 and idx < self.read_size:
269 don_pos.append(idx)
270
271 if len(don_pos) == 2:
272 break
273
274 max_acc = -inf
275 acc_pos = []
276 for idx,score in enumerate(currentAcc):
277 if score > -inf and idx >= self.intron_size:
278 acc_pos = idx
279 #acc_pos.append(idx)
280 break
281
282 return don_pos,acc_pos
283
284
285 def calcAlternativeAlignments(self,location):
286 """
287 Given an alignment proposed by Vmatch this function calculates possible
288 alternative alignments taking into account for example matched
289 donor/acceptor positions.
290 """
291
292 run = self.run
293
294 id = location['id']
295 chr = location['chr']
296 pos = location['pos']
297 strand = location['strand']
298 seq = location['seq']
299 #orig_seq = location['orig_seq']
300 prb = location['prb']
301 cal_prb = location['cal_prb']
302
303 orig_seq = self.original_reads[int(id)]
304
305 unb_seq = unbracket_est(seq)
306 effective_len = len(unb_seq)
307
308 genomicSeq_start = pos
309 genomicSeq_stop = pos+self.intron_size*2+self.read_size*2
310
311 start = cpu()
312 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
313 stop = cpu()
314 self.get_time += stop-start
315
316 dna = currentDNASeq
317
318 start = cpu()
319 alt_don_pos,acc_pos = self.findHighestScoringSpliceSites(currentAcc,currentDon)
320 stop = cpu()
321 self.splice_site_time = stop-start
322
323 alternativeScores = []
324
325 exons = zeros((2,2),dtype=numpy.int)
326 est = unb_seq
327 original_est = seq
328 quality = prb
329
330 # inlined
331 h = self.h
332 d = self.d
333 a = self.a
334 mmatrix = self.mmatrix
335 qualityPlifs = self.qualityPlifs
336 # inlined
337
338 _start = cpu()
339 for don_pos in alt_don_pos:
340 start = cpu()
341
342 exons[0,0] = 0
343 exons[0,1] = don_pos
344 exons[1,0] = acc_pos+1
345 exons[1,1] = acc_pos+1+(self.read_size-don_pos)
346
347 _dna = dna[:int(exons[1,1])]
348 _dna = _dna[:exons[1,0]] + orig_seq[don_pos:]
349
350 _currentAcc = currentAcc[:int(exons[1,1])]
351 _currentAcc = [0.25]*len(_currentAcc)
352
353 _currentDon = currentDon[:int(exons[1,1])]
354 _currentDon = [0.25]*len(_currentDon)
355
356 currentVMatchAlignment = _dna, exons, est, original_est, quality,\
357 _currentAcc, _currentDon
358
359 stop = cpu()
360 self.array_stuff += stop - start
361
362 #alternativeScore = self.calcAlignmentScore(currentVMatchAlignment)
363 #alternativeScores.append(self.calcAlignmentScore(currentVMatchAlignment))
364
365 # Lets start calculation
366 dna, exons, est, original_est, quality, acc_supp, don_supp =\
367 currentVMatchAlignment
368
369 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
370 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
371 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
372 quality, qualityPlifs, run)
373
374 print exons
375 score = computeSpliceAlignScoreWithQuality(dna, exons, est, original_est,\
376 quality, qualityPlifs, run, self.currentPhi)
377
378 stop = cpu()
379 self.computeSpliceAlignWithQualityTime += stop-start
380 start = cpu()
381
382 # Calculate the weights
383 trueWeightDon, trueWeightAcc, trueWeightIntron =\
384 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
385
386 stop = cpu()
387 self.computeSpliceWeightsTime += stop-start
388
389 start = cpu()
390
391 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
392
393 # Calculate w'phi(x,y) the total score of the alignment
394 alternativeScores.append((trueWeight.T * self.currentPhi)[0,0])
395
396 if score!=0.0:
397 print 'diff %f' % (trueWeight.T * self.currentPhi)[0,0] - score
398
399 stop = cpu()
400 self.DotProdTime += stop-start
401
402 _stop = cpu()
403 self.alternativeScoresTime += _stop-_start
404
405 return alternativeScores
406
407
408 def calcAlignmentScore(self,alignment):
409 """
410 Given an alignment (dna,exons,est) and the current parameter for QPalma
411 this function calculates the dot product of the feature representation of
412 the alignment and the parameter vector i.e the alignment score.
413 """
414
415 run = self.run
416
417 h = self.h
418 d = self.d
419 a = self.a
420 mmatrix = self.mmatrix
421 qualityPlifs = self.qualityPlifs
422
423 # Lets start calculation
424 dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
425
426 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
427 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
428 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
429 quality, qualityPlifs, run)
430
431 # Calculate the weights
432 trueWeightDon, trueWeightAcc, trueWeightIntron =\
433 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
434
435 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
436
437 # Calculate w'phi(x,y) the total score of the alignment
438 return (trueWeight.T * self.currentPhi)[0,0]
439
440
441 def cpu():
442 return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
443 resource.getrusage(resource.RUSAGE_SELF).ru_stime)
444
445
446 if __name__ == '__main__':
447 #run_fname = sys.argv[1]
448 #data_fname = sys.argv[2]
449 #param_filename = sys.argv[3]
450
451 dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
452 jp = os.path.join
453
454 run_fname = jp(dir,'run_object.pickle')
455 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_flag'
456
457 #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
458 data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
459
460 param_fname = jp(dir,'param_500.pickle')
461
462 ph1 = PipelineHeuristic(run_fname,data_fname,param_fname)
463
464 start = cpu()
465 ph1.filter()
466 stop = cpu()
467
468 print 'total time elapsed: %f' % (stop-start)
469 print 'time spend for get seq: %f' % ph1.get_time
470 print 'time spend for calcAlignmentScoreTime: %f' % ph1.calcAlignmentScoreTime
471 print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
472 print 'time spend for count time: %f' % ph1.count_time
473 print 'time spend for init time: %f' % ph1.init_time
474 print 'time spend for read_parsing time: %f' % ph1.read_parsing
475 print 'time spend for main_loop time: %f' % ph1.main_loop
476 print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
477
478 print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
479 print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
480 print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
481 print 'time spend forarray_stuff time: %f'% ph1.array_stuff
482 #import cProfile
483 #cProfile.run('ph1.filter()')
484
485 #import hotshot
486 #p = hotshot.Profile('profile.log')
487 #p.runcall(ph1.filter)