+ added framework code for training modus
[qpalma.git] / scripts / Evaluation.py
index af9bccb..952eb4c 100644 (file)
@@ -14,6 +14,16 @@ from qpalma.parsers import *
 data = None
 
 
+def result_statistic():
+   """
+
+   """   
+
+   num_unaligned_reads = 0
+   num_incorrectly_aligned_reads = 0
+   pass
+
+   
 def createErrorVSCutPlot(results):
    """
    This function takes the results of the evaluation and creates a tex table.
@@ -92,7 +102,7 @@ def compare_scores_and_labels(scores,labels):
             if otherPos == currentPos:
                continue
 
-            if labels[otherPos] == False and otherElem > currentElem:
+            if labels[otherPos] == False and otherElem >= currentElem:
                return False
 
    return True
@@ -145,119 +155,120 @@ def evaluate_example(current_prediction):
    predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
    truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
 
-   #pdb.set_trace()
-
    pos_comparison = (predPositions == truePositions)
 
-   #if label == True and pos_comparison == False:
-   # pdb.set_trace()
-
    return label,pos_comparison,pred_score
 
 
 def prediction_on(filename):
 
-   incorrect_gt_cuts       = {}
-   incorrect_vmatch_cuts   = {}
-
    allPredictions = cPickle.load(open(filename))
 
-   exon1Begin     = []
-   exon1End       = []
-   exon2Begin     = []
-   exon2End       = []
-   allWrongExons  = []
-   allDoubleScores = []
-
    gt_correct_ctr = 0
    gt_incorrect_ctr = 0
+   incorrect_gt_cuts       = {}
 
    pos_correct_ctr   = 0
    pos_incorrect_ctr = 0
+   incorrect_vmatch_cuts   = {}
 
    score_correct_ctr = 0
    score_incorrect_ctr = 0
 
    total_gt_examples = 0
-
    total_vmatch_instances_ctr = 0
+
    true_vmatch_instances_ctr = 0
 
-   for current_example_pred in allPredictions:
-      gt_example = current_example_pred[0]
-      gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
-      gt_correct = evaluate_unmapped_example(gt_example)
+   allUniquePredictions = [False]*len(allPredictions)
 
-      exampleIdx = gt_example['exampleIdx']
+   for pos,current_example_pred in enumerate(allPredictions):
+      for elem_nr,new_prediction in enumerate(current_example_pred[1:]):
 
-      cut_pos = gt_example['true_cut']
+         if allUniquePredictions[pos] != False:
+            current_prediction = allUniquePredictions[pos]
 
-      current_scores = []
-      current_labels = []
-      current_scores.append(gt_score)
-      current_labels.append(gt_correct)
+            current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+            new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
 
-      if gt_correct:
-         gt_correct_ctr += 1
-      else:
-         gt_incorrect_ctr += 1
+            if current_a_score < new_score :
+               allUniquePredictions[id] = new_prediction
 
-         try:
-            incorrect_gt_cuts[cut_pos] += 1
-         except:
-            incorrect_gt_cuts[cut_pos] = 1
+         else:
+            allUniquePredictions[pos] = new_prediction
+
+   for current_pred in allUniquePredictions:
+      if current_pred == False:
+         continue
+
+   #for current_example_pred in allPredictions:
+      #gt_example = current_example_pred[0]
+      #gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
+      #gt_correct = evaluate_unmapped_example(gt_example)
+
+      #exampleIdx = gt_example['exampleIdx']
+
+      #cut_pos = gt_example['true_cut']
 
-      total_gt_examples += 1
+      #if gt_correct:
+      #   gt_correct_ctr += 1
+      #else:
+      #   gt_incorrect_ctr += 1
+
+      #   try:
+      #      incorrect_gt_cuts[cut_pos] += 1
+      #   except:
+      #      incorrect_gt_cuts[cut_pos] = 1
+
+      #total_gt_examples += 1
      
-      #pdb.set_trace()
+      #current_scores = []
+      #current_labels = []
+      #for elem_nr,current_pred in enumerate(current_example_pred[1:]):
 
-      for elem_nr,current_pred in enumerate(current_example_pred[1:]):
-         current_label,comparison_result,current_score = evaluate_example(current_pred)
+      current_label,comparison_result,current_score = evaluate_example(current_pred)
 
-         # if vmatch found the right read pos we check for right exons
-         # boundaries
-         if current_label:
-            if comparison_result:
-               pos_correct_ctr += 1
-            else:
-               pos_incorrect_ctr += 1
+      # if vmatch found the right read pos we check for right exons
+      # boundaries
+      #if current_label:
+      if comparison_result:
+         pos_correct_ctr += 1
+      else:
+         pos_incorrect_ctr += 1
 
-               try:
-                  incorrect_vmatch_cuts[cut_pos] += 1
-               except:
-                  incorrect_vmatch_cuts[cut_pos] = 1
+            #try:
+            #   incorrect_vmatch_cuts[cut_pos] += 1
+            #except:
+            #   incorrect_vmatch_cuts[cut_pos] = 1
 
-            true_vmatch_instances_ctr += 1
-            
-         current_scores.append(current_score)
-         current_labels.append(current_label)
+         true_vmatch_instances_ctr += 1
+         
+      #current_scores.append(current_score)
+      #current_labels.append(current_label)
 
-         total_vmatch_instances_ctr += 1
+      total_vmatch_instances_ctr += 1
 
       # check whether the correct predictions score higher than the incorrect
       # ones
-      cmp_res = compare_scores_and_labels(current_scores,current_labels)
-      if cmp_res:
-         score_correct_ctr += 1
-      else:
-         score_incorrect_ctr += 1
+      #cmp_res = compare_scores_and_labels(current_scores,current_labels)
+      #if cmp_res:
+      #   score_correct_ctr += 1
+      #else:
+      #   score_incorrect_ctr += 1
 
    # now that we have evaluated all instances put out all counters and sizes
    print 'Total num. of examples: %d' % len(allPredictions)
    print 'Number of correct ground truth examples: %d' % gt_correct_ctr
    print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
-   print 'Correct pos: %d, incorrect pos: %d' %\
-   (pos_correct_ctr,pos_incorrect_ctr)
+   print 'Correct pos: %d, incorrect pos: %d' % (pos_correct_ctr,pos_incorrect_ctr)
    print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
    print 'Correct scores: %d, incorrect scores: %d' %\
    (score_correct_ctr,score_incorrect_ctr)
 
-   pos_error   = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
+   pos_error   = 1.0 * pos_incorrect_ctr / total_vmatch_instances_ctr
    score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
    gt_error    = 1.0 * gt_incorrect_ctr / total_gt_examples
 
-   #print pos_error,score_error,gt_error
-
    return (pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts)
 
 
@@ -298,6 +309,7 @@ def perform_prediction(current_dir,run_name):
    This function takes care of starting the jobs needed for the prediction phase
    of qpalma
    """
