+ got rid of some unneccessary code
[qpalma.git] / scripts / qpalma_main.py
index b3edea8..e6fc376 100644 (file)
@@ -26,23 +26,18 @@ import numpy
 from numpy.matlib import mat,zeros,ones,inf
 from numpy.linalg import norm
 
-import QPalmaDP
-import qpalma
-
 #from qpalma.SIQP_CPX import SIQPSolver
 #from qpalma.SIQP_CVXOPT import SIQPSolver
 
+import QPalmaDP
+import qpalma
 from qpalma.computeSpliceWeights import *
 from qpalma.set_param_palma import *
 from qpalma.computeSpliceAlignWithQuality import *
 from qpalma.TrainingParam import Param
 from qpalma.Plif import Plf,compute_donacc
 
-# these two imports are needed for the load genomic resp. interval query
-# functions
-#from Genefinding import *
-#from genome_utils import load_genomic
-from Utils import calc_stat, calc_info, pprint_alignment, get_alignment
+from Utils import pprint_alignment, get_alignment
 
 class SpliceSiteException:
    pass
@@ -68,12 +63,6 @@ def getData(training_set,exampleKey,run):
    dna_flat_files =  '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
    dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
 
-   # splice score is located at g of ag
-   #ag_tuple_pos = [p for p,e in enumerate(dna) if p>1 and dna[p-1]=='a' and dna[p]=='g' ]
-   #assert ag_tuple_pos == [p for p,e in enumerate(acc_supp) if e != -inf and p > 1], pdb.set_trace()
-   #gt_tuple_pos = [p for p,e in enumerate(dna) if p>0 and p<len(dna)-1 and e=='g' and (dna[p+1]=='t' or dna[p+1]=='c')]
-   #assert gt_tuple_pos == [p for p,e in enumerate(don_supp) if e != -inf and p > 0], pdb.set_trace()
-
    original_exons = currentExons
    exons = original_exons - (up_cut-1)
    exons[0,0] -= 1
@@ -145,80 +134,17 @@ class QPalma:
        acceptor, a_len, c_qualityPlifs, run['remove_duplicate_scores'],\
        run['print_matrix'] )
 
-      c_SpliceAlign       = QPalmaDP.createIntArrayFromList([0]*(dna_len*current_num_path))
-      c_EstAlign          = QPalmaDP.createIntArrayFromList([0]*(est_len*current_num_path))
-      c_WeightMatch       = QPalmaDP.createIntArrayFromList([0]*(mm_len*current_num_path))
-      c_DPScores   = QPalmaDP.createDoubleArrayFromList([.0]*current_num_path)
-
-      c_qualityPlifsFeatures = QPalmaDP.createDoubleArrayFromList([.0]*(run['totalQualSuppPoints']*current_num_path))
-
       if prediction_mode:
          # part that is only needed for prediction
          result_len = currentAlignment.getResultLength()
-         c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
-         c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
-
-         currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
-         _dna_array,_est_array = currentAlignment.getAlignmentArraysNew()
-
-         dna_array = [0.0]*result_len
-         est_array = [0.0]*result_len
-
-         for r_idx in range(result_len):
-            dna_array[r_idx] = c_dna_array[r_idx]
-            est_array[r_idx] = c_est_array[r_idx]
-
+         dna_array,est_array = currentAlignment.getAlignmentArraysNew()
       else:
          dna_array = None
          est_array = None
 
-      assert dna_array == _dna_array
-      assert est_array == _est_array
-
-      currentAlignment.getAlignmentResults(c_SpliceAlign, c_EstAlign,\
-      c_WeightMatch, c_DPScores, c_qualityPlifsFeatures)
-
-      _SpliceAlign, _EstAlign, _WeightMatch, _DPScores, _newQualityPlifsFeatures =\
+      newSpliceAlign, newEstAlign, newWeightMatch, newDPScores, newQualityPlifsFeatures =\
       currentAlignment.getAlignmentResultsNew()
 
-      newSpliceAlign = zeros((current_num_path*dna_len,1))
-      newEstAlign    = zeros((est_len*current_num_path,1))
-      newWeightMatch = zeros((current_num_path*mm_len,1))
-      newDPScores    = zeros((current_num_path,1))
-      newQualityPlifsFeatures = zeros((run['totalQualSuppPoints']*current_num_path,1))
-
-      for i in range(dna_len*current_num_path):
-         newSpliceAlign[i] = c_SpliceAlign[i]
-
-      for i in range(est_len*current_num_path):
-         newEstAlign[i] = c_EstAlign[i]
-
-      for i in range(mm_len*current_num_path):
-         newWeightMatch[i] = c_WeightMatch[i]
-
-      for i in range(current_num_path):
-         newDPScores[i] = c_DPScores[i]
-
-      if self.use_quality_scores:
-         for i in range(run['totalQualSuppPoints']*current_num_path):
-            newQualityPlifsFeatures[i] = c_qualityPlifsFeatures[i]
-
-      del c_SpliceAlign
-      del c_EstAlign
-      del c_WeightMatch
-      del c_DPScores
-      del c_qualityPlifsFeatures
-      del currentAlignment
-
-
-      assert newSpliceAlign.flatten().tolist()[0] == _SpliceAlign
-      assert newEstAlign.flatten().tolist()[0]    == _EstAlign
-      assert newWeightMatch.flatten().tolist()[0]  == _WeightMatch
-      assert newDPScores.flatten().tolist()[0]     == _DPScores 
-      assert newQualityPlifsFeatures.flatten().tolist()[0]  == _newQualityPlifsFeatures 
-
-      pdb.set_trace()
-
       return newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
       newQualityPlifsFeatures, dna_array, est_array
 
