+ incorporated feature count for prb and chastity Plifs in result_align
[qpalma.git] / python / qpalma.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 #
7 #
8 ###########################################################
9
10 import sys
11 import subprocess
12 import scipy.io
13 import pdb
14
15 from numpy.matlib import mat,zeros,ones,inf
16 from numpy.linalg import norm
17
18 import QPalmaDP
19
20 from SIQP_CPX import SIQPSolver
21
22 from paths_load_data import *
23 from paths_load_data_pickle import *
24
25 from computeSpliceWeights import *
26 from set_param_palma import *
27 from computeSpliceAlign import *
28 from penalty_lookup_new import *
29 from compute_donacc import *
30 from TrainingParam import Param
31 from export_param import *
32
33 import Configuration
34
35
36
37 def initializeQualityScoringFunctions(numPlifs,numSuppPoints):
38
39 min_intron_len=20
40 max_intron_len=1000
41 min_svm_score=-5
42 max_svm_score=5
43
44 qualityPlifs = [None]*numPlifs
45
46 for idx in range(numPlifs):
47
48 curPlif = Plf()
49 curPlif.limits = linspace(min_svm_score,max_svm_score,numSuppPoints)
50 curPlif.penalties = [0]*numSuppPoints
51 curPlif.transform = ''
52 curPlif.name = ''
53 curPlif.max_len = 100
54 curPlif.min_len = -100
55 curPlif.id = 1
56 curPlif.use_svm = 0
57 curPlif.next_id = 0
58
59 if idx == 0:
60 curPlif.penalties[0] = 11
61 curPlif.penalties[1] = 22
62 curPlif.penalties[2] = 33
63
64 if idx == 1:
65 curPlif.penalties[0] = 99
66 curPlif.penalties[1] = 100
67 curPlif.penalties[2] = 101
68
69 curPlif = curPlif.convert2SWIG()
70 qualityPlifs[idx] = curPlif
71
72 return qualityPlifs
73
74 class QPalma:
75 """
76 A training method for the QPalma project
77 """
78
79 def __init__(self):
80 self.ARGS = Param()
81
82 self.logfh = open('qpalma.log','w+')
83 gen_file= '%s/genome.config' % self.ARGS.basedir
84
85 cmd = ['']*4
86 cmd[0] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/utils'
87 cmd[1] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/genomes'
88 cmd[2] = 'genome_info = init_genome(\'%s\')' % gen_file
89 cmd[3] = 'save genome_info.mat genome_info'
90 full_cmd = "matlab -nojvm -nodisplay -r \"%s; %s; %s; %s; exit\"" % (cmd[0],cmd[1],cmd[2],cmd[3])
91
92 obj = subprocess.Popen(full_cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
93 out,err = obj.communicate()
94 assert err == '', 'An error occured!\n%s'%err
95
96 ginfo = scipy.io.loadmat('genome_info.mat')
97 self.genome_info = ginfo['genome_info']
98
99 self.plog('genome_info.basedir is %s\n'%self.genome_info.basedir)
100
101 self.C=1.0
102
103 # 'normal' means work like Palma
104 # 'using_quality_scores' means work like Palma plus using sequencing
105 # quality scores
106 self.mode = 'normal'
107 #self.mode = 'using_quality_scores'
108
109 # Here we specify the total number of parameters.
110 # When using quality scores our scoring function is defined as
111 #
112 # f: S x R x S -> R
113 #
114 # as opposed to a usage without quality scores when we only have
115 #
116 # f: S x S -> R
117 #
118 self.numDonSuppPoints = 30
119 self.numAccSuppPoints = 30
120 self.numLengthSuppPoints = 30
121 if self.mode == 'normal':
122 self.sizeMMatrix = 36
123 elif self.mode == 'using_quality_scores':
124 self.sizeMMatrix = 728
125 else:
126 assert False, 'Wrong operation mode specified'
127
128 # this number defines the number of support points for one tuple (a,b)
129 # where 'a' comes with a quality score
130 self.numQualSuppPoints = 10
131 self.numQualSuppPoints = 0
132
133 self.numFeatures = self.numDonSuppPoints + self.numAccSuppPoints\
134 + self.numLengthSuppPoints + self.sizeMMatrix
135
136 self.plog('Initializing problem...\n')
137
138
139 def plog(self,string):
140 self.logfh.write(string)
141
142
143 def run(self):
144 # Load the whole dataset
145 #Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data('training',self.genome_info,self.ARGS)
146 Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
147
148 # number of training instances
149 N = len(Sequences)
150 self.numExamples = N
151 assert N == len(Acceptors) and N == len(Acceptors) and N == len(Exons)\
152 and N == len(Ests), 'The Seq,Accept,Donor,.. arrays are of different lengths'
153 self.plog('Number of training examples: %d\n'% N)
154
155 #iteration_steps = 200 ; #upper bound on iteration steps
156 iteration_steps = 2 ; #upper bound on iteration steps
157
158 remove_duplicate_scores = False
159 print_matrix = False
160 anzpath = 2
161
162 # Initialize parameter vector
163 # param = numpy.matlib.rand(126,1)
164 param = Configuration.fixedParam
165
166 # Set the parameters such as limits penalties for the Plifs
167 [h,d,a,mmatrix] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
168
169 # delete splicesite-score-information
170 if not self.ARGS.train_with_splicesitescoreinformation:
171 for i in range(len(Acceptors)):
172 if Acceptors[i] > -20:
173 Acceptors[i] = 1
174 if Donors[i] >-20:
175 Donors[i] = 1
176
177 # Initialize solver
178 if not __debug__:
179 solver = SIQPSolver(self.numFeatures,self.numExamples,self.C,self.logfh)
180
181 # stores the number of alignments done for each example (best path, second-best path etc.)
182 num_path = [anzpath]*N
183 # stores the gap for each example
184 gap = [0.0]*N
185
186 qualityMatrix = zeros((self.numQualSuppPoints,1))
187
188 numPlifs = 24
189 numSuppPoints = 30
190
191 qualityPlifs = initializeQualityScoringFunctions(numPlifs,numSuppPoints)
192
193 #############################################################################################
194 # Training
195 #############################################################################################
196 self.plog('Starting training...\n')
197
198 iteration_nr = 0
199
200 while True:
201 if iteration_nr == iteration_steps:
202 break
203
204 for exampleIdx in range(self.numExamples):
205 if (exampleIdx%10) == 0:
206 print 'Current example nr %d' % exampleIdx
207
208 dna = Sequences[exampleIdx]
209 est = Ests[exampleIdx]
210
211 exons = Exons[exampleIdx]
212 # NoiseMatrix = Noises[exampleIdx]
213 don_supp = Donors[exampleIdx]
214 acc_supp = Acceptors[exampleIdx]
215
216 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
217 trueSpliceAlign, trueWeightMatch = computeSpliceAlign(dna, exons)
218 trueSpliceAlign, trueWeightMatch, trueQualityPlifs = computeSpliceAlignWithQuality(dna, exons, qualityPlifs)
219
220 # Calculate the weights
221 trueWeightDon, trueWeightAcc, trueWeightIntron = computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
222 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, qualityMatrix ])
223
224 currentPhi = zeros((self.numFeatures,1))
225 currentPhi[0:30] = mat(d.penalties[:]).reshape(30,1)
226 currentPhi[30:60] = mat(a.penalties[:]).reshape(30,1)
227 currentPhi[60:90] = mat(h.penalties[:]).reshape(30,1)
228 currentPhi[90:126] = mmatrix[:]
229 currentPhi[126:] = qualityMatrix[:]
230
231 # Calculate w'phi(x,y) the total score of the alignment
232 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
233
234 # The allWeights vector is supposed to store the weight parameter
235 # of the true alignment as well as the weight parameters of the
236 # num_path[exampleIdx] other alignments
237 allWeights = zeros((self.numFeatures,num_path[exampleIdx]+1))
238 allWeights[:,0] = trueWeight[:,0]
239
240 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
241 AlignmentScores[0] = trueAlignmentScore
242
243 ################## Calculate wrong alignment(s) ######################
244
245 # Compute donor, acceptor with penalty_lookup_new
246 # returns two double lists
247 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
248
249 #myalign wants the acceptor site on the g of the ag
250 acceptor = acceptor[1:]
251 acceptor.append(-inf)
252
253 dna = str(dna)
254 est = str(est)
255 dna_len = len(dna)
256 est_len = len(est)
257 ps = h.convert2SWIG()
258
259 prb = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
260 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
261
262 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
263 mm_len = 36
264
265 d_len = len(donor)
266 donor = QPalmaDP.createDoubleArrayFromList(donor)
267 a_len = len(acceptor)
268 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
269
270 currentAlignment = QPalmaDP.Alignment()
271 qualityMat = QPalmaDP.createDoubleArrayFromList(qualityMatrix)
272 currentAlignment.setQualityMatrix(qualityMat,self.numQualSuppPoints)
273
274 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList(qualityPlifs)
275
276 #print 'PYTHON: Calling myalign...'
277 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
278 currentAlignment.myalign( num_path[exampleIdx], dna, dna_len,\
279 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
280 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores, print_matrix)
281 #print 'PYTHON: After myalign call...'
282
283 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*num_path[exampleIdx]))
284 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*num_path[exampleIdx]))
285 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*num_path[exampleIdx]))
286 c_AlignmentScores = QPalmaDP.createDoubleArrayFromList([.0]*num_path[exampleIdx])
287
288 emptyPlif = Plf()
289 emptyPlif = curPlif.convert2SWIG()
290 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList(emptyPlif*(24*num_path[exampleIdx]))
291
292 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
293 c_WeightMatch, c_AlignmentScores, c_qualityPlifs)
294 del currentAlignment
295
296 newSpliceAlign = zeros((num_path[exampleIdx]*dna_len,1))
297 newEstAlign = zeros((est_len*num_path[exampleIdx],1))
298 newWeightMatch = zeros((num_path[exampleIdx]*mm_len,1))
299
300 #print 'newSpliceAlign'
301 for i in range(dna_len*num_path[exampleIdx]):
302 newSpliceAlign[i] = c_SpliceAlign[i]
303 # print '%f' % (spliceAlign[i])
304
305 #print 'newEstAlign'
306 for i in range(est_len*num_path[exampleIdx]):
307 newEstAlign[i] = c_EstAlign[i]
308 # print '%f' % (spliceAlign[i])
309
310 #print 'weightMatch'
311 for i in range(mm_len*num_path[exampleIdx]):
312 newWeightMatch[i] = c_WeightMatch[i]
313 # print '%f' % (weightMatch[i])
314
315 for i in range(num_path[exampleIdx]):
316 AlignmentScores[i+1] = c_AlignmentScores[i]
317
318 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
319 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
320 # Calculate weights of the respective alignments Note that we are
321 # calculating n-best alignments without any hamming loss, so we
322 # have to keep track which of the n-best alignments correspond to
323 # the true one in order not to incorporate a true alignment in the
324 # constraints. To keep track of the true and false alignments we
325 # define an array true_map with a boolean indicating the
326 # equivalence to the true alignment for each decoded alignment.
327 true_map = [0]*(num_path[exampleIdx]+1)
328 true_map[0] = 1
329 path_loss = [0]*(num_path[exampleIdx]+1)
330
331 for pathNr in range(num_path[exampleIdx]):
332 #dna_numbers = dnaest{1,pathNr}
333 #est_numbers = dnaest{2,pathNr}
334
335 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a, h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp, acc_supp)
336
337 #
338 qualityWeights = computeSpliceQualityWeights()
339
340 # sum up positionwise loss between alignments
341 for alignPosIdx in range(len(newSpliceAlign[pathNr,:])):
342 if newSpliceAlign[pathNr,alignPosIdx] != trueSpliceAlign[alignPosIdx]:
343 path_loss[pathNr+1] += 1
344
345 # Gewichte in restliche Zeilen der Matrix speichern
346 wp = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, qualityMatrix ])
347 allWeights[:,pathNr+1] = wp
348
349 hpen = mat(h.penalties).reshape(len(h.penalties),1)
350 dpen = mat(d.penalties).reshape(len(d.penalties),1)
351 apen = mat(a.penalties).reshape(len(a.penalties),1)
352
353 features = numpy.vstack([hpen , dpen , apen , mmatrix[:]])
354 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
355
356 # Check wether scalar product + loss equals viterbi score
357 #assert math.fabs(newAlignmentScores[pathNr] - AlignmentScores[pathNr+1]) < 1e-6,\
358 #'Scalar prod + loss is not equal Viterbi score. Respective values are %f, %f' % \
359 #(newAlignmentScores[pathNr],AlignmentScores[pathNr+1])
360
361 # # if the pathNr-best alignment is very close to the true alignment consider it as true
362 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
363 true_map[pathNr+1] = 1
364
365 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
366
367 # this means that all n-best paths are to close to each other
368 # we have to extend the n-best search to a (n+1)-best
369 if len([elem for elem in true_map if elem == 1]) == len(true_map):
370 num_path[exampleIdx] = num_path[exampleIdx]+1
371
372 # Choose true and first false alignment for extending A
373 firstFalseIdx = -1
374 for map_idx,elem in enumerate(true_map):
375 if elem == 0:
376 firstFalseIdx = map_idx
377 break
378
379 # if there is at least one useful false alignment add the
380 # corresponding constraints to the optimization problem
381 if firstFalseIdx != -1:
382 trueWeights = allWeights[:,0]
383 firstFalseWeights = allWeights[:,firstFalseIdx]
384
385 # LMM.py code:
386 deltas = firstFalseWeights - trueWeights
387 if not __debug__:
388 const_added = solver.addConstraint(deltas, exampleIdx)
389 objValue,w,self.slacks = solver.solve()
390
391 sum_xis = 0
392 for elem in self.slacks:
393 sum_xis += elem
394
395 for i in range(len(param)):
396 param[i] = w[i]
397
398 [h,d,a,mmatrix] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
399
400 #
401 # end of one example processing
402 #
403 #if exampleIdx == 100:
404 # break
405
406 #break
407
408 #
409 # end of one iteration through all examples
410 #
411 iteration_nr += 1
412
413 #
414 # end of optimization
415 #
416 export_param('elegans.param',h,d,a,mmatrix)
417 self.logfh.close()
418 print 'Training completed'
419
420 if __name__ == '__main__':
421 qpalma = QPalma()
422 qpalma.run()