+
    #for i in range(1,6):
    for i in range(1,2):
       cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s %d |\
@@ -329,53 +341,151 @@ def forall_experiments(current_func,tl_dir):
    for current_dir in run_dirs:
       run_name = current_dir.split('/')[-1]
 
-      current_func(current_dir,run_name)
+      pdb.set_trace()
+
+      if current_func.__name__ == 'perform_prediction':
+         current_func(current_dir,run_name)
+
+      if current_func.__name__ == 'collect_prediction':
+         train_result,test_result,currentRunId = current_func(current_dir,run_name)
+         all_results[currentRunId] = test_result
+         pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
+         all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
+
+   if current_func.__name__ == 'collect_prediction':
+      #createErrorVSCutPlot(all_error_rates)
+      createTable(all_results)
 
-      #train_result,test_result,currentRunId = current_func(current_dir,run_name)
-      #all_results[currentRunId] = test_result
-      #pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
-      #all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
+def _predict_on(filename,filtered_reads,with_coverage):
+   """
+   This function evaluates the predictions  made by QPalma.
+   It needs a pickled file containing the predictions themselves and the
+   ascii file with original reads.
 
-   #createErrorVSCutPlot(all_error_rates)
-   #createTable(all_results)
+   Optionally one can specifiy a coverage file containing for each read the
+   coverage number estimated by a remapping step.
 
+   """
 
