+ fixed a bug in the C++ interface
[qpalma.git] / standalone / palma / signal_detectors.py
1 #
2 # This program is free software; you can redistribute it and/or modify
3 # it under the terms of the GNU General Public License as published by
4 # the Free Software Foundation; either version 2 of the License, or
5 # (at your option) any later version.
6 #
7 # Written (W) 2006-2007 Soeren Sonnenburg
8 # Written (W) 2007 Gunnar Raetsch
9 # Copyright (C) 2007 Fraunhofer Institute FIRST and Max-Planck-Society
10 #
11
12 import sys
13 import numpy
14 import seqdict
15 import model
16
17 from shogun.Classifier import SVM
18 from shogun.Features import StringCharFeatures,DNA
19 # from shogun.Kernel import WeightedDegreeCharKernel
20 # In the svn version of shogun, this has been renamed.
21 from shogun.Kernel import WeightedDegreeStringKernel
22
23 class svm_splice_model(object):
24 def __init__(self, order, traindat, alphas, b, (window_left,offset,window_right), consensus):
25
26 f=StringCharFeatures(DNA)
27 f.set_string_features(traindat)
28 # In the svn version of shogun, this has been renamed.
29 wd_kernel = WeightedDegreeStringKernel(f,f, int(order))
30 # wd_kernel = WeightedDegreeCharKernel(f,f, int(order))
31 wd_kernel.io.set_target_to_stderr()
32
33 self.svm=SVM(wd_kernel, alphas, numpy.arange(len(alphas), dtype=numpy.int32), b)
34 self.svm.io.set_target_to_stderr()
35 self.svm.parallel.set_num_threads(self.svm.parallel.get_num_cpus())
36 #self.svm.set_linadd_enabled(False)
37 #self.svm.set_batch_computation_enabled(true)
38
39 self.window_left=int(window_left)
40 self.window_right=int(window_right)
41
42 self.consensus=consensus
43 self.wd_kernel=wd_kernel
44 self.traindat=f
45 self.offset = offset ;
46
47 def get_positions(self, sequence):
48 """DEPRECATED: Please use get_positions_from_seqdict"""
49 print "DEPRECATED: Please use get_positions_from_seqdict"
50 positions=list()
51
52 for cons in self.consensus:
53 l=sequence.find(cons)
54 while l>-1:
55 #if l<len(sequence)-self.window_right-2 and l>self.window_left:
56 positions.append(l+self.offset)
57 l=sequence.find(cons, l+1)
58
59 positions.sort()
60 return positions
61
62 def get_predictions(self, sequence, positions):
63 """DEPRECATED: Please use get_predictions_from_seqdict"""
64 print "DEPRECATED: Please use get_predictions_from_seqdict"
65 seqlen=self.window_right+self.window_left+2
66 num=len(positions)
67
68 #testdat = numpy.chararray((seqlen,num),1,order='FORTRAN')
69 testdat = num*[[]]
70
71 for j in xrange(num):
72 i=positions[j] - self.offset ;
73 start = i-self.window_left
74 if start<0:
75 s_start='A'*(-start)
76 start = 0;
77 else:
78 s_start = ''
79 stop = i+self.window_right+2
80 if stop>len(sequence):
81 s_stop = 'A'*(stop-len(sequence))
82 stop=len(sequence) - 1 ;
83 else:
84 s_stop = '' ;
85 s= s_start + sequence[start:stop] + s_stop
86 testdat[:,j]=list(s)
87 testdat[j]=s
88
89 t=StringCharFeatures(DNA)
90 t.set_string_features(testdat)
91
92 self.wd_kernel.init(self.traindat, t)
93 l=self.svm.classify().get_labels()
94 sys.stderr.write("\n...done...\n")
95 return l
96
97
98 def get_predictions_from_seqdict(self, seqdic, site):
99 """ we need to generate a huge test features object
100 containing all locations found in each seqdict-sequence
101 and each location (this is necessary to efficiently
102 (==fast,low memory) compute the splice outputs
103 """
104
105 #seqlen=self.window_right+self.window_left+2
106
107 num=0
108 for s in seqdic:
109 num+= len(s.preds[site].positions)
110
111 #testdat = numpy.chararray((seqlen,num), 1, order='FORTRAN')
112 testdat = num*[[]]
113
114 k=0
115 si = 0 ;
116 for s in seqdic:
117 sequence=s.seq
118 positions=s.preds[site].positions
119 si += 1
120 for j in xrange(len(positions)):
121 if len(positions)>50000 and j%10000==0:
122 sys.stderr.write('sequence %i: progress %1.2f%%\r' %(si,(100*j/len(positions))))
123 i=positions[j] - self.offset
124 start = i-self.window_left
125 if start<0:
126 s_start='A'*(-start)
127 start = 0;
128 else:
129 s_start = ''
130 stop = i+self.window_right+2
131 if stop>len(sequence):
132 s_stop = 'A'*(stop-len(sequence) +1)
133 stop=len(sequence) - 1 ;
134 else:
135 s_stop = '' ;
136 s= s_start + sequence[start:stop] + s_stop
137 #print len(s)
138 #s=sequence[i-self.window_left:i+self.window_right+2]
139 #testdat[:,k]=list(s)
140 testdat[k]=s
141 k+=1
142
143 if len(positions)>50000:
144 sys.stderr.write('\n')
145
146 t=StringCharFeatures(DNA)
147 t.set_string_features(testdat)
148
149 self.wd_kernel.init(self.traindat, t)
150 l=self.svm.classify().get_labels()
151 sys.stderr.write("\n...done...\n")
152
153 k=0
154 for s in seqdic:
155 num=len(s.preds[site].positions)
156 scores= num * [0]
157 for j in xrange(num):
158 scores[j]=l[k]
159 k+=1
160 s.preds[site].set_scores(scores)
161
162 def get_positions_from_seqdict(self, seqdic, site):
163 for d in seqdic:
164 positions=list()
165 sequence=d.seq
166 for cons in self.consensus:
167 l=sequence.find(cons)
168 while l>-1:
169 #if l<len(sequence)-self.window_right-2 and l>self.window_left:
170 positions.append(l+self.offset)
171 l=sequence.find(cons, l+1)
172 positions.sort()
173 d.preds[site].set_positions(positions)
174
175
176 class signal_detectors(object):
177 def __init__(self, model, donor_splice_use_gc):
178 if donor_splice_use_gc:
179 donor_consensus=['GC','GT']
180 else:
181 donor_consensus=['GT']
182
183 self.acceptor=svm_splice_model(model.acceptor_splice_order, model.acceptor_splice_svs,
184 numpy.array(model.acceptor_splice_alphas).flatten(), model.acceptor_splice_b,
185 (model.acceptor_splice_window_left, 2, model.acceptor_splice_window_right), ['AG'])
186 self.donor=svm_splice_model(model.donor_splice_order, model.donor_splice_svs,
187 numpy.array(model.donor_splice_alphas).flatten(), model.donor_splice_b,
188 (model.donor_splice_window_left, 0, model.donor_splice_window_right),
189 donor_consensus)
190
191 def set_sequence(self, seq):
192 self.acceptor.set_sequence(seq)
193 self.donor.set_sequence(seq)
194
195 def predict_acceptor_sites(self, seq):
196 pos=self.acceptor.get_positions(seq)
197 sys.stderr.write("computing svm output for acceptor positions\n")
198 pred=self.acceptor.get_predictions(seq, pos)
199 return (pos,pred)
200
201 def predict_donor_sites(self,seq):
202 pos=self.donor.get_positions(seq)
203 sys.stderr.write("computing svm output for donor positions\n")
204 pred=self.donor.get_predictions(seq, pos)
205 return (pos,pred)
206
207 def predict_acceptor_sites_from_seqdict(self, seqs):
208 self.acceptor.get_positions_from_seqdict(seqs, 'acceptor')
209 sys.stderr.write("computing svm output for acceptor positions\n")
210 self.acceptor.get_predictions_from_seqdict(seqs, 'acceptor')
211
212 def predict_donor_sites_from_seqdict(self, seqs):
213 self.donor.get_positions_from_seqdict(seqs, 'donor')
214 sys.stderr.write("computing svm output for donor positions\n")
215 self.donor.get_predictions_from_seqdict(seqs, 'donor')
216