+ minor changes
[qpalma.git] / scripts / grid_predict.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import pdb
7 import os
8 import os.path
9 import math
10
11 from pythongrid import Job, KybJob, MethodJob, processJobs, Usage, processJobsLocally
12
13 from qpalma_main import *
14
15 import grid_predict
16
17
18 def get_slices(dataset_size,num_nodes):
19 all_instances = []
20
21 part = dataset_size / num_nodes
22 begin = 0
23 end = 0
24 for idx in range(1,num_nodes+1):
25
26 if idx == num_nodes:
27 begin = end
28 end = dataset_size
29 else:
30 begin = end
31 end = begin+part
32
33 params = (begin,end)
34
35 all_instances.append(params)
36
37 return all_instances
38
39
40 def makeJobs(run,dataset_fn,chunks,param):
41 """
42 """
43
44 jobs=[]
45
46 for c_name,current_chunk in chunks:
47 current_job = KybJob(grid_predict.g_predict,[run,dataset_fn,current_chunk,param,c_name])
48 current_job.h_vmem = '5.0G'
49 #current_job.express = 'True'
50
51 print "job #1: ", current_job.nativeSpecification
52
53 jobs.append(current_job)
54
55 return jobs
56
57
58 def create_and_submit():
59 """
60
61 """
62
63 jp = os.path.join
64
65 run_dir = '/fml/ag-raetsch/home/fabio/tmp/newest_run/alignment/run_enable_quality_scores_+_enable_splice_signals_+_enable_intron_length_+'
66
67 run = cPickle.load(open(jp(run_dir,'run_obj.pickle')))
68 param = cPickle.load(open(jp(run_dir,'param_526.pickle')))
69
70 dataset_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/dataset_12_05_08.test.pickle'
71 prediction_keys_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/dataset_12_05_08.test_keys.pickle'
72
73 prediction_keys = cPickle.load(open(prediction_keys_fn))
74
75 print 'Found %d keys for prediction.' % len(prediction_keys)
76
77 num_splits = 15
78 slices = get_slices(len(prediction_keys),num_splits)
79 chunks = []
80 for idx,slice in enumerate(slices):
81 c_name = 'chunk_%d' % idx
82 chunks.append((c_name,prediction_keys[slice[0]:slice[1]]))
83
84 functionJobs = makeJobs(run,dataset_fn,chunks,param)
85
86 sum = 0
87 for size in [len(elem) for name,elem in chunks]:
88 sum += size
89
90 assert sum == len(prediction_keys)
91
92 print 'Got %d job(s)' % len(functionJobs)
93
94 print "output ret field in each job before sending it onto the cluster"
95 for (i, job) in enumerate(functionJobs):
96 print "Job with id: ", i, "- ret: ", job.ret
97
98 print ""
99 print "sending function jobs to cluster"
100 print ""
101
102 processedFunctionJobs = processJobs(functionJobs)
103
104 print "ret fields AFTER execution on cluster"
105 for (i, job) in enumerate(processedFunctionJobs):
106 print "Job with id: ", i, "- ret: ", job.ret
107
108
109 def g_predict(run,dataset_fn,prediction_keys,param,set_name):
110 """
111
112 """
113
114 qp = QPalma()
115 qp.predict(run,dataset_fn,prediction_keys,param,set_name)
116
117 return 'finished prediction of set %s.' % set_name
118
119
120 if __name__ == '__main__':
121 create_and_submit()