+   coverage_map = {}
 
-def predict_on(filename,filtered_reads):
+   if with_coverage:
+      for line in open('/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/coverage_results/ALL_COVERAGES'):
+         id,coverage_nr = line.strip().split()
+         coverage_map[int(id)] = int(coverage_nr)
 
    print 'parsing filtered reads..'
    all_filtered_reads = parse_filtered_reads(filtered_reads)
    print 'found %d filtered reads' % len(all_filtered_reads)
 
+   out_fh = open('predicted_positions.txt','w+')
+
    allPredictions = cPickle.load(open(filename))
 
+   spliced_ctr = 0
+   unspliced_ctr = 0
+
    pos_correct_ctr   = 0
    pos_incorrect_ctr = 0
 
-   score_correct_ctr = 0
-   score_incorrect_ctr = 0
+   correct_spliced_ctr     = 0
+   correct_unspliced_ctr   = 0
+
+   incorrect_spliced_ctr     = 0
+   incorrect_unspliced_ctr   = 0
+
+   correct_covered_splice_ctr = 0
+   incorrect_covered_splice_ctr = 0
 
    total_vmatch_instances_ctr = 0
 
+   unspliced_spliced_reads_ctr = 0
+   wrong_spliced_reads_ctr = 0
 
-   for current_prediction in allPredictions:
+   wrong_aligned_unspliced_reads_ctr = 0
+   wrong_unspliced_reads_ctr = 0
+
+   cut_pos_ctr = {}
+
+   total_ctr = 0
+   skipped_ctr = 0
+
+   is_spliced = False
+   min_coverage = 3
+
+   allUniqPredictions = {}
+
+   print 'Got %d predictions' % len(allPredictions)
+
+   for new_prediction in allPredictions:
+      id = new_prediction['id']
+      id = int(id)
+
+      if allUniqPredictions.has_key(id):
+         current_prediction = allUniqPredictions[id]
+
+         current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+         new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
+
+         if current_a_score < new_score :
+            allUniqPredictions[id] = new_prediction
+
+      else:
+         allUniqPredictions[id] = new_prediction
+
+   print 'Got %d uniq predictions' % len(allUniqPredictions)
+
+   #for current_prediction in allPredictions:
+   for _id,current_prediction in allUniqPredictions.items():
       id = current_prediction['id']
-      current_ground_truth = all_filtered_reads[id]
+      id = int(id)
+
+      if not id >= 1000000300000:
+         is_spliced = True
+      else:
+         is_spliced = False
+
+      is_covered = False
+
+      if is_spliced and with_coverage:
+         try:
+            current_coverage_nr = coverage_map[id]
+            is_covered = True
+         except:
+            is_covered = False
+
+
+      if is_spliced:
+         spliced_ctr += 1
+      else: 
+         unspliced_ctr += 1
+   
+      try:
+         current_ground_truth = all_filtered_reads[id]
+      except:
+         skipped_ctr += 1
+         continue
 
       start_pos = current_prediction['start_pos']
       chr = current_prediction['chr']
       strand = current_prediction['strand']
 
       #score = current_prediction['DPScores'].flatten().tolist()[0][0]
-
       #pdb.set_trace()
 
       predExons = current_prediction['predExons'] #:newExons, 'dna':dna, 'est':est
       predExons = [e+start_pos for e in predExons]
