+ changed output format of alignment file script to be compatible with est2gff
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads together with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import sys
18 import cPickle
19 import pdb
20 import re
21 import os.path
22
23 from qpalma.sequence_utils import *
24
25 import numpy
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy.linalg import norm
28
29 import QPalmaDP
30 import qpalma
31
32 from qpalma.SIQP_CPX import SIQPSolver
33 #from qpalma.SIQP_CVXOPT import SIQPSolver
34
35 from qpalma.DataProc import *
36 from qpalma.computeSpliceWeights import *
37 from qpalma.set_param_palma import *
38 from qpalma.computeSpliceAlignWithQuality import *
39 from qpalma.penalty_lookup_new import *
40 from qpalma.compute_donacc import *
41 from qpalma.TrainingParam import Param
42 from qpalma.Plif import Plf
43
44 from qpalma.Configuration import *
45
46 # this two imports are needed for the load genomic resp. interval query
47 # functions
48 from Genefinding import *
49 from genome_utils import load_genomic
50 from Utils import calc_stat, calc_info, pprint_alignment, get_alignment
51
52 class SpliceSiteException:
53 pass
54
55
56 def getData(training_set,exampleKey,run):
57 currentSeqInfo,currentExons,original_est,currentQualities = training_set[exampleKey]
58 id,chr,strand,up_cut,down_cut = currentSeqInfo
59
60 est = original_est
61 est = "".join(est)
62 est = est.lower()
63 est = unbracket_est(est)
64 est = est.replace('-','')
65
66 assert len(est) == run['read_size'], pdb.set_trace()
67 est_len = len(est)
68
69 #original_est = OriginalEsts[exampleIdx]
70 original_est = "".join(original_est)
71 original_est = original_est.lower()
72
73 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
74 dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
75
76 # splice score is located at g of ag
77 ag_tuple_pos = [p for p,e in enumerate(dna) if p>1 and dna[p-1]=='a' and dna[p]=='g' ]
78 assert ag_tuple_pos == [p for p,e in enumerate(acc_supp) if e != -inf and p > 1], pdb.set_trace()
79
80 gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
81 assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
82
83 #original_exons = Exons[exampleIdx]
84 original_exons = currentExons
85 exons = original_exons - (up_cut-1)
86 exons[0,0] -= 1
87 exons[1,0] -= 1
88
89 if exons.shape == (2,2):
90 fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
91
92 donor_elem = dna[exons[0,1]:exons[0,1]+2]
93 acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
94
95 if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
96 print 'invalid donor in example %d'% exampleKey
97 raise SpliceSiteException
98
99 if not ( acceptor_elem == 'ag' ):
100 print 'invalid acceptor in example %d'% exampleKey
101 raise SpliceSiteException
102
103 assert len(fetched_dna_subseq) == len(est), pdb.set_trace()
104
105 return dna,est,acc_supp,don_supp,exons,original_est,currentQualities
106
107
108
109 class QPalma:
110 """
111 This class wraps the training and prediction functions for
112 the alignment.
113 """
114
115 def __init__(self):
116 self.ARGS = Param()
117
118
119 def plog(self,string):
120 self.logfh.write(string)
121 self.logfh.flush()
122
123
124 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
125 """
126 Given the needed input this method calls the QPalma C module which
127 calculates a dynamic programming in order to obtain an alignment
128 """
129 run = self.run
130
131 dna_len = len(dna)
132 est_len = len(est)
133
134 prb = QPalmaDP.createDoubleArrayFromList(quality)
135 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
136
137 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
138 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
139
140 d_len = len(donor)
141 donor = QPalmaDP.createDoubleArrayFromList(donor)
142 a_len = len(acceptor)
143 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
144
145 # Create the alignment object representing the interface to the C/C++ code.
146 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
147 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
148 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
149 currentAlignment.myalign( current_num_path, dna, dna_len,\
150 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
151 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
152 print_matrix)
153
154 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
155 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
156 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
157 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
158
159 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
160
161 if prediction_mode:
162 # part that is only needed for prediction
163 result_len = currentAlignment.getResultLength()
164 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
165 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
166
167 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
168
169 dna_array = [0.0]*result_len
170 est_array = [0.0]*result_len
171
172 for r_idx in range(result_len):
173 dna_array[r_idx] = c_dna_array[r_idx]
174 est_array[r_idx] = c_est_array[r_idx]
175
176 else:
177 dna_array = None
178 est_array = None
179
180 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
181 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
182
183 newSpliceAlign = zeros((current_num_path*dna_len,1))
184 newEstAlign = zeros((est_len*current_num_path,1))
185 newWeightMatch = zeros((current_num_path*mm_len,1))
186 newDPScores = zeros((current_num_path,1))
187 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
188
189 for i in range(dna_len*current_num_path):
190 newSpliceAlign[i] = c_SpliceAlign[i]
191
192 for i in range(est_len*current_num_path):
193 newEstAlign[i] = c_EstAlign[i]
194
195 for i in range(mm_len*current_num_path):
196 newWeightMatch[i] = c_WeightMatch[i]
197
198 for i in range(current_num_path):
199 newDPScores[i] = c_DPScores[i]
200
201 if self.use_quality_scores:
202 for i in range(run['totalQualSuppPoints']*current_num_path):
203 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
204
205 del c_SpliceAlign
206 del c_EstAlign
207 del c_WeightMatch
208 del c_DPScores
209 del c_qualityPlifsFeatures
210 del currentAlignment
211
212 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
213 newQualityPlifsFeatures, dna_array, est_array
214
215
216 def train(self,run,training_set):
217 self.run = run
218
219 full_working_path = os.path.join(run['alignment_dir'],run['name'])
220
221 #assert not os.path.exists(full_working_path)
222 if not os.path.exists(full_working_path):
223 os.mkdir(full_working_path)
224
225 assert os.path.exists(full_working_path)
226
227 # ATTENTION: Changing working directory
228 os.chdir(full_working_path)
229
230 self.logfh = open('_qpalma_train.log','w+')
231 cPickle.dump(run,open('run_obj.pickle','w+'))
232
233 self.plog("Settings are:\n")
234 self.plog("%s\n"%str(run))
235
236 if self.run['mode'] == 'normal':
237 self.use_quality_scores = False
238
239 elif self.run['mode'] == 'using_quality_scores':
240 self.use_quality_scores = True
241 else:
242 assert(False)
243
244 numExamples = len(training_set)
245 self.plog('Number of training examples: %d\n'% numExamples)
246
247 self.noImprovementCtr = 0
248 self.oldObjValue = 1e8
249
250 iteration_steps = run['iter_steps']
251 remove_duplicate_scores = run['remove_duplicate_scores']
252 print_matrix = run['print_matrix']
253 anzpath = run['anzpath']
254
255 # Initialize parameter vector /
256 #param = Conf.fixedParam[:run['numFeatures']]
257 param = numpy.matlib.rand(run['numFeatures'],1)
258
259 lengthSP = run['numLengthSuppPoints']
260 donSP = run['numDonSuppPoints']
261 accSP = run['numAccSuppPoints']
262 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
263 numq = run['numQualSuppPoints']
264 totalQualSP = run['totalQualSuppPoints']
265
266 # no intron length model
267 if not run['enable_intron_length']:
268 param[:lengthSP] *= 0.0
269
270 # Set the parameters such as limits penalties for the Plifs
271 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
272
273 # Initialize solver
274 self.plog('Initializing problem...\n')
275
276 try:
277 solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
278 except:
279 self.plog('Got no license. Telling queue to reschedule job...\n')
280 sys.exit(99)
281
282 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
283 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
284
285 # stores the number of alignments done for each example (best path, second-best path etc.)
286 num_path = [anzpath]*numExamples
287 # stores the gap for each example
288 gap = [0.0]*numExamples
289 #############################################################################################
290 # Training
291 #############################################################################################
292 self.plog('Starting training...\n')
293
294 currentPhi = zeros((run['numFeatures'],1))
295 totalQualityPenalties = zeros((totalQualSP,1))
296
297 numConstPerRound = run['numConstraintsPerRound']
298 solver_call_ctr = 0
299
300 suboptimal_example = 0
301 iteration_nr = 0
302 param_idx = 0
303 const_added_ctr = 0
304
305 featureVectors = zeros((run['numFeatures'],numExamples))
306
307 # the main training loop
308 while True:
309 if iteration_nr == iteration_steps:
310 break
311
312 for exampleIdx,example_key in enumerate(training_set.keys()):
313 print 'Current example %d' % example_key
314 try:
315 dna,est,acc_supp,don_supp,exons,original_est,currentQualities =\
316 getData(training_set,example_key,run)
317 except SpliceSiteException:
318 continue
319
320 dna_len = len(dna)
321
322 if run['mode'] == 'normal':
323 quality = [40]*len(est)
324
325 if run['mode'] == 'using_quality_scores':
326 quality = currentQualities[0]
327
328 if not run['enable_quality_scores']:
329 quality = [40]*len(est)
330
331 if not run['enable_splice_signals']:
332 for idx,elem in enumerate(don_supp):
333 if elem != -inf:
334 don_supp[idx] = 0.0
335
336 for idx,elem in enumerate(acc_supp):
337 if elem != -inf:
338 acc_supp[idx] = 0.0
339
340 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
341 if run['mode'] == 'using_quality_scores':
342 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
343 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
344 quality, qualityPlifs,run)
345 else:
346 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
347
348 dna_calc = dna_calc.replace('-','')
349
350 #print 'right before computeSpliceWeights exampleIdx %d' % exampleIdx
351 # Calculate the weights
352 trueWeightDon, trueWeightAcc, trueWeightIntron =\
353 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
354 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
355
356 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
357 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
358 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
359 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
360
361 if run['mode'] == 'using_quality_scores':
362 totalQualityPenalties = param[-totalQualSP:]
363 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
364
365 # Calculate w'phi(x,y) the total score of the alignment
366 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
367
368 # The allWeights vector is supposed to store the weight parameter
369 # of the true alignment as well as the weight parameters of the
370 # num_path[exampleIdx] other alignments
371 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
372 allWeights[:,0] = trueWeight[:,0]
373
374 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
375 AlignmentScores[0] = trueAlignmentScore
376
377 ################## Calculate wrong alignment(s) ######################
378 # Compute donor, acceptor with penalty_lookup_new
379 # returns two double lists
380 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
381
382 #myalign wants the acceptor site on the g of the ag
383 #acceptor = acceptor[1:]
384 #acceptor.append(-inf)
385
386 #donor = [-inf] + donor[:-1]
387
388 ps = h.convert2SWIG()
389
390 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
391 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
392 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
393 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
394
395 newSpliceAlign = _newSpliceAlign
396 newEstAlign = _newEstAlign
397 newWeightMatch = _newWeightMatch
398 newDPScores = _newDPScores
399 newQualityPlifsFeatures = _newQualityPlifsFeatures
400
401 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
402 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
403
404 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
405 # Calculate weights of the respective alignments. Note that we are
406 # calculating n-best alignments without hamming loss, so we
407 # have to keep track which of the n-best alignments correspond to
408 # the true one in order not to incorporate a true alignment in the
409 # constraints. To keep track of the true and false alignments we
410 # define an array true_map with a boolean indicating the
411 # equivalence to the true alignment for each decoded alignment.
412 true_map = [0]*(num_path[exampleIdx]+1)
413 true_map[0] = 1
414
415 for pathNr in range(num_path[exampleIdx]):
416 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
417 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
418 acc_supp)
419
420 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
421 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
422 # Gewichte in restliche Zeilen der Matrix speichern
423 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
424
425 hpen = mat(h.penalties).reshape(len(h.penalties),1)
426 dpen = mat(d.penalties).reshape(len(d.penalties),1)
427 apen = mat(a.penalties).reshape(len(a.penalties),1)
428 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
429
430 featureVectors[:,exampleIdx] = allWeights[:,pathNr+1]
431
432 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
433
434 distinct_scores = False
435 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
436 distinct_scores = True
437
438 # Check wether scalar product + loss equals viterbi score
439 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
440 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
441 pdb.set_trace()
442
443 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
444 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
445
446 # if the pathNr-best alignment is very close to the true alignment consider it as true
447 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
448 true_map[pathNr+1] = 1
449
450 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
451 print "suboptimal_example %d\n" %exampleIdx
452 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
453 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
454
455 #pdb.set_trace()
456 suboptimal_example += 1
457 self.plog("suboptimal_example %d\n" %exampleIdx)
458
459 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
460 # this means that all n-best paths are to close to each other
461 # we have to extend the n-best search to a (n+1)-best
462 if len([elem for elem in true_map if elem == 1]) == len(true_map):
463 num_path[exampleIdx] = num_path[exampleIdx]+1
464
465 # Choose true and first false alignment for extending
466 firstFalseIdx = -1
467 for map_idx,elem in enumerate(true_map):
468 if elem == 0:
469 firstFalseIdx = map_idx
470 break
471
472 if False:
473 self.plog("Is considered as: %d\n" % true_map[1])
474
475 result_len = currentAlignment.getResultLength()
476 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
477 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
478
479 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
480
481 dna_array = [0.0]*result_len
482 est_array = [0.0]*result_len
483
484 for r_idx in range(result_len):
485 dna_array[r_idx] = c_dna_array[r_idx]
486 est_array[r_idx] = c_est_array[r_idx]
487
488 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
489 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
490
491 #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
492 #self.plog(line1+'\n')
493 #self.plog(line2+'\n')
494 #self.plog(line3+'\n')
495
496 # if there is at least one useful false alignment add the
497 # corresponding constraints to the optimization problem
498 if firstFalseIdx != -1:
499 firstFalseWeights = allWeights[:,firstFalseIdx]
500 differenceVector = trueWeight - firstFalseWeights
501 #pdb.set_trace()
502
503 #print 'NOT ADDING ANY CONSTRAINTS'
504 const_added = solver.addConstraint(differenceVector, exampleIdx)
505
506 const_added_ctr += 1
507 #
508 # end of one example processing
509 #
510
511 # call solver every nth example //added constraint
512 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
513 objValue,w,self.slacks = solver.solve()
514 solver_call_ctr += 1
515
516 if solver_call_ctr == 5:
517 numConstPerRound = 200
518 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
519
520 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
521 self.noImprovementCtr += 1
522
523 if self.noImprovementCtr == numExamples+1:
524 break
525
526 self.oldObjValue = objValue
527 print "objValue is %f" % objValue
528
529 sum_xis = 0
530 for elem in self.slacks:
531 sum_xis += elem
532
533 print 'sum of slacks is %f'% sum_xis
534 self.plog('sum of slacks is %f\n'% sum_xis)
535
536 for i in range(len(param)):
537 param[i] = w[i]
538
539 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
540 param_idx += 1
541 [h,d,a,mmatrix,qualityPlifs] =\
542 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
543
544 #
545 # end of one iteration through all examples
546 #
547
548 self.plog("suboptimal rounds %d\n" %suboptimal_example)
549
550 if self.noImprovementCtr == numExamples*2:
551 break
552
553 iteration_nr += 1
554
555 #
556 # end of optimization
557 #
558 print 'Training completed'
559
560 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
561 self.logfh.close()
562
563
564 ###############################################################################
565 #
566 # End of the code needed for training
567 #
568 # Begin of code for prediction
569 #
570 ###############################################################################
571
572 def predict(self,run,dataset_fn,prediction_keys,param,set_name):
573 """
574 Performing a prediction takes...
575 """
576 self.run = run
577
578 full_working_path = os.path.join(run['alignment_dir'],run['name'])
579
580 print 'full_working_path is %s' % full_working_path
581
582 #assert not os.path.exists(full_working_path)
583 if not os.path.exists(full_working_path):
584 os.mkdir(full_working_path)
585
586 assert os.path.exists(full_working_path)
587
588 # ATTENTION: Changing working directory
589 os.chdir(full_working_path)
590
591 self.logfh = open('_qpalma_predict_%s.log'%set_name,'w+')
592
593 if self.run['mode'] == 'normal':
594 self.use_quality_scores = False
595
596 elif self.run['mode'] == 'using_quality_scores':
597 self.use_quality_scores = True
598 else:
599 assert(False)
600
601 # number of prediction instances
602 self.plog('Number of prediction examples: %d\n'% len(prediction_keys))
603
604 # load dataset and fetch instances that shall be predicted
605 dataset = cPickle.load(open(dataset_fn))
606
607 prediction_set = {}
608 for key in prediction_keys:
609 prediction_set[key] = dataset[key]
610
611 # we do not need the full dataset anymore
612 del dataset
613
614 # Set the parameters such as limits penalties for the Plifs
615 [h,d,a,mmatrix,qualityPlifs] =\
616 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
617
618 #############################################################################################
619 # Prediction
620 #############################################################################################
621 self.plog('Starting prediction...\n')
622
623 donSP = self.run['numDonSuppPoints']
624 accSP = self.run['numAccSuppPoints']
625 lengthSP = self.run['numLengthSuppPoints']
626 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
627 numq = self.run['numQualSuppPoints']
628 totalQualSP = self.run['totalQualSuppPoints']
629
630 totalQualityPenalties = zeros((totalQualSP,1))
631
632 problem_ctr = 0
633
634 # where we store the predictions
635 allPredictions = []
636
637 # beginning of the prediction loop
638 for example_key in prediction_set.keys():
639 print 'Current example %d' % example_key
640
641 for example in prediction_set[example_key]:
642
643 currentSeqInfo,original_est,currentQualities = example
644
645 id,chr,strand,genomicSeq_start,genomicSeq_stop =\
646 currentSeqInfo
647
648 assert id == example_key
649
650 if not chr in range(1,6):
651 continue
652
653 self.plog('Loading example id: %d...\n'% int(id))
654
655 est = original_est
656 est = unbracket_est(est)
657
658 if run['mode'] == 'normal':
659 quality = [40]*len(est)
660
661 if run['mode'] == 'using_quality_scores':
662 quality = currentQualities[0]
663
664 if not run['enable_quality_scores']:
665 quality = [40]*len(est)
666
667 current_example_predictions = []
668
669 try:
670 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
671 except:
672 problem_ctr += 1
673 continue
674
675 if not run['enable_splice_signals']:
676 for idx,elem in enumerate(currentDon):
677 if elem != -inf:
678 currentDon[idx] = 0.0
679
680 for idx,elem in enumerate(currentAcc):
681 if elem != -inf:
682 currentAcc[idx] = 0.0
683
684 current_prediction = self.calc_alignment(currentDNASeq, est,\
685 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs)
686
687 current_prediction['id'] = id
688 #current_prediction['start_pos'] = up_cut
689 current_prediction['start_pos'] = genomicSeq_start
690 current_prediction['chr'] = chr
691 current_prediction['strand'] = strand
692
693 allPredictions.append(current_prediction)
694
695 # end of the prediction loop we save all predictions in a pickle file and exit
696 cPickle.dump(allPredictions,open('%s.predictions.pickle'%(set_name),'w+'))
697 print 'Prediction completed'
698 self.plog('Prediction completed\n')
699 mes = 'Problem ctr %d' % problem_ctr
700 print mes
701 self.plog(mes+'\n')
702 self.logfh.close()
703
704
705 def calc_alignment(self, dna, est, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
706 """
707 Given two sequences and the parameters we calculate on alignment
708 """
709
710 run = self.run
711 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
712
713 dna = str(dna)
714 est = str(est)
715
716 if '-' in est:
717 self.plog('found gap\n')
718 est = est.replace('-','')
719 assert len(est) == Conf.read_size
720
721 dna_len = len(dna)
722 est_len = len(est)
723
724 ps = h.convert2SWIG()
725
726 newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
727 newQualityPlifsFeatures, dna_array, est_array =\
728 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
729
730 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
731
732 # old code removed
733 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
734 newWeightMatch = newWeightMatch.reshape(1,mm_len)
735 true_map = [0]*2
736 true_map[0] = 1
737 pathNr = 0
738
739 _newSpliceAlign = newSpliceAlign.flatten().tolist()[0]
740 _newEstAlign = newEstAlign.flatten().tolist()[0]
741
742 alignment = get_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array) #(qStart, qEnd, tStart, tEnd, num_exons, qExonSizes, qStarts, qEnds, tExonSizes, tStarts, tEnds)
743 #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
744 #self.plog(line1+'\n')
745 #self.plog(line2+'\n')
746 #self.plog(line3+'\n')
747
748 newExons = self.calculatePredictedExons(newSpliceAlign)
749
750 current_prediction = {'predExons':newExons, 'dna':dna, 'est':est, 'DPScores':newDPScores,\
751 'alignment':alignment}
752
753 return current_prediction
754
755
756 def calculatePredictedExons(self,SpliceAlign):
757 newExons = []
758 oldElem = -1
759 SpliceAlign = SpliceAlign.flatten().tolist()[0]
760 SpliceAlign.append(-1)
761 for pos,elem in enumerate(SpliceAlign):
762 if pos == 0:
763 oldElem = -1
764 else:
765 oldElem = SpliceAlign[pos-1]
766
767 if oldElem != 0 and elem == 0: # start of exon
768 newExons.append(pos)
769
770 if oldElem == 0 and elem != 0: # end of exon
771 newExons.append(pos)
772
773 return newExons
774
775 ###########################
776 # A simple command line
777 # interface
778 ###########################
779
780 if __name__ == '__main__':
781 assert len(sys.argv) == 4
782
783 run_fn = sys.argv[1]
784 dataset_fn = sys.argv[2]
785 param_fn = sys.argv[3]
786
787 run_obj = cPickle.load(open(run_fn))
788 dataset_obj = cPickle.load(open(dataset_fn))
789
790 qpalma = QPalma()
791
792 if param_fn == 'train':
793 qpalma.train(run_obj,dataset_obj)
794 else:
795 param_obj = cPickle.load(open(param_fn))
796 qpalma.predict(run_obj,dataset_obj,param_obj)