git-svn-id: http://svn.tuebingen.mpg.de/ag-raetsch/projects/QPalma@8653 e1793c9e...
[qpalma.git] / scripts / PipelineHeuristic.py
index 7548512..4ea1cb9 100644 (file)
@@ -3,11 +3,48 @@
 
 import cPickle
 import sys
-import pydb
 import pdb
 import os
 import os.path
 import math
+import resource
+
+from qpalma.DataProc import *
+from qpalma.computeSpliceWeights import *
+from qpalma.set_param_palma import *
+from qpalma.computeSpliceAlignWithQuality import *
+from qpalma.penalty_lookup_new import *
+from qpalma.compute_donacc import *
+from qpalma.TrainingParam import Param
+from qpalma.Plif import Plf
+
+from qpalma.tools.splicesites import getDonAccScores
+from qpalma.Configuration import *
+
+from compile_dataset import getSpliceScores, get_seq_and_scores
+
+from numpy.matlib import mat,zeros,ones,inf
+from numpy import inf,mean
+
+from qpalma.parsers import PipelineReadParser
+
+
+def unbracket_est(est):
+   new_est = ''
+   e = 0
+
+   while True:
+      if e >= len(est):
+         break
+
+      if est[e] == '[':
+         new_est += est[e+2]
+         e += 4
+      else:
+         new_est += est[e]
+         e += 1
+
+   return "".join(new_est).lower()
 
 
 class PipelineHeuristic:
