+ added framework code for training modus
[qpalma.git] / scripts / Evaluation.py
index 5126373..952eb4c 100644 (file)
@@ -3,45 +3,85 @@
 
 import cPickle
 import sys
-import pydb
 import pdb
 import os
 import os.path
 import math
 
+from qpalma.parsers import *
 
-def createTable(results):
+
+data = None
+
+
+def result_statistic():
+   """
+
+   """   
+
+   num_unaligned_reads = 0
+   num_incorrectly_aligned_reads = 0
+   pass
+
+   
+def createErrorVSCutPlot(results):
    """
-   This function gets the results from the evaluation and creates a tex table.
+   This function takes the results of the evaluation and creates a tex table.
    """
 
-   fh = open('result_table.tex','w+')
-   lines = ['\begin{tabular}{|c|c|c|r|}', '\hline',\
+   fh = open('error_rates_table.tex','w+')
+   lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
    'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
-   'information & site pred. & length & \multicolumn{1}{c|}{rate}\\',\
-   '\hline',\
-   
-   #'-       &    -      &    -      &  %2.3f \%\\' % ( results['---'][0], results['---'][1] ),\
-   #'+       &    -      &    -      &  %2.3f \%\\' % ( results['+--'][0], results['+--'][1] ),\
+   'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
 
-   #'\hline',\
+   #for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
+   for pos,key in enumerate(['+++']):
+      res = results[key]
+      for i in range(37):
+         ctr = 0
+         try:
+            ctr = res[1][i]
+         except:
+            ctr = 0
 
-   #'-       &    +      &    -      & %2.3f \%\\' % ( results['-+-'][0], results['-+-'][1] ),\
-   #'+       &    +      &    -      & %2.3f \%\\' % ( results['++-'][0], results['++-'][1] ),\
+         lines.append( '%d\n' % ctr)
 
-   '\hline\hline',\
+      if pos % 2 == 1:
+         lines.append('\hline')
 
-   '- & - & + & %2.3f & %2.3f \\%%\\\\' % ( results['--+'][0], results['--+'][1] ),\
-   '+ & - & + & %2.3f & %2.3f \\%%\\\\' % ( results['+-+'][0], results['+-+'][1] ),\
+   lines.append('\end{tabular}')
 
-   '\hline',\
+   lines = [l+'\n' for l in lines]
+   for l in lines:
+      fh.write(l)
+   fh.close()
 
-   '- & + & + & %2.3f & %2.3f \\%%\\\\' % ( results['-++'][0], results['-++'][1] ),\
-   '+ & + & + & %2.3f & %2.3f \\%%\\\\' % ( results['+++'][0], results['+++'][1] ),\
 
-   '\hline',\
-   '\end{tabular}']
+def createTable(results):
+   """
+   This function takes the results of the evaluation and creates a tex table.
+   """
 
+   fh = open('result_table.tex','w+')
+   lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
+   'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
+   'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
+
+   for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
+      res = [e*100 for e in results[key]]
+   
+      lines.append( '%s & %s & %s & %2.2f & %2.2f \\%%\\\\' % ( key[0], key[1], key[2], res[0], res[1] ) )
+      if pos % 2 == 1:
+         lines.append('\hline')
+
+   for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
+      res = [e*100 for e in results[key]]
+   
+      lines.append( '%s & %s & %s & %2.2f & x \\%%\\\\' % ( key[0], key[1], key[2], res[2] ) )
+      if pos % 2 == 1:
+         lines.append('\hline')
+
+   lines.append('\end{tabular}')
 
    lines = [l+'\n' for l in lines]
    for l in lines:
@@ -62,12 +102,34 @@ def compare_scores_and_labels(scores,labels):
             if otherPos == currentPos:
                continue
 
-            if labels[otherPos] == False and otherElem > currentElem:
+            if labels[otherPos] == False and otherElem >= currentElem:
                return False
 
    return True
 
 
+def compare_exons(predExons,trueExons):
+   e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
+
+   if len(predExons) == 4:
+      e1_begin,e1_end = predExons[0],predExons[1]
+      e2_begin,e2_end = predExons[2],predExons[3]
+   else:
+      return False
+
+   e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
+   e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
+
+   e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
+   e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
+
+   if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
+   and e2_e_off == 0:
+      return True
+
+   return False
+
+
 def evaluate_unmapped_example(current_prediction):
    predExons = current_prediction['predExons']
    trueExons = current_prediction['trueExons']
