244eb7f6795d1dfa96698eea58958e7e93285d79
[qpalma.git] / scripts / qpalma_train.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 #
7 #
8 ###########################################################
9
10 import sys
11 import subprocess
12 import scipy.io
13 import pdb
14 import os.path
15
16 from numpy.matlib import mat,zeros,ones,inf
17 from numpy.linalg import norm
18
19 import QPalmaDP
20
21 import qpalma
22 from qpalma.SIQP_CPX import SIQPSolver
23 from qpalma.DataProc import *
24
25 from qpalma.generateEvaluationData import *
26 from qpalma.computeSpliceWeights import *
27 from qpalma.set_param_palma import *
28 from qpalma.computeSpliceAlignWithQuality import *
29 from qpalma.penalty_lookup_new import *
30 from qpalma.compute_donacc import *
31 from qpalma.TrainingParam import Param
32 from qpalma.export_param import *
33
34 import qpalma.Configuration
35 from qpalma.Plif import Plf
36 from qpalma.Helpers import *
37
38 from qpalma.tools.splicesites import getDonAccScores
39
40 def getQualityFeatureCounts(qualityPlifs):
41 weightQuality = qualityPlifs[0].penalties
42 for currentPlif in qualityPlifs[1:]:
43 weightQuality = numpy.vstack([weightQuality, currentPlif.penalties])
44
45 return weightQuality
46
47 class QPalma:
48 """
49 A training method for the QPalma project
50 """
51
52 def __init__(self):
53 self.ARGS = Param()
54 self.logfh = open('qpalma.log','w+')
55
56 #gen_file= '%s/genome.config' % self.ARGS.basedir
57 #ginfo_filename = 'genome_info.pickle'
58 #self.genome_info = fetch_genome_info(ginfo_filename)
59 #self.plog('genome_info.basedir is %s\n'%self.genome_info.basedir)
60
61 #self.ARGS.train_with_splicesitescoreinformation = False
62
63 def plog(self,string):
64 self.logfh.write(string)
65 self.logfh.flush()
66
67 def run(self):
68 # Load the whole dataset
69 if Configuration.mode == 'normal':
70 #Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
71 Sequences, Acceptors, Donors, Exons, Ests, Qualities = loadArtificialData(1000)
72
73 Donors, Acceptors = getDonAccScores(Sequences)
74
75 use_quality_scores = False
76 elif Configuration.mode == 'using_quality_scores':
77 #Sequences, Acceptors, Donors, Exons, Ests, Qualities, SplitPos = paths_load_data_solexa('training',None,self.ARGS)
78
79 #end = 50
80 #Sequences = Sequences[:end]
81 #Exons = Exons[:end]
82 #Ests = Ests[:end]
83 #Qualities = Qualities[:end]
84 #SplitPos = SplitPos[:end]
85
86 #Donors, Acceptors = getDonAccScores(Sequences)
87
88 Sequences, Acceptors, Donors, Exons, Ests, Qualities = loadArtificialData(1000)
89
90 #Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
91 #Qualities = []
92 #for i in range(len(Ests)):
93 # Qualities.append([40]*len(Ests[i]))
94 use_quality_scores = True
95 else:
96 assert(False)
97
98 # number of training instances
99 N = len(Sequences)
100 self.numExamples = N
101 assert N == len(Acceptors) and N == len(Acceptors) and N == len(Exons)\
102 and N == len(Ests), 'The Seq,Accept,Donor,.. arrays are of different lengths'
103 self.plog('Number of training examples: %d\n'% N)
104 print 'Number of features: %d\n'% Configuration.numFeatures
105
106 iteration_steps = Configuration.iter_steps ; #upper bound on iteration steps
107 remove_duplicate_scores = Configuration.remove_duplicate_scores
108 print_matrix = Configuration.print_matrix
109 anzpath = Configuration.anzpath
110
111 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
112 param = Configuration.fixedParam
113
114 # Set the parameters such as limits penalties for the Plifs
115 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
116
117 # delete splicesite-score-information
118 #if not self.ARGS.train_with_splicesitescoreinformation:
119 # for i in range(len(Acceptors)):
120 # if Acceptors[i] > -20:
121 # Acceptors[i] = 1
122 # if Donors[i] >-20:
123 # Donors[i] = 1
124
125 # Initialize solver
126 if Configuration.USE_OPT:
127 self.plog('Initializing problem...\n')
128 solver = SIQPSolver(Configuration.numFeatures,self.numExamples,Configuration.C,self.logfh)
129
130 # stores the number of alignments done for each example (best path, second-best path etc.)
131 num_path = [anzpath]*N
132 # stores the gap for each example
133 gap = [0.0]*N
134
135 #############################################################################################
136 # Training
137 #############################################################################################
138 self.plog('Starting training...\n')
139
140 donSP = Configuration.numDonSuppPoints
141 accSP = Configuration.numAccSuppPoints
142 lengthSP = Configuration.numLengthSuppPoints
143 mmatrixSP = Configuration.sizeMatchmatrix[0]\
144 *Configuration.sizeMatchmatrix[1]
145 numq = Configuration.numQualSuppPoints
146 totalQualSP = Configuration.totalQualSuppPoints
147
148 currentPhi = zeros((Configuration.numFeatures,1))
149 totalQualityPenalties = zeros((totalQualSP,1))
150
151 iteration_nr = 0
152 param_idx = 0
153 const_added_ctr = 0
154 while True:
155 if iteration_nr == iteration_steps:
156 break
157
158 for exampleIdx in range(self.numExamples):
159 print 'Current example nr %d' % exampleIdx
160
161 if (exampleIdx%10) == 0:
162 print 'Current example nr %d' % exampleIdx
163
164 dna = Sequences[exampleIdx]
165 est = Ests[exampleIdx]
166
167 if Configuration.mode == 'normal':
168 quality = [40]*len(est)
169
170 if Configuration.mode == 'using_quality_scores':
171 quality = Qualities[exampleIdx]
172
173 exons = Exons[exampleIdx]
174 # NoiseMatrix = Noises[exampleIdx]
175 don_supp = Donors[exampleIdx]
176 acc_supp = Acceptors[exampleIdx]
177
178 if exons[-1,1] > len(dna):
179 continue
180
181 #pdb.set_trace()
182 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
183 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
184
185 #print 'trueWeights'
186 # Calculate the weights
187 trueWeightDon, trueWeightAcc, trueWeightIntron = computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
188 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
189
190
191 currentPhi[0:donSP] = mat(d.penalties[:]).reshape(donSP,1)
192 currentPhi[donSP:donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
193 currentPhi[donSP+accSP:donSP+accSP+lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
194 currentPhi[donSP+accSP+lengthSP:donSP+accSP+lengthSP+mmatrixSP] = mmatrix[:]
195
196 #pdb.set_trace()
197 if Configuration.mode == 'using_quality_scores':
198 totalQualityPenalties = param[-totalQualSP:]
199 currentPhi[donSP+accSP+lengthSP+mmatrixSP:] = totalQualityPenalties[:]
200
201 # Calculate w'phi(x,y) the total score of the alignment
202 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
203
204 # The allWeights vector is supposed to store the weight parameter
205 # of the true alignment as well as the weight parameters of the
206 # num_path[exampleIdx] other alignments
207 allWeights = zeros((Configuration.numFeatures,num_path[exampleIdx]+1))
208 allWeights[:,0] = trueWeight[:,0]
209
210 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
211 AlignmentScores[0] = trueAlignmentScore
212
213 ################## Calculate wrong alignment(s) ######################
214
215 # Compute donor, acceptor with penalty_lookup_new
216 # returns two double lists
217 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
218
219 #myalign wants the acceptor site on the g of the ag
220 acceptor = acceptor[1:]
221 acceptor.append(-inf)
222
223 # for now we don't use donor/acceptor scores
224
225 #donor = [-inf] * len(donor)
226 #acceptor = [-inf] * len(donor)
227
228 dna = str(dna)
229 est = str(est)
230 dna_len = len(dna)
231 est_len = len(est)
232
233 ps = h.convert2SWIG()
234
235 prb = QPalmaDP.createDoubleArrayFromList(quality)
236 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
237
238 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
239 mm_len = Configuration.sizeMatchmatrix[0]*Configuration.sizeMatchmatrix[1]
240
241 d_len = len(donor)
242 donor = QPalmaDP.createDoubleArrayFromList(donor)
243 a_len = len(acceptor)
244 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
245
246 # Create the alignment object representing the interface to the C/C++ code.
247 currentAlignment = QPalmaDP.Alignment(Configuration.numQualPlifs,Configuration.numQualSuppPoints, use_quality_scores)
248
249 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
250 #print 'Calling myalign...'
251 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
252 currentAlignment.myalign( num_path[exampleIdx], dna, dna_len,\
253 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
254 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
255 print_matrix)
256
257 #print 'After calling myalign...'
258 #print 'Calling getAlignmentResults...'
259
260 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*num_path[exampleIdx]))
261 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*num_path[exampleIdx]))
262 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*num_path[exampleIdx]))
263 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*num_path[exampleIdx])
264
265 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(Configuration.totalQualSuppPoints*num_path[exampleIdx]))
266
267 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
268 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
269
270 #print 'After calling getAlignmentResults...'
271
272 newSpliceAlign = zeros((num_path[exampleIdx]*dna_len,1))
273 newEstAlign = zeros((est_len*num_path[exampleIdx],1))
274 newWeightMatch = zeros((num_path[exampleIdx]*mm_len,1))
275 newDPScores = zeros((num_path[exampleIdx],1))
276 newQualityPlifsFeatures = zeros((Configuration.totalQualSuppPoints*num_path[exampleIdx],1))
277
278 #print 'newSpliceAlign'
279 for i in range(dna_len*num_path[exampleIdx]):
280 newSpliceAlign[i] = c_SpliceAlign[i]
281 # print '%f' % (spliceAlign[i])
282
283 #print 'newEstAlign'
284 for i in range(est_len*num_path[exampleIdx]):
285 newEstAlign[i] = c_EstAlign[i]
286 # print '%f' % (spliceAlign[i])
287
288 #print 'weightMatch'
289 for i in range(mm_len*num_path[exampleIdx]):
290 newWeightMatch[i] = c_WeightMatch[i]
291 # print '%f' % (weightMatch[i])
292
293 #print 'ViterbiScores'
294 for i in range(num_path[exampleIdx]):
295 newDPScores[i] = c_DPScores[i]
296
297
298 if use_quality_scores:
299 for i in range(Configuration.totalQualSuppPoints*num_path[exampleIdx]):
300 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
301
302 # equals palma up to here
303
304 #print "Calling destructors"
305 del c_SpliceAlign
306 del c_EstAlign
307 del c_WeightMatch
308 del c_DPScores
309 del c_qualityPlifsFeatures
310 del currentAlignment
311
312 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
313 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
314 # Calculate weights of the respective alignments Note that we are
315 # calculating n-best alignments without hamming loss, so we
316 # have to keep track which of the n-best alignments correspond to
317 # the true one in order not to incorporate a true alignment in the
318 # constraints. To keep track of the true and false alignments we
319 # define an array true_map with a boolean indicating the
320 # equivalence to the true alignment for each decoded alignment.
321 true_map = [0]*(num_path[exampleIdx]+1)
322 true_map[0] = 1
323 path_loss = [0]*(num_path[exampleIdx])
324
325 for pathNr in range(num_path[exampleIdx]):
326 #print 'decodedWeights'
327 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,
328 h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,
329 acc_supp,True)
330
331 decodedQualityFeatures = zeros((Configuration.totalQualSuppPoints,1))
332 for qidx in range(Configuration.totalQualSuppPoints):
333 decodedQualityFeatures[qidx] = newQualityPlifsFeatures[(pathNr*Configuration.totalQualSuppPoints)+qidx]
334
335 #pdb.set_trace()
336
337 path_loss[pathNr] = 0
338 # sum up positionwise loss between alignments
339 for alignPosIdx in range(newSpliceAlign[pathNr,:].shape[1]):
340 if newSpliceAlign[pathNr,alignPosIdx] != trueSpliceAlign[alignPosIdx]:
341 path_loss[pathNr] += 1
342
343 #pdb.set_trace()
344
345 # Gewichte in restliche Zeilen der Matrix speichern
346 wp = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures])
347 allWeights[:,pathNr+1] = wp
348
349 hpen = mat(h.penalties).reshape(len(h.penalties),1)
350 dpen = mat(d.penalties).reshape(len(d.penalties),1)
351 apen = mat(a.penalties).reshape(len(a.penalties),1)
352 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties])
353
354 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
355
356 # Check wether scalar product + loss equals viterbi score
357 print 'Example nr.: %d, path nr. %d, scores: %f vs %f' % (exampleIdx,pathNr,newDPScores[pathNr,0], AlignmentScores[pathNr+1])
358
359 distinct_scores = False
360 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
361 distinct_scores = True
362
363 #if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) + [0,1][distinct_scores and (pathNr>0)] <= 1e-5:
364 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
365 pdb.set_trace()
366
367 # # if the pathNr-best alignment is very close to the true alignment consider it as true
368 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
369 true_map[pathNr+1] = 1
370
371 # assert AlignmentScores[0] > max(AlignmentScores[1:]) + 1e-6, pdb.set_trace()
372
373 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
374 # this means that all n-best paths are to close to each other
375 # we have to extend the n-best search to a (n+1)-best
376 if len([elem for elem in true_map if elem == 1]) == len(true_map):
377 num_path[exampleIdx] = num_path[exampleIdx]+1
378
379 # Choose true and first false alignment for extending A
380 firstFalseIdx = -1
381 for map_idx,elem in enumerate(true_map):
382 if elem == 0:
383 firstFalseIdx = map_idx
384 break
385
386 # if there is at least one useful false alignment add the
387 # corresponding constraints to the optimization problem
388 if firstFalseIdx != -1:
389 trueWeights = allWeights[:,0]
390 firstFalseWeights = allWeights[:,firstFalseIdx]
391 differenceVector = trueWeights - firstFalseWeights
392 #pdb.set_trace()
393
394 if Configuration.USE_OPT:
395 const_added = solver.addConstraint(differenceVector, exampleIdx)
396 const_added_ctr += 1
397 #
398 # end of one example processing
399 #
400
401 # call solver every nth example //added constraint
402 if exampleIdx != 0 and exampleIdx % 20 == 0 and Configuration.USE_OPT:
403 objValue,w,self.slacks = solver.solve()
404
405 print "objValue is %f" % objValue
406
407 sum_xis = 0
408 for elem in self.slacks:
409 sum_xis += elem
410
411 for i in range(len(param)):
412 param[i] = w[i]
413
414 #pdb.set_trace()
415 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
416 param_idx += 1
417 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
418
419 #
420 # end of one iteration through all examples
421 #
422 iteration_nr += 1
423
424 #
425 # end of optimization
426 #
427 print 'Training completed'
428
429 pa = para()
430 pa.h = h
431 pa.d = d
432 pa.a = a
433 pa.mmatrix = mmatrix
434 pa.qualityPlifs = qualityPlifs
435
436 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
437 #cPickle.dump(pa,open('elegans.param','w+'))
438 self.logfh.close()
439
440 def fetch_genome_info(ginfo_filename):
441 if not os.path.exists(ginfo_filename):
442 cmd = ['']*4
443 cmd[0] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/utils'
444 cmd[1] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/genomes'
445 cmd[2] = 'genome_info = init_genome(\'%s\')' % gen_file
446 cmd[3] = 'save genome_info.mat genome_info'
447 full_cmd = "matlab -nojvm -nodisplay -r \"%s; %s; %s; %s; exit\"" % (cmd[0],cmd[1],cmd[2],cmd[3])
448
449 obj = subprocess.Popen(full_cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
450 out,err = obj.communicate()
451 assert err == '', 'An error occured!\n%s'%err
452
453 ginfo = scipy.io.loadmat('genome_info.mat')
454 cPickle.dump(self.genome_info,open(ginfo_filename,'w+'))
455 return ginfo['genome_info']
456
457 else:
458 return cPickle.load(open(ginfo_filename))
459
460 if __name__ == '__main__':
461 qpalma = QPalma()
462 qpalma.run()