+ saving alignment information now in array container (saves factor 8 memory
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads together with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import sys
18 import cPickle
19 import pdb
20 import re
21 import os.path
22
23 from qpalma.sequence_utils import *
24
25 import numpy
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy.linalg import norm
28
29 import QPalmaDP
30 import qpalma
31
32 #from qpalma.SIQP_CPX import SIQPSolver
33 #from qpalma.SIQP_CVXOPT import SIQPSolver
34 import array
35
36 from qpalma.DataProc import *
37 from qpalma.computeSpliceWeights import *
38 from qpalma.set_param_palma import *
39 from qpalma.computeSpliceAlignWithQuality import *
40 from qpalma.penalty_lookup_new import *
41 from qpalma.compute_donacc import *
42 from qpalma.TrainingParam import Param
43 from qpalma.Plif import Plf
44
45 from qpalma.Configuration import *
46
47 # this two imports are needed for the load genomic resp. interval query
48 # functions
49 from Genefinding import *
50 from genome_utils import load_genomic
51 from Utils import calc_stat, calc_info, pprint_alignment, get_alignment
52
53 class SpliceSiteException:
54 pass
55
56
57 def getData(training_set,exampleKey,run):
58 currentSeqInfo,currentExons,original_est,currentQualities = training_set[exampleKey]
59 id,chr,strand,up_cut,down_cut = currentSeqInfo
60
61 est = original_est
62 est = "".join(est)
63 est = est.lower()
64 est = unbracket_est(est)
65 est = est.replace('-','')
66
67 assert len(est) == run['read_size'], pdb.set_trace()
68 est_len = len(est)
69
70 #original_est = OriginalEsts[exampleIdx]
71 original_est = "".join(original_est)
72 original_est = original_est.lower()
73
74 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
75 dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
76
77 # splice score is located at g of ag
78 ag_tuple_pos = [p for p,e in enumerate(dna) if p>1 and dna[p-1]=='a' and dna[p]=='g' ]
79 assert ag_tuple_pos == [p for p,e in enumerate(acc_supp) if e != -inf and p > 1], pdb.set_trace()
80
81 gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
82 assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
83
84 #original_exons = Exons[exampleIdx]
85 original_exons = currentExons
86 exons = original_exons - (up_cut-1)
87 exons[0,0] -= 1
88 exons[1,0] -= 1
89
90 if exons.shape == (2,2):
91 fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
92
93 donor_elem = dna[exons[0,1]:exons[0,1]+2]
94 acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
95
96 if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
97 print 'invalid donor in example %d'% exampleKey
98 raise SpliceSiteException
99
100 if not ( acceptor_elem == 'ag' ):
101 print 'invalid acceptor in example %d'% exampleKey
102 raise SpliceSiteException
103
104 assert len(fetched_dna_subseq) == len(est), pdb.set_trace()
105
106 return dna,est,acc_supp,don_supp,exons,original_est,currentQualities
107
108
109
110 class QPalma:
111 """
112 This class wraps the training and prediction functions for
113 the alignment.
114 """
115
116 def __init__(self):
117 self.ARGS = Param()
118
119
120 def plog(self,string):
121 self.logfh.write(string)
122 self.logfh.flush()
123
124
125 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
126 """
127 Given the needed input this method calls the QPalma C module which
128 calculates a dynamic programming in order to obtain an alignment
129 """
130 run = self.run
131
132 dna_len = len(dna)
133 est_len = len(est)
134
135 prb = QPalmaDP.createDoubleArrayFromList(quality)
136 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
137
138 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
139 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
140
141 d_len = len(donor)
142 donor = QPalmaDP.createDoubleArrayFromList(donor)
143 a_len = len(acceptor)
144 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
145
146 # Create the alignment object representing the interface to the C/C++ code.
147 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
148 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
149 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
150 currentAlignment.myalign( current_num_path, dna, dna_len,\
151 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
152 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
153 print_matrix)
154
155 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
156 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
157 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
158 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
159
160 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
161
162 if prediction_mode:
163 # part that is only needed for prediction
164 result_len = currentAlignment.getResultLength()
165 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
166 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
167
168 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
169
170 dna_array = [0.0]*result_len
171 est_array = [0.0]*result_len
172
173 for r_idx in range(result_len):
174 dna_array[r_idx] = c_dna_array[r_idx]
175 est_array[r_idx] = c_est_array[r_idx]
176
177 else:
178 dna_array = None
179 est_array = None
180
181 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
182 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
183
184 newSpliceAlign = zeros((current_num_path*dna_len,1))
185 newEstAlign = zeros((est_len*current_num_path,1))
186 newWeightMatch = zeros((current_num_path*mm_len,1))
187 newDPScores = zeros((current_num_path,1))
188 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
189
190 for i in range(dna_len*current_num_path):
191 newSpliceAlign[i] = c_SpliceAlign[i]
192
193 for i in range(est_len*current_num_path):
194 newEstAlign[i] = c_EstAlign[i]
195
196 for i in range(mm_len*current_num_path):
197 newWeightMatch[i] = c_WeightMatch[i]
198
199 for i in range(current_num_path):
200 newDPScores[i] = c_DPScores[i]
201
202 if self.use_quality_scores:
203 for i in range(run['totalQualSuppPoints']*current_num_path):
204 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
205
206 del c_SpliceAlign
207 del c_EstAlign
208 del c_WeightMatch
209 del c_DPScores
210 del c_qualityPlifsFeatures
211 del currentAlignment
212
213 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
214 newQualityPlifsFeatures, dna_array, est_array
215
216
217 def train(self,run,training_set):
218 self.run = run
219
220 full_working_path = os.path.join(run['alignment_dir'],run['name'])
221
222 #assert not os.path.exists(full_working_path)
223 if not os.path.exists(full_working_path):
224 os.mkdir(full_working_path)
225
226 assert os.path.exists(full_working_path)
227
228 # ATTENTION: Changing working directory
229 os.chdir(full_working_path)
230
231 self.logfh = open('_qpalma_train.log','w+')
232 cPickle.dump(run,open('run_obj.pickle','w+'))
233
234 self.plog("Settings are:\n")
235 self.plog("%s\n"%str(run))
236
237 if self.run['mode'] == 'normal':
238 self.use_quality_scores = False
239
240 elif self.run['mode'] == 'using_quality_scores':
241 self.use_quality_scores = True
242 else:
243 assert(False)
244
245 numExamples = len(training_set)
246 self.plog('Number of training examples: %d\n'% numExamples)
247
248 self.noImprovementCtr = 0
249 self.oldObjValue = 1e8
250
251 iteration_steps = run['iter_steps']
252 remove_duplicate_scores = run['remove_duplicate_scores']
253 print_matrix = run['print_matrix']
254 anzpath = run['anzpath']
255
256 # Initialize parameter vector /
257 #param = Conf.fixedParam[:run['numFeatures']]
258 param = numpy.matlib.rand(run['numFeatures'],1)
259
260 lengthSP = run['numLengthSuppPoints']
261 donSP = run['numDonSuppPoints']
262 accSP = run['numAccSuppPoints']
263 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
264 numq = run['numQualSuppPoints']
265 totalQualSP = run['totalQualSuppPoints']
266
267 # no intron length model
268 if not run['enable_intron_length']:
269 param[:lengthSP] *= 0.0
270
271 # Set the parameters such as limits penalties for the Plifs
272 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
273
274 # Initialize solver
275 self.plog('Initializing problem...\n')
276
277 #try:
278 # solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
279 #except:
280 # self.plog('Got no license. Telling queue to reschedule job...\n')
281 # sys.exit(99)
282
283 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
284 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
285
286 # stores the number of alignments done for each example (best path, second-best path etc.)
287 num_path = [anzpath]*numExamples
288 # stores the gap for each example
289 gap = [0.0]*numExamples
290 #############################################################################################
291 # Training
292 #############################################################################################
293 self.plog('Starting training...\n')
294
295 currentPhi = zeros((run['numFeatures'],1))
296 totalQualityPenalties = zeros((totalQualSP,1))
297
298 numConstPerRound = run['numConstraintsPerRound']
299 solver_call_ctr = 0
300
301 suboptimal_example = 0
302 iteration_nr = 0
303 param_idx = 0
304 const_added_ctr = 0
305
306 featureVectors = zeros((run['numFeatures'],numExamples))
307
308 # the main training loop
309 while True:
310 if iteration_nr == iteration_steps:
311 break
312
313 for exampleIdx,example_key in enumerate(training_set.keys()):
314 print 'Current example %d' % example_key
315 try:
316 dna,est,acc_supp,don_supp,exons,original_est,currentQualities =\
317 getData(training_set,example_key,run)
318 except SpliceSiteException:
319 continue
320
321 dna_len = len(dna)
322
323 if run['mode'] == 'normal':
324 quality = [40]*len(est)
325
326 if run['mode'] == 'using_quality_scores':
327 quality = currentQualities[0]
328
329 if not run['enable_quality_scores']:
330 quality = [40]*len(est)
331
332 if not run['enable_splice_signals']:
333 for idx,elem in enumerate(don_supp):
334 if elem != -inf:
335 don_supp[idx] = 0.0
336
337 for idx,elem in enumerate(acc_supp):
338 if elem != -inf:
339 acc_supp[idx] = 0.0
340
341 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
342 if run['mode'] == 'using_quality_scores':
343 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
344 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
345 quality, qualityPlifs,run)
346 else:
347 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
348
349 dna_calc = dna_calc.replace('-','')
350
351 #print 'right before computeSpliceWeights exampleIdx %d' % exampleIdx
352 # Calculate the weights
353 trueWeightDon, trueWeightAcc, trueWeightIntron =\
354 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
355 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
356
357 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
358 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
359 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
360 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
361
362 if run['mode'] == 'using_quality_scores':
363 totalQualityPenalties = param[-totalQualSP:]
364 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
365
366 # Calculate w'phi(x,y) the total score of the alignment
367 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
368
369 # The allWeights vector is supposed to store the weight parameter
370 # of the true alignment as well as the weight parameters of the
371 # num_path[exampleIdx] other alignments
372 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
373 allWeights[:,0] = trueWeight[:,0]
374
375 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
376 AlignmentScores[0] = trueAlignmentScore
377
378 ################## Calculate wrong alignment(s) ######################
379 # Compute donor, acceptor with penalty_lookup_new
380 # returns two double lists
381 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
382
383 #myalign wants the acceptor site on the g of the ag
384 #acceptor = acceptor[1:]
385 #acceptor.append(-inf)
386
387 #donor = [-inf] + donor[:-1]
388
389 ps = h.convert2SWIG()
390
391 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
392 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
393 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
394 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
395
396 newSpliceAlign = _newSpliceAlign
397 newEstAlign = _newEstAlign
398 newWeightMatch = _newWeightMatch
399 newDPScores = _newDPScores
400 newQualityPlifsFeatures = _newQualityPlifsFeatures
401
402 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
403 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
404
405 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
406 # Calculate weights of the respective alignments. Note that we are
407 # calculating n-best alignments without hamming loss, so we
408 # have to keep track which of the n-best alignments correspond to
409 # the true one in order not to incorporate a true alignment in the
410 # constraints. To keep track of the true and false alignments we
411 # define an array true_map with a boolean indicating the
412 # equivalence to the true alignment for each decoded alignment.
413 true_map = [0]*(num_path[exampleIdx]+1)
414 true_map[0] = 1
415
416 for pathNr in range(num_path[exampleIdx]):
417 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
418 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
419 acc_supp)
420
421 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
422 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
423 # Gewichte in restliche Zeilen der Matrix speichern
424 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
425
426 hpen = mat(h.penalties).reshape(len(h.penalties),1)
427 dpen = mat(d.penalties).reshape(len(d.penalties),1)
428 apen = mat(a.penalties).reshape(len(a.penalties),1)
429 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
430
431 featureVectors[:,exampleIdx] = allWeights[:,pathNr+1]
432
433 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
434
435 distinct_scores = False
436 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
437 distinct_scores = True
438
439 # Check wether scalar product + loss equals viterbi score
440 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
441 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
442 pdb.set_trace()
443
444 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
445 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
446
447 # if the pathNr-best alignment is very close to the true alignment consider it as true
448 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
449 true_map[pathNr+1] = 1
450
451 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
452 print "suboptimal_example %d\n" %exampleIdx
453 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
454 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
455
456 #pdb.set_trace()
457 suboptimal_example += 1
458 self.plog("suboptimal_example %d\n" %exampleIdx)
459
460 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
461 # this means that all n-best paths are to close to each other
462 # we have to extend the n-best search to a (n+1)-best
463 if len([elem for elem in true_map if elem == 1]) == len(true_map):
464 num_path[exampleIdx] = num_path[exampleIdx]+1
465
466 # Choose true and first false alignment for extending
467 firstFalseIdx = -1
468 for map_idx,elem in enumerate(true_map):
469 if elem == 0:
470 firstFalseIdx = map_idx
471 break
472
473 if False:
474 self.plog("Is considered as: %d\n" % true_map[1])
475
476 result_len = currentAlignment.getResultLength()
477 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
478 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
479
480 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
481
482 dna_array = [0.0]*result_len
483 est_array = [0.0]*result_len
484
485 for r_idx in range(result_len):
486 dna_array[r_idx] = c_dna_array[r_idx]
487 est_array[r_idx] = c_est_array[r_idx]
488
489 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
490 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
491
492 #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
493 #self.plog(line1+'\n')
494 #self.plog(line2+'\n')
495 #self.plog(line3+'\n')
496
497 # if there is at least one useful false alignment add the
498 # corresponding constraints to the optimization problem
499 if firstFalseIdx != -1:
500 firstFalseWeights = allWeights[:,firstFalseIdx]
501 differenceVector = trueWeight - firstFalseWeights
502 #pdb.set_trace()
503
504 #print 'NOT ADDING ANY CONSTRAINTS'
505 const_added = solver.addConstraint(differenceVector, exampleIdx)
506
507 const_added_ctr += 1
508 #
509 # end of one example processing
510 #
511
512 # call solver every nth example //added constraint
513 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
514 objValue,w,self.slacks = solver.solve()
515 solver_call_ctr += 1
516
517 if solver_call_ctr == 5:
518 numConstPerRound = 200
519 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
520
521 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
522 self.noImprovementCtr += 1
523
524 if self.noImprovementCtr == numExamples+1:
525 break
526
527 self.oldObjValue = objValue
528 print "objValue is %f" % objValue
529
530 sum_xis = 0
531 for elem in self.slacks:
532 sum_xis += elem
533
534 print 'sum of slacks is %f'% sum_xis
535 self.plog('sum of slacks is %f\n'% sum_xis)
536
537 for i in range(len(param)):
538 param[i] = w[i]
539
540 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
541 param_idx += 1
542 [h,d,a,mmatrix,qualityPlifs] =\
543 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
544
545 #
546 # end of one iteration through all examples
547 #
548
549 self.plog("suboptimal rounds %d\n" %suboptimal_example)
550
551 if self.noImprovementCtr == numExamples*2:
552 break
553
554 iteration_nr += 1
555
556 #
557 # end of optimization
558 #
559 print 'Training completed'
560
561 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
562 self.logfh.close()
563
564
565 ###############################################################################
566 #
567 # End of the code needed for training
568 #
569 # Begin of code for prediction
570 #
571 ###############################################################################
572
573 def predict(self,run,dataset_fn,prediction_keys,param,set_name):
574 """
575 Performing a prediction takes...
576 """
577 self.run = run
578
579 full_working_path = os.path.join(run['alignment_dir'],run['name'])
580
581 print 'full_working_path is %s' % full_working_path
582
583 #assert not os.path.exists(full_working_path)
584 if not os.path.exists(full_working_path):
585 os.mkdir(full_working_path)
586
587 assert os.path.exists(full_working_path)
588
589 # ATTENTION: Changing working directory
590 os.chdir(full_working_path)
591
592 self.logfh = open('_qpalma_predict_%s.log'%set_name,'w+')
593
594 if self.run['mode'] == 'normal':
595 self.use_quality_scores = False
596
597 elif self.run['mode'] == 'using_quality_scores':
598 self.use_quality_scores = True
599 else:
600 assert(False)
601
602 # number of prediction instances
603 self.plog('Number of prediction examples: %d\n'% len(prediction_keys))
604
605 # load dataset and fetch instances that shall be predicted
606 dataset = cPickle.load(open(dataset_fn))
607
608 prediction_set = {}
609 for key in prediction_keys:
610 prediction_set[key] = dataset[key]
611
612 # we do not need the full dataset anymore
613 del dataset
614
615 # Set the parameters such as limits penalties for the Plifs
616 [h,d,a,mmatrix,qualityPlifs] =\
617 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
618
619 #############################################################################################
620 # Prediction
621 #############################################################################################
622 self.plog('Starting prediction...\n')
623
624 donSP = self.run['numDonSuppPoints']
625 accSP = self.run['numAccSuppPoints']
626 lengthSP = self.run['numLengthSuppPoints']
627 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
628 numq = self.run['numQualSuppPoints']
629 totalQualSP = self.run['totalQualSuppPoints']
630
631 totalQualityPenalties = zeros((totalQualSP,1))
632
633 problem_ctr = 0
634
635 # where we store the predictions
636 allPredictions = []
637
638 # beginning of the prediction loop
639 for example_key in prediction_set.keys():
640 print 'Current example %d' % example_key
641
642 for example in prediction_set[example_key]:
643
644 currentSeqInfo,original_est,currentQualities = example
645
646 id,chr,strand,genomicSeq_start,genomicSeq_stop =\
647 currentSeqInfo
648
649 assert id == example_key
650
651 if not chr in range(1,6):
652 continue
653
654 self.plog('Loading example id: %d...\n'% int(id))
655
656 est = original_est
657 est = unbracket_est(est)
658
659 if run['mode'] == 'normal':
660 quality = [40]*len(est)
661
662 if run['mode'] == 'using_quality_scores':
663 quality = currentQualities[0]
664
665 if not run['enable_quality_scores']:
666 quality = [40]*len(est)
667
668 current_example_predictions = []
669
670 try:
671 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
672 except:
673 problem_ctr += 1
674 continue
675
676 if not run['enable_splice_signals']:
677 for idx,elem in enumerate(currentDon):
678 if elem != -inf:
679 currentDon[idx] = 0.0
680
681 for idx,elem in enumerate(currentAcc):
682 if elem != -inf:
683 currentAcc[idx] = 0.0
684
685 current_prediction = self.calc_alignment(currentDNASeq, est,\
686 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs)
687
688 current_prediction['id'] = id
689 #current_prediction['start_pos'] = up_cut
690 current_prediction['start_pos'] = genomicSeq_start
691 current_prediction['chr'] = chr
692 current_prediction['strand'] = strand
693
694 allPredictions.append(current_prediction)
695
696 # end of the prediction loop we save all predictions in a pickle file and exit
697 cPickle.dump(allPredictions,open('%s.predictions.pickle'%(set_name),'w+'))
698 print 'Prediction completed'
699 self.plog('Prediction completed\n')
700 mes = 'Problem ctr %d' % problem_ctr
701 print mes
702 self.plog(mes+'\n')
703 self.logfh.close()
704
705
706 def calc_alignment(self, dna, est, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
707 """
708 Given two sequences and the parameters we calculate on alignment
709 """
710
711 run = self.run
712 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
713
714 dna = str(dna)
715 est = str(est)
716
717 if '-' in est:
718 self.plog('found gap\n')
719 est = est.replace('-','')
720 assert len(est) == Conf.read_size
721
722 dna_len = len(dna)
723 est_len = len(est)
724
725 ps = h.convert2SWIG()
726
727 newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
728 newQualityPlifsFeatures, dna_array, est_array =\
729 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
730
731 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
732
733 # old code removed
734 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
735 newWeightMatch = newWeightMatch.reshape(1,mm_len)
736 true_map = [0]*2
737 true_map[0] = 1
738 pathNr = 0
739
740 _newSpliceAlign = array.array('B',newSpliceAlign.flatten().tolist()[0])
741 _newEstAlign = array.array('B',newEstAlign.flatten().tolist()[0])
742
743 alignment = get_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array) #(qStart, qEnd, tStart, tEnd, num_exons, qExonSizes, qStarts, qEnds, tExonSizes, tStarts, tEnds)
744
745 dna_array = array.array('c',dna_array)
746 est_array = array.array('c',est_array)
747
748 #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
749 #self.plog(line1+'\n')
750 #self.plog(line2+'\n')
751 #self.plog(line3+'\n')
752
753 newExons = self.calculatePredictedExons(newSpliceAlign)
754
755 current_prediction = {'predExons':newExons, 'dna':dna, 'est':est, 'DPScores':newDPScores,\
756 'alignment':alignment,\
757 'spliceAlign':_newSpliceAlign,'estAlign':_newEstAlign,\
758 'dna_array':dna_array, 'est_array':est_array }
759
760 return current_prediction
761
762
763 def calculatePredictedExons(self,SpliceAlign):
764 newExons = []
765 oldElem = -1
766 SpliceAlign = SpliceAlign.flatten().tolist()[0]
767 SpliceAlign.append(-1)
768 for pos,elem in enumerate(SpliceAlign):
769 if pos == 0:
770 oldElem = -1
771 else:
772 oldElem = SpliceAlign[pos-1]
773
774 if oldElem != 0 and elem == 0: # start of exon
775 newExons.append(pos)
776
777 if oldElem == 0 and elem != 0: # end of exon
778 newExons.append(pos)
779
780 return newExons
781
782 ###########################
783 # A simple command line
784 # interface
785 ###########################
786
787 if __name__ == '__main__':
788 assert len(sys.argv) == 4
789
790 run_fn = sys.argv[1]
791 dataset_fn = sys.argv[2]
792 param_fn = sys.argv[3]
793
794 run_obj = cPickle.load(open(run_fn))
795 dataset_obj = cPickle.load(open(dataset_fn))
796
797 qpalma = QPalma()
798
799 if param_fn == 'train':
800 qpalma.train(run_obj,dataset_obj)
801 else:
802 param_obj = cPickle.load(open(param_fn))
803 qpalma.predict(run_obj,dataset_obj,param_obj)