+
+      spliced_flag = False
+
       if len(predExons) == 4:
+         spliced_flag = True
          predExons[1] -= 1
          predExons[3] -= 1
 
+         if predExons[0] == 19504568:
+            pdb.set_trace()
+
          cut_pos = current_ground_truth['true_cut']
          p_start = current_ground_truth['p_start']
          e_stop = current_ground_truth['exon_stop']
@@ -386,12 +496,12 @@ def predict_on(filename,filtered_reads):
 
          if p_start == predExons[0] and e_stop == predExons[1] and\
          e_start == predExons[2] and p_stop == predExons[3]:
-            pos_correct_ctr += 1
+            pos_correct = True
          else:
-            pos_incorrect_ctr += 1
-            #pdb.set_trace()
+            pos_correct = False
 
       elif len(predExons) == 2:
+         spliced_flag = False
          predExons[1] -= 1
 
          cut_pos = current_ground_truth['true_cut']
@@ -400,44 +510,420 @@ def predict_on(filename,filtered_reads):
 
          true_cut = current_ground_truth['true_cut']
 
-         if p_start == predExons[0] and p_stop == predExons[1]:
-            pos_correct_ctr += 1
+         if math.fabs(p_start - predExons[0]) <= 0:# and math.fabs(p_stop - predExons[1]) <= 2:
+            pos_correct = True
          else:
-            pos_incorrect_ctr += 1
-            #pdb.set_trace()
+            pos_correct = False
 
       else:
-         pass
-      ## check whether the correct predictions score higher than the incorrect
-      ## ones
-      #cmp_res = compare_scores_and_labels(current_scores,current_labels)
-      #if cmp_res:
-      #   score_correct_ctr += 1
-      #else:
-      #   score_incorrect_ctr += 1
+         pos_correct = False
+
+      if is_spliced and not spliced_flag:
+         unspliced_spliced_reads_ctr += 1
+
+      if is_spliced and not pos_correct and len(predExons) == 4 and predExons[1]!=-1:
+          wrong_spliced_reads_ctr += 1
+
+      if not is_spliced and spliced_flag:
+         wrong_unspliced_reads_ctr += 1
+
+      if not is_spliced and not pos_correct:
+         wrong_aligned_unspliced_reads_ctr += 1
+
+      if pos_correct:
+         pos_correct_ctr += 1
+
+         if is_spliced:
+               correct_spliced_ctr += 1
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  correct_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               correct_unspliced_ctr += 1
+
+      else:
+         pos_incorrect_ctr += 1
+
+         if is_spliced:
+               incorrect_spliced_ctr += 1
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  incorrect_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               incorrect_unspliced_ctr += 1
+
+      if with_coverage and spliced_flag:
+          if not is_covered:
+              current_coverage_nr=0 
+          if pos_correct:
+              print "%s\tcorrect\t%i" %( current_prediction['id'], current_coverage_nr)
+          else:
+              print "%s\twrong\t%i" %( current_prediction['id'], current_coverage_nr)
+
+      total_ctr += 1
+
 
-   numPredictions = len(allPredictions)
+   numPredictions = len(allUniqPredictions)
 
    # now that we have evaluated all instances put out all counters and sizes
    print 'Total num. of examples: %d' % numPredictions
-   print 'Correct pos: %2.3f, incorrect pos: %2.3f' %\
-   (pos_correct_ctr/(1.0*numPredictions),pos_incorrect_ctr/(1.0*numPredictions))
 
-   #print 'Correct scores: %d, incorrect scores: %d' %\
-   #(score_correct_ctr,score_incorrect_ctr)
+   print "spliced/unspliced:  %d,%d " % (spliced_ctr, unspliced_ctr )
+   print "Correct/incorrect spliced: %d,%d " % (correct_spliced_ctr, incorrect_spliced_ctr )
+   print "Correct/incorrect unspliced: %d,%d " % (correct_unspliced_ctr , incorrect_unspliced_ctr )
+   print "Correct/incorrect covered spliced read: %d,%d " %\
+   (correct_covered_splice_ctr,incorrect_covered_splice_ctr)
 