@@ -93,112 +155,122 @@ def evaluate_example(current_prediction):
    predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
    truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
 
-   #pdb.set_trace()
-
    pos_comparison = (predPositions == truePositions)
 
-   #if label == True and pos_comparison == False:
-   #   pdb.set_trace()
-
    return label,pos_comparison,pred_score
 
 
-def compare_exons(predExons,trueExons):
-   e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
-
-   if len(predExons) == 4:
-      e1_begin,e1_end = predExons[0],predExons[1]
-      e2_begin,e2_end = predExons[2],predExons[3]
-   else:
-      return False
-
-   e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
-   e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
-
-   e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
-   e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
-
-   if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
-   and e2_e_off == 0:
-      return True
-
-   return False
-
 def prediction_on(filename):
-   allPredictions = cPickle.load(open(filename))
 
-   exon1Begin     = []
-   exon1End       = []
-   exon2Begin     = []
-   exon2End       = []
-   allWrongExons  = []
-   allDoubleScores = []
+   allPredictions = cPickle.load(open(filename))
 
    gt_correct_ctr = 0
+   gt_incorrect_ctr = 0
+   incorrect_gt_cuts       = {}
+
    pos_correct_ctr   = 0
    pos_incorrect_ctr = 0
+   incorrect_vmatch_cuts   = {}
+
    score_correct_ctr = 0
    score_incorrect_ctr = 0
 
+   total_gt_examples = 0
    total_vmatch_instances_ctr = 0
+
    true_vmatch_instances_ctr = 0
 
-   for current_example_pred in allPredictions:
-      gt_example = current_example_pred[0]
-      gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
-      gt_correct = evaluate_unmapped_example(gt_example)
+   allUniquePredictions = [False]*len(allPredictions)
+
+   for pos,current_example_pred in enumerate(allPredictions):
+      for elem_nr,new_prediction in enumerate(current_example_pred[1:]):
+
+         if allUniquePredictions[pos] != False:
+            current_prediction = allUniquePredictions[pos]
 
-      current_scores = []
-      current_labels = []
-      current_scores.append(gt_score)
-      current_labels.append(gt_correct)
+            current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+            new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
 
-      if gt_correct:
-         gt_correct_ctr += 1
+            if current_a_score < new_score :
+               allUniquePredictions[id] = new_prediction
 
-      for elem_nr,current_pred in enumerate(current_example_pred[1:]):
-         current_label,comparison_result,current_score = evaluate_example(current_pred)
+         else:
+            allUniquePredictions[pos] = new_prediction
 
-         # if vmatch found the right read pos we check for right exons
-         # boundaries
-         if current_label:
-            if comparison_result:
-               pos_correct_ctr += 1
-            else:
-               pos_incorrect_ctr += 1
+   for current_pred in allUniquePredictions:
+      if current_pred == False:
+         continue
 
-            true_vmatch_instances_ctr += 1
-            
+   #for current_example_pred in allPredictions:
+      #gt_example = current_example_pred[0]
+      #gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
+      #gt_correct = evaluate_unmapped_example(gt_example)
 
-         current_scores.append(current_score)
-         current_labels.append(current_label)
+      #exampleIdx = gt_example['exampleIdx']
 
-         total_vmatch_instances_ctr += 1
+      #cut_pos = gt_example['true_cut']
+
+      #if gt_correct:
+      #   gt_correct_ctr += 1
+      #else:
+      #   gt_incorrect_ctr += 1
+
+      #   try:
+      #      incorrect_gt_cuts[cut_pos] += 1
+      #   except:
+      #      incorrect_gt_cuts[cut_pos] = 1
+
+      #total_gt_examples += 1
+     
+      #current_scores = []
+      #current_labels = []
+      #for elem_nr,current_pred in enumerate(current_example_pred[1:]):
+
+      current_label,comparison_result,current_score = evaluate_example(current_pred)
+
+      # if vmatch found the right read pos we check for right exons
+      # boundaries
+      #if current_label:
+      if comparison_result:
+         pos_correct_ctr += 1
+      else:
+         pos_incorrect_ctr += 1
+
+            #try:
+            #   incorrect_vmatch_cuts[cut_pos] += 1
+            #except:
+            #   incorrect_vmatch_cuts[cut_pos] = 1
+
+         true_vmatch_instances_ctr += 1
+         
+      #current_scores.append(current_score)
+      #current_labels.append(current_label)
+
+      total_vmatch_instances_ctr += 1
 
       # check whether the correct predictions score higher than the incorrect
       # ones