@@ -16,77 +53,486 @@ class PipelineHeuristic:
    vmatch is spliced an should be then newly aligned using QPalma or not.
    """
 
-   def __init__(self,filename):
-      self.data_filename = filename
+   def __init__(self,run_fname,data_fname,param_fname):
+      """
+      We need a run object holding information about the nr. of support points
+      etc.
+      """
 
+      run = cPickle.load(open(run_fname))
+      self.run = run
 
-   def filter(self):
-      SeqInfo, Exons, OriginalEsts, Qualities,\
-      AlternativeSequences = paths_load_data(data_filename,'training',None,self.ARGS)
+      start = cpu()
+
+      self.data_fname = data_fname
+
+      self.param = cPickle.load(open(param_fname))
+      
+      # Set the parameters such as limits penalties for the Plifs
+      [h,d,a,mmatrix,qualityPlifs] = set_param_palma(self.param,True,run)
+
+      self.h = h
+      self.d = d
+      self.a = a
+      self.mmatrix = mmatrix
+      self.qualityPlifs = qualityPlifs
+
+      # when we look for alternative alignments with introns this value is the
+      # mean of intron size
+      self.intron_size  = 250
+
+      self.read_size    = 36
    
-      for idx in range(len(SeqInfo)):
-         currentVMatchAligment
-         vMatchScore = calcAlignmentScore(currentVMatchAligment)
-
-         alternativeAlignments = calcAlternativeAligments(currentVMatchAligment)
-
-         maxScore    = 0.0
-         maxAligment = None
-         for currentAlternative in alternativeAlignments:
-            if currentScore > maxScore:
-               maxScore = alternativeScores 
-               maxAlignment = currentAlternative
-
-            currentScore =calcAlignmentScore(currentAlternative))
-
-
-         # Seems that according to our learned parameters VMatch found a good
-         # alignment of the current read
-         if maxScore < vMatchScore:
-            pass
-         # We found an alternative aligment considering splice sites that scores
-         # higher than the VMatch alignment
-         else:
-            pass
+      self.original_reads = {}
+
+      for line in open('/fml/ag-raetsch/share/projects/qpalma/solexa/allReads.pipeline'):
+         line = line.strip()
+         id,seq,q1,q2,q3 = line.split()
+         id = int(id)
+         self.original_reads[id] = seq
+
+      lengthSP    = run['numLengthSuppPoints']
+      donSP       = run['numDonSuppPoints']
+      accSP       = run['numAccSuppPoints']
+      mmatrixSP   = run['matchmatrixRows']*run['matchmatrixCols']
+      numq        = run['numQualSuppPoints']
+      totalQualSP = run['totalQualSuppPoints']
+
+      currentPhi = zeros((run['numFeatures'],1))
+      currentPhi[0:lengthSP]                                            = mat(h.penalties[:]).reshape(lengthSP,1)
+      currentPhi[lengthSP:lengthSP+donSP]                               = mat(d.penalties[:]).reshape(donSP,1)
+      currentPhi[lengthSP+donSP:lengthSP+donSP+accSP]                   = mat(a.penalties[:]).reshape(accSP,1)
+      currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP]   = mmatrix[:]
+
+      totalQualityPenalties = self.param[-totalQualSP:]
+      currentPhi[lengthSP+donSP+accSP+mmatrixSP:]                    = totalQualityPenalties[:]
+      self.currentPhi = currentPhi
+
+      # we want to identify spliced reads 
+      # so true pos are spliced reads that are predicted "spliced"
+      self.true_pos  = 0
+      
+      # as false positives we count all reads that are not spliced but predicted
+      # as "spliced"
+      self.false_pos = 0
+
+      self.true_neg  = 0
+      self.false_neg = 0
+
+      # total time spend for get seq and scores
+      self.get_time  = 0.0
+      self.calcAlignmentScoreTime = 0.0
+      self.alternativeScoresTime = 0.0
+
+      self.count_time = 0.0
+      self.read_parsing = 0.0
+      self.main_loop = 0.0
+      self.splice_site_time = 0.0
+      self.computeSpliceAlignWithQualityTime = 0.0
+      self.computeSpliceWeightsTime = 0.0
+      self.DotProdTime = 0.0
+      self.array_stuff = 0.0
+      stop = cpu()
+
+      self.init_time = stop-start
+
+   def filter(self):
+      """
+      This method...
+      """
+      run = self.run 
+
+      start = cpu()
+
+      rrp = PipelineReadParser(self.data_fname)
+      all_remapped_reads = rrp.parse()
+
+      stop = cpu()
+
+      self.read_parsing = stop-start
+
+      ctr = 0
+      unspliced_ctr  = 0
+      spliced_ctr    = 0
+
+      print 'Starting filtering...'
+      _start = cpu()
+
+      for readId,currentReadLocations in all_remapped_reads.items():
+         for location in currentReadLocations[:1]:
+
+            id       = location['id']
+            chr      = location['chr']
+            pos      = location['pos']
+            strand   = location['strand']
+            mismatch = location['mismatches']
+            length   = location['length']
+            off      = location['offset']
+            seq      = location['seq']
+            prb      = location['prb']
+            cal_prb  = location['cal_prb']
+            chastity = location['chastity']
+
+            id = int(id)
+
+            if strand == '-':
+               continue
+
+            if ctr == 1000:
+               break
+
+            #if pos > 10000000:
+            #   continue
+      
+            unb_seq = unbracket_est(seq)
+            effective_len = len(unb_seq)
+
+            genomicSeq_start  = pos
+            genomicSeq_stop   = pos+effective_len-1
+
+            start = cpu()
+            #print genomicSeq_start,genomicSeq_stop
+            currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
+            stop = cpu()
+            self.get_time += stop-start
+
+            dna            = currentDNASeq
+            exons          = zeros((2,1))
+            exons[0,0]     = 0
+            exons[1,0]     = effective_len
+            est            = unb_seq
+            original_est   = seq
+            quality        = prb
+
+            #pdb.set_trace()
+
+            currentVMatchAlignment = dna, exons, est, original_est, quality,\
+            currentAcc, currentDon
+            vMatchScore = self.calcAlignmentScore(currentVMatchAlignment)
+
+            alternativeAlignmentScores = self.calcAlternativeAlignments(location)
+
+            start = cpu()
+            # found no alternatives
+            if alternativeAlignmentScores == []:
+               continue
             
+            maxAlternativeAlignmentScore = max(alternativeAlignmentScores)
+            #print 'vMatchScore/alternativeScore: %f %f ' % (vMatchScore,maxAlternativeAlignmentScore)
+            #print 'all candidates %s' % str(alternativeAlignmentScores)
+
+            new_id = id - 1000000300000
 
+            unspliced = False
+            # unspliced 
+            if new_id > 0: 
+               unspliced = True
 
+            # Seems that according to our learned parameters VMatch found a good
+            # alignment of the current read
+            if maxAlternativeAlignmentScore < vMatchScore:
+               unspliced_ctr += 1
 
-   def calcAlignmentScore(self,dna,exons,est):
+               if unspliced:
+                  self.true_neg += 1
+               else:
+                  self.false_neg += 1
+
+            # We found an alternative alignment considering splice sites that scores
+            # higher than the VMatch alignment
+            else:
+               spliced_ctr += 1
+
+               if unspliced:
+                  self.false_pos += 1
+               else:
+                  self.true_pos += 1
+
+            ctr += 1
+            stop = cpu()
+            self.count_time = stop-start
+
+      _stop = cpu()
+      self.main_loop = _stop-_start
+
+      print 'Unspliced/Splice: %d %d'%(unspliced_ctr,spliced_ctr)
+      print 'True pos / false pos : %d %d'%(self.true_pos,self.false_pos)
+      print 'True neg / false neg : %d %d'%(self.true_neg,self.false_neg)
+
+
+   def findHighestScoringSpliceSites(self, currentAcc, currentDon, DNA, max_intron_size, read_size, splice_thresh):
+
+      def signum(a):
+          if a>0: 
+              return 1
+         elif a<0:
+              return -1
+          else:
+              return 0
+
+      proximal_acc   = []
+      for idx in xrange(max_intron_size, max_intron_size+read_size):
+          if currentAcc[idx]>= splice_thresh:
+            proximal_acc.append((idx,currentAcc[idx]))
+
+      proximal_acc.sort(lambda x,y: signum(x[1]-y[1])) 
+      proximal_acc=proximal_acc[-2:]
+
+      distal_acc   = []
+      for idx in xrange(max_intron_size+read_size, len(currentAcc)):
+          if currentAcc[idx]>= splice_thresh:
+            distal_acc.append((idx, currentAcc[idx], DNA[idx+1:idx+read_size]))
+
+      #distal_acc.sort(lambda x,y: signum(x[1]-y[1])) 
+      #distal_acc=distal_acc[-2:]
+
+
+      proximal_don   = []
+      for idx in xrange(max_intron_size, max_intron_size+read_size):
+         if currentDon[idx] >= splice_thresh:
+            proximal_don.append((idx, currentDon[idx]))
+
+      proximal_don.sort(lambda x,y: signum(x[1]-y[1]))
+      proximal_don=proximal_don[-2:]
+
+      distal_don   = []
+      for idx in xrange(1, max_intron_size):
+         if currentDon[idx] >= splice_thresh:
+            distal_don.append((idx, currentDon[idx], DNA[idx-read_size:idx]))
+
+      #distal_don.sort(lambda x,y: signum(x[1]-y[1]))
+      #distal_don=distal_don[-2:]
+
+      return proximal_acc,proximal_don,distal_acc,distal_don
+
+   def calcAlternativeAlignments(self,location):
+      """
+      Given an alignment proposed by Vmatch this function calculates possible
+      alternative alignments taking into account for example matched
+      donor/acceptor positions.
+      """
+
+      run = self.run
+      splice_thresh = 0.01
+      max_intron_size = 2000 
+
+      id       = location['id']
+      chr      = location['chr']
+      pos      = location['pos']
+      strand   = location['strand']
+      original_est = location['seq']
+      quality      = location['prb']
+      cal_prb  = location['cal_prb']
+      
+      est = unbracket_est(original_est)
+      effective_len = len(est)
+
+      genomicSeq_start  = pos - max_intron_size
+      genomicSeq_stop   = pos + max_intron_size + len(est)
+
+      start = cpu()
+      currentDNASeq, currentAcc, currentDon = get_seq_and_scores(chr,strand,genomicSeq_start,genomicSeq_stop,run['dna_flat_files'])
+      stop = cpu()
+      self.get_time += stop-start
+      dna   = currentDNASeq
+
+      proximal_acc,proximal_don,distal_acc,distal_don = self.findHighestScoringSpliceSites(currentAcc,currentDon, dna, max_intron_size, len(est), splice_thresh)
+       
+      print proximal_acc
+      print proximal_don
+      print distal_acc
+      print distal_don
+      
+      alternativeScores = []
+      
+      # inlined
+      h = self.h
+      d = self.d
+      a = self.a
+      mmatrix = self.mmatrix
+      qualityPlifs = self.qualityPlifs
+      # inlined
+
+      # compute dummy scores
+      #IntronScore = calculatePlif(h, [math.fabs(max_acc_pos-30)])[0]
+      #dummyAcceptorScore = calculatePlif(a, [max_acc_score])[0] 
+      IntronScore = calculatePlif(h, [self.intron_size])[0] - 0.5
+      dummyAcceptorScore = calculatePlif(a, [0.25])[0] 
+      dummyDonorScore = calculatePlif(d, [0.25])[0]
+      
+      _start = cpu()
+      for (don_pos,don_score) in proximal_don:
+         # remove mismatching positions in the second exon
+         original_est_cut=''
+
+         est_ptr=0
+         dna_ptr=0
+         ptr=0 
+         while ptr<len(original_est):
+             
+            if original_est[ptr]=='[':
+                dnaletter=original_est[ptr+1]
+                estletter=original_est[ptr+2]
+                if dna_ptr < don_pos:
+                    original_est_cut+=original_est[ptr:ptr+4] 
+                else:
+                    #original_est_cut+=estletter # EST letter
+                    original_est_cut+=dnaletter # DNA letter
+                ptr+=4 
+            else:
+                dnaletter=original_est[ptr]
+                estletter=dnaletter
+                
+                original_est_cut+=estletter # EST letter
+                ptr+=1
+
+            if estletter=='-':
+                dna_ptr+=1 
+            elif dnaletter=='-':
+                est_ptr+=1
+            else:
+                dna_ptr+=1 
+                est_ptr+=1
+                         
+         assert(dna_ptr<=len(dna))
+         assert(est_ptr<=len(est))
+
+         #print "Donor"
+         DonorScore = calculatePlif(d, [don_score])[0]
+         #print DonorScore,don_score,don_pos
+         
+         score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
+         score += dummyAcceptorScore + IntronScore + DonorScore
+         
+         #print 'diff %f,%f,%f' % ((trueWeight.T * self.currentPhi)[0,0] - score,(trueWeight.T * self.currentPhi)[0,0], score)
+         alternativeScores.append(score)
+
+      _stop = cpu()
+      self.alternativeScoresTime += _stop-_start
+
+      _start = cpu()
+      for (acc_pos,acc_score) in alt_acc:
+         # remove mismatching positions in the second exon
+         original_est_cut=''
+
+         est_ptr=0
+         dna_ptr=0
+         ptr=0 
+         while ptr<len(original_est):
+             
+            if original_est[ptr]=='[':
+                dnaletter=original_est[ptr+1]
+                estletter=original_est[ptr+2]
+                if est_ptr>=acc_pos:
+                    original_est_cut+=original_est[ptr:ptr+4] 
+                else:
+                    original_est_cut+=estletter # EST letter
+                ptr+=4 
+            else:
+                dnaletter=original_est[ptr]
+                estletter=dnaletter
+                
+                original_est_cut+=estletter # EST letter
+                ptr+=1
+
+            if estletter=='-':
+                dna_ptr+=1 
+            elif dnaletter=='-':
+                est_ptr+=1
+            else:
+                dna_ptr+=1 
+                est_ptr+=1
+                         
+         assert(dna_ptr<=len(dna))
+         assert(est_ptr<=len(est))
+
+         #print original_est,original_est_cut
+         
+         AcceptorScore = calculatePlif(d, [acc_score])[0]
+         #print "Acceptor"
+         #print AcceptorScore,acc_score,acc_pos
+
+         #if acc_score<0.1:
+         #    print currentAcc[0:50]
+         #    print currentDon[0:50]
+         
+         score = computeSpliceAlignScoreWithQuality(original_est_cut, quality, qualityPlifs, run, self.currentPhi)
+         score += AcceptorScore + IntronScore + dummyDonorScore
+         
+         #print 'diff %f,%f,%f' % ((trueWeight.T * self.currentPhi)[0,0] - score,(trueWeight.T * self.currentPhi)[0,0], score)
+         alternativeScores.append(score)
+
+      _stop = cpu()
+      self.alternativeScoresTime += _stop-_start
+
+      return alternativeScores
+
+
+   def calcAlignmentScore(self,alignment):
       """
       Given an alignment (dna,exons,est) and the current parameter for QPalma
       this function calculates the dot product of the feature representation of
       the alignment and the parameter vector i.e the alignment score. 
       """
 
-      currentPhi = zeros((run['numFeatures'],1))
+      start = cpu()
+      run = self.run
 
-      # Berechne die Parameter des wirklichen Alignments (but with untrained d,a,h ...)    
-      trueSpliceAlign, trueWeightMatch, trueWeightQuality ,dna_calc =\
-      computeSpliceAlignWithQuality(dna, exons, est, original_est,\
-      quality, qualityPlifs, run)
+      # Lets start calculation
+      dna, exons, est, original_est, quality, acc_supp, don_supp = alignment
 
-      # Calculate the weights
-      trueWeightDon, trueWeightAcc, trueWeightIntron =\
-      computeSpliceWeights(d, a, h, trueSpliceAlign, don_supp, acc_supp)
-      trueWeight = numpy.vstack([trueWeightIntron, trueWeightDon, trueWeightAcc, trueWeightMatch, trueWeightQuality])
+      score = computeSpliceAlignScoreWithQuality(original_est, quality, self.qualityPlifs, run, self.currentPhi)
 
-      currentPhi[0:lengthSP]                                            = mat(h.penalties[:]).reshape(lengthSP,1)
-      currentPhi[lengthSP:lengthSP+donSP]                               = mat(d.penalties[:]).reshape(donSP,1)
-      currentPhi[lengthSP+donSP:lengthSP+donSP+accSP]                   = mat(a.penalties[:]).reshape(accSP,1)
-      currentPhi[lengthSP+donSP+accSP:lengthSP+donSP+accSP+mmatrixSP]   = mmatrix[:]
-
-      if run['mode'] == 'using_quality_scores':
-         totalQualityPenalties = param[-totalQualSP:]
-         currentPhi[lengthSP+donSP+accSP+mmatrixSP:]                    = totalQualityPenalties[:]
+      stop = cpu()
+      self.calcAlignmentScoreTime += stop-start
+      
+      return score 
 
-      # Calculate w'phi(x,y) the total score of the alignment
-      return (trueWeight.T * currentPhi)[0,0]
 
+def cpu():
+   return (resource.getrusage(resource.RUSAGE_SELF).ru_utime+\
+   resource.getrusage(resource.RUSAGE_SELF).ru_stime) 
 
 
 if __name__ == '__main__':
-   filename = sys.argv[1]
-   out_filename = sys.argv[2]
-   ph1 = PipelineHeuristic(filename)
+   #run_fname = sys.argv[1]
+   #data_fname = sys.argv[2]
+   #param_filename = sys.argv[3]
+
+   dir = '/fml/ag-raetsch/home/fabio/tmp/QPalma_test/run_+_quality_+_splicesignals_+_intron_len'
+   jp = os.path.join
+
+   run_fname   = jp(dir,'run_object.pickle')
+   #data_fname = '/fml/ag-raetsch/share/projects/qpalma/solexa/current_data/map.vm_unspliced_1k'
+
+   data_fname  = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_2k'
+   #data_fname  = '/fml/ag-raetsch/share/projects/qpalma/solexa/pipeline_data/map.vm_100'
+
+   param_fname = jp(dir,'param_500.pickle')
+
+   ph1 = PipelineHeuristic(run_fname,data_fname,param_fname)
+
+   start = cpu()
    ph1.filter()
+   stop = cpu()
+
+   print 'total time elapsed: %f' % (stop-start)
+   print 'time spend for get seq: %f' % ph1.get_time
+   print 'time spend for calcAlignmentScoreTime: %f' %  ph1.calcAlignmentScoreTime
+   print 'time spend for alternativeScoresTime: %f' % ph1.alternativeScoresTime
+   print 'time spend for count time: %f' % ph1.count_time
+   print 'time spend for init time: %f' % ph1.init_time
+   print 'time spend for read_parsing time: %f' % ph1.read_parsing
+   print 'time spend for main_loop time: %f' % ph1.main_loop
+   print 'time spend for splice_site_time time: %f' % ph1.splice_site_time
+
+   print 'time spend for computeSpliceAlignWithQualityTime time: %f'% ph1.computeSpliceAlignWithQualityTime
+   print 'time spend for computeSpliceWeightsTime time: %f'% ph1.computeSpliceWeightsTime
+   print 'time spend for DotProdTime time: %f'% ph1.DotProdTime
+   print 'time spend forarray_stuff time: %f'% ph1.array_stuff
+   #import cProfile
+   #cProfile.run('ph1.filter()')
+
+   #import hotshot
+   #p = hotshot.Profile('profile.log')
+   #p.runcall(ph1.filter)