+ restructured test cases
[qpalma.git] / qpalma / gridtools.py
1 # This program is free software; you can redistribute it and/or modify
2 # it under the terms of the GNU General Public License as published by
3 # the Free Software Foundation; either version 2 of the License, or
4 # (at your option) any later version.
5 #
6 # Written (W) 2008 Fabio De Bona
7 # Copyright (C) 2008 Max-Planck-Society
8
9 import cPickle
10 import math
11 import time
12 import os
13 import os.path
14 import pdb
15 import sys
16 import time
17
18 from threading import Thread
19
20 from pythongrid import KybJob, Usage
21 from pythongrid import process_jobs, submit_jobs, collect_jobs, get_status
22
23 from qpalma.OutputFormat import createAlignmentOutput
24
25 from PipelineHeuristic import *
26
27 import gridtools
28
29 from utils import get_slices,split_file,combine_files
30
31 from qpalma.sequence_utils import SeqSpliceInfo,DataAccessWrapper
32
33 from qpalma_main import QPalma
34
35 jp = os.path.join
36 pjoin = lambda *args: reduce(lambda x,y: jp(x,y),args)
37
38
39 class ClusterTask(Thread):
40 """
41 This class..
42
43 Every task creates a status file.
44 All cluster jobs submit then their exit status to this status file.
45
46 Every cluster task subclass should override/implement the methods:
47
48 1. __init__
49 2. CreateJobs
50 3. TaskStarter
51
52 """
53
54 def __init__(self,settings):
55 self.sleep_time = 0
56
57 # this list stores the cluster/local jobs objects
58 self.functionJobs = []
59
60 # this object stores the configuration
61 self.settings = settings
62
63
64 def CreateJobs(self):
65 """
66 This method create an array of jobs called self.functionJobs. It is only
67 virtual in this base class and has to be implemented specifically for
68 each cluster task.
69 """
70 pass
71
72
73 def Submit(self):
74 """
75 After creation of jobs this function submits them to the cluster.
76 """
77 self.sid, self.jobids = submit_jobs(self.functionJobs)
78 #self.processedFunctionJobs = process_jobs(self.functionJobs,local=True,maxNumThreads=1)
79
80
81 def Restart(self,id):
82 pass
83
84
85 def collectResults(self):
86 pass
87
88
89 def CheckIfTaskFinished(self):
90 """
91 This function is responsible for checking whether the submitted jobs were
92 completed successfully.
93 """
94
95 print 'collecting jobs'
96 retjobs = collect_jobs(self.sid, self.jobids, self.functionJobs, True)
97 print "ret fields AFTER execution on cluster"
98 for (i, job) in enumerate(retjobs):
99 print "Job #", i, "- ret: ", job.ret
100
101 print '--------------'
102
103 self.collectResults()
104
105
106 class ApproximationTask(ClusterTask):
107 """
108 This task represents the first step towards a valid QPalma dataset.
109 It starts an approximative QPalma model on the putative unspliced reads to
110 identify true spliced reads.
111 """
112
113 def __init__(self,settings):
114 ClusterTask.__init__(self,settings)
115
116
117 def CreateJobs(self):
118 """
119 Create...
120 """
121
122 num_splits = self.settings['num_splits']
123
124 #run_dir = '/fml/ag-raetsch/home/fabio/tmp/newest_run/alignment/run_enable_quality_scores_+_enable_splice_signals_+_enable_intron_length_+'
125 #param_fname = jp(run_dir,'param_526.pickle')
126 param_fname = self.settings['prediction_param_fn']
127 #run_fname = jp(run_dir,'run_obj.pickle')
128
129 #result_dir = '/fml/ag-raetsch/home/fabio/tmp/vmatch_evaluation/main'
130 result_dir = self.settings['approximation_dir']
131
132 original_map_fname = self.settings['unspliced_reads_fn']
133 split_file(original_map_fname,result_dir,num_splits)
134
135 self.result_files = []
136
137 for idx in range(0,num_splits):
138 data_fname = jp(result_dir,'map.part_%d'%idx)
139 result_fname = jp(result_dir,'map.vm.part_%d'%idx)
140 self.result_files.append(result_fname)
141
142 current_job = KybJob(gridtools.ApproximationTaskStarter,[data_fname,param_fname,result_fname,self.settings])
143 current_job.h_vmem = '25.0G'
144 #current_job.express = 'True'
145
146 print "job #1: ", current_job.nativeSpecification
147
148 self.functionJobs.append(current_job)
149
150
151 def collectResults(self):
152 result_dir = self.settings['approximation_dir']
153 combined_fn = jp(result_dir,'map.vm.spliced')
154 self.result_files = map(lambda x:x+'.spliced',self.result_files)
155 combine_files(self.result_files,combined_fn)
156 combine_files([combined_fn,self.settings['spliced_reads_fn']],'map.vm')
157
158
159 def ApproximationTaskStarter(data_fname,param_fname,result_fname,settings):
160 ph1 = PipelineHeuristic(data_fname,param_fname,result_fname,settings)
161 ph1.filter()
162
163 return 'finished filtering set %s.' % data_fname
164
165
166 class PreprocessingTask(ClusterTask):
167 """
168 This class encapsules some...
169 """
170
171 def __init__(self):
172 ClusterTask.__init__(self)
173
174
175 class AlignmentTask(ClusterTask):
176 """
177 This task represents the main part of QPalma.
178 """
179
180 def __init__(self,settings):
181 ClusterTask.__init__(self,settings)
182
183
184 def CreateJobs(self):
185 """
186
187 """
188
189 num_splits = self.settings['num_splits']
190
191 jp = os.path.join
192
193 dataset_fn = self.settings['prediction_dataset_fn']
194 prediction_keys_fn = self.settings['prediction_dataset_keys_fn']
195
196 prediction_keys = cPickle.load(open(prediction_keys_fn))
197
198 print 'Found %d keys for prediction.' % len(prediction_keys)
199
200 slices = get_slices(len(prediction_keys),num_splits)
201 chunks = []
202 for idx,slice in enumerate(slices):
203 #if idx != 0:
204 c_name = 'chunk_%d' % idx
205 chunks.append((c_name,prediction_keys[slice[0]:slice[1]]))
206
207 for c_name,current_chunk in chunks:
208 current_job = KybJob(gridtools.AlignmentTaskStarter,[self.settings,dataset_fn,current_chunk,c_name])
209 current_job.h_vmem = '2.0G'
210 current_job.express = 'True'
211
212 print "job #1: ", current_job.nativeSpecification
213
214 self.functionJobs.append(current_job)
215
216 sum = 0
217 for size in [len(elem) for name,elem in chunks]:
218 sum += size
219
220 print 'Got %d job(s)' % len(self.functionJobs)
221
222
223 def AlignmentTaskStarter(settings,dataset_fn,prediction_keys,set_name):
224 """
225
226 """
227 accessWrapper = DataAccessWrapper(settings)
228 seqInfo = SeqSpliceInfo(accessWrapper,settings['allowed_fragments'])
229 qp = QPalma(seqInfo)
230 qp.init_prediction(dataset_fn,prediction_keys,settings,set_name)
231 return 'finished prediction of set %s.' % set_name
232
233
234
235 class TrainingTask(ClusterTask):
236 """
237 This class represents the cluster task of training QPalma.
238 """
239
240 def __init__(self):
241 ClusterTask.__init__(self)
242
243
244 def CreateJobs(self):
245 """
246
247 """
248
249 jp = os.path.join
250
251 dataset_fn = self.settings['training_dataset_fn']
252 training_keys = cPickle.load(open(self.settings['training_dataset_keys_fn']))
253
254 print 'Found %d keys for training.' % len(training_keys)
255
256 set_name = 'training_set'
257
258 current_job = KybJob(gridtools.AlignmentTaskStarter,[self.settings,dataset_fn,training_keys,set_name])
259 current_job.h_vmem = '2.0G'
260 current_job.express = 'True'
261
262 print "job #1: ", current_job.nativeSpecification
263
264 self.functionJobs.append(current_job)
265
266 print 'Got %d job(s)' % len(self.functionJobs)
267
268
269 def collectResults(self):
270 pass
271
272 def TrainingTaskStarter(settings,dataset_fn,training_keys,set_name):
273 accessWrapper = DataAccessWrapper(settings)
274 seqInfo = SeqSpliceInfo(accessWrapper,settings['allowed_fragments'])
275 qp = QPalma(seqInfo)
276 qp.init_training(dataset_fn,training_keys,settings,set_name)
277 return 'finished prediction of set %s.' % set_name
278
279
280 class PostprocessingTask(ClusterTask):
281 """
282 After QPalma predicted alignments this task postprocesses the data.
283 """
284
285 def __init__(self,settings):
286 ClusterTask.__init__(self,settings)
287
288
289 def CreateJobs(self):
290 run_dir = self.settings['prediction_dir']
291 self.result_dir = self.settings['alignment_dir']
292
293 chunks_fn = []
294 for fn in os.listdir(run_dir):
295 if fn.startswith('chunk'):
296 chunks_fn.append(fn)
297
298 functionJobs=[]
299
300 self.result_files = []
301 for chunk_fn in chunks_fn:
302 chunk_name = chunk_fn[:chunk_fn.find('.')]
303 result_fn = jp(self.result_dir,'%s.%s'%(chunk_name,self.settings['output_format']))
304 chunk_fn = jp(run_dir,chunk_fn)
305
306 self.result_files.append(result_fn)
307
308 current_job = KybJob(gridtools.PostProcessingTaskStarter,[self.settings,chunk_fn,result_fn])
309 current_job.h_vmem = '2.0G'
310 current_job.express = 'True'
311
312 print "job #1: ", current_job.nativeSpecification
313
314 self.functionJobs.append(current_job)
315
316
317 def collectResults(self):
318 combined_fn = jp(self.result_dir,'all_alignments.%s'%self.settings['output_format'])
319 combine_files(self.result_files,combined_fn)
320
321
322 def PostProcessingTaskStarter(settings,chunk_fn,result_fn):
323 createAlignmentOutput(settings,chunk_fn,result_fn)
324
325
326
327 class DummyTask(ClusterTask):
328 """
329 This class represents a dummy task to make debugging easier.
330 """
331 def __init__(self):
332 ClusterTask.__init__(self)
333
334
335 def DummyTaskStarter(param):
336 create_alignment_file(chunk_fn,result_fn)
337
338
339 def CreateJobs(self):
340 run_dir = '/fml/ag-raetsch/home/fabio/tmp/vmatch_evaluation/spliced_1/prediction'
341 result_dir = '/fml/ag-raetsch/home/fabio/tmp/vmatch_evaluation/spliced_1/alignment'
342
343 for chunk_fn in chunks_fn:
344 chunk_name = chunk_fn[:chunk_fn.find('.')]
345 result_fn = jp(result_dir,'%s.align_remap'%chunk_name)
346 chunk_fn = jp(run_dir,chunk_fn)
347
348 current_job = KybJob(grid_alignment.DummyTaskStarter,[chunk_fn,result_fn])
349 current_job.h_vmem = '15.0G'
350 current_job.express = 'True'
351
352 print "job #1: ", current_job.nativeSpecification