+ fixed if/else cases bug in the increaseFeatureCount function of result_align
[qpalma.git] / python / qpalma.py
1 #qualityquality!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 #
7 #
8 ###########################################################
9
10 import sys
11 import subprocess
12 import scipy.io
13 import pdb
14 import os.path
15
16 from numpy.matlib import mat,zeros,ones,inf
17 from numpy.linalg import norm
18
19 import QPalmaDP
20
21 from SIQP_CPX import SIQPSolver
22
23 from paths_load_data import *
24 from paths_load_data_pickle import *
25 from paths_load_data_solexa import *
26
27 from computeSpliceWeights import *
28 from set_param_palma import *
29 from computeSpliceAlignWithQuality import *
30 from penalty_lookup_new import *
31 from compute_donacc import *
32 from TrainingParam import Param
33 from export_param import *
34
35 import Configuration
36 from Plif import Plf
37
38 def getQualityFeatureCounts(qualityPlifs):
39 weightQuality = qualityPlifs[0].penalties
40 for currentPlif in qualityPlifs[1:]:
41 weightQuality = numpy.vstack([weightQuality, currentPlif.penalties])
42
43 return weightQuality
44
45 class QPalma:
46 """
47 A training method for the QPalma project
48 """
49
50 def __init__(self):
51 self.ARGS = Param()
52
53 self.logfh = open('qpalma.log','w+')
54 gen_file= '%s/genome.config' % self.ARGS.basedir
55
56 ginfo_filename = 'genome_info.pickle'
57 self.genome_info = fetch_genome_info(ginfo_filename)
58
59 self.plog('genome_info.basedir is %s\n'%self.genome_info.basedir)
60
61 #self.ARGS.train_with_splicesitescoreinformation = False
62
63 def plog(self,string):
64 self.logfh.write(string)
65 self.logfh.flush()
66
67 def run(self):
68 # Load the whole dataset
69 if Configuration.mode == 'normal':
70 Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
71 elif Configuration.mode == 'using_quality_scores':
72 Sequences, Acceptors, Donors, Exons, Ests, Qualities = paths_load_data_solexa('training',self.genome_info,self.ARGS)
73 else:
74 assert(False)
75
76 # number of training instances
77 N = len(Sequences)
78 self.numExamples = N
79 assert N == len(Acceptors) and N == len(Acceptors) and N == len(Exons)\
80 and N == len(Ests), 'The Seq,Accept,Donor,.. arrays are of different lengths'
81 self.plog('Number of training examples: %d\n'% N)
82 print 'Number of features: %d\n'% Configuration.numFeatures
83
84 iteration_steps = Configuration.iter_steps ; #upper bound on iteration steps
85 remove_duplicate_scores = Configuration.remove_duplicate_scores
86 print_matrix = Configuration.print_matrix
87 anzpath = Configuration.anzpath
88
89 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
90 param = Configuration.fixedParam
91 #param = cPickle.load(open('initial_param.pickle'))
92
93 # Set the parameters such as limits penalties for the Plifs
94 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
95
96 # delete splicesite-score-information
97 if not self.ARGS.train_with_splicesitescoreinformation:
98 for i in range(len(Acceptors)):
99 if Acceptors[i] > -20:
100 Acceptors[i] = 1
101 if Donors[i] >-20:
102 Donors[i] = 1
103
104 # Initialize solver
105 if not __debug__:
106 self.plog('Initializing problem...\n')
107 solver = SIQPSolver(Configuration.numFeatures,self.numExamples,Configuration.C,self.logfh)
108
109 # stores the number of alignments done for each example (best path, second-best path etc.)
110 num_path = [anzpath]*N
111 # stores the gap for each example
112 gap = [0.0]*N
113
114 #############################################################################################
115 # Training
116 #############################################################################################
117 self.plog('Starting training...\n')
118
119 donSP = Configuration.numDonSuppPoints
120 accSP = Configuration.numAccSuppPoints
121 lengthSP = Configuration.numLengthSuppPoints
122 mmatrixSP = Configuration.sizeMatchmatrix[0]\
123 *Configuration.sizeMatchmatrix[1]
124 numq = Configuration.numQualSuppPoints
125 totalQualSP = Configuration.totalQualSuppPoints
126
127 currentPhi = zeros((Configuration.numFeatures,1))
128 totalQualityPenalties = zeros((totalQualSP,1))
129
130 iteration_nr = 0
131 while True:
132 if iteration_nr == iteration_steps:
133 break
134
135 #for exampleIdx in range(self.numExamples):
136 for exampleIdx in range(6,11):
137 if (exampleIdx%10) == 0:
138 print 'Current example nr %d' % exampleIdx
139
140 dna = Sequences[exampleIdx]
141 est = Ests[exampleIdx]
142 quality = Qualities[exampleIdx]
143 quality = [40]*len(quality )
144 exons = Exons[exampleIdx]
145 # NoiseMatrix = Noises[exampleIdx]
146 don_supp = Donors[exampleIdx]
147 acc_supp = Acceptors[exampleIdx]
148
149 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
150 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
151
152 # Calculate the weights
153 trueWeightDon, trueWeightAcc, trueWeightIntron = computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
154
155 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
156 totalQualityPenalties = param[-totalQualSP:]
157
158 currentPhi[0:donSP] = mat(d.penalties[:]).reshape(donSP,1)
159 currentPhi[donSP:donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
160 currentPhi[donSP+accSP:donSP+accSP+lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
161 currentPhi[donSP+accSP+lengthSP:donSP+accSP+lengthSP+mmatrixSP] = mmatrix[:]
162 currentPhi[donSP+accSP+lengthSP+mmatrixSP:] = totalQualityPenalties[:]
163
164 # Calculate w'phi(x,y) the total score of the alignment
165 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
166
167 # The allWeights vector is supposed to store the weight parameter
168 # of the true alignment as well as the weight parameters of the
169 # num_path[exampleIdx] other alignments
170 allWeights = zeros((Configuration.numFeatures,num_path[exampleIdx]+1))
171 allWeights[:,0] = trueWeight[:,0]
172
173 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
174 AlignmentScores[0] = trueAlignmentScore
175
176 ################## Calculate wrong alignment(s) ######################
177
178 # Compute donor, acceptor with penalty_lookup_new
179 # returns two double lists
180 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
181
182 #myalign wants the acceptor site on the g of the ag
183 acceptor = acceptor[1:]
184 acceptor.append(-inf)
185
186 # for now we don't use donor/acceptor scores
187 donor = [-inf] * len(donor)
188 acceptor = [-inf] * len(donor)
189 #
190
191 dna = str(dna)
192 est = str(est)
193 dna_len = len(dna)
194 est_len = len(est)
195 ps = h.convert2SWIG()
196
197 prb = QPalmaDP.createDoubleArrayFromList(quality)
198 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
199
200 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
201 mm_len = 36
202
203 d_len = len(donor)
204 donor = QPalmaDP.createDoubleArrayFromList(donor)
205 a_len = len(acceptor)
206 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
207
208 currentAlignment = QPalmaDP.Alignment(Configuration.numQualPlifs,Configuration.numQualSuppPoints)
209
210 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
211
212 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
213 currentAlignment.myalign( num_path[exampleIdx], dna, dna_len,\
214 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
215 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores, print_matrix)
216
217 print 'Python: after myalign...'
218
219 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*num_path[exampleIdx]))
220 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*num_path[exampleIdx]))
221 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*num_path[exampleIdx]))
222 c_AlignmentScores = QPalmaDP.createDoubleArrayFromList([.0]*num_path[exampleIdx])
223
224 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(Configuration.numQualPlifs*num_path[exampleIdx]))
225
226 #currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
227 #c_WeightMatch, c_AlignmentScores, c_qualityPlifsFeatures)
228
229 print 'Python: after getAlignmentResults...'
230
231 newSpliceAlign = zeros((num_path[exampleIdx]*dna_len,1))
232 newEstAlign = zeros((est_len*num_path[exampleIdx],1))
233 newWeightMatch = zeros((num_path[exampleIdx]*mm_len,1))
234
235 #print 'newSpliceAlign'
236 for i in range(dna_len*num_path[exampleIdx]):
237 newSpliceAlign[i] = c_SpliceAlign[i]
238 # print '%f' % (spliceAlign[i])
239
240 #print 'newEstAlign'
241 for i in range(est_len*num_path[exampleIdx]):
242 newEstAlign[i] = c_EstAlign[i]
243 # print '%f' % (spliceAlign[i])
244
245 #print 'weightMatch'
246 for i in range(mm_len*num_path[exampleIdx]):
247 newWeightMatch[i] = c_WeightMatch[i]
248 # print '%f' % (weightMatch[i])
249
250 #print 'AlignmentScores'
251 for i in range(num_path[exampleIdx]):
252 AlignmentScores[i+1] = c_AlignmentScores[i]
253
254 print "Calling destructors"
255 del c_SpliceAlign
256 del c_EstAlign
257 del c_WeightMatch
258 del c_AlignmentScores
259 del c_qualityPlifsFeatures
260 del currentAlignment
261
262 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
263 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
264 # Calculate weights of the respective alignments Note that we are
265 # calculating n-best alignments without hamming loss, so we
266 # have to keep track which of the n-best alignments correspond to
267 # the true one in order not to incorporate a true alignment in the
268 # constraints. To keep track of the true and false alignments we
269 # define an array true_map with a boolean indicating the
270 # equivalence to the true alignment for each decoded alignment.
271 true_map = [0]*(num_path[exampleIdx]+1)
272 true_map[0] = 1
273 path_loss = [0]*(num_path[exampleIdx]+1)
274
275 print 'Calculating decoded features'
276
277 for pathNr in range(num_path[exampleIdx]):
278 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a, h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp, acc_supp)
279
280 decodedQualityFeatures = zeros((Configuration.totalQualSuppPoints,1))
281
282 #for qidx in range(Configuration.numQualPlifs*pathNr,Configuration.numQualPlifs*(pathNr+1)):
283 # decodedQualityFeatures[qidx%Configuration.numQualPlifs] = c_qualityPlifsFeatures[1]
284
285 # sum up positionwise loss between alignments
286 for alignPosIdx in range(len(newSpliceAlign[pathNr,:])):
287 if newSpliceAlign[pathNr,alignPosIdx] != trueSpliceAlign[alignPosIdx]:
288 path_loss[pathNr+1] += 1
289
290 # Gewichte in restliche Zeilen der Matrix speichern
291 wp = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures])
292 allWeights[:,pathNr+1] = wp
293
294 hpen = mat(h.penalties).reshape(len(h.penalties),1)
295 dpen = mat(d.penalties).reshape(len(d.penalties),1)
296 apen = mat(a.penalties).reshape(len(a.penalties),1)
297
298 features = numpy.vstack([hpen , dpen , apen , mmatrix[:], zeros((Configuration.totalQualSuppPoints,1))])
299 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
300
301 # Check wether scalar product + loss equals viterbi score
302 #assert math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) < 1e-6,\
303 #'Scalar prod + loss is not equal Viterbi score. Respective values are %f, %f' % \
304 #(AlignmentScores[pathNr],AlignmentScores[pathNr+1])
305
306 # # if the pathNr-best alignment is very close to the true alignment consider it as true
307 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
308 true_map[pathNr+1] = 1
309
310 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
311 # this means that all n-best paths are to close to each other
312 # we have to extend the n-best search to a (n+1)-best
313 if len([elem for elem in true_map if elem == 1]) == len(true_map):
314 num_path[exampleIdx] = num_path[exampleIdx]+1
315
316 # Choose true and first false alignment for extending A
317 firstFalseIdx = -1
318 for map_idx,elem in enumerate(true_map):
319 if elem == 0:
320 firstFalseIdx = map_idx
321 break
322
323 # if there is at least one useful false alignment add the
324 # corresponding constraints to the optimization problem
325 if firstFalseIdx != -1:
326 trueWeights = allWeights[:,0]
327 firstFalseWeights = allWeights[:,firstFalseIdx]
328 differenceVector = firstFalseWeights - trueWeights
329 #pdb.set_trace()
330
331 if not __debug__:
332 const_added = solver.addConstraint(differenceVector, exampleIdx)
333
334 #
335 # end of one example processing
336 #
337
338 # call solver every nth step
339 if exampleIdx != 0 and exampleIdx % 3 == 0 and not __debug__:
340 objValue,w,self.slacks = solver.solve()
341
342 print "objValue is %f" % objValue
343
344 sum_xis = 0
345 for elem in self.slacks:
346 sum_xis += elem
347
348 for i in range(len(param)):
349 param[i] = w[i]
350
351 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
352
353 # break
354 #break
355
356 #
357 # end of one iteration through all examples
358 #
359 iteration_nr += 1
360
361 #
362 # end of optimization
363 #
364 print 'Training completed'
365 #pdb.set_trace()
366 export_param('elegans.param',h,d,a,mmatrix)
367 self.logfh.close()
368
369 def fetch_genome_info(ginfo_filename):
370 if not os.path.exists(ginfo_filename):
371 cmd = ['']*4
372 cmd[0] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/utils'
373 cmd[1] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/genomes'
374 cmd[2] = 'genome_info = init_genome(\'%s\')' % gen_file
375 cmd[3] = 'save genome_info.mat genome_info'
376 full_cmd = "matlab -nojvm -nodisplay -r \"%s; %s; %s; %s; exit\"" % (cmd[0],cmd[1],cmd[2],cmd[3])
377
378 obj = subprocess.Popen(full_cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
379 out,err = obj.communicate()
380 assert err == '', 'An error occured!\n%s'%err
381
382 ginfo = scipy.io.loadmat('genome_info.mat')
383 cPickle.dump(self.genome_info,open(ginfo_filename,'w+'))
384 return ginfo['genome_info']
385
386 else:
387 return cPickle.load(open(ginfo_filename))
388
389 if __name__ == '__main__':
390 qpalma = QPalma()
391 qpalma.run()