+ take into account multiple hits per unique read
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads together with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import sys
18 import cPickle
19 import pdb
20 import re
21 import os.path
22
23 from compile_dataset import getSpliceScores, get_seq_and_scores
24
25 import numpy
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy.linalg import norm
28
29 import QPalmaDP
30 import qpalma
31
32 from qpalma.SIQP_CPX import SIQPSolver
33 #from qpalma.SIQP_CVXOPT import SIQPSolver
34
35 from qpalma.DataProc import *
36 from qpalma.computeSpliceWeights import *
37 from qpalma.set_param_palma import *
38 from qpalma.computeSpliceAlignWithQuality import *
39 from qpalma.penalty_lookup_new import *
40 from qpalma.compute_donacc import *
41 from qpalma.TrainingParam import Param
42 from qpalma.Plif import Plf
43
44 from qpalma.Configuration import *
45
46 # this two imports are needed for the load genomic resp. interval query
47 # functions
48 from Genefinding import *
49 from genome_utils import load_genomic
50 from Utils import calc_stat, calc_info, pprint_alignment, get_alignment
51
52 class SpliceSiteException:
53 pass
54
55
56 def unbracket_est(est):
57 new_est = ''
58 e = 0
59
60 while True:
61 if e >= len(est):
62 break
63
64 if est[e] == '[':
65 new_est += est[e+2]
66 e += 4
67 else:
68 new_est += est[e]
69 e += 1
70
71 return "".join(new_est).lower()
72
73
74 def getData(training_set,exampleKey,run):
75 currentSeqInfo,currentExons,original_est,currentQualities = training_set[exampleKey]
76 id,chr,strand,up_cut,down_cut = currentSeqInfo
77
78 est = original_est
79 est = "".join(est)
80 est = est.lower()
81 est = unbracket_est(est)
82 est = est.replace('-','')
83
84 assert len(est) == run['read_size'], pdb.set_trace()
85 est_len = len(est)
86
87 #original_est = OriginalEsts[exampleIdx]
88 original_est = "".join(original_est)
89 original_est = original_est.lower()
90
91 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
92 dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
93
94 # splice score is located at g of ag
95 ag_tuple_pos = [p for p,e in enumerate(dna) if p>1 and dna[p-1]=='a' and dna[p]=='g' ]
96 assert ag_tuple_pos == [p for p,e in enumerate(acc_supp) if e != -inf and p > 1], pdb.set_trace()
97
98 gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
99 assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
100
101 #original_exons = Exons[exampleIdx]
102 original_exons = currentExons
103 exons = original_exons - (up_cut-1)
104 exons[0,0] -= 1
105 exons[1,0] -= 1
106
107 if exons.shape == (2,2):
108 fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
109
110 donor_elem = dna[exons[0,1]:exons[0,1]+2]
111 acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
112
113 if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
114 print 'invalid donor in example %d'% exampleKey
115 raise SpliceSiteException
116
117 if not ( acceptor_elem == 'ag' ):
118 print 'invalid acceptor in example %d'% exampleKey
119 raise SpliceSiteException
120
121 assert len(fetched_dna_subseq) == len(est), pdb.set_trace()
122
123 return dna,est,acc_supp,don_supp,exons,original_est,currentQualities
124
125
126
127 class QPalma:
128 """
129 This class wraps the training and prediction functions for
130 the alignment.
131 """
132
133 def __init__(self):
134 self.ARGS = Param()
135
136
137 def plog(self,string):
138 self.logfh.write(string)
139 self.logfh.flush()
140
141
142 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
143 """
144 Given the needed input this method calls the QPalma C module which
145 calculates a dynamic programming in order to obtain an alignment
146 """
147 run = self.run
148
149 dna_len = len(dna)
150 est_len = len(est)
151
152 prb = QPalmaDP.createDoubleArrayFromList(quality)
153 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
154
155 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
156 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
157
158 d_len = len(donor)
159 donor = QPalmaDP.createDoubleArrayFromList(donor)
160 a_len = len(acceptor)
161 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
162
163 # Create the alignment object representing the interface to the C/C++ code.
164 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
165 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
166 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
167 currentAlignment.myalign( current_num_path, dna, dna_len,\
168 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
169 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
170 print_matrix)
171
172 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
173 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
174 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
175 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
176
177 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
178
179 if prediction_mode:
180 # part that is only needed for prediction
181 result_len = currentAlignment.getResultLength()
182 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
183 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
184
185 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
186
187 dna_array = [0.0]*result_len
188 est_array = [0.0]*result_len
189
190 for r_idx in range(result_len):
191 dna_array[r_idx] = c_dna_array[r_idx]
192 est_array[r_idx] = c_est_array[r_idx]
193
194 else:
195 dna_array = None
196 est_array = None
197
198 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
199 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
200
201 newSpliceAlign = zeros((current_num_path*dna_len,1))
202 newEstAlign = zeros((est_len*current_num_path,1))
203 newWeightMatch = zeros((current_num_path*mm_len,1))
204 newDPScores = zeros((current_num_path,1))
205 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
206
207 for i in range(dna_len*current_num_path):
208 newSpliceAlign[i] = c_SpliceAlign[i]
209
210 for i in range(est_len*current_num_path):
211 newEstAlign[i] = c_EstAlign[i]
212
213 for i in range(mm_len*current_num_path):
214 newWeightMatch[i] = c_WeightMatch[i]
215
216 for i in range(current_num_path):
217 newDPScores[i] = c_DPScores[i]
218
219 if self.use_quality_scores:
220 for i in range(run['totalQualSuppPoints']*current_num_path):
221 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
222
223 del c_SpliceAlign
224 del c_EstAlign
225 del c_WeightMatch
226 del c_DPScores
227 del c_qualityPlifsFeatures
228 del currentAlignment
229
230 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
231 newQualityPlifsFeatures, dna_array, est_array
232
233
234 def train(self,run,training_set):
235 self.run = run
236
237 full_working_path = os.path.join(run['alignment_dir'],run['name'])
238
239 #assert not os.path.exists(full_working_path)
240 if not os.path.exists(full_working_path):
241 os.mkdir(full_working_path)
242
243 assert os.path.exists(full_working_path)
244
245 # ATTENTION: Changing working directory
246 os.chdir(full_working_path)
247
248 self.logfh = open('_qpalma_train.log','w+')
249 cPickle.dump(run,open('run_obj.pickle','w+'))
250
251 self.plog("Settings are:\n")
252 self.plog("%s\n"%str(run))
253
254 if self.run['mode'] == 'normal':
255 self.use_quality_scores = False
256
257 elif self.run['mode'] == 'using_quality_scores':
258 self.use_quality_scores = True
259 else:
260 assert(False)
261
262 numExamples = len(training_set)
263 self.plog('Number of training examples: %d\n'% numExamples)
264
265 self.noImprovementCtr = 0
266 self.oldObjValue = 1e8
267
268 iteration_steps = run['iter_steps']
269 remove_duplicate_scores = run['remove_duplicate_scores']
270 print_matrix = run['print_matrix']
271 anzpath = run['anzpath']
272
273 # Initialize parameter vector /
274 #param = Conf.fixedParam[:run['numFeatures']]
275 param = numpy.matlib.rand(run['numFeatures'],1)
276
277 lengthSP = run['numLengthSuppPoints']
278 donSP = run['numDonSuppPoints']
279 accSP = run['numAccSuppPoints']
280 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
281 numq = run['numQualSuppPoints']
282 totalQualSP = run['totalQualSuppPoints']
283
284 # no intron length model
285 if not run['enable_intron_length']:
286 param[:lengthSP] *= 0.0
287
288 # Set the parameters such as limits penalties for the Plifs
289 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
290
291 # Initialize solver
292 self.plog('Initializing problem...\n')
293
294 try:
295 solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
296 except:
297 self.plog('Got no license. Telling queue to reschedule job...\n')
298 sys.exit(99)
299
300 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
301 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
302
303 # stores the number of alignments done for each example (best path, second-best path etc.)
304 num_path = [anzpath]*numExamples
305 # stores the gap for each example
306 gap = [0.0]*numExamples
307 #############################################################################################
308 # Training
309 #############################################################################################
310 self.plog('Starting training...\n')
311
312 currentPhi = zeros((run['numFeatures'],1))
313 totalQualityPenalties = zeros((totalQualSP,1))
314
315 numConstPerRound = run['numConstraintsPerRound']
316 solver_call_ctr = 0
317
318 suboptimal_example = 0
319 iteration_nr = 0
320 param_idx = 0
321 const_added_ctr = 0
322
323 featureVectors = zeros((run['numFeatures'],numExamples))
324
325 # the main training loop
326 while True:
327 if iteration_nr == iteration_steps:
328 break
329
330 for exampleIdx,example_key in enumerate(training_set.keys()):
331 print 'Current example %d' % example_key
332 try:
333 dna,est,acc_supp,don_supp,exons,original_est,currentQualities =\
334 getData(training_set,example_key,run)
335 except SpliceSiteException:
336 continue
337
338 dna_len = len(dna)
339
340 if run['mode'] == 'normal':
341 quality = [40]*len(est)
342
343 if run['mode'] == 'using_quality_scores':
344 quality = currentQualities[0]
345
346 if not run['enable_quality_scores']:
347 quality = [40]*len(est)
348
349 if not run['enable_splice_signals']:
350 for idx,elem in enumerate(don_supp):
351 if elem != -inf:
352 don_supp[idx] = 0.0
353
354 for idx,elem in enumerate(acc_supp):
355 if elem != -inf:
356 acc_supp[idx] = 0.0
357
358 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
359 if run['mode'] == 'using_quality_scores':
360 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
361 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
362 quality, qualityPlifs,run)
363 else:
364 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
365
366 dna_calc = dna_calc.replace('-','')
367
368 #print 'right before computeSpliceWeights exampleIdx %d' % exampleIdx
369 # Calculate the weights
370 trueWeightDon, trueWeightAcc, trueWeightIntron =\
371 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
372 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
373
374 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
375 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
376 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
377 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
378
379 if run['mode'] == 'using_quality_scores':
380 totalQualityPenalties = param[-totalQualSP:]
381 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
382
383 # Calculate w'phi(x,y) the total score of the alignment
384 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
385
386 # The allWeights vector is supposed to store the weight parameter
387 # of the true alignment as well as the weight parameters of the
388 # num_path[exampleIdx] other alignments
389 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
390 allWeights[:,0] = trueWeight[:,0]
391
392 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
393 AlignmentScores[0] = trueAlignmentScore
394
395 ################## Calculate wrong alignment(s) ######################
396 # Compute donor, acceptor with penalty_lookup_new
397 # returns two double lists
398 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
399
400 #myalign wants the acceptor site on the g of the ag
401 #acceptor = acceptor[1:]
402 #acceptor.append(-inf)
403
404 #donor = [-inf] + donor[:-1]
405
406 ps = h.convert2SWIG()
407
408 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
409 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
410 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
411 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
412
413 newSpliceAlign = _newSpliceAlign
414 newEstAlign = _newEstAlign
415 newWeightMatch = _newWeightMatch
416 newDPScores = _newDPScores
417 newQualityPlifsFeatures = _newQualityPlifsFeatures
418
419 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
420 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
421
422 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
423 # Calculate weights of the respective alignments. Note that we are
424 # calculating n-best alignments without hamming loss, so we
425 # have to keep track which of the n-best alignments correspond to
426 # the true one in order not to incorporate a true alignment in the
427 # constraints. To keep track of the true and false alignments we
428 # define an array true_map with a boolean indicating the
429 # equivalence to the true alignment for each decoded alignment.
430 true_map = [0]*(num_path[exampleIdx]+1)
431 true_map[0] = 1
432
433 for pathNr in range(num_path[exampleIdx]):
434 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
435 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
436 acc_supp)
437
438 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
439 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
440 # Gewichte in restliche Zeilen der Matrix speichern
441 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
442
443 hpen = mat(h.penalties).reshape(len(h.penalties),1)
444 dpen = mat(d.penalties).reshape(len(d.penalties),1)
445 apen = mat(a.penalties).reshape(len(a.penalties),1)
446 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
447
448 featureVectors[:,exampleIdx] = allWeights[:,pathNr+1]
449
450 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
451
452 distinct_scores = False
453 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
454 distinct_scores = True
455
456 # Check wether scalar product + loss equals viterbi score
457 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
458 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
459 pdb.set_trace()
460
461 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
462 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
463
464 # if the pathNr-best alignment is very close to the true alignment consider it as true
465 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
466 true_map[pathNr+1] = 1
467
468 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
469 print "suboptimal_example %d\n" %exampleIdx
470 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
471 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
472
473 #pdb.set_trace()
474 suboptimal_example += 1
475 self.plog("suboptimal_example %d\n" %exampleIdx)
476
477 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
478 # this means that all n-best paths are to close to each other
479 # we have to extend the n-best search to a (n+1)-best
480 if len([elem for elem in true_map if elem == 1]) == len(true_map):
481 num_path[exampleIdx] = num_path[exampleIdx]+1
482
483 # Choose true and first false alignment for extending
484 firstFalseIdx = -1
485 for map_idx,elem in enumerate(true_map):
486 if elem == 0:
487 firstFalseIdx = map_idx
488 break
489
490 if False:
491 self.plog("Is considered as: %d\n" % true_map[1])
492
493 result_len = currentAlignment.getResultLength()
494 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
495 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
496
497 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
498
499 dna_array = [0.0]*result_len
500 est_array = [0.0]*result_len
501
502 for r_idx in range(result_len):
503 dna_array[r_idx] = c_dna_array[r_idx]
504 est_array[r_idx] = c_est_array[r_idx]
505
506 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
507 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
508
509 #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
510 #self.plog(line1+'\n')
511 #self.plog(line2+'\n')
512 #self.plog(line3+'\n')
513
514 # if there is at least one useful false alignment add the
515 # corresponding constraints to the optimization problem
516 if firstFalseIdx != -1:
517 firstFalseWeights = allWeights[:,firstFalseIdx]
518 differenceVector = trueWeight - firstFalseWeights
519 #pdb.set_trace()
520
521 #print 'NOT ADDING ANY CONSTRAINTS'
522 const_added = solver.addConstraint(differenceVector, exampleIdx)
523
524 const_added_ctr += 1
525 #
526 # end of one example processing
527 #
528
529 # call solver every nth example //added constraint
530 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
531 objValue,w,self.slacks = solver.solve()
532 solver_call_ctr += 1
533
534 if solver_call_ctr == 5:
535 numConstPerRound = 200
536 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
537
538 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
539 self.noImprovementCtr += 1
540
541 if self.noImprovementCtr == numExamples+1:
542 break
543
544 self.oldObjValue = objValue
545 print "objValue is %f" % objValue
546
547 sum_xis = 0
548 for elem in self.slacks:
549 sum_xis += elem
550
551 print 'sum of slacks is %f'% sum_xis
552 self.plog('sum of slacks is %f\n'% sum_xis)
553
554 for i in range(len(param)):
555 param[i] = w[i]
556
557 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
558 param_idx += 1
559 [h,d,a,mmatrix,qualityPlifs] =\
560 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
561
562 #
563 # end of one iteration through all examples
564 #
565
566 self.plog("suboptimal rounds %d\n" %suboptimal_example)
567
568 if self.noImprovementCtr == numExamples*2:
569 break
570
571 iteration_nr += 1
572
573 #
574 # end of optimization
575 #
576 print 'Training completed'
577
578 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
579 self.logfh.close()
580
581
582 ###############################################################################
583 #
584 # End of the code needed for training
585 #
586 # Begin of code for prediction
587 #
588 ###############################################################################
589
590 def predict(self,run,dataset_fn,prediction_keys,param,set_name):
591 """
592 Performing a prediction takes...
593 """
594 self.run = run
595
596 full_working_path = os.path.join(run['alignment_dir'],run['name'])
597
598 print 'full_working_path is %s' % full_working_path
599
600 #assert not os.path.exists(full_working_path)
601 if not os.path.exists(full_working_path):
602 os.mkdir(full_working_path)
603
604 assert os.path.exists(full_working_path)
605
606 # ATTENTION: Changing working directory
607 os.chdir(full_working_path)
608
609 self.logfh = open('_qpalma_predict_%s.log'%set_name,'w+')
610
611 if self.run['mode'] == 'normal':
612 self.use_quality_scores = False
613
614 elif self.run['mode'] == 'using_quality_scores':
615 self.use_quality_scores = True
616 else:
617 assert(False)
618
619 # number of prediction instances
620 self.plog('Number of prediction examples: %d\n'% len(prediction_keys))
621
622 # load dataset and fetch instances that shall be predicted
623 dataset = cPickle.load(open(dataset_fn))
624
625 prediction_set = {}
626 for key in prediction_keys:
627 prediction_set[key] = dataset[key]
628
629 # we do not need the full dataset anymore
630 del dataset
631
632 # Set the parameters such as limits penalties for the Plifs
633 [h,d,a,mmatrix,qualityPlifs] =\
634 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
635
636 #############################################################################################
637 # Prediction
638 #############################################################################################
639 self.plog('Starting prediction...\n')
640
641 donSP = self.run['numDonSuppPoints']
642 accSP = self.run['numAccSuppPoints']
643 lengthSP = self.run['numLengthSuppPoints']
644 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
645 numq = self.run['numQualSuppPoints']
646 totalQualSP = self.run['totalQualSuppPoints']
647
648 totalQualityPenalties = zeros((totalQualSP,1))
649
650 problem_ctr = 0
651
652 # where we store the predictions
653 allPredictions = []
654
655 # beginning of the prediction loop
656 for example_key in prediction_set.keys():
657 print 'Current example %d' % example_key
658
659 for example in prediction_set[example_key]:
660
661 currentSeqInfo,original_est,currentQualities = example
662
663 id,chr,strand,genomicSeq_start,genomicSeq_stop =\
664 currentSeqInfo
665
666 assert id == example_key
667
668 if not chr in range(1,6):
669 continue
670
671 self.plog('Loading example id: %d...\n'% int(id))
672
673 est = original_est
674 est = unbracket_est(est)
675
676 if run['mode'] == 'normal':
677 quality = [40]*len(est)
678
679 if run['mode'] == 'using_quality_scores':
680 quality = currentQualities[0]
681
682 if not run['enable_quality_scores']:
683 quality = [40]*len(est)
684
685 current_example_predictions = []
686
687 try:
688 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
689 except:
690 problem_ctr += 1
691 continue
692
693 if not run['enable_splice_signals']:
694 for idx,elem in enumerate(currentDon):
695 if elem != -inf:
696 currentDon[idx] = 0.0
697
698 for idx,elem in enumerate(currentAcc):
699 if elem != -inf:
700 currentAcc[idx] = 0.0
701
702 current_prediction = self.calc_alignment(currentDNASeq, est,\
703 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs)
704
705 current_prediction['id'] = id
706 #current_prediction['start_pos'] = up_cut
707 current_prediction['start_pos'] = genomicSeq_start
708 current_prediction['chr'] = chr
709 current_prediction['strand'] = strand
710
711 allPredictions.append(current_prediction)
712
713 # end of the prediction loop we save all predictions in a pickle file and exit
714 cPickle.dump(allPredictions,open('%s.predictions.pickle'%(set_name),'w+'))
715 print 'Prediction completed'
716 print 'Problem ctr %d' % problem_ctr
717 self.logfh.close()
718
719
720 def calc_alignment(self, dna, est, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
721 """
722 Given two sequences and the parameters we calculate on alignment
723 """
724
725 run = self.run
726 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
727
728 dna = str(dna)
729 est = str(est)
730
731 if '-' in est:
732 self.plog('found gap\n')
733 est = est.replace('-','')
734 assert len(est) == 36
735
736 dna_len = len(dna)
737 est_len = len(est)
738
739 ps = h.convert2SWIG()
740
741 newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
742 newQualityPlifsFeatures, dna_array, est_array =\
743 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
744
745 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
746
747 # old code removed
748 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
749 newWeightMatch = newWeightMatch.reshape(1,mm_len)
750 true_map = [0]*2
751 true_map[0] = 1
752 pathNr = 0
753
754 _newSpliceAlign = newSpliceAlign.flatten().tolist()[0]
755 _newEstAlign = newEstAlign.flatten().tolist()[0]
756
757 alignment = get_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array) #(qStart, qEnd, tStart, tEnd, num_exons, qExonSizes, qStarts, qEnds, tExonSizes, tStarts, tEnds)
758 #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
759 #self.plog(line1+'\n')
760 #self.plog(line2+'\n')
761 #self.plog(line3+'\n')
762
763 newExons = self.calculatePredictedExons(newSpliceAlign)
764
765 current_prediction = {'predExons':newExons, 'dna':dna, 'est':est, 'DPScores':newDPScores,\
766 'alignment':alignment}
767
768 return current_prediction
769
770
771 def calculatePredictedExons(self,SpliceAlign):
772 newExons = []
773 oldElem = -1
774 SpliceAlign = SpliceAlign.flatten().tolist()[0]
775 SpliceAlign.append(-1)
776 for pos,elem in enumerate(SpliceAlign):
777 if pos == 0:
778 oldElem = -1
779 else:
780 oldElem = SpliceAlign[pos-1]
781
782 if oldElem != 0 and elem == 0: # start of exon
783 newExons.append(pos)
784
785 if oldElem == 0 and elem != 0: # end of exon
786 newExons.append(pos)
787
788 return newExons
789
790 ###########################
791 # A simple command line
792 # interface
793 ###########################
794
795 if __name__ == '__main__':
796 assert len(sys.argv) == 4
797
798 run_fn = sys.argv[1]
799 dataset_fn = sys.argv[2]
800 param_fn = sys.argv[3]
801
802 run_obj = cPickle.load(open(run_fn))
803 dataset_obj = cPickle.load(open(dataset_fn))
804
805 qpalma = QPalma()
806
807 if param_fn == 'train':
808 qpalma.train(run_obj,dataset_obj)
809 else:
810 param_obj = cPickle.load(open(param_fn))
811 qpalma.predict(run_obj,dataset_obj,param_obj)