-   #pos_error   = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
-   #score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
+   print "pos_correct: %d,%d" % (pos_correct_ctr ,  pos_incorrect_ctr )
 
+   print 'unspliced_spliced reads: %d' % unspliced_spliced_reads_ctr 
+   print 'spliced reads at wrong_place: %d' % wrong_spliced_reads_ctr
 
-if __name__ == '__main__':
-   #dir = sys.argv[1]
-   #assert os.path.exists(dir), 'Error: Directory does not exist!'
+   print 'spliced_unspliced reads:  %d' % wrong_unspliced_reads_ctr
+   print 'wrong aligned at wrong_pos: %d' % wrong_aligned_unspliced_reads_ctr
+
+   print 'total_ctr: %d' % total_ctr
+
+   print "skipped: %d "  % skipped_ctr
+   print 'min. coverage: %d' % min_coverage
+
+   result_dict = {}
+   result_dict['skipped_ctr'] = skipped_ctr
+   result_dict['min_coverage'] = min_coverage
+
+   return result_dict
+
+
+
+
+
+def predict_on(allPredictions,all_filtered_reads,all_labels_fn,with_coverage,coverage_fn,coverage_labels_fn):
+   """
+   This function evaluates the predictions  made by QPalma.
+   It needs a pickled file containing the predictions themselves and the
+   ascii file with original reads.
+
+   Optionally one can specifiy a coverage file containing for each read the
+   coverage number estimated by a remapping step.
+
+
+   """
+
+   coverage_labels_fh = open(coverage_labels_fn,'w+')
+
+   all_labels_fh = open(all_labels_fn,'w+')
+
+   import qparser
+   qparser.parse_reads(all_filtered_reads)
+
+   coverage_map = {}
+
+
+   if with_coverage:
+      for line in open(coverage_fn):
+         id,coverage_nr = line.strip().split()
+         coverage_map[int(id)] = int(coverage_nr)
+
+   #out_fh = open('predicted_positions.txt','w+')
+
+   spliced_ctr = 0
+   unspliced_ctr = 0
+
+   pos_correct_ctr   = 0
+   pos_incorrect_ctr = 0
+
+   correct_spliced_ctr     = 0
+   correct_unspliced_ctr   = 0
+
+   incorrect_spliced_ctr     = 0
+   incorrect_unspliced_ctr   = 0
+
+   correct_covered_splice_ctr = 0
+   incorrect_covered_splice_ctr = 0
+
+   total_vmatch_instances_ctr = 0
 
-   #global data
-   #data_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/dataset_remapped_test_new'
-   #data = cPickle.load(open(data_fn))
-   #forall_experiments(perform_prediction,dir)
-   #forall_experiments(collect_prediction,dir)
+   unspliced_spliced_reads_ctr = 0
+   wrong_spliced_reads_ctr = 0
 
