9f47d330da8ae65573a83337852e475e546ea239
[qpalma.git] / python / qpalma.py
1 #qualityquality!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 #
7 #
8 ###########################################################
9
10 import sys
11 import subprocess
12 import scipy.io
13 import pdb
14 import os.path
15
16 from numpy.matlib import mat,zeros,ones,inf
17 from numpy.linalg import norm
18
19 import QPalmaDP
20
21 from SIQP_CPX import SIQPSolver
22
23 from paths_load_data import *
24 from paths_load_data_pickle import *
25 from paths_load_data_solexa import *
26
27 from computeSpliceWeights import *
28 from set_param_palma import *
29 from computeSpliceAlignWithQuality import *
30 from penalty_lookup_new import *
31 from compute_donacc import *
32 from TrainingParam import Param
33 from export_param import *
34
35 import Configuration
36 from Plif import Plf
37
38 def getQualityFeatureCounts(qualityPlifs):
39 weightQuality = qualityPlifs[0].penalties
40 for currentPlif in qualityPlifs[1:]:
41 weightQuality = numpy.vstack([weightQuality, currentPlif.penalties])
42
43 return weightQuality
44
45 class QPalma:
46 """
47 A training method for the QPalma project
48 """
49
50 def __init__(self):
51 self.ARGS = Param()
52
53 self.logfh = open('qpalma.log','w+')
54 gen_file= '%s/genome.config' % self.ARGS.basedir
55
56 ginfo_filename = 'genome_info.pickle'
57 self.genome_info = fetch_genome_info(ginfo_filename)
58
59 self.plog('genome_info.basedir is %s\n'%self.genome_info.basedir)
60
61 #self.ARGS.train_with_splicesitescoreinformation = False
62
63 def plog(self,string):
64 self.logfh.write(string)
65 self.logfh.flush()
66
67 def run(self):
68 # Load the whole dataset
69 if Configuration.mode == 'normal':
70 Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
71 use_quality_scores = False
72 elif Configuration.mode == 'using_quality_scores':
73 Sequences, Acceptors, Donors, Exons, Ests, Qualities = paths_load_data_solexa('training',self.genome_info,self.ARGS)
74 use_quality_scores = True
75 else:
76 assert(False)
77
78 # number of training instances
79 N = len(Sequences)
80 self.numExamples = N
81 assert N == len(Acceptors) and N == len(Acceptors) and N == len(Exons)\
82 and N == len(Ests), 'The Seq,Accept,Donor,.. arrays are of different lengths'
83 self.plog('Number of training examples: %d\n'% N)
84 print 'Number of features: %d\n'% Configuration.numFeatures
85
86 iteration_steps = Configuration.iter_steps ; #upper bound on iteration steps
87 remove_duplicate_scores = Configuration.remove_duplicate_scores
88 print_matrix = Configuration.print_matrix
89 anzpath = Configuration.anzpath
90
91 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
92 param = Configuration.fixedParam
93 #param = cPickle.load(open('initial_param.pickle'))
94
95 # Set the parameters such as limits penalties for the Plifs
96 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
97
98 # delete splicesite-score-information
99 if not self.ARGS.train_with_splicesitescoreinformation:
100 for i in range(len(Acceptors)):
101 if Acceptors[i] > -20:
102 Acceptors[i] = 1
103 if Donors[i] >-20:
104 Donors[i] = 1
105
106 # Initialize solver
107 if not __debug__:
108 self.plog('Initializing problem...\n')
109 solver = SIQPSolver(Configuration.numFeatures,self.numExamples,Configuration.C,self.logfh)
110
111 # stores the number of alignments done for each example (best path, second-best path etc.)
112 num_path = [anzpath]*N
113 # stores the gap for each example
114 gap = [0.0]*N
115
116 #############################################################################################
117 # Training
118 #############################################################################################
119 self.plog('Starting training...\n')
120
121 donSP = Configuration.numDonSuppPoints
122 accSP = Configuration.numAccSuppPoints
123 lengthSP = Configuration.numLengthSuppPoints
124 mmatrixSP = Configuration.sizeMatchmatrix[0]\
125 *Configuration.sizeMatchmatrix[1]
126 numq = Configuration.numQualSuppPoints
127 totalQualSP = Configuration.totalQualSuppPoints
128
129 currentPhi = zeros((Configuration.numFeatures,1))
130 totalQualityPenalties = zeros((totalQualSP,1))
131
132 iteration_nr = 0
133 while True:
134 if iteration_nr == iteration_steps:
135 break
136
137 for exampleIdx in range(self.numExamples):
138
139 if (exampleIdx%10) == 0:
140 print 'Current example nr %d' % exampleIdx
141
142 dna = Sequences[exampleIdx]
143 est = Ests[exampleIdx]
144
145 if Configuration.mode == 'normal':
146 quality = [40]*len(est)
147
148 if Configuration.mode == 'using_quality_scores':
149 quality = Qualities[exampleIdx]
150
151 exons = Exons[exampleIdx]
152 # NoiseMatrix = Noises[exampleIdx]
153 don_supp = Donors[exampleIdx]
154 acc_supp = Acceptors[exampleIdx]
155
156 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
157 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
158
159 # Calculate the weights
160 trueWeightDon, trueWeightAcc, trueWeightIntron = computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
161 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
162
163 currentPhi[0:donSP] = mat(d.penalties[:]).reshape(donSP,1)
164 currentPhi[donSP:donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
165 currentPhi[donSP+accSP:donSP+accSP+lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
166 currentPhi[donSP+accSP+lengthSP:donSP+accSP+lengthSP+mmatrixSP] = mmatrix[:]
167
168 if Configuration.mode == 'using_quality_scores':
169 totalQualityPenalties = param[-totalQualSP:]
170 currentPhi[donSP+accSP+lengthSP+mmatrixSP:] = totalQualityPenalties[:]
171
172 # Calculate w'phi(x,y) the total score of the alignment
173 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
174
175 # The allWeights vector is supposed to store the weight parameter
176 # of the true alignment as well as the weight parameters of the
177 # num_path[exampleIdx] other alignments
178 allWeights = zeros((Configuration.numFeatures,num_path[exampleIdx]+1))
179 allWeights[:,0] = trueWeight[:,0]
180
181 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
182 AlignmentScores[0] = trueAlignmentScore
183
184
185 ################## Calculate wrong alignment(s) ######################
186
187 # Compute donor, acceptor with penalty_lookup_new
188 # returns two double lists
189 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
190
191 #myalign wants the acceptor site on the g of the ag
192 acceptor = acceptor[1:]
193 acceptor.append(-inf)
194
195 # for now we don't use donor/acceptor scores
196 donor = [-inf] * len(donor)
197 acceptor = [-inf] * len(donor)
198
199 dna = str(dna)
200 est = str(est)
201 dna_len = len(dna)
202 est_len = len(est)
203 ps = h.convert2SWIG()
204
205 prb = QPalmaDP.createDoubleArrayFromList(quality)
206 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
207
208 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
209 mm_len = 36
210
211 d_len = len(donor)
212 donor = QPalmaDP.createDoubleArrayFromList(donor)
213 a_len = len(acceptor)
214 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
215
216 currentAlignment = QPalmaDP.Alignment(Configuration.numQualPlifs,Configuration.numQualSuppPoints)
217
218 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
219
220 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
221 currentAlignment.myalign( num_path[exampleIdx], dna, dna_len,\
222 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
223 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
224 print_matrix, use_quality_scores)
225
226 #print 'Python: after myalign...'
227
228 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*num_path[exampleIdx]))
229 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*num_path[exampleIdx]))
230 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*num_path[exampleIdx]))
231 c_AlignmentScores = QPalmaDP.createDoubleArrayFromList([.0]*num_path[exampleIdx])
232
233 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(Configuration.numQualSuppPoints*Configuration.numQualPlifs*num_path[exampleIdx]))
234
235 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
236 c_WeightMatch, c_AlignmentScores, c_qualityPlifsFeatures)
237
238 #print 'Python: after getAlignmentResults...'
239
240 newSpliceAlign = zeros((num_path[exampleIdx]*dna_len,1))
241 newEstAlign = zeros((est_len*num_path[exampleIdx],1))
242 newWeightMatch = zeros((num_path[exampleIdx]*mm_len,1))
243 newQualityPlifsFeatures = zeros((Configuration.numQualSuppPoints*Configuration.numQualPlifs*num_path[exampleIdx],1))
244
245 #print 'newSpliceAlign'
246 for i in range(dna_len*num_path[exampleIdx]):
247 newSpliceAlign[i] = c_SpliceAlign[i]
248 # print '%f' % (spliceAlign[i])
249
250 #print 'newEstAlign'
251 for i in range(est_len*num_path[exampleIdx]):
252 newEstAlign[i] = c_EstAlign[i]
253 # print '%f' % (spliceAlign[i])
254
255 #print 'weightMatch'
256 for i in range(mm_len*num_path[exampleIdx]):
257 newWeightMatch[i] = c_WeightMatch[i]
258 # print '%f' % (weightMatch[i])
259
260 #print 'AlignmentScores'
261 for i in range(num_path[exampleIdx]):
262 AlignmentScores[i+1] = c_AlignmentScores[i]
263
264 for i in range(Configuration.numQualSuppPoints*Configuration.numQualPlifs*num_path[exampleIdx]):
265 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
266
267 # equals palma up to here
268 pdb.set_trace()
269
270 # print "Calling destructors"
271 del c_SpliceAlign
272 del c_EstAlign
273 del c_WeightMatch
274 del c_AlignmentScores
275 del c_qualityPlifsFeatures
276 del currentAlignment
277
278 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
279 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
280 # Calculate weights of the respective alignments Note that we are
281 # calculating n-best alignments without hamming loss, so we
282 # have to keep track which of the n-best alignments correspond to
283 # the true one in order not to incorporate a true alignment in the
284 # constraints. To keep track of the true and false alignments we
285 # define an array true_map with a boolean indicating the
286 # equivalence to the true alignment for each decoded alignment.
287 true_map = [0]*(num_path[exampleIdx]+1)
288 true_map[0] = 1
289 path_loss = [0]*(num_path[exampleIdx]+1)
290
291 #print 'Calculating decoded features'
292
293 for pathNr in range(num_path[exampleIdx]):
294 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a, h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp, acc_supp)
295
296 decodedQualityFeatures = zeros((Configuration.totalQualSuppPoints,1))
297
298 #print "Calculating pathNr %d" % pathNr
299
300 for qidx in range(Configuration.numQualPlifs*pathNr,Configuration.numQualPlifs*(pathNr+1)):
301 decodedQualityFeatures[qidx%Configuration.numQualPlifs] = newQualityPlifsFeatures[qidx]
302
303 # sum up positionwise loss between alignments
304 for alignPosIdx in range(len(newSpliceAlign[pathNr,:])):
305 if newSpliceAlign[pathNr,alignPosIdx] != trueSpliceAlign[alignPosIdx]:
306 path_loss[pathNr+1] += 1
307
308 # Gewichte in restliche Zeilen der Matrix speichern
309 wp = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures])
310 allWeights[:,pathNr+1] = wp
311
312 hpen = mat(h.penalties).reshape(len(h.penalties),1)
313 dpen = mat(d.penalties).reshape(len(d.penalties),1)
314 apen = mat(a.penalties).reshape(len(a.penalties),1)
315
316 features = numpy.vstack([hpen , dpen , apen , mmatrix[:], zeros((Configuration.totalQualSuppPoints,1))])
317 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
318
319 # Check wether scalar product + loss equals viterbi score
320 #assert math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) < 1e-6,\
321 #'Scalar prod + loss is not equal Viterbi score. Respective values are %f, %f' % \
322 #(AlignmentScores[pathNr],AlignmentScores[pathNr+1])
323
324 # # if the pathNr-best alignment is very close to the true alignment consider it as true
325 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
326 true_map[pathNr+1] = 1
327
328 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
329 # this means that all n-best paths are to close to each other
330 # we have to extend the n-best search to a (n+1)-best
331 if len([elem for elem in true_map if elem == 1]) == len(true_map):
332 num_path[exampleIdx] = num_path[exampleIdx]+1
333
334 # Choose true and first false alignment for extending A
335 firstFalseIdx = -1
336 for map_idx,elem in enumerate(true_map):
337 if elem == 0:
338 firstFalseIdx = map_idx
339 break
340
341 # if there is at least one useful false alignment add the
342 # corresponding constraints to the optimization problem
343 if firstFalseIdx != -1:
344 trueWeights = allWeights[:,0]
345 firstFalseWeights = allWeights[:,firstFalseIdx]
346 differenceVector = firstFalseWeights - trueWeights
347 pdb.set_trace()
348
349 if not __debug__:
350 const_added = solver.addConstraint(differenceVector, exampleIdx)
351
352 #
353 # end of one example processing
354 #
355
356 # call solver every nth step
357 if exampleIdx != 0 and exampleIdx % 10 == 0 and not __debug__:
358 objValue,w,self.slacks = solver.solve()
359
360 print "objValue is %f" % objValue
361
362 sum_xis = 0
363 for elem in self.slacks:
364 sum_xis += elem
365
366 for i in range(len(param)):
367 param[i] = w[i]
368
369 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
370
371 # break
372 #break
373
374 #
375 # end of one iteration through all examples
376 #
377 iteration_nr += 1
378
379 #
380 # end of optimization
381 #
382 print 'Training completed'
383 #pdb.set_trace()
384 export_param('elegans.param',h,d,a,mmatrix)
385 self.logfh.close()
386
387 def fetch_genome_info(ginfo_filename):
388 if not os.path.exists(ginfo_filename):
389 cmd = ['']*4
390 cmd[0] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/utils'
391 cmd[1] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/genomes'
392 cmd[2] = 'genome_info = init_genome(\'%s\')' % gen_file
393 cmd[3] = 'save genome_info.mat genome_info'
394 full_cmd = "matlab -nojvm -nodisplay -r \"%s; %s; %s; %s; exit\"" % (cmd[0],cmd[1],cmd[2],cmd[3])
395
396 obj = subprocess.Popen(full_cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
397 out,err = obj.communicate()
398 assert err == '', 'An error occured!\n%s'%err
399
400 ginfo = scipy.io.loadmat('genome_info.mat')
401 cPickle.dump(self.genome_info,open(ginfo_filename,'w+'))
402 return ginfo['genome_info']
403
404 else:
405 return cPickle.load(open(ginfo_filename))
406
407 if __name__ == '__main__':
408 qpalma = QPalma()
409 qpalma.run()