git-svn-id: http://svn.tuebingen.mpg.de/ag-raetsch/projects/QPalma@8599 e1793c9e...
[qpalma.git] / scripts / qpalma_main.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 # The QPalma project aims at extending the Palma project
7 # to be able to use Solexa reads together with their
8 # quality scores.
9 #
10 # This file represents the conversion of the main matlab
11 # training loop for Palma to Python.
12 #
13 # Author: Fabio De Bona
14 #
15 ###########################################################
16
17 import sys
18 import cPickle
19 import pdb
20 import re
21 import os.path
22
23 from compile_dataset import getSpliceScores, get_seq_and_scores
24
25 import numpy
26 from numpy.matlib import mat,zeros,ones,inf
27 from numpy.linalg import norm
28
29 import QPalmaDP
30 import qpalma
31
32 from qpalma.SIQP_CPX import SIQPSolver
33 #from qpalma.SIQP_CVXOPT import SIQPSolver
34
35 from qpalma.DataProc import *
36 from qpalma.computeSpliceWeights import *
37 from qpalma.set_param_palma import *
38 from qpalma.computeSpliceAlignWithQuality import *
39 from qpalma.penalty_lookup_new import *
40 from qpalma.compute_donacc import *
41 from qpalma.TrainingParam import Param
42 from qpalma.Plif import Plf
43
44 from qpalma.tools.splicesites import getDonAccScores
45 from qpalma.Configuration import *
46
47 # this two imports are needed for the load genomic resp. interval query
48 # functions
49 from Genefinding import *
50 from genome_utils import load_genomic
51 from Utils import calc_stat, calc_info, pprint_alignment, get_alignment
52
53 class SpliceSiteException:
54 pass
55
56
57 def unbracket_est(est):
58 new_est = ''
59 e = 0
60
61 while True:
62 if e >= len(est):
63 break
64
65 if est[e] == '[':
66 new_est += est[e+2]
67 e += 4
68 else:
69 new_est += est[e]
70 e += 1
71
72 return "".join(new_est).lower()
73
74
75 def getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run):
76 currentSeqInfo = SeqInfo[exampleIdx]
77 id,chr,strand,up_cut,down_cut,true_cut = currentSeqInfo
78
79 est = OriginalEsts[exampleIdx]
80 est = "".join(est)
81 est = est.lower()
82 est = unbracket_est(est)
83 est = est.replace('-','')
84
85 assert len(est) == run['read_size'], pdb.set_trace()
86 est_len = len(est)
87
88 original_est = OriginalEsts[exampleIdx]
89 original_est = "".join(original_est)
90 original_est = original_est.lower()
91
92 dna_flat_files = '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
93 dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
94
95 # splice score is located at g of ag
96 ag_tuple_pos = [p for p,e in enumerate(dna) if p>1 and dna[p-1]=='a' and dna[p]=='g' ]
97 assert ag_tuple_pos == [p for p,e in enumerate(acc_supp) if e != -inf and p > 1], pdb.set_trace()
98
99 gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
100 assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
101
102 # check whether the whole business is really related to invalid
103 # splice scores
104 #acc_supp = [0.0]*len(acc_supp)
105 #don_supp = [0.0]*len(don_supp)
106
107 original_exons = Exons[exampleIdx]
108
109 exons = original_exons - (up_cut-1)
110 exons[0,0] -= 1
111 exons[1,0] -= 1
112
113 fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
114
115 donor_elem = dna[exons[0,1]:exons[0,1]+2]
116 acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
117
118 if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
119 print 'invalid donor in example %d'% exampleIdx
120 raise SpliceSiteException
121
122 if not ( acceptor_elem == 'ag' ):
123 print 'invalid acceptor in example %d'% exampleIdx
124 raise SpliceSiteException
125
126 assert len(fetched_dna_subseq) == len(est), pdb.set_trace()
127
128 #new_string = ''
129 #for idx in range(len(est)):
130 # dna_char = fetched_dna_subseq[idx]
131 # est_char = est[idx]
132 # if dna_char == est_char:
133 # new_string += est_char
134 # else:
135 # new_string += '[%s%s]'%(dna_char,est_char)
136
137 return dna,est,acc_supp,don_supp,exons,original_est
138
139
140 class QPalma:
141 """
142 A training method for the QPalma project
143 """
144
145 def __init__(self,run):
146 self.ARGS = Param()
147 self.run = run
148
149 if self.run['mode'] == 'normal':
150 self.use_quality_scores = False
151
152 elif self.run['mode'] == 'using_quality_scores':
153 self.use_quality_scores = True
154 else:
155 assert(False)
156
157
158 def plog(self,string):
159 self.logfh.write(string)
160 self.logfh.flush()
161
162
163 def do_alignment(self,dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode):
164 """
165 Given the needed input this method calls the QPalma C module which
166 calculates a dynamic programming in order to obtain an alignment
167 """
168 run = self.run
169
170 dna_len = len(dna)
171 est_len = len(est)
172
173 #pdb.set_trace()
174
175 prb = QPalmaDP.createDoubleArrayFromList(quality)
176 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
177
178 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
179 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
180
181 d_len = len(donor)
182 donor = QPalmaDP.createDoubleArrayFromList(donor)
183 a_len = len(acceptor)
184 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
185
186 # Create the alignment object representing the interface to the C/C++ code.
187 currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
188 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
189 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
190 currentAlignment.myalign( current_num_path, dna, dna_len,\
191 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
192 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
193 print_matrix)
194
195 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
196 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
197 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
198 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
199
200 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
201
202 if prediction_mode:
203 # part that is only needed for prediction
204 result_len = currentAlignment.getResultLength()
205 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
206 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
207
208 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
209
210 dna_array = [0.0]*result_len
211 est_array = [0.0]*result_len
212
213 for r_idx in range(result_len):
214 dna_array[r_idx] = c_dna_array[r_idx]
215 est_array[r_idx] = c_est_array[r_idx]
216
217 else:
218 dna_array = None
219 est_array = None
220
221 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
222 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
223
224 #print 'After calling getAlignmentResults...'
225
226 newSpliceAlign = zeros((current_num_path*dna_len,1))
227 newEstAlign = zeros((est_len*current_num_path,1))
228 newWeightMatch = zeros((current_num_path*mm_len,1))
229 newDPScores = zeros((current_num_path,1))
230 newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
231
232 for i in range(dna_len*current_num_path):
233 newSpliceAlign[i] = c_SpliceAlign[i]
234
235 for i in range(est_len*current_num_path):
236 newEstAlign[i] = c_EstAlign[i]
237
238 for i in range(mm_len*current_num_path):
239 newWeightMatch[i] = c_WeightMatch[i]
240
241 for i in range(current_num_path):
242 newDPScores[i] = c_DPScores[i]
243
244 if self.use_quality_scores:
245 for i in range(run['totalQualSuppPoints']*current_num_path):
246 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
247
248 del c_SpliceAlign
249 del c_EstAlign
250 del c_WeightMatch
251 del c_DPScores
252 del c_qualityPlifsFeatures
253 del currentAlignment
254
255 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
256 newQualityPlifsFeatures, dna_array, est_array
257
258
259 def train(self):
260 run = self.run
261
262 full_working_path = os.path.join(run['experiment_path'],run['name'])
263
264 #assert not os.path.exists(full_working_path)
265 if not os.path.exists(full_working_path):
266 os.mkdir(full_working_path)
267
268 assert os.path.exists(full_working_path)
269
270 # ATTENTION: Changing working directory
271 os.chdir(full_working_path)
272
273 cPickle.dump(run,open('run_object.pickle','w+'))
274
275 self.logfh = open('_qpalma_train.log','w+')
276
277 self.plog("Settings are:\n")
278 self.plog("%s\n"%str(run))
279
280 data_filename = self.run['dataset_filename']
281
282 SeqInfo, Exons, OriginalEsts, Qualities,\
283 AlternativeSequences = paths_load_data(data_filename)
284
285 # Load the whole dataset
286 if self.run['mode'] == 'normal':
287 self.use_quality_scores = False
288
289 elif self.run['mode'] == 'using_quality_scores':
290 self.use_quality_scores = True
291 else:
292 assert(False)
293
294 self.SeqInfo = SeqInfo
295 self.Exons = Exons
296 self.OriginalEsts= OriginalEsts
297 self.Qualities = Qualities
298
299 #calc_info(self.Exons,self.Qualities)
300
301 beg = run['training_begin']
302 end = run['training_end']
303
304 SeqInfo = SeqInfo[beg:end]
305 Exons = Exons[beg:end]
306 OriginalEsts= OriginalEsts[beg:end]
307 Qualities = Qualities[beg:end]
308
309 # number of training instances
310 N = numExamples = len(SeqInfo)
311 assert len(Exons) == N and len(OriginalEsts) == N and len(Qualities) == N,\
312 'The Exons,Acc,Don,.. arrays are of different lengths'
313
314 self.plog('Number of training examples: %d\n'% numExamples)
315
316 self.noImprovementCtr = 0
317 self.oldObjValue = 1e8
318
319 iteration_steps = run['iter_steps']
320 remove_duplicate_scores = run['remove_duplicate_scores']
321 print_matrix = run['print_matrix']
322 anzpath = run['anzpath']
323
324 # Initialize parameter vector /
325 #param = Conf.fixedParam[:run['numFeatures']]
326 param = numpy.matlib.rand(run['numFeatures'],1)
327
328 lengthSP = run['numLengthSuppPoints']
329 donSP = run['numDonSuppPoints']
330 accSP = run['numAccSuppPoints']
331 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
332 numq = run['numQualSuppPoints']
333 totalQualSP = run['totalQualSuppPoints']
334
335 # no intron length model
336 if not run['enable_intron_length']:
337 param[:lengthSP] *= 0.0
338
339 # Set the parameters such as limits penalties for the Plifs
340 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
341
342 # Initialize solver
343 self.plog('Initializing problem...\n')
344
345 try:
346 solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
347 except:
348 sys.exit(99)
349
350 #solver = None
351
352 #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
353 #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
354
355 # stores the number of alignments done for each example (best path, second-best path etc.)
356 num_path = [anzpath]*numExamples
357 # stores the gap for each example
358 gap = [0.0]*numExamples
359 #############################################################################################
360 # Training
361 #############################################################################################
362 self.plog('Starting training...\n')
363
364 currentPhi = zeros((run['numFeatures'],1))
365 totalQualityPenalties = zeros((totalQualSP,1))
366
367 numConstPerRound = run['numConstraintsPerRound']
368 solver_call_ctr = 0
369
370 suboptimal_example = 0
371 iteration_nr = 0
372 param_idx = 0
373 const_added_ctr = 0
374
375 featureVectors = zeros((run['numFeatures'],numExamples))
376
377 # the main training loop
378 while True:
379 if iteration_nr == iteration_steps:
380 break
381
382 for exampleIdx in range(numExamples):
383 if (exampleIdx%100) == 0:
384 print 'Current example nr %d' % exampleIdx
385
386 try:
387 dna,est,acc_supp,don_supp,exons,original_est =\
388 getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run)
389 except SpliceSiteException:
390 continue
391
392 dna_len = len(dna)
393
394 #if new_string != original_est:
395 # print new_string,original_est
396 # continue
397
398 if run['mode'] == 'normal':
399 quality = [40]*len(est)
400
401 if run['mode'] == 'using_quality_scores':
402 quality = Qualities[exampleIdx][0]
403
404 if not run['enable_quality_scores']:
405 quality = [40]*len(est)
406
407 if not run['enable_splice_signals']:
408 for idx,elem in enumerate(don_supp):
409 if elem != -inf:
410 don_supp[idx] = 0.0
411
412 for idx,elem in enumerate(acc_supp):
413 if elem != -inf:
414 acc_supp[idx] = 0.0
415
416 #pdb.set_trace()
417 #continue
418
419 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
420 if run['mode'] == 'using_quality_scores':
421 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
422 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
423 quality, qualityPlifs,run)
424 else:
425 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
426
427 dna_calc = dna_calc.replace('-','')
428
429 #print 'right before computeSpliceWeights exampleIdx %d' % exampleIdx
430 # Calculate the weights
431 trueWeightDon, trueWeightAcc, trueWeightIntron =\
432 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
433 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
434
435 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
436 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
437 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
438 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
439
440 if run['mode'] == 'using_quality_scores':
441 totalQualityPenalties = param[-totalQualSP:]
442 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
443
444 # Calculate w'phi(x,y) the total score of the alignment
445 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
446
447 # The allWeights vector is supposed to store the weight parameter
448 # of the true alignment as well as the weight parameters of the
449 # num_path[exampleIdx] other alignments
450 allWeights = zeros((run['numFeatures'],num_path[exampleIdx]+1))
451 allWeights[:,0] = trueWeight[:,0]
452
453 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
454 AlignmentScores[0] = trueAlignmentScore
455
456 ################## Calculate wrong alignment(s) ######################
457 # Compute donor, acceptor with penalty_lookup_new
458 # returns two double lists
459 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
460
461 #myalign wants the acceptor site on the g of the ag
462 #acceptor = acceptor[1:]
463 #acceptor.append(-inf)
464
465 #donor = [-inf] + donor[:-1]
466 #pdb.set_trace()
467
468 ps = h.convert2SWIG()
469
470 _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
471 _newQualityPlifsFeatures, unneeded1, unneeded2 =\
472 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
473 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
474
475 # old code removed
476
477 newSpliceAlign = _newSpliceAlign
478 newEstAlign = _newEstAlign
479 newWeightMatch = _newWeightMatch
480 newDPScores = _newDPScores
481 newQualityPlifsFeatures = _newQualityPlifsFeatures
482
483 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
484 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
485
486 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
487 # Calculate weights of the respective alignments. Note that we are
488 # calculating n-best alignments without hamming loss, so we
489 # have to keep track which of the n-best alignments correspond to
490 # the true one in order not to incorporate a true alignment in the
491 # constraints. To keep track of the true and false alignments we
492 # define an array true_map with a boolean indicating the
493 # equivalence to the true alignment for each decoded alignment.
494 true_map = [0]*(num_path[exampleIdx]+1)
495 true_map[0] = 1
496
497 for pathNr in range(num_path[exampleIdx]):
498 #print 'decodedWeights'
499 #print 'right before computeSpliceWeights(2) exampleIdx %d' % exampleIdx
500 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
501 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
502 acc_supp)
503
504 decodedQualityFeatures = zeros((run['totalQualSuppPoints'],1))
505 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
506 # Gewichte in restliche Zeilen der Matrix speichern
507 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
508
509 hpen = mat(h.penalties).reshape(len(h.penalties),1)
510 dpen = mat(d.penalties).reshape(len(d.penalties),1)
511 apen = mat(a.penalties).reshape(len(a.penalties),1)
512 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
513
514 featureVectors[:,exampleIdx] = allWeights[:,pathNr+1]
515
516 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
517
518 distinct_scores = False
519 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
520 distinct_scores = True
521
522 # Check wether scalar product + loss equals viterbi score
523 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
524 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
525 pdb.set_trace()
526
527 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
528 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
529
530 # if the pathNr-best alignment is very close to the true alignment consider it as true
531 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
532 true_map[pathNr+1] = 1
533
534 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
535 print "suboptimal_example %d\n" %exampleIdx
536 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
537 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
538
539 #pdb.set_trace()
540 suboptimal_example += 1
541 self.plog("suboptimal_example %d\n" %exampleIdx)
542
543 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
544 # this means that all n-best paths are to close to each other
545 # we have to extend the n-best search to a (n+1)-best
546 if len([elem for elem in true_map if elem == 1]) == len(true_map):
547 num_path[exampleIdx] = num_path[exampleIdx]+1
548
549 # Choose true and first false alignment for extending
550 firstFalseIdx = -1
551 for map_idx,elem in enumerate(true_map):
552 if elem == 0:
553 firstFalseIdx = map_idx
554 break
555
556 if False:
557 self.plog("Is considered as: %d\n" % true_map[1])
558
559 result_len = currentAlignment.getResultLength()
560 c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
561 c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
562
563 currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
564
565 dna_array = [0.0]*result_len
566 est_array = [0.0]*result_len
567
568 for r_idx in range(result_len):
569 dna_array[r_idx] = c_dna_array[r_idx]
570 est_array[r_idx] = c_est_array[r_idx]
571
572 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
573 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
574
575 #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
576 #self.plog(line1+'\n')
577 #self.plog(line2+'\n')
578 #self.plog(line3+'\n')
579
580 # if there is at least one useful false alignment add the
581 # corresponding constraints to the optimization problem
582 if firstFalseIdx != -1:
583 firstFalseWeights = allWeights[:,firstFalseIdx]
584 differenceVector = trueWeight - firstFalseWeights
585 #pdb.set_trace()
586
587 #print 'NOT ADDING ANY CONSTRAINTS'
588 const_added = solver.addConstraint(differenceVector, exampleIdx)
589
590 const_added_ctr += 1
591 #
592 # end of one example processing
593 #
594
595 # call solver every nth example //added constraint
596 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
597 objValue,w,self.slacks = solver.solve()
598 solver_call_ctr += 1
599
600 if solver_call_ctr == 5:
601 numConstPerRound = 200
602 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
603
604 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
605 self.noImprovementCtr += 1
606
607 if self.noImprovementCtr == numExamples+1:
608 break
609
610 self.oldObjValue = objValue
611 print "objValue is %f" % objValue
612
613 sum_xis = 0
614 for elem in self.slacks:
615 sum_xis += elem
616
617 print 'sum of slacks is %f'% sum_xis
618 self.plog('sum of slacks is %f\n'% sum_xis)
619
620 for i in range(len(param)):
621 param[i] = w[i]
622
623 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
624 param_idx += 1
625 [h,d,a,mmatrix,qualityPlifs] =\
626 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
627
628 #
629 # end of one iteration through all examples
630 #
631
632 self.plog("suboptimal rounds %d\n" %suboptimal_example)
633
634 if self.noImprovementCtr == numExamples*2:
635 break
636
637 iteration_nr += 1
638
639 #
640 # end of optimization
641 #
642 print 'Training completed'
643
644 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
645 self.logfh.close()
646
647
648 ###############################################################################
649 #
650 # End of the code needed for training
651 #
652 #
653 # Begin of code for prediction
654 #
655 ###############################################################################
656
657
658 def evaluate(self,param_filename):
659 run = self.run
660 beg = run['prediction_begin']
661 end = run['prediction_end']
662
663 data_filename = self.run['dataset_filename']
664
665 SeqInfo, Exons, OriginalEsts, Qualities,\
666 AlternativeSequences = paths_load_data(data_filename)
667
668 #AlternativeSequences = paths_load_data(data_filename,'training',None,self.ARGS)
669
670 self.SeqInfo = SeqInfo
671 self.Exons = Exons
672 self.OriginalEsts= OriginalEsts
673 self.Qualities = Qualities
674 self.AlternativeSequences = AlternativeSequences
675
676 #calc_info(self.Acceptors,self.Donors,self.Exons,self.Qualities)
677 #print 'leaving constructor...'
678
679 self.logfh = open('_qpalma_predict_%d.log'%run['id'],'w+')
680
681 # predict on training set
682 #self.plog('##### Prediction on the training set #####\n')
683 #self.predict(param_filename,0,beg,'TRAIN')
684
685 # predict on test set
686 self.plog('##### Prediction on the test set #####\n')
687 self.predict(param_filename,beg,end,'TEST')
688
689 self.plog('##### Finished prediction #####\n')
690 self.logfh.close()
691
692 def predict(self,param_filename,beg,end,set_flag):
693 """
694 Performing a prediction takes...
695
696 """
697
698 run = self.run
699
700 if self.run['mode'] == 'normal':
701 self.use_quality_scores = False
702
703 elif self.run['mode'] == 'using_quality_scores':
704 self.use_quality_scores = True
705 else:
706 assert(False)
707
708 SeqInfo = self.SeqInfo[beg:end]
709 Exons = self.Exons[beg:end]
710 OriginalEsts= self.OriginalEsts[beg:end]
711 Qualities = self.Qualities[beg:end]
712 AlternativeSequences = self.AlternativeSequences[beg:end]
713
714 # number of training instances
715 N = numExamples = len(SeqInfo)
716 assert len(Exons) == N and len(OriginalEsts) == N\
717 and len(Qualities) == N and len(AlternativeSequences) == N,\
718 'The Exons,Acc,Don,.. arrays are of different lengths'
719
720 self.plog('Number of training examples: %d\n'% numExamples)
721
722 self.noImprovementCtr = 0
723 self.oldObjValue = 1e8
724
725 remove_duplicate_scores = self.run['remove_duplicate_scores']
726 print_matrix = self.run['print_matrix']
727 anzpath = self.run['anzpath']
728
729 param = cPickle.load(open(param_filename))
730
731 # Set the parameters such as limits penalties for the Plifs
732 [h,d,a,mmatrix,qualityPlifs] =\
733 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
734
735 #############################################################################################
736 # Prediction
737 #############################################################################################
738 self.plog('Starting prediction...\n')
739
740 donSP = self.run['numDonSuppPoints']
741 accSP = self.run['numAccSuppPoints']
742 lengthSP = self.run['numLengthSuppPoints']
743 mmatrixSP = run['matchmatrixRows']*run['matchmatrixCols']
744 numq = self.run['numQualSuppPoints']
745 totalQualSP = self.run['totalQualSuppPoints']
746
747 totalQualityPenalties = zeros((totalQualSP,1))
748
749 # where we store the predictions
750 allPredictions = []
751
752 # beginning of the prediction loop
753 for exampleIdx in range(numExamples):
754 self.plog('Loading example nr. %d...\n'%exampleIdx)
755
756 currentSeqInfo = SeqInfo[exampleIdx]
757 id,chr,strand,up_cut,down_cut,true_cut = currentSeqInfo
758
759 try:
760 dna,est,acc_supp,don_supp,exons,original_est =\
761 getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run)
762 except SpliceSiteException:
763 continue
764
765 dna_len = len(dna)
766
767 #if new_string != original_est:
768 # print 'seq. inconsistency'
769 # print new_string,original_est
770 # print exampleIdx
771 # continue
772
773 if run['mode'] == 'normal':
774 quality = [40]*len(est)
775
776 if run['mode'] == 'using_quality_scores':
777 quality = Qualities[exampleIdx][0]
778
779 if not run['enable_quality_scores']:
780 quality = [40]*len(est)
781
782 if not run['enable_splice_signals']:
783 for idx,elem in enumerate(don_supp):
784 if elem != -inf:
785 don_supp[idx] = 0.0
786
787 for idx,elem in enumerate(acc_supp):
788 if elem != -inf:
789 acc_supp[idx] = 0.0
790
791 currentAlternatives = AlternativeSequences[exampleIdx]
792 current_example_predictions = []
793
794 # first make a prediction on the dna fragment which comes from the ground truth
795 current_prediction = self.calc_alignment(dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs)
796 current_prediction['exampleIdx'] = exampleIdx
797 current_prediction['id'] = id
798 current_prediction['start_pos'] = up_cut
799 current_prediction['label'] = True
800 current_prediction['true_cut'] = true_cut
801 current_prediction['chr'] = chr
802 current_prediction['strand'] = strand
803
804 current_example_predictions.append(current_prediction)
805
806 # then make predictions for all dna fragments that where occurring in
807 # the vmatch results
808 for alternative_alignment in currentAlternatives:
809 chr, strand, genomicSeq_start, genomicSeq_stop, currentLabel = alternative_alignment
810
811 if not chr in range(1,6):
812 continue
813
814 currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
815
816 if not run['enable_splice_signals']:
817 for idx,elem in enumerate(currentDon):
818 if elem != -inf:
819 currentDon[idx] = 0.0
820
821 for idx,elem in enumerate(currentAcc):
822 if elem != -inf:
823 currentAcc[idx] = 0.0
824
825 current_prediction = self.calc_alignment(currentDNASeq, est, exons,\
826 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs)
827 current_prediction['exampleIdx'] = exampleIdx
828 current_prediction['id'] = id
829 current_prediction['start_pos'] = up_cut
830 current_prediction['alternative_start_pos'] = genomicSeq_start
831 current_prediction['label'] = currentLabel
832 current_prediction['true_cut'] = true_cut
833 current_prediction['chr'] = chr
834 current_prediction['strand'] = strand
835
836 current_example_predictions.append(current_prediction)
837
838 allPredictions.append(current_example_predictions)
839
840 # end of the prediction loop we save all predictions in a pickle file and exit
841 cPickle.dump(allPredictions,open('%s_allPredictions_%s'%(run['name'],set_flag),'w+'))
842 print 'Prediction completed'
843
844
845 def calc_alignment(self, dna, est, exons, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
846 """
847 Given two sequences and the parameters we calculate on alignment
848 """
849
850 run = self.run
851 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
852
853 dna = str(dna)
854 est = str(est)
855 dna_len = len(dna)
856 est_len = len(est)
857
858 ps = h.convert2SWIG()
859
860 newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
861 newQualityPlifsFeatures, dna_array, est_array =\
862 self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True)
863
864 mm_len = run['matchmatrixRows']*run['matchmatrixCols']
865
866 # old code removed
867 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
868 newWeightMatch = newWeightMatch.reshape(1,mm_len)
869 true_map = [0]*2
870 true_map[0] = 1
871 pathNr = 0
872
873 _newSpliceAlign = newSpliceAlign.flatten().tolist()[0]
874 _newEstAlign = newEstAlign.flatten().tolist()[0]
875
876 #if True:
877 alignment = get_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array) #(qStart, qEnd, tStart, tEnd, num_exons, qExonSizes, qStarts, qEnds, tExonSizes, tStarts, tEnds)
878 #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
879 #self.plog(line1+'\n')
880 #self.plog(line2+'\n')
881 #self.plog(line3+'\n')
882
883 #currentAlignmentString = ''
884
885 #for idx in range(len(line1)):
886 # dna_elem = line1[idx]
887 # match_elem = line2[idx]
888 # est_elem = line3[idx]
889
890 # if dna_elem != '' and est_elem != '' and match_elem != '':
891 # if dna_elem == est_elem:
892 # currentAlignmentString = '%s%s' % (currentAlignmentString,est_elem)
893
894 # if dna_elem != '' and est_elem != '' and match_elem == '':
895 # currentAlignmentString = '%s[%s%s]' %\
896 # (currentAlignmentString,dna_elem,est_elem)
897
898 # if dna_elem != '' and est_elem == '_':
899 # currentAlignmentString = '%s[%s-]' %\
900 # (currentAlignmentString,dna_elem)
901 #
902 # if dna_elem == '_' and est_elem != '':
903 # currentAlignmentString = '%s[-%s]' %\
904 # (currentAlignmentString,est_elem)
905
906
907 newExons = self.calculatePredictedExons(newSpliceAlign)
908
909 current_prediction = {'predExons':newExons, 'trueExons':exons,\
910 'dna':dna, 'est':est, 'DPScores':newDPScores,\
911 'alignment':alignment}
912
913 return current_prediction
914
915
916 def calculatePredictedExons(self,SpliceAlign):
917 newExons = []
918 oldElem = -1
919 SpliceAlign = SpliceAlign.flatten().tolist()[0]
920 SpliceAlign.append(-1)
921 for pos,elem in enumerate(SpliceAlign):
922 if pos == 0:
923 oldElem = -1
924 else:
925 oldElem = SpliceAlign[pos-1]
926
927 if oldElem != 0 and elem == 0: # start of exon
928 newExons.append(pos)
929
930 if oldElem == 0 and elem != 0: # end of exon
931 newExons.append(pos)
932
933 return newExons
934
935
936 ###########################
937 # A simple command line
938 # interface
939 ###########################
940
941 if __name__ == '__main__':
942 mode = sys.argv[1]
943 run_obj_fn = sys.argv[2]
944
945 run_obj = cPickle.load(open(run_obj_fn))
946
947 qpalma = QPalma(run_obj)
948
949
950 if len(sys.argv) == 3 and mode == 'train':
951 qpalma.train()
952
953 elif len(sys.argv) == 4 and mode == 'predict':
954 param_filename = sys.argv[3]
955 assert os.path.exists(param_filename)
956 qpalma.evaluate(param_filename)
957 else:
958 print 'You have to choose between training or prediction mode:'
959 print 'python qpalma. py (train|predict) <param_file>'