+ added splicesite scores to training
[qpalma.git] / scripts / qpalma_train.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 #
7 #
8 ###########################################################
9
10 import sys
11 import subprocess
12 import scipy.io
13 import pdb
14 import os.path
15
16 from numpy.matlib import mat,zeros,ones,inf
17 from numpy.linalg import norm
18
19 import QPalmaDP
20
21 import qpalma
22 from qpalma.SIQP_CPX import SIQPSolver
23 from qpalma.DataProc import *
24
25 from qpalma.generateEvaluationData import *
26 from qpalma.computeSpliceWeights import *
27 from qpalma.set_param_palma import *
28 from qpalma.computeSpliceAlignWithQuality import *
29 from qpalma.penalty_lookup_new import *
30 from qpalma.compute_donacc import *
31 from qpalma.TrainingParam import Param
32 from qpalma.export_param import *
33
34 import qpalma.Configuration
35 from qpalma.Plif import Plf
36 from qpalma.Helpers import *
37
38 from qpalma.tools.splicesites import getDonAccScores
39
40 def getQualityFeatureCounts(qualityPlifs):
41 weightQuality = qualityPlifs[0].penalties
42 for currentPlif in qualityPlifs[1:]:
43 weightQuality = numpy.vstack([weightQuality, currentPlif.penalties])
44
45 return weightQuality
46
47 class QPalma:
48 """
49 A training method for the QPalma project
50 """
51
52 def __init__(self):
53 self.ARGS = Param()
54 self.logfh = open('qpalma.log','w+')
55
56 #gen_file= '%s/genome.config' % self.ARGS.basedir
57 #ginfo_filename = 'genome_info.pickle'
58 #self.genome_info = fetch_genome_info(ginfo_filename)
59 #self.plog('genome_info.basedir is %s\n'%self.genome_info.basedir)
60
61 #self.ARGS.train_with_splicesitescoreinformation = False
62
63 def plog(self,string):
64 self.logfh.write(string)
65 self.logfh.flush()
66
67 def run(self):
68 # Load the whole dataset
69 if Configuration.mode == 'normal':
70 #Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
71 Sequences, Acceptors, Donors, Exons, Ests, Qualities = loadArtificialData(1000)
72 use_quality_scores = False
73 elif Configuration.mode == 'using_quality_scores':
74 Sequences, Acceptors, Donors, Exons, Ests, Qualities, SplitPos = paths_load_data_solexa('training',None,self.ARGS)
75
76 end = 30
77 Sequences = Sequences[:end]
78 Exons = Exons[:end]
79 Ests = Ests[:end]
80 Qualities = Qualities[:end]
81 SplitPos = SplitPos[:end]
82
83 Donors, Acceptors = getDonAccScores(Sequences)
84
85 #Sequences, Acceptors, Donors, Exons, Ests, Qualities = loadArtificialData(1000)
86
87 #Sequences_, Acceptors_, Donors_, Exons_, Ests_, Qualities_ = generateData(100)
88
89 #Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
90 #Qualities = []
91 #for i in range(len(Ests)):
92 # Qualities.append([40]*len(Ests[i]))
93 use_quality_scores = True
94 else:
95 assert(False)
96
97 # number of training instances
98 N = len(Sequences)
99 self.numExamples = N
100 assert N == len(Acceptors) and N == len(Acceptors) and N == len(Exons)\
101 and N == len(Ests), 'The Seq,Accept,Donor,.. arrays are of different lengths'
102 self.plog('Number of training examples: %d\n'% N)
103 print 'Number of features: %d\n'% Configuration.numFeatures
104
105 iteration_steps = Configuration.iter_steps ; #upper bound on iteration steps
106 remove_duplicate_scores = Configuration.remove_duplicate_scores
107 print_matrix = Configuration.print_matrix
108 anzpath = Configuration.anzpath
109
110 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
111 param = Configuration.fixedParam
112
113 # Set the parameters such as limits penalties for the Plifs
114 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
115
116 # delete splicesite-score-information
117 if not self.ARGS.train_with_splicesitescoreinformation:
118 for i in range(len(Acceptors)):
119 if Acceptors[i] > -20:
120 Acceptors[i] = 1
121 if Donors[i] >-20:
122 Donors[i] = 1
123
124 # Initialize solver
125 if Configuration.USE_OPT:
126 self.plog('Initializing problem...\n')
127 solver = SIQPSolver(Configuration.numFeatures,self.numExamples,Configuration.C,self.logfh)
128
129 # stores the number of alignments done for each example (best path, second-best path etc.)
130 num_path = [anzpath]*N
131 # stores the gap for each example
132 gap = [0.0]*N
133
134 #############################################################################################
135 # Training
136 #############################################################################################
137 self.plog('Starting training...\n')
138
139 donSP = Configuration.numDonSuppPoints
140 accSP = Configuration.numAccSuppPoints
141 lengthSP = Configuration.numLengthSuppPoints
142 mmatrixSP = Configuration.sizeMatchmatrix[0]\
143 *Configuration.sizeMatchmatrix[1]
144 numq = Configuration.numQualSuppPoints
145 totalQualSP = Configuration.totalQualSuppPoints
146
147 currentPhi = zeros((Configuration.numFeatures,1))
148 totalQualityPenalties = zeros((totalQualSP,1))
149
150 iteration_nr = 0
151 param_idx = 0
152 const_added_ctr = 0
153 while True:
154 if iteration_nr == iteration_steps:
155 break
156
157 for exampleIdx in range(self.numExamples):
158
159 if (exampleIdx%10) == 0:
160 print 'Current example nr %d' % exampleIdx
161
162 dna = Sequences[exampleIdx]
163 est = Ests[exampleIdx]
164
165 if Configuration.mode == 'normal':
166 quality = [40]*len(est)
167
168 if Configuration.mode == 'using_quality_scores':
169 quality = Qualities[exampleIdx]
170
171 exons = Exons[exampleIdx]
172 # NoiseMatrix = Noises[exampleIdx]
173 don_supp = Donors[exampleIdx]
174 acc_supp = Acceptors[exampleIdx]
175
176 if exons[-1,1] > len(dna):
177 continue
178
179 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
180 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
181
182 #pdb.set_trace()
183 # Calculate the weights
184 trueWeightDon, trueWeightAcc, trueWeightIntron = computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
185 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
186
187 currentPhi[0:donSP] = mat(d.penalties[:]).reshape(donSP,1)
188 currentPhi[donSP:donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
189 currentPhi[donSP+accSP:donSP+accSP+lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
190 currentPhi[donSP+accSP+lengthSP:donSP+accSP+lengthSP+mmatrixSP] = mmatrix[:]
191
192 if Configuration.mode == 'using_quality_scores':
193 totalQualityPenalties = param[-totalQualSP:]
194 currentPhi[donSP+accSP+lengthSP+mmatrixSP:] = totalQualityPenalties[:]
195
196 # Calculate w'phi(x,y) the total score of the alignment
197 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
198
199 # The allWeights vector is supposed to store the weight parameter
200 # of the true alignment as well as the weight parameters of the
201 # num_path[exampleIdx] other alignments
202 allWeights = zeros((Configuration.numFeatures,num_path[exampleIdx]+1))
203 allWeights[:,0] = trueWeight[:,0]
204
205 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
206 AlignmentScores[0] = trueAlignmentScore
207
208 ################## Calculate wrong alignment(s) ######################
209
210 # Compute donor, acceptor with penalty_lookup_new
211 # returns two double lists
212 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
213
214 #myalign wants the acceptor site on the g of the ag
215 acceptor = acceptor[1:]
216 acceptor.append(-inf)
217
218 # for now we don't use donor/acceptor scores
219
220 #donor = [-inf] * len(donor)
221 #acceptor = [-inf] * len(donor)
222
223 dna = str(dna)
224 est = str(est)
225 dna_len = len(dna)
226 est_len = len(est)
227
228 ps = h.convert2SWIG()
229
230 prb = QPalmaDP.createDoubleArrayFromList(quality)
231 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
232
233 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
234 mm_len = Configuration.sizeMatchmatrix[0]*Configuration.sizeMatchmatrix[1]
235
236 d_len = len(donor)
237 donor = QPalmaDP.createDoubleArrayFromList(donor)
238 a_len = len(acceptor)
239 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
240
241 # Create the alignment object representing the interface to the C/C++ code.
242 currentAlignment = QPalmaDP.Alignment(Configuration.numQualPlifs,Configuration.numQualSuppPoints, use_quality_scores)
243
244 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
245
246 #print 'Calling myalign...'
247 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
248 currentAlignment.myalign( num_path[exampleIdx], dna, dna_len,\
249 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
250 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
251 print_matrix)
252
253 #print 'After calling myalign...'
254 #print 'Calling getAlignmentResults...'
255
256 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*num_path[exampleIdx]))
257 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*num_path[exampleIdx]))
258 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*num_path[exampleIdx]))
259 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*num_path[exampleIdx])
260
261 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(Configuration.totalQualSuppPoints*num_path[exampleIdx]))
262
263 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
264 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
265
266 #print 'After calling getAlignmentResults...'
267
268 newSpliceAlign = zeros((num_path[exampleIdx]*dna_len,1))
269 newEstAlign = zeros((est_len*num_path[exampleIdx],1))
270 newWeightMatch = zeros((num_path[exampleIdx]*mm_len,1))
271 newDPScores = zeros((num_path[exampleIdx],1))
272 newQualityPlifsFeatures = zeros((Configuration.totalQualSuppPoints*num_path[exampleIdx],1))
273
274 #print 'newSpliceAlign'
275 for i in range(dna_len*num_path[exampleIdx]):
276 newSpliceAlign[i] = c_SpliceAlign[i]
277 # print '%f' % (spliceAlign[i])
278
279 #print 'newEstAlign'
280 for i in range(est_len*num_path[exampleIdx]):
281 newEstAlign[i] = c_EstAlign[i]
282 # print '%f' % (spliceAlign[i])
283
284 #print 'weightMatch'
285 for i in range(mm_len*num_path[exampleIdx]):
286 newWeightMatch[i] = c_WeightMatch[i]
287 # print '%f' % (weightMatch[i])
288
289 #print 'ViterbiScores'
290 for i in range(num_path[exampleIdx]):
291 newDPScores[i] = c_DPScores[i]
292
293
294 if use_quality_scores:
295 for i in range(Configuration.totalQualSuppPoints*num_path[exampleIdx]):
296 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
297
298 # equals palma up to here
299
300 #print "Calling destructors"
301 del c_SpliceAlign
302 del c_EstAlign
303 del c_WeightMatch
304 del c_DPScores
305 del c_qualityPlifsFeatures
306 del currentAlignment
307
308 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
309 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
310 # Calculate weights of the respective alignments Note that we are
311 # calculating n-best alignments without hamming loss, so we
312 # have to keep track which of the n-best alignments correspond to
313 # the true one in order not to incorporate a true alignment in the
314 # constraints. To keep track of the true and false alignments we
315 # define an array true_map with a boolean indicating the
316 # equivalence to the true alignment for each decoded alignment.
317 true_map = [0]*(num_path[exampleIdx]+1)
318 true_map[0] = 1
319 path_loss = [0]*(num_path[exampleIdx])
320
321 for pathNr in range(num_path[exampleIdx]):
322 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,
323 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,
324 acc_supp,True)
325
326 decodedQualityFeatures = zeros((Configuration.totalQualSuppPoints,1))
327 for qidx in range(Configuration.totalQualSuppPoints):
328 decodedQualityFeatures[qidx] = newQualityPlifsFeatures[(pathNr*Configuration.totalQualSuppPoints)+qidx]
329
330 #pdb.set_trace()
331
332 path_loss[pathNr] = 0
333 # sum up positionwise loss between alignments
334 for alignPosIdx in range(newSpliceAlign[pathNr,:].shape[1]):
335 if newSpliceAlign[pathNr,alignPosIdx] != trueSpliceAlign[alignPosIdx]:
336 path_loss[pathNr] += 1
337
338 # Gewichte in restliche Zeilen der Matrix speichern
339 wp = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures])
340 allWeights[:,pathNr+1] = wp
341
342 hpen = mat(h.penalties).reshape(len(h.penalties),1)
343 dpen = mat(d.penalties).reshape(len(d.penalties),1)
344 apen = mat(a.penalties).reshape(len(a.penalties),1)
345 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties])
346
347 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
348
349 # Check wether scalar product + loss equals viterbi score
350 print 'Example nr.: %d, path nr. %d, scores: %f vs %f' % (exampleIdx,pathNr,newDPScores[pathNr,0], AlignmentScores[pathNr+1])
351
352 distinct_scores = False
353 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
354 distinct_scores = True
355
356 #if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) + [0,1][distinct_scores and (pathNr>0)] <= 1e-5:
357 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
358 pdb.set_trace()
359
360 # # if the pathNr-best alignment is very close to the true alignment consider it as true
361 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
362 true_map[pathNr+1] = 1
363
364 # assert AlignmentScores[0] > max(AlignmentScores[1:]) + 1e-6, pdb.set_trace()
365
366 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
367 # this means that all n-best paths are to close to each other
368 # we have to extend the n-best search to a (n+1)-best
369 if len([elem for elem in true_map if elem == 1]) == len(true_map):
370 num_path[exampleIdx] = num_path[exampleIdx]+1
371
372 # Choose true and first false alignment for extending A
373 firstFalseIdx = -1
374 for map_idx,elem in enumerate(true_map):
375 if elem == 0:
376 firstFalseIdx = map_idx
377 break
378
379 # if there is at least one useful false alignment add the
380 # corresponding constraints to the optimization problem
381 if firstFalseIdx != -1:
382 trueWeights = allWeights[:,0]
383 firstFalseWeights = allWeights[:,firstFalseIdx]
384 differenceVector = trueWeights - firstFalseWeights
385 #pdb.set_trace()
386
387 if Configuration.USE_OPT:
388 const_added = solver.addConstraint(differenceVector, exampleIdx)
389 const_added_ctr += 1
390 #
391 # end of one example processing
392 #
393
394 # call solver every nth example //added constraint
395 if exampleIdx != 0 and exampleIdx % 20 == 0 and Configuration.USE_OPT:
396 objValue,w,self.slacks = solver.solve()
397
398 print "objValue is %f" % objValue
399
400 sum_xis = 0
401 for elem in self.slacks:
402 sum_xis += elem
403
404 for i in range(len(param)):
405 param[i] = w[i]
406
407 #pdb.set_trace()
408 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
409 param_idx += 1
410 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
411
412 #
413 # end of one iteration through all examples
414 #
415 iteration_nr += 1
416
417 #
418 # end of optimization
419 #
420 print 'Training completed'
421
422 pa = para()
423 pa.h = h
424 pa.d = d
425 pa.a = a
426 pa.mmatrix = mmatrix
427 pa.qualityPlifs = qualityPlifs
428
429 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
430 #cPickle.dump(pa,open('elegans.param','w+'))
431 self.logfh.close()
432
433 def fetch_genome_info(ginfo_filename):
434 if not os.path.exists(ginfo_filename):
435 cmd = ['']*4
436 cmd[0] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/utils'
437 cmd[1] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/genomes'
438 cmd[2] = 'genome_info = init_genome(\'%s\')' % gen_file
439 cmd[3] = 'save genome_info.mat genome_info'
440 full_cmd = "matlab -nojvm -nodisplay -r \"%s; %s; %s; %s; exit\"" % (cmd[0],cmd[1],cmd[2],cmd[3])
441
442 obj = subprocess.Popen(full_cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
443 out,err = obj.communicate()
444 assert err == '', 'An error occured!\n%s'%err
445
446 ginfo = scipy.io.loadmat('genome_info.mat')
447 cPickle.dump(self.genome_info,open(ginfo_filename,'w+'))
448 return ginfo['genome_info']
449
450 else:
451 return cPickle.load(open(ginfo_filename))
452
453 if __name__ == '__main__':
454 qpalma = QPalma()
455 qpalma.run()