+ added license text
authorFabio <fabio@congo.fml.local>
Tue, 2 Sep 2008 10:37:52 +0000 (12:37 +0200)
committerFabio <fabio@congo.fml.local>
Tue, 2 Sep 2008 10:37:52 +0000 (12:37 +0200)
+ changed parameter storage inside qpalma object
+ changed training code to use the new c++ interface

scripts/qpalma_main.py
tests/test_qpalma_prediction.py

index e6fc376..7af70dd 100644 (file)
@@ -1,18 +1,13 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-###########################################################
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
 #
-# The QPalma project aims at extending the Palma project 
-# to be able to use Solexa reads together with their 
-# quality scores.
-# 
-# This file represents the conversion of the main matlab 
-# training loop for Palma to Python.
-# 
-# Author: Fabio De Bona
-# 
-###########################################################
+# Written (W) 2008 Fabio De Bona
+# Copyright (C) 2008 Max-Planck-Society
 
 import array
 import cPickle
@@ -63,6 +58,8 @@ def getData(training_set,exampleKey,run):
    dna_flat_files =  '/fml/ag-raetsch/share/projects/genomes/A_thaliana_best/genome/'
    dna, acc_supp, don_supp = get_seq_and_scores(chr,strand,up_cut,down_cut,dna_flat_files)
 
+   #currentDNASeq, currentAcc, currentDon = seqInfo.get_seq_and_scores(chromo,strand,genomicSeq_start,genomicSeq_stop)
+
    original_exons = currentExons
    exons = original_exons - (up_cut-1)
    exons[0,0] -= 1
@@ -87,16 +84,17 @@ def getData(training_set,exampleKey,run):
    return dna,est,acc_supp,don_supp,exons,original_est,currentQualities
 
 
-
 class QPalma:
    """
    This class wraps the training and prediction functions for 
    the alignment.
    """
    
-   def __init__(self,dmode=False):
+   def __init__(self,run,seqInfo,dmode=False):
       self.ARGS = Param()
       self.qpalma_debug_mode = dmode
+      self.run = run
+      self.seqInfo = seqInfo
 
 
    def plog(self,string):
