+ performed some further refactoring
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads toegether with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import sys
18 import cPickle
19 import subprocess
20 import scipy.io
21 import pdb
22 import re
23 import os.path
24 import pydb
25
26 from compile_dataset import getSpliceScores
27
28 import numpy
29 from numpy.matlib import mat,zeros,ones,inf
30 from numpy.linalg import norm
31
32 import QPalmaDP
33 import qpalma
34 from qpalma.SIQP_CPX import SIQPSolver
35 from qpalma.DataProc import *
36 from qpalma.computeSpliceWeights import *
37 from qpalma.set_param_palma import *
38 from qpalma.computeSpliceAlignWithQuality import *
39 from qpalma.penalty_lookup_new import *
40 from qpalma.compute_donacc import *
41 from qpalma.TrainingParam import Param
42 from qpalma.Plif import Plf
43
44 from qpalma.tools.splicesites import getDonAccScores
45 from qpalma.Configuration import *
46
47 # this two imports are needed for the load genomic resp. interval query
48 # functions
49 from Genefinding import *
50 from genome_utils import load_genomic
51
52 from Utils import calc_stat, calc_info, pprint_alignment
53
54 class QPalma:
55 """
56 A training method for the QPalma project
57 """
58
59 def __init__(self,run):
60 self.ARGS = Param()
61 self.run = run
62
63 if self.run['mode'] == 'normal':
64 self.use_quality_scores = False
65
66 elif self.run['mode'] == 'using_quality_scores':
67 self.use_quality_scores = True
68 else:
69 assert(False)
70
71 full_working_path = os.path.join(run['experiment_path'],run['name'])
72 #assert not os.path.exists(full_working_path)
73 #os.mkdir(full_working_path)
74 assert os.path.exists(full_working_path)
75
76 # ATTENTION: Changing working directory
77 os.chdir(full_working_path)
78
79 cPickle.dump(run,open('run_object.pickle','w+'))
80
81 def plog(self,string):
82 self.logfh.write(string)
83 self.logfh.flush()
84
85
86 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
87 """
88 Given the needed input this method calls the QPalma C module which
89 calculates a dynamic programming in order to obtain an alignment
90 """
91 run = self.run
92
93 dna_len = len(dna)
94 est_len = len(est)
95
96 prb = QPalmaDP.createDoubleArrayFromList(quality)
97 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
98
99 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
100 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
101
102 d_len = len(donor)
103 donor = QPalmaDP.createDoubleArrayFromList(donor)
104 a_len = len(acceptor)
105 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
106
107 # Create the alignment object representing the interface to the C/C++ code.
108 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
109 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
110 #print 'Calling myalign...'
111 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
112 currentAlignment.myalign( current_num_path, dna, dna_len,\
113 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
114 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
115 print_matrix)
116
117 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
118 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
119 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
120 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
121
122 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
123
124 if prediction_mode:
125 # part that is only needed for prediction
126 result_len = currentAlignment.getResultLength()
127 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
128 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
129
130 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
131
132 dna_array = [0.0]*result_len
133 est_array = [0.0]*result_len
134
135 for r_idx in range(result_len):
136 dna_array[r_idx] = c_dna_array[r_idx]
137 est_array[r_idx] = c_est_array[r_idx]
138
139 else:
140 dna_array = None
141 est_array = None
142
143 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
144 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
145
146 #print 'After calling getAlignmentResults...'
147
148 newSpliceAlign = zeros((current_num_path*dna_len,1))
149 newEstAlign = zeros((est_len*current_num_path,1))
150 newWeightMatch = zeros((current_num_path*mm_len,1))
151 newDPScores = zeros((current_num_path,1))
152 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
153
154 for i in range(dna_len*current_num_path):
155 newSpliceAlign[i] = c_SpliceAlign[i]
156
157 for i in range(est_len*current_num_path):
158 newEstAlign[i] = c_EstAlign[i]
159
160 for i in range(mm_len*current_num_path):
161 newWeightMatch[i] = c_WeightMatch[i]
162
163 for i in range(current_num_path):
164 newDPScores[i] = c_DPScores[i]
165
166 if self.use_quality_scores:
167 for i in range(run['totalQualSuppPoints']*current_num_path):
168 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
169
170 del c_SpliceAlign
171 del c_EstAlign
172 del c_WeightMatch
173 del c_DPScores
174 del c_qualityPlifsFeatures
175 del currentAlignment
176
177 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
178 newQualityPlifsFeatures, dna_array, est_array
179
180
181 def train(self):
182 run = self.run
183
184 self.logfh = open('_qpalma_train.log','w+')
185
186 self.plog("Settings are:\n")
187 self.plog("%s\n"%str(run))
188
189 data_filename = self.run['dataset_filename']
190 Sequences, Acceptors, Donors, Exons, Ests, OriginalEsts, Qualities,\
191 AlternativeSequences =\
192 paths_load_data(data_filename,'training',None,self.ARGS)
193
194 # Load the whole dataset
195 if self.run['mode'] == 'normal':
196 self.use_quality_scores = False
197
198 elif self.run['mode'] == 'using_quality_scores':
199 self.use_quality_scores = True
200 else:
201 assert(False)
202
203 self.Sequences = Sequences
204 self.Exons = Exons
205 self.Ests = Ests
206 self.OriginalEsts= OriginalEsts
207 self.Qualities = Qualities
208 self.Donors = Donors
209 self.Acceptors = Acceptors
210
211 calc_info(self.Acceptors,self.Donors,self.Exons,self.Qualities)
212
213 beg = run['training_begin']
214 end = run['training_end']
215
216 Sequences = Sequences[beg:end]
217 Exons = Exons[beg:end]
218 Ests = Ests[beg:end]
219 OriginalEsts= OriginalEsts[beg:end]
220 Qualities = Qualities[beg:end]
221 Acceptors = Acceptors[beg:end]
222 Donors = Donors[beg:end]
223
224 # number of training instances
225 N = numExamples = len(Sequences)
226 assert len(Exons) == N and len(Ests) == N\
227 and len(Qualities) == N and len(Acceptors) == N\
228 and len(Donors) == N, 'The Exons,Acc,Don,.. arrays are of different lengths'
229 self.plog('Number of training examples: %d\n'% numExamples)
230
231 self.noImprovementCtr = 0
232 self.oldObjValue = 1e8
233
234 iteration_steps = run['iter_steps']
235 remove_duplicate_scores = run['remove_duplicate_scores']
236 print_matrix = run['print_matrix']
237 anzpath = run['anzpath']
238
239 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
240 param = Conf.fixedParam[:run['numFeatures']]
241
242 lengthSP = run['numLengthSuppPoints']
243 donSP = run['numDonSuppPoints']
244 accSP = run['numAccSuppPoints']
245 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
246 numq = run['numQualSuppPoints']
247 totalQualSP = run['totalQualSuppPoints']
248
249 # no intron length model
250 if not run['enable_intron_length']:
251 param[:lengthSP] *= 0.0
252
253 # Set the parameters such as limits penalties for the Plifs
254 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
255
256 # Initialize solver
257 self.plog('Initializing problem...\n')
258 solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
259
260 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
261 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
262
263 # stores the number of alignments done for each example (best path, second-best path etc.)
264 num_path = [anzpath]*numExamples
265 # stores the gap for each example
266 gap = [0.0]*numExamples
267 #############################################################################################
268 # Training
269 #############################################################################################
270 self.plog('Starting training...\n')
271
272 currentPhi = zeros((run['numFeatures'],1))
273 totalQualityPenalties = zeros((totalQualSP,1))
274
275 numConstPerRound = run['numConstraintsPerRound']
276 solver_call_ctr = 0
277
278 suboptimal_example = 0
279 iteration_nr = 0
280 param_idx = 0
281 const_added_ctr = 0
282
283 # the main training loop
284 while True:
285 if iteration_nr == iteration_steps:
286 break
287
288 for exampleIdx in range(numExamples):
289 if (exampleIdx%100) == 0:
290 print 'Current example nr %d' % exampleIdx
291
292 dna = Sequences[exampleIdx]
293 est = Ests[exampleIdx]
294 est = "".join(est)
295 est = est.lower()
296 est = est.replace('-','')
297 original_est = OriginalEsts[exampleIdx]
298 original_est = "".join(original_est)
299 original_est = original_est.lower()
300
301 dna_len = len(dna)
302 est_len = len(est)
303
304 assert len(est) == run['read_size'], pdb.set_trace()
305
306 if run['mode'] == 'normal':
307 quality = [40]*len(est)
308
309 if run['mode'] == 'using_quality_scores':
310 quality = Qualities[exampleIdx]
311
312 if not run['enable_quality_scores']:
313 quality = [40]*len(est)
314
315 exons = Exons[exampleIdx]
316 don_supp = Donors[exampleIdx]
317 acc_supp = Acceptors[exampleIdx]
318
319 if not run['enable_splice_signals']:
320 for idx,elem in enumerate(don_supp):
321 if elem != -inf:
322 don_supp[idx] = 0.0
323
324 for idx,elem in enumerate(acc_supp):
325 if elem != -inf:
326 acc_supp[idx] = 0.0
327
328 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
329 if run['mode'] == 'using_quality_scores':
330 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
331 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
332 quality, qualityPlifs,run)
333 else:
334 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
335
336 dna_calc = dna_calc.replace('-','')
337
338 # Calculate the weights
339 trueWeightDon, trueWeightAcc, trueWeightIntron = computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
340 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
341
342 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
343 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
344 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
345 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
346
347 if run['mode'] == 'using_quality_scores':
348 totalQualityPenalties = param[-totalQualSP:]
349 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
350
351 # Calculate w'phi(x,y) the total score of the alignment
352 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
353
354 # The allWeights vector is supposed to store the weight parameter
355 # of the true alignment as well as the weight parameters of the
356 # num_path[exampleIdx] other alignments
357 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
358 allWeights[:,0] = trueWeight[:,0]
359
360 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
361 AlignmentScores[0] = trueAlignmentScore
362
363 ################## Calculate wrong alignment(s) ######################
364 # Compute donor, acceptor with penalty_lookup_new
365 # returns two double lists
366 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
367
368 #myalign wants the acceptor site on the g of the ag
369 acceptor = acceptor[1:]
370 acceptor.append(-inf)
371
372 # check that splice site scores are at dna positions as expected by
373 # the dynamic programming component
374
375 #for d_pos in [pos for pos,elem in enumerate(donor) if elem != -inf]:
376 # assert dna[d_pos] == 'g' and (dna[d_pos+1] == 'c'\
377 # or dna[d_pos+1] == 't'), pdb.set_trace()
378 #
379 #for a_pos in [pos for pos,elem in enumerate(acceptor) if elem != -inf]:
380 # assert dna[a_pos-1] == 'a' and dna[a_pos] == 'g', pdb.set_trace()
381
382 ps = h.convert2SWIG()
383
384 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
385 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
386 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
387
388 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
389
390 # old code removed
391
392 newSpliceAlign = _newSpliceAlign
393 newEstAlign = _newEstAlign
394 newWeightMatch = _newWeightMatch
395 newDPScores = _newDPScores
396 newQualityPlifsFeatures = _newQualityPlifsFeatures
397
398 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
399 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
400
401 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
402
403 # Calculate weights of the respective alignments. Note that we are
404 # calculating n-best alignments without hamming loss, so we
405 # have to keep track which of the n-best alignments correspond to
406 # the true one in order not to incorporate a true alignment in the
407 # constraints. To keep track of the true and false alignments we
408 # define an array true_map with a boolean indicating the
409 # equivalence to the true alignment for each decoded alignment.
410 true_map = [0]*(num_path[exampleIdx]+1)
411 true_map[0] = 1
412
413 for pathNr in range(num_path[exampleIdx]):
414 #print 'decodedWeights'
415 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
416 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
417 acc_supp)
418
419 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
420 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
421 # Gewichte in restliche Zeilen der Matrix speichern
422 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
423
424 hpen = mat(h.penalties).reshape(len(h.penalties),1)
425 dpen = mat(d.penalties).reshape(len(d.penalties),1)
426 apen = mat(a.penalties).reshape(len(a.penalties),1)
427 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
428
429 #pdb.set_trace()
430
431 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
432
433 distinct_scores = False
434 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
435 distinct_scores = True
436
437 #pdb.set_trace()
438
439 # Check wether scalar product + loss equals viterbi score
440 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
441 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
442 pdb.set_trace()
443
444 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
445 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
446
447 #pdb.set_trace()
448
449 # if the pathNr-best alignment is very close to the true alignment consider it as true
450 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
451 true_map[pathNr+1] = 1
452
453
454 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
455 print "suboptimal_example %d\n" %exampleIdx
456 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
457 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
458
459 #pdb.set_trace()
460 suboptimal_example += 1
461 self.plog("suboptimal_example %d\n" %exampleIdx)
462
463
464 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
465 # this means that all n-best paths are to close to each other
466 # we have to extend the n-best search to a (n+1)-best
467 if len([elem for elem in true_map if elem == 1]) == len(true_map):
468 num_path[exampleIdx] = num_path[exampleIdx]+1
469
470 # Choose true and first false alignment for extending
471 firstFalseIdx = -1
472 for map_idx,elem in enumerate(true_map):
473 if elem == 0:
474 firstFalseIdx = map_idx
475 break
476
477 if False:
478 self.plog("Is considered as: %d\n" % true_map[1])
479
480 result_len = currentAlignment.getResultLength()
481 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
482 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
483
484 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
485
486 dna_array = [0.0]*result_len
487 est_array = [0.0]*result_len
488
489 for r_idx in range(result_len):
490 dna_array[r_idx] = c_dna_array[r_idx]
491 est_array[r_idx] = c_est_array[r_idx]
492
493 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
494 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
495 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
496 self.plog(line1+'\n')
497 self.plog(line2+'\n')
498 self.plog(line3+'\n')
499
500 # if there is at least one useful false alignment add the
501 # corresponding constraints to the optimization problem
502 if firstFalseIdx != -1:
503 firstFalseWeights = allWeights[:,firstFalseIdx]
504 differenceVector = trueWeight - firstFalseWeights
505 #pdb.set_trace()
506
507 const_added = solver.addConstraint(differenceVector, exampleIdx)
508 const_added_ctr += 1
509 #
510 # end of one example processing
511 #
512
513 # call solver every nth example //added constraint
514 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
515 objValue,w,self.slacks = solver.solve()
516 solver_call_ctr += 1
517
518 if solver_call_ctr == 5:
519 numConstPerRound = 200
520 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
521
522 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
523 self.noImprovementCtr += 1
524
525 if self.noImprovementCtr == numExamples+1:
526 break
527
528 self.oldObjValue = objValue
529 print "objValue is %f" % objValue
530
531 sum_xis = 0
532 for elem in self.slacks:
533 sum_xis += elem
534
535 print 'sum of slacks is %f'% sum_xis
536 self.plog('sum of slacks is %f\n'% sum_xis)
537
538 for i in range(len(param)):
539 param[i] = w[i]
540
541 #pdb.set_trace()
542
543 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
544 param_idx += 1
545 [h,d,a,mmatrix,qualityPlifs] =\
546 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
547
548 #
549 # end of one iteration through all examples
550 #
551
552 self.plog("suboptimal rounds %d\n" %suboptimal_example)
553
554 if self.noImprovementCtr == numExamples*2:
555 break
556
557 iteration_nr += 1
558
559 #
560 # end of optimization
561 #
562 print 'Training completed'
563
564 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
565 self.logfh.close()
566
567 ###############################################################################
568 #
569 # End of the code needed for training
570 #
571 #
572 # Begin of code for prediction
573 #
574 ###############################################################################
575
576
577 def evaluate(self,param_filename):
578 run = self.run
579 beg = run['prediction_begin']
580 end = run['prediction_end']
581
582 data_filename = self.run['dataset_filename']
583 Sequences, Acceptors, Donors, Exons, Ests, OriginalEsts, Qualities,\
584 AlternativeSequences=\
585 paths_load_data(data_filename,'training',None,self.ARGS)
586
587 self.Sequences = Sequences
588 self.Exons = Exons
589 self.Ests = Ests
590 self.OriginalEsts= OriginalEsts
591 self.Qualities = Qualities
592 self.Donors = Donors
593 self.Acceptors = Acceptors
594
595 self.AlternativeSequences = AlternativeSequences
596
597 #calc_info(self.Acceptors,self.Donors,self.Exons,self.Qualities)
598 #print 'leaving constructor...'
599
600 self.logfh = open('_qpalma_predict.log','w+')
601
602 # predict on training set
603 print '##### Prediction on the training set #####'
604 self.predict(param_filename,0,beg,'TRAIN')
605
606 # predict on test set
607 print '##### Prediction on the test set #####'
608 self.predict(param_filename,beg,end,'TEST')
609
610 self.logfh.close()
611 #pdb.set_trace()
612
613 def predict(self,param_filename,beg,end,set_flag):
614 """
615 Performing a prediction takes...
616
617 """
618
619 run = self.run
620
621 if self.run['mode'] == 'normal':
622 self.use_quality_scores = False
623
624 elif self.run['mode'] == 'using_quality_scores':
625 self.use_quality_scores = True
626 else:
627 assert(False)
628
629 Sequences = self.Sequences[beg:end]
630 Exons = self.Exons[beg:end]
631 Ests = self.Ests[beg:end]
632 OriginalEsts = self.OriginalEsts[beg:end]
633 Qualities = self.Qualities[beg:end]
634 Acceptors = self.Acceptors[beg:end]
635 Donors = self.Donors[beg:end]
636 #SplitPos = self.SplitPositions[beg:end]
637
638 AlternativeSequences = self.AlternativeSequences[beg:end]
639
640 print 'STARTING PREDICTION'
641
642 # number of training instances
643 N = numExamples = len(Sequences)
644 assert len(Exons) == N and len(Ests) == N\
645 and len(Qualities) == N and len(Acceptors) == N\
646 and len(Donors) == N, 'The Exons,Acc,Don,.. arrays are of different lengths'
647 self.plog('Number of training examples: %d\n'% numExamples)
648
649 self.noImprovementCtr = 0
650 self.oldObjValue = 1e8
651
652 remove_duplicate_scores = self.run['remove_duplicate_scores']
653 print_matrix = self.run['print_matrix']
654 anzpath = self.run['anzpath']
655
656 param = cPickle.load(open(param_filename))
657
658 # Set the parameters such as limits penalties for the Plifs
659 [h,d,a,mmatrix,qualityPlifs] =\
660 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
661
662 #############################################################################################
663 # Prediction
664 #############################################################################################
665 self.plog('Starting prediction...\n')
666
667 donSP = self.run['numDonSuppPoints']
668 accSP = self.run['numAccSuppPoints']
669 lengthSP = self.run['numLengthSuppPoints']
670 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
671 numq = self.run['numQualSuppPoints']
672 totalQualSP = self.run['totalQualSuppPoints']
673
674 #currentPhi = zeros((self.run['numFeatures'],1))
675 totalQualityPenalties = zeros((totalQualSP,1))
676
677 # some datastructures storing the predictions
678 exon1Begin = []
679 exon1End = []
680 exon2Begin = []
681 exon2End = []
682 allWrongExons = []
683
684 allRealCorrectSplitPos = []
685 allRealWrongSplitPos = []
686
687 allPredictions = []
688
689 # beginning of the prediction loop
690 for exampleIdx in range(numExamples):
691 dna = Sequences[exampleIdx]
692 est = Ests[exampleIdx]
693
694 new_est = ''
695 e = 0
696 while True:
697 if not e < len(est):
698 break
699
700 if est[e] == '[':
701 new_est += est[e+2]
702 e += 4
703 else:
704 new_est += est[e]
705 e += 1
706
707 est = new_est
708
709 #print exampleIdx
710
711 est = "".join(est)
712 est = est.lower()
713
714 exons = Exons[exampleIdx]
715
716 currentAlternatives = AlternativeSequences[exampleIdx]
717
718 #est = est.replace('-','')
719 #original_est = OriginalEsts[exampleIdx]
720 #original_est = "".join(original_est)
721 #original_est = original_est.lower()
722 #currentSplitPos = SplitPos[exampleIdx]
723
724 if self.run['mode'] == 'normal':
725 quality = [40]*len(est)
726
727 if self.run['mode'] == 'using_quality_scores':
728 quality = Qualities[exampleIdx]
729
730 if not run['enable_quality_scores']:
731 quality = [40]*len(est)
732
733 don_supp = Donors[exampleIdx]
734 acc_supp = Acceptors[exampleIdx]
735
736 if not run['enable_splice_signals']:
737
738 for idx,elem in enumerate(don_supp):
739 if elem != -inf:
740 don_supp[idx] = 0.0
741
742 for idx,elem in enumerate(acc_supp):
743 if elem != -inf:
744 acc_supp[idx] = 0.0
745
746 # the code up to here is equal for all alternative positions
747 for alternative_alignment in currentAlternatives:
748 chr,strand,genomicSeq_start,genomicSeq_stop = alternative_alignment
749
750 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
751
752 current_prediction = self.calc_alignment(dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs)
753 current_prediction['exampleIdx'] = exampleIdx
754
755 allPredictions.append(current_prediction)
756
757
758 #pydb.debugger()
759
760 cPickle.dump(allPredictions,open('%s_allPredictions_%s'%(run['name'],set_flag),'w+'))
761
762 print 'Prediction completed'
763
764
765 def calc_alignment(self, dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
766 ################## Calculate alignment(s) ######################
767 # Compute donor, acceptor with penalty_lookup_new
768 # returns two double lists
769 run = self.run
770
771 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
772
773 #myalign wants the acceptor site on the g of the ag
774 acceptor = acceptor[1:]
775 acceptor.append(-inf)
776
777 dna = str(dna)
778 est = str(est)
779 dna_len = len(dna)
780 est_len = len(est)
781
782 ps = h.convert2SWIG()
783
784 __newSpliceAlign, __newEstAlign, __newWeightMatch, __newDPScores,\
785 __newQualityPlifsFeatures, __dna_array, __est_array =\
786 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
787
788 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
789
790 # old code removed
791
792 newSpliceAlign = __newSpliceAlign
793 newEstAlign = __newEstAlign
794 newWeightMatch = __newWeightMatch
795 newDPScores = __newDPScores
796 newQualityPlifsFeatures = __newQualityPlifsFeatures
797 dna_array = __dna_array
798 est_array = __est_array
799
800 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
801 newWeightMatch = newWeightMatch.reshape(1,mm_len)
802 true_map = [0]*2
803 true_map[0] = 1
804 pathNr = 0
805
806 _newSpliceAlign = newSpliceAlign.flatten().tolist()[0]
807 _newEstAlign = newEstAlign.flatten().tolist()[0]
808
809 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
810 self.plog(line1+'\n')
811 self.plog(line2+'\n')
812 self.plog(line3+'\n')
813
814 e1_b_off,e1_e_off,e2_b_off,e2_e_off,newExons = self.evaluateExample(dna,est,exons,newSpliceAlign,newEstAlign)
815
816 current_prediction = {'e1_b_off':e1_b_off,'e1_e_off':e1_e_off,
817 'e2_b_off':e2_b_off ,'e2_e_off':e2_e_off,\
818 'predExons':newExons, 'trueExons':exons,\
819 'dna':dna, 'est':est,\
820 'DPScores':newDPScores }
821
822 return current_prediction
823
824
825 def evaluateExample(self,dna,est,exons,SpliceAlign,newEstAlign):
826 newExons = []
827 oldElem = -1
828 SpliceAlign = SpliceAlign.flatten().tolist()[0]
829 SpliceAlign.append(-1)
830 for pos,elem in enumerate(SpliceAlign):
831 if pos == 0:
832 oldElem = -1
833 else:
834 oldElem = SpliceAlign[pos-1]
835
836 if oldElem != 0 and elem == 0: # start of exon
837 newExons.append(pos)
838
839 if oldElem == 0 and elem != 0: # end of exon
840 newExons.append(pos)
841
842 e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
843
844 #self.plog(("Example %d"%exampleIdx) + str(exons) + " "+ str(newExons) + "\n")
845 #self.plog(("SpliceAlign " + str(SpliceAlign)+ "\n"))
846 self.plog(("read: " + str(est)+ "\n"))
847
848 if len(newExons) == 4:
849 e1_begin,e1_end = newExons[0],newExons[1]
850 e2_begin,e2_end = newExons[2],newExons[3]
851 else:
852 return None,None,None,None,newExons
853
854 e1_b_off = int(math.fabs(e1_begin - exons[0,0]))
855 e1_e_off = int(math.fabs(e1_end - exons[0,1]))
856
857 e2_b_off = int(math.fabs(e2_begin - exons[1,0]))
858 e2_e_off = int(math.fabs(e2_end - exons[1,1]))
859
860 return e1_b_off,e1_e_off,e2_b_off,e2_e_off,newExons
861
862
863 def get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,dna_flat_files):
864 """
865 This function expects an interval, chromosome and strand information and
866 returns then the genomic sequence of this interval and the associated scores.
867 """
868
869 chrom = 'chr%d' % chr
870 genomicSeq = load_genomic(chrom,strand,genomicSeq_start-1,genomicSeq_stop,dna_flat_files,one_based=False)
871 genomicSeq = genomicSeq.lower()
872
873 # check the obtained dna sequence
874 assert genomicSeq != '', 'load_genomic returned empty sequence!'
875 #for elem in genomicSeq:
876 # if not elem in alphabet:
877
878 no_base = re.compile('[^acgt]')
879 genomicSeq = no_base.sub('n',genomicSeq)
880
881 intervalBegin = genomicSeq_start-100
882 intervalEnd = genomicSeq_stop+100
883 currentDNASeq = genomicSeq
884 seq_pos_offset = genomicSeq_start
885
886 currentAcc, currentDon = getSpliceScores(chr,strand,intervalBegin,intervalEnd,currentDNASeq,seq_pos_offset)
887
888 return currentDNASeq, currentAcc, currentDon
889
890
891 ###########################
892 # A simple command line
893 # interface
894 ###########################
895
896 if __name__ == '__main__':
897 mode = sys.argv[1]
898 run_obj_fn = sys.argv[2]
899
900 run_obj = cPickle.load(open(run_obj_fn))
901
902 qpalma = QPalma(run_obj)
903
904
905 if len(sys.argv) == 3 and mode == 'train':
906 qpalma.train()
907
908 elif len(sys.argv) == 4 and mode == 'predict':
909 param_filename = sys.argv[3]
910 assert os.path.exists(param_filename)
911 qpalma.evaluate(param_filename)
912 else:
913 print 'You have to choose between training or prediction mode:'
914 print 'python qpalma. py (train|predict) <param_file>'