+ refactored code further
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads toegether with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import sys
18 import cPickle
19 import subprocess
20 import scipy.io
21 import pdb
22 import re
23 import os.path
24 #import pydb
25
26 from compile_dataset import getSpliceScores
27
28 import numpy
29 from numpy.matlib import mat,zeros,ones,inf
30 from numpy.linalg import norm
31
32 import QPalmaDP
33 import qpalma
34 from qpalma.SIQP_CPX import SIQPSolver
35 from qpalma.DataProc import *
36 from qpalma.computeSpliceWeights import *
37 from qpalma.set_param_palma import *
38 from qpalma.computeSpliceAlignWithQuality import *
39 from qpalma.penalty_lookup_new import *
40 from qpalma.compute_donacc import *
41 from qpalma.TrainingParam import Param
42 from qpalma.Plif import Plf
43
44 from qpalma.tools.splicesites import getDonAccScores
45 from qpalma.Configuration import *
46
47 # this two imports are needed for the load genomic resp. interval query
48 # functions
49 from Genefinding import *
50 from genome_utils import load_genomic
51
52 from Utils import calc_stat, calc_info, pprint_alignment
53
54 class QPalma:
55 """
56 A training method for the QPalma project
57 """
58
59 def __init__(self,run):
60 self.ARGS = Param()
61 self.run = run
62
63 if self.run['mode'] == 'normal':
64 self.use_quality_scores = False
65
66 elif self.run['mode'] == 'using_quality_scores':
67 self.use_quality_scores = True
68 else:
69 assert(False)
70
71
72 def plog(self,string):
73 self.logfh.write(string)
74 self.logfh.flush()
75
76
77 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
78 """
79 Given the needed input this method calls the QPalma C module which
80 calculates a dynamic programming in order to obtain an alignment
81 """
82 run = self.run
83
84 dna_len = len(dna)
85 est_len = len(est)
86
87 prb = QPalmaDP.createDoubleArrayFromList(quality)
88 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
89
90 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
91 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
92
93 d_len = len(donor)
94 donor = QPalmaDP.createDoubleArrayFromList(donor)
95 a_len = len(acceptor)
96 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
97
98 # Create the alignment object representing the interface to the C/C++ code.
99 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
100 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
101 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
102 currentAlignment.myalign( current_num_path, dna, dna_len,\
103 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
104 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
105 print_matrix)
106
107 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
108 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
109 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
110 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
111
112 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
113
114 if prediction_mode:
115 # part that is only needed for prediction
116 result_len = currentAlignment.getResultLength()
117 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
118 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
119
120 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
121
122 dna_array = [0.0]*result_len
123 est_array = [0.0]*result_len
124
125 for r_idx in range(result_len):
126 dna_array[r_idx] = c_dna_array[r_idx]
127 est_array[r_idx] = c_est_array[r_idx]
128
129 else:
130 dna_array = None
131 est_array = None
132
133 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
134 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
135
136 #print 'After calling getAlignmentResults...'
137
138 newSpliceAlign = zeros((current_num_path*dna_len,1))
139 newEstAlign = zeros((est_len*current_num_path,1))
140 newWeightMatch = zeros((current_num_path*mm_len,1))
141 newDPScores = zeros((current_num_path,1))
142 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
143
144 for i in range(dna_len*current_num_path):
145 newSpliceAlign[i] = c_SpliceAlign[i]
146
147 for i in range(est_len*current_num_path):
148 newEstAlign[i] = c_EstAlign[i]
149
150 for i in range(mm_len*current_num_path):
151 newWeightMatch[i] = c_WeightMatch[i]
152
153 for i in range(current_num_path):
154 newDPScores[i] = c_DPScores[i]
155
156 if self.use_quality_scores:
157 for i in range(run['totalQualSuppPoints']*current_num_path):
158 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
159
160 del c_SpliceAlign
161 del c_EstAlign
162 del c_WeightMatch
163 del c_DPScores
164 del c_qualityPlifsFeatures
165 del currentAlignment
166
167 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
168 newQualityPlifsFeatures, dna_array, est_array
169
170
171 def train(self):
172 run = self.run
173
174 full_working_path = os.path.join(run['experiment_path'],run['name'])
175
176 assert not os.path.exists(full_working_path)
177 os.mkdir(full_working_path)
178
179 assert os.path.exists(full_working_path)
180
181 # ATTENTION: Changing working directory
182 os.chdir(full_working_path)
183
184 cPickle.dump(run,open('run_object.pickle','w+'))
185
186 self.logfh = open('_qpalma_train.log','w+')
187
188 self.plog("Settings are:\n")
189 self.plog("%s\n"%str(run))
190
191 data_filename = self.run['dataset_filename']
192 Sequences, Acceptors, Donors, Exons, Ests, OriginalEsts, Qualities,\
193 UpCut, StartPos, AlternativeSequences =\
194 paths_load_data(data_filename,'training',None,self.ARGS)
195
196 # Load the whole dataset
197 if self.run['mode'] == 'normal':
198 self.use_quality_scores = False
199
200 elif self.run['mode'] == 'using_quality_scores':
201 self.use_quality_scores = True
202 else:
203 assert(False)
204
205 self.Sequences = Sequences
206 self.Exons = Exons
207 self.Ests = Ests
208 self.OriginalEsts= OriginalEsts
209 self.Qualities = Qualities
210 self.Donors = Donors
211 self.Acceptors = Acceptors
212
213 calc_info(self.Acceptors,self.Donors,self.Exons,self.Qualities)
214
215 beg = run['training_begin']
216 end = run['training_end']
217
218 Sequences = Sequences[beg:end]
219 Exons = Exons[beg:end]
220 Ests = Ests[beg:end]
221 OriginalEsts= OriginalEsts[beg:end]
222 Qualities = Qualities[beg:end]
223 Acceptors = Acceptors[beg:end]
224 Donors = Donors[beg:end]
225
226 # number of training instances
227 N = numExamples = len(Sequences)
228 assert len(Exons) == N and len(Ests) == N\
229 and len(Qualities) == N and len(Acceptors) == N\
230 and len(Donors) == N, 'The Exons,Acc,Don,.. arrays are of different lengths'
231 self.plog('Number of training examples: %d\n'% numExamples)
232
233 self.noImprovementCtr = 0
234 self.oldObjValue = 1e8
235
236 iteration_steps = run['iter_steps']
237 remove_duplicate_scores = run['remove_duplicate_scores']
238 print_matrix = run['print_matrix']
239 anzpath = run['anzpath']
240
241 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
242 param = Conf.fixedParam[:run['numFeatures']]
243
244 lengthSP = run['numLengthSuppPoints']
245 donSP = run['numDonSuppPoints']
246 accSP = run['numAccSuppPoints']
247 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
248 numq = run['numQualSuppPoints']
249 totalQualSP = run['totalQualSuppPoints']
250
251 # no intron length model
252 if not run['enable_intron_length']:
253 param[:lengthSP] *= 0.0
254
255 # Set the parameters such as limits penalties for the Plifs
256 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
257
258 # Initialize solver
259 self.plog('Initializing problem...\n')
260 solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
261
262 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
263 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
264
265 # stores the number of alignments done for each example (best path, second-best path etc.)
266 num_path = [anzpath]*numExamples
267 # stores the gap for each example
268 gap = [0.0]*numExamples
269 #############################################################################################
270 # Training
271 #############################################################################################
272 self.plog('Starting training...\n')
273
274 currentPhi = zeros((run['numFeatures'],1))
275 totalQualityPenalties = zeros((totalQualSP,1))
276
277 numConstPerRound = run['numConstraintsPerRound']
278 solver_call_ctr = 0
279
280 suboptimal_example = 0
281 iteration_nr = 0
282 param_idx = 0
283 const_added_ctr = 0
284
285 # the main training loop
286 while True:
287 if iteration_nr == iteration_steps:
288 break
289
290 for exampleIdx in range(numExamples):
291 if (exampleIdx%100) == 0:
292 print 'Current example nr %d' % exampleIdx
293
294 dna = Sequences[exampleIdx]
295 est = Ests[exampleIdx]
296 est = "".join(est)
297 est = est.lower()
298 est = est.replace('-','')
299 original_est = OriginalEsts[exampleIdx]
300 original_est = "".join(original_est)
301 original_est = original_est.lower()
302
303 dna_len = len(dna)
304 est_len = len(est)
305
306 assert len(est) == run['read_size'], pdb.set_trace()
307
308 if run['mode'] == 'normal':
309 quality = [40]*len(est)
310
311 if run['mode'] == 'using_quality_scores':
312 quality = Qualities[exampleIdx]
313
314 if not run['enable_quality_scores']:
315 quality = [40]*len(est)
316
317 exons = Exons[exampleIdx]
318 don_supp = Donors[exampleIdx]
319 acc_supp = Acceptors[exampleIdx]
320
321 if not run['enable_splice_signals']:
322 for idx,elem in enumerate(don_supp):
323 if elem != -inf:
324 don_supp[idx] = 0.0
325
326 for idx,elem in enumerate(acc_supp):
327 if elem != -inf:
328 acc_supp[idx] = 0.0
329
330 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
331 if run['mode'] == 'using_quality_scores':
332 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
333 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
334 quality, qualityPlifs,run)
335 else:
336 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
337
338 dna_calc = dna_calc.replace('-','')
339
340 # Calculate the weights
341 trueWeightDon, trueWeightAcc, trueWeightIntron = computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
342 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
343
344 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
345 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
346 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
347 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
348
349 if run['mode'] == 'using_quality_scores':
350 totalQualityPenalties = param[-totalQualSP:]
351 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
352
353 # Calculate w'phi(x,y) the total score of the alignment
354 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
355
356 # The allWeights vector is supposed to store the weight parameter
357 # of the true alignment as well as the weight parameters of the
358 # num_path[exampleIdx] other alignments
359 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
360 allWeights[:,0] = trueWeight[:,0]
361
362 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
363 AlignmentScores[0] = trueAlignmentScore
364
365 ################## Calculate wrong alignment(s) ######################
366 # Compute donor, acceptor with penalty_lookup_new
367 # returns two double lists
368 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
369
370 #myalign wants the acceptor site on the g of the ag
371 acceptor = acceptor[1:]
372 acceptor.append(-inf)
373
374 # check that splice site scores are at dna positions as expected by
375 # the dynamic programming component
376
377 #for d_pos in [pos for pos,elem in enumerate(donor) if elem != -inf]:
378 # assert dna[d_pos] == 'g' and (dna[d_pos+1] == 'c'\
379 # or dna[d_pos+1] == 't'), pdb.set_trace()
380 #
381 #for a_pos in [pos for pos,elem in enumerate(acceptor) if elem != -inf]:
382 # assert dna[a_pos-1] == 'a' and dna[a_pos] == 'g', pdb.set_trace()
383
384 ps = h.convert2SWIG()
385
386 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
387 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
388 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
389
390 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
391
392 # old code removed
393
394 newSpliceAlign = _newSpliceAlign
395 newEstAlign = _newEstAlign
396 newWeightMatch = _newWeightMatch
397 newDPScores = _newDPScores
398 newQualityPlifsFeatures = _newQualityPlifsFeatures
399
400 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
401 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
402
403 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
404
405 # Calculate weights of the respective alignments. Note that we are
406 # calculating n-best alignments without hamming loss, so we
407 # have to keep track which of the n-best alignments correspond to
408 # the true one in order not to incorporate a true alignment in the
409 # constraints. To keep track of the true and false alignments we
410 # define an array true_map with a boolean indicating the
411 # equivalence to the true alignment for each decoded alignment.
412 true_map = [0]*(num_path[exampleIdx]+1)
413 true_map[0] = 1
414
415 for pathNr in range(num_path[exampleIdx]):
416 #print 'decodedWeights'
417 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
418 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
419 acc_supp)
420
421 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
422 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
423 # Gewichte in restliche Zeilen der Matrix speichern
424 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
425
426 hpen = mat(h.penalties).reshape(len(h.penalties),1)
427 dpen = mat(d.penalties).reshape(len(d.penalties),1)
428 apen = mat(a.penalties).reshape(len(a.penalties),1)
429 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
430
431 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
432
433 distinct_scores = False
434 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
435 distinct_scores = True
436
437 # Check wether scalar product + loss equals viterbi score
438 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
439 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
440 pdb.set_trace()
441
442 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
443 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
444
445 # if the pathNr-best alignment is very close to the true alignment consider it as true
446 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
447 true_map[pathNr+1] = 1
448
449 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
450 print "suboptimal_example %d\n" %exampleIdx
451 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
452 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
453
454 #pdb.set_trace()
455 suboptimal_example += 1
456 self.plog("suboptimal_example %d\n" %exampleIdx)
457
458 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
459 # this means that all n-best paths are to close to each other
460 # we have to extend the n-best search to a (n+1)-best
461 if len([elem for elem in true_map if elem == 1]) == len(true_map):
462 num_path[exampleIdx] = num_path[exampleIdx]+1
463
464 # Choose true and first false alignment for extending
465 firstFalseIdx = -1
466 for map_idx,elem in enumerate(true_map):
467 if elem == 0:
468 firstFalseIdx = map_idx
469 break
470
471 if False:
472 self.plog("Is considered as: %d\n" % true_map[1])
473
474 result_len = currentAlignment.getResultLength()
475 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
476 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
477
478 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
479
480 dna_array = [0.0]*result_len
481 est_array = [0.0]*result_len
482
483 for r_idx in range(result_len):
484 dna_array[r_idx] = c_dna_array[r_idx]
485 est_array[r_idx] = c_est_array[r_idx]
486
487 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
488 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
489
490 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
491 self.plog(line1+'\n')
492 self.plog(line2+'\n')
493 self.plog(line3+'\n')
494
495 # if there is at least one useful false alignment add the
496 # corresponding constraints to the optimization problem
497 if firstFalseIdx != -1:
498 firstFalseWeights = allWeights[:,firstFalseIdx]
499 differenceVector = trueWeight - firstFalseWeights
500 #pdb.set_trace()
501
502 const_added = solver.addConstraint(differenceVector, exampleIdx)
503 const_added_ctr += 1
504 #
505 # end of one example processing
506 #
507
508 # call solver every nth example //added constraint
509 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
510 objValue,w,self.slacks = solver.solve()
511 solver_call_ctr += 1
512
513 if solver_call_ctr == 5:
514 numConstPerRound = 200
515 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
516
517 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
518 self.noImprovementCtr += 1
519
520 if self.noImprovementCtr == numExamples+1:
521 break
522
523 self.oldObjValue = objValue
524 print "objValue is %f" % objValue
525
526 sum_xis = 0
527 for elem in self.slacks:
528 sum_xis += elem
529
530 print 'sum of slacks is %f'% sum_xis
531 self.plog('sum of slacks is %f\n'% sum_xis)
532
533 for i in range(len(param)):
534 param[i] = w[i]
535
536 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
537 param_idx += 1
538 [h,d,a,mmatrix,qualityPlifs] =\
539 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
540
541 #
542 # end of one iteration through all examples
543 #
544
545 self.plog("suboptimal rounds %d\n" %suboptimal_example)
546
547 if self.noImprovementCtr == numExamples*2:
548 break
549
550 iteration_nr += 1
551
552 #
553 # end of optimization
554 #
555 print 'Training completed'
556
557 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
558 self.logfh.close()
559
560 ###############################################################################
561 #
562 # End of the code needed for training
563 #
564 #
565 # Begin of code for prediction
566 #
567 ###############################################################################
568
569 def evaluate(self,param_filename):
570 run = self.run
571 beg = run['prediction_begin']
572 end = run['prediction_end']
573
574 data_filename = self.run['dataset_filename']
575 Sequences, Acceptors, Donors, Exons, Ests, OriginalEsts, Qualities,\
576 UpCut, StartPos, AlternativeSequences=\
577 paths_load_data(data_filename,'training',None,self.ARGS)
578
579 self.Sequences = Sequences
580 self.Exons = Exons
581 self.Ests = Ests
582 self.OriginalEsts= OriginalEsts
583 self.Qualities = Qualities
584 self.Donors = Donors
585 self.Acceptors = Acceptors
586 self.UpCut = UpCut
587 self.StartPos = StartPos
588
589 self.AlternativeSequences = AlternativeSequences
590
591 #calc_info(self.Acceptors,self.Donors,self.Exons,self.Qualities)
592 #print 'leaving constructor...'
593
594 self.logfh = open('_qpalma_predict.log','w+')
595
596 # predict on training set
597 self.plog('##### Prediction on the training set #####\n')
598
599 self.predict(param_filename,0,beg,'TRAIN')
600
601 # predict on test set
602 self.plog('##### Prediction on the test set #####\n')
603 self.predict(param_filename,beg,end,'TEST')
604
605 self.plog('##### Finished prediction #####\n')
606 self.logfh.close()
607
608 def predict(self,param_filename,beg,end,set_flag):
609 """
610 Performing a prediction takes...
611
612 """
613
614 run = self.run
615
616 if self.run['mode'] == 'normal':
617 self.use_quality_scores = False
618
619 elif self.run['mode'] == 'using_quality_scores':
620 self.use_quality_scores = True
621 else:
622 assert(False)
623
624 Sequences = self.Sequences[beg:end]
625 Exons = self.Exons[beg:end]
626 Ests = self.Ests[beg:end]
627 OriginalEsts = self.OriginalEsts[beg:end]
628 Qualities = self.Qualities[beg:end]
629 Acceptors = self.Acceptors[beg:end]
630 Donors = self.Donors[beg:end]
631 UpCut = self.UpCut[beg:end]
632 StartPos = self.StartPos[beg:end]
633 #SplitPos = self.SplitPositions[beg:end]
634
635 AlternativeSequences = self.AlternativeSequences[beg:end]
636
637 # number of training instances
638 N = numExamples = len(Sequences)
639 assert len(Exons) == N and len(Ests) == N\
640 and len(Qualities) == N and len(Acceptors) == N\
641 and len(Donors) == N, 'The Exons,Acc,Don,.. arrays are of different lengths'
642 self.plog('Number of training examples: %d\n'% numExamples)
643
644 self.noImprovementCtr = 0
645 self.oldObjValue = 1e8
646
647 remove_duplicate_scores = self.run['remove_duplicate_scores']
648 print_matrix = self.run['print_matrix']
649 anzpath = self.run['anzpath']
650
651 param = cPickle.load(open(param_filename))
652
653 # Set the parameters such as limits penalties for the Plifs
654 [h,d,a,mmatrix,qualityPlifs] =\
655 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
656
657 #############################################################################################
658 # Prediction
659 #############################################################################################
660 self.plog('Starting prediction...\n')
661
662 donSP = self.run['numDonSuppPoints']
663 accSP = self.run['numAccSuppPoints']
664 lengthSP = self.run['numLengthSuppPoints']
665 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
666 numq = self.run['numQualSuppPoints']
667 totalQualSP = self.run['totalQualSuppPoints']
668
669 totalQualityPenalties = zeros((totalQualSP,1))
670
671 # where we store the predictions
672 allPredictions = []
673
674 # beginning of the prediction loop
675 for exampleIdx in range(numExamples):
676 self.plog('Loading example nr. %d...\n'%exampleIdx)
677
678 dna = Sequences[exampleIdx]
679 est = Ests[exampleIdx]
680
681 new_est = ''
682 e = 0
683 while True:
684 if not e < len(est):
685 break
686
687 if est[e] == '[':
688 new_est += est[e+2]
689 e += 4
690 else:
691 new_est += est[e]
692 e += 1
693
694 est = new_est
695 est = "".join(est)
696 est = est.lower()
697
698 exons = Exons[exampleIdx]
699
700 current_up_cut = UpCut[exampleIdx]
701
702 current_start_pos = StartPos[exampleIdx]
703
704 currentAlternatives = AlternativeSequences[exampleIdx]
705
706 #est = est.replace('-','')
707 #original_est = OriginalEsts[exampleIdx]
708 #original_est = "".join(original_est)
709 #original_est = original_est.lower()
710 #currentSplitPos = SplitPos[exampleIdx]
711
712 if self.run['mode'] == 'normal':
713 quality = [40]*len(est)
714
715 if self.run['mode'] == 'using_quality_scores':
716 quality = Qualities[exampleIdx]
717
718 if not run['enable_quality_scores']:
719 quality = [40]*len(est)
720
721 don_supp = Donors[exampleIdx]
722 acc_supp = Acceptors[exampleIdx]
723
724 if not run['enable_splice_signals']:
725
726 for idx,elem in enumerate(don_supp):
727 if elem != -inf:
728 don_supp[idx] = 0.0
729
730 for idx,elem in enumerate(acc_supp):
731 if elem != -inf:
732 acc_supp[idx] = 0.0
733
734 current_example_predictions = []
735
736 # first make a prediction on the dna fragment which comes from the ground truth
737 current_prediction = self.calc_alignment(dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs)
738 current_prediction['exampleIdx'] = exampleIdx
739 current_prediction['start_pos'] = current_start_pos
740 current_prediction['label'] = True
741
742 current_example_predictions.append(current_prediction)
743
744 # then make predictions for all dna fragments that where occurring in
745 # the vmatch results
746 for alternative_alignment in currentAlternatives:
747 chr, strand, genomicSeq_start, genomicSeq_stop, currentLabel = alternative_alignment
748 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
749
750 if not run['enable_splice_signals']:
751 for idx,elem in enumerate(currentDon):
752 if elem != -inf:
753 currentDon[idx] = 0.0
754
755 for idx,elem in enumerate(currentAcc):
756 if elem != -inf:
757 currentAcc[idx] = 0.0
758
759 current_prediction = self.calc_alignment(currentDNASeq, est, exons,\
760 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs)
761 current_prediction['exampleIdx'] = exampleIdx
762 current_prediction['start_pos'] = current_start_pos
763 current_prediction['alternative_start_pos'] = genomicSeq_start
764 current_prediction['label'] = currentLabel
765
766 current_example_predictions.append(current_prediction)
767
768 allPredictions.append(current_example_predictions)
769
770 # end of the prediction loop we save all predictions in a pickle file and exit
771 cPickle.dump(allPredictions,open('%s_allPredictions_%s'%(run['name'],set_flag),'w+'))
772 print 'Prediction completed'
773
774
775 def calc_alignment(self, dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
776 """
777 Given two sequences and the parameters we calculate on alignment
778 """
779
780 run = self.run
781
782 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
783
784 #myalign wants the acceptor site on the g of the ag
785 acceptor = acceptor[1:]
786 acceptor.append(-inf)
787
788 dna = str(dna)
789 est = str(est)
790 dna_len = len(dna)
791 est_len = len(est)
792
793 ps = h.convert2SWIG()
794
795 __newSpliceAlign, __newEstAlign, __newWeightMatch, __newDPScores,\
796 __newQualityPlifsFeatures, __dna_array, __est_array =\
797 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
798
799 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
800
801 # old code removed
802
803 newSpliceAlign = __newSpliceAlign
804 newEstAlign = __newEstAlign
805 newWeightMatch = __newWeightMatch
806 newDPScores = __newDPScores
807 newQualityPlifsFeatures = __newQualityPlifsFeatures
808 dna_array = __dna_array
809 est_array = __est_array
810
811 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
812 newWeightMatch = newWeightMatch.reshape(1,mm_len)
813 true_map = [0]*2
814 true_map[0] = 1
815 pathNr = 0
816
817 _newSpliceAlign = newSpliceAlign.flatten().tolist()[0]
818 _newEstAlign = newEstAlign.flatten().tolist()[0]
819
820 if False:
821 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
822 self.plog(line1+'\n')
823 self.plog(line2+'\n')
824 self.plog(line3+'\n')
825
826 newExons = self.calculatePredictedExons(newSpliceAlign)
827
828 current_prediction = {'predExons':newExons, 'trueExons':exons,\
829 'dna':dna, 'est':est, 'DPScores':newDPScores}
830
831 return current_prediction
832
833
834 def calculatePredictedExons(self,SpliceAlign):
835 newExons = []
836 oldElem = -1
837 SpliceAlign = SpliceAlign.flatten().tolist()[0]
838 SpliceAlign.append(-1)
839 for pos,elem in enumerate(SpliceAlign):
840 if pos == 0:
841 oldElem = -1
842 else:
843 oldElem = SpliceAlign[pos-1]
844
845 if oldElem != 0 and elem == 0: # start of exon
846 newExons.append(pos)
847
848 if oldElem == 0 and elem != 0: # end of exon
849 newExons.append(pos)
850
851 return newExons
852
853
854 def get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,dna_flat_files):
855 """
856 This function expects an interval, chromosome and strand information and
857 returns then the genomic sequence of this interval and the associated scores.
858 """
859
860 chrom = 'chr%d' % chr
861 genomicSeq = load_genomic(chrom,strand,genomicSeq_start-1,genomicSeq_stop,dna_flat_files,one_based=False)
862 genomicSeq = genomicSeq.lower()
863
864 # check the obtained dna sequence
865 assert genomicSeq != '', 'load_genomic returned empty sequence!'
866 #for elem in genomicSeq:
867 # if not elem in alphabet:
868
869 no_base = re.compile('[^acgt]')
870 genomicSeq = no_base.sub('n',genomicSeq)
871
872 intervalBegin = genomicSeq_start-100
873 intervalEnd = genomicSeq_stop+100
874 currentDNASeq = genomicSeq
875 seq_pos_offset = genomicSeq_start
876
877 currentAcc, currentDon = getSpliceScores(chr,strand,intervalBegin,intervalEnd,currentDNASeq,seq_pos_offset)
878
879 return currentDNASeq, currentAcc, currentDon
880
881
882 ###########################
883 # A simple command line
884 # interface
885 ###########################
886
887 if __name__ == '__main__':
888 mode = sys.argv[1]
889 run_obj_fn = sys.argv[2]
890
891 run_obj = cPickle.load(open(run_obj_fn))
892
893 qpalma = QPalma(run_obj)
894
895
896 if len(sys.argv) == 3 and mode == 'train':
897 qpalma.train()
898
899 elif len(sys.argv) == 4 and mode == 'predict':
900 param_filename = sys.argv[3]
901 assert os.path.exists(param_filename)
902 qpalma.evaluate(param_filename)
903 else:
904 print 'You have to choose between training or prediction mode:'
905 print 'python qpalma. py (train|predict) <param_file>'