@@ -109,7 +107,6 @@ class QPalma:
       Given the needed input this method calls the QPalma C module which
       calculates a dynamic programming in order to obtain an alignment
       """
-      run = self.run
 
       dna_len = len(dna)
       est_len = len(est)
@@ -118,7 +115,7 @@ class QPalma:
       chastity = QPalmaDP.createDoubleArrayFromList([.0]*est_len)
 
       matchmatrix = QPalmaDP.createDoubleArrayFromList(mmatrix.flatten().tolist()[0])
-      mm_len = run['matchmatrixRows']*run['matchmatrixCols']
+      mm_len = self.run['matchmatrixRows']*self.run['matchmatrixCols']
 
       d_len = len(donor)
       donor = QPalmaDP.createDoubleArrayFromList(donor)
@@ -126,13 +123,13 @@ class QPalma:
       acceptor = QPalmaDP.createDoubleArrayFromList(acceptor)
 
       # Create the alignment object representing the interface to the C/C++ code.
-      currentAlignment = QPalmaDP.Alignment(run['numQualPlifs'],run['numQualSuppPoints'], self.use_quality_scores)
+      currentAlignment = QPalmaDP.Alignment(self.run['numQualPlifs'],self.run['numQualSuppPoints'], self.use_quality_scores)
       c_qualityPlifs = QPalmaDP.createPenaltyArrayFromList([elem.convert2SWIG() for elem in qualityPlifs])
       # calculates SpliceAlign, EstAlign, weightMatch, Gesamtscores, dnaest
       currentAlignment.myalign( current_num_path, dna, dna_len,\
        est, est_len, prb, chastity, ps, matchmatrix, mm_len, donor, d_len,\
-       acceptor, a_len, c_qualityPlifs, run['remove_duplicate_scores'],\
-       run['print_matrix'] )
+       acceptor, a_len, c_qualityPlifs, self.run['remove_duplicate_scores'],\
+       self.run['print_matrix'] )
 
       if prediction_mode:
          # part that is only needed for prediction
@@ -149,10 +146,8 @@ class QPalma:
       newQualityPlifsFeatures, dna_array, est_array
 
 
-   def train(self,run,training_set):
-      self.run = run
-
-      full_working_path = os.path.join(run['alignment_dir'],run['name'])
+   def init_train(self,training_set):
+      full_working_path = os.path.join(self.run['alignment_dir'],self.run['name'])
 
       #assert not os.path.exists(full_working_path)
       if not os.path.exists(full_working_path):
@@ -164,10 +159,10 @@ class QPalma:
       os.chdir(full_working_path)
 
       self.logfh = open('_qpalma_train.log','w+')
-      cPickle.dump(run,open('run_obj.pickle','w+'))
+      cPickle.dump(self.run,open('run_obj.pickle','w+'))
 
       self.plog("Settings are:\n")
-      self.plog("%s\n"%str(run))
+      self.plog("%s\n"%str(self.run))
 
       if self.run['mode'] == 'normal':
          self.use_quality_scores = False
@@ -176,56 +171,60 @@ class QPalma:
          self.use_quality_scores = True
       else:
          assert(False)
-   
+
+
+   def setUpSolver(self):
+      # Initialize solver 
+      self.plog('Initializing problem...\n')
+      
+      try:
+         solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
+      except:
+         self.plog('Got no license. Telling queue to reschedule job...\n')
+         sys.exit(99)
+
+      solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
+      solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
+
+      return solver
+
+
+   def train(self,training_set):
       numExamples = len(training_set)
       self.plog('Number of training examples: %d\n'% numExamples)
 
       self.noImprovementCtr = 0
       self.oldObjValue = 1e8
 
-      iteration_steps         = run['iter_steps']
-      remove_duplicate_scores = run['remove_duplicate_scores']
-      print_matrix            = run['print_matrix']
-      anzpath                 = run['anzpath']
+      iteration_steps         = self.run['iter_steps']
+      remove_duplicate_scores = self.run['remove_duplicate_scores']
+      print_matrix            = self.run['print_matrix']
+      anzpath                 = self.run['anzpath']
 
-      # Initialize parameter vector  / 
-      #param = Conf.fixedParam[:run['numFeatures']]
+      # Initialize parameter vector
       param = numpy.matlib.rand(run['numFeatures'],1)
    
-      lengthSP    = run['numLengthSuppPoints']
-      donSP       = run['numDonSuppPoints']
-      accSP       = run['numAccSuppPoints']
-      mmatrixSP   = run['matchmatrixRows']*run['matchmatrixCols']
-      numq        = run['numQualSuppPoints']
-      totalQualSP = run['totalQualSuppPoints']
+      lengthSP    = self.run['numLengthSuppPoints']
+      donSP       = self.run['numDonSuppPoints']
+      accSP       = self.run['numAccSuppPoints']
+      mmatrixSP   = self.run['matchmatrixRows']*run['matchmatrixCols']
+      numq        = self.run['numQualSuppPoints']
+      totalQualSP = self.run['totalQualSuppPoints']
 
       # no intron length model
-      if not run['enable_intron_length']:
+      if not self.run['enable_intron_length']:
          param[:lengthSP] *= 0.0
 
       # Set the parameters such as limits penalties for the Plifs
-      [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
-
-      # Initialize solver 
-      self.plog('Initializing problem...\n')
-      
-      #try:
-      #   solver = SIQPSolver(run['numFeatures'],numExamples,run['C'],self.logfh,run)
-      #except:
-      #   self.plog('Got no license. Telling queue to reschedule job...\n')
-      #   sys.exit(99)
+      [h,d,a,mmatrix,qualityPlifs] = set_param_palma(param,self.ARGS.train_with_intronlengthinformation,self.run)
 
-      #solver.enforceMonotonicity(lengthSP,lengthSP+donSP)
-      #solver.enforceMonotonicity(lengthSP+donSP,lengthSP+donSP+accSP)
+      solver = self.setUpSolver()
 
       # stores the number of alignments done for each example (best path, second-best path etc.)
       num_path = [anzpath]*numExamples
+
       # stores the gap for each example
       gap      = [0.0]*numExamples
-      #############################################################################################
-      # Training
-      #############################################################################################
-      self.plog('Starting training...\n')
 
       currentPhi = zeros((run['numFeatures'],1))
       totalQualityPenalties = zeros((totalQualSP,1))
@@ -240,6 +239,7 @@ class QPalma:
 
       featureVectors = zeros((run['numFeatures'],numExamples))
 
+      self.plog('Starting training...\n')
       # the main training loop
       while True:
          if iteration_nr == iteration_steps:
@@ -255,14 +255,10 @@ class QPalma:
 
             dna_len = len(dna)
 
-            if run['mode'] == 'normal':
-               quality = [40]*len(est)
-
-            if run['mode'] == 'using_quality_scores':
-               quality = currentQualities[0]
-
-            if not run['enable_quality_scores']:
-               quality = [40]*len(est)
+            if run['enable_quality_scores']:
+               quality = currentQualities[quality_index]
+            else:
+               quality = [40]*len(read)
 
             if not run['enable_splice_signals']:
                for idx,elem in enumerate(don_supp):
@@ -315,33 +311,20 @@ class QPalma:
             # returns two double lists
             donor, acceptor = compute_donacc(don_supp, acc_supp, d, a)
 
-            #myalign wants the acceptor site on the g of the ag
-            #acceptor = acceptor[1:]
-            #acceptor.append(-inf)
-
-            #donor = [-inf] + donor[:-1]
-
             ps = h.convert2SWIG()
 
-            _newSpliceAlign, _newEstAlign, _newWeightMatch, _newDPScores,\
-            _newQualityPlifsFeatures, unneeded1, unneeded2 =\
+            newSpliceAlign, newEstAlign, newWeightMatch, newDPScores,\
+            newQualityPlifsFeatures, unneeded1, unneeded2 =\
             self.do_alignment(dna,est,quality,mmatrix,donor,acceptor,ps,qualityPlifs,num_path[exampleIdx],False)
             mm_len = run['matchmatrixRows']*run['matchmatrixCols']
 
-            newSpliceAlign = _newSpliceAlign
-            newEstAlign    = _newEstAlign
-            newWeightMatch = _newWeightMatch
-            newDPScores    = _newDPScores
-            newQualityPlifsFeatures = _newQualityPlifsFeatures
-
             newSpliceAlign = newSpliceAlign.reshape(num_path[exampleIdx],dna_len)
             newWeightMatch = newWeightMatch.reshape(num_path[exampleIdx],mm_len)
 
             newQualityPlifsFeatures = newQualityPlifsFeatures.reshape(num_path[exampleIdx],run['totalQualSuppPoints'])
-            # Calculate weights of the respective alignments. Note that we are
-            # calculating n-best alignments without hamming loss, so we
-            # have to keep track which of the n-best alignments correspond to
-            # the true one in order not to incorporate a true alignment in the
+            # Calculate weights of the respective alignments. Note that we are calculating n-best alignments without 
+            # hamming loss, so we have to keep track which of the n-best alignments correspond to the true one in order 
+            # not to incorporate a true alignment in the
             # constraints. To keep track of the true and false alignments we
             # define an array true_map with a boolean indicating the
             # equivalence to the true alignment for each decoded alignment.
@@ -408,27 +391,13 @@ class QPalma:
                if False:
                   self.plog("Is considered as: %d\n" % true_map[1])
 
-                  result_len = currentAlignment.getResultLength()
-                  c_dna_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
-                  c_est_array = QPalmaDP.createIntArrayFromList([0]*(result_len))
-
-                  currentAlignment.getAlignmentArrays(c_dna_array,c_est_array)
-
-                  dna_array = [0.0]*result_len
-                  est_array = [0.0]*result_len
+                  #result_len = currentAlignment.getResultLength()
 
-                  for r_idx in range(result_len):
-                     dna_array[r_idx] = c_dna_array[r_idx]
-                     est_array[r_idx] = c_est_array[r_idx]
+                  dna_array,est_array = currentAlignment.getAlignmentArraysNew()
 
                   _newSpliceAlign = newSpliceAlign[0].flatten().tolist()[0]
                   _newEstAlign = newEstAlign[0].flatten().tolist()[0]
 
-                  #line1,line2,line3 = pprint_alignment(_newSpliceAlign,_newEstAlign, dna_array, est_array)
-                  #self.plog(line1+'\n')
-                  #self.plog(line2+'\n')
-                  #self.plog(line3+'\n')
-
                # if there is at least one useful false alignment add the
                # corresponding constraints to the optimization problem
                if firstFalseIdx != -1:
@@ -436,13 +405,10 @@ class QPalma:
                   differenceVector  = trueWeight - firstFalseWeights
                   #pdb.set_trace()
 
-                  #print 'NOT ADDING ANY CONSTRAINTS'
                   const_added = solver.addConstraint(differenceVector, exampleIdx)
-
                   const_added_ctr += 1
-               #
+
                # end of one example processing 
-               #
 
             # call solver every nth example //added constraint
             if exampleIdx != 0 and exampleIdx % numConstPerRound == 0:
@@ -475,26 +441,31 @@ class QPalma:
                cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
                param_idx += 1
                [h,d,a,mmatrix,qualityPlifs] =\
-               set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
+               set_param_palma(param,self.ARGS.train_with_intronlengthinformation,self.run)
 
-         #
-         # end of one iteration through all examples
-         #
+         ##############################################
+         # end of one iteration through all examples  #
+         ##############################################
 
          self.plog("suboptimal rounds %d\n" %suboptimal_example)
 
          if self.noImprovementCtr == numExamples*2:
-            break
+            FinalizeTraining(param,'param_%d.pickle'%param_idx)
 
          iteration_nr += 1
 
       #
       # end of optimization 
       #  
-      print 'Training completed'
+      FinalizeTraining(param,'param_%d.pickle'%param_idx)
+
 
-      cPickle.dump(param,open('param_%d.pickle'%param_idx,'w+'))
+   def FinalizeTraining(self,vector,name):
+      self.plog("Training completed")
+      cPickle.dump(param,open(name,'w+'))
       self.logfh.close()
+      sys.exit(0)
+   
 
 ###############################################################################
 #
@@ -504,15 +475,14 @@ class QPalma:
 #
 ###############################################################################
 
-   def init_prediction(self,run,dataset_fn,prediction_keys,param_fn,seqInfo,set_name,):
+   def init_prediction(self,dataset_fn,prediction_keys,param_fn,set_name):
       """
       Performing a prediction takes...
       """
-      self.run = run
       self.set_name = set_name
 
       #full_working_path = os.path.join(run['alignment_dir'],run['name'])
-      full_working_path = run['result_dir']
+      full_working_path = self.run['result_dir']
 
       print 'full_working_path is %s' % full_working_path 
 
@@ -543,16 +513,14 @@ class QPalma:
       # Load parameter vector to predict with
       param = cPickle.load(open(param_fn))
 
-      self.predict(run,prediction_set,param,seqInfo)
+      self.predict(prediction_set,param)
 
 
-   def predict(self,run,prediction_set,param,seqInfo):
+   def predict(self,prediction_set,param):
       """
       This method...
       """
 
-      self.run = run
-
       if self.run['mode'] == 'normal':
          self.use_quality_scores = False
 
@@ -563,7 +531,7 @@ class QPalma:
 
       # Set the parameters such as limits/penalties for the Plifs
       [h,d,a,mmatrix,qualityPlifs] =\
-      set_param_palma(param,self.ARGS.train_with_intronlengthinformation,run)
+      set_param_palma(param,self.ARGS.train_with_intronlengthinformation,self.run)
 
       if not self.qpalma_debug_mode:
          self.plog('Starting prediction...\n')
@@ -589,18 +557,18 @@ class QPalma:
             if not self.qpalma_debug_mode:
                self.plog('Loading example id: %d...\n'% int(id))
 
-            if run['enable_quality_scores']:
+            if self.run['enable_quality_scores']:
                quality = currentQualities[quality_index]
             else:
                quality = [40]*len(read)
 
             try:
-               currentDNASeq, currentAcc, currentDon = seqInfo.get_seq_and_scores(chromo,strand,genomicSeq_start,genomicSeq_stop)
+               currentDNASeq, currentAcc, currentDon = self.seqInfo.get_seq_and_scores(chromo,strand,genomicSeq_start,genomicSeq_stop)
             except:
                self.problem_ctr += 1
                continue
 
-            if not run['enable_splice_signals']:
+            if not self.run['enable_splice_signals']:
                for idx,elem in enumerate(currentDon):
                   if elem != -inf:
                      currentDon[idx] = 0.0
@@ -620,20 +588,20 @@ class QPalma:
             allPredictions.append(current_prediction)
 
       if not self.qpalma_debug_mode:
-         self.finalizePrediction(allPredictions)
+         self.FinalizePrediction(allPredictions)
       else:
          return allPredictions
 
 
-   def finalizePrediction(self,allPredictions):
-      # end of the prediction loop we save all predictions in a pickle file and exit
+   def FinalizePrediction(self,allPredictions):
+      """ End of the prediction loop we save all predictions in a pickle file and exit """
+
       cPickle.dump(allPredictions,open('%s.predictions.pickle'%(self.set_name),'w+'))
-      print 'Prediction completed'
       self.plog('Prediction completed\n')
       mes =  'Problem ctr %d' % self.problem_ctr
-      print mes
       self.plog(mes+'\n')
       self.logfh.close()
+      sys.exit(0)
 
 
    def calc_alignment(self, dna, read, quality, don_supp, acc_supp, d, a, h, mmatrix, qualityPlifs):
index 8e5e37a..f1211d3 100644 (file)
@@ -111,9 +111,9 @@ class TestQPalmaPrediction(unittest.TestCase):
       accessWrapper = DataAccessWrapper(g_dir,acc_dir,don_dir,g_fmt,s_fmt)
       seqInfo = SeqSpliceInfo(accessWrapper,range(1,num_chromo))
 
-      qp = QPalma(True)
+      qp = QPalma(run,seqInfo,True)
       #qp.init_prediction(run,set_name)
-      allPredictions = qp.predict(run,self.prediction_set,param,seqInfo )
+      allPredictions = qp.predict(self.prediction_set,param)
 
       for current_prediction in allPredictions:
          align_str = print_prediction(current_prediction)