-   predict_on(sys.argv[1],sys.argv[2])
+   wrong_aligned_unspliced_reads_ctr = 0
+   wrong_unspliced_reads_ctr = 0
+
+   cut_pos_ctr = {}
+
+   total_ctr = 0
+   skipped_ctr = 0
+
+   is_spliced = False
+   min_coverage = 3
+
+   allUniqPredictions = {}
+
+   print 'Got %d predictions' % len(allPredictions)
+
+   for k,predictions in allPredictions.items():
+      for new_prediction in predictions:
+         id = new_prediction['id']
+         id = int(id)
+
+         if allUniqPredictions.has_key(id):
+            current_prediction = allUniqPredictions[id]
+
+            current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+            new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
+
+            if current_a_score < new_score :
+               allUniqPredictions[id] = new_prediction
+
+         else:
+            allUniqPredictions[id] = new_prediction
+
+   print 'Got %d uniq predictions' % len(allUniqPredictions)
+
+   for _id,current_prediction in allUniqPredictions.items():
+      id = current_prediction['id']
+      id = int(id)
+
+      if not id >= 1000000300000:
+         is_spliced = True
+      else:
+         is_spliced = False
+
+      is_covered = False
+
+      if is_spliced and with_coverage:
+         try:
+            current_coverage_nr = coverage_map[id]
+            is_covered = True
+         except:
+            is_covered = False
+
+
+      if is_spliced:
+         spliced_ctr += 1
+      else: 
+         unspliced_ctr += 1
+   
+      try:
+         #current_ground_truth = all_filtered_reads[id]
+         current_ground_truth = qparser.fetch_read(id)
+      except:
+         skipped_ctr += 1
+         continue
+
+      start_pos = current_prediction['start_pos']
+      chr = current_prediction['chr']
+      strand = current_prediction['strand']
+
+      #score = current_prediction['DPScores'].flatten().tolist()[0][0]
+      #pdb.set_trace()
+
+      predExons = current_prediction['predExons'] #:newExons, 'dna':dna, 'est':est
+      predExons = [e+start_pos for e in predExons]
+
+      spliced_flag = False
+
+      if len(predExons) == 4:
+         spliced_flag = True
+         predExons[1] -= 1
+         predExons[3] -= 1
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         e_stop = current_ground_truth['exon_stop']
+         e_start = current_ground_truth['exon_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
+
+         if p_start == predExons[0] and e_stop == predExons[1] and\
+         e_start == predExons[2] and p_stop == predExons[3]:
+            pos_correct = True
+         else:
+            pos_correct = False
+
+      elif len(predExons) == 2:
+         spliced_flag = False
+         predExons[1] -= 1
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
+
+         if math.fabs(p_start - predExons[0]) <= 0:# and math.fabs(p_stop - predExons[1]) <= 2:
+            pos_correct = True
+         else:
+            pos_correct = False
+
+      else:
+         pos_correct = False
+
+      if is_spliced and not spliced_flag:
+         unspliced_spliced_reads_ctr += 1
+
+      if is_spliced and not pos_correct and len(predExons) == 4 and predExons[1]!=-1:
+          wrong_spliced_reads_ctr += 1
+
+      if not is_spliced and spliced_flag:
+         wrong_unspliced_reads_ctr += 1
+
+      if not is_spliced and not pos_correct:
+         wrong_aligned_unspliced_reads_ctr += 1
+
+      if pos_correct:
+         pos_correct_ctr += 1
+
+         if is_spliced:
+               correct_spliced_ctr += 1
+               all_labels_fh.write('%d correct\n'%id)
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  correct_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               correct_unspliced_ctr += 1
+
+      else:
+         pos_incorrect_ctr += 1
+
+         if is_spliced:
+               incorrect_spliced_ctr += 1
+               all_labels_fh.write('%d wrong\n'%id)
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  incorrect_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               incorrect_unspliced_ctr += 1
+
+      if with_coverage:
+         if not is_covered:
+            current_coverage_nr=0 
+
+         if pos_correct:
+            new_line = "%s\tcorrect\t%i" %( current_prediction['id'], current_coverage_nr)
+         else:
+            new_line = "%s\twrong\t%i" %( current_prediction['id'], current_coverage_nr)
+
+         coverage_labels_fh.write(new_line+'\n')
+
+      total_ctr += 1
+
+   coverage_labels_fh.close()
+
+   numPredictions = len(allUniqPredictions)
+
+   result = []
+
+   # now that we have evaluated all instances put out all counters and sizes
+   result.append(('numPredictions',numPredictions))
+   result.append(('spliced_ctr',spliced_ctr))
+   result.append(('unspliced_ctr',unspliced_ctr))
+
+   result.append(('correct_spliced_ctr',correct_spliced_ctr))
+   result.append(('incorrect_spliced_ctr',incorrect_spliced_ctr))
+
+   result.append(('correct_unspliced_ctr',correct_unspliced_ctr))
+   result.append(('incorrect_unspliced_ctr',incorrect_unspliced_ctr))
+
+   result.append(('correct_covered_splice_ctr',correct_covered_splice_ctr))
+   result.append(('incorrect_covered_splice_ctr',incorrect_covered_splice_ctr))
+
+   result.append(('pos_correct_ctr',pos_correct_ctr))
+   result.append(('pos_incorrect_ctr',pos_incorrect_ctr))
+
+   result.append(('unspliced_spliced_reads_ctr',unspliced_spliced_reads_ctr))
+   result.append(('wrong_spliced_reads_ctr',wrong_spliced_reads_ctr))
+
+   result.append(('wrong_unspliced_reads_ctr',wrong_unspliced_reads_ctr))
+   result.append(('wrong_aligned_unspliced_reads_ctr',wrong_aligned_unspliced_reads_ctr))
+
+   result.append(('total_ctr',total_ctr))
+
+   result.append(('skipped_ctr',skipped_ctr))
+   result.append(('min_coverage',min_coverage))
+
+   return result
+
+
+
+def print_result(result):
+   # now that we have evaluated all instances put out all counters and sizes
+   for name,ctr in result:
+      print name,ctr
+
+
+def load_chunks(current_dir):
+   chunks_fn = []
+   for fn in os.listdir(current_dir):
+      if fn.startswith('chunk'):
+         chunks_fn.append(fn)
+
+   allPredictions = []
+
+   for c_fn in chunks_fn:
+      full_fn = os.path.join(current_dir,c_fn)
+      print full_fn
+      current_chunk = cPickle.load(open(full_fn))
+      allPredictions.extend(current_chunk)
+
+   return allPredictions
+
+
+def predict_on_all_chunks(current_dir,training_keys_fn):
+   """
+   We load all chunks from the current_dir belonging to one run.
+   Then we load the saved keys of the training set to restore the training and
+   testing sets.
+   Once we have done that we separately evaluate both sets.
+
+   """
+   
+   allPredictions = load_chunks(current_dir)
+
+   allPredictionsDict = {}
+   for elem in allPredictions:
+      id = elem['id']
+
+      if allPredictionsDict.has_key(id):
+         old_entry = allPredictionsDict[id]
+         old_entry.append(elem)
+         allPredictionsDict[id] = old_entry
+      else:
+         allPredictionsDict[id] = [elem]
+
+   training_keys = cPickle.load(open(training_keys_fn))
+
+   training_set = {}
+   for key in training_keys:
+      # we have the try construct because some of the reads used for training
+      # may not be found using vmatch at all
+      try:
+         training_set[key] = allPredictionsDict[key]
+         del allPredictionsDict[key]
+      except:
+         pass
+
+   test_set = allPredictionsDict
+
+   #test_set = {}
+   #for k in allPredictionsDict.keys()[:100]:
+   #   test_set[k] = allPredictionsDict[k]
+   #result_train = predict_on(training_set,all_filtered_reads,False,coverage_fn)
+   #pdb.set_trace()
+
+   # this is the heuristic.parsed_spliced_reads.txt file from the vmatch remapping step
+   coverage_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/all_coverages' 
+
+   all_filtered_reads = '/fml/ag-raetsch/share/projects/qpalma/solexa/new_run/allReads.full'
+
+   coverage_labels_fn = 'COVERAGE_LABELS'
+
+   result_test = predict_on(test_set,all_filtered_reads,'all_prediction_labels.txt',True,coverage_fn,coverage_labels_fn)
+   #print_result(result_train)
+
+   return result_test
+   
+
+if __name__ == '__main__':
+   pass