-      cmp_res = compare_scores_and_labels(current_scores,current_labels)
-      if cmp_res:
-         score_correct_ctr += 1
-      else:
-         score_incorrect_ctr += 1
+      #cmp_res = compare_scores_and_labels(current_scores,current_labels)
+      #if cmp_res:
+      #   score_correct_ctr += 1
+      #else:
+      #   score_incorrect_ctr += 1
 
    # now that we have evaluated all instances put out all counters and sizes
    print 'Total num. of examples: %d' % len(allPredictions)
    print 'Number of correct ground truth examples: %d' % gt_correct_ctr
    print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
-   print 'Correct pos: %d, incorrect pos: %d' %\
-   (pos_correct_ctr,pos_incorrect_ctr)
+   print 'Correct pos: %d, incorrect pos: %d' % (pos_correct_ctr,pos_incorrect_ctr)
    print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
    print 'Correct scores: %d, incorrect scores: %d' %\
    (score_correct_ctr,score_incorrect_ctr)
 
-   pos_error   = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
+   pos_error   = 1.0 * pos_incorrect_ctr / total_vmatch_instances_ctr
    score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
+   gt_error    = 1.0 * gt_incorrect_ctr / total_gt_examples
 
-   print pos_error,score_error
+   return (pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts)
 
-   return (pos_error,score_error)
 
 def collect_prediction(current_dir,run_name):
    """
@@ -206,21 +278,24 @@ def collect_prediction(current_dir,run_name):
    experiment the training and test predictions are evaluated.
 
    """
-   train_suffix   = '_allPredictions_TRAIN'
-   test_suffix    = '_allPredictions_TEST'
+   idx = 5
+
+   train_suffix   = '_%d_allPredictions_TRAIN' % (idx)
+   test_suffix    = '_%d_allPredictions_TEST' % (idx)
 
    jp = os.path.join
    b2s = ['-','+']
 
-   currentRun = cPickle.load(open(jp(current_dir,'run_object.pickle')))
+   currentRun = cPickle.load(open(jp(current_dir,'run_object_%d.pickle'%(idx))))
    QFlag    = currentRun['enable_quality_scores']
    SSFlag   = currentRun['enable_splice_signals']
    ILFlag   = currentRun['enable_intron_length']
    currentRunId = '%s%s%s' % (b2s[QFlag],b2s[SSFlag],b2s[ILFlag])
    
-   filename =  jp(current_dir,run_name)+train_suffix
-   print 'Prediction on: %s' % filename
-   train_result = prediction_on(filename)
+   #filename =  jp(current_dir,run_name)+train_suffix
+   #print 'Prediction on: %s' % filename
+   #train_result = prediction_on(filename)
+   train_result = []
 
    filename =  jp(current_dir,run_name)+test_suffix
    print 'Prediction on: %s' % filename
