+ changed python and c code in order to perform easier parameter handling and
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads together with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import array
18 import cPickle
19 import os.path
20 import pdb
21 import sys
22
23 from qpalma.sequence_utils import SeqSpliceInfo,DataAccessWrapper,unbracket_seq
24
25 import numpy
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy.linalg import norm
28
29 import QPalmaDP
30 import qpalma
31
32 #from qpalma.SIQP_CPX import SIQPSolver
33 #from qpalma.SIQP_CVXOPT import SIQPSolver
34
35 from qpalma.computeSpliceWeights import *
36 from qpalma.set_param_palma import *
37 from qpalma.computeSpliceAlignWithQuality import *
38 from qpalma.TrainingParam import Param
39 from qpalma.Plif import Plf,compute_donacc
40
41 # these two imports are needed for the load genomic resp. interval query
42 # functions
43 #from Genefinding import *
44 #from genome_utils import load_genomic
45 from Utils import calc_stat, calc_info, pprint_alignment, get_alignment
46
47 class SpliceSiteException:
48 pass
49
50
51 def getData(training_set,exampleKey,run):
52 currentSeqInfo,currentExons,original_est,currentQualities = training_set[exampleKey]
53 id,chr,strand,up_cut,down_cut = currentSeqInfo
54
55 est = original_est
56 est = "".join(est)
57 est = est.lower()
58 est = unbracket_est(est)
59 est = est.replace('-','')
60
61 assert len(est) == run['read_size'], pdb.set_trace()
62 est_len = len(est)
63
64 #original_est = OriginalEsts[exampleIdx]
65 original_est = "".join(original_est)
66 original_est = original_est.lower()
67
68 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
69 dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
70
71 # splice score is located at g of ag
72 #ag_tuple_pos = [p for p,e in enumerate(dna) if p>1 and dna[p-1]=='a' and dna[p]=='g' ]
73 #assert ag_tuple_pos == [p for p,e in enumerate(acc_supp) if e != -inf and p > 1], pdb.set_trace()
74 #gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
75 #assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
76
77 original_exons = currentExons
78 exons = original_exons - (up_cut-1)
79 exons[0,0] -= 1
80 exons[1,0] -= 1
81
82 if exons.shape == (2,2):
83 fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
84
85 donor_elem = dna[exons[0,1]:exons[0,1]+2]
86 acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
87
88 if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
89 print 'invalid donor in example %d'% exampleKey
90 raise SpliceSiteException
91
92 if not ( acceptor_elem == 'ag' ):
93 print 'invalid acceptor in example %d'% exampleKey
94 raise SpliceSiteException
95
96 assert len(fetched_dna_subseq) == len(est), pdb.set_trace()
97
98 return dna,est,acc_supp,don_supp,exons,original_est,currentQualities
99
100
101
102 class QPalma:
103 """
104 This class wraps the training and prediction functions for
105 the alignment.
106 """
107
108 def __init__(self,dmode=False):
109 self.ARGS = Param()
110 self.qpalma_debug_mode = dmode
111
112
113 def plog(self,string):
114 self.logfh.write(string)
115 self.logfh.flush()
116
117
118 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
119 """
120 Given the needed input this method calls the QPalma C module which
121 calculates a dynamic programming in order to obtain an alignment
122 """
123 run = self.run
124
125 dna_len = len(dna)
126 est_len = len(est)
127
128 prb = QPalmaDP.createDoubleArrayFromList(quality)
129 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
130
131 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
132 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
133
134 d_len = len(donor)
135 donor = QPalmaDP.createDoubleArrayFromList(donor)
136 a_len = len(acceptor)
137 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
138
139 # Create the alignment object representing the interface to the C/C++ code.
140 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
141 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
142 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
143 currentAlignment.myalign( current_num_path, dna, dna_len,\
144 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
145 acceptor, a_len, c_qualityPlifs, run['remove_duplicate_scores'],\
146 run['print_matrix'] )
147
148 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
149 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
150 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
151 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
152
153 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
154
155 if prediction_mode:
156 # part that is only needed for prediction
157 result_len = currentAlignment.getResultLength()
158 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
159 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
160
161 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
162 _dna_array,_est_array = currentAlignment.getAlignmentArraysNew()
163
164 dna_array = [0.0]*result_len
165 est_array = [0.0]*result_len
166
167 for r_idx in range(result_len):
168 dna_array[r_idx] = c_dna_array[r_idx]
169 est_array[r_idx] = c_est_array[r_idx]
170
171 else:
172 dna_array = None
173 est_array = None
174
175 assert dna_array == _dna_array
176 assert est_array == _est_array
177
178 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
179 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
180
181 _SpliceAlign, _EstAlign, _WeightMatch, _DPScores, _newQualityPlifsFeatures =\
182 currentAlignment.getAlignmentResultsNew()
183
184 newSpliceAlign = zeros((current_num_path*dna_len,1))
185 newEstAlign = zeros((est_len*current_num_path,1))
186 newWeightMatch = zeros((current_num_path*mm_len,1))
187 newDPScores = zeros((current_num_path,1))
188 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
189
190 for i in range(dna_len*current_num_path):
191 newSpliceAlign[i] = c_SpliceAlign[i]
192
193 for i in range(est_len*current_num_path):
194 newEstAlign[i] = c_EstAlign[i]
195
196 for i in range(mm_len*current_num_path):
197 newWeightMatch[i] = c_WeightMatch[i]
198
199 for i in range(current_num_path):
200 newDPScores[i] = c_DPScores[i]
201
202 if self.use_quality_scores:
203 for i in range(run['totalQualSuppPoints']*current_num_path):
204 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
205
206 del c_SpliceAlign
207 del c_EstAlign
208 del c_WeightMatch
209 del c_DPScores
210 del c_qualityPlifsFeatures
211 del currentAlignment
212
213
214 assert newSpliceAlign.flatten().tolist()[0] == _SpliceAlign
215 assert newEstAlign.flatten().tolist()[0] == _EstAlign
216 assert newWeightMatch.flatten().tolist()[0] == _WeightMatch
217 assert newDPScores.flatten().tolist()[0] == _DPScores
218 assert newQualityPlifsFeatures.flatten().tolist()[0] == _newQualityPlifsFeatures
219
220 pdb.set_trace()
221
222 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
223 newQualityPlifsFeatures, dna_array, est_array
224
225
226 def train(self,run,training_set):
227 self.run = run
228
229 full_working_path = os.path.join(run['alignment_dir'],run['name'])
230
231 #assert not os.path.exists(full_working_path)
232 if not os.path.exists(full_working_path):
233 os.mkdir(full_working_path)
234
235 assert os.path.exists(full_working_path)
236
237 # ATTENTION: Changing working directory
238 os.chdir(full_working_path)
239
240 self.logfh = open('_qpalma_train.log','w+')
241 cPickle.dump(run,open('run_obj.pickle','w+'))
242
243 self.plog("Settings are:\n")
244 self.plog("%s\n"%str(run))
245
246 if self.run['mode'] == 'normal':
247 self.use_quality_scores = False
248
249 elif self.run['mode'] == 'using_quality_scores':
250 self.use_quality_scores = True
251 else:
252 assert(False)
253
254 numExamples = len(training_set)
255 self.plog('Number of training examples: %d\n'% numExamples)
256
257 self.noImprovementCtr = 0
258 self.oldObjValue = 1e8
259
260 iteration_steps = run['iter_steps']
261 remove_duplicate_scores = run['remove_duplicate_scores']
262 print_matrix = run['print_matrix']
263 anzpath = run['anzpath']
264
265 # Initialize parameter vector /
266 #param = Conf.fixedParam[:run['numFeatures']]
267 param = numpy.matlib.rand(run['numFeatures'],1)
268
269 lengthSP = run['numLengthSuppPoints']
270 donSP = run['numDonSuppPoints']
271 accSP = run['numAccSuppPoints']
272 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
273 numq = run['numQualSuppPoints']
274 totalQualSP = run['totalQualSuppPoints']
275
276 # no intron length model
277 if not run['enable_intron_length']:
278 param[:lengthSP] *= 0.0
279
280 # Set the parameters such as limits penalties for the Plifs
281 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
282
283 # Initialize solver
284 self.plog('Initializing problem...\n')
285
286 #try:
287 # solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
288 #except:
289 # self.plog('Got no license. Telling queue to reschedule job...\n')
290 # sys.exit(99)
291
292 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
293 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
294
295 # stores the number of alignments done for each example (best path, second-best path etc.)
296 num_path = [anzpath]*numExamples
297 # stores the gap for each example
298 gap = [0.0]*numExamples
299 #############################################################################################
300 # Training
301 #############################################################################################
302 self.plog('Starting training...\n')
303
304 currentPhi = zeros((run['numFeatures'],1))
305 totalQualityPenalties = zeros((totalQualSP,1))
306
307 numConstPerRound = run['numConstraintsPerRound']
308 solver_call_ctr = 0
309
310 suboptimal_example = 0
311 iteration_nr = 0
312 param_idx = 0
313 const_added_ctr = 0
314
315 featureVectors = zeros((run['numFeatures'],numExamples))
316
317 # the main training loop
318 while True:
319 if iteration_nr == iteration_steps:
320 break
321
322 for exampleIdx,example_key in enumerate(training_set.keys()):
323 print 'Current example %d' % example_key
324 try:
325 dna,est,acc_supp,don_supp,exons,original_est,currentQualities =\
326 getData(training_set,example_key,run)
327 except SpliceSiteException:
328 continue
329
330 dna_len = len(dna)
331
332 if run['mode'] == 'normal':
333 quality = [40]*len(est)
334
335 if run['mode'] == 'using_quality_scores':
336 quality = currentQualities[0]
337
338 if not run['enable_quality_scores']:
339 quality = [40]*len(est)
340
341 if not run['enable_splice_signals']:
342 for idx,elem in enumerate(don_supp):
343 if elem != -inf:
344 don_supp[idx] = 0.0
345
346 for idx,elem in enumerate(acc_supp):
347 if elem != -inf:
348 acc_supp[idx] = 0.0
349
350 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
351 if run['mode'] == 'using_quality_scores':
352 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
353 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
354 quality, qualityPlifs,run)
355 else:
356 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
357
358 dna_calc = dna_calc.replace('-','')
359
360 #print 'right before computeSpliceWeights exampleIdx %d' % exampleIdx
361 # Calculate the weights
362 trueWeightDon, trueWeightAcc, trueWeightIntron =\
363 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
364 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
365
366 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
367 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
368 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
369 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
370
371 if run['mode'] == 'using_quality_scores':
372 totalQualityPenalties = param[-totalQualSP:]
373 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
374
375 # Calculate w'phi(x,y) the total score of the alignment
376 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
377
378 # The allWeights vector is supposed to store the weight parameter
379 # of the true alignment as well as the weight parameters of the
380 # num_path[exampleIdx] other alignments
381 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
382 allWeights[:,0] = trueWeight[:,0]
383
384 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
385 AlignmentScores[0] = trueAlignmentScore
386
387 ################## Calculate wrong alignment(s) ######################
388 # Compute donor, acceptor with penalty_lookup_new
389 # returns two double lists
390 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
391
392 #myalign wants the acceptor site on the g of the ag
393 #acceptor = acceptor[1:]
394 #acceptor.append(-inf)
395
396 #donor = [-inf] + donor[:-1]
397
398 ps = h.convert2SWIG()
399
400 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
401 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
402 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
403 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
404
405 newSpliceAlign = _newSpliceAlign
406 newEstAlign = _newEstAlign
407 newWeightMatch = _newWeightMatch
408 newDPScores = _newDPScores
409 newQualityPlifsFeatures = _newQualityPlifsFeatures
410
411 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
412 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
413
414 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
415 # Calculate weights of the respective alignments. Note that we are
416 # calculating n-best alignments without hamming loss, so we
417 # have to keep track which of the n-best alignments correspond to
418 # the true one in order not to incorporate a true alignment in the
419 # constraints. To keep track of the true and false alignments we
420 # define an array true_map with a boolean indicating the
421 # equivalence to the true alignment for each decoded alignment.
422 true_map = [0]*(num_path[exampleIdx]+1)
423 true_map[0] = 1
424
425 for pathNr in range(num_path[exampleIdx]):
426 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
427 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
428 acc_supp)
429
430 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
431 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
432 # Gewichte in restliche Zeilen der Matrix speichern
433 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
434
435 hpen = mat(h.penalties).reshape(len(h.penalties),1)
436 dpen = mat(d.penalties).reshape(len(d.penalties),1)
437 apen = mat(a.penalties).reshape(len(a.penalties),1)
438 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
439
440 featureVectors[:,exampleIdx] = allWeights[:,pathNr+1]
441
442 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
443
444 distinct_scores = False
445 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
446 distinct_scores = True
447
448 # Check wether scalar product + loss equals viterbi score
449 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
450 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
451 pdb.set_trace()
452
453 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
454 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
455
456 # if the pathNr-best alignment is very close to the true alignment consider it as true
457 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
458 true_map[pathNr+1] = 1
459
460 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
461 print "suboptimal_example %d\n" %exampleIdx
462 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
463 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
464
465 #pdb.set_trace()
466 suboptimal_example += 1
467 self.plog("suboptimal_example %d\n" %exampleIdx)
468
469 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
470 # this means that all n-best paths are to close to each other
471 # we have to extend the n-best search to a (n+1)-best
472 if len([elem for elem in true_map if elem == 1]) == len(true_map):
473 num_path[exampleIdx] = num_path[exampleIdx]+1
474
475 # Choose true and first false alignment for extending
476 firstFalseIdx = -1
477 for map_idx,elem in enumerate(true_map):
478 if elem == 0:
479 firstFalseIdx = map_idx
480 break
481
482 if False:
483 self.plog("Is considered as: %d\n" % true_map[1])
484
485 result_len = currentAlignment.getResultLength()
486 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
487 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
488
489 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
490
491 dna_array = [0.0]*result_len
492 est_array = [0.0]*result_len
493
494 for r_idx in range(result_len):
495 dna_array[r_idx] = c_dna_array[r_idx]
496 est_array[r_idx] = c_est_array[r_idx]
497
498 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
499 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
500
501 #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
502 #self.plog(line1+'\n')
503 #self.plog(line2+'\n')
504 #self.plog(line3+'\n')
505
506 # if there is at least one useful false alignment add the
507 # corresponding constraints to the optimization problem
508 if firstFalseIdx != -1:
509 firstFalseWeights = allWeights[:,firstFalseIdx]
510 differenceVector = trueWeight - firstFalseWeights
511 #pdb.set_trace()
512
513 #print 'NOT ADDING ANY CONSTRAINTS'
514 const_added = solver.addConstraint(differenceVector, exampleIdx)
515
516 const_added_ctr += 1
517 #
518 # end of one example processing
519 #
520
521 # call solver every nth example //added constraint
522 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
523 objValue,w,self.slacks = solver.solve()
524 solver_call_ctr += 1
525
526 if solver_call_ctr == 5:
527 numConstPerRound = 200
528 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
529
530 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
531 self.noImprovementCtr += 1
532
533 if self.noImprovementCtr == numExamples+1:
534 break
535
536 self.oldObjValue = objValue
537 print "objValue is %f" % objValue
538
539 sum_xis = 0
540 for elem in self.slacks:
541 sum_xis += elem
542
543 print 'sum of slacks is %f'% sum_xis
544 self.plog('sum of slacks is %f\n'% sum_xis)
545
546 for i in range(len(param)):
547 param[i] = w[i]
548
549 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
550 param_idx += 1
551 [h,d,a,mmatrix,qualityPlifs] =\
552 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
553
554 #
555 # end of one iteration through all examples
556 #
557
558 self.plog("suboptimal rounds %d\n" %suboptimal_example)
559
560 if self.noImprovementCtr == numExamples*2:
561 break
562
563 iteration_nr += 1
564
565 #
566 # end of optimization
567 #
568 print 'Training completed'
569
570 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
571 self.logfh.close()
572
573
574 ###############################################################################
575 #
576 # End of the code needed for training
577 #
578 # Begin of code for prediction
579 #
580 ###############################################################################
581
582 def init_prediction(self,run,dataset_fn,prediction_keys,param_fn,seqInfo,set_name,):
583 """
584 Performing a prediction takes...
585 """
586 self.run = run
587 self.set_name = set_name
588
589 #full_working_path = os.path.join(run['alignment_dir'],run['name'])
590 full_working_path = run['result_dir']
591
592 print 'full_working_path is %s' % full_working_path
593
594 #assert not os.path.exists(full_working_path)
595 if not os.path.exists(full_working_path):
596 os.mkdir(full_working_path)
597
598 assert os.path.exists(full_working_path)
599
600 # ATTENTION: Changing working directory
601 os.chdir(full_working_path)
602
603 self.logfh = open('_qpalma_predict_%s.log'%set_name,'w+')
604
605
606 # number of prediction instances
607 self.plog('Number of prediction examples: %d\n'% len(prediction_keys))
608
609 # load dataset and fetch instances that shall be predicted
610 dataset = cPickle.load(open(dataset_fn))
611
612 prediction_set = {}
613 for key in prediction_keys:
614 prediction_set[key] = dataset[key]
615
616 # we do not need the full dataset anymore
617 del dataset
618
619 # Load parameter vector to predict with
620 param = cPickle.load(open(param_fn))
621
622 self.predict(run,prediction_set,param,seqInfo)
623
624
625 def predict(self,run,prediction_set,param,seqInfo):
626 """
627 This method...
628 """
629
630 self.run = run
631
632 if self.run['mode'] == 'normal':
633 self.use_quality_scores = False
634
635 elif self.run['mode'] == 'using_quality_scores':
636 self.use_quality_scores = True
637 else:
638 assert(False)
639
640 # Set the parameters such as limits/penalties for the Plifs
641 [h,d,a,mmatrix,qualityPlifs] =\
642 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
643
644 #############################################################################################
645 # Prediction
646 #############################################################################################
647
648 if not self.qpalma_debug_mode:
649 self.plog('Starting prediction...\n')
650
651 #donSP = self.run['numDonSuppPoints']
652 #accSP = self.run['numAccSuppPoints']
653 #lengthSP = self.run['numLengthSuppPoints']
654 #mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
655 #numq = self.run['numQualSuppPoints']
656 #totalQualSP = self.run['totalQualSuppPoints']
657 #totalQualityPenalties = zeros((totalQualSP,1))
658
659 self.problem_ctr = 0
660
661 # where we store the predictions
662 allPredictions = []
663
664 # we take the first quality vector of the tuple of quality vectors
665 quality_index = 0
666
667 # fetch the data needed
668 #g_dir = run['dna_flat_files'] #'/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
669 #acc_dir = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/acc'
670 #don_dir = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/don'
671
672 #g_fmt = 'chr%d.dna.flat'
673 #s_fmt = 'contig_%d%s'
674
675 #num_chromo = 6
676
677 #accessWrapper = DataAccessWrapper(g_dir,acc_dir,don_dir,g_fmt,s_fmt)
678 #seqInfo = SeqSpliceInfo(accessWrapper,range(1,num_chromo))
679
680 # beginning of the prediction loop
681 for example_key in prediction_set.keys():
682 print 'Current example %d' % example_key
683
684 for example in prediction_set[example_key]:
685
686 currentSeqInfo,read,currentQualities = example
687
688 id,chromo,strand,genomicSeq_start,genomicSeq_stop =\
689 currentSeqInfo
690
691 if not self.qpalma_debug_mode:
692 self.plog('Loading example id: %d...\n'% int(id))
693
694 if run['mode'] == 'normal':
695 quality = [40]*len(read)
696
697 if run['mode'] == 'using_quality_scores':
698 quality = currentQualities[quality_index]
699
700 if not run['enable_quality_scores']:
701 quality = [40]*len(read)
702
703 try:
704 currentDNASeq, currentAcc, currentDon = seqInfo.get_seq_and_scores(chromo,strand,genomicSeq_start,genomicSeq_stop)
705 except:
706 self.problem_ctr += 1
707 continue
708
709 if not run['enable_splice_signals']:
710 for idx,elem in enumerate(currentDon):
711 if elem != -inf:
712 currentDon[idx] = 0.0
713
714 for idx,elem in enumerate(currentAcc):
715 if elem != -inf:
716 currentAcc[idx] = 0.0
717
718 current_prediction = self.calc_alignment(currentDNASeq, read,\
719 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs)
720
721 current_prediction['id'] = id
722 current_prediction['chr'] = chromo
723 current_prediction['strand'] = strand
724 current_prediction['start_pos'] = genomicSeq_start
725
726 allPredictions.append(current_prediction)
727
728 if not self.qpalma_debug_mode:
729 self.finalizePrediction(allPredictions)
730 else:
731 return allPredictions
732
733
734 def finalizePrediction(self,allPredictions):
735 # end of the prediction loop we save all predictions in a pickle file and exit
736 cPickle.dump(allPredictions,open('%s.predictions.pickle'%(self.set_name),'w+'))
737 print 'Prediction completed'
738 self.plog('Prediction completed\n')
739 mes = 'Problem ctr %d' % self.problem_ctr
740 print mes
741 self.plog(mes+'\n')
742 self.logfh.close()
743
744
745 def calc_alignment(self, dna, read, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
746 """
747 Given two sequences and the parameters we calculate on alignment
748 """
749
750 run = self.run
751 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
752
753 dna = str(dna)
754 read = str(read)
755
756 if '-' in read:
757 self.plog('found gap\n')
758 read = read.replace('-','')
759 assert len(read) == Conf.read_size
760
761 dna_len = len(dna)
762 read_len = len(read)
763
764 ps = h.convert2SWIG()
765
766 newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
767 newQualityPlifsFeatures, dna_array, read_array =\
768 self.do_alignment(dna,read,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
769
770 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
771
772 # old code removed
773 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
774 newWeightMatch = newWeightMatch.reshape(1,mm_len)
775 true_map = [0]*2
776 true_map[0] = 1
777 pathNr = 0
778
779 #pdb.set_trace()
780
781 _newSpliceAlign = array.array('B',newSpliceAlign.flatten().tolist()[0])
782 _newEstAlign = array.array('B',newEstAlign.flatten().tolist()[0])
783
784 alignment = get_alignment(_newSpliceAlign,_newEstAlign, dna_array, read_array) #(qStart, qEnd, tStart, tEnd, num_exons, qExonSizes, qStarts, qEnds, tExonSizes, tStarts, tEnds)
785
786 dna_array = array.array('B',dna_array)
787 read_array = array.array('B',read_array)
788
789 newExons = self.calculatePredictedExons(newSpliceAlign)
790
791 current_prediction = {'predExons':newExons, 'dna':dna, 'read':read, 'DPScores':newDPScores,\
792 'alignment':alignment,'spliceAlign':_newSpliceAlign,'estAlign':_newEstAlign,\
793 'dna_array':dna_array, 'read_array':read_array }
794
795 return current_prediction
796
797
798 def calculatePredictedExons(self,SpliceAlign):
799 newExons = []
800 oldElem = -1
801 SpliceAlign = SpliceAlign.flatten().tolist()[0]
802 SpliceAlign.append(-1)
803 for pos,elem in enumerate(SpliceAlign):
804 if pos == 0:
805 oldElem = -1
806 else:
807 oldElem = SpliceAlign[pos-1]
808
809 if oldElem != 0 and elem == 0: # start of exon
810 newExons.append(pos)
811
812 if oldElem == 0 and elem != 0: # end of exon
813 newExons.append(pos)
814
815 return newExons