+ training works,
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads together with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import sys
18 import cPickle
19 import pdb
20 import re
21 import os.path
22
23 from compile_dataset import getSpliceScores, get_seq_and_scores
24
25 import numpy
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy.linalg import norm
28
29 import QPalmaDP
30 import qpalma
31
32 #from qpalma.SIQP_CPX import SIQPSolver
33 from qpalma.SIQP_CVXOPT import SIQPSolver
34
35 from qpalma.DataProc import *
36 from qpalma.computeSpliceWeights import *
37 from qpalma.set_param_palma import *
38 from qpalma.computeSpliceAlignWithQuality import *
39 from qpalma.penalty_lookup_new import *
40 from qpalma.compute_donacc import *
41 from qpalma.TrainingParam import Param
42 from qpalma.Plif import Plf
43
44 from qpalma.tools.splicesites import getDonAccScores
45 from qpalma.Configuration import *
46
47 # this two imports are needed for the load genomic resp. interval query
48 # functions
49 from Genefinding import *
50 from genome_utils import load_genomic
51 from Utils import calc_stat, calc_info, pprint_alignment
52
53
54 def unbracket_est(est):
55 new_est = ''
56 e = 0
57
58 while True:
59 if e >= len(est):
60 break
61
62 if est[e] == '[':
63 new_est += est[e+2]
64 e += 4
65 else:
66 new_est += est[e]
67 e += 1
68
69 return "".join(new_est).lower()
70
71
72 class QPalma:
73 """
74 A training method for the QPalma project
75 """
76
77 def __init__(self,run):
78 self.ARGS = Param()
79 self.run = run
80
81 if self.run['mode'] == 'normal':
82 self.use_quality_scores = False
83
84 elif self.run['mode'] == 'using_quality_scores':
85 self.use_quality_scores = True
86 else:
87 assert(False)
88
89
90 def plog(self,string):
91 self.logfh.write(string)
92 self.logfh.flush()
93
94
95 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
96 """
97 Given the needed input this method calls the QPalma C module which
98 calculates a dynamic programming in order to obtain an alignment
99 """
100 run = self.run
101
102 dna_len = len(dna)
103 est_len = len(est)
104
105 #pdb.set_trace()
106
107 prb = QPalmaDP.createDoubleArrayFromList(quality)
108 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
109
110 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
111 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
112
113 d_len = len(donor)
114 donor = QPalmaDP.createDoubleArrayFromList(donor)
115 a_len = len(acceptor)
116 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
117
118 # Create the alignment object representing the interface to the C/C++ code.
119 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
120 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
121 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
122 currentAlignment.myalign( current_num_path, dna, dna_len,\
123 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
124 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
125 print_matrix)
126
127 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
128 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
129 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
130 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
131
132 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
133
134 if prediction_mode:
135 # part that is only needed for prediction
136 result_len = currentAlignment.getResultLength()
137 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
138 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
139
140 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
141
142 dna_array = [0.0]*result_len
143 est_array = [0.0]*result_len
144
145 for r_idx in range(result_len):
146 dna_array[r_idx] = c_dna_array[r_idx]
147 est_array[r_idx] = c_est_array[r_idx]
148
149 else:
150 dna_array = None
151 est_array = None
152
153 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
154 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
155
156 #print 'After calling getAlignmentResults...'
157
158 newSpliceAlign = zeros((current_num_path*dna_len,1))
159 newEstAlign = zeros((est_len*current_num_path,1))
160 newWeightMatch = zeros((current_num_path*mm_len,1))
161 newDPScores = zeros((current_num_path,1))
162 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
163
164 for i in range(dna_len*current_num_path):
165 newSpliceAlign[i] = c_SpliceAlign[i]
166
167 for i in range(est_len*current_num_path):
168 newEstAlign[i] = c_EstAlign[i]
169
170 for i in range(mm_len*current_num_path):
171 newWeightMatch[i] = c_WeightMatch[i]
172
173 for i in range(current_num_path):
174 newDPScores[i] = c_DPScores[i]
175
176 if self.use_quality_scores:
177 for i in range(run['totalQualSuppPoints']*current_num_path):
178 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
179
180 del c_SpliceAlign
181 del c_EstAlign
182 del c_WeightMatch
183 del c_DPScores
184 del c_qualityPlifsFeatures
185 del currentAlignment
186
187 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
188 newQualityPlifsFeatures, dna_array, est_array
189
190
191 def train(self):
192 run = self.run
193
194 full_working_path = os.path.join(run['experiment_path'],run['name'])
195
196 assert not os.path.exists(full_working_path)
197 os.mkdir(full_working_path)
198
199 assert os.path.exists(full_working_path)
200
201 # ATTENTION: Changing working directory
202 os.chdir(full_working_path)
203
204 cPickle.dump(run,open('run_object.pickle','w+'))
205
206 self.logfh = open('_qpalma_train.log','w+')
207
208 self.plog("Settings are:\n")
209 self.plog("%s\n"%str(run))
210
211 data_filename = self.run['dataset_filename']
212
213 SeqInfo, Exons, OriginalEsts, Qualities,\
214 AlternativeSequences = paths_load_data(data_filename,'training',None,self.ARGS)
215
216 # Load the whole dataset
217 if self.run['mode'] == 'normal':
218 self.use_quality_scores = False
219
220 elif self.run['mode'] == 'using_quality_scores':
221 self.use_quality_scores = True
222 else:
223 assert(False)
224
225 self.SeqInfo = SeqInfo
226 self.Exons = Exons
227 self.OriginalEsts= OriginalEsts
228 self.Qualities = Qualities
229
230 #calc_info(self.Exons,self.Qualities)
231
232 beg = run['training_begin']
233 end = run['training_end']
234
235 SeqInfo = SeqInfo[beg:end]
236 Exons = Exons[beg:end]
237 OriginalEsts= OriginalEsts[beg:end]
238 Qualities = Qualities[beg:end]
239
240 # number of training instances
241 N = numExamples = len(SeqInfo)
242 assert len(Exons) == N and len(OriginalEsts) == N and len(Qualities) == N,\
243 'The Exons,Acc,Don,.. arrays are of different lengths'
244
245 self.plog('Number of training examples: %d\n'% numExamples)
246
247 self.noImprovementCtr = 0
248 self.oldObjValue = 1e8
249
250 iteration_steps = run['iter_steps']
251 remove_duplicate_scores = run['remove_duplicate_scores']
252 print_matrix = run['print_matrix']
253 anzpath = run['anzpath']
254
255 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
256 param = Conf.fixedParam[:run['numFeatures']]
257
258 lengthSP = run['numLengthSuppPoints']
259 donSP = run['numDonSuppPoints']
260 accSP = run['numAccSuppPoints']
261 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
262 numq = run['numQualSuppPoints']
263 totalQualSP = run['totalQualSuppPoints']
264
265 # no intron length model
266 if not run['enable_intron_length']:
267 param[:lengthSP] *= 0.0
268
269 # Set the parameters such as limits penalties for the Plifs
270 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
271
272 # Initialize solver
273 self.plog('Initializing problem...\n')
274
275 solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
276 #solver = None
277
278 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
279 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
280
281 # stores the number of alignments done for each example (best path, second-best path etc.)
282 num_path = [anzpath]*numExamples
283 # stores the gap for each example
284 gap = [0.0]*numExamples
285 #############################################################################################
286 # Training
287 #############################################################################################
288 self.plog('Starting training...\n')
289
290 currentPhi = zeros((run['numFeatures'],1))
291 totalQualityPenalties = zeros((totalQualSP,1))
292
293 numConstPerRound = run['numConstraintsPerRound']
294 solver_call_ctr = 0
295
296 suboptimal_example = 0
297 iteration_nr = 0
298 param_idx = 0
299 const_added_ctr = 0
300
301 featureVectors = zeros((run['numFeatures'],numExamples))
302
303 # the main training loop
304 while True:
305 if iteration_nr == iteration_steps:
306 break
307
308 for exampleIdx in range(numExamples):
309 if (exampleIdx%100) == 0:
310 print 'Current example nr %d' % exampleIdx
311
312 currentSeqInfo = SeqInfo[exampleIdx]
313 chr,strand,up_cut,down_cut = currentSeqInfo
314
315 est = OriginalEsts[exampleIdx]
316 est = "".join(est)
317 est = est.lower()
318 est = unbracket_est(est)
319 est = est.replace('-','')
320
321 assert len(est) == run['read_size'], pdb.set_trace()
322 est_len = len(est)
323
324 original_est = OriginalEsts[exampleIdx]
325 original_est = "".join(original_est)
326 original_est = original_est.lower()
327
328 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
329 dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
330 dna_len = len(dna)
331
332 # splice score is located at g of ag
333 ag_tuple_pos = [p for p,e in enumerate(dna) if p>1 and dna[p-1]=='a' and dna[p]=='g' ]
334 assert ag_tuple_pos == [p for p,e in enumerate(acc_supp) if e != -inf and p > 1], pdb.set_trace()
335
336 gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
337 assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
338
339 if run['mode'] == 'normal':
340 quality = [40]*len(est)
341
342 if run['mode'] == 'using_quality_scores':
343 quality = Qualities[exampleIdx][0]
344
345 if not run['enable_quality_scores']:
346 quality = [40]*len(est)
347
348 # check whether the whole business is really related to invalid
349 # splice scores
350 #acc_supp = [0.0]*len(acc_supp)
351 #don_supp = [0.0]*len(don_supp)
352
353 original_exons = Exons[exampleIdx]
354 exons = original_exons - (up_cut-1)
355 exons[0,0] -= 1
356 exons[1,0] -= 1
357
358 fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
359
360 donor_elem = dna[exons[0,1]:exons[0,1]+2]
361 acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
362
363 if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
364 print 'invalid donor in example %d'% exampleIdx
365
366 if not ( acceptor_elem == 'ag' ):
367 print 'invalid acceptor in example %d'% exampleIdx
368
369 assert len(fetched_dna_subseq) == len(est), pdb.set_trace()
370
371 new_string = ''
372 for idx in range(len(est)):
373 dna_char = fetched_dna_subseq[idx]
374 est_char = est[idx]
375 if dna_char == est_char:
376 new_string += est_char
377 else:
378 new_string += '[%s%s]'%(dna_char,est_char)
379
380 if new_string != original_est:
381 print new_string,original_est
382 continue
383
384 if not run['enable_splice_signals']:
385 for idx,elem in enumerate(don_supp):
386 if elem != -inf:
387 don_supp[idx] = 0.0
388
389 for idx,elem in enumerate(acc_supp):
390 if elem != -inf:
391 acc_supp[idx] = 0.0
392
393 #pdb.set_trace()
394 #continue
395
396 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
397 if run['mode'] == 'using_quality_scores':
398 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
399 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
400 quality, qualityPlifs,run)
401 else:
402 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
403
404 dna_calc = dna_calc.replace('-','')
405
406 #print 'right before computeSpliceWeights exampleIdx %d' % exampleIdx
407 # Calculate the weights
408 trueWeightDon, trueWeightAcc, trueWeightIntron =\
409 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
410 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
411
412 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
413 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
414 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
415 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
416
417 if run['mode'] == 'using_quality_scores':
418 totalQualityPenalties = param[-totalQualSP:]
419 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
420
421 # Calculate w'phi(x,y) the total score of the alignment
422 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
423
424 # The allWeights vector is supposed to store the weight parameter
425 # of the true alignment as well as the weight parameters of the
426 # num_path[exampleIdx] other alignments
427 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
428 allWeights[:,0] = trueWeight[:,0]
429
430 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
431 AlignmentScores[0] = trueAlignmentScore
432
433 ################## Calculate wrong alignment(s) ######################
434 # Compute donor, acceptor with penalty_lookup_new
435 # returns two double lists
436 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
437
438 #myalign wants the acceptor site on the g of the ag
439 #acceptor = acceptor[1:]
440 #acceptor.append(-inf)
441
442 #donor = [-inf] + donor[:-1]
443 #pdb.set_trace()
444
445 ps = h.convert2SWIG()
446
447 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
448 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
449 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
450 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
451
452 # old code removed
453
454 newSpliceAlign = _newSpliceAlign
455 newEstAlign = _newEstAlign
456 newWeightMatch = _newWeightMatch
457 newDPScores = _newDPScores
458 newQualityPlifsFeatures = _newQualityPlifsFeatures
459
460 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
461 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
462
463 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
464 # Calculate weights of the respective alignments. Note that we are
465 # calculating n-best alignments without hamming loss, so we
466 # have to keep track which of the n-best alignments correspond to
467 # the true one in order not to incorporate a true alignment in the
468 # constraints. To keep track of the true and false alignments we
469 # define an array true_map with a boolean indicating the
470 # equivalence to the true alignment for each decoded alignment.
471 true_map = [0]*(num_path[exampleIdx]+1)
472 true_map[0] = 1
473
474 for pathNr in range(num_path[exampleIdx]):
475 #print 'decodedWeights'
476 #print 'right before computeSpliceWeights(2) exampleIdx %d' % exampleIdx
477 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
478 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
479 acc_supp)
480
481 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
482 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
483 # Gewichte in restliche Zeilen der Matrix speichern
484 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
485
486 hpen = mat(h.penalties).reshape(len(h.penalties),1)
487 dpen = mat(d.penalties).reshape(len(d.penalties),1)
488 apen = mat(a.penalties).reshape(len(a.penalties),1)
489 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
490
491 featureVectors[:,exampleIdx] = allWeights[:,pathNr+1]
492
493 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
494
495 distinct_scores = False
496 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
497 distinct_scores = True
498
499 # Check wether scalar product + loss equals viterbi score
500 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
501 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
502 pdb.set_trace()
503
504 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
505 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
506
507 # if the pathNr-best alignment is very close to the true alignment consider it as true
508 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
509 true_map[pathNr+1] = 1
510
511 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
512 print "suboptimal_example %d\n" %exampleIdx
513 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
514 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
515
516 #pdb.set_trace()
517 suboptimal_example += 1
518 self.plog("suboptimal_example %d\n" %exampleIdx)
519
520 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
521 # this means that all n-best paths are to close to each other
522 # we have to extend the n-best search to a (n+1)-best
523 if len([elem for elem in true_map if elem == 1]) == len(true_map):
524 num_path[exampleIdx] = num_path[exampleIdx]+1
525
526 # Choose true and first false alignment for extending
527 firstFalseIdx = -1
528 for map_idx,elem in enumerate(true_map):
529 if elem == 0:
530 firstFalseIdx = map_idx
531 break
532
533 if False:
534 self.plog("Is considered as: %d\n" % true_map[1])
535
536 result_len = currentAlignment.getResultLength()
537 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
538 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
539
540 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
541
542 dna_array = [0.0]*result_len
543 est_array = [0.0]*result_len
544
545 for r_idx in range(result_len):
546 dna_array[r_idx] = c_dna_array[r_idx]
547 est_array[r_idx] = c_est_array[r_idx]
548
549 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
550 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
551
552 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
553 self.plog(line1+'\n')
554 self.plog(line2+'\n')
555 self.plog(line3+'\n')
556
557 # if there is at least one useful false alignment add the
558 # corresponding constraints to the optimization problem
559 if firstFalseIdx != -1:
560 firstFalseWeights = allWeights[:,firstFalseIdx]
561 differenceVector = trueWeight - firstFalseWeights
562 #pdb.set_trace()
563
564 #print 'NOT ADDING ANY CONSTRAINTS'
565 const_added = solver.addConstraint(differenceVector, exampleIdx)
566
567 const_added_ctr += 1
568 #
569 # end of one example processing
570 #
571
572 # call solver every nth example //added constraint
573 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
574 objValue,w,self.slacks = solver.solve()
575 solver_call_ctr += 1
576
577 if solver_call_ctr == 5:
578 numConstPerRound = 200
579 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
580
581 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
582 self.noImprovementCtr += 1
583
584 if self.noImprovementCtr == numExamples+1:
585 break
586
587 self.oldObjValue = objValue
588 print "objValue is %f" % objValue
589
590 sum_xis = 0
591 for elem in self.slacks:
592 sum_xis += elem
593
594 print 'sum of slacks is %f'% sum_xis
595 self.plog('sum of slacks is %f\n'% sum_xis)
596
597 for i in range(len(param)):
598 param[i] = w[i]
599
600 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
601 param_idx += 1
602 [h,d,a,mmatrix,qualityPlifs] =\
603 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
604
605 #
606 # end of one iteration through all examples
607 #
608
609 self.plog("suboptimal rounds %d\n" %suboptimal_example)
610
611 if self.noImprovementCtr == numExamples*2:
612 break
613
614 iteration_nr += 1
615
616 #
617 # end of optimization
618 #
619 print 'Training completed'
620
621 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
622 self.logfh.close()
623
624
625 ###############################################################################
626 #
627 # End of the code needed for training
628 #
629 #
630 # Begin of code for prediction
631 #
632 ###############################################################################
633
634
635 def evaluate(self,param_filename):
636 run = self.run
637 beg = run['prediction_begin']
638 end = run['prediction_end']
639
640 data_filename = self.run['dataset_filename']
641
642 SeqInfo, Exons, OriginalEsts, Qualities,\
643 AlternativeSequences = paths_load_data(data_filename,'training',None,self.ARGS)
644
645 self.SeqInfo = SeqInfo
646 self.Exons = Exons
647 self.OriginalEsts= OriginalEsts
648 self.Qualities = Qualities
649 self.AlternativeSequences = AlternativeSequences
650
651 #calc_info(self.Acceptors,self.Donors,self.Exons,self.Qualities)
652 #print 'leaving constructor...'
653
654 self.logfh = open('_qpalma_predict.log','w+')
655
656 # predict on training set
657 self.plog('##### Prediction on the training set #####\n')
658
659 self.predict(param_filename,0,beg,'TRAIN')
660
661 # predict on test set
662 self.plog('##### Prediction on the test set #####\n')
663 self.predict(param_filename,beg,end,'TEST')
664
665 self.plog('##### Finished prediction #####\n')
666 self.logfh.close()
667
668 def predict(self,param_filename,beg,end,set_flag):
669 """
670 Performing a prediction takes...
671
672 """
673
674 run = self.run
675
676 if self.run['mode'] == 'normal':
677 self.use_quality_scores = False
678
679 elif self.run['mode'] == 'using_quality_scores':
680 self.use_quality_scores = True
681 else:
682 assert(False)
683
684 SeqInfo = self.SeqInfo[beg:end]
685 Exons = self.Exons[beg:end]
686 OriginalEsts= self.OriginalEsts[beg:end]
687 Qualities = self.Qualities[beg:end]
688 AlternativeSequences = self.AlternativeSequences[beg:end]
689
690 # number of training instances
691 N = numExamples = len(SeqInfo)
692 assert len(Exons) == N and len(OriginalEsts) == N\
693 and len(Qualities) == N and len(AlternativeSequences) == N,\
694 'The Exons,Acc,Don,.. arrays are of different lengths'
695
696 self.plog('Number of training examples: %d\n'% numExamples)
697
698 self.noImprovementCtr = 0
699 self.oldObjValue = 1e8
700
701 remove_duplicate_scores = self.run['remove_duplicate_scores']
702 print_matrix = self.run['print_matrix']
703 anzpath = self.run['anzpath']
704
705 param = cPickle.load(open(param_filename))
706
707 # Set the parameters such as limits penalties for the Plifs
708 [h,d,a,mmatrix,qualityPlifs] =\
709 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
710
711 #############################################################################################
712 # Prediction
713 #############################################################################################
714 self.plog('Starting prediction...\n')
715
716 donSP = self.run['numDonSuppPoints']
717 accSP = self.run['numAccSuppPoints']
718 lengthSP = self.run['numLengthSuppPoints']
719 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
720 numq = self.run['numQualSuppPoints']
721 totalQualSP = self.run['totalQualSuppPoints']
722
723 totalQualityPenalties = zeros((totalQualSP,1))
724
725 # where we store the predictions
726 allPredictions = []
727
728 # beginning of the prediction loop
729 for exampleIdx in range(numExamples):
730 self.plog('Loading example nr. %d...\n'%exampleIdx)
731
732 currentSeqInfo = SeqInfo[exampleIdx]
733 chr,strand,up_cut,down_cut = currentSeqInfo
734
735 est = OriginalEsts[exampleIdx]
736 est = "".join(est)
737 est = est.lower()
738 est = unbracket_est(est)
739 est = est.replace('-','')
740
741 assert len(est) == run['read_size'], pdb.set_trace()
742 est_len = len(est)
743
744 original_est = OriginalEsts[exampleIdx]
745 original_est = "".join(original_est)
746 original_est = original_est.lower()
747
748 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
749 dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
750 dna_len = len(dna)
751
752 # splice score is located at g of ag
753 ag_tuple_pos = [p for p,e in enumerate(dna) if p>1 and dna[p-1]=='a' and dna[p]=='g' ]
754 assert ag_tuple_pos == [p for p,e in enumerate(acc_supp) if e != -inf and p > 1], pdb.set_trace()
755
756 gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
757 assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
758
759 if run['mode'] == 'normal':
760 quality = [40]*len(est)
761
762 if run['mode'] == 'using_quality_scores':
763 quality = Qualities[exampleIdx][0]
764
765 if not run['enable_quality_scores']:
766 quality = [40]*len(est)
767
768 # check whether the whole business is really related to invalid
769 # splice scores
770 #acc_supp = [0.0]*len(acc_supp)
771 #don_supp = [0.0]*len(don_supp)
772
773 original_exons = Exons[exampleIdx]
774 exons = original_exons - (up_cut-1)
775 exons[0,0] -= 1
776 exons[1,0] -= 1
777
778 fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
779
780 donor_elem = dna[exons[0,1]:exons[0,1]+2]
781 acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
782
783 if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
784 print 'invalid donor in example %d'% exampleIdx
785
786 if not ( acceptor_elem == 'ag' ):
787 print 'invalid acceptor in example %d'% exampleIdx
788
789 assert len(fetched_dna_subseq) == len(est), pdb.set_trace()
790
791 new_string = ''
792 for idx in range(len(est)):
793 dna_char = fetched_dna_subseq[idx]
794 est_char = est[idx]
795 if dna_char == est_char:
796 new_string += est_char
797 else:
798 new_string += '[%s%s]'%(dna_char,est_char)
799
800 if new_string != original_est:
801 print new_string,original_est
802 print exampleIdx
803 continue
804
805 if not run['enable_splice_signals']:
806 for idx,elem in enumerate(don_supp):
807 if elem != -inf:
808 don_supp[idx] = 0.0
809
810 for idx,elem in enumerate(acc_supp):
811 if elem != -inf:
812 acc_supp[idx] = 0.0
813
814 currentAlternatives = AlternativeSequences[exampleIdx]
815 current_example_predictions = []
816
817 # first make a prediction on the dna fragment which comes from the ground truth
818 current_prediction = self.calc_alignment(dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs)
819 current_prediction['exampleIdx'] = exampleIdx
820 current_prediction['start_pos'] = up_cut
821 current_prediction['label'] = True
822
823 current_example_predictions.append(current_prediction)
824
825 # then make predictions for all dna fragments that where occurring in
826 # the vmatch results
827 for alternative_alignment in currentAlternatives:
828 chr, strand, genomicSeq_start, genomicSeq_stop, currentLabel = alternative_alignment
829 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
830
831 if not run['enable_splice_signals']:
832 for idx,elem in enumerate(currentDon):
833 if elem != -inf:
834 currentDon[idx] = 0.0
835
836 for idx,elem in enumerate(currentAcc):
837 if elem != -inf:
838 currentAcc[idx] = 0.0
839
840 current_prediction = self.calc_alignment(currentDNASeq, est, exons,\
841 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs)
842 current_prediction['exampleIdx'] = exampleIdx
843 current_prediction['start_pos'] = up_cut
844 current_prediction['alternative_start_pos'] = genomicSeq_start
845 current_prediction['label'] = currentLabel
846
847 current_example_predictions.append(current_prediction)
848
849 allPredictions.append(current_example_predictions)
850
851 # end of the prediction loop we save all predictions in a pickle file and exit
852 cPickle.dump(allPredictions,open('%s_allPredictions_%s'%(run['name'],set_flag),'w+'))
853 print 'Prediction completed'
854
855
856 def calc_alignment(self, dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
857 """
858 Given two sequences and the parameters we calculate on alignment
859 """
860
861 run = self.run
862 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
863
864 dna = str(dna)
865 est = str(est)
866 dna_len = len(dna)
867 est_len = len(est)
868
869 ps = h.convert2SWIG()
870
871 newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
872 newQualityPlifsFeatures, dna_array, est_array =\
873 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
874
875 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
876
877 # old code removed
878 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
879 newWeightMatch = newWeightMatch.reshape(1,mm_len)
880 true_map = [0]*2
881 true_map[0] = 1
882 pathNr = 0
883
884 _newSpliceAlign = newSpliceAlign.flatten().tolist()[0]
885 _newEstAlign = newEstAlign.flatten().tolist()[0]
886
887 if False:
888 line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
889 self.plog(line1+'\n')
890 self.plog(line2+'\n')
891 self.plog(line3+'\n')
892
893 newExons = self.calculatePredictedExons(newSpliceAlign)
894
895 current_prediction = {'predExons':newExons, 'trueExons':exons,\
896 'dna':dna, 'est':est, 'DPScores':newDPScores}
897
898 return current_prediction
899
900
901 def calculatePredictedExons(self,SpliceAlign):
902 newExons = []
903 oldElem = -1
904 SpliceAlign = SpliceAlign.flatten().tolist()[0]
905 SpliceAlign.append(-1)
906 for pos,elem in enumerate(SpliceAlign):
907 if pos == 0:
908 oldElem = -1
909 else:
910 oldElem = SpliceAlign[pos-1]
911
912 if oldElem != 0 and elem == 0: # start of exon
913 newExons.append(pos)
914
915 if oldElem == 0 and elem != 0: # end of exon
916 newExons.append(pos)
917
918 return newExons
919
920
921 ###########################
922 # A simple command line
923 # interface
924 ###########################
925
926 if __name__ == '__main__':
927 mode = sys.argv[1]
928 run_obj_fn = sys.argv[2]
929
930 run_obj = cPickle.load(open(run_obj_fn))
931
932 qpalma = QPalma(run_obj)
933
934
935 if len(sys.argv) == 3 and mode == 'train':
936 qpalma.train()
937
938 elif len(sys.argv) == 4 and mode == 'predict':
939 param_filename = sys.argv[3]
940 assert os.path.exists(param_filename)
941 qpalma.evaluate(param_filename)
942 else:
943 print 'You have to choose between training or prediction mode:'
944 print 'python qpalma. py (train|predict) <param_file>'