+ added some testcases
[qpalma.git] / scripts / grid_predict.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import cPickle
5 import sys
6 import time
7 import pdb
8 import os
9 import os.path
10 import math
11
12 from pythongrid import KybJob, Usage
13 from pythongrid import process_jobs, submit_jobs, collect_jobs, get_status
14
15 from qpalma_main import *
16
17 import grid_predict
18
19
20 def get_slices(dataset_size,num_nodes):
21 all_instances = []
22
23 part = dataset_size / num_nodes
24 begin = 0
25 end = 0
26 for idx in range(1,num_nodes+1):
27
28 if idx == num_nodes:
29 begin = end
30 end = dataset_size
31 else:
32 begin = end
33 end = begin+part
34
35 params = (begin,end)
36
37 all_instances.append(params)
38
39 return all_instances
40
41
42 def makeJobs(run,dataset_fn,chunks,param):
43 """
44 """
45
46 jobs=[]
47
48 for c_name,current_chunk in chunks:
49 current_job = KybJob(grid_predict.g_predict,[run,dataset_fn,current_chunk,param,c_name])
50 current_job.h_vmem = '20.0G'
51 #current_job.express = 'True'
52
53 print "job #1: ", current_job.nativeSpecification
54
55 jobs.append(current_job)
56
57 return jobs
58
59
60 def create_and_submit():
61 """
62
63 """
64
65 jp = os.path.join
66
67 run_dir = '/fml/ag-raetsch/home/fabio/tmp/newest_run/alignment/saved_run'
68
69 run = cPickle.load(open(jp(run_dir,'run_obj.pickle')))
70 run['name'] = 'saved_run'
71
72 param = cPickle.load(open(jp(run_dir,'param_526.pickle')))
73
74 #dataset_fn = '/fml/ag-raetsch/home/fabio/tmp/sandbox/dataset_neg_strand_testcase.pickle'
75 #prediction_keys_fn = '/fml/ag-raetsch/home/fabio/tmp/sandbox/dataset_neg_strand_testcase.keys.pickle'
76
77 #dataset_fn = '/fml/ag-raetsch/home/fabio/tmp/transcriptome_data/dataset_transcriptome_run_1.pickle'
78 #prediction_keys_fn = '/fml/ag-raetsch/home/fabio/tmp/transcriptome_data/dataset_transcriptome_run_1.keys.pickle'
79
80 #run['result_dir'] = '/fml/ag-raetsch/home/fabio/tmp/transcriptome_data/run_1/'
81 #dataset_fn = '/fml/ag-raetsch/home/fabio/tmp/transcriptome_data/run_1/dataset_run_1.pickle.pickle'
82 #prediction_keys_fn = '/fml/ag-raetsch/home/fabio/tmp/transcriptome_data/run_1/dataset_run_1.pickle.keys.pickle'
83
84 run['result_dir'] = '/fml/ag-raetsch/home/fabio/tmp/transcriptome_data/run_2/'
85 dataset_fn = '/fml/ag-raetsch/home/fabio/tmp/transcriptome_data/run_2/dataset_run_2.pickle.pickle'
86 prediction_keys_fn = '/fml/ag-raetsch/home/fabio/tmp/transcriptome_data/run_2/dataset_run_2.pickle.keys.pickle'
87
88 prediction_keys = cPickle.load(open(prediction_keys_fn))
89
90 print 'Found %d keys for prediction.' % len(prediction_keys)
91
92 num_splits = 25
93 #num_splits = 1
94 slices = get_slices(len(prediction_keys),num_splits)
95 chunks = []
96 for idx,slice in enumerate(slices):
97 #if idx != 0:
98 c_name = 'chunk_%d' % idx
99 chunks.append((c_name,prediction_keys[slice[0]:slice[1]]))
100
101 functionJobs = makeJobs(run,dataset_fn,chunks,param)
102
103 sum = 0
104 for size in [len(elem) for name,elem in chunks]:
105 sum += size
106
107 #assert sum == len(prediction_keys)
108
109 print 'Got %d job(s)' % len(functionJobs)
110
111 #print "output ret field in each job before sending it onto the cluster"
112 #for (i, job) in enumerate(functionJobs):
113 # print "Job with id: ", i, "- ret: ", job.ret
114
115 #print ""
116 #print "sending function jobs to cluster"
117 #print ""
118
119 #pdb.set_trace()
120
121 (sid, jobids) = submit_jobs(functionJobs)
122
123 #print 'checking whether finished'
124 #while not get_status(sid, jobids):
125 # time.sleep(7)
126 #print 'collecting jobs'
127 #retjobs = collect_jobs(sid, jobids, functionJobs)
128 #print "ret fields AFTER execution on cluster"
129 #for (i, job) in enumerate(retjobs):
130 # print "Job #", i, "- ret: ", job.ret
131
132 #print '--------------'
133
134
135 def g_predict(run,dataset_fn,prediction_keys,param,set_name):
136 """
137
138 """
139
140 qp = QPalma()
141 qp.predict(run,dataset_fn,prediction_keys,param,set_name)
142
143 return 'finished prediction of set %s.' % set_name
144
145
146 if __name__ == '__main__':
147 create_and_submit()