241410da04fe9251c3b9f7d581c2ff2ff58f971c
[qpalma.git] / scripts / qpalma_predict.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 #
7 #
8 ###########################################################
9
10 import sys
11 import subprocess
12 import scipy.io
13 import pdb
14 import os.path
15
16 from numpy.matlib import mat,zeros,ones,inf
17 from numpy.linalg import norm
18
19 import QPalmaDP
20
21 import qpalma
22 from qpalma.SIQP_CPX import SIQPSolver
23 from qpalma.DataProc import *
24
25 from qpalma.generateEvaluationData import *
26 from qpalma.computeSpliceWeights import *
27 from qpalma.set_param_palma import *
28 from qpalma.computeSpliceAlignWithQuality import *
29 from qpalma.penalty_lookup_new import *
30 from qpalma.compute_donacc import *
31 from qpalma.TrainingParam import Param
32 from qpalma.export_param import *
33
34 import qpalma.Configuration
35 from qpalma.Plif import Plf
36 from qpalma.Helpers import *
37
38 def getQualityFeatureCounts(qualityPlifs):
39 weightQuality = qualityPlifs[0].penalties
40 for currentPlif in qualityPlifs[1:]:
41 weightQuality = numpy.vstack([weightQuality, currentPlif.penalties])
42
43 return weightQuality
44
45 class QPalma:
46 """
47 A training method for the QPalma project
48 """
49
50 def __init__(self):
51 self.ARGS = Param()
52
53 self.logfh = open('qpalma.log','w+')
54 #gen_file= '%s/genome.config' % self.ARGS.basedir
55 #ginfo_filename = 'genome_info.pickle'
56 #self.genome_info = fetch_genome_info(ginfo_filename)
57 #self.plog('genome_info.basedir is %s\n'%self.genome_info.basedir)
58
59 #self.ARGS.train_with_splicesitescoreinformation = False
60
61 def plog(self,string):
62 self.logfh.write(string)
63 self.logfh.flush()
64
65 def run(self):
66 # Load the whole dataset
67 if Conf.mode == 'normal':
68 #Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
69 Sequences, Acceptors, Donors, Exons, Ests, Qualities = loadArtificialData(1000)
70 use_quality_scores = False
71
72 elif Conf.mode == 'using_quality_scores':
73 #Sequences, Acceptors, Donors, Exons, Ests, Qualities, SplitPos = paths_load_data_solexa('training',None,self.ARGS)
74 #end = 50
75 #Sequences = Sequences[:end]
76 #Acceptors = Acceptors[:end]
77 #Donors = Donors[:end]
78 #Exons = Exons[:end]
79 #Ests = Ests[:end]
80 #Qualities = Qualities[:end]
81 #SplitPos = SplitPos[:end]
82
83 Sequences, Acceptors, Donors, Exons, Ests, Qualities = loadArtificialData(1000)
84 SplitPos = [1]*len(Qualities)
85 #Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
86 #pdb.set_trace()
87 #Qualities = []
88 #for i in range(len(Ests)):
89 # Qualities.append([40]*len(Ests[i]))
90 use_quality_scores = True
91 else:
92 assert(False)
93
94 # number of training instances
95 N = len(Sequences)
96 self.numExamples = N
97 assert N == len(Acceptors) and N == len(Acceptors) and N == len(Exons)\
98 and N == len(Ests), 'The Seq,Accept,Donor,.. arrays are of different lengths'
99 self.plog('Number of training examples: %d\n'% N)
100 print 'Number of features: %d\n'% Conf.numFeatures
101
102 remove_duplicate_scores = Conf.remove_duplicate_scores
103 print_matrix = Conf.print_matrix
104 anzpath = Conf.anzpath
105
106 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
107 #param_filename = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/python/elegans.param'
108 param_filename='/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/param_30.pickle'
109 param = load_param(param_filename)
110
111 # Set the parameters such as limits penalties for the Plifs
112 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
113
114 # stores the number of alignments done for each example (best path, second-best path etc.)
115 num_path = [anzpath]*N
116
117 #############################################################################################
118 # Training
119 #############################################################################################
120 self.plog('Starting training...\n')
121
122 donSP = Conf.numDonSuppPoints
123 accSP = Conf.numAccSuppPoints
124 lengthSP = Conf.numLengthSuppPoints
125 mmatrixSP = Conf.sizeMatchmatrix[0]\
126 *Conf.sizeMatchmatrix[1]
127 numq = Conf.numQualSuppPoints
128 totalQualSP = Conf.totalQualSuppPoints
129
130 currentPhi = zeros((Conf.numFeatures,1))
131 totalQualityPenalties = zeros((totalQualSP,1))
132
133 total_up_off = []
134 total_down_off = []
135
136 #for exampleIdx in range(self.numExamples):
137 for exampleIdx in range(200):
138
139 dna = Sequences[exampleIdx]
140 est = Ests[exampleIdx]
141 currentSplitPos = SplitPos[exampleIdx]
142
143 if Conf.mode == 'normal':
144 quality = [40]*len(est)
145
146 if Conf.mode == 'using_quality_scores':
147 quality = Qualities[exampleIdx]
148
149 exons = Exons[exampleIdx]
150 # NoiseMatrix = Noises[exampleIdx]
151 don_supp = Donors[exampleIdx]
152 acc_supp = Acceptors[exampleIdx]
153
154 if exons[-1,1] > len(dna):
155 continue
156
157 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
158 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
159
160 # Calculate the weights
161 trueWeightDon, trueWeightAcc, trueWeightIntron = computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
162 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
163
164 currentPhi[0:donSP] = mat(d.penalties[:]).reshape(donSP,1)
165 currentPhi[donSP:donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
166 currentPhi[donSP+accSP:donSP+accSP+lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
167 currentPhi[donSP+accSP+lengthSP:donSP+accSP+lengthSP+mmatrixSP] = mmatrix[:]
168
169 if Conf.mode == 'using_quality_scores':
170 totalQualityPenalties = param[-totalQualSP:]
171 currentPhi[donSP+accSP+lengthSP+mmatrixSP:] = totalQualityPenalties[:]
172
173 # Calculate w'phi(x,y) the total score of the alignment
174 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
175
176 # The allWeights vector is supposed to store the weight parameter
177 # of the true alignment as well as the weight parameters of the
178 # 1 other alignments
179 allWeights = zeros((Conf.numFeatures,1+1))
180 allWeights[:,0] = trueWeight[:,0]
181
182 AlignmentScores = [0.0]*(1+1)
183 AlignmentScores[0] = trueAlignmentScore
184
185 ################## Calculate wrong alignment(s) ######################
186
187 # Compute donor, acceptor with penalty_lookup_new
188 # returns two double lists
189 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
190
191 #myalign wants the acceptor site on the g of the ag
192 acceptor = acceptor[1:]
193 acceptor.append(-inf)
194
195 # for now we don't use donor/acceptor scores
196 #donor = [-inf] * len(donor)
197 #acceptor = [-inf] * len(donor)
198
199 dna = str(dna)
200 est = str(est)
201 dna_len = len(dna)
202 est_len = len(est)
203
204 ps = h.convert2SWIG()
205
206 prb = QPalmaDP.createDoubleArrayFromList(quality)
207 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
208
209 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
210 mm_len = Conf.sizeMatchmatrix[0]*Conf.sizeMatchmatrix[1]
211
212 d_len = len(donor)
213 donor = QPalmaDP.createDoubleArrayFromList(donor)
214 a_len = len(acceptor)
215 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
216
217 # Create the alignment object representing the interface to the C/C++ code.
218 currentAlignment = QPalmaDP.Alignment(Conf.numQualPlifs,Conf.numQualSuppPoints, use_quality_scores)
219
220 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
221
222 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
223 currentAlignment.myalign(1, dna, dna_len,\
224 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
225 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
226 print_matrix)
227
228 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*1))
229 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*1))
230 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*1))
231 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*1)
232
233 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(Conf.totalQualSuppPoints))
234
235 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
236 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
237
238 newSpliceAlign = zeros((dna_len,1))
239 newEstAlign = zeros((est_len,1))
240 newWeightMatch = zeros((mm_len,1))
241 newQualityPlifsFeatures = zeros((Conf.totalQualSuppPoints,1))
242
243 #print 'newSpliceAlign'
244 for i in range(dna_len):
245 newSpliceAlign[i] = c_SpliceAlign[i]
246 # print '%f' % (spliceAlign[i])
247
248 #print 'newEstAlign'
249 for i in range(est_len):
250 newEstAlign[i] = c_EstAlign[i]
251 # print '%f' % (spliceAlign[i])
252
253 #print 'weightMatch'
254 for i in range(mm_len):
255 newWeightMatch[i] = c_WeightMatch[i]
256 # print '%f' % (weightMatch[i])
257
258 newDPScores = c_DPScores[0]
259
260 for i in range(Conf.totalQualSuppPoints):
261 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
262
263 # equals palma up to here
264
265 #print "Calling destructors"
266 del c_SpliceAlign
267 del c_EstAlign
268 del c_WeightMatch
269 del c_DPScores
270 del c_qualityPlifsFeatures
271 del currentAlignment
272
273 newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
274 newWeightMatch = newWeightMatch.reshape(1,mm_len)
275 # Calculate weights of the respective alignments Note that we are
276 # calculating n-best alignments without hamming loss, so we
277 # have to keep track which of the n-best alignments correspond to
278 # the true one in order not to incorporate a true alignment in the
279 # constraints. To keep track of the true and false alignments we
280 # define an array true_map with a boolean indicating the
281 # equivalence to the true alignment for each decoded alignment.
282 true_map = [0]*2
283 true_map[0] = 1
284 pathNr = 0
285
286 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a, h, newSpliceAlign.flatten().tolist()[0], don_supp, acc_supp)
287
288 decodedQualityFeatures = zeros((Conf.totalQualSuppPoints,1))
289 for qidx in range(Conf.totalQualSuppPoints):
290 decodedQualityFeatures[qidx] = newQualityPlifsFeatures[qidx]
291
292 # Gewichte in restliche Zeilen der Matrix speichern
293 wp = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch.T, decodedQualityFeatures])
294 allWeights[:,pathNr+1] = wp
295
296 hpen = mat(h.penalties).reshape(len(h.penalties),1)
297 dpen = mat(d.penalties).reshape(len(d.penalties),1)
298 apen = mat(a.penalties).reshape(len(a.penalties),1)
299 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties])
300
301 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
302
303 # Check wether scalar product + loss equals viterbi score
304 print '%f vs. %f' % (newDPScores, AlignmentScores[pathNr+1])
305
306 # if the pathNr-best alignment is very close to the true alignment consider it as true
307 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
308 true_map[pathNr+1] = 1
309
310 #pdb.set_trace()
311
312 #up_off,down_off = evaluateExample(dna,est,exons,newSpliceAlign,newEstAlign,currentSplitPos)
313 evaluateExample(dna,est,exons,newSpliceAlign,newEstAlign,currentSplitPos)
314 #print up_off,down_off
315
316 #if up_off > -1:
317 # total_up_off.append(up_off)
318 # total_down_off.append(down_off)
319
320 #total_up = 0
321 #total_down = 0
322 #for idx in range(len(total_up_off)):
323 # total_up += total_up_off[idx]
324 # total_down += total_down_off[idx]
325 #
326 #total_up /= len(total_up_off)
327 #total_down /= len(total_down_off)
328
329 #print 'Mean up_off is %f' % total_up
330 #print 'Mean down_off is %f' % total_down
331 ##print total_up_off
332 ##print total_down_off
333 #print 'len is %d' % len(total_up_off)
334
335 print 'Prediction completed'
336 self.logfh.close()
337
338 def fetch_genome_info(ginfo_filename):
339 if not os.path.exists(ginfo_filename):
340 cmd = ['']*4
341 cmd[0] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/utils'
342 cmd[1] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/genomes'
343 cmd[2] = 'genome_info = init_genome(\'%s\')' % gen_file
344 cmd[3] = 'save genome_info.mat genome_info'
345 full_cmd = "matlab -nojvm -nodisplay -r \"%s; %s; %s; %s; exit\"" % (cmd[0],cmd[1],cmd[2],cmd[3])
346
347 obj = subprocess.Popen(full_cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
348 out,err = obj.communicate()
349 assert err == '', 'An error occured!\n%s'%err
350
351 ginfo = scipy.io.loadmat('genome_info.mat')
352 cPickle.dump(self.genome_info,open(ginfo_filename,'w+'))
353 return ginfo['genome_info']
354
355 else:
356 return cPickle.load(open(ginfo_filename))
357
358 def plifs2param(h,d,a,mmatrix,qualityPlifs):
359 donSP = Conf.numDonSuppPoints
360 accSP = Conf.numAccSuppPoints
361 lengthSP = Conf.numLengthSuppPoints
362 mmatrixSP = Conf.sizeMatchmatrix[0]\
363 *Conf.sizeMatchmatrix[1]
364
365
366 param = zeros((Conf.numFeatures,1))
367 param[0:donSP] = mat(d.penalties[:]).reshape(donSP,1)
368 param[donSP:donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
369 param[donSP+accSP:donSP+accSP+lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
370 param[donSP+accSP+lengthSP:donSP+accSP+lengthSP+mmatrixSP] = mmatrix[:]
371
372 for idx in range(len(qualityPlifs)):
373 currentPlif = qualityPlifs[idx]
374 begin = lengthSP+donSP+accSP+mmatrixSP+(idx*Conf.numQualSuppPoints)
375 end = lengthSP+donSP+accSP+mmatrixSP+((idx+1)*Conf.numQualSuppPoints)
376 param[begin:end] = mat(currentPlif.penalties).reshape(Conf.numQualSuppPoints,1)
377
378 return param
379
380 def load_param(filename):
381 param = None
382 #try:
383 # para = cPickle.load(open(filename))
384 #except:
385 # print 'Error: Could not open parameter file!'
386 # sys.exit(1)
387 #
388 #param = plifs2param(para.h,para.d,para.a,para.mmatrix,para.qualityPlifs)
389
390 param = cPickle.load(open(filename))
391 return param
392
393 def evaluateExample(dna,est,exons,SpliceAlign,newEstAlign,spos):
394 newExons = []
395 oldElem = -1
396 SpliceAlign = SpliceAlign.flatten().tolist()[0]
397 SpliceAlign.append(-1)
398 for pos,elem in enumerate(SpliceAlign):
399 if pos == 0:
400 oldElem = -1
401 else:
402 oldElem = SpliceAlign[pos-1]
403
404 if oldElem == 2 and elem == 0: # start of exon
405 newExons.append(pos-1)
406
407 if oldElem == 0 and elem == 1: # end of exon
408 newExons.append(pos)
409
410 pdb.set_trace()
411 #up_off = -1
412 #down_off = -1
413
414 #if len(newExons) != 4:
415 # acc = 0.0
416 #else:
417 # e1_begin,e1_end = newExons[0],newExons[1]
418 # e2_begin,e2_end = newExons[2],newExons[3]
419
420 # up_off = int(math.fabs(e1_end - exons[0,1]))
421 # down_off = int(math.fabs(e2_begin - exons[1,0]))
422
423 #return up_off,down_off
424
425 if __name__ == '__main__':
426 qpalma = QPalma()
427 qpalma.run()