git-svn-id: http://svn.tuebingen.mpg.de/ag-raetsch/projects/QPalma@8640 e1793c9e...
[qpalma.git] / scripts / createAlignmentFileFromPrediction.py
index 5d21804..78f7ebe 100644 (file)
@@ -22,7 +22,13 @@ def prediction_on(filename):
       exampleId = gt_example['id']
 
       for elem_nr,current_pred in enumerate(current_example_pred[1:]):
-         exampleId = current_pred['id']
+         exampleId   = current_pred['id']
+         chr         = current_pred['chr']
+         strand      = current_pred['strand']
+         true_cut    = current_pred['true_cut']
+         start_pos   = current_pred['alternative_start_pos']
+         alignment   = current_pred['alignment']
+
          predExons = current_pred['predExons']
          trueExons = current_pred['trueExons']
          predPositions = [elem + current_pred['alternative_start_pos'] for elem in predExons]
@@ -36,13 +42,13 @@ def prediction_on(filename):
 
             #line = '%d\t%d\t%d\t%d\t%d\n' % (exampleId,p1,p2,p3,p4)
             #print line
-            allPositions[exampleId] = (chr,strand,true_cut,p1,p2,p3,p4)
+            allPositions[exampleId] = (chr,strand,start_pos,true_cut,p1,p2,p3,p4,alignment)
 
 
    return allPositions
 
 
-def writePredictions(fname,allPositions)
+def writePredictions(fname,allPositions):
 
    out_fh = open(fname,'w+')
 
@@ -57,16 +63,31 @@ def writePredictions(fname,allPositions)
 
 
    for id,elems in allPositions.items():
+      id += 1000000000000
       seq,q1,q2,q3 = allEntries[id]
-      chr,strand,true_cut,p1,p2,p3,p4 = elems
+      chr,strand,start_pos,true_cut,p1,p2,p3,p4,alignment = elems
+
+      p1 += start_pos
+      p2 += start_pos
+      p3 += start_pos
+      p4 += start_pos
+
+      #pdb.set_trace()
+
+      (qStart, qEnd, tStart, tEnd, num_exons, qExonSizes, qStarts, qEnds,\
+      tExonSizes,tStarts, tEnds) = alignment 
 
-      new_line = '%d\t%d\t%s\t%s\t%s\t%s\t%s\t%d\t%d\t%d\t%d\t%d\n' %\
-      (id,chr,strand,seq,q1,q2,q3,true_cut,p1,p2,p3,p4)
+      new_line = '%d\t%d\t%s\t%s\t%s\t%d\t%d\t%d\t%d\t%d\t%d\t%s\t%s\t%s\t%s\t%s\t%s\n' %\
+      (id,chr,strand,seq,q1,start_pos,qStart,qEnd,tStart,tEnd,num_exons,\
+      str(qExonSizes)[1:-1].replace(' ',''),str(qStarts)[1:-1].replace(' ',''),\
+      str(qEnds)[1:-1].replace(' ',''),str(tExonSizes)[1:-1].replace(' ',''),\
+      str(tStarts)[1:-1].replace(' ',''),str(tEnds)[1:-1].replace(' ',''))
 
-      out_fh.write(line)
+      out_fh.write(new_line)
 
+   out_fh.close()
 
-def collect_prediction(current_dir,run_name):
+def collect_prediction(current_dir):
    """
    Given the toplevel directory this function takes care that for each distinct
    experiment the training and test predictions are evaluated.
@@ -75,27 +96,32 @@ def collect_prediction(current_dir,run_name):
    train_suffix   = '_allPredictions_TRAIN'
    test_suffix    = '_allPredictions_TEST'
 
+   run_name = 'run_+_quality_+_splicesignals_+_intron_len_1'
+
    jp = os.path.join
    b2s = ['-','+']
 
-   currentRun = cPickle.load(open(jp(current_dir,'run_object.pickle')))
+   currentRun = cPickle.load(open(jp(current_dir,'run_object_1.pickle')))
    QFlag    = currentRun['enable_quality_scores']
    SSFlag   = currentRun['enable_splice_signals']
    ILFlag   = currentRun['enable_intron_length']
    currentRunId = '%s%s%s' % (b2s[QFlag],b2s[SSFlag],b2s[ILFlag])
    
-   filename =  jp(current_dir,run_name)+train_suffix
+   #filename =  jp(current_dir,run_name)+train_suffix
+   #print 'Prediction on: %s' % filename
+   #prediction_on(filename)
+
+   filename =  jp(current_dir,run_name)+test_suffix
    print 'Prediction on: %s' % filename
-   prediction_on(filename)
+   test_result = prediction_on(filename)
 
-   #filename =  jp(current_dir,run_name)+test_suffix
-   #print 'Prediction on: %s' % filename
-   #test_result = prediction_on(filename)
+   fname = 'predictions.txt'
+   writePredictions(fname,test_result)
 
    #return train_result,test_result,currentRunId
 
 
 if __name__ == '__main__':
    dir = sys.argv[1]
-   run_name = sys.argv[2]
-   collect_prediction(dir,run_name)
+   #run_name = sys.argv[2]
+   collect_prediction(dir)