+ adapted training method to use new dataset standard
authorfabio <fabio@e1793c9e-67f9-0310-80fc-b846ff1f7b36>
Tue, 13 May 2008 11:34:37 +0000 (11:34 +0000)
committerfabio <fabio@e1793c9e-67f9-0310-80fc-b846ff1f7b36>
Tue, 13 May 2008 11:34:37 +0000 (11:34 +0000)
git-svn-id: http://svn.tuebingen.mpg.de/ag-raetsch/projects/QPalma@8984 e1793c9e-67f9-0310-80fc-b846ff1f7b36

scripts/qpalma_main.py

index 0e8814a..2dfc753 100644 (file)
@@ -71,11 +71,11 @@ def unbracket_est(est):
    return "".join(new_est).lower()
 
 
-def getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run):
-   currentSeqInfo = SeqInfo[exampleIdx]
+def getData(training_set,exampleKey,run):
+   currentSeqInfo,currentExons,original_est,currentQualities = training_set[exampleKey]
    id,chr,strand,up_cut,down_cut = currentSeqInfo
 
-   est = OriginalEsts[exampleIdx] 
+   est = original_est
    est = "".join(est)
    est = est.lower()
    est = unbracket_est(est)
@@ -84,7 +84,7 @@ def getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run):
    assert len(est) == run['read_size'], pdb.set_trace()
    est_len = len(est)
 
-   original_est = OriginalEsts[exampleIdx] 
+   #original_est = OriginalEsts[exampleIdx]
    original_est = "".join(original_est)
    original_est = original_est.lower()
 
@@ -98,47 +98,40 @@ def getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run):
    gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
    assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
 
-   original_exons = Exons[exampleIdx]
-
+   #original_exons = Exons[exampleIdx]
+   original_exons = currentExons
    exons = original_exons - (up_cut-1)
    exons[0,0] -= 1
    exons[1,0] -= 1
 
    if exons.shape == (2,2):
       fetched_dna_subseq = dna[exons[0,0]:exons[0,1]] + dna[exons[1,0]:exons[1,1]]
-      
+     
       donor_elem = dna[exons[0,1]:exons[0,1]+2]
       acceptor_elem = dna[exons[1,0]-2:exons[1,0]]
 
       if not ( donor_elem == 'gt' or donor_elem == 'gc' ):
-         print 'invalid donor in example %d'% exampleIdx
+         print 'invalid donor in example %d'% exampleKey
          raise SpliceSiteException
 
       if not ( acceptor_elem == 'ag' ):
-         print 'invalid acceptor in example %d'% exampleIdx
+         print 'invalid acceptor in example %d'% exampleKey
          raise SpliceSiteException
 
       assert len(fetched_dna_subseq) == len(est), pdb.set_trace()
 
-   return dna,est,acc_supp,don_supp,exons,original_est
+   return dna,est,acc_supp,don_supp,exons,original_est,currentQualities
+
 
 
 class QPalma:
    """
-   A training method for the QPalma project
+   This class wraps the training and prediction functions for 
+   the alignment.
    """
    
-   def __init__(self,run):
+   def __init__(self):
       self.ARGS = Param()
-      self.run = run
-
-      if self.run['mode'] == 'normal':
-         self.use_quality_scores = False
-
-      elif self.run['mode'] == 'using_quality_scores':
-         self.use_quality_scores = True
-      else:
-         assert(False)
 
 
    def plog(self,string):
@@ -156,8 +149,6 @@ class QPalma:
       dna_len = len(dna)
       est_len = len(est)
 
-      #pdb.set_trace()
-
       prb = QPalmaDP.createDoubleArrayFromList(quality)
       chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
 
@@ -207,8 +198,6 @@ class QPalma:
       currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
       c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
 
-      #print 'After calling getAlignmentResults...'
-
       newSpliceAlign = zeros((current_num_path*dna_len,1))
       newEstAlign    = zeros((est_len*current_num_path,1))
       newWeightMatch = zeros((current_num_path*mm_len,1))
@@ -242,10 +231,10 @@ class QPalma:
       newQualityPlifsFeatures, dna_array, est_array
 
 
-   def train(self):
-      run = self.run
+   def train(self,run,training_set):
+      self.run = run
 
-      full_working_path = os.path.join(run['experiment_path'],run['name'])
+      full_working_path = os.path.join(run['alignment_dir'],run['name'])
 
       #assert not os.path.exists(full_working_path)
       if not os.path.exists(full_working_path):
@@ -256,23 +245,11 @@ class QPalma:
       # ATTENTION: Changing working directory
       os.chdir(full_working_path)
 
-      cPickle.dump(run,open('run_object.pickle','w+'))
-
       self.logfh = open('_qpalma_train.log','w+')
 
       self.plog("Settings are:\n")
       self.plog("%s\n"%str(run))
 
