+ renamed main dir in order to create python module hierarchy
[qpalma.git] / qpalma / qpalma.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 ###########################################################
5 #
6 #
7 #
8 ###########################################################
9
10 import sys
11 import subprocess
12 import scipy.io
13 import pdb
14 import os.path
15
16 from numpy.matlib import mat,zeros,ones,inf
17 from numpy.linalg import norm
18
19 import QPalmaDP
20 from SIQP_CPX import SIQPSolver
21
22 from paths_load_data_pickle import *
23 from paths_load_data_solexa import *
24
25 from generateEvaluationData import *
26
27 from computeSpliceWeights import *
28 from set_param_palma import *
29 from computeSpliceAlignWithQuality import *
30 from penalty_lookup_new import *
31 from compute_donacc import *
32 from TrainingParam import Param
33 from export_param import *
34
35 import Configuration
36 from Plif import Plf
37 from Helpers import *
38
39 def getQualityFeatureCounts(qualityPlifs):
40 weightQuality = qualityPlifs[0].penalties
41 for currentPlif in qualityPlifs[1:]:
42 weightQuality = numpy.vstack([weightQuality, currentPlif.penalties])
43
44 return weightQuality
45
46 class QPalma:
47 """
48 A training method for the QPalma project
49 """
50
51 def __init__(self):
52 self.ARGS = Param()
53
54 self.logfh = open('qpalma.log','w+')
55 gen_file= '%s/genome.config' % self.ARGS.basedir
56
57 ginfo_filename = 'genome_info.pickle'
58 self.genome_info = fetch_genome_info(ginfo_filename)
59
60 self.plog('genome_info.basedir is %s\n'%self.genome_info.basedir)
61
62 #self.ARGS.train_with_splicesitescoreinformation = False
63
64 def plog(self,string):
65 self.logfh.write(string)
66 self.logfh.flush()
67
68 def run(self):
69 # Load the whole dataset
70 if Configuration.mode == 'normal':
71 #Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
72 Sequences, Acceptors, Donors, Exons, Ests, Qualities = loadArtificialData(1000)
73 use_quality_scores = False
74 elif Configuration.mode == 'using_quality_scores':
75 Sequences, Acceptors, Donors, Exons, Ests, Qualities = paths_load_data_solexa('training',self.genome_info,self.ARGS)
76
77 #Sequences, Acceptors, Donors, Exons, Ests, Qualities = loadArtificialData(1000)
78
79 #Sequences_, Acceptors_, Donors_, Exons_, Ests_, Qualities_ = generateData(100)
80
81 #Sequences, Acceptors, Donors, Exons, Ests, Noises = paths_load_data_pickle('training',self.genome_info,self.ARGS)
82 #Qualities = []
83 #for i in range(len(Ests)):
84 # Qualities.append([40]*len(Ests[i]))
85 use_quality_scores = True
86 else:
87 assert(False)
88
89 # number of training instances
90 N = len(Sequences)
91 self.numExamples = N
92 assert N == len(Acceptors) and N == len(Acceptors) and N == len(Exons)\
93 and N == len(Ests), 'The Seq,Accept,Donor,.. arrays are of different lengths'
94 self.plog('Number of training examples: %d\n'% N)
95 print 'Number of features: %d\n'% Configuration.numFeatures
96
97 iteration_steps = Configuration.iter_steps ; #upper bound on iteration steps
98 remove_duplicate_scores = Configuration.remove_duplicate_scores
99 print_matrix = Configuration.print_matrix
100 anzpath = Configuration.anzpath
101
102 # Initialize parameter vector / param = numpy.matlib.rand(126,1)
103 param = Configuration.fixedParam
104
105 # Set the parameters such as limits penalties for the Plifs
106 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
107
108 # delete splicesite-score-information
109 if not self.ARGS.train_with_splicesitescoreinformation:
110 for i in range(len(Acceptors)):
111 if Acceptors[i] > -20:
112 Acceptors[i] = 1
113 if Donors[i] >-20:
114 Donors[i] = 1
115
116 # Initialize solver
117 if Configuration.USE_OPT:
118 self.plog('Initializing problem...\n')
119 solver = SIQPSolver(Configuration.numFeatures,self.numExamples,Configuration.C,self.logfh)
120
121 # stores the number of alignments done for each example (best path, second-best path etc.)
122 num_path = [anzpath]*N
123 # stores the gap for each example
124 gap = [0.0]*N
125
126 #############################################################################################
127 # Training
128 #############################################################################################
129 self.plog('Starting training...\n')
130
131 donSP = Configuration.numDonSuppPoints
132 accSP = Configuration.numAccSuppPoints
133 lengthSP = Configuration.numLengthSuppPoints
134 mmatrixSP = Configuration.sizeMatchmatrix[0]\
135 *Configuration.sizeMatchmatrix[1]
136 numq = Configuration.numQualSuppPoints
137 totalQualSP = Configuration.totalQualSuppPoints
138
139 currentPhi = zeros((Configuration.numFeatures,1))
140 totalQualityPenalties = zeros((totalQualSP,1))
141
142 iteration_nr = 0
143 param_idx = 0
144 const_added_ctr = 0
145 while True:
146 if iteration_nr == iteration_steps:
147 break
148
149 for exampleIdx in range(self.numExamples):
150
151 if (exampleIdx%10) == 0:
152 print 'Current example nr %d' % exampleIdx
153
154 dna = Sequences[exampleIdx]
155 est = Ests[exampleIdx]
156
157 if Configuration.mode == 'normal':
158 quality = [40]*len(est)
159
160 if Configuration.mode == 'using_quality_scores':
161 quality = Qualities[exampleIdx]
162
163 exons = Exons[exampleIdx]
164 # NoiseMatrix = Noises[exampleIdx]
165 don_supp = Donors[exampleIdx]
166 acc_supp = Acceptors[exampleIdx]
167
168 if exons[-1,1] > len(dna):
169 continue
170
171 # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)
172 trueSpliceAlign, trueWeightMatch, trueWeightQuality = computeSpliceAlignWithQuality(dna, exons)
173
174 # Calculate the weights
175 trueWeightDon, trueWeightAcc, trueWeightIntron = computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
176 trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
177
178 currentPhi[0:donSP] = mat(d.penalties[:]).reshape(donSP,1)
179 currentPhi[donSP:donSP+accSP] = mat(a.penalties[:]).reshape(accSP,1)
180 currentPhi[donSP+accSP:donSP+accSP+lengthSP] = mat(h.penalties[:]).reshape(lengthSP,1)
181 currentPhi[donSP+accSP+lengthSP:donSP+accSP+lengthSP+mmatrixSP] = mmatrix[:]
182
183 if Configuration.mode == 'using_quality_scores':
184 totalQualityPenalties = param[-totalQualSP:]
185 currentPhi[donSP+accSP+lengthSP+mmatrixSP:] = totalQualityPenalties[:]
186
187 # Calculate w'phi(x,y) the total score of the alignment
188 trueAlignmentScore = (trueWeight.T * currentPhi)[0,0]
189
190 # The allWeights vector is supposed to store the weight parameter
191 # of the true alignment as well as the weight parameters of the
192 # num_path[exampleIdx] other alignments
193 allWeights = zeros((Configuration.numFeatures,num_path[exampleIdx]+1))
194 allWeights[:,0] = trueWeight[:,0]
195
196 AlignmentScores = [0.0]*(num_path[exampleIdx]+1)
197 AlignmentScores[0] = trueAlignmentScore
198
199 ################## Calculate wrong alignment(s) ######################
200
201 # Compute donor, acceptor with penalty_lookup_new
202 # returns two double lists
203 donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
204
205 #myalign wants the acceptor site on the g of the ag
206 acceptor = acceptor[1:]
207 acceptor.append(-inf)
208
209 # for now we don't use donor/acceptor scores
210
211 #donor = [-inf] * len(donor)
212 #acceptor = [-inf] * len(donor)
213
214 dna = str(dna)
215 est = str(est)
216 dna_len = len(dna)
217 est_len = len(est)
218
219 ps = h.convert2SWIG()
220
221 prb = QPalmaDP.createDoubleArrayFromList(quality)
222 chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
223
224 matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
225 mm_len = Configuration.sizeMatchmatrix[0]*Configuration.sizeMatchmatrix[1]
226
227 d_len = len(donor)
228 donor = QPalmaDP.createDoubleArrayFromList(donor)
229 a_len = len(acceptor)
230 acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
231
232 # Create the alignment object representing the interface to the C/C++ code.
233 currentAlignment = QPalmaDP.Alignment(Configuration.numQualPlifs,Configuration.numQualSuppPoints, use_quality_scores)
234
235 c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
236
237 #print 'Calling myalign...'
238 # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
239 currentAlignment.myalign( num_path[exampleIdx], dna, dna_len,\
240 est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
241 acceptor, a_len, c_qualityPlifs, remove_duplicate_scores,
242 print_matrix)
243
244 #print 'After calling myalign...'
245 #print 'Calling getAlignmentResults...'
246
247 c_SpliceAlign = QPalmaDP.createIntArrayFromList([0]*(dna_len*num_path[exampleIdx]))
248 c_EstAlign = QPalmaDP.createIntArrayFromList([0]*(est_len*num_path[exampleIdx]))
249 c_WeightMatch = QPalmaDP.createIntArrayFromList([0]*(mm_len*num_path[exampleIdx]))
250 c_DPScores = QPalmaDP.createDoubleArrayFromList([.0]*num_path[exampleIdx])
251
252 c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(Configuration.totalQualSuppPoints*num_path[exampleIdx]))
253
254 currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
255 c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
256
257 #print 'After calling getAlignmentResults...'
258
259 newSpliceAlign = zeros((num_path[exampleIdx]*dna_len,1))
260 newEstAlign = zeros((est_len*num_path[exampleIdx],1))
261 newWeightMatch = zeros((num_path[exampleIdx]*mm_len,1))
262 newDPScores = zeros((num_path[exampleIdx],1))
263 newQualityPlifsFeatures = zeros((Configuration.totalQualSuppPoints*num_path[exampleIdx],1))
264
265 #print 'newSpliceAlign'
266 for i in range(dna_len*num_path[exampleIdx]):
267 newSpliceAlign[i] = c_SpliceAlign[i]
268 # print '%f' % (spliceAlign[i])
269
270 #print 'newEstAlign'
271 for i in range(est_len*num_path[exampleIdx]):
272 newEstAlign[i] = c_EstAlign[i]
273 # print '%f' % (spliceAlign[i])
274
275 #print 'weightMatch'
276 for i in range(mm_len*num_path[exampleIdx]):
277 newWeightMatch[i] = c_WeightMatch[i]
278 # print '%f' % (weightMatch[i])
279
280 #print 'ViterbiScores'
281 for i in range(num_path[exampleIdx]):
282 newDPScores[i] = c_DPScores[i]
283
284
285 if use_quality_scores:
286 for i in range(Configuration.totalQualSuppPoints*num_path[exampleIdx]):
287 newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
288
289 # equals palma up to here
290
291 #print "Calling destructors"
292 del c_SpliceAlign
293 del c_EstAlign
294 del c_WeightMatch
295 del c_DPScores
296 del c_qualityPlifsFeatures
297 del currentAlignment
298
299 newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
300 newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
301 # Calculate weights of the respective alignments Note that we are
302 # calculating n-best alignments without hamming loss, so we
303 # have to keep track which of the n-best alignments correspond to
304 # the true one in order not to incorporate a true alignment in the
305 # constraints. To keep track of the true and false alignments we
306 # define an array true_map with a boolean indicating the
307 # equivalence to the true alignment for each decoded alignment.
308 true_map = [0]*(num_path[exampleIdx]+1)
309 true_map[0] = 1
310 path_loss = [0]*(num_path[exampleIdx])
311
312 for pathNr in range(num_path[exampleIdx]):
313 weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a, h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp, acc_supp)
314
315 decodedQualityFeatures = zeros((Configuration.totalQualSuppPoints,1))
316 for qidx in range(Configuration.totalQualSuppPoints):
317 decodedQualityFeatures[qidx] = newQualityPlifsFeatures[(pathNr*Configuration.totalQualSuppPoints)+qidx]
318
319 #pdb.set_trace()
320
321 path_loss[pathNr] = 0
322 # sum up positionwise loss between alignments
323 for alignPosIdx in range(newSpliceAlign[pathNr,:].shape[1]):
324 if newSpliceAlign[pathNr,alignPosIdx] != trueSpliceAlign[alignPosIdx]:
325 path_loss[pathNr] += 1
326
327 # Gewichte in restliche Zeilen der Matrix speichern
328 wp = numpy.vstack([weightIntron, weightDon, weightAcc, newWeightMatch[pathNr,:].T, decodedQualityFeatures])
329 allWeights[:,pathNr+1] = wp
330
331 hpen = mat(h.penalties).reshape(len(h.penalties),1)
332 dpen = mat(d.penalties).reshape(len(d.penalties),1)
333 apen = mat(a.penalties).reshape(len(a.penalties),1)
334 features = numpy.vstack([hpen, dpen, apen, mmatrix[:], totalQualityPenalties])
335
336 AlignmentScores[pathNr+1] = (allWeights[:,pathNr+1].T * features)[0,0]
337
338 # Check wether scalar product + loss equals viterbi score
339 print 'Example nr.: %d, path nr. %d, scores: %f vs %f' % (exampleIdx,pathNr,newDPScores[pathNr,0], AlignmentScores[pathNr+1])
340
341 distinct_scores = False
342 if math.fabs(AlignmentScores[pathNr] - AlignmentScores[pathNr+1]) > 1e-5:
343 distinct_scores = True
344
345 #if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) + [0,1][distinct_scores and (pathNr>0)] <= 1e-5:
346 if not math.fabs(newDPScores[pathNr,0] - AlignmentScores[pathNr+1]) <= 1e-5:
347 pdb.set_trace()
348
349 # # if the pathNr-best alignment is very close to the true alignment consider it as true
350 if norm( allWeights[:,0] - allWeights[:,pathNr+1] ) < 1e-5:
351 true_map[pathNr+1] = 1
352
353 # assert AlignmentScores[0] > max(AlignmentScores[1:]) + 1e-6, pdb.set_trace()
354
355 # the true label sequence should not have a larger score than the maximal one WHYYYYY?
356 # this means that all n-best paths are to close to each other
357 # we have to extend the n-best search to a (n+1)-best
358 if len([elem for elem in true_map if elem == 1]) == len(true_map):
359 num_path[exampleIdx] = num_path[exampleIdx]+1
360
361 # Choose true and first false alignment for extending A
362 firstFalseIdx = -1
363 for map_idx,elem in enumerate(true_map):
364 if elem == 0:
365 firstFalseIdx = map_idx
366 break
367
368 # if there is at least one useful false alignment add the
369 # corresponding constraints to the optimization problem
370 if firstFalseIdx != -1:
371 trueWeights = allWeights[:,0]
372 firstFalseWeights = allWeights[:,firstFalseIdx]
373 differenceVector = trueWeights - firstFalseWeights
374 #pdb.set_trace()
375
376 if Configuration.USE_OPT:
377 const_added = solver.addConstraint(differenceVector, exampleIdx)
378 const_added_ctr += 1
379 #
380 # end of one example processing
381 #
382
383 # call solver every nth example //added constraint
384 if exampleIdx != 0 and exampleIdx % 20 == 0 and Configuration.USE_OPT:
385 objValue,w,self.slacks = solver.solve()
386
387 print "objValue is %f" % objValue
388
389 sum_xis = 0
390 for elem in self.slacks:
391 sum_xis += elem
392
393 for i in range(len(param)):
394 param[i] = w[i]
395
396 #pdb.set_trace()
397 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
398 param_idx += 1
399 [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation)
400
401 #
402 # end of one iteration through all examples
403 #
404 iteration_nr += 1
405
406 #
407 # end of optimization
408 #
409 print 'Training completed'
410
411 pa = para()
412 pa.h = h
413 pa.d = d
414 pa.a = a
415 pa.mmatrix = mmatrix
416 pa.qualityPlifs = qualityPlifs
417
418 cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
419 #cPickle.dump(pa,open('elegans.param','w+'))
420 self.logfh.close()
421
422 def fetch_genome_info(ginfo_filename):
423 if not os.path.exists(ginfo_filename):
424 cmd = ['']*4
425 cmd[0] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/utils'
426 cmd[1] = 'addpath /fml/ag-raetsch/home/fabio/svn/tools/genomes'
427 cmd[2] = 'genome_info = init_genome(\'%s\')' % gen_file
428 cmd[3] = 'save genome_info.mat genome_info'
429 full_cmd = "matlab -nojvm -nodisplay -r \"%s; %s; %s; %s; exit\"" % (cmd[0],cmd[1],cmd[2],cmd[3])
430
431 obj = subprocess.Popen(full_cmd,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
432 out,err = obj.communicate()
433 assert err == '', 'An error occured!\n%s'%err
434
435 ginfo = scipy.io.loadmat('genome_info.mat')
436 cPickle.dump(self.genome_info,open(ginfo_filename,'w+'))
437 return ginfo['genome_info']
438
439 else:
440 return cPickle.load(open(ginfo_filename))
441
442 if __name__ == '__main__':
443 qpalma = QPalma()
444 qpalma.run()