+ changed prediction
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads together with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import sys
18 import cPickle
19 import pdb
20 import re
21 import os.path
22
23 from compile_dataset import getSpliceScores, get_seq_and_scores
24
25 import numpy
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy.linalg import norm
28
29 import QPalmaDP
30 import qpalma
31
32 from qpalma.SIQP_CPX import SIQPSolver
33 #from qpalma.SIQP_CVXOPT import SIQPSolver
34
35 from qpalma.DataProc import *
36 from qpalma.computeSpliceWeights import *
37 from qpalma.set_param_palma import *
38 from qpalma.computeSpliceAlignWithQuality import *
39 from qpalma.penalty_lookup_new import *
40 from qpalma.compute_donacc import *
41 from qpalma.TrainingParam import Param
42 from qpalma.Plif import Plf
43
44 from qpalma.Configuration import *
45
46 # this two imports are needed for the load genomic resp. interval query
47 # functions
48 from Genefinding import *
49 from genome_utils import load_genomic
50 from Utils import calc_stat, calc_info, pprint_alignment, get_alignment
51
52 class SpliceSiteException:
53 pass
54
55
56 def unbracket_est(est):
57 new_est = ''
58 e = 0
59
60 while True:
61 if e >= len(est):
62 break
63
64 if est[e] == '[':
65 new_est += est[e+2]
66 e += 4
67 else:
68 new_est += est[e]
69 e += 1
70
71 return "".join(new_est).lower()
72
73
74 def getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run):
75 currentSeqInfo = SeqInfo[exampleIdx]
76 id,chr,strand,up_cut,down_cut,true_cut = currentSeqInfo
77
78 est = OriginalEsts[exampleIdx]
79 est = "".join(est)
80 est = est.lower()
81 est = unbracket_est(est)
82 est = est.replace('-','')
83
84 assert len(est) == run['read_size'], pdb.set_trace()
85 est_len = len(est)
86
87 original_est = OriginalEsts[exampleIdx]
88 original_est = "".join(original_est)
89 original_est = original_est.lower()
90
91 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
92 dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
93
94 # splice score is located at g of ag
95 ag_tuple_pos = [p for p,e in enumerate(dna) if p>1 and dna[p-1]=='a' and dna[p]=='g' ]
96 assert ag_tuple_pos == [p for p,e in enumerate(acc_supp) if e != -inf and p > 1], pdb.set_trace()
97
98 gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
99 assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
100
101 original_exons = Exons[exampleIdx]
102
103 exons = original_exons - (up_cut-1)
104 exons[0,0] -= 1
105 exons[1,0] -= 1
106
107 if exons.shape == (2,2):
108 fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
109
110 donor_elem = dna[exons[0,1]:exons[0,1]+2]
111 acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
112
113 if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
114 print 'invalid donor in example %d'% exampleIdx
115 raise SpliceSiteException
116
117 if not ( acceptor_elem == 'ag' ):
118 print 'invalid acceptor in example %d'% exampleIdx
119 raise SpliceSiteException
120
121 assert len(fetched_dna_subseq) == len(est), pdb.set_trace()
122
123 return dna,est,acc_supp,don_supp,exons,original_est
124
125
126 class QPalma:
127 """
128 A training method for the QPalma project
129 """
130
131 def __init__(self,run):
132 self.ARGS = Param()
133 self.run = run
134
135 if self.run['mode'] == 'normal':
136 self.use_quality_scores = False
137
138 elif self.run['mode'] == 'using_quality_scores':
139 self.use_quality_scores = True
140 else:
141 assert(False)
142
143
144 def plog(self,string):
145 self.logfh.write(string)
146 self.logfh.flush()
147
148
149 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
150 """
151 Given the needed input this method calls the QPalma C module which
152 calculates a dynamic programming in order to obtain an alignment
153 """
154 run = self.run
155
156 dna_len = len(dna)
157 est_len = len(est)
158
159 #pdb.set_trace()
160
161 prb = QPalmaDP.createDoubleArrayFromList(quality)
162 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
163
164 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
165 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
166
167 d_len = len(donor)
168 donor = QPalmaDP.createDoubleArrayFromList(donor)
169 a_len = len(acceptor)
170 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
171
172 # Create the alignment object representing the interface to the C/C++ code.
173 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
174 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
175 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
176 currentAlignment.myalign( current_num_path, dna, dna_len,\
177 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
178 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
179 print_matrix)
180
181 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
182 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
183 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
184 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
185
186 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
187
188 if prediction_mode:
189 # part that is only needed for prediction
190 result_len = currentAlignment.getResultLength()
191 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
192 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
193
194 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
195
196 dna_array = [0.0]*result_len
197 est_array = [0.0]*result_len
198
199 for r_idx in range(result_len):
200 dna_array[r_idx] = c_dna_array[r_idx]
201 est_array[r_idx] = c_est_array[r_idx]
202
203 else:
204 dna_array = None
205 est_array = None
206
207 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
208 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
209
210 #print 'After calling getAlignmentResults...'
211
212 newSpliceAlign = zeros((current_num_path*dna_len,1))
213 newEstAlign = zeros((est_len*current_num_path,1))
214 newWeightMatch = zeros((current_num_path*mm_len,1))
215 newDPScores = zeros((current_num_path,1))
216 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
217
218 for i in range(dna_len*current_num_path):
219 newSpliceAlign[i] = c_SpliceAlign[i]
220
221 for i in range(est_len*current_num_path):
222 newEstAlign[i] = c_EstAlign[i]
223
224 for i in range(mm_len*current_num_path):
225 newWeightMatch[i] = c_WeightMatch[i]
226
227 for i in range(current_num_path):
228 newDPScores[i] = c_DPScores[i]
229
230 if self.use_quality_scores:
231 for i in range(run['totalQualSuppPoints']*current_num_path):
232 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
233
234 del c_SpliceAlign
235 del c_EstAlign
236 del c_WeightMatch
237 del c_DPScores
238 del c_qualityPlifsFeatures
239 del currentAlignment
240
241 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
242 newQualityPlifsFeatures, dna_array, est_array
243
244
245 def train(self):
246 run = self.run
247
248 full_working_path = os.path.join(run['experiment_path'],run['name'])
249
250 #assert not os.path.exists(full_working_path)
251 if not os.path.exists(full_working_path):
252 os.mkdir(full_working_path)
253
254 assert os.path.exists(full_working_path)
255
256 # ATTENTION: Changing working directory
257 os.chdir(full_working_path)
258
259 cPickle.dump(run,open('run_object.pickle','w+'))
260
261 self.logfh = open('_qpalma_train.log','w+')
262
263 self.plog("Settings are:\n")
264 self.plog("%s\n"%str(run))
265
266 data_filename = self.run['dataset_filename']
267
268 SeqInfo, Exons, OriginalEsts, Qualities,\
269 AlternativeSequences = paths_load_data(data_filename)
270
271 # Load the whole dataset
272 if self.run['mode'] == 'normal':
273 self.use_quality_scores = False
274
275 elif self.run['mode'] == 'using_quality_scores':
276 self.use_quality_scores = True
277 else:
278 assert(False)
279
280 self.SeqInfo = SeqInfo
281 self.Exons = Exons
282 self.OriginalEsts= OriginalEsts
283 self.Qualities = Qualities
284
285 #calc_info(self.Exons,self.Qualities)
286
287 beg = run['training_begin']
288 end = run['training_end']
289
290 SeqInfo = SeqInfo[beg:end]
291 Exons = Exons[beg:end]
292 OriginalEsts= OriginalEsts[beg:end]
293 Qualities = Qualities[beg:end]
294
295 # number of training instances
296 N = numExamples = len(SeqInfo)
297 assert len(Exons) == N and len(OriginalEsts) == N and len(Qualities) == N,\
298 'The Exons,Acc,Don,.. arrays are of different lengths'
299
300 self.plog('Number of training examples: %d\n'% numExamples)
301
302 self.noImprovementCtr = 0
303 self.oldObjValue = 1e8
304
305 iteration_steps = run['iter_steps']
306 remove_duplicate_scores = run['remove_duplicate_scores']
307 print_matrix = run['print_matrix']
308 anzpath = run['anzpath']
309
310 # Initialize parameter vector /
311 #param = Conf.fixedParam[:run['numFeatures']]
312 param = numpy.matlib.rand(run['numFeatures'],1)
313
314 lengthSP = run['numLengthSuppPoints']
315 donSP = run['numDonSuppPoints']
316 accSP = run['numAccSuppPoints']
317 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
318 numq = run['numQualSuppPoints']
319 totalQualSP = run['totalQualSuppPoints']
320
321 # no intron length model
322 if not run['enable_intron_length']:
323 param[:lengthSP] *= 0.0
324
325 # Set the parameters such as limits penalties for the Plifs
326 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
327
328 # Initialize solver
329 self.plog('Initializing problem...\n')
330
331 try:
332 solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
333 except:
334 sys.exit(99)
335
336 #solver = None
337 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
338 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
339
340 # stores the number of alignments done for each example (best path, second-best path etc.)
341 num_path = [anzpath]*numExamples
342 # stores the gap for each example
343 gap = [0.0]*numExamples
344 #############################################################################################
345 # Training
346 #############################################################################################
347 self.plog('Starting training...\n')
348
349 currentPhi = zeros((run['numFeatures'],1))
350 totalQualityPenalties = zeros((totalQualSP,1))
351
352 numConstPerRound = run['numConstraintsPerRound']
353 solver_call_ctr = 0
354
355 suboptimal_example = 0
356 iteration_nr = 0
357 param_idx = 0
358 const_added_ctr = 0
359
360 featureVectors = zeros((run['numFeatures'],numExamples))
361
362 # the main training loop
363 while True:
364 if iteration_nr == iteration_steps:
365 break
366
367 for exampleIdx in range(numExamples):
368 if (exampleIdx%100) == 0:
369 print 'Current example nr %d' % exampleIdx
370
371 try:
372 dna,est,acc_supp,don_supp,exons,original_est =\
373 getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run)
374 except SpliceSiteException:
375 continue
376
377 dna_len = len(dna)
378
379 #if new_string != original_est:
380 # print new_string,original_est
381 # continue
382
383 if run['mode'] == 'normal':
384 quality = [40]*len(est)
385
386 if run['mode'] == 'using_quality_scores':
387 quality = Qualities[exampleIdx][0]
388
389 if not run['enable_quality_scores']:
390 quality = [40]*len(est)
391
392 if not run['enable_splice_signals']:
393 for idx,elem in enumerate(don_supp):
394 if elem != -inf:
395 don_supp[idx] = 0.0
396
397 for idx,elem in enumerate(acc_supp):
398 if elem != -inf:
399 acc_supp[idx] = 0.0
400
401 #pdb.set_trace()
402 #continue
403
404 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
405 if run['mode'] == 'using_quality_scores':
406 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
407 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
408 quality, qualityPlifs,run)
409 else:
410 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
411
412 dna_calc = dna_calc.replace('-','')
413
414 #print 'right before computeSpliceWeights exampleIdx %d' % exampleIdx
415 # Calculate the weights
416 trueWeightDon, trueWeightAcc, trueWeightIntron =\
417 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
418 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
419
420 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
421 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
422 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
423 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
424
425 if run['mode'] == 'using_quality_scores':
426 totalQualityPenalties = param[-totalQualSP:]
427 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
428
429 # Calculate w'phi(x,y) the total score of the alignment
430 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
431
432 # The allWeights vector is supposed to store the weight parameter
433 # of the true alignment as well as the weight parameters of the
434 # num_path[exampleIdx] other alignments
435 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
436 allWeights[:,0] = trueWeight[:,0]
437
438 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
439 AlignmentScores[0] = trueAlignmentScore
440
441 ################## Calculate wrong alignment(s) ######################
442 # Compute donor, acceptor with penalty_lookup_new
443 # returns two double lists
444 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
445
446 #myalign wants the acceptor site on the g of the ag
447 #acceptor = acceptor[1:]
448 #acceptor.append(-inf)
449
450 #donor = [-inf] + donor[:-1]
451 #pdb.set_trace()
452
453 ps = h.convert2SWIG()
454
455 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
456 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
457 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
458 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
459
460 # old code removed
461
462 newSpliceAlign = _newSpliceAlign
463 newEstAlign = _newEstAlign
464 newWeightMatch = _newWeightMatch
465 newDPScores = _newDPScores
466 newQualityPlifsFeatures = _newQualityPlifsFeatures
467
468 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
469 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
470
471 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
472 # Calculate weights of the respective alignments. Note that we are
473 # calculating n-best alignments without hamming loss, so we
474 # have to keep track which of the n-best alignments correspond to
475 # the true one in order not to incorporate a true alignment in the
476 # constraints. To keep track of the true and false alignments we
477 # define an array true_map with a boolean indicating the
478 # equivalence to the true alignment for each decoded alignment.
479 true_map = [0]*(num_path[exampleIdx]+1)
480 true_map[0] = 1
481
482 for pathNr in range(num_path[exampleIdx]):
483 #print 'decodedWeights'
484 #print 'right before computeSpliceWeights(2) exampleIdx %d' % exampleIdx
485 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
486 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
487 acc_supp)
488
489 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
490 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
491 # Gewichte in restliche Zeilen der Matrix speichern
492 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
493
494 hpen = mat(h.penalties).reshape(len(h.penalties),1)
495 dpen = mat(d.penalties).reshape(len(d.penalties),1)
496 apen = mat(a.penalties).reshape(len(a.penalties),1)
497 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
498
499 featureVectors[:,exampleIdx] = allWeights[:,pathNr+1]
500
501 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
502
503 distinct_scores = False
504 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
505 distinct_scores = True
506
507 # Check wether scalar product + loss equals viterbi score
508 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
509 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
510 pdb.set_trace()
511
512 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
513 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
514
515 # if the pathNr-best alignment is very close to the true alignment consider it as true
516 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
517 true_map[pathNr+1] = 1
518
519 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
520 print "suboptimal_example %d\n" %exampleIdx
521 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
522 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
523
524 #pdb.set_trace()
525 suboptimal_example += 1
526 self.plog("suboptimal_example %d\n" %exampleIdx)
527
528 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
529 # this means that all n-best paths are to close to each other
530 # we have to extend the n-best search to a (n+1)-best
531 if len([elem for elem in true_map if elem == 1]) == len(true_map):
532 num_path[exampleIdx] = num_path[exampleIdx]+1
533
534 # Choose true and first false alignment for extending
535 firstFalseIdx = -1
536 for map_idx,elem in enumerate(true_map):
537 if elem == 0:
538 firstFalseIdx = map_idx
539 break
540
541 if False:
542 self.plog("Is considered as: %d\n" % true_map[1])
543
544 result_len = currentAlignment.getResultLength()
545 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
546 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
547
548 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
549
550 dna_array = [0.0]*result_len
551 est_array = [0.0]*result_len
552
553 for r_idx in range(result_len):
554 dna_array[r_idx] = c_dna_array[r_idx]
555 est_array[r_idx] = c_est_array[r_idx]
556
557 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
558 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
559
560 #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
561 #self.plog(line1+'\n')
562 #self.plog(line2+'\n')
563 #self.plog(line3+'\n')
564
565 # if there is at least one useful false alignment add the
566 # corresponding constraints to the optimization problem
567 if firstFalseIdx != -1:
568 firstFalseWeights = allWeights[:,firstFalseIdx]
569 differenceVector = trueWeight - firstFalseWeights
570 #pdb.set_trace()
571
572 #print 'NOT ADDING ANY CONSTRAINTS'
573 const_added = solver.addConstraint(differenceVector, exampleIdx)
574
575 const_added_ctr += 1
576 #
577 # end of one example processing
578 #
579
580 # call solver every nth example //added constraint
581 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
582 objValue,w,self.slacks = solver.solve()
583 solver_call_ctr += 1
584
585 if solver_call_ctr == 5:
586 numConstPerRound = 200
587 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
588
589 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
590 self.noImprovementCtr += 1
591
592 if self.noImprovementCtr == numExamples+1:
593 break
594
595 self.oldObjValue = objValue
596 print "objValue is %f" % objValue
597
598 sum_xis = 0
599 for elem in self.slacks:
600 sum_xis += elem
601
602 print 'sum of slacks is %f'% sum_xis
603 self.plog('sum of slacks is %f\n'% sum_xis)
604
605 for i in range(len(param)):
606 param[i] = w[i]
607
608 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
609 param_idx += 1
610 [h,d,a,mmatrix,qualityPlifs] =\
611 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
612
613 #
614 # end of one iteration through all examples
615 #
616
617 self.plog("suboptimal rounds %d\n" %suboptimal_example)
618
619 if self.noImprovementCtr == numExamples*2:
620 break
621
622 iteration_nr += 1
623
624 #
625 # end of optimization
626 #
627 print 'Training completed'
628
629 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
630 self.logfh.close()
631
632
633 ###############################################################################
634 #
635 # End of the code needed for training
636 #
637 #
638 # Begin of code for prediction
639 #
640 ###############################################################################
641
642
643 def evaluate(self,param_filename):
644 run = self.run
645 beg = run['prediction_begin']
646 end = run['prediction_end']
647
648 data_filename = self.run['dataset_filename']
649
650 SeqInfo, Exons, OriginalEsts, Qualities,\
651 AlternativeSequences = paths_load_data(data_filename)
652
653 #AlternativeSequences = paths_load_data(data_filename,'training',None,self.ARGS)
654
655 self.SeqInfo = SeqInfo
656 self.Exons = Exons
657 self.OriginalEsts= OriginalEsts
658 self.Qualities = Qualities
659 self.AlternativeSequences = AlternativeSequences
660
661 #calc_info(self.Acceptors,self.Donors,self.Exons,self.Qualities)
662 #print 'leaving constructor...'
663
664 self.logfh = open('_qpalma_predict_%d.log'%run['id'],'w+')
665
666 # predict on training set
667 #self.plog('##### Prediction on the training set #####\n')
668 #self.predict(param_filename,0,beg,'TRAIN')
669
670 # predict on test set
671 self.plog('##### Prediction on the test set #####\n')
672 self.predict(param_filename,beg,end,'TEST')
673
674 self.plog('##### Finished prediction #####\n')
675 self.logfh.close()
676
677 def predict(self,param_filename,beg,end,set_flag):
678 """
679 Performing a prediction takes...
680
681 """
682
683 run = self.run
684
685 if self.run['mode'] == 'normal':
686 self.use_quality_scores = False
687
688 elif self.run['mode'] == 'using_quality_scores':
689 self.use_quality_scores = True
690 else:
691 assert(False)
692
693 SeqInfo = self.SeqInfo[beg:end]
694 Exons = self.Exons[beg:end]
695 OriginalEsts= self.OriginalEsts[beg:end]
696 Qualities = self.Qualities[beg:end]
697 AlternativeSequences = self.AlternativeSequences[beg:end]
698
699 # number of training instances
700 N = numExamples = len(SeqInfo)
701 assert len(Exons) == N and len(OriginalEsts) == N\
702 and len(Qualities) == N and len(AlternativeSequences) == N,\
703 'The Exons,Acc,Don,.. arrays are of different lengths'
704
705 self.plog('Number of training examples: %d\n'% numExamples)
706
707 self.noImprovementCtr = 0
708 self.oldObjValue = 1e8
709
710 remove_duplicate_scores = self.run['remove_duplicate_scores']
711 print_matrix = self.run['print_matrix']
712 anzpath = self.run['anzpath']
713
714 param = cPickle.load(open(param_filename))
715
716 # Set the parameters such as limits penalties for the Plifs
717 [h,d,a,mmatrix,qualityPlifs] =\
718 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
719
720 #############################################################################################
721 # Prediction
722 #############################################################################################
723 self.plog('Starting prediction...\n')
724
725 donSP = self.run['numDonSuppPoints']
726 accSP = self.run['numAccSuppPoints']
727 lengthSP = self.run['numLengthSuppPoints']
728 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
729 numq = self.run['numQualSuppPoints']
730 totalQualSP = self.run['totalQualSuppPoints']
731
732 totalQualityPenalties = zeros((totalQualSP,1))
733
734 # where we store the predictions
735 allPredictions = []
736
737 # beginning of the prediction loop
738 for exampleIdx in range(numExamples):
739 self.plog('Loading example nr. %d...\n'%exampleIdx)
740
741 currentSeqInfo = SeqInfo[exampleIdx]
742 id,chr,strand,up_cut,down_cut,true_cut = currentSeqInfo
743
744 try:
745 dna,est,acc_supp,don_supp,exons,original_est =\
746 getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run)
747 except SpliceSiteException:
748 continue
749
750 dna_len = len(dna)
751
752 if run['mode'] == 'normal':
753 quality = [40]*len(est)
754
755 if run['mode'] == 'using_quality_scores':
756 quality = Qualities[exampleIdx][0]
757
758 if not run['enable_quality_scores']:
759 quality = [40]*len(est)
760
761 if not run['enable_splice_signals']:
762 for idx,elem in enumerate(don_supp):
763 if elem != -inf:
764 don_supp[idx] = 0.0
765
766 for idx,elem in enumerate(acc_supp):
767 if elem != -inf:
768 acc_supp[idx] = 0.0
769
770 currentAlternatives = AlternativeSequences[exampleIdx]
771 current_example_predictions = []
772
773 # first make a prediction on the dna fragment which comes from the ground truth
774 #current_prediction = self.calc_alignment(dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs)
775 #current_prediction['exampleIdx'] = exampleIdx
776 #current_prediction['id'] = id
777 #current_prediction['start_pos'] = up_cut
778 #current_prediction['label'] = True
779 #current_prediction['true_cut'] = true_cut
780 #current_prediction['chr'] = chr
781 #current_prediction['strand'] = strand
782
783 #current_example_predictions.append(current_prediction)
784
785 # then make predictions for all dna fragments that where occurring in
786 # the vmatch results
787 for alternative_alignment in currentAlternatives:
788 chr, strand, genomicSeq_start, genomicSeq_stop, currentLabel = alternative_alignment
789
790 if not chr in range(1,6):
791 continue
792
793 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
794
795 if not run['enable_splice_signals']:
796 for idx,elem in enumerate(currentDon):
797 if elem != -inf:
798 currentDon[idx] = 0.0
799
800 for idx,elem in enumerate(currentAcc):
801 if elem != -inf:
802 currentAcc[idx] = 0.0
803
804 current_prediction = self.calc_alignment(currentDNASeq, est, exons,\
805 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs)
806 current_prediction['exampleIdx'] = exampleIdx
807 current_prediction['id'] = id
808 current_prediction['start_pos'] = up_cut
809 current_prediction['alternative_start_pos'] = genomicSeq_start
810 current_prediction['label'] = currentLabel
811 current_prediction['true_cut'] = true_cut
812 current_prediction['chr'] = chr
813 current_prediction['strand'] = strand
814
815 current_example_predictions.append(current_prediction)
816
817 allPredictions.append(current_example_predictions)
818
819 # end of the prediction loop we save all predictions in a pickle file and exit
820 cPickle.dump(allPredictions,open('%s_allPredictions_%s'%(run['name'],set_flag),'w+'))
821 print 'Prediction completed'
822
823
824 def calc_alignment(self, dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
825 """
826 Given two sequences and the parameters we calculate on alignment
827 """
828
829 run = self.run
830 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
831
832 dna = str(dna)
833 est = str(est)
834 dna_len = len(dna)
835 est_len = len(est)
836
837 ps = h.convert2SWIG()
838
839 newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
840 newQualityPlifsFeatures, dna_array, est_array =\
841 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
842
843 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
844
845 # old code removed
846 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
847 newWeightMatch = newWeightMatch.reshape(1,mm_len)
848 true_map = [0]*2
849 true_map[0] = 1
850 pathNr = 0
851
852 _newSpliceAlign = newSpliceAlign.flatten().tolist()[0]
853 _newEstAlign = newEstAlign.flatten().tolist()[0]
854
855 #if True:
856 alignment = get_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array) #(qStart, qEnd, tStart, tEnd, num_exons, qExonSizes, qStarts, qEnds, tExonSizes, tStarts, tEnds)
857 #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
858 #self.plog(line1+'\n')
859 #self.plog(line2+'\n')
860 #self.plog(line3+'\n')
861
862 #currentAlignmentString = ''
863
864 #for idx in range(len(line1)):
865 # dna_elem = line1[idx]
866 # match_elem = line2[idx]
867 # est_elem = line3[idx]
868
869 # if dna_elem != '' and est_elem != '' and match_elem != '':
870 # if dna_elem == est_elem:
871 # currentAlignmentString = '%s%s' % (currentAlignmentString,est_elem)
872
873 # if dna_elem != '' and est_elem != '' and match_elem == '':
874 # currentAlignmentString = '%s[%s%s]' %\
875 # (currentAlignmentString,dna_elem,est_elem)
876
877 # if dna_elem != '' and est_elem == '_':
878 # currentAlignmentString = '%s[%s-]' %\
879 # (currentAlignmentString,dna_elem)
880 #
881 # if dna_elem == '_' and est_elem != '':
882 # currentAlignmentString = '%s[-%s]' %\
883 # (currentAlignmentString,est_elem)
884
885
886 newExons = self.calculatePredictedExons(newSpliceAlign)
887
888 current_prediction = {'predExons':newExons, 'trueExons':exons,\
889 'dna':dna, 'est':est, 'DPScores':newDPScores,\
890 'alignment':alignment}
891
892 return current_prediction
893
894
895 def calculatePredictedExons(self,SpliceAlign):
896 newExons = []
897 oldElem = -1
898 SpliceAlign = SpliceAlign.flatten().tolist()[0]
899 SpliceAlign.append(-1)
900 for pos,elem in enumerate(SpliceAlign):
901 if pos == 0:
902 oldElem = -1
903 else:
904 oldElem = SpliceAlign[pos-1]
905
906 if oldElem != 0 and elem == 0: # start of exon
907 newExons.append(pos)
908
909 if oldElem == 0 and elem != 0: # end of exon
910 newExons.append(pos)
911
912 return newExons
913
914
915 ###########################
916 # A simple command line
917 # interface
918 ###########################
919
920 if __name__ == '__main__':
921 mode = sys.argv[1]
922 run_obj_fn = sys.argv[2]
923
924 run_obj = cPickle.load(open(run_obj_fn))
925
926 qpalma = QPalma(run_obj)
927
928
929 if len(sys.argv) == 3 and mode == 'train':
930 qpalma.train()
931
932 elif len(sys.argv) == 4 and mode == 'predict':
933 param_filename = sys.argv[3]
934 assert os.path.exists(param_filename)
935 qpalma.evaluate(param_filename)
936 else:
937 print 'You have to choose between training or prediction mode:'
938 print 'python qpalma. py (train|predict) <param_file>'