@@ -234,10 +309,16 @@ def perform_prediction(current_dir,run_name):
    This function takes care of starting the jobs needed for the prediction phase
    of qpalma
    """
-   cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s | qsub -l h_vmem=12.0G -cwd -j y -N \"%s.log\"'%(current_dir,run_name)
-   #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
-   #print cmd
-   os.system(cmd)
+
+   #for i in range(1,6):
+   for i in range(1,2):
+      cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s %d |\
+      qsub -l h_vmem=12.0G -cwd -j y -N \"%s_%d.log\"'%(current_dir,i,run_name,i)
+
+      #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
+      #print cmd
+      os.system(cmd)
+
 
 
 def forall_experiments(current_func,tl_dir):
@@ -255,84 +336,594 @@ def forall_experiments(current_func,tl_dir):
    run_dirs = [de for de in dir_entries if os.path.isdir(de)]
 
    all_results = {}
+   all_error_rates = {}
 
    for current_dir in run_dirs:
       run_name = current_dir.split('/')[-1]
-      #current_func(current_dir,run_name)
-      train_result,test_result,currentRunId = current_func(current_dir,run_name)
-      all_results[currentRunId] = test_result
 
-   createTable(all_results)
+      pdb.set_trace()
 
+      if current_func.__name__ == 'perform_prediction':
+         current_func(current_dir,run_name)
+
+      if current_func.__name__ == 'collect_prediction':
+         train_result,test_result,currentRunId = current_func(current_dir,run_name)
+         all_results[currentRunId] = test_result
+         pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
+         all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
+
+   if current_func.__name__ == 'collect_prediction':
+      #createErrorVSCutPlot(all_error_rates)
+      createTable(all_results)
+
+def _predict_on(filename,filtered_reads,with_coverage):
+   """
+   This function evaluates the predictions  made by QPalma.
+   It needs a pickled file containing the predictions themselves and the
+   ascii file with original reads.
+
+   Optionally one can specifiy a coverage file containing for each read the
+   coverage number estimated by a remapping step.
+
+   """
+
+   coverage_map = {}
+
+   if with_coverage:
+      for line in open('/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/coverage_results/ALL_COVERAGES'):
+         id,coverage_nr = line.strip().split()
+         coverage_map[int(id)] = int(coverage_nr)
+
+   print 'parsing filtered reads..'
+   all_filtered_reads = parse_filtered_reads(filtered_reads)
+   print 'found %d filtered reads' % len(all_filtered_reads)
+
+   out_fh = open('predicted_positions.txt','w+')
+
+   allPredictions = cPickle.load(open(filename))
+
+   spliced_ctr = 0
+   unspliced_ctr = 0
+
+   pos_correct_ctr   = 0
+   pos_incorrect_ctr = 0
+
+   correct_spliced_ctr     = 0
+   correct_unspliced_ctr   = 0
+
+   incorrect_spliced_ctr     = 0
+   incorrect_unspliced_ctr   = 0
+
+   correct_covered_splice_ctr = 0
+   incorrect_covered_splice_ctr = 0
+
+   total_vmatch_instances_ctr = 0
+
+   unspliced_spliced_reads_ctr = 0
+   wrong_spliced_reads_ctr = 0
+
+   wrong_aligned_unspliced_reads_ctr = 0
+   wrong_unspliced_reads_ctr = 0
+
+   cut_pos_ctr = {}
+
+   total_ctr = 0
+   skipped_ctr = 0
+
+   is_spliced = False
+   min_coverage = 3
+
+   allUniqPredictions = {}
+
+   print 'Got %d predictions' % len(allPredictions)
+
+   for new_prediction in allPredictions:
+      id = new_prediction['id']
+      id = int(id)
+
+      if allUniqPredictions.has_key(id):
+         current_prediction = allUniqPredictions[id]
+
+         current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+         new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
+
+         if current_a_score < new_score :
+            allUniqPredictions[id] = new_prediction
 
-if __name__ == '__main__':
-   dir = sys.argv[1]
-   assert os.path.exists(dir), 'Error directory does not exist!'
-
-   #forall_experiments(perform_prediction,dir)
-   forall_experiments(collect_prediction,dir)
-
-"""
-def evaluateExample(dna,est,exons,SpliceAlign,newEstAlign):
-   newExons = []
-   oldElem = -1
-   SpliceAlign = SpliceAlign.flatten().tolist()[0]
-   SpliceAlign.append(-1)
-   for pos,elem in enumerate(SpliceAlign):
-      if pos == 0:
-         oldElem = -1
       else:
-         oldElem = SpliceAlign[pos-1]
+         allUniqPredictions[id] = new_prediction
 
-      if oldElem != 0 and elem == 0: # start of exon
-         newExons.append(pos)
+   print 'Got %d uniq predictions' % len(allUniqPredictions)
 
