+ changed training method to work with new dataset format
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads together with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import sys
18 import cPickle
19 import pdb
20 import re
21 import os.path
22
23 from compile_dataset import getSpliceScores
24
25 import numpy
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy.linalg import norm
28
29 import QPalmaDP
30 import qpalma
31 from qpalma.SIQP_CPX import SIQPSolver
32 from qpalma.DataProc import *
33 from qpalma.computeSpliceWeights import *
34 from qpalma.set_param_palma import *
35 from qpalma.computeSpliceAlignWithQuality import *
36 from qpalma.penalty_lookup_new import *
37 from qpalma.compute_donacc import *
38 from qpalma.TrainingParam import Param
39 from qpalma.Plif import Plf
40
41 from qpalma.tools.splicesites import getDonAccScores
42 from qpalma.Configuration import *
43
44 # this two imports are needed for the load genomic resp. interval query
45 # functions
46 from Genefinding import *
47 from genome_utils import load_genomic
48
49 from Utils import calc_stat, calc_info, pprint_alignment
50
51
52 def unbracket_est(est):
53 new_est = ''
54 e = 0
55
56 while True:
57 if e >= len(est):
58 break
59
60 if est[e] == '[':
61 new_est += est[e+2]
62 e += 4
63 else:
64 new_est += est[e]
65 e += 1
66
67 return "".join(new_est).lower()
68
69
70 class QPalma:
71 """
72 A training method for the QPalma project
73 """
74
75 def __init__(self,run):
76 self.ARGS = Param()
77 self.run = run
78
79 if self.run['mode'] == 'normal':
80 self.use_quality_scores = False
81
82 elif self.run['mode'] == 'using_quality_scores':
83 self.use_quality_scores = True
84 else:
85 assert(False)
86
87
88 def plog(self,string):
89 self.logfh.write(string)
90 self.logfh.flush()
91
92
93 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
94 """
95 Given the needed input this method calls the QPalma C module which
96 calculates a dynamic programming in order to obtain an alignment
97 """
98 run = self.run
99
100 dna_len = len(dna)
101 est_len = len(est)
102
103 prb = QPalmaDP.createDoubleArrayFromList(quality)
104 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
105
106 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
107 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
108
109 d_len = len(donor)
110 donor = QPalmaDP.createDoubleArrayFromList(donor)
111 a_len = len(acceptor)
112 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
113
114 # Create the alignment object representing the interface to the C/C++ code.
115 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
116 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
117 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
118 currentAlignment.myalign( current_num_path, dna, dna_len,\
119 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
120 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
121 print_matrix)
122
123 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
124 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
125 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
126 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
127
128 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
129
130 if prediction_mode:
131 # part that is only needed for prediction
132 result_len = currentAlignment.getResultLength()
133 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
134 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
135
136 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
137
138 dna_array = [0.0]*result_len
139 est_array = [0.0]*result_len
140
141 for r_idx in range(result_len):
142 dna_array[r_idx] = c_dna_array[r_idx]
143 est_array[r_idx] = c_est_array[r_idx]
144
145 else:
146 dna_array = None
147 est_array = None
148
149 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
150 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
151
152 #print 'After calling getAlignmentResults...'
153
154 newSpliceAlign = zeros((current_num_path*dna_len,1))
155 newEstAlign = zeros((est_len*current_num_path,1))
156 newWeightMatch = zeros((current_num_path*mm_len,1))
157 newDPScores = zeros((current_num_path,1))
158 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
159
160 for i in range(dna_len*current_num_path):
161 newSpliceAlign[i] = c_SpliceAlign[i]
162
163 for i in range(est_len*current_num_path):
164 newEstAlign[i] = c_EstAlign[i]
165
166 for i in range(mm_len*current_num_path):
167 newWeightMatch[i] = c_WeightMatch[i]
168
169 for i in range(current_num_path):
170 newDPScores[i] = c_DPScores[i]
171
172 if self.use_quality_scores:
173 for i in range(run['totalQualSuppPoints']*current_num_path):
174 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
175
176 del c_SpliceAlign
177 del c_EstAlign
178 del c_WeightMatch
179 del c_DPScores
180 del c_qualityPlifsFeatures
181 del currentAlignment
182
183 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
184 newQualityPlifsFeatures, dna_array, est_array
185
186
187 def train(self):
188 run = self.run
189
190 full_working_path = os.path.join(run['experiment_path'],run['name'])
191
192 assert not os.path.exists(full_working_path)
193 os.mkdir(full_working_path)
194
195 assert os.path.exists(full_working_path)
196
197 # ATTENTION: Changing working directory
198 os.chdir(full_working_path)
199
200 cPickle.dump(run,open('run_object.pickle','w+'))
201
202 self.logfh = open('_qpalma_train.log','w+')
203
204 self.plog("Settings are:\n")
205 self.plog("%s\n"%str(run))
206
207 data_filename = self.run['dataset_filename']
208
209 SeqInfo, Exons, OriginalEsts, Qualities,\
210 AlternativeSequences = paths_load_data(data_filename,'training',None,self.ARGS)
211
212 # Load the whole dataset
213 if self.run['mode'] == 'normal':
214 self.use_quality_scores = False
215
216 elif self.run['mode'] == 'using_quality_scores':
217 self.use_quality_scores = True
218 else:
219 assert(False)
220
221 self.SeqInfo = SeqInfo
222 self.Exons = Exons
223 self.OriginalEsts= OriginalEsts
224 self.Qualities = Qualities
225
226 #calc_info(self.Exons,self.Qualities)
227
228 beg = run['training_begin']
229 end = run['training_end']
230
231 SeqInfo = SeqInfo[beg:end]
232 Exons = Exons[beg:end]
233 OriginalEsts= OriginalEsts[beg:end]
234 Qualities = Qualities[beg:end]
235
236 # number of training instances
237 N = numExamples = len(SeqInfo)
238 assert len(Exons) == N and len(OriginalEsts) == N and len(Qualities) == N,\
239 'The Exons,Acc,Don,.. arrays are of different lengths'
240
241 self.plog('Number of training examples: %d\n'% numExamples)
242
243 self.noImprovementCtr = 0
244 self.oldObjValue = 1e8
245
246 iteration_steps = run['iter_steps']
247 remove_duplicate_scores = run['remove_duplicate_scores']
248 print_matrix = run['print_matrix']
249 anzpath = run['anzpath']
250
251 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
252 param = Conf.fixedParam[:run['numFeatures']]
253
254 lengthSP = run['numLengthSuppPoints']
255 donSP = run['numDonSuppPoints']
256 accSP = run['numAccSuppPoints']
257 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
258 numq = run['numQualSuppPoints']
259 totalQualSP = run['totalQualSuppPoints']
260
261 # no intron length model
262 if not run['enable_intron_length']:
263 param[:lengthSP] *= 0.0
264
265 # Set the parameters such as limits penalties for the Plifs
266 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
267
268 # Initialize solver
269 self.plog('Initializing problem...\n')
270 solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
271
272 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
273 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
274
275 # stores the number of alignments done for each example (best path, second-best path etc.)
276 num_path = [anzpath]*numExamples
277 # stores the gap for each example
278 gap = [0.0]*numExamples
279 #############################################################################################
280 # Training
281 #############################################################################################
282 self.plog('Starting training...\n')
283
284 currentPhi = zeros((run['numFeatures'],1))
285 totalQualityPenalties = zeros((totalQualSP,1))
286
287 numConstPerRound = run['numConstraintsPerRound']
288 solver_call_ctr = 0
289
290 suboptimal_example = 0
291 iteration_nr = 0
292 param_idx = 0
293 const_added_ctr = 0
294
295 # the main training loop
296 while True:
297 if iteration_nr == iteration_steps:
298 break
299
300 for exampleIdx in range(numExamples):
301 if (exampleIdx%100) == 0:
302 print 'Current example nr %d' % exampleIdx
303
304 currentSeqInfo = SeqInfo[exampleIdx]
305 chr,strand,up_cut,down_cut = currentSeqInfo
306
307 est = OriginalEsts[exampleIdx]
308 est = "".join(est)
309 est = est.lower()
310 est = unbracket_est(est)
311 est = est.replace('-','')
312
313 original_est = OriginalEsts[exampleIdx]
314 original_est = "".join(original_est)
315 original_est = original_est.lower()
316
317 est_len = len(est)
318
319 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
320 dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
321 dna_len = len(dna)
322
323 don_supp = don_supp[1:] + [-inf]
324 acc_supp = acc_supp[1:] + [-inf]
325
326 assert len(est) == run['read_size'], pdb.set_trace()
327
328 if run['mode'] == 'normal':
329 quality = [40]*len(est)
330
331 if run['mode'] == 'using_quality_scores':
332 quality = Qualities[exampleIdx][0]
333
334 if not run['enable_quality_scores']:
335 quality = [40]*len(est)
336
337 exons = Exons[exampleIdx]
338 exons -= up_cut
339
340 if not run['enable_splice_signals']:
341 for idx,elem in enumerate(don_supp):
342 if elem != -inf:
343 don_supp[idx] = 0.0
344
345 for idx,elem in enumerate(acc_supp):
346 if elem != -inf:
347 acc_supp[idx] = 0.0
348
349 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
350 if run['mode'] == 'using_quality_scores':
351 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
352 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
353 quality, qualityPlifs,run)
354 else:
355 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
356
357 dna_calc = dna_calc.replace('-','')
358
359 # Calculate the weights
360 trueWeightDon, trueWeightAcc, trueWeightIntron = computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
361 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
362
363 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
364 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
365 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
366 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
367
368 if run['mode'] == 'using_quality_scores':
369 totalQualityPenalties = param[-totalQualSP:]
370 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
371
372 # Calculate w'phi(x,y) the total score of the alignment
373 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
374
375 # The allWeights vector is supposed to store the weight parameter
376 # of the true alignment as well as the weight parameters of the
377 # num_path[exampleIdx] other alignments
378 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
379 allWeights[:,0] = trueWeight[:,0]
380
381 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
382 AlignmentScores[0] = trueAlignmentScore
383
384 ################## Calculate wrong alignment(s) ######################
385 # Compute donor, acceptor with penalty_lookup_new
386 # returns two double lists
387 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
388
389 #myalign wants the acceptor site on the g of the ag
390 acceptor = acceptor[1:]
391 acceptor.append(-inf)
392
393 # check that splice site scores are at dna positions as expected by
394 # the dynamic programming component
395
396 #for d_pos in [pos for pos,elem in enumerate(donor) if elem != -inf]:
397 # assert dna[d_pos] == 'g' and (dna[d_pos+1] == 'c'\
398 # or dna[d_pos+1] == 't'), pdb.set_trace()
399 #
400 #for a_pos in [pos for pos,elem in enumerate(acceptor) if elem != -inf]:
401 # assert dna[a_pos-1] == 'a' and dna[a_pos] == 'g', pdb.set_trace()
402
403 ps = h.convert2SWIG()
404
405 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
406 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
407 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
408
409 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
410
411 # old code removed
412
413 newSpliceAlign = _newSpliceAlign
414 newEstAlign = _newEstAlign
415 newWeightMatch = _newWeightMatch
416 newDPScores = _newDPScores
417 newQualityPlifsFeatures = _newQualityPlifsFeatures
418
419 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
420 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
421
422 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
423
424 # Calculate weights of the respective alignments. Note that we are
425 # calculating n-best alignments without hamming loss, so we
426 # have to keep track which of the n-best alignments correspond to
427 # the true one in order not to incorporate a true alignment in the
428 # constraints. To keep track of the true and false alignments we
429 # define an array true_map with a boolean indicating the
430 # equivalence to the true alignment for each decoded alignment.
431 true_map = [0]*(num_path[exampleIdx]+1)
432 true_map[0] = 1
433
434 for pathNr in range(num_path[exampleIdx]):
435 #print 'decodedWeights'
436 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
437 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
438 acc_supp)
439
440 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
441 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
442 # Gewichte in restliche Zeilen der Matrix speichern
443 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
444
445 hpen = mat(h.penalties).reshape(len(h.penalties),1)
446 dpen = mat(d.penalties).reshape(len(d.penalties),1)
447 apen = mat(a.penalties).reshape(len(a.penalties),1)
448 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
449
450 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
451
452 distinct_scores = False
453 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
454 distinct_scores = True
455
456 # Check wether scalar product + loss equals viterbi score
457 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
458 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
459 pdb.set_trace()
460
461 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
462 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
463
464 # if the pathNr-best alignment is very close to the true alignment consider it as true
465 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
466 true_map[pathNr+1] = 1
467
468 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
469 print "suboptimal_example %d\n" %exampleIdx
470 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
471 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
472
473 #pdb.set_trace()
474 suboptimal_example += 1
475 self.plog("suboptimal_example %d\n" %exampleIdx)
476
477 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
478 # this means that all n-best paths are to close to each other
479 # we have to extend the n-best search to a (n+1)-best
480 if len([elem for elem in true_map if elem == 1]) == len(true_map):
481 num_path[exampleIdx] = num_path[exampleIdx]+1
482
483 # Choose true and first false alignment for extending
484 firstFalseIdx = -1
485 for map_idx,elem in enumerate(true_map):
486 if elem == 0:
487 firstFalseIdx = map_idx
488 break
489
490 if False:
491 self.plog("Is considered as: %d\n" % true_map[1])
492
493 result_len = currentAlignment.getResultLength()
494 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
495 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
496
497 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
498
499 dna_array = [0.0]*result_len
500 est_array = [0.0]*result_len
501
502 for r_idx in range(result_len):
503 dna_array[r_idx] = c_dna_array[r_idx]
504 est_array[r_idx] = c_est_array[r_idx]
505
506 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
507 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
508
509 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
510 self.plog(line1+'\n')
511 self.plog(line2+'\n')
512 self.plog(line3+'\n')
513
514 # if there is at least one useful false alignment add the
515 # corresponding constraints to the optimization problem
516 if firstFalseIdx != -1:
517 firstFalseWeights = allWeights[:,firstFalseIdx]
518 differenceVector = trueWeight - firstFalseWeights
519 #pdb.set_trace()
520
521 const_added = solver.addConstraint(differenceVector, exampleIdx)
522 const_added_ctr += 1
523 #
524 # end of one example processing
525 #
526
527 # call solver every nth example //added constraint
528 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
529 objValue,w,self.slacks = solver.solve()
530 solver_call_ctr += 1
531
532 if solver_call_ctr == 5:
533 numConstPerRound = 200
534 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
535
536 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
537 self.noImprovementCtr += 1
538
539 if self.noImprovementCtr == numExamples+1:
540 break
541
542 self.oldObjValue = objValue
543 print "objValue is %f" % objValue
544
545 sum_xis = 0
546 for elem in self.slacks:
547 sum_xis += elem
548
549 print 'sum of slacks is %f'% sum_xis
550 self.plog('sum of slacks is %f\n'% sum_xis)
551
552 for i in range(len(param)):
553 param[i] = w[i]
554
555 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
556 param_idx += 1
557 [h,d,a,mmatrix,qualityPlifs] =\
558 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
559
560 #
561 # end of one iteration through all examples
562 #
563
564 self.plog("suboptimal rounds %d\n" %suboptimal_example)
565
566 if self.noImprovementCtr == numExamples*2:
567 break
568
569 iteration_nr += 1
570
571 #
572 # end of optimization
573 #
574 print 'Training completed'
575
576 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
577 self.logfh.close()
578
579 ###############################################################################
580 #
581 # End of the code needed for training
582 #
583 #
584 # Begin of code for prediction
585 #
586 ###############################################################################
587
588 def evaluate(self,param_filename):
589 run = self.run
590 beg = run['prediction_begin']
591 end = run['prediction_end']
592
593 data_filename = self.run['dataset_filename']
594 Sequences, Acceptors, Donors, Exons, Ests, OriginalEsts, Qualities,\
595 UpCut, StartPos, AlternativeSequences=\
596 paths_load_data(data_filename,'training',None,self.ARGS)
597
598 self.Sequences = Sequences
599 self.Exons = Exons
600 self.Ests = Ests
601 self.OriginalEsts= OriginalEsts
602 self.Qualities = Qualities
603 self.Donors = Donors
604 self.Acceptors = Acceptors
605 self.UpCut = UpCut
606 self.StartPos = StartPos
607
608 self.AlternativeSequences = AlternativeSequences
609
610 #calc_info(self.Acceptors,self.Donors,self.Exons,self.Qualities)
611 #print 'leaving constructor...'
612
613 self.logfh = open('_qpalma_predict.log','w+')
614
615 # predict on training set
616 self.plog('##### Prediction on the training set #####\n')
617
618 self.predict(param_filename,0,beg,'TRAIN')
619
620 # predict on test set
621 self.plog('##### Prediction on the test set #####\n')
622 self.predict(param_filename,beg,end,'TEST')
623
624 self.plog('##### Finished prediction #####\n')
625 self.logfh.close()
626
627 def predict(self,param_filename,beg,end,set_flag):
628 """
629 Performing a prediction takes...
630
631 """
632
633 run = self.run
634
635 if self.run['mode'] == 'normal':
636 self.use_quality_scores = False
637
638 elif self.run['mode'] == 'using_quality_scores':
639 self.use_quality_scores = True
640 else:
641 assert(False)
642
643 Sequences = self.Sequences[beg:end]
644 Exons = self.Exons[beg:end]
645 Ests = self.Ests[beg:end]
646 OriginalEsts = self.OriginalEsts[beg:end]
647 Qualities = self.Qualities[beg:end]
648 Acceptors = self.Acceptors[beg:end]
649 Donors = self.Donors[beg:end]
650 UpCut = self.UpCut[beg:end]
651 StartPos = self.StartPos[beg:end]
652
653 AlternativeSequences = self.AlternativeSequences[beg:end]
654
655 # number of training instances
656 N = numExamples = len(Sequences)
657 assert len(Exons) == N and len(Ests) == N\
658 and len(Qualities) == N and len(Acceptors) == N\
659 and len(Donors) == N, 'The Exons,Acc,Don,.. arrays are of different lengths'
660 self.plog('Number of training examples: %d\n'% numExamples)
661
662 self.noImprovementCtr = 0
663 self.oldObjValue = 1e8
664
665 remove_duplicate_scores = self.run['remove_duplicate_scores']
666 print_matrix = self.run['print_matrix']
667 anzpath = self.run['anzpath']
668
669 param = cPickle.load(open(param_filename))
670
671 # Set the parameters such as limits penalties for the Plifs
672 [h,d,a,mmatrix,qualityPlifs] =\
673 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
674
675 #############################################################################################
676 # Prediction
677 #############################################################################################
678 self.plog('Starting prediction...\n')
679
680 donSP = self.run['numDonSuppPoints']
681 accSP = self.run['numAccSuppPoints']
682 lengthSP = self.run['numLengthSuppPoints']
683 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
684 numq = self.run['numQualSuppPoints']
685 totalQualSP = self.run['totalQualSuppPoints']
686
687 totalQualityPenalties = zeros((totalQualSP,1))
688
689 # where we store the predictions
690 allPredictions = []
691
692 # beginning of the prediction loop
693 for exampleIdx in range(numExamples):
694 self.plog('Loading example nr. %d...\n'%exampleIdx)
695
696 dna = Sequences[exampleIdx]
697 est = Ests[exampleIdx]
698
699 new_est = unbracket_est(est)
700
701
702 exons = Exons[exampleIdx]
703
704 current_up_cut = UpCut[exampleIdx]
705
706 current_start_pos = StartPos[exampleIdx]
707
708 currentAlternatives = AlternativeSequences[exampleIdx]
709
710 #est = est.replace('-','')
711 #original_est = OriginalEsts[exampleIdx]
712 #original_est = "".join(original_est)
713 #original_est = original_est.lower()
714 #currentSplitPos = SplitPos[exampleIdx]
715
716 if self.run['mode'] == 'normal':
717 quality = [40]*len(est)
718
719 if self.run['mode'] == 'using_quality_scores':
720 quality = Qualities[exampleIdx]
721
722 if not run['enable_quality_scores']:
723 quality = [40]*len(est)
724
725 don_supp = Donors[exampleIdx]
726 acc_supp = Acceptors[exampleIdx]
727
728 if not run['enable_splice_signals']:
729
730 for idx,elem in enumerate(don_supp):
731 if elem != -inf:
732 don_supp[idx] = 0.0
733
734 for idx,elem in enumerate(acc_supp):
735 if elem != -inf:
736 acc_supp[idx] = 0.0
737
738 current_example_predictions = []
739
740 # first make a prediction on the dna fragment which comes from the ground truth
741 current_prediction = self.calc_alignment(dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs)
742 current_prediction['exampleIdx'] = exampleIdx
743 current_prediction['start_pos'] = current_start_pos
744 current_prediction['label'] = True
745
746 current_example_predictions.append(current_prediction)
747
748 # then make predictions for all dna fragments that where occurring in
749 # the vmatch results
750 for alternative_alignment in currentAlternatives:
751 chr, strand, genomicSeq_start, genomicSeq_stop, currentLabel = alternative_alignment
752 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
753
754 if not run['enable_splice_signals']:
755 for idx,elem in enumerate(currentDon):
756 if elem != -inf:
757 currentDon[idx] = 0.0
758
759 for idx,elem in enumerate(currentAcc):
760 if elem != -inf:
761 currentAcc[idx] = 0.0
762
763 current_prediction = self.calc_alignment(currentDNASeq, est, exons,\
764 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs)
765 current_prediction['exampleIdx'] = exampleIdx
766 current_prediction['start_pos'] = current_start_pos
767 current_prediction['alternative_start_pos'] = genomicSeq_start
768 current_prediction['label'] = currentLabel
769
770 current_example_predictions.append(current_prediction)
771
772 allPredictions.append(current_example_predictions)
773
774 # end of the prediction loop we save all predictions in a pickle file and exit
775 cPickle.dump(allPredictions,open('%s_allPredictions_%s'%(run['name'],set_flag),'w+'))
776 print 'Prediction completed'
777
778
779 def calc_alignment(self, dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
780 """
781 Given two sequences and the parameters we calculate on alignment
782 """
783
784 run = self.run
785
786 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
787
788 #myalign wants the acceptor site on the g of the ag
789 acceptor = acceptor[1:]
790 acceptor.append(-inf)
791
792 dna = str(dna)
793 est = str(est)
794 dna_len = len(dna)
795 est_len = len(est)
796
797 ps = h.convert2SWIG()
798
799 __newSpliceAlign, __newEstAlign, __newWeightMatch, __newDPScores,\
800 __newQualityPlifsFeatures, __dna_array, __est_array =\
801 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
802
803 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
804
805 # old code removed
806
807 newSpliceAlign = __newSpliceAlign
808 newEstAlign = __newEstAlign
809 newWeightMatch = __newWeightMatch
810 newDPScores = __newDPScores
811 newQualityPlifsFeatures = __newQualityPlifsFeatures
812 dna_array = __dna_array
813 est_array = __est_array
814
815 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
816 newWeightMatch = newWeightMatch.reshape(1,mm_len)
817 true_map = [0]*2
818 true_map[0] = 1
819 pathNr = 0
820
821 _newSpliceAlign = newSpliceAlign.flatten().tolist()[0]
822 _newEstAlign = newEstAlign.flatten().tolist()[0]
823
824 if False:
825 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
826 self.plog(line1+'\n')
827 self.plog(line2+'\n')
828 self.plog(line3+'\n')
829
830 newExons = self.calculatePredictedExons(newSpliceAlign)
831
832 current_prediction = {'predExons':newExons, 'trueExons':exons,\
833 'dna':dna, 'est':est, 'DPScores':newDPScores}
834
835 return current_prediction
836
837
838 def calculatePredictedExons(self,SpliceAlign):
839 newExons = []
840 oldElem = -1
841 SpliceAlign = SpliceAlign.flatten().tolist()[0]
842 SpliceAlign.append(-1)
843 for pos,elem in enumerate(SpliceAlign):
844 if pos == 0:
845 oldElem = -1
846 else:
847 oldElem = SpliceAlign[pos-1]
848
849 if oldElem != 0 and elem == 0: # start of exon
850 newExons.append(pos)
851
852 if oldElem == 0 and elem != 0: # end of exon
853 newExons.append(pos)
854
855 return newExons
856
857
858 def get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,dna_flat_files):
859 """
860 This function expects an interval, chromosome and strand information and
861 returns then the genomic sequence of this interval and the associated scores.
862 """
863
864 chrom = 'chr%d' % chr
865 genomicSeq = load_genomic(chrom,strand,genomicSeq_start-1,genomicSeq_stop,dna_flat_files,one_based=False)
866 genomicSeq = genomicSeq.lower()
867
868 # check the obtained dna sequence
869 assert genomicSeq != '', 'load_genomic returned empty sequence!'
870 #for elem in genomicSeq:
871 # if not elem in alphabet:
872
873 no_base = re.compile('[^acgt]')
874 genomicSeq = no_base.sub('n',genomicSeq)
875
876 intervalBegin = genomicSeq_start-100
877 intervalEnd = genomicSeq_stop+100
878 currentDNASeq = genomicSeq
879 seq_pos_offset = genomicSeq_start
880
881 currentAcc, currentDon = getSpliceScores(chr,strand,intervalBegin,intervalEnd,currentDNASeq,seq_pos_offset)
882
883 return currentDNASeq, currentAcc, currentDon
884
885
886 ###########################
887 # A simple command line
888 # interface
889 ###########################
890
891 if __name__ == '__main__':
892 mode = sys.argv[1]
893 run_obj_fn = sys.argv[2]
894
895 run_obj = cPickle.load(open(run_obj_fn))
896
897 qpalma = QPalma(run_obj)
898
899
900 if len(sys.argv) == 3 and mode == 'train':
901 qpalma.train()
902
903 elif len(sys.argv) == 4 and mode == 'predict':
904 param_filename = sys.argv[3]
905 assert os.path.exists(param_filename)
906 qpalma.evaluate(param_filename)
907 else:
908 print 'You have to choose between training or prediction mode:'
909 print 'python qpalma. py (train|predict) <param_file>'