+ added script which draws the coverage bar plots presented in the QPalma paper
[qpalma.git] / scripts / compile_dataset.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import sys
5 import random
6 import os
7 import re
8 import pdb
9 import cPickle
10
11 import numpy
12 from numpy.matlib import mat,zeros,ones,inf
13
14 import qpalma
15 import qpalma.tools
16
17 from qpalma.parsers import *
18
19 from Genefinding import *
20
21 from genome_utils import load_genomic
22
23 import qpalma.Configuration as Conf
24
25
26 class DatasetGenerator:
27 """
28
29 """
30
31 def __init__(self,filtered_reads,map_file,map_2nd_file):
32 assert os.path.exists(filtered_reads), 'Error: Can not find reads file'
33 self.filtered_reads = filtered_reads
34
35 assert os.path.exists(map_file), 'Error: Can not find map file'
36 assert os.path.exists(map_2nd_file), 'Error: Can not find map_2nd file'
37 self.map_file = map_file
38 self.map_2nd_file = map_2nd_file
39
40 self.training_set = []
41 self.testing_set = []
42
43 print 'parsing filtered reads..'
44 self.all_filtered_reads = parse_filtered_reads(self.filtered_reads)
45 print 'found %d filtered reads' % len(self.all_filtered_reads)
46
47
48 def saveAs(self,dataset_file):
49 assert not os.path.exists(dataset_file), 'The data_file already exists!'
50
51 all_keys = self.training_set.keys()
52 random.shuffle(all_keys)
53 training_keys = all_keys[0:10000]
54
55 # saving new datasets
56 #cPickle.dump(self.training_set,open('%s.train.pickle'%dataset_file,'w+'),protocol=2)
57 #cPickle.dump(training_keys,open('%s.train_keys.pickle'%dataset_file,'w+'),protocol=2)
58 cPickle.dump(self.testing_set,open('%s.test.pickle'%dataset_file,'w+'),protocol=2)
59
60 prediction_keys = [0]*len(self.testing_set.keys())
61 for pos,key in enumerate(self.testing_set.keys()):
62 prediction_keys[pos] = key
63
64 cPickle.dump(self.prediction_keys,open('%s.test_keys.pickle'%dataset_file,'w+'),protocol=2)
65
66
67 def compile_training_set(self):
68 # this stores the new dataset
69 dataset = {}
70
71 # Iterate over all remapped reads in order to generate for each read a
72 # training / prediction example
73 instance_counter = 1
74 skipped_ctr = 0
75
76 for id,filteredRead in self.all_filtered_reads.items():
77 if instance_counter % 1001 == 0:
78 print 'processed %d examples' % instance_counter
79
80 # training set consists only of spliced reads
81 if not id < 1000000300000:
82 continue
83
84 if filteredRead['strand'] != '+':
85 skipped_ctr += 1
86 continue
87
88 if not filteredRead['chr'] in range(1,6):
89 skipped_ctr += 1
90 continue
91
92 # we cut out the genomic region for training
93 CUT_OFFSET = random.randint(Conf.extension[0],Conf.extension[1])
94 genomicSeq_start = filteredRead['p_start'] - CUT_OFFSET
95 genomicSeq_stop = filteredRead['p_stop'] + CUT_OFFSET
96
97 # information of the full read we need for training
98 chromo = filteredRead['chr']
99 strand = filteredRead['strand']
100 original_est = filteredRead['seq']
101 quality = filteredRead['prb']
102 cal_prb = filteredRead['cal_prb']
103 chastity = filteredRead['chastity']
104
105 currentExons = zeros((2,2),dtype=numpy.int)
106 currentExons[0,0] = filteredRead['p_start']
107 currentExons[0,1] = filteredRead['exon_stop']
108 currentExons[1,0] = filteredRead['exon_start']
109 currentExons[1,1] = filteredRead['p_stop']
110
111 # add instance to set
112 currentSeqInfo = (id,chromo,strand,genomicSeq_start,genomicSeq_stop)
113 currentQualities = (quality,cal_prb,chastity)
114
115 dataset[id] = (currentSeqInfo,currentExons,original_est,currentQualities)
116
117 instance_counter += 1
118
119 print 'Full dataset has size %d' % len(dataset)
120 print 'Skipped %d reads' % skipped_ctr
121
122 self.training_set = dataset
123
124
125 def parse_map_file(self,dataset,map_file)
126 strand_map = ['-','+']
127 instance_counter = 1
128
129 for line in open(map_file):
130 if instance_counter % 1001 == 0:
131 print 'processed %d examples' % instance_counter
132
133 line = line.strip()
134 slist = line.split()
135 id = int(slist[0])
136 chromo = int(slist[1])
137 pos = int(slist[2])
138 strand = slist[3]
139 strand = strand_map[strand == 'D']
140
141 genomicSeq_start = pos - 1500
142 genomicSeq_stop = pos + 1500
143
144 # fetch missing information from original reads
145 filteredRead = self.all_filtered_reads[id]
146 original_est = filteredRead['seq']
147 original_est = original_est.lower()
148
149 original_est = filteredRead['seq']
150 prb = filteredRead['prb']
151 cal_prb = filteredRead['cal_prb']
152 chastity = filteredRead['chastity']
153
154 # add instance to set
155 currentSeqInfo = (id,chromo,strand,genomicSeq_start,genomicSeq_stop)
156 currentQualities = (prb,cal_prb,chastity)
157
158 dataset[id] = (currentSeqInfo,original_est,currentQualities)
159
160 instance_counter += 1
161
162 return dataset
163
164
165 def compile_testing_set(self):
166
167 dataset = {}
168
169 # usually we have two files to parse:
170 # the map file from the second run and a subset of the map file from the
171 # first run
172 dataset = self.parse_map_file(dataset,self.map_file)
173 dataset = self.parse_map_file(dataset,self.map_2nd_file)
174
175 # store the full set
176 self.testing_set = dataset
177
178
179 def compile_dataset_direct(filtered_reads,dataset_file):
180
181 strand_map = ['-','+']
182
183 SeqInfo = []
184 OriginalEsts = []
185 Qualities = []
186
187 instance_counter = 0
188
189 for line in open(filtered_reads):
190 line = line.strip()
191 slist = line.split()
192 id = int(slist[0])
193
194 if not id < 1000000300000:
195 continue
196
197 if instance_counter % 1000 == 0:
198 print 'processed %d examples' % instance_counter
199
200 chr = int(slist[1])
201 strand = slist[2]
202 strand = strand_map[strand == 'D']
203
204 genomicSeq_start = int(slist[10]) - 1000
205 genomicSeq_stop = int(slist[13] ) + 1000
206
207 original_est = slist[3]
208 original_est = original_est.lower()
209 #print original_est
210
211 prb = [ord(elem)-50 for elem in slist[6]]
212 cal_prb = [ord(elem)-64 for elem in slist[7]]
213 chastity = [ord(elem)+10 for elem in slist[8]]
214
215 #pdb.set_trace()
216
217 # add instance to set
218 SeqInfo.append((id,chr,strand,genomicSeq_start,genomicSeq_stop))
219 OriginalEsts.append(original_est)
220 Qualities.append( (prb,cal_prb,chastity) )
221
222 instance_counter += 1
223
224 dataset = [SeqInfo, OriginalEsts, Qualities ]
225
226 # saving new dataset
227 cPickle.dump(dataset,open(dataset_file,'w+'),protocol=2)
228
229
230
231
232 def getSpliceScores(chr,strand,intervalBegin,intervalEnd):
233 """
234 Now we want to use interval_query to get the predicted splice scores trained
235 on the TAIR sequence and annotation.
236 """
237
238 size = intervalEnd-intervalBegin
239 assert size > 1, 'Error (getSpliceScores): interval size is less than 2!'
240
241 acc = size*[0.0]
242 don = size*[0.0]
243
244 interval_matrix = createIntArrayFromList([intervalBegin,intervalEnd])
245 pos_size = new_intp()
246 intp_assign(pos_size,1)
247
248 # fetch acceptor scores
249 sscore_filename = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/acc/contig_%d%s'
250 acc = doIntervalQuery(chr,strand,intervalBegin,intervalEnd,sscore_filename)
251
252 # fetch donor scores
253 sscore_filename = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/don/contig_%d%s'
254 don = doIntervalQuery(chr,strand,intervalBegin,intervalEnd,sscore_filename)
255
256 return acc, don
257
258
259 def process_map(currentRRead,fRead):
260 """
261 For all matches found by Vmatch we calculate the fragment of the DNA we
262 would like to perform an aligment during prediction.
263 """
264
265 fragment_size = 3000
266
267 alternativeSeq = []
268
269 onePositiveLabel = False
270
271 rId = currentRRead['id']
272 pos = currentRRead['pos']
273 chr = currentRRead['chr']
274 strand = currentRRead['strand']
275
276 length = currentRRead['length']
277 offset = currentRRead['offset']
278
279 CUT_OFFSET = random.randint(Conf.extension[0],Conf.extension[1])
280
281 # vmatch found the correct position
282 if fRead['chr'] == chr and fRead['strand'] == strand and fRead['p_start']-36 <= pos <= fRead['p_stop']+36:
283 genomicSeq_start = fRead['p_start'] - CUT_OFFSET
284 genomicSeq_stop = fRead['p_stop'] + CUT_OFFSET
285 else:
286 genomicSeq_start = pos - (fragment_size/2)
287 genomicSeq_stop = pos + (fragment_size/2)
288
289 return (rId,chr,strand,genomicSeq_start,genomicSeq_stop)
290
291
292 def get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,dna_flat_files):
293 """
294 This function expects an interval, chromosome and strand information and
295 returns then the genomic sequence of this interval and the associated scores.
296 """
297
298 assert genomicSeq_start < genomicSeq_stop
299
300 chrom = 'chr%d' % chr
301 genomicSeq = load_genomic(chrom,strand,genomicSeq_start-1,genomicSeq_stop,dna_flat_files,one_based=False)
302 genomicSeq = genomicSeq.lower()
303
304 # check the obtained dna sequence
305 assert genomicSeq != '', 'load_genomic returned empty sequence!'
306
307 # all entries other than a c g t are set to n
308 no_base = re.compile('[^acgt]')
309 genomicSeq = no_base.sub('n',genomicSeq)
310
311 intervalBegin = genomicSeq_start-100
312 intervalEnd = genomicSeq_stop+100
313 seq_pos_offset = genomicSeq_start
314
315 currentAcc, currentDon = getSpliceScores(chr,strand,intervalBegin,intervalEnd)
316
317 currentAcc = currentAcc[100:-98]
318 currentAcc = currentAcc[1:]
319 currentDon = currentDon[100:-100]
320
321 length = len(genomicSeq)
322 currentAcc = currentAcc[:length]
323
324 currentDon = currentDon+[-inf]*(length-len(currentDon))
325
326 ag_tuple_pos = [p for p,e in enumerate(genomicSeq) if p>1 and genomicSeq[p-1]=='a' and genomicSeq[p]=='g' ]
327 gt_tuple_pos = [p for p,e in enumerate(genomicSeq) if p>0 and p<len(genomicSeq)-1 and e=='g' and (genomicSeq[p+1]=='t' or genomicSeq[p+1]=='c')]
328
329 assert ag_tuple_pos == [p for p,e in enumerate(currentAcc) if e != -inf and p > 1], pdb.set_trace()
330 assert gt_tuple_pos == [p for p,e in enumerate(currentDon) if e != -inf and p > 0], pdb.set_trace()
331 assert len(currentAcc) == len(currentDon)
332
333 return genomicSeq, currentAcc, currentDon
334
335
336 def reverse_complement(seq):
337 map = {'a':'t','c':'g','g':'c','t':'a'}
338
339 new_seq = [map[elem] for elem in seq]
340 new_seq.reverse()
341 new_seq = "".join(new_seq)
342
343 return new_seq
344
345
346 def get_seq(begin,end,exon_end):
347 """
348 """
349
350 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
351
352 if exon_end:
353 gene_start = begin
354 gene_stop = end+2
355 else:
356 gene_start = begin-2
357 gene_stop = end
358
359 chrom = 'chr%d' % 1
360 strand = '+'
361
362 genomicSeq = load_genomic(chrom,strand,gene_start,gene_stop,dna_flat_files,one_based=False)
363 genomicSeq = genomicSeq.lower()
364
365 return genomicSeq
366
367
368 def parseLine(line):
369 """
370 We assume that a line has the following entries:
371
372 read_nr,chr,strand,seq,splitpos,read_size,prb,cal_prb,chastity,gene_id,p_start,exon_stop,exon_start,p_stop
373
374 """
375 #try:
376 id,chr,strand,seq,splitpos,read_size,prb,cal_prb,chastity,gene_id,p_start,exon_stop,exon_start,p_stop,true_cut = line.split()
377 #except:
378 # id,chr,strand,seq,splitpos,read_size,prb,cal_prb,chastity,gene_id,p_start,exon_stop,exon_start,p_stop = line.split()
379 # true_cut = -1
380
381 splitpos = int(splitpos)
382 read_size = int(read_size)
383
384 seq=seq.lower()
385
386 assert strand in ['D','P']
387
388 if strand == 'D':
389 strand = '+'
390
391 if strand == 'P':
392 strand = '-'
393
394 chr = int(chr)
395
396 prb = [ord(elem)-50 for elem in prb]
397 cal_prb = [ord(elem)-64 for elem in cal_prb]
398 chastity = [ord(elem)+10 for elem in chastity]
399
400 p_start = int(p_start)
401 exon_stop = int(exon_stop)
402 exon_start = int(exon_start)
403 p_stop = int(p_stop)
404 true_cut = int(true_cut)
405
406 line_d = {'id':id, 'chr':chr, 'strand':strand, 'seq':seq, 'splitpos':splitpos,\
407 'read_size':read_size, 'prb':prb, 'cal_prb':cal_prb, 'chastity':chastity, 'gene_id':gene_id,\
408 'p_start':p_start, 'exon_stop':exon_stop, 'exon_start':exon_start,\
409 'p_stop':p_stop,'true_cut':true_cut}
410
411 return line_d
412
413
414 if __name__ == '__main__':
415 if len(sys.argv) == 1:
416 print info
417
418 assert len(sys.argv) == 6, help
419 compile_d(gff_file,dna_flat_files,solexa_reads,remapped_reads,dataset_file)