-      if oldElem == 0 and elem != 0: # end of exon
-         newExons.append(pos)
+   #for current_prediction in allPredictions:
+   for _id,current_prediction in allUniqPredictions.items():
+      id = current_prediction['id']
+      id = int(id)
 
-   e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
+      if not id >= 1000000300000:
+         is_spliced = True
+      else:
+         is_spliced = False
 
-   if len(newExons) == 4:
-      e1_begin,e1_end = newExons[0],newExons[1]
-      e2_begin,e2_end = newExons[2],newExons[3]
-   else:
-      return None,None,None,None,newExons
-
-   e1_b_off = int(math.fabs(e1_begin - exons[0,0]))
-   e1_e_off = int(math.fabs(e1_end - exons[0,1]))
-
-   e2_b_off = int(math.fabs(e2_begin - exons[1,0]))
-   e2_e_off = int(math.fabs(e2_end - exons[1,1]))
-
-   pdb.set_trace()
-
-   return e1_b_off,e1_e_off,e2_b_off,e2_e_off,newExons
-"""
-
-"""
-def evaluatePositions(eBegin,eEnd):
-   eBegin_pos = [elem for elem in eBegin if elem == 0]
-   eBegin_neg = [elem for elem in eBegin if elem != 0]
-   eEnd_pos = [elem for elem in eEnd if elem == 0]
-   eEnd_neg = [elem for elem in eEnd if elem != 0]
-
-   mean_eBegin_neg = 0
-   for idx in range(len(eBegin_neg)):
-      mean_eBegin_neg += eBegin_neg[idx]
-      
-   try:
-      mean_eBegin_neg /= 1.0*len(eBegin_neg)
-   except:
-      mean_eBegin_neg = -1
-
-   mean_eEnd_neg = 0
-   for idx in range(len(eEnd_neg)):
-      mean_eEnd_neg += eEnd_neg[idx]
-
-   try:
-      mean_eEnd_neg /= 1.0*len(eEnd_neg)
-   except:
-      mean_eEnd_neg = -1
-
-   return eBegin_pos,eBegin_neg,eEnd_pos,eEnd_neg,mean_eBegin_neg,mean_eEnd_neg
-"""
+      is_covered = False
+
+      if is_spliced and with_coverage:
+         try:
+            current_coverage_nr = coverage_map[id]
+            is_covered = True
+         except:
+            is_covered = False
+
+
+      if is_spliced:
+         spliced_ctr += 1
+      else: 
+         unspliced_ctr += 1
+   
+      try:
+         current_ground_truth = all_filtered_reads[id]
+      except:
+         skipped_ctr += 1
+         continue
+
+      start_pos = current_prediction['start_pos']
+      chr = current_prediction['chr']
+      strand = current_prediction['strand']
+
+      #score = current_prediction['DPScores'].flatten().tolist()[0][0]
+      #pdb.set_trace()
+
+      predExons = current_prediction['predExons'] #:newExons, 'dna':dna, 'est':est
+      predExons = [e+start_pos for e in predExons]
+
+      spliced_flag = False
+
+      if len(predExons) == 4:
+         spliced_flag = True
+         predExons[1] -= 1
+         predExons[3] -= 1
+
+         if predExons[0] == 19504568:
+            pdb.set_trace()
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         e_stop = current_ground_truth['exon_stop']
+         e_start = current_ground_truth['exon_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
+
+         if p_start == predExons[0] and e_stop == predExons[1] and\
+         e_start == predExons[2] and p_stop == predExons[3]:
+            pos_correct = True
+         else:
+            pos_correct = False
+
+      elif len(predExons) == 2:
+         spliced_flag = False
+         predExons[1] -= 1
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
+
+         if math.fabs(p_start - predExons[0]) <= 0:# and math.fabs(p_stop - predExons[1]) <= 2:
+            pos_correct = True
+         else:
+            pos_correct = False
+
+      else:
+         pos_correct = False
+
+      if is_spliced and not spliced_flag:
+         unspliced_spliced_reads_ctr += 1
+
+      if is_spliced and not pos_correct and len(predExons) == 4 and predExons[1]!=-1:
+          wrong_spliced_reads_ctr += 1
+
+      if not is_spliced and spliced_flag:
+         wrong_unspliced_reads_ctr += 1
+
+      if not is_spliced and not pos_correct:
+         wrong_aligned_unspliced_reads_ctr += 1
+
+      if pos_correct:
+         pos_correct_ctr += 1
+
+         if is_spliced:
+               correct_spliced_ctr += 1
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  correct_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               correct_unspliced_ctr += 1
+
+      else:
+         pos_incorrect_ctr += 1
+
+         if is_spliced:
+               incorrect_spliced_ctr += 1
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  incorrect_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               incorrect_unspliced_ctr += 1
+
+      if with_coverage and spliced_flag:
+          if not is_covered:
+              current_coverage_nr=0 
+          if pos_correct:
+              print "%s\tcorrect\t%i" %( current_prediction['id'], current_coverage_nr)
+          else:
+              print "%s\twrong\t%i" %( current_prediction['id'], current_coverage_nr)
+
+      total_ctr += 1
+
+
+   numPredictions = len(allUniqPredictions)
+
+   # now that we have evaluated all instances put out all counters and sizes
+   print 'Total num. of examples: %d' % numPredictions
+
+   print "spliced/unspliced:  %d,%d " % (spliced_ctr, unspliced_ctr )
+   print "Correct/incorrect spliced: %d,%d " % (correct_spliced_ctr, incorrect_spliced_ctr )
+   print "Correct/incorrect unspliced: %d,%d " % (correct_unspliced_ctr , incorrect_unspliced_ctr )
+   print "Correct/incorrect covered spliced read: %d,%d " %\
+   (correct_covered_splice_ctr,incorrect_covered_splice_ctr)
+
+   print "pos_correct: %d,%d" % (pos_correct_ctr ,  pos_incorrect_ctr )
+
+   print 'unspliced_spliced reads: %d' % unspliced_spliced_reads_ctr 
+   print 'spliced reads at wrong_place: %d' % wrong_spliced_reads_ctr
+
+   print 'spliced_unspliced reads:  %d' % wrong_unspliced_reads_ctr
+   print 'wrong aligned at wrong_pos: %d' % wrong_aligned_unspliced_reads_ctr
+
+   print 'total_ctr: %d' % total_ctr
+
+   print "skipped: %d "  % skipped_ctr
+   print 'min. coverage: %d' % min_coverage
+
+   result_dict = {}
+   result_dict['skipped_ctr'] = skipped_ctr
+   result_dict['min_coverage'] = min_coverage
+
+   return result_dict
+
+
+
+
+
+def predict_on(allPredictions,all_filtered_reads,all_labels_fn,with_coverage,coverage_fn,coverage_labels_fn):
+   """
+   This function evaluates the predictions  made by QPalma.
+   It needs a pickled file containing the predictions themselves and the
+   ascii file with original reads.
+
+   Optionally one can specifiy a coverage file containing for each read the
+   coverage number estimated by a remapping step.
+
+
+   """
+
+   coverage_labels_fh = open(coverage_labels_fn,'w+')
+
+   all_labels_fh = open(all_labels_fn,'w+')
+
+   import qparser
+   qparser.parse_reads(all_filtered_reads)
+
+   coverage_map = {}
+
+
+   if with_coverage:
+      for line in open(coverage_fn):
+         id,coverage_nr = line.strip().split()
+         coverage_map[int(id)] = int(coverage_nr)
+
+   #out_fh = open('predicted_positions.txt','w+')
+
+   spliced_ctr = 0
+   unspliced_ctr = 0
+
+   pos_correct_ctr   = 0
+   pos_incorrect_ctr = 0
+
+   correct_spliced_ctr     = 0
+   correct_unspliced_ctr   = 0
+
+   incorrect_spliced_ctr     = 0
+   incorrect_unspliced_ctr   = 0
+
+   correct_covered_splice_ctr = 0
+   incorrect_covered_splice_ctr = 0
+
+   total_vmatch_instances_ctr = 0
+
+   unspliced_spliced_reads_ctr = 0
+   wrong_spliced_reads_ctr = 0
+
+   wrong_aligned_unspliced_reads_ctr = 0
+   wrong_unspliced_reads_ctr = 0
+
+   cut_pos_ctr = {}
+
+   total_ctr = 0
+   skipped_ctr = 0
+
+   is_spliced = False
+   min_coverage = 3
+
+   allUniqPredictions = {}
+
+   print 'Got %d predictions' % len(allPredictions)
+
+   for k,predictions in allPredictions.items():
+      for new_prediction in predictions:
+         id = new_prediction['id']
+         id = int(id)
+
+         if allUniqPredictions.has_key(id):
+            current_prediction = allUniqPredictions[id]
+
+            current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+            new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
+
+            if current_a_score < new_score :
+               allUniqPredictions[id] = new_prediction
+
+         else:
+            allUniqPredictions[id] = new_prediction
+
+   print 'Got %d uniq predictions' % len(allUniqPredictions)
+
+   for _id,current_prediction in allUniqPredictions.items():
+      id = current_prediction['id']
+      id = int(id)
+
+      if not id >= 1000000300000:
+         is_spliced = True
+      else:
+         is_spliced = False
+
+      is_covered = False
+
+      if is_spliced and with_coverage:
+         try:
+            current_coverage_nr = coverage_map[id]
+            is_covered = True
+         except:
+            is_covered = False
+
+
+      if is_spliced:
+         spliced_ctr += 1
+      else: 
+         unspliced_ctr += 1
+   
+      try:
+         #current_ground_truth = all_filtered_reads[id]
+         current_ground_truth = qparser.fetch_read(id)
+      except:
+         skipped_ctr += 1
+         continue
+
+      start_pos = current_prediction['start_pos']
+      chr = current_prediction['chr']
+      strand = current_prediction['strand']
+
+      #score = current_prediction['DPScores'].flatten().tolist()[0][0]
+      #pdb.set_trace()
+
+      predExons = current_prediction['predExons'] #:newExons, 'dna':dna, 'est':est
+      predExons = [e+start_pos for e in predExons]
+
+      spliced_flag = False
+
+      if len(predExons) == 4:
+         spliced_flag = True
+         predExons[1] -= 1
+         predExons[3] -= 1
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         e_stop = current_ground_truth['exon_stop']
+         e_start = current_ground_truth['exon_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
+
+         if p_start == predExons[0] and e_stop == predExons[1] and\
+         e_start == predExons[2] and p_stop == predExons[3]:
+            pos_correct = True
+         else:
+            pos_correct = False
+
+      elif len(predExons) == 2:
+         spliced_flag = False
+         predExons[1] -= 1
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
+
+         if math.fabs(p_start - predExons[0]) <= 0:# and math.fabs(p_stop - predExons[1]) <= 2:
+            pos_correct = True
+         else:
+            pos_correct = False
+
+      else:
+         pos_correct = False
+
+      if is_spliced and not spliced_flag:
+         unspliced_spliced_reads_ctr += 1
+
+      if is_spliced and not pos_correct and len(predExons) == 4 and predExons[1]!=-1:
+          wrong_spliced_reads_ctr += 1
+
+      if not is_spliced and spliced_flag:
+         wrong_unspliced_reads_ctr += 1
+
+      if not is_spliced and not pos_correct:
+         wrong_aligned_unspliced_reads_ctr += 1
+
+      if pos_correct:
+         pos_correct_ctr += 1
+
+         if is_spliced:
+               correct_spliced_ctr += 1
+               all_labels_fh.write('%d correct\n'%id)
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  correct_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               correct_unspliced_ctr += 1
+
+      else:
+         pos_incorrect_ctr += 1
+
+         if is_spliced:
+               incorrect_spliced_ctr += 1
+               all_labels_fh.write('%d wrong\n'%id)
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  incorrect_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               incorrect_unspliced_ctr += 1
+
+      if with_coverage:
+         if not is_covered:
+            current_coverage_nr=0 
+
+         if pos_correct:
+            new_line = "%s\tcorrect\t%i" %( current_prediction['id'], current_coverage_nr)
+         else:
+            new_line = "%s\twrong\t%i" %( current_prediction['id'], current_coverage_nr)
+
+         coverage_labels_fh.write(new_line+'\n')
+
+      total_ctr += 1
+
+   coverage_labels_fh.close()
+
+   numPredictions = len(allUniqPredictions)
+
+   result = []
+
+   # now that we have evaluated all instances put out all counters and sizes
+   result.append(('numPredictions',numPredictions))
+   result.append(('spliced_ctr',spliced_ctr))
+   result.append(('unspliced_ctr',unspliced_ctr))
+
+   result.append(('correct_spliced_ctr',correct_spliced_ctr))
+   result.append(('incorrect_spliced_ctr',incorrect_spliced_ctr))
+
+   result.append(('correct_unspliced_ctr',correct_unspliced_ctr))
+   result.append(('incorrect_unspliced_ctr',incorrect_unspliced_ctr))
+
+   result.append(('correct_covered_splice_ctr',correct_covered_splice_ctr))
+   result.append(('incorrect_covered_splice_ctr',incorrect_covered_splice_ctr))
+
+   result.append(('pos_correct_ctr',pos_correct_ctr))
+   result.append(('pos_incorrect_ctr',pos_incorrect_ctr))
+
+   result.append(('unspliced_spliced_reads_ctr',unspliced_spliced_reads_ctr))
+   result.append(('wrong_spliced_reads_ctr',wrong_spliced_reads_ctr))
+
+   result.append(('wrong_unspliced_reads_ctr',wrong_unspliced_reads_ctr))
+   result.append(('wrong_aligned_unspliced_reads_ctr',wrong_aligned_unspliced_reads_ctr))
+
+   result.append(('total_ctr',total_ctr))
+
+   result.append(('skipped_ctr',skipped_ctr))
+   result.append(('min_coverage',min_coverage))
+
+   return result
+
+
+
+def print_result(result):
+   # now that we have evaluated all instances put out all counters and sizes
+   for name,ctr in result:
+      print name,ctr
+
+
+def load_chunks(current_dir):
+   chunks_fn = []
+   for fn in os.listdir(current_dir):
+      if fn.startswith('chunk'):
+         chunks_fn.append(fn)
+
+   allPredictions = []
+
+   for c_fn in chunks_fn:
+      full_fn = os.path.join(current_dir,c_fn)
+      print full_fn
+      current_chunk = cPickle.load(open(full_fn))
+      allPredictions.extend(current_chunk)
+
+   return allPredictions
+
+
+def predict_on_all_chunks(current_dir,training_keys_fn):
+   """
+   We load all chunks from the current_dir belonging to one run.
+   Then we load the saved keys of the training set to restore the training and
+   testing sets.
+   Once we have done that we separately evaluate both sets.
+
+   """
+   
+   allPredictions = load_chunks(current_dir)
+
+   allPredictionsDict = {}
+   for elem in allPredictions:
+      id = elem['id']
+
+      if allPredictionsDict.has_key(id):
+         old_entry = allPredictionsDict[id]
+         old_entry.append(elem)
+         allPredictionsDict[id] = old_entry
+      else:
+         allPredictionsDict[id] = [elem]
+
+   training_keys = cPickle.load(open(training_keys_fn))
+
+   training_set = {}
+   for key in training_keys:
+      # we have the try construct because some of the reads used for training
+      # may not be found using vmatch at all
+      try:
+         training_set[key] = allPredictionsDict[key]
+         del allPredictionsDict[key]
+      except:
+         pass
+
+   test_set = allPredictionsDict
+
+   #test_set = {}
+   #for k in allPredictionsDict.keys()[:100]:
+   #   test_set[k] = allPredictionsDict[k]
+   #result_train = predict_on(training_set,all_filtered_reads,False,coverage_fn)
+   #pdb.set_trace()
+
+   # this is the heuristic.parsed_spliced_reads.txt file from the vmatch remapping step
+   coverage_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/all_coverages' 
+
+   all_filtered_reads = '/fml/ag-raetsch/share/projects/qpalma/solexa/new_run/allReads.full'
+
+   coverage_labels_fn = 'COVERAGE_LABELS'
+
+   result_test = predict_on(test_set,all_filtered_reads,'all_prediction_labels.txt',True,coverage_fn,coverage_labels_fn)
+   #print_result(result_train)
+
+   return result_test
+   
+
+if __name__ == '__main__':
+   pass