@@ -570,7 +496,6 @@ class QPalma:
       cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
       self.logfh.close()
 
-
 ###############################################################################
 #
 # End of the code needed for training 
@@ -602,7 +527,6 @@ class QPalma:
 
       self.logfh = open('_qpalma_predict_%s.log'%set_name,'w+')
 
-
       # number of prediction instances
       self.plog('Number of prediction examples: %d\n'% len(prediction_keys))
 
@@ -641,21 +565,9 @@ class QPalma:
       [h,d,a,mmatrix,qualityPlifs] =\
       set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
 
-      #############################################################################################
-      # Prediction
-      #############################################################################################
-      
       if not self.qpalma_debug_mode:
          self.plog('Starting prediction...\n')
 
-      #donSP       = self.run['numDonSuppPoints']
-      #accSP       = self.run['numAccSuppPoints']
-      #lengthSP    = self.run['numLengthSuppPoints']
-      #mmatrixSP   = run['matchmatrixRows']*run['matchmatrixCols']
-      #numq        = self.run['numQualSuppPoints']
-      #totalQualSP = self.run['totalQualSuppPoints']
-      #totalQualityPenalties = zeros((totalQualSP,1))
-
       self.problem_ctr = 0
 
       # where we store the predictions
@@ -664,23 +576,9 @@ class QPalma:
       # we take the first quality vector of the tuple of quality vectors
       quality_index = 0
 
-      # fetch the data needed
-      #g_dir    = run['dna_flat_files'] #'/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
-      #acc_dir  = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/acc'
-      #don_dir  = '/fml/ag-raetsch/home/fabio/tmp/interval_query_files/don'
-
-      #g_fmt = 'chr%d.dna.flat'
-      #s_fmt = 'contig_%d%s'
-
-      #num_chromo = 6
-
-      #accessWrapper = DataAccessWrapper(g_dir,acc_dir,don_dir,g_fmt,s_fmt)
-      #seqInfo = SeqSpliceInfo(accessWrapper,range(1,num_chromo))
-
       # beginning of the prediction loop
       for example_key in prediction_set.keys():
          print 'Current example %d' % example_key
-
          for example in prediction_set[example_key]:
 
             currentSeqInfo,read,currentQualities = example
@@ -691,13 +589,9 @@ class QPalma:
             if not self.qpalma_debug_mode:
                self.plog('Loading example id: %d...\n'% int(id))
 
-            if run['mode'] == 'normal':
-               quality = [40]*len(read)
-
-            if run['mode'] == 'using_quality_scores':
+            if run['enable_quality_scores']:
                quality = currentQualities[quality_index]
-
-            if not run['enable_quality_scores']:
+            else:
                quality = [40]*len(read)
 
             try:
@@ -750,17 +644,11 @@ class QPalma:
       run = self.run
       donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
 
-      dna = str(dna)
-      read = str(read)
-
       if '-' in read:
          self.plog('found gap\n')
          read = read.replace('-','')
          assert len(read) == Conf.read_size
 
-      dna_len = len(dna)
-      read_len = len(read)
-
       ps = h.convert2SWIG()
 
       newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
@@ -769,17 +657,12 @@ class QPalma:
 
       mm_len = run['matchmatrixRows']*run['matchmatrixCols']
 
-      # old code removed
-      newSpliceAlign = newSpliceAlign.reshape(1,dna_len)
-      newWeightMatch = newWeightMatch.reshape(1,mm_len)
       true_map    = [0]*2
       true_map[0] = 1
       pathNr      = 0
 
-      #pdb.set_trace()
-
-      _newSpliceAlign   = array.array('B',newSpliceAlign.flatten().tolist()[0])
-      _newEstAlign      = array.array('B',newEstAlign.flatten().tolist()[0])
+      _newSpliceAlign   = array.array('B',newSpliceAlign)
+      _newEstAlign      = array.array('B',newEstAlign)
        
       alignment = get_alignment(_newSpliceAlign,_newEstAlign, dna_array, read_array) #(qStart, qEnd, tStart, tEnd, num_exons, qExonSizes, qStarts, qEnds, tExonSizes, tStarts, tEnds)
 
@@ -798,7 +681,6 @@ class QPalma:
    def calculatePredictedExons(self,SpliceAlign):
       newExons = []
       oldElem = -1
-      SpliceAlign = SpliceAlign.flatten().tolist()[0]
       SpliceAlign.append(-1)
       for pos,elem in enumerate(SpliceAlign):
          if pos == 0: