+ fixed nasty index bugs in the when creating the label of the ground truth
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads together with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import sys
18 import cPickle
19 import pdb
20 import re
21 import os.path
22
23 from compile_dataset import getSpliceScores, get_seq_and_scores
24
25 import numpy
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy.linalg import norm
28
29 import QPalmaDP
30 import qpalma
31
32 #from qpalma.SIQP_CPX import SIQPSolver
33 from qpalma.SIQP_CVXOPT import SIQPSolver
34
35 from qpalma.DataProc import *
36 from qpalma.computeSpliceWeights import *
37 from qpalma.set_param_palma import *
38 from qpalma.computeSpliceAlignWithQuality import *
39 from qpalma.penalty_lookup_new import *
40 from qpalma.compute_donacc import *
41 from qpalma.TrainingParam import Param
42 from qpalma.Plif import Plf
43
44 from qpalma.tools.splicesites import getDonAccScores
45 from qpalma.Configuration import *
46
47 # this two imports are needed for the load genomic resp. interval query
48 # functions
49 from Genefinding import *
50 from genome_utils import load_genomic
51
52 from Utils import calc_stat, calc_info, pprint_alignment
53
54
55 def unbracket_est(est):
56 new_est = ''
57 e = 0
58
59 while True:
60 if e >= len(est):
61 break
62
63 if est[e] == '[':
64 new_est += est[e+2]
65 e += 4
66 else:
67 new_est += est[e]
68 e += 1
69
70 return "".join(new_est).lower()
71
72
73 class QPalma:
74 """
75 A training method for the QPalma project
76 """
77
78 def __init__(self,run):
79 self.ARGS = Param()
80 self.run = run
81
82 if self.run['mode'] == 'normal':
83 self.use_quality_scores = False
84
85 elif self.run['mode'] == 'using_quality_scores':
86 self.use_quality_scores = True
87 else:
88 assert(False)
89
90
91 def plog(self,string):
92 self.logfh.write(string)
93 self.logfh.flush()
94
95
96 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
97 """
98 Given the needed input this method calls the QPalma C module which
99 calculates a dynamic programming in order to obtain an alignment
100 """
101 run = self.run
102
103 dna_len = len(dna)
104 est_len = len(est)
105
106 prb = QPalmaDP.createDoubleArrayFromList(quality)
107 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
108
109 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
110 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
111
112 d_len = len(donor)
113 donor = QPalmaDP.createDoubleArrayFromList(donor)
114 a_len = len(acceptor)
115 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
116
117 # Create the alignment object representing the interface to the C/C++ code.
118 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
119 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
120 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
121 currentAlignment.myalign( current_num_path, dna, dna_len,\
122 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
123 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
124 print_matrix)
125
126 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
127 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
128 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
129 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
130
131 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
132
133 if prediction_mode:
134 # part that is only needed for prediction
135 result_len = currentAlignment.getResultLength()
136 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
137 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
138
139 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
140
141 dna_array = [0.0]*result_len
142 est_array = [0.0]*result_len
143
144 for r_idx in range(result_len):
145 dna_array[r_idx] = c_dna_array[r_idx]
146 est_array[r_idx] = c_est_array[r_idx]
147
148 else:
149 dna_array = None
150 est_array = None
151
152 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
153 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
154
155 #print 'After calling getAlignmentResults...'
156
157 newSpliceAlign = zeros((current_num_path*dna_len,1))
158 newEstAlign = zeros((est_len*current_num_path,1))
159 newWeightMatch = zeros((current_num_path*mm_len,1))
160 newDPScores = zeros((current_num_path,1))
161 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
162
163 for i in range(dna_len*current_num_path):
164 newSpliceAlign[i] = c_SpliceAlign[i]
165
166 for i in range(est_len*current_num_path):
167 newEstAlign[i] = c_EstAlign[i]
168
169 for i in range(mm_len*current_num_path):
170 newWeightMatch[i] = c_WeightMatch[i]
171
172 for i in range(current_num_path):
173 newDPScores[i] = c_DPScores[i]
174
175 if self.use_quality_scores:
176 for i in range(run['totalQualSuppPoints']*current_num_path):
177 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
178
179 del c_SpliceAlign
180 del c_EstAlign
181 del c_WeightMatch
182 del c_DPScores
183 del c_qualityPlifsFeatures
184 del currentAlignment
185
186 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
187 newQualityPlifsFeatures, dna_array, est_array
188
189
190 def train(self):
191 run = self.run
192
193 full_working_path = os.path.join(run['experiment_path'],run['name'])
194
195 assert not os.path.exists(full_working_path)
196 os.mkdir(full_working_path)
197
198 assert os.path.exists(full_working_path)
199
200 # ATTENTION: Changing working directory
201 os.chdir(full_working_path)
202
203 cPickle.dump(run,open('run_object.pickle','w+'))
204
205 self.logfh = open('_qpalma_train.log','w+')
206
207 self.plog("Settings are:\n")
208 self.plog("%s\n"%str(run))
209
210 data_filename = self.run['dataset_filename']
211
212 SeqInfo, Exons, OriginalEsts, Qualities,\
213 AlternativeSequences = paths_load_data(data_filename,'training',None,self.ARGS)
214
215 # Load the whole dataset
216 if self.run['mode'] == 'normal':
217 self.use_quality_scores = False
218
219 elif self.run['mode'] == 'using_quality_scores':
220 self.use_quality_scores = True
221 else:
222 assert(False)
223
224 self.SeqInfo = SeqInfo
225 self.Exons = Exons
226 self.OriginalEsts= OriginalEsts
227 self.Qualities = Qualities
228
229 #calc_info(self.Exons,self.Qualities)
230
231 beg = run['training_begin']
232 end = run['training_end']
233
234 SeqInfo = SeqInfo[beg:end]
235 Exons = Exons[beg:end]
236 OriginalEsts= OriginalEsts[beg:end]
237 Qualities = Qualities[beg:end]
238
239 # number of training instances
240 N = numExamples = len(SeqInfo)
241 assert len(Exons) == N and len(OriginalEsts) == N and len(Qualities) == N,\
242 'The Exons,Acc,Don,.. arrays are of different lengths'
243
244 self.plog('Number of training examples: %d\n'% numExamples)
245
246 self.noImprovementCtr = 0
247 self.oldObjValue = 1e8
248
249 iteration_steps = run['iter_steps']
250 remove_duplicate_scores = run['remove_duplicate_scores']
251 print_matrix = run['print_matrix']
252 anzpath = run['anzpath']
253
254 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
255 param = Conf.fixedParam[:run['numFeatures']]
256
257 lengthSP = run['numLengthSuppPoints']
258 donSP = run['numDonSuppPoints']
259 accSP = run['numAccSuppPoints']
260 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
261 numq = run['numQualSuppPoints']
262 totalQualSP = run['totalQualSuppPoints']
263
264 # no intron length model
265 if not run['enable_intron_length']:
266 param[:lengthSP] *= 0.0
267
268 # Set the parameters such as limits penalties for the Plifs
269 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
270
271 # Initialize solver
272 self.plog('Initializing problem...\n')
273
274 solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
275 #solver = None
276
277 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
278 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
279
280 # stores the number of alignments done for each example (best path, second-best path etc.)
281 num_path = [anzpath]*numExamples
282 # stores the gap for each example
283 gap = [0.0]*numExamples
284 #############################################################################################
285 # Training
286 #############################################################################################
287 self.plog('Starting training...\n')
288
289 currentPhi = zeros((run['numFeatures'],1))
290 totalQualityPenalties = zeros((totalQualSP,1))
291
292 numConstPerRound = run['numConstraintsPerRound']
293 solver_call_ctr = 0
294
295 suboptimal_example = 0
296 iteration_nr = 0
297 param_idx = 0
298 const_added_ctr = 0
299
300 featureVectors = zeros((run['numFeatures'],numExamples))
301
302 # the main training loop
303 while True:
304 if iteration_nr == iteration_steps:
305 break
306
307 for exampleIdx in range(numExamples):
308 if (exampleIdx%100) == 0:
309 print 'Current example nr %d' % exampleIdx
310
311 currentSeqInfo = SeqInfo[exampleIdx]
312 chr,strand,up_cut,down_cut = currentSeqInfo
313
314 est = OriginalEsts[exampleIdx]
315 est = "".join(est)
316 est = est.lower()
317 est = unbracket_est(est)
318 est = est.replace('-','')
319
320 assert len(est) == run['read_size'], pdb.set_trace()
321 est_len = len(est)
322
323 original_est = OriginalEsts[exampleIdx]
324 original_est = "".join(original_est)
325 original_est = original_est.lower()
326
327 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
328 dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
329 dna_len = len(dna)
330
331 # splice score is located at g of ag
332 ag_tuple_pos = [p for p,e in enumerate(dna) if p>1 and dna[p-1]=='a' and dna[p]=='g' ]
333 assert ag_tuple_pos == [p for p,e in enumerate(acc_supp) if e != -inf and p > 1], pdb.set_trace()
334
335 gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
336 assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
337
338 if run['mode'] == 'normal':
339 quality = [40]*len(est)
340
341 if run['mode'] == 'using_quality_scores':
342 quality = Qualities[exampleIdx][0]
343
344 if not run['enable_quality_scores']:
345 quality = [40]*len(est)
346
347 # check whether the whole business is really related to invalid
348 # splice scores
349 #acc_supp = [0.0]*len(acc_supp)
350 #don_supp = [0.0]*len(don_supp)
351
352 original_exons = Exons[exampleIdx]
353 exons = original_exons - (up_cut-1)
354 exons[0,0] -= 1
355
356 fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
357
358 donor_elem = dna[exons[0,1]:exons[0,1]+2]
359 acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
360
361 if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
362 print 'invalid donor in example %d'% exampleIdx
363
364 if not ( acceptor_elem == 'ag' ):
365 print 'invalid acceptor in example %d'% exampleIdx
366
367 assert len(fetched_dna_subseq) == len(est), pdb.set_trace()
368
369 new_string = ''
370 for idx in range(len(est)):
371 dna_char = fetched_dna_subseq[idx]
372 est_char = est[idx]
373 if dna_char == est_char:
374 new_string += est_char
375 else:
376 new_string += '[%s%s]'%(dna_char,est_char)
377
378 if new_string != original_est:
379 print new_string,original_est
380 continue
381
382 if not run['enable_splice_signals']:
383 for idx,elem in enumerate(don_supp):
384 if elem != -inf:
385 don_supp[idx] = 0.0
386
387 for idx,elem in enumerate(acc_supp):
388 if elem != -inf:
389 acc_supp[idx] = 0.0
390
391 #pdb.set_trace()
392 #continue
393
394 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
395 if run['mode'] == 'using_quality_scores':
396 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
397 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
398 quality, qualityPlifs,run)
399 else:
400 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
401
402 dna_calc = dna_calc.replace('-','')
403
404 print 'right before computeSpliceWeights exampleIdx %d' % exampleIdx
405 # Calculate the weights
406 trueWeightDon, trueWeightAcc, trueWeightIntron =\
407 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
408 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
409
410 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
411 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
412 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
413 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
414
415 if run['mode'] == 'using_quality_scores':
416 totalQualityPenalties = param[-totalQualSP:]
417 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
418
419 # Calculate w'phi(x,y) the total score of the alignment
420 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
421
422 # The allWeights vector is supposed to store the weight parameter
423 # of the true alignment as well as the weight parameters of the
424 # num_path[exampleIdx] other alignments
425 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
426 allWeights[:,0] = trueWeight[:,0]
427
428 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
429 AlignmentScores[0] = trueAlignmentScore
430
431 ################## Calculate wrong alignment(s) ######################
432 # Compute donor, acceptor with penalty_lookup_new
433 # returns two double lists
434 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
435
436 #myalign wants the acceptor site on the g of the ag
437 #acceptor = acceptor[1:]
438 #acceptor.append(-inf)
439
440 #donor = [-inf] + donor[:-1]
441 #pdb.set_trace()
442
443 ps = h.convert2SWIG()
444
445 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
446 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
447 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
448 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
449
450 # old code removed
451
452 newSpliceAlign = _newSpliceAlign
453 newEstAlign = _newEstAlign
454 newWeightMatch = _newWeightMatch
455 newDPScores = _newDPScores
456 newQualityPlifsFeatures = _newQualityPlifsFeatures
457
458 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
459 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
460
461 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
462 # Calculate weights of the respective alignments. Note that we are
463 # calculating n-best alignments without hamming loss, so we
464 # have to keep track which of the n-best alignments correspond to
465 # the true one in order not to incorporate a true alignment in the
466 # constraints. To keep track of the true and false alignments we
467 # define an array true_map with a boolean indicating the
468 # equivalence to the true alignment for each decoded alignment.
469 true_map = [0]*(num_path[exampleIdx]+1)
470 true_map[0] = 1
471
472 for pathNr in range(num_path[exampleIdx]):
473 #print 'decodedWeights'
474 print 'right before computeSpliceWeights(2) exampleIdx %d' % exampleIdx
475 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
476 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
477 acc_supp)
478
479 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
480 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
481 # Gewichte in restliche Zeilen der Matrix speichern
482 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
483
484 hpen = mat(h.penalties).reshape(len(h.penalties),1)
485 dpen = mat(d.penalties).reshape(len(d.penalties),1)
486 apen = mat(a.penalties).reshape(len(a.penalties),1)
487 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
488
489 featureVectors[:,exampleIdx] = allWeights[:,pathNr+1]
490
491 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
492
493 distinct_scores = False
494 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
495 distinct_scores = True
496
497 # Check wether scalar product + loss equals viterbi score
498 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
499 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
500 pdb.set_trace()
501
502 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
503 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
504
505 # if the pathNr-best alignment is very close to the true alignment consider it as true
506 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
507 true_map[pathNr+1] = 1
508
509 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
510 print "suboptimal_example %d\n" %exampleIdx
511 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
512 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
513
514 #pdb.set_trace()
515 suboptimal_example += 1
516 self.plog("suboptimal_example %d\n" %exampleIdx)
517
518 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
519 # this means that all n-best paths are to close to each other
520 # we have to extend the n-best search to a (n+1)-best
521 if len([elem for elem in true_map if elem == 1]) == len(true_map):
522 num_path[exampleIdx] = num_path[exampleIdx]+1
523
524 # Choose true and first false alignment for extending
525 firstFalseIdx = -1
526 for map_idx,elem in enumerate(true_map):
527 if elem == 0:
528 firstFalseIdx = map_idx
529 break
530
531 if False:
532 self.plog("Is considered as: %d\n" % true_map[1])
533
534 result_len = currentAlignment.getResultLength()
535 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
536 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
537
538 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
539
540 dna_array = [0.0]*result_len
541 est_array = [0.0]*result_len
542
543 for r_idx in range(result_len):
544 dna_array[r_idx] = c_dna_array[r_idx]
545 est_array[r_idx] = c_est_array[r_idx]
546
547 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
548 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
549
550 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
551 self.plog(line1+'\n')
552 self.plog(line2+'\n')
553 self.plog(line3+'\n')
554
555 # if there is at least one useful false alignment add the
556 # corresponding constraints to the optimization problem
557 if firstFalseIdx != -1:
558 firstFalseWeights = allWeights[:,firstFalseIdx]
559 differenceVector = trueWeight - firstFalseWeights
560 #pdb.set_trace()
561
562 #print 'NOT ADDING ANY CONSTRAINTS'
563 const_added = solver.addConstraint(differenceVector, exampleIdx)
564
565 const_added_ctr += 1
566 #
567 # end of one example processing
568 #
569
570 # call solver every nth example //added constraint
571 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
572 objValue,w,self.slacks = solver.solve()
573 solver_call_ctr += 1
574
575 if solver_call_ctr == 5:
576 numConstPerRound = 200
577 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
578
579 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
580 self.noImprovementCtr += 1
581
582 if self.noImprovementCtr == numExamples+1:
583 break
584
585 self.oldObjValue = objValue
586 print "objValue is %f" % objValue
587
588 sum_xis = 0
589 for elem in self.slacks:
590 sum_xis += elem
591
592 print 'sum of slacks is %f'% sum_xis
593 self.plog('sum of slacks is %f\n'% sum_xis)
594
595 for i in range(len(param)):
596 param[i] = w[i]
597
598 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
599 param_idx += 1
600 [h,d,a,mmatrix,qualityPlifs] =\
601 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
602
603 #
604 # end of one iteration through all examples
605 #
606
607 self.plog("suboptimal rounds %d\n" %suboptimal_example)
608
609 if self.noImprovementCtr == numExamples*2:
610 break
611
612 iteration_nr += 1
613
614 #
615 # end of optimization
616 #
617 print 'Training completed'
618
619 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
620 self.logfh.close()
621
622
623 ###############################################################################
624 #
625 # End of the code needed for training
626 #
627 #
628 # Begin of code for prediction
629 #
630 ###############################################################################
631
632
633 def evaluate(self,param_filename):
634 run = self.run
635 beg = run['prediction_begin']
636 end = run['prediction_end']
637
638 data_filename = self.run['dataset_filename']
639 Sequences, Acceptors, Donors, Exons, Ests, OriginalEsts, Qualities,\
640 UpCut, StartPos, AlternativeSequences=\
641 paths_load_data(data_filename,'training',None,self.ARGS)
642
643 self.Sequences = Sequences
644 self.Exons = Exons
645 self.Ests = Ests
646 self.OriginalEsts= OriginalEsts
647 self.Qualities = Qualities
648 self.Donors = Donors
649 self.Acceptors = Acceptors
650 self.UpCut = UpCut
651 self.StartPos = StartPos
652
653 self.AlternativeSequences = AlternativeSequences
654
655 #calc_info(self.Acceptors,self.Donors,self.Exons,self.Qualities)
656 #print 'leaving constructor...'
657
658 self.logfh = open('_qpalma_predict.log','w+')
659
660 # predict on training set
661 self.plog('##### Prediction on the training set #####\n')
662
663 self.predict(param_filename,0,beg,'TRAIN')
664
665 # predict on test set
666 self.plog('##### Prediction on the test set #####\n')
667 self.predict(param_filename,beg,end,'TEST')
668
669 self.plog('##### Finished prediction #####\n')
670 self.logfh.close()
671
672 def predict(self,param_filename,beg,end,set_flag):
673 """
674 Performing a prediction takes...
675
676 """
677
678 run = self.run
679
680 if self.run['mode'] == 'normal':
681 self.use_quality_scores = False
682
683 elif self.run['mode'] == 'using_quality_scores':
684 self.use_quality_scores = True
685 else:
686 assert(False)
687
688 Sequences = self.Sequences[beg:end]
689 Exons = self.Exons[beg:end]
690 Ests = self.Ests[beg:end]
691 OriginalEsts = self.OriginalEsts[beg:end]
692 Qualities = self.Qualities[beg:end]
693 Acceptors = self.Acceptors[beg:end]
694 Donors = self.Donors[beg:end]
695 UpCut = self.UpCut[beg:end]
696 StartPos = self.StartPos[beg:end]
697
698 AlternativeSequences = self.AlternativeSequences[beg:end]
699
700 # number of training instances
701 N = numExamples = len(Sequences)
702 assert len(Exons) == N and len(Ests) == N\
703 and len(Qualities) == N and len(Acceptors) == N\
704 and len(Donors) == N, 'The Exons,Acc,Don,.. arrays are of different lengths'
705 self.plog('Number of training examples: %d\n'% numExamples)
706
707 self.noImprovementCtr = 0
708 self.oldObjValue = 1e8
709
710 remove_duplicate_scores = self.run['remove_duplicate_scores']
711 print_matrix = self.run['print_matrix']
712 anzpath = self.run['anzpath']
713
714 param = cPickle.load(open(param_filename))
715
716 # Set the parameters such as limits penalties for the Plifs
717 [h,d,a,mmatrix,qualityPlifs] =\
718 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
719
720 #############################################################################################
721 # Prediction
722 #############################################################################################
723 self.plog('Starting prediction...\n')
724
725 donSP = self.run['numDonSuppPoints']
726 accSP = self.run['numAccSuppPoints']
727 lengthSP = self.run['numLengthSuppPoints']
728 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
729 numq = self.run['numQualSuppPoints']
730 totalQualSP = self.run['totalQualSuppPoints']
731
732 totalQualityPenalties = zeros((totalQualSP,1))
733
734 # where we store the predictions
735 allPredictions = []
736
737 # beginning of the prediction loop
738 for exampleIdx in range(numExamples):
739 self.plog('Loading example nr. %d...\n'%exampleIdx)
740
741 dna = Sequences[exampleIdx]
742 est = Ests[exampleIdx]
743
744 new_est = unbracket_est(est)
745
746
747 exons = Exons[exampleIdx]
748
749 current_up_cut = UpCut[exampleIdx]
750
751 current_start_pos = StartPos[exampleIdx]
752
753 currentAlternatives = AlternativeSequences[exampleIdx]
754
755 #est = est.replace('-','')
756 #original_est = OriginalEsts[exampleIdx]
757 #original_est = "".join(original_est)
758 #original_est = original_est.lower()
759 #currentSplitPos = SplitPos[exampleIdx]
760
761 if self.run['mode'] == 'normal':
762 quality = [40]*len(est)
763
764 if self.run['mode'] == 'using_quality_scores':
765 quality = Qualities[exampleIdx]
766
767 if not run['enable_quality_scores']:
768 quality = [40]*len(est)
769
770 don_supp = Donors[exampleIdx]
771 acc_supp = Acceptors[exampleIdx]
772
773 if not run['enable_splice_signals']:
774
775 for idx,elem in enumerate(don_supp):
776 if elem != -inf:
777 don_supp[idx] = 0.0
778
779 for idx,elem in enumerate(acc_supp):
780 if elem != -inf:
781 acc_supp[idx] = 0.0
782
783 current_example_predictions = []
784
785 # first make a prediction on the dna fragment which comes from the ground truth
786 current_prediction = self.calc_alignment(dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs)
787 current_prediction['exampleIdx'] = exampleIdx
788 current_prediction['start_pos'] = current_start_pos
789 current_prediction['label'] = True
790
791 current_example_predictions.append(current_prediction)
792
793 # then make predictions for all dna fragments that where occurring in
794 # the vmatch results
795 for alternative_alignment in currentAlternatives:
796 chr, strand, genomicSeq_start, genomicSeq_stop, currentLabel = alternative_alignment
797 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
798
799 if not run['enable_splice_signals']:
800 for idx,elem in enumerate(currentDon):
801 if elem != -inf:
802 currentDon[idx] = 0.0
803
804 for idx,elem in enumerate(currentAcc):
805 if elem != -inf:
806 currentAcc[idx] = 0.0
807
808 current_prediction = self.calc_alignment(currentDNASeq, est, exons,\
809 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs)
810 current_prediction['exampleIdx'] = exampleIdx
811 current_prediction['start_pos'] = current_start_pos
812 current_prediction['alternative_start_pos'] = genomicSeq_start
813 current_prediction['label'] = currentLabel
814
815 current_example_predictions.append(current_prediction)
816
817 allPredictions.append(current_example_predictions)
818
819 # end of the prediction loop we save all predictions in a pickle file and exit
820 cPickle.dump(allPredictions,open('%s_allPredictions_%s'%(run['name'],set_flag),'w+'))
821 print 'Prediction completed'
822
823
824 def calc_alignment(self, dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
825 """
826 Given two sequences and the parameters we calculate on alignment
827 """
828
829 run = self.run
830
831 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
832
833 #myalign wants the acceptor site on the g of the ag
834 acceptor = acceptor[1:]
835 acceptor.append(-inf)
836
837 dna = str(dna)
838 est = str(est)
839 dna_len = len(dna)
840 est_len = len(est)
841
842 ps = h.convert2SWIG()
843
844 __newSpliceAlign, __newEstAlign, __newWeightMatch, __newDPScores,\
845 __newQualityPlifsFeatures, __dna_array, __est_array =\
846 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
847
848 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
849
850 # old code removed
851
852 newSpliceAlign = __newSpliceAlign
853 newEstAlign = __newEstAlign
854 newWeightMatch = __newWeightMatch
855 newDPScores = __newDPScores
856 newQualityPlifsFeatures = __newQualityPlifsFeatures
857 dna_array = __dna_array
858 est_array = __est_array
859
860 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
861 newWeightMatch = newWeightMatch.reshape(1,mm_len)
862 true_map = [0]*2
863 true_map[0] = 1
864 pathNr = 0
865
866 _newSpliceAlign = newSpliceAlign.flatten().tolist()[0]
867 _newEstAlign = newEstAlign.flatten().tolist()[0]
868
869 if False:
870 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
871 self.plog(line1+'\n')
872 self.plog(line2+'\n')
873 self.plog(line3+'\n')
874
875 newExons = self.calculatePredictedExons(newSpliceAlign)
876
877 current_prediction = {'predExons':newExons, 'trueExons':exons,\
878 'dna':dna, 'est':est, 'DPScores':newDPScores}
879
880 return current_prediction
881
882
883 def calculatePredictedExons(self,SpliceAlign):
884 newExons = []
885 oldElem = -1
886 SpliceAlign = SpliceAlign.flatten().tolist()[0]
887 SpliceAlign.append(-1)
888 for pos,elem in enumerate(SpliceAlign):
889 if pos == 0:
890 oldElem = -1
891 else:
892 oldElem = SpliceAlign[pos-1]
893
894 if oldElem != 0 and elem == 0: # start of exon
895 newExons.append(pos)
896
897 if oldElem == 0 and elem != 0: # end of exon
898 newExons.append(pos)
899
900 return newExons
901
902
903 ###########################
904 # A simple command line
905 # interface
906 ###########################
907
908 if __name__ == '__main__':
909 mode = sys.argv[1]
910 run_obj_fn = sys.argv[2]
911
912 run_obj = cPickle.load(open(run_obj_fn))
913
914 qpalma = QPalma(run_obj)
915
916
917 if len(sys.argv) == 3 and mode == 'train':
918 qpalma.train()
919
920 elif len(sys.argv) == 4 and mode == 'predict':
921 param_filename = sys.argv[3]
922 assert os.path.exists(param_filename)
923 qpalma.evaluate(param_filename)
924 else:
925 print 'You have to choose between training or prediction mode:'
926 print 'python qpalma. py (train|predict) <param_file>'