-      data_filename = self.run['dataset_filename']
-
-      dataset = cPickle.load(open(data_filename))
-
-      SeqInfo, OriginalEsts, Qualities  = dataset
-
-      #SeqInfo, Exons, OriginalEsts, Qualities,\
-      #AlternativeSequences = paths_load_data(data_filename)
-
-      # Load the whole dataset 
       if self.run['mode'] == 'normal':
          self.use_quality_scores = False
 
@@ -280,27 +257,8 @@ class QPalma:
          self.use_quality_scores = True
       else:
          assert(False)
-
-      self.SeqInfo     = SeqInfo
-      #self.Exons       = Exons
-      self.OriginalEsts= OriginalEsts
-      self.Qualities   = Qualities
-
-      #calc_info(self.Exons,self.Qualities)
-
-      beg = run['training_begin']
-      end = run['training_end']
-
-      SeqInfo     = SeqInfo[beg:end]
-      Exons       = Exons[beg:end]
-      OriginalEsts= OriginalEsts[beg:end]
-      Qualities   = Qualities[beg:end]
-
-      # number of training instances
-      N = numExamples = len(SeqInfo)
-      assert len(Exons) == N and len(OriginalEsts) == N and len(Qualities) == N,\
-      'The Exons,Acc,Don,.. arrays are of different lengths'
-
+   
+      numExamples = len(training_set)
       self.plog('Number of training examples: %d\n'% numExamples)
 
       self.noImprovementCtr = 0
@@ -335,9 +293,9 @@ class QPalma:
       try:
          solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
       except:
+         self.plog('Got no license. Telling queue to reschedule job...\n')
          sys.exit(99)
 
-      #solver = None
       #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
       #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
 
@@ -368,27 +326,21 @@ class QPalma:
          if iteration_nr == iteration_steps:
             break
 
-         for exampleIdx in range(numExamples):
-            if (exampleIdx%100) == 0:
-               print 'Current example nr %d' % exampleIdx
-
+         for exampleIdx,example_key in enumerate(training_set.keys()):
+            print 'Current example %d' % example_key
             try:
-               dna,est,acc_supp,don_supp,exons,original_est =\
-               getData(SeqInfo,OriginalEsts,Exons,exampleIdx,run)
+               dna,est,acc_supp,don_supp,exons,original_est,currentQualities =\
+               getData(training_set,example_key,run)
             except SpliceSiteException:
                continue
 
             dna_len = len(dna)
 
-            #if new_string != original_est:
-            #   print new_string,original_est
-            #   continue
-
             if run['mode'] == 'normal':
                quality = [40]*len(est)
 
             if run['mode'] == 'using_quality_scores':
-               quality = Qualities[exampleIdx][0]
+               quality = currentQualities[0]
 
             if not run['enable_quality_scores']:
                quality = [40]*len(est)
@@ -402,9 +354,6 @@ class QPalma:
                   if elem != -inf:
                      acc_supp[idx] = 0.0
 
-            #pdb.set_trace()
-            #continue
-
             # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)    
             if run['mode'] == 'using_quality_scores':
                trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
@@ -452,7 +401,6 @@ class QPalma:
             #acceptor.append(-inf)
 
             #donor = [-inf] + donor[:-1]
-            #pdb.set_trace()
 
             ps = h.convert2SWIG()
 
@@ -461,8 +409,6 @@ class QPalma:
             self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
             mm_len = run['matchmatrixRows']*run['matchmatrixCols']
 
-            # old code removed
-
             newSpliceAlign = _newSpliceAlign
             newEstAlign    = _newEstAlign
             newWeightMatch = _newWeightMatch
@@ -484,8 +430,6 @@ class QPalma:
             true_map[0] = 1
 
             for pathNr in range(num_path[exampleIdx]):
-               #print 'decodedWeights' 
-               #print 'right before computeSpliceWeights(2) exampleIdx %d' % exampleIdx
                weightDon, weightAcc, weightIntron = computeSpliceWeights(d, a,\
                h, newSpliceAlign[pathNr,:].flatten().tolist()[0], don_supp,\
                acc_supp)
@@ -642,46 +586,24 @@ class QPalma:
 #
 ###############################################################################
 
-   def evaluate(self,param_filename):
-      run = self.run
-      beg = run['prediction_begin']
-      end = run['prediction_end']
-
-      data_filename = self.run['dataset_filename']
-
-      dataset = cPickle.load(open(data_filename))
-      SeqInfo, OriginalEsts, Qualities  = dataset
-
-      #AlternativeSequences = paths_load_data(data_filename,'training',None,self.ARGS)
-
-      self.SeqInfo     = SeqInfo
-      #self.Exons       = Exons
-      self.OriginalEsts= OriginalEsts
-      self.Qualities   = Qualities
-      #self.AlternativeSequences = AlternativeSequences
+   def predict(self,run,prediction_set,param):
+      """
+      Performing a prediction takes...
+      """
+      self.run = run
 
-      #calc_info(self.Acceptors,self.Donors,self.Exons,self.Qualities)
-      #print 'leaving constructor...'
+      full_working_path = os.path.join(run['alignment_dir'],run['name'])
 
-      self.logfh = open('_qpalma_predict_%d.log'%run['id'],'w+')
+      #assert not os.path.exists(full_working_path)
+      if not os.path.exists(full_working_path):
+         os.mkdir(full_working_path)
 
-      # predict on training set
-      #self.plog('##### Prediction on the training set #####\n')
-      #self.predict(param_filename,0,beg,'TRAIN')
-      
-      # predict on test set
-      self.plog('##### Prediction on the test set #####\n')
-      self.predict(param_filename,beg,end,'TEST')
-   
-      self.plog('##### Finished prediction #####\n')
-      self.logfh.close()
+      assert os.path.exists(full_working_path)
 
-   def predict(self,param_filename,beg,end,set_flag):
-      """
-      Performing a prediction takes...
+      # ATTENTION: Changing working directory
+      os.chdir(full_working_path)
 
-      """
-      run = self.run
+      self.logfh = open('_qpalma_train.log','w+')
 
       if self.run['mode'] == 'normal':
          self.use_quality_scores = False
@@ -691,25 +613,8 @@ class QPalma:
       else:
          assert(False)
 
-      SeqInfo     = self.SeqInfo[beg:end]
-      OriginalEsts= self.OriginalEsts[beg:end]
-      Qualities   = self.Qualities[beg:end]
-
-      # number of training instances
-      N = numExamples = len(SeqInfo) 
-      assert len(OriginalEsts) == N and len(Qualities) == N,\
-      'The Exons,Acc,Don,.. arrays are of different lengths'
-
-      self.plog('Number of training examples: %d\n'% numExamples)
-
-      #self.noImprovementCtr = 0
-      #self.oldObjValue = 1e8
-
-      #remove_duplicate_scores = self.run['remove_duplicate_scores']
-      #print_matrix            = self.run['print_matrix']
-      #anzpath                 = self.run['anzpath']
-
-      param = cPickle.load(open(param_filename))
+      # number of prediction instances
+      self.plog('Number of prediction examples: %d\n'% len(prediction_set))
 
       # Set the parameters such as limits penalties for the Plifs
       [h,d,a,mmatrix,qualityPlifs] =\
@@ -735,35 +640,33 @@ class QPalma:
       allPredictions = []
 
       # beginning of the prediction loop
-      for exampleIdx in range(numExamples):
+      for example_key in prediction_set.keys():
+         print 'Current example %d' % example_key
 
-         currentSeqInfo = SeqInfo[exampleIdx]
+         currentSeqInfo,original_est,currentQualities = prediction_set[id]
 
          id,chr,strand,genomicSeq_start,genomicSeq_stop =\
          currentSeqInfo 
 
+         if not chr in range(1,6):
+            continue
+
          self.plog('Loading example nr. %d (id: %d)...\n'%(exampleIdx,int(id)))
 
-         est = OriginalEsts[exampleIdx]
+         est = original_est
          est = unbracket_est(est)
 
          if run['mode'] == 'normal':
             quality = [40]*len(est)
 
          if run['mode'] == 'using_quality_scores':
-            quality = Qualities[exampleIdx][0]
+            quality = currentQualities[0]
 
          if not run['enable_quality_scores']:
             quality = [40]*len(est)
 
          current_example_predictions = []
 
-         # then make predictions for all dna fragments that where occurring in
-         # the vmatch results
-
-         if not chr in range(1,6):
-            continue
-
          try:
             currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
          except:
@@ -794,6 +697,7 @@ class QPalma:
       cPickle.dump(allPredictions,open('%s_allPredictions_%s'%(run['name'],set_flag),'w+'))
       print 'Prediction completed'
       print 'Problem ctr %d' % problem_ctr
+      self.logfh.close()
 
 
    def calc_alignment(self, dna, est, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
@@ -866,28 +770,25 @@ class QPalma:
 
       return newExons
 
-
 ###########################
 # A simple command line 
 # interface
 ###########################
 
 if __name__ == '__main__':
-   mode = sys.argv[1]
-   run_obj_fn = sys.argv[2]
-
-   run_obj = cPickle.load(open(run_obj_fn))
+   assert len(sys.argv) == 4
 
-   qpalma = QPalma(run_obj)
+   run_fn      = sys.argv[1]
+   dataset_fn  = sys.argv[2]
+   param_fn    = sys.argv[3]
 
+   run_obj = cPickle.load(open(run_fn))
+   dataset_obj = cPickle.load(open(dataset_fn))
 
-   if len(sys.argv) == 3 and mode == 'train':
-      qpalma.train()
+   qpalma = QPalma()
 
-   elif len(sys.argv) == 4 and mode == 'predict':
-      param_filename = sys.argv[3]
-      assert os.path.exists(param_filename)
-      qpalma.evaluate(param_filename)
+   if param_fn == 'train':
+      qpalma.train(run_obj,dataset_obj)
    else:
-      print 'You have to choose between training or prediction mode:'
-      print 'python qpalma. py (train|predict) <param_file>' 
+      param_obj = cPickle.load(open(param_fn))
+      qpalma.predict(run_obj,dataset_obj,param_obj)