1564ad85db970e8fd8a178c6279bd5860479fdc4
[qpalma.git] / qpalma / qpalma_main.py
1 # This program is free software; you can redistribute it and/or modify
2 # it under the terms of the GNU General Public License as published by
3 # the Free Software Foundation; either version 2 of the License, or
4 # (at your option) any later version.
5 #
6 # Written (W) 2008 Fabio De Bona
7 # Copyright (C) 2008 Max-Planck-Society
8
9 import array
10 import cPickle
11 import os.path
12 import pdb
13 import sys
14
15 import numpy
16 from numpy.matlib import mat,zeros,ones,inf
17 from numpy.linalg import norm
18
19 #from qpalma.SIQP_CPX import SIQPSolver
20 #from qpalma.SIQP_CVXOPT import SIQPSolver
21
22 import QPalmaDP
23 import qpalma
24 from qpalma.computeSpliceWeights import *
25 from qpalma.set_param_palma import *
26 from qpalma.computeSpliceAlignWithQuality import *
27 from qpalma.TrainingParam import Param
28 from qpalma.Plif import Plf,compute_donacc
29
30 from qpalma.sequence_utils import SeqSpliceInfo,DataAccessWrapper,unbracket_seq
31 from qpalma.utils import pprint_alignment, get_alignment
32
33 jp = os.path.join
34 dotProd = lambda x,y: (x.T * y)[0,0]
35
36
37 class SpliceSiteException:
38 pass
39
40
41 def preprocessExample(training_set,exampleKey,seqInfo,settings):
42 """
43 This function...
44 """
45
46 currentSeqInfo,originalRead,currentQualities,currentExons = training_set[exampleKey]
47 id,chr,strand,up_cut,down_cut = currentSeqInfo
48
49 read = unbracket_seq(originalRead)
50 read = read.replace('-','')
51
52 dna, acc_supp, don_supp = seqInfo.get_seq_and_scores(chr,strand,up_cut,down_cut)
53
54 exons = currentExons - up_cut
55
56 if exons.shape == (2,2):
57 fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
58
59 donor_elem = dna[exons[0,1]:exons[0,1]+2]
60 acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
61
62 if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
63 print 'invalid donor in example %d'% exampleKey
64 raise SpliceSiteException
65
66 if not ( acceptor_elem == 'ag' ):
67 print 'invalid acceptor in example %d'% exampleKey
68 raise SpliceSiteException
69
70 assert len(fetched_dna_subseq) == len(read), pdb.set_trace()
71
72 return dna,read,acc_supp,don_supp,exons,originalRead,currentQualities
73
74
75 def performAlignment(dna,read,quality,mmatrix,donor,acceptor,ps,qualityPlifs,current_num_path,prediction_mode,settings):
76 """
77 Given the needed input this method calls the QPalma C module which
78 calculates a dynamic programming in order to obtain an alignment
79 """
80
81 prb = QPalmaDP.createDoubleArrayFromList(quality)
82 chastity = QPalmaDP.createDoubleArrayFromList([.0]*len(read))
83
84 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
85 mm_len = settings['matchmatrixRows']*settings['matchmatrixCols']
86
87 donor = QPalmaDP.createDoubleArrayFromList(donor)
88 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
89
90 # Create the alignment object representing the interface to the C/C++ code.
91 currentAlignment = QPalmaDP.Alignment(settings['numQualPlifs'],settings['numQualSuppPoints'], settings['enable_quality_scores'])
92
93 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
94
95 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
96 currentAlignment.myalign( current_num_path, dna, len(dna),\
97 read, len(read), prb, chastity, ps, matchmatrix, mm_len, donor, len(donor),\
98 acceptor, len(acceptor), c_qualityPlifs, settings['remove_duplicate_scores'],\
99 settings['print_matrix'] )
100
101 if prediction_mode:
102 # part that is only needed for prediction
103 result_len = currentAlignment.getResultLength()
104 dna_array,read_array = currentAlignment.getAlignmentArraysNew()
105 else:
106 dna_array = None
107 read_array = None
108
109 newSpliceAlign, newEstAlign, newWeightMatch, newDPScores, newQualityPlifsFeatures =\
110 currentAlignment.getAlignmentResultsNew()
111
112 return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
113 newQualityPlifsFeatures, dna_array, read_array
114
115
116
117 class QPalma:
118 """
119 This class wraps the training and prediction functions for
120 the alignment.
121 """
122
123 def __init__(self,seqInfo,dmode=False):
124 self.ARGS = Param()
125 self.qpalma_debug_mode = dmode
126 self.seqInfo = seqInfo
127
128
129 def plog(self,string):
130 self.logfh.write(string)
131 self.logfh.flush()
132
133
134 def init_training(self,dataset_fn,training_keys,settings,set_name):
135 full_working_path = jp(settings['training_dir'],run_name)
136
137 #assert not os.path.exists(full_working_path)
138 if not os.path.exists(full_working_path):
139 os.mkdir(full_working_path)
140
141 assert os.path.exists(full_working_path)
142
143 # ATTENTION: Changing working directory
144 os.chdir(full_working_path)
145
146 self.logfh = open('_qpalma_train.log','w+')
147
148 training_set = cPickle.load(open(dataset_fn))
149
150 self.train(training_set,settings,set_name)
151
152
153 def train(self,training_set,settings,set_name):
154 """
155 The mainloop for training.
156 """
157
158 numExamples = len(training_set)
159 self.plog('Number of training examples: %d\n'% numExamples)
160
161 self.noImprovementCtr = 0
162 self.oldObjValue = 1e8
163
164 remove_duplicate_scores = settings['remove_duplicate_scores']
165 print_matrix = settings['print_matrix']
166
167 lengthSP = settings['numLengthSuppPoints']
168 donSP = settings['numDonSuppPoints']
169 accSP = settings['numAccSuppPoints']
170 mmatrixSP = settings['matchmatrixRows']*settings['matchmatrixCols']
171 numq = settings['numQualSuppPoints']
172 totalQualSP = settings['totalQualSuppPoints']
173
174 # calculate the total number of features
175 numFeatures = lengthSP+donSP+accSP+mmatrixSP*numq
176
177 # Initialize parameter vector
178 param = numpy.matlib.rand(numFeatures,1)
179
180 # no intron length model
181 if not settings['enable_intron_length']:
182 param[:lengthSP] *= 0.0
183
184 # Set the parameters such as limits penalties for the Plifs
185 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,settings)
186
187 # Initialize solver
188 self.plog('Initializing problem...\n')
189
190 try:
191 solver = SIQPSolver(numFeatures,numExamples,settings['C'],self.logfh,settings)
192 except:
193 self.plog('Got no license. Telling queue to reschedule job...\n')
194 sys.exit(99)
195
196 solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
197 solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
198
199 # stores the number of alignments done for each example (best path, second-best path etc.)
200 num_path = settings['anzpath']*numExamples
201
202 currentPhi = zeros((numFeatures,1))
203 totalQualityPenalties = zeros((totalQualSP,1))
204
205 numConstPerRound = settings['numConstraintsPerRound']
206 solver_call_ctr = 0
207
208 suboptimal_example = 0
209 iteration_nr = 0
210 param_idx = 0
211 const_added_ctr = 0
212
213 featureVectors = zeros((numFeatures,numExamples))
214
215 self.plog('Starting training...\n')
216 # the main training loop
217 while True:
218 if iteration_nr == settings['iter_steps']:
219 break
220
221 for exampleIdx,example_key in enumerate(training_set.keys()):
222 print 'Current example %d' % example_key
223
224 try:
225 dna,est,acc_supp,don_supp,exons,original_est,currentQualities =\
226 preprocessExample(training_set,example_key,self.seqInfo,settings)
227 except SpliceSiteException:
228 continue
229
230 if settings['enable_quality_scores']:
231 quality = currentQualities[quality_index]
232 else:
233 quality = [40]*len(read)
234
235 if not settings['enable_splice_scores']:
236 for idx,elem in enumerate(don_supp):
237 if elem != -inf:
238 don_supp[idx] = 0.0
239
240 for idx,elem in enumerate(acc_supp):
241 if elem != -inf:
242 acc_supp[idx] = 0.0
243
244 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
245 if settings['enable_quality_scores']:
246 trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
247 computeSpliceAlignWithQuality(dna, exons, est, original_est,\
248 quality, qualityPlifs,settings)
249 else:
250 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
251
252 dna_calc = dna_calc.replace('-','')
253
254 #print 'right before computeSpliceWeights exampleIdx %d' % exampleIdx
255 # Calculate the weights
256 trueWeightDon, trueWeightAcc, trueWeightIntron =\
257 computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
258 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
259
260 currentPhi[0:lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
261 currentPhi[lengthSP:lengthSP+donSP] = mat(d.penalties[:]).reshape(donSP,1)
262 currentPhi[lengthSP+donSP:lengthSP+donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
263 currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP] = mmatrix[:]
264
265 if settings['enable_quality_scores']:
266 totalQualityPenalties = param[-totalQualSP:]
267 currentPhi[lengthSP+donSP+accSP+mmatrixSP:] = totalQualityPenalties[:]
268
269 # Calculate w'phi(x,y) the total score of the alignment
270 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
271
272 # The allWeights vector is supposed to store the weight parameter
273 # of the true alignment as well as the weight parameters of the
274 # num_path[exampleIdx] other alignments
275 allWeights = zeros((numFeatures,num_path[exampleIdx]+1))
276 allWeights[:,0] = trueWeight[:,0]
277
278 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
279 AlignmentScores[0] = trueAlignmentScore
280
281 ################## Calculate wrong alignment(s) ######################
282 # Compute donor, acceptor with penalty_lookup_new
283 # returns two double lists
284 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
285
286 ps = h.convert2SWIG()
287
288 newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
289 newQualityPlifsFeatures, unneeded1, unneeded2 =\
290 performAlignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False,settings)
291 mm_len = settings['matchmatrixRows']*settings['matchmatrixCols']
292
293 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],len(dna))
294 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
295
296 newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],settings['totalQualSuppPoints'])
297 # Calculate weights of the respective alignments. Note that we are calculating n-best alignments without
298 # hamming loss, so we have to keep track which of the n-best alignments correspond to the true one in order
299 # not to incorporate a true alignment in the
300 # constraints. To keep track of the true and false alignments we
301 # define an array true_map with a boolean indicating the
302 # equivalence to the true alignment for each decoded alignment.
303 true_map = [0]*(num_path[exampleIdx]+1)
304 true_map[0] = 1
305
306 for pathNr in range(num_path[exampleIdx]):
307 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
308 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
309 acc_supp)
310
311 decodedQualityFeatures = zeros((settings['totalQualSuppPoints'],1))
312 decodedQualityFeatures = newQualityPlifsFeatures[pathNr,:].T
313 # Gewichte in restliche Zeilen der Matrix speichern
314 allWeights[:,pathNr+1] = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures[:]])
315
316 hpen = mat(h.penalties).reshape(len(h.penalties),1)
317 dpen = mat(d.penalties).reshape(len(d.penalties),1)
318 apen = mat(a.penalties).reshape(len(a.penalties),1)
319 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties[:]])
320
321 featureVectors[:,exampleIdx] = allWeights[:,pathNr+1]
322
323 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
324
325 distinct_scores = False
326 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
327 distinct_scores = True
328
329 # Check wether scalar product + loss equals viterbi score
330 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
331 self.plog("Scalar prod. + loss not equals Viterbi output!\n")
332 pdb.set_trace()
333
334 self.plog(" scalar prod (correct) : %f\n"%AlignmentScores[0])
335 self.plog(" scalar prod (pred.) : %f %f\n"%(newDPScores[pathNr,0],AlignmentScores[pathNr+1]))
336
337 # if the pathNr-best alignment is very close to the true alignment consider it as true
338 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
339 true_map[pathNr+1] = 1
340
341 if not trueAlignmentScore <= max(AlignmentScores[1:]) + 1e-6:
342 print "suboptimal_example %d\n" %exampleIdx
343 #trueSpliceAlign, trueWeightMatch, trueWeightQuality dna_calc=\
344 #computeSpliceAlignWithQuality(dna, exons, est, original_est, quality, qualityPlifs)
345
346 #pdb.set_trace()
347 suboptimal_example += 1
348 self.plog("suboptimal_example %d\n" %exampleIdx)
349
350 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
351 # this means that all n-best paths are to close to each other
352 # we have to extend the n-best search to a (n+1)-best
353 if len([elem for elem in true_map if elem == 1]) == len(true_map):
354 num_path[exampleIdx] = num_path[exampleIdx]+1
355
356 # Choose true and first false alignment for extending
357 firstFalseIdx = -1
358 for map_idx,elem in enumerate(true_map):
359 if elem == 0:
360 firstFalseIdx = map_idx
361 break
362
363 if False:
364 self.plog("Is considered as: %d\n" % true_map[1])
365
366 #result_len = currentAlignment.getResultLength()
367
368 dna_array,est_array = currentAlignment.getAlignmentArraysNew()
369
370 _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
371 _newEstAlign = newEstAlign[0].flatten().tolist()[0]
372
373 # if there is at least one useful false alignment add the
374 # corresponding constraints to the optimization problem
375 if firstFalseIdx != -1:
376 firstFalseWeights = allWeights[:,firstFalseIdx]
377 differenceVector = trueWeight - firstFalseWeights
378 #pdb.set_trace()
379
380 const_added = solver.addConstraint(differenceVector, exampleIdx)
381 const_added_ctr += 1
382
383 # end of one example processing
384
385 # call solver every nth example / added constraint
386 if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
387 objValue,w,self.slacks = solver.solve()
388 solver_call_ctr += 1
389
390 if solver_call_ctr == 5:
391 numConstPerRound = 200
392 self.plog('numConstPerRound is now %d\n'% numConstPerRound)
393
394 if math.fabs(objValue - self.oldObjValue) <= 1e-6:
395 self.noImprovementCtr += 1
396
397 if self.noImprovementCtr == numExamples+1:
398 break
399
400 self.oldObjValue = objValue
401 print "objValue is %f" % objValue
402
403 sum_xis = 0
404 for elem in self.slacks:
405 sum_xis += elem
406
407 print 'sum of slacks is %f'% sum_xis
408 self.plog('sum of slacks is %f\n'% sum_xis)
409
410 for i in range(len(param)):
411 param[i] = w[i]
412
413 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
414 param_idx += 1
415 [h,d,a,mmatrix,qualityPlifs] =\
416 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,settings)
417
418 # end of one iteration through all examples
419 self.plog("suboptimal rounds %d\n" %suboptimal_example)
420
421 if self.noImprovementCtr == numExamples*2:
422 FinalizeTraining(param,'param_%d.pickle'%param_idx)
423
424 iteration_nr += 1
425
426 #
427 # end of optimization
428 #
429 FinalizeTraining(param,'param_%d.pickle'%param_idx)
430
431
432 def FinalizeTraining(self,vector,name):
433 self.plog("Training completed")
434 cPickle.dump(param,open(name,'w+'))
435 self.logfh.close()
436
437
438 ###############################################################################
439 #
440 # End of the code needed for training
441 #
442 # Begin of code for prediction
443 #
444 ###############################################################################
445
446
447 def init_prediction(self,dataset_fn,prediction_keys,settings,set_name):
448 """
449 Performing a prediction takes...
450 """
451 self.set_name = set_name
452
453 full_working_path = settings['prediction_dir']
454
455 print 'full_working_path is %s' % full_working_path
456
457 #assert not os.path.exists(full_working_path)
458 if not os.path.exists(full_working_path):
459 os.mkdir(full_working_path)
460
461 assert os.path.exists(full_working_path)
462
463 # ATTENTION: Changing working directory
464 os.chdir(full_working_path)
465
466 self.logfh = open('_qpalma_predict_%s.log'%set_name,'w+')
467
468 # number of prediction instances
469 self.plog('Number of prediction examples: %d\n'% len(prediction_keys))
470
471 # load dataset and fetch instances that shall be predicted
472 dataset = cPickle.load(open(dataset_fn))
473
474 prediction_set = {}
475 for key in prediction_keys:
476 prediction_set[key] = dataset[key]
477
478 # we do not need the full dataset anymore
479 del dataset
480
481 self.predict(prediction_set,settings)
482
483
484 def predict(self,prediction_set,settings):
485 """
486 This method...
487 """
488
489 # Load parameter vector to predict with
490 param = cPickle.load(open(settings['prediction_param_fn']))
491
492 # Set the parameters such as limits/penalties for the Plifs
493 [h,d,a,mmatrix,qualityPlifs] =\
494 set_param_palma(param,self.ARGS.train_with_intronlengthinformation,settings)
495
496 if not self.qpalma_debug_mode:
497 self.plog('Starting prediction...\n')
498
499 self.problem_ctr = 0
500
501 # where we store the predictions
502 allPredictions = []
503
504 # we take the first quality vector of the tuple of quality vectors
505 quality_index = 0
506
507 # beginning of the prediction loop
508 for example_key in prediction_set.keys():
509 print 'Current example %d' % example_key
510 for example in prediction_set[example_key]:
511
512 currentSeqInfo,read,currentQualities = example
513
514 id,chromo,strand,genomicSeq_start,genomicSeq_stop =\
515 currentSeqInfo
516
517 if not self.qpalma_debug_mode:
518 self.plog('Loading example id: %d...\n'% int(id))
519
520 if settings['enable_quality_scores']:
521 quality = currentQualities[quality_index]
522 else:
523 quality = [40]*len(read)
524
525 try:
526 currentDNASeq, currentAcc, currentDon = self.seqInfo.get_seq_and_scores(chromo,strand,genomicSeq_start,genomicSeq_stop)
527 except:
528 self.problem_ctr += 1
529 print sys.exc_info()
530 continue
531
532 if not settings['enable_splice_scores']:
533 for idx,elem in enumerate(currentDon):
534 if elem != -inf:
535 currentDon[idx] = 0.0
536
537 for idx,elem in enumerate(currentAcc):
538 if elem != -inf:
539 currentAcc[idx] = 0.0
540
541 current_prediction = self.calc_alignment(currentDNASeq, read,\
542 quality, currentDon, currentAcc, d, a, h, mmatrix, qualityPlifs,settings)
543
544 current_prediction['id'] = id
545 current_prediction['chr'] = chromo
546 current_prediction['strand'] = strand
547 current_prediction['start_pos'] = genomicSeq_start
548
549 allPredictions.append(current_prediction)
550
551 if not self.qpalma_debug_mode:
552 self.FinalizePrediction(allPredictions)
553 else:
554 return allPredictions
555
556
557 def FinalizePrediction(self,allPredictions):
558 """ End of the prediction loop we save all predictions in a pickle file and exit """
559
560 cPickle.dump(allPredictions,open('%s.predictions.pickle'%(self.set_name),'w+'))
561 self.plog('Prediction completed\n')
562 mes = 'Problem ctr %d' % self.problem_ctr
563 self.plog(mes+'\n')
564 self.logfh.close()
565
566
567 def calc_alignment(self, dna, read, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs,settings):
568 """
569 Given two sequences and the parameters we calculate on alignment
570 """
571
572 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
573
574 if '-' in read:
575 self.plog('found gap\n')
576 read = read.replace('-','')
577 assert len(read) == Conf.read_size
578
579 ps = h.convert2SWIG()
580
581 newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
582 newQualityPlifsFeatures, dna_array, read_array =\
583 performAlignment(dna,read,quality,mmatrix,donor,acceptor,ps,qualityPlifs,1,True,settings)
584
585 mm_len = settings['matchmatrixRows']*settings['matchmatrixCols']
586
587 true_map = [0]*2
588 true_map[0] = 1
589 pathNr = 0
590
591 _newSpliceAlign = array.array('B',newSpliceAlign)
592 _newEstAlign = array.array('B',newEstAlign)
593
594 #(qStart, qEnd, tStart, tEnd, num_exons, qExonSizes, qStarts, qEnds, tExonSizes, tStarts, tEnds)
595 alignment = get_alignment(_newSpliceAlign,_newEstAlign, dna_array, read_array)
596
597 dna_array = array.array('B',dna_array)
598 read_array = array.array('B',read_array)
599
600 newExons = self.calculatePredictedExons(newSpliceAlign)
601
602 current_prediction = {'predExons':newExons, 'dna':dna, 'read':read, 'DPScores':newDPScores,\
603 'alignment':alignment,'spliceAlign':_newSpliceAlign,'estAlign':_newEstAlign,\
604 'dna_array':dna_array, 'read_array':read_array }
605
606 return current_prediction
607
608
609 def calculatePredictedExons(self,SpliceAlign):
610 newExons = []
611 oldElem = -1
612 SpliceAlign.append(-1)
613 for pos,elem in enumerate(SpliceAlign):
614 if pos == 0:
615 oldElem = -1
616 else:
617 oldElem = SpliceAlign[pos-1]
618
619 if oldElem != 0 and elem == 0: # start of exon
620 newExons.append(pos)
621
622 if oldElem == 0 and elem != 0: # end of exon
623 newExons.append(pos)
624
625 return newExons