+ removed duplicated code
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads together with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import sys
18 import cPickle
19 import pdb
20 import re
21 import os.path
22
23 from compile_dataset import getSpliceScores, get_seq_and_scores
24
25 import numpy
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy.linalg import norm
28
29 import QPalmaDP
30 import qpalma
31
32 from qpalma.SIQP_CPX import SIQPSolver
33 #from qpalma.SIQP_CVXOPT import SIQPSolver
34
35 from qpalma.DataProc import *
36 from qpalma.computeSpliceWeights import *
37 from qpalma.set_param_palma import *
38 from qpalma.computeSpliceAlignWithQuality import *
39 from qpalma.penalty_lookup_new import *
40 from qpalma.compute_donacc import *
41 from qpalma.TrainingParam import Param
42 from qpalma.Plif import Plf
43
44 from qpalma.tools.splicesites import getDonAccScores
45 from qpalma.Configuration import *
46
47 # this two imports are needed for the load genomic resp. interval query
48 # functions
49 from Genefinding import *
50 from genome_utils import load_genomic
51 from Utils import calc_stat, calc_info, pprint_alignment
52
53
54 def unbracket_est(est):
55 new_est = ''
56 e = 0
57
58 while True:
59 if e >= len(est):
60 break
61
62 if est[e] == '[':
63 new_est += est[e+2]
64 e += 4
65 else:
66 new_est += est[e]
67 e += 1
68
69 return "".join(new_est).lower()
70
71
72 def getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run):
73 currentSeqInfo = SeqInfo[exampleIdx]
74 chr,strand,up_cut,down_cut = currentSeqInfo
75
76 est = OriginalEsts[exampleIdx]
77 est = "".join(est)
78 est = est.lower()
79 est = unbracket_est(est)
80 est = est.replace('-','')
81
82 assert len(est) == run['read_size'], pdb.set_trace()
83 est_len = len(est)
84
85 original_est = OriginalEsts[exampleIdx]
86 original_est = "".join(original_est)
87 original_est = original_est.lower()
88
89 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
90 dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
91
92 # splice score is located at g of ag
93 ag_tuple_pos = [p for p,e in enumerate(dna) if p>1 and dna[p-1]=='a' and dna[p]=='g' ]
94 assert ag_tuple_pos == [p for p,e in enumerate(acc_supp) if e != -inf and p > 1], pdb.set_trace()
95
96 gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
97 assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
98
99 # check whether the whole business is really related to invalid
100 # splice scores
101 #acc_supp = [0.0]*len(acc_supp)
102 #don_supp = [0.0]*len(don_supp)
103
104 original_exons = Exons[exampleIdx]
105
106 exons = original_exons - (up_cut-1)
107 exons[0,0] -= 1
108 exons[1,0] -= 1
109
110 fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
111
112 donor_elem = dna[exons[0,1]:exons[0,1]+2]
113 acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
114
115 if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
116 print 'invalid donor in example %d'% exampleIdx
117
118 if not ( acceptor_elem == 'ag' ):
119 print 'invalid acceptor in example %d'% exampleIdx
120
121 assert len(fetched_dna_subseq) == len(est), pdb.set_trace()
122
123 #new_string = ''
124 #for idx in range(len(est)):
125 # dna_char = fetched_dna_subseq[idx]
126 # est_char = est[idx]
127 # if dna_char == est_char:
128 # new_string += est_char
129 # else:
130 # new_string += '[%s%s]'%(dna_char,est_char)
131
132 return dna,est,acc_supp,don_supp,exons,original_est
133
134
135 class QPalma:
136 """
137 A training method for the QPalma project
138 """
139
140 def __init__(self,run):
141 self.ARGS = Param()
142 self.run = run
143
144 if self.run['mode'] == 'normal':
145 self.use_quality_scores = False
146
147 elif self.run['mode'] == 'using_quality_scores':
148 self.use_quality_scores = True
149 else:
150 assert(False)
151
152
153 def plog(self,string):
154 self.logfh.write(string)
155 self.logfh.flush()
156
157
158 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
159 """
160 Given the needed input this method calls the QPalma C module which
161 calculates a dynamic programming in order to obtain an alignment
162 """
163 run = self.run
164
165 dna_len = len(dna)
166 est_len = len(est)
167
168 #pdb.set_trace()
169
170 prb = QPalmaDP.createDoubleArrayFromList(quality)
171 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
172
173 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
174 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
175
176 d_len = len(donor)
177 donor = QPalmaDP.createDoubleArrayFromList(donor)
178 a_len = len(acceptor)
179 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
180
181 # Create the alignment object representing the interface to the C/C++ code.
182 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
183 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
184 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
185 currentAlignment.myalign( current_num_path, dna, dna_len,\
186 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
187 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
188 print_matrix)
189
190 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
191 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
192 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
193 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
194
195 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
196
197 if prediction_mode:
198 # part that is only needed for prediction
199 result_len = currentAlignment.getResultLength()
200 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
201 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
202
203 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
204
205 dna_array = [0.0]*result_len
206 est_array = [0.0]*result_len
207
208 for r_idx in range(result_len):
209 dna_array[r_idx] = c_dna_array[r_idx]
210 est_array[r_idx] = c_est_array[r_idx]
211
212 else:
213 dna_array = None
214 est_array = None
215
216 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
217 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
218
219 #print 'After calling getAlignmentResults...'
220
221 newSpliceAlign = zeros((current_num_path*dna_len,1))
222 newEstAlign = zeros((est_len*current_num_path,1))
223 newWeightMatch = zeros((current_num_path*mm_len,1))
224 newDPScores = zeros((current_num_path,1))
225 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
226
227 for i in range(dna_len*current_num_path):
228 newSpliceAlign[i] = c_SpliceAlign[i]
229
230 for i in range(est_len*current_num_path):
231 newEstAlign[i] = c_EstAlign[i]
232
233 for i in range(mm_len*current_num_path):
234 newWeightMatch[i] = c_WeightMatch[i]
235
236 for i in range(current_num_path):
237 newDPScores[i] = c_DPScores[i]
238
239 if self.use_quality_scores:
240 for i in range(run['totalQualSuppPoints']*current_num_path):
241 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
242
243 del c_SpliceAlign
244 del c_EstAlign
245 del c_WeightMatch
246 del c_DPScores
247 del c_qualityPlifsFeatures
248 del currentAlignment
249
250 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
251 newQualityPlifsFeatures, dna_array, est_array
252
253
254 def train(self):
255 run = self.run
256
257 full_working_path = os.path.join(run['experiment_path'],run['name'])
258
259 assert not os.path.exists(full_working_path)
260 os.mkdir(full_working_path)
261
262 assert os.path.exists(full_working_path)
263
264 # ATTENTION: Changing working directory
265 os.chdir(full_working_path)
266
267 cPickle.dump(run,open('run_object.pickle','w+'))
268
269 self.logfh = open('_qpalma_train.log','w+')
270
271 self.plog("Settings are:\n")
272 self.plog("%s\n"%str(run))
273
274 data_filename = self.run['dataset_filename']
275
276 SeqInfo, Exons, OriginalEsts, Qualities,\
277 AlternativeSequences = paths_load_data(data_filename,'training',None,self.ARGS)
278
279 # Load the whole dataset
280 if self.run['mode'] == 'normal':
281 self.use_quality_scores = False
282
283 elif self.run['mode'] == 'using_quality_scores':
284 self.use_quality_scores = True
285 else:
286 assert(False)
287
288 self.SeqInfo = SeqInfo
289 self.Exons = Exons
290 self.OriginalEsts= OriginalEsts
291 self.Qualities = Qualities
292
293 #calc_info(self.Exons,self.Qualities)
294
295 beg = run['training_begin']
296 end = run['training_end']
297
298 SeqInfo = SeqInfo[beg:end]
299 Exons = Exons[beg:end]
300 OriginalEsts= OriginalEsts[beg:end]
301 Qualities = Qualities[beg:end]
302
303 # number of training instances
304 N = numExamples = len(SeqInfo)
305 assert len(Exons) == N and len(OriginalEsts) == N and len(Qualities) == N,\
306 'The Exons,Acc,Don,.. arrays are of different lengths'
307
308 self.plog('Number of training examples: %d\n'% numExamples)
309
310 self.noImprovementCtr = 0
311 self.oldObjValue = 1e8
312
313 iteration_steps = run['iter_steps']
314 remove_duplicate_scores = run['remove_duplicate_scores']
315 print_matrix = run['print_matrix']
316 anzpath = run['anzpath']
317
318 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
319 param = Conf.fixedParam[:run['numFeatures']]
320
321 lengthSP = run['numLengthSuppPoints']
322 donSP = run['numDonSuppPoints']
323 accSP = run['numAccSuppPoints']
324 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
325 numq = run['numQualSuppPoints']
326 totalQualSP = run['totalQualSuppPoints']
327
328 # no intron length model
329 if not run['enable_intron_length']:
330 param[:lengthSP] *= 0.0
331
332 # Set the parameters such as limits penalties for the Plifs
333 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
334
335 # Initialize solver
336 self.plog('Initializing problem...\n')
337
338 solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
339 #solver = None
340
341 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
342 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
343
344 # stores the number of alignments done for each example (best path, second-best path etc.)
345 num_path = [anzpath]*numExamples
346 # stores the gap for each example
347 gap = [0.0]*numExamples
348 #############################################################################################
349 # Training
350 #############################################################################################
351 self.plog('Starting training...\n')
352
353 currentPhi = zeros((run['numFeatures'],1))
354 totalQualityPenalties = zeros((totalQualSP,1))
355
356 numConstPerRound = run['numConstraintsPerRound']
357 solver_call_ctr = 0
358
359 suboptimal_example = 0
360 iteration_nr = 0
361 param_idx = 0
362 const_added_ctr = 0
363
364 featureVectors = zeros((run['numFeatures'],numExamples))
365
366 # the main training loop
367 while True:
368 if iteration_nr == iteration_steps:
369 break
370
371 for exampleIdx in range(numExamples):
372 if (exampleIdx%100) == 0:
373 print 'Current example nr %d' % exampleIdx
374
375 dna,est,acc_supp,don_supp,exons,original_est =\
376 getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run)
377
378 dna_len = len(dna)
379
380 #if new_string != original_est:
381 # print new_string,original_est
382 # continue
383
384 if run['mode'] == 'normal':
385 quality = [40]*len(est)
386
387 if run['mode'] == 'using_quality_scores':
388 quality = Qualities[exampleIdx][0]
389
390 if not run['enable_quality_scores']:
391 quality = [40]*len(est)
392
393 if not run['enable_splice_signals']:
394 for idx,elem in enumerate(don_supp):
395 if elem != -inf:
396 don_supp[idx] = 0.0
397
398 for idx,elem in enumerate(acc_supp):
399 if elem != -inf:
400 acc_supp[idx] = 0.0
401
402 #pdb.set_trace()
403 #continue
404
405 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
406 if run['mode'] == 'using_quality_scores':
407 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
408 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
409 quality, qualityPlifs,run)
410 else:
411 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
412
413 dna_calc = dna_calc.replace('-','')
414
415 #print 'right before computeSpliceWeights exampleIdx %d' % exampleIdx
416 # Calculate the weights
417 trueWeightDon, trueWeightAcc, trueWeightIntron =\
418 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
419 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
420
421 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
422 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
423 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
424 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
425
426 if run['mode'] == 'using_quality_scores':
427 totalQualityPenalties = param[-totalQualSP:]
428 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
429
430 # Calculate w'phi(x,y) the total score of the alignment
431 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
432
433 # The allWeights vector is supposed to store the weight parameter
434 # of the true alignment as well as the weight parameters of the
435 # num_path[exampleIdx] other alignments
436 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
437 allWeights[:,0] = trueWeight[:,0]
438
439 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
440 AlignmentScores[0] = trueAlignmentScore
441
442 ################## Calculate wrong alignment(s) ######################
443 # Compute donor, acceptor with penalty_lookup_new
444 # returns two double lists
445 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
446
447 #myalign wants the acceptor site on the g of the ag
448 #acceptor = acceptor[1:]
449 #acceptor.append(-inf)
450
451 #donor = [-inf] + donor[:-1]
452 #pdb.set_trace()
453
454 ps = h.convert2SWIG()
455
456 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
457 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
458 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
459 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
460
461 # old code removed
462
463 newSpliceAlign = _newSpliceAlign
464 newEstAlign = _newEstAlign
465 newWeightMatch = _newWeightMatch
466 newDPScores = _newDPScores
467 newQualityPlifsFeatures = _newQualityPlifsFeatures
468
469 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
470 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
471
472 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
473 # Calculate weights of the respective alignments. Note that we are
474 # calculating n-best alignments without hamming loss, so we
475 # have to keep track which of the n-best alignments correspond to
476 # the true one in order not to incorporate a true alignment in the
477 # constraints. To keep track of the true and false alignments we
478 # define an array true_map with a boolean indicating the
479 # equivalence to the true alignment for each decoded alignment.
480 true_map = [0]*(num_path[exampleIdx]+1)
481 true_map[0] = 1
482
483 for pathNr in range(num_path[exampleIdx]):
484 #print 'decodedWeights'
485 #print 'right before computeSpliceWeights(2) exampleIdx %d' % exampleIdx
486 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
487 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
488 acc_supp)
489
490 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
491 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
492 # Gewichte in restliche Zeilen der Matrix speichern
493 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
494
495 hpen = mat(h.penalties).reshape(len(h.penalties),1)
496 dpen = mat(d.penalties).reshape(len(d.penalties),1)
497 apen = mat(a.penalties).reshape(len(a.penalties),1)
498 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
499
500 featureVectors[:,exampleIdx] = allWeights[:,pathNr+1]
501
502 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
503
504 distinct_scores = False
505 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
506 distinct_scores = True
507
508 # Check wether scalar product + loss equals viterbi score
509 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
510 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
511 pdb.set_trace()
512
513 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
514 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
515
516 # if the pathNr-best alignment is very close to the true alignment consider it as true
517 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
518 true_map[pathNr+1] = 1
519
520 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
521 print "suboptimal_example %d\n" %exampleIdx
522 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
523 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
524
525 #pdb.set_trace()
526 suboptimal_example += 1
527 self.plog("suboptimal_example %d\n" %exampleIdx)
528
529 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
530 # this means that all n-best paths are to close to each other
531 # we have to extend the n-best search to a (n+1)-best
532 if len([elem for elem in true_map if elem == 1]) == len(true_map):
533 num_path[exampleIdx] = num_path[exampleIdx]+1
534
535 # Choose true and first false alignment for extending
536 firstFalseIdx = -1
537 for map_idx,elem in enumerate(true_map):
538 if elem == 0:
539 firstFalseIdx = map_idx
540 break
541
542 if False:
543 self.plog("Is considered as: %d\n" % true_map[1])
544
545 result_len = currentAlignment.getResultLength()
546 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
547 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
548
549 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
550
551 dna_array = [0.0]*result_len
552 est_array = [0.0]*result_len
553
554 for r_idx in range(result_len):
555 dna_array[r_idx] = c_dna_array[r_idx]
556 est_array[r_idx] = c_est_array[r_idx]
557
558 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
559 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
560
561 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
562 self.plog(line1+'\n')
563 self.plog(line2+'\n')
564 self.plog(line3+'\n')
565
566 # if there is at least one useful false alignment add the
567 # corresponding constraints to the optimization problem
568 if firstFalseIdx != -1:
569 firstFalseWeights = allWeights[:,firstFalseIdx]
570 differenceVector = trueWeight - firstFalseWeights
571 #pdb.set_trace()
572
573 #print 'NOT ADDING ANY CONSTRAINTS'
574 const_added = solver.addConstraint(differenceVector, exampleIdx)
575
576 const_added_ctr += 1
577 #
578 # end of one example processing
579 #
580
581 # call solver every nth example //added constraint
582 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
583 objValue,w,self.slacks = solver.solve()
584 solver_call_ctr += 1
585
586 if solver_call_ctr == 5:
587 numConstPerRound = 200
588 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
589
590 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
591 self.noImprovementCtr += 1
592
593 if self.noImprovementCtr == numExamples+1:
594 break
595
596 self.oldObjValue = objValue
597 print "objValue is %f" % objValue
598
599 sum_xis = 0
600 for elem in self.slacks:
601 sum_xis += elem
602
603 print 'sum of slacks is %f'% sum_xis
604 self.plog('sum of slacks is %f\n'% sum_xis)
605
606 for i in range(len(param)):
607 param[i] = w[i]
608
609 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
610 param_idx += 1
611 [h,d,a,mmatrix,qualityPlifs] =\
612 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
613
614 #
615 # end of one iteration through all examples
616 #
617
618 self.plog("suboptimal rounds %d\n" %suboptimal_example)
619
620 if self.noImprovementCtr == numExamples*2:
621 break
622
623 iteration_nr += 1
624
625 #
626 # end of optimization
627 #
628 print 'Training completed'
629
630 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
631 self.logfh.close()
632
633
634 ###############################################################################
635 #
636 # End of the code needed for training
637 #
638 #
639 # Begin of code for prediction
640 #
641 ###############################################################################
642
643
644 def evaluate(self,param_filename):
645 run = self.run
646 beg = run['prediction_begin']
647 end = run['prediction_end']
648
649 data_filename = self.run['dataset_filename']
650
651 SeqInfo, Exons, OriginalEsts, Qualities,\
652 AlternativeSequences = paths_load_data(data_filename,'training',None,self.ARGS)
653
654 self.SeqInfo = SeqInfo
655 self.Exons = Exons
656 self.OriginalEsts= OriginalEsts
657 self.Qualities = Qualities
658 self.AlternativeSequences = AlternativeSequences
659
660 #calc_info(self.Acceptors,self.Donors,self.Exons,self.Qualities)
661 #print 'leaving constructor...'
662
663 self.logfh = open('_qpalma_predict.log','w+')
664
665 # predict on training set
666 self.plog('##### Prediction on the training set #####\n')
667
668 self.predict(param_filename,0,beg,'TRAIN')
669
670 # predict on test set
671 self.plog('##### Prediction on the test set #####\n')
672 self.predict(param_filename,beg,end,'TEST')
673
674 self.plog('##### Finished prediction #####\n')
675 self.logfh.close()
676
677 def predict(self,param_filename,beg,end,set_flag):
678 """
679 Performing a prediction takes...
680
681 """
682
683 run = self.run
684
685 if self.run['mode'] == 'normal':
686 self.use_quality_scores = False
687
688 elif self.run['mode'] == 'using_quality_scores':
689 self.use_quality_scores = True
690 else:
691 assert(False)
692
693 SeqInfo = self.SeqInfo[beg:end]
694 Exons = self.Exons[beg:end]
695 OriginalEsts= self.OriginalEsts[beg:end]
696 Qualities = self.Qualities[beg:end]
697 AlternativeSequences = self.AlternativeSequences[beg:end]
698
699 # number of training instances
700 N = numExamples = len(SeqInfo)
701 assert len(Exons) == N and len(OriginalEsts) == N\
702 and len(Qualities) == N and len(AlternativeSequences) == N,\
703 'The Exons,Acc,Don,.. arrays are of different lengths'
704
705 self.plog('Number of training examples: %d\n'% numExamples)
706
707 self.noImprovementCtr = 0
708 self.oldObjValue = 1e8
709
710 remove_duplicate_scores = self.run['remove_duplicate_scores']
711 print_matrix = self.run['print_matrix']
712 anzpath = self.run['anzpath']
713
714 param = cPickle.load(open(param_filename))
715
716 # Set the parameters such as limits penalties for the Plifs
717 [h,d,a,mmatrix,qualityPlifs] =\
718 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
719
720 #############################################################################################
721 # Prediction
722 #############################################################################################
723 self.plog('Starting prediction...\n')
724
725 donSP = self.run['numDonSuppPoints']
726 accSP = self.run['numAccSuppPoints']
727 lengthSP = self.run['numLengthSuppPoints']
728 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
729 numq = self.run['numQualSuppPoints']
730 totalQualSP = self.run['totalQualSuppPoints']
731
732 totalQualityPenalties = zeros((totalQualSP,1))
733
734 # where we store the predictions
735 allPredictions = []
736
737 # beginning of the prediction loop
738 for exampleIdx in range(numExamples):
739 self.plog('Loading example nr. %d...\n'%exampleIdx)
740
741 currentSeqInfo = SeqInfo[exampleIdx]
742 chr,strand,up_cut,down_cut = currentSeqInfo
743
744 dna,est,acc_supp,don_supp,exons,original_est =\
745 getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run)
746
747 dna_len = len(dna)
748
749 #if new_string != original_est:
750 # print 'seq. inconsistency'
751 # print new_string,original_est
752 # print exampleIdx
753 # continue
754
755 if run['mode'] == 'normal':
756 quality = [40]*len(est)
757
758 if run['mode'] == 'using_quality_scores':
759 quality = Qualities[exampleIdx][0]
760
761 if not run['enable_quality_scores']:
762 quality = [40]*len(est)
763
764 if not run['enable_splice_signals']:
765 for idx,elem in enumerate(don_supp):
766 if elem != -inf:
767 don_supp[idx] = 0.0
768
769 for idx,elem in enumerate(acc_supp):
770 if elem != -inf:
771 acc_supp[idx] = 0.0
772
773 currentAlternatives = AlternativeSequences[exampleIdx]
774 current_example_predictions = []
775
776 # first make a prediction on the dna fragment which comes from the ground truth
777 current_prediction = self.calc_alignment(dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs)
778 current_prediction['exampleIdx'] = exampleIdx
779 current_prediction['start_pos'] = up_cut
780 current_prediction['label'] = True
781
782 current_example_predictions.append(current_prediction)
783
784 # then make predictions for all dna fragments that where occurring in
785 # the vmatch results
786 for alternative_alignment in currentAlternatives:
787 chr, strand, genomicSeq_start, genomicSeq_stop, currentLabel = alternative_alignment
788
789 if not chr in range(1,6):
790 continue
791
792 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
793
794 if not run['enable_splice_signals']:
795 for idx,elem in enumerate(currentDon):
796 if elem != -inf:
797 currentDon[idx] = 0.0
798
799 for idx,elem in enumerate(currentAcc):
800 if elem != -inf:
801 currentAcc[idx] = 0.0
802
803 current_prediction = self.calc_alignment(currentDNASeq, est, exons,\
804 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs)
805 current_prediction['exampleIdx'] = exampleIdx
806 current_prediction['start_pos'] = up_cut
807 current_prediction['alternative_start_pos'] = genomicSeq_start
808 current_prediction['label'] = currentLabel
809
810 current_example_predictions.append(current_prediction)
811
812 allPredictions.append(current_example_predictions)
813
814 # end of the prediction loop we save all predictions in a pickle file and exit
815 cPickle.dump(allPredictions,open('%s_allPredictions_%s'%(run['name'],set_flag),'w+'))
816 print 'Prediction completed'
817
818
819 def calc_alignment(self, dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
820 """
821 Given two sequences and the parameters we calculate on alignment
822 """
823
824 run = self.run
825 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
826
827 dna = str(dna)
828 est = str(est)
829 dna_len = len(dna)
830 est_len = len(est)
831
832 ps = h.convert2SWIG()
833
834 newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
835 newQualityPlifsFeatures, dna_array, est_array =\
836 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
837
838 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
839
840 # old code removed
841 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
842 newWeightMatch = newWeightMatch.reshape(1,mm_len)
843 true_map = [0]*2
844 true_map[0] = 1
845 pathNr = 0
846
847 _newSpliceAlign = newSpliceAlign.flatten().tolist()[0]
848 _newEstAlign = newEstAlign.flatten().tolist()[0]
849
850 if False:
851 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
852 self.plog(line1+'\n')
853 self.plog(line2+'\n')
854 self.plog(line3+'\n')
855
856 newExons = self.calculatePredictedExons(newSpliceAlign)
857
858 current_prediction = {'predExons':newExons, 'trueExons':exons,\
859 'dna':dna, 'est':est, 'DPScores':newDPScores}
860
861 return current_prediction
862
863
864 def calculatePredictedExons(self,SpliceAlign):
865 newExons = []
866 oldElem = -1
867 SpliceAlign = SpliceAlign.flatten().tolist()[0]
868 SpliceAlign.append(-1)
869 for pos,elem in enumerate(SpliceAlign):
870 if pos == 0:
871 oldElem = -1
872 else:
873 oldElem = SpliceAlign[pos-1]
874
875 if oldElem != 0 and elem == 0: # start of exon
876 newExons.append(pos)
877
878 if oldElem == 0 and elem != 0: # end of exon
879 newExons.append(pos)
880
881 return newExons
882
883
884 ###########################
885 # A simple command line
886 # interface
887 ###########################
888
889 if __name__ == '__main__':
890 mode = sys.argv[1]
891 run_obj_fn = sys.argv[2]
892
893 run_obj = cPickle.load(open(run_obj_fn))
894
895 qpalma = QPalma(run_obj)
896
897
898 if len(sys.argv) == 3 and mode == 'train':
899 qpalma.train()
900
901 elif len(sys.argv) == 4 and mode == 'predict':
902 param_filename = sys.argv[3]
903 assert os.path.exists(param_filename)
904 qpalma.evaluate(param_filename)
905 else:
906 print 'You have to choose between training or prediction mode:'
907 print 'python qpalma. py (train|predict) <param_file>'