+ added exons read boundary output to filterReads output
[qpalma.git] / scripts / qpalma_train.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 #
7 #
8 ###########################################################
9
10 import sys
11 import subprocess
12 import scipy.io
13 import pdb
14 import os.path
15
16 from numpy.matlib import mat,zeros,ones,inf
17 from numpy.linalg import norm
18
19 import QPalmaDP
20
21 import qpalma
22 from qpalma.SIQP_CPX import SIQPSolver
23 from qpalma.DataProc import *
24
25 from qpalma.generateEvaluationData import *
26 from qpalma.computeSpliceWeights import *
27 from qpalma.set_param_palma import *
28 from qpalma.computeSpliceAlignWithQuality import *
29 from qpalma.penalty_lookup_new import *
30 from qpalma.compute_donacc import *
31 from qpalma.TrainingParam import Param
32 from qpalma.export_param import *
33
34 import qpalma.Configuration
35 from qpalma.Plif import Plf
36 from qpalma.Helpers import *
37
38 from qpalma.tools.splicesites import getDonAccScores
39
40 def getQualityFeatureCounts(qualityPlifs):
41 weightQuality = qualityPlifs[0].penalties
42 for currentPlif in qualityPlifs[1:]:
43 weightQuality = numpy.vstack([weightQuality, currentPlif.penalties])
44
45 return weightQuality
46
47 class QPalma:
48 """
49 A training method for the QPalma project
50 """
51
52 def __init__(self):
53 self.ARGS = Param()
54 self.logfh = open('qpalma.log','w+')
55
56 #gen_file= '%s/genome.config' % self.ARGS.basedir
57 #ginfo_filename = 'genome_info.pickle'
58 #self.genome_info = fetch_genome_info(ginfo_filename)
59 #self.plog('genome_info.basedir is %s\n'%self.genome_info.basedir)
60
61 #self.ARGS.train_with_splicesitescoreinformation = False
62
63 def plog(self,string):
64 self.logfh.write(string)
65 self.logfh.flush()
66
67 def run(self):
68 # Load the whole dataset
69 if Configuration.mode == 'normal':
70 #Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
71 Sequences, Acceptors, Donors, Exons, Ests, Qualities = loadArtificialData(1000)
72
73 Donors, Acceptors = getDonAccScores(Sequences)
74
75 use_quality_scores = False
76 elif Configuration.mode == 'using_quality_scores':
77
78
79 filename = 'real_dset_%s'% 'recent'
80 if True:# not os.path.exists(filename):
81 Sequences, Acceptors, Donors, Exons, Ests, Qualities, SplitPos = paths_load_data_solexa('training',None,self.ARGS)
82
83 end = 200
84 Sequences = Sequences[:end]
85 Exons = Exons[:end]
86 Ests = Ests[:end]
87 Qualities = Qualities[:end]
88 Acceptors = Acceptors[:end]
89 Donors = Donors[:end]
90
91 pdb.set_trace()
92
93 Donors, Acceptors = getDonAccScores(Sequences)
94 #dset = Dataset()
95 #dset.Sequences = Sequences
96 #dset.Acceptor = Acceptors
97 #dset.Donors = Donors
98 #dset.Exons = Exons
99 #dset.Ests = Ests
100 #dset.Qualities = Qualities
101 #cPickle.dump(dset,open(filename,'w+'))
102 else:
103 dset = cPickle.load(open(filename))
104 Sequences = dset.Sequences
105 Acceptors = dset.Acceptors
106 Donors = dset.Donors
107 Exons = dset.Exons
108 Ests = dset.Ests
109 Qualities = dset.Qualities
110
111 #Sequences, Acceptors, Donors, Exons, Ests, Qualities = loadArtificialData(1000)
112 #Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
113 #Qualities = []
114 #for i in range(len(Ests)):
115 # Qualities.append([40]*len(Ests[i]))
116 use_quality_scores = True
117 else:
118 assert(False)
119
120 # number of training instances
121 N = len(Sequences)
122 self.numExamples = N
123 assert N == len(Acceptors) and N == len(Acceptors) and N == len(Exons)\
124 and N == len(Ests), 'The Seq,Accept,Donor,.. arrays are of different lengths'
125 self.plog('Number of training examples: %d\n'% N)
126 print 'Number of features: %d\n'% Configuration.numFeatures
127
128 iteration_steps = Configuration.iter_steps ; #upper bound on iteration steps
129 remove_duplicate_scores = Configuration.remove_duplicate_scores
130 print_matrix = Configuration.print_matrix
131 anzpath = Configuration.anzpath
132
133 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
134 param = Configuration.fixedParam
135
136 # Set the parameters such as limits penalties for the Plifs
137 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
138
139 # delete splicesite-score-information
140 #if not self.ARGS.train_with_splicesitescoreinformation:
141 # for i in range(len(Acceptors)):
142 # if Acceptors[i] > -20:
143 # Acceptors[i] = 1
144 # if Donors[i] >-20:
145 # Donors[i] = 1
146
147 # Initialize solver
148 if Configuration.USE_OPT:
149 self.plog('Initializing problem...\n')
150 solver = SIQPSolver(Configuration.numFeatures,self.numExamples,Configuration.C,self.logfh)
151
152 # stores the number of alignments done for each example (best path, second-best path etc.)
153 num_path = [anzpath]*N
154 # stores the gap for each example
155 gap = [0.0]*N
156
157 #############################################################################################
158 # Training
159 #############################################################################################
160 self.plog('Starting training...\n')
161
162 donSP = Configuration.numDonSuppPoints
163 accSP = Configuration.numAccSuppPoints
164 lengthSP = Configuration.numLengthSuppPoints
165 mmatrixSP = Configuration.sizeMatchmatrix[0]\
166 *Configuration.sizeMatchmatrix[1]
167 numq = Configuration.numQualSuppPoints
168 totalQualSP = Configuration.totalQualSuppPoints
169
170 currentPhi = zeros((Configuration.numFeatures,1))
171 totalQualityPenalties = zeros((totalQualSP,1))
172
173 iteration_nr = 0
174 param_idx = 0
175 const_added_ctr = 0
176 while True:
177 if iteration_nr == iteration_steps:
178 break
179
180 for exampleIdx in range(self.numExamples):
181 print 'Current example nr %d' % exampleIdx
182
183 if (exampleIdx%10) == 0:
184 print 'Current example nr %d' % exampleIdx
185
186 dna = Sequences[exampleIdx]
187 est = Ests[exampleIdx]
188
189 if Configuration.mode == 'normal':
190 quality = [40]*len(est)
191
192 if Configuration.mode == 'using_quality_scores':
193 quality = Qualities[exampleIdx]
194
195 exons = Exons[exampleIdx]
196 # NoiseMatrix = Noises[exampleIdx]
197 don_supp = Donors[exampleIdx]
198 acc_supp = Acceptors[exampleIdx]
199
200 #if exons[-1,1] > len(dna):
201 # continue
202
203 #pdb.set_trace()
204 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
205 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
206
207 #print 'trueWeights'
208 # Calculate the weights
209 trueWeightDon, trueWeightAcc, trueWeightIntron = computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
210 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
211
212
213 currentPhi[0:donSP] = mat(d.penalties[:]).reshape(donSP,1)
214 currentPhi[donSP:donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
215 currentPhi[donSP+accSP:donSP+accSP+lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
216 currentPhi[donSP+accSP+lengthSP:donSP+accSP+lengthSP+mmatrixSP] = mmatrix[:]
217
218 #pdb.set_trace()
219 if Configuration.mode == 'using_quality_scores':
220 totalQualityPenalties = param[-totalQualSP:]
221 currentPhi[donSP+accSP+lengthSP+mmatrixSP:] = totalQualityPenalties[:]
222
223 # Calculate w'phi(x,y) the total score of the alignment
224 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
225
226 # The allWeights vector is supposed to store the weight parameter
227 # of the true alignment as well as the weight parameters of the
228 # num_path[exampleIdx] other alignments
229 allWeights = zeros((Configuration.numFeatures,num_path[exampleIdx]+1))
230 allWeights[:,0] = trueWeight[:,0]
231
232 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
233 AlignmentScores[0] = trueAlignmentScore
234
235 ################## Calculate wrong alignment(s) ######################
236
237 # Compute donor, acceptor with penalty_lookup_new
238 # returns two double lists
239 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
240
241 #myalign wants the acceptor site on the g of the ag
242 acceptor = acceptor[1:]
243 acceptor.append(-inf)
244
245 # for now we don't use donor/acceptor scores
246
247 #donor = [-inf] * len(donor)
248 #acceptor = [-inf] * len(donor)
249
250 dna = str(dna)
251 est = str(est)
252 dna_len = len(dna)
253 est_len = len(est)
254
255 ps = h.convert2SWIG()
256
257 prb = QPalmaDP.createDoubleArrayFromList(quality)
258 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
259
260 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
261 mm_len = Configuration.sizeMatchmatrix[0]*Configuration.sizeMatchmatrix[1]
262
263 d_len = len(donor)
264 donor = QPalmaDP.createDoubleArrayFromList(donor)
265 a_len = len(acceptor)
266 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
267
268 # Create the alignment object representing the interface to the C/C++ code.
269 currentAlignment = QPalmaDP.Alignment(Configuration.numQualPlifs,Configuration.numQualSuppPoints, use_quality_scores)
270
271 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
272 #print 'Calling myalign...'
273 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
274 currentAlignment.myalign( num_path[exampleIdx], dna, dna_len,\
275 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
276 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
277 print_matrix)
278
279 #print 'After calling myalign...'
280 #print 'Calling getAlignmentResults...'
281
282 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*num_path[exampleIdx]))
283 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*num_path[exampleIdx]))
284 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*num_path[exampleIdx]))
285 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*num_path[exampleIdx])
286
287 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(Configuration.totalQualSuppPoints*num_path[exampleIdx]))
288
289 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
290 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
291
292 #print 'After calling getAlignmentResults...'
293
294 newSpliceAlign = zeros((num_path[exampleIdx]*dna_len,1))
295 newEstAlign = zeros((est_len*num_path[exampleIdx],1))
296 newWeightMatch = zeros((num_path[exampleIdx]*mm_len,1))
297 newDPScores = zeros((num_path[exampleIdx],1))
298 newQualityPlifsFeatures = zeros((Configuration.totalQualSuppPoints*num_path[exampleIdx],1))
299
300 #print 'newSpliceAlign'
301 for i in range(dna_len*num_path[exampleIdx]):
302 newSpliceAlign[i] = c_SpliceAlign[i]
303 # print '%f' % (spliceAlign[i])
304
305 #print 'newEstAlign'
306 for i in range(est_len*num_path[exampleIdx]):
307 newEstAlign[i] = c_EstAlign[i]
308 # print '%f' % (spliceAlign[i])
309
310 #print 'weightMatch'
311 for i in range(mm_len*num_path[exampleIdx]):
312 newWeightMatch[i] = c_WeightMatch[i]
313 # print '%f' % (weightMatch[i])
314
315 #print 'ViterbiScores'
316 for i in range(num_path[exampleIdx]):
317 newDPScores[i] = c_DPScores[i]
318
319
320 if use_quality_scores:
321 for i in range(Configuration.totalQualSuppPoints*num_path[exampleIdx]):
322 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
323
324 # equals palma up to here
325
326 #print "Calling destructors"
327 del c_SpliceAlign
328 del c_EstAlign
329 del c_WeightMatch
330 del c_DPScores
331 del c_qualityPlifsFeatures
332 del currentAlignment
333
334 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
335 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
336 # Calculate weights of the respective alignments Note that we are
337 # calculating n-best alignments without hamming loss, so we
338 # have to keep track which of the n-best alignments correspond to
339 # the true one in order not to incorporate a true alignment in the
340 # constraints. To keep track of the true and false alignments we
341 # define an array true_map with a boolean indicating the
342 # equivalence to the true alignment for each decoded alignment.
343 true_map = [0]*(num_path[exampleIdx]+1)
344 true_map[0] = 1
345 path_loss = [0]*(num_path[exampleIdx])
346
347 for pathNr in range(num_path[exampleIdx]):
348 #print 'decodedWeights'
349 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
350 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
351 acc_supp)
352
353 decodedQualityFeatures = zeros((Configuration.totalQualSuppPoints,1))
354 for qidx in range(Configuration.totalQualSuppPoints):
355 decodedQualityFeatures[qidx] = newQualityPlifsFeatures[(pathNr*Configuration.totalQualSuppPoints)+qidx]
356
357 #pdb.set_trace()
358
359 path_loss[pathNr] = 0
360 # sum up positionwise loss between alignments
361 for alignPosIdx in range(newSpliceAlign[pathNr,:].shape[1]):
362 if newSpliceAlign[pathNr,alignPosIdx] != trueSpliceAlign[alignPosIdx]:
363 path_loss[pathNr] += 1
364
365 #pdb.set_trace()
366
367 # Gewichte in restliche Zeilen der Matrix speichern
368 wp = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures])
369 allWeights[:,pathNr+1] = wp
370
371 hpen = mat(h.penalties).reshape(len(h.penalties),1)
372 dpen = mat(d.penalties).reshape(len(d.penalties),1)
373 apen = mat(a.penalties).reshape(len(a.penalties),1)
374 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties])
375
376 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
377
378 # Check wether scalar product + loss equals viterbi score
379 print 'Example nr.: %d, path nr. %d, scores: %f vs %f' % (exampleIdx,pathNr,newDPScores[pathNr,0], AlignmentScores[pathNr+1])
380
381 distinct_scores = False
382 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
383 distinct_scores = True
384
385 #if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) + [0,1][distinct_scores and (pathNr>0)] <= 1e-5:
386 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
387 pdb.set_trace()
388
389 # # if the pathNr-best alignment is very close to the true alignment consider it as true
390 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
391 true_map[pathNr+1] = 1
392
393 # assert AlignmentScores[0] > max(AlignmentScores[1:]) + 1e-6, pdb.set_trace()
394
395 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
396 # this means that all n-best paths are to close to each other
397 # we have to extend the n-best search to a (n+1)-best
398 if len([elem for elem in true_map if elem == 1]) == len(true_map):
399 num_path[exampleIdx] = num_path[exampleIdx]+1
400
401 # Choose true and first false alignment for extending A
402 firstFalseIdx = -1
403 for map_idx,elem in enumerate(true_map):
404 if elem == 0:
405 firstFalseIdx = map_idx
406 break
407
408 # if there is at least one useful false alignment add the
409 # corresponding constraints to the optimization problem
410 if firstFalseIdx != -1:
411 trueWeights = allWeights[:,0]
412 firstFalseWeights = allWeights[:,firstFalseIdx]
413 differenceVector = trueWeights - firstFalseWeights
414 #pdb.set_trace()
415
416 if Configuration.USE_OPT:
417 const_added = solver.addConstraint(differenceVector, exampleIdx)
418 const_added_ctr += 1
419 #
420 # end of one example processing
421 #
422
423 # call solver every nth example //added constraint
424 if exampleIdx != 0 and exampleIdx % 20 == 0 and Configuration.USE_OPT:
425 objValue,w,self.slacks = solver.solve()
426
427 print "objValue is %f" % objValue
428
429 sum_xis = 0
430 for elem in self.slacks:
431 sum_xis += elem
432
433 for i in range(len(param)):
434 param[i] = w[i]
435
436 #pdb.set_trace()
437 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
438 param_idx += 1
439 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
440
441 #
442 # end of one iteration through all examples
443 #
444 iteration_nr += 1
445
446 #
447 # end of optimization
448 #
449 print 'Training completed'
450
451 pa = para()
452 pa.h = h
453 pa.d = d
454 pa.a = a
455 pa.mmatrix = mmatrix
456 pa.qualityPlifs = qualityPlifs
457
458 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
459 #cPickle.dump(pa,open('elegans.param','w+'))
460 self.logfh.close()
461
462 def fetch_genome_info(ginfo_filename):
463 if not os.path.exists(ginfo_filename):
464 cmd = ['']*4
465 cmd[0] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/utils'
466 cmd[1] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/genomes'
467 cmd[2] = 'genome_info = init_genome(\'%s\')' % gen_file
468 cmd[3] = 'save genome_info.mat genome_info'
469 full_cmd = "matlab -nojvm -nodisplay -r \"%s; %s; %s; %s; exit\"" % (cmd[0],cmd[1],cmd[2],cmd[3])
470
471 obj = subprocess.Popen(full_cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
472 out,err = obj.communicate()
473 assert err == '', 'An error occured!\n%s'%err
474
475 ginfo = scipy.io.loadmat('genome_info.mat')
476 cPickle.dump(self.genome_info,open(ginfo_filename,'w+'))
477 return ginfo['genome_info']
478
479 else:
480 return cPickle.load(open(ginfo_filename))
481
482 if __name__ == '__main__':
483 qpalma = QPalma()
484 qpalma.run()