+ added framework code for training modus
[qpalma.git] / scripts / Evaluation.py
index bc30d59..952eb4c 100644 (file)
@@ -3,15 +3,27 @@
 
 import cPickle
 import sys
-import pydb
 import pdb
 import os
 import os.path
 import math
 
+from qpalma.parsers import *
+
+
 data = None
 
 
+def result_statistic():
+   """
+
+   """   
+
+   num_unaligned_reads = 0
+   num_incorrectly_aligned_reads = 0
+   pass
+
+   
 def createErrorVSCutPlot(results):
    """
    This function takes the results of the evaluation and creates a tex table.
@@ -90,7 +102,7 @@ def compare_scores_and_labels(scores,labels):
             if otherPos == currentPos:
                continue
 
-            if labels[otherPos] == False and otherElem > currentElem:
+            if labels[otherPos] == False and otherElem >= currentElem:
                return False
 
    return True
@@ -143,119 +155,120 @@ def evaluate_example(current_prediction):
    predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
    truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
 
-   #pdb.set_trace()
-
    pos_comparison = (predPositions == truePositions)
 
-   #if label == True and pos_comparison == False:
-   # pdb.set_trace()
-
    return label,pos_comparison,pred_score
 
 
 def prediction_on(filename):
 
-   incorrect_gt_cuts       = {}
-   incorrect_vmatch_cuts   = {}
-
    allPredictions = cPickle.load(open(filename))
 
-   exon1Begin     = []
-   exon1End       = []
-   exon2Begin     = []
-   exon2End       = []
-   allWrongExons  = []
-   allDoubleScores = []
-
    gt_correct_ctr = 0
    gt_incorrect_ctr = 0
+   incorrect_gt_cuts       = {}
 
    pos_correct_ctr   = 0
    pos_incorrect_ctr = 0
+   incorrect_vmatch_cuts   = {}
 
    score_correct_ctr = 0
    score_incorrect_ctr = 0
 
    total_gt_examples = 0
-
    total_vmatch_instances_ctr = 0
+
    true_vmatch_instances_ctr = 0
 
-   for current_example_pred in allPredictions:
-      gt_example = current_example_pred[0]
-      gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
-      gt_correct = evaluate_unmapped_example(gt_example)
+   allUniquePredictions = [False]*len(allPredictions)
 
-      exampleIdx = gt_example['exampleIdx']
+   for pos,current_example_pred in enumerate(allPredictions):
+      for elem_nr,new_prediction in enumerate(current_example_pred[1:]):
 
-      cut_pos = gt_example['true_cut']
+         if allUniquePredictions[pos] != False:
+            current_prediction = allUniquePredictions[pos]
 
-      current_scores = []
-      current_labels = []
-      current_scores.append(gt_score)
-      current_labels.append(gt_correct)
+            current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+            new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
 
-      if gt_correct:
-         gt_correct_ctr += 1
-      else:
-         gt_incorrect_ctr += 1
+            if current_a_score < new_score :
+               allUniquePredictions[id] = new_prediction
 
-         try:
-            incorrect_gt_cuts[cut_pos] += 1
-         except:
-            incorrect_gt_cuts[cut_pos] = 1
+         else:
+            allUniquePredictions[pos] = new_prediction
+
+   for current_pred in allUniquePredictions:
+      if current_pred == False:
+         continue
+
+   #for current_example_pred in allPredictions:
+      #gt_example = current_example_pred[0]
+      #gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
+      #gt_correct = evaluate_unmapped_example(gt_example)
 
-      total_gt_examples += 1
+      #exampleIdx = gt_example['exampleIdx']
+
+      #cut_pos = gt_example['true_cut']
+
+      #if gt_correct:
+      #   gt_correct_ctr += 1
+      #else:
+      #   gt_incorrect_ctr += 1
+
+      #   try:
+      #      incorrect_gt_cuts[cut_pos] += 1
+      #   except:
+      #      incorrect_gt_cuts[cut_pos] = 1
+
+      #total_gt_examples += 1
      
-      #pdb.set_trace()
+      #current_scores = []
+      #current_labels = []
+      #for elem_nr,current_pred in enumerate(current_example_pred[1:]):
 
-      for elem_nr,current_pred in enumerate(current_example_pred[1:]):
-         current_label,comparison_result,current_score = evaluate_example(current_pred)
+      current_label,comparison_result,current_score = evaluate_example(current_pred)
 
-         # if vmatch found the right read pos we check for right exons
-         # boundaries
-         if current_label:
-            if comparison_result:
-               pos_correct_ctr += 1
-            else:
-               pos_incorrect_ctr += 1
+      # if vmatch found the right read pos we check for right exons
+      # boundaries
+      #if current_label:
+      if comparison_result:
+         pos_correct_ctr += 1
+      else:
+         pos_incorrect_ctr += 1
 
-               try:
-                  incorrect_vmatch_cuts[cut_pos] += 1
-               except:
-                  incorrect_vmatch_cuts[cut_pos] = 1
+            #try:
+            #   incorrect_vmatch_cuts[cut_pos] += 1
+            #except:
+            #   incorrect_vmatch_cuts[cut_pos] = 1
 
-            true_vmatch_instances_ctr += 1
-            
-         current_scores.append(current_score)
-         current_labels.append(current_label)
+         true_vmatch_instances_ctr += 1
+         
+      #current_scores.append(current_score)
+      #current_labels.append(current_label)
 
-         total_vmatch_instances_ctr += 1
+      total_vmatch_instances_ctr += 1
 
       # check whether the correct predictions score higher than the incorrect
       # ones
-      cmp_res = compare_scores_and_labels(current_scores,current_labels)
-      if cmp_res:
-         score_correct_ctr += 1
-      else:
-         score_incorrect_ctr += 1
+      #cmp_res = compare_scores_and_labels(current_scores,current_labels)
+      #if cmp_res:
+      #   score_correct_ctr += 1
+      #else:
+      #   score_incorrect_ctr += 1
 
    # now that we have evaluated all instances put out all counters and sizes
    print 'Total num. of examples: %d' % len(allPredictions)
    print 'Number of correct ground truth examples: %d' % gt_correct_ctr
    print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
-   print 'Correct pos: %d, incorrect pos: %d' %\
-   (pos_correct_ctr,pos_incorrect_ctr)
+   print 'Correct pos: %d, incorrect pos: %d' % (pos_correct_ctr,pos_incorrect_ctr)
    print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
    print 'Correct scores: %d, incorrect scores: %d' %\
    (score_correct_ctr,score_incorrect_ctr)
 
-   pos_error   = 1.0 * pos_incorrect_ctr / true_vmatch_instances_ctr
+   pos_error   = 1.0 * pos_incorrect_ctr / total_vmatch_instances_ctr
    score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
    gt_error    = 1.0 * gt_incorrect_ctr / total_gt_examples
 
-   #print pos_error,score_error,gt_error
-
    return (pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts)
 
 
@@ -296,7 +309,9 @@ def perform_prediction(current_dir,run_name):
    This function takes care of starting the jobs needed for the prediction phase
    of qpalma
    """
-   for i in range(1,6):
+
+   #for i in range(1,6):
+   for i in range(1,2):
       cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s %d |\
       qsub -l h_vmem=12.0G -cwd -j y -N \"%s_%d.log\"'%(current_dir,i,run_name,i)
 
@@ -305,6 +320,7 @@ def perform_prediction(current_dir,run_name):
       os.system(cmd)
 
 
+
 def forall_experiments(current_func,tl_dir):
    """
    Given the toplevel directoy this function calls for each subdir the
@@ -325,24 +341,589 @@ def forall_experiments(current_func,tl_dir):
    for current_dir in run_dirs:
       run_name = current_dir.split('/')[-1]
 
-      current_func(current_dir,run_name)
+      pdb.set_trace()
 
-      train_result,test_result,currentRunId = current_func(current_dir,run_name)
-      all_results[currentRunId] = test_result
-      pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
-      all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
+      if current_func.__name__ == 'perform_prediction':
+         current_func(current_dir,run_name)
 
-   createErrorVSCutPlot(all_error_rates)
-   #createTable(all_results)
+      if current_func.__name__ == 'collect_prediction':
+         train_result,test_result,currentRunId = current_func(current_dir,run_name)
+         all_results[currentRunId] = test_result
+         pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
+         all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
 
+   if current_func.__name__ == 'collect_prediction':
+      #createErrorVSCutPlot(all_error_rates)
+      createTable(all_results)
 
-if __name__ == '__main__':
-   dir = sys.argv[1]
-   assert os.path.exists(dir), 'Error: Directory does not exist!'
+def _predict_on(filename,filtered_reads,with_coverage):
+   """
+   This function evaluates the predictions  made by QPalma.
+   It needs a pickled file containing the predictions themselves and the
+   ascii file with original reads.
+
+   Optionally one can specifiy a coverage file containing for each read the
+   coverage number estimated by a remapping step.
+
+   """
+
+   coverage_map = {}
+
+   if with_coverage:
+      for line in open('/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/coverage_results/ALL_COVERAGES'):
+         id,coverage_nr = line.strip().split()
+         coverage_map[int(id)] = int(coverage_nr)
+
+   print 'parsing filtered reads..'
+   all_filtered_reads = parse_filtered_reads(filtered_reads)
+   print 'found %d filtered reads' % len(all_filtered_reads)
+
+   out_fh = open('predicted_positions.txt','w+')
+
+   allPredictions = cPickle.load(open(filename))
+
+   spliced_ctr = 0
+   unspliced_ctr = 0
+
+   pos_correct_ctr   = 0
+   pos_incorrect_ctr = 0
+
+   correct_spliced_ctr     = 0
+   correct_unspliced_ctr   = 0
+
+   incorrect_spliced_ctr     = 0
+   incorrect_unspliced_ctr   = 0
+
+   correct_covered_splice_ctr = 0
+   incorrect_covered_splice_ctr = 0
+
+   total_vmatch_instances_ctr = 0
+
+   unspliced_spliced_reads_ctr = 0
+   wrong_spliced_reads_ctr = 0
+
+   wrong_aligned_unspliced_reads_ctr = 0
+   wrong_unspliced_reads_ctr = 0
+
+   cut_pos_ctr = {}
+
+   total_ctr = 0
+   skipped_ctr = 0
+
+   is_spliced = False
+   min_coverage = 3
+
+   allUniqPredictions = {}
+
+   print 'Got %d predictions' % len(allPredictions)
+
+   for new_prediction in allPredictions:
+      id = new_prediction['id']
+      id = int(id)
+
+      if allUniqPredictions.has_key(id):
+         current_prediction = allUniqPredictions[id]
+
+         current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+         new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
+
+         if current_a_score < new_score :
+            allUniqPredictions[id] = new_prediction
+
+      else:
+         allUniqPredictions[id] = new_prediction
+
+   print 'Got %d uniq predictions' % len(allUniqPredictions)
+
+   #for current_prediction in allPredictions:
+   for _id,current_prediction in allUniqPredictions.items():
+      id = current_prediction['id']
+      id = int(id)
+
+      if not id >= 1000000300000:
+         is_spliced = True
+      else:
+         is_spliced = False
+
+      is_covered = False
+
+      if is_spliced and with_coverage:
+         try:
+            current_coverage_nr = coverage_map[id]
+            is_covered = True
+         except:
+            is_covered = False
+
+
+      if is_spliced:
+         spliced_ctr += 1
+      else: 
+         unspliced_ctr += 1
+   
+      try:
+         current_ground_truth = all_filtered_reads[id]
+      except:
+         skipped_ctr += 1
+         continue
+
+      start_pos = current_prediction['start_pos']
+      chr = current_prediction['chr']
+      strand = current_prediction['strand']
+
+      #score = current_prediction['DPScores'].flatten().tolist()[0][0]
+      #pdb.set_trace()
+
+      predExons = current_prediction['predExons'] #:newExons, 'dna':dna, 'est':est
+      predExons = [e+start_pos for e in predExons]
+
+      spliced_flag = False
+
+      if len(predExons) == 4:
+         spliced_flag = True
+         predExons[1] -= 1
+         predExons[3] -= 1
+
+         if predExons[0] == 19504568:
+            pdb.set_trace()
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         e_stop = current_ground_truth['exon_stop']
+         e_start = current_ground_truth['exon_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
+
+         if p_start == predExons[0] and e_stop == predExons[1] and\
+         e_start == predExons[2] and p_stop == predExons[3]:
+            pos_correct = True
+         else:
+            pos_correct = False
+
+      elif len(predExons) == 2:
+         spliced_flag = False
+         predExons[1] -= 1
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
+
+         if math.fabs(p_start - predExons[0]) <= 0:# and math.fabs(p_stop - predExons[1]) <= 2:
+            pos_correct = True
+         else:
+            pos_correct = False
+
+      else:
+         pos_correct = False
+
+      if is_spliced and not spliced_flag:
+         unspliced_spliced_reads_ctr += 1
+
+      if is_spliced and not pos_correct and len(predExons) == 4 and predExons[1]!=-1:
+          wrong_spliced_reads_ctr += 1
+
+      if not is_spliced and spliced_flag:
+         wrong_unspliced_reads_ctr += 1
+
+      if not is_spliced and not pos_correct:
+         wrong_aligned_unspliced_reads_ctr += 1
+
+      if pos_correct:
+         pos_correct_ctr += 1
+
+         if is_spliced:
+               correct_spliced_ctr += 1
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  correct_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               correct_unspliced_ctr += 1
+
+      else:
+         pos_incorrect_ctr += 1
+
+         if is_spliced:
+               incorrect_spliced_ctr += 1
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  incorrect_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               incorrect_unspliced_ctr += 1
+
+      if with_coverage and spliced_flag:
+          if not is_covered:
+              current_coverage_nr=0 
+          if pos_correct:
+              print "%s\tcorrect\t%i" %( current_prediction['id'], current_coverage_nr)
+          else:
+              print "%s\twrong\t%i" %( current_prediction['id'], current_coverage_nr)
+
+      total_ctr += 1
+
+
+   numPredictions = len(allUniqPredictions)
+
+   # now that we have evaluated all instances put out all counters and sizes
+   print 'Total num. of examples: %d' % numPredictions
+
+   print "spliced/unspliced:  %d,%d " % (spliced_ctr, unspliced_ctr )
+   print "Correct/incorrect spliced: %d,%d " % (correct_spliced_ctr, incorrect_spliced_ctr )
+   print "Correct/incorrect unspliced: %d,%d " % (correct_unspliced_ctr , incorrect_unspliced_ctr )
+   print "Correct/incorrect covered spliced read: %d,%d " %\
+   (correct_covered_splice_ctr,incorrect_covered_splice_ctr)
+
+   print "pos_correct: %d,%d" % (pos_correct_ctr ,  pos_incorrect_ctr )
+
+   print 'unspliced_spliced reads: %d' % unspliced_spliced_reads_ctr 
+   print 'spliced reads at wrong_place: %d' % wrong_spliced_reads_ctr
+
+   print 'spliced_unspliced reads:  %d' % wrong_unspliced_reads_ctr
+   print 'wrong aligned at wrong_pos: %d' % wrong_aligned_unspliced_reads_ctr
+
+   print 'total_ctr: %d' % total_ctr
+
+   print "skipped: %d "  % skipped_ctr
+   print 'min. coverage: %d' % min_coverage
+
+   result_dict = {}
+   result_dict['skipped_ctr'] = skipped_ctr
+   result_dict['min_coverage'] = min_coverage
+
+   return result_dict
 
-   #global data
-   #data_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/dataset_remapped_test_new'
-   #data = cPickle.load(open(data_fn))
 
-   #forall_experiments(perform_prediction,dir)
-   forall_experiments(collect_prediction,dir)
+
+
+
+def predict_on(allPredictions,all_filtered_reads,all_labels_fn,with_coverage,coverage_fn,coverage_labels_fn):
+   """
+   This function evaluates the predictions  made by QPalma.
+   It needs a pickled file containing the predictions themselves and the
+   ascii file with original reads.
+
+   Optionally one can specifiy a coverage file containing for each read the
+   coverage number estimated by a remapping step.
+
+
+   """
+
+   coverage_labels_fh = open(coverage_labels_fn,'w+')
+
+   all_labels_fh = open(all_labels_fn,'w+')
+
+   import qparser
+   qparser.parse_reads(all_filtered_reads)
+
+   coverage_map = {}
+
+
+   if with_coverage:
+      for line in open(coverage_fn):
+         id,coverage_nr = line.strip().split()
+         coverage_map[int(id)] = int(coverage_nr)
+
+   #out_fh = open('predicted_positions.txt','w+')
+
+   spliced_ctr = 0
+   unspliced_ctr = 0
+
+   pos_correct_ctr   = 0
+   pos_incorrect_ctr = 0
+
+   correct_spliced_ctr     = 0
+   correct_unspliced_ctr   = 0
+
+   incorrect_spliced_ctr     = 0
+   incorrect_unspliced_ctr   = 0
+
+   correct_covered_splice_ctr = 0
+   incorrect_covered_splice_ctr = 0
+
+   total_vmatch_instances_ctr = 0
+
+   unspliced_spliced_reads_ctr = 0
+   wrong_spliced_reads_ctr = 0
+
+   wrong_aligned_unspliced_reads_ctr = 0
+   wrong_unspliced_reads_ctr = 0
+
+   cut_pos_ctr = {}
+
+   total_ctr = 0
+   skipped_ctr = 0
+
+   is_spliced = False
+   min_coverage = 3
+
+   allUniqPredictions = {}
+
+   print 'Got %d predictions' % len(allPredictions)
+
+   for k,predictions in allPredictions.items():
+      for new_prediction in predictions:
+         id = new_prediction['id']
+         id = int(id)
+
+         if allUniqPredictions.has_key(id):
+            current_prediction = allUniqPredictions[id]
+
+            current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+            new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
+
+            if current_a_score < new_score :
+               allUniqPredictions[id] = new_prediction
+
+         else:
+            allUniqPredictions[id] = new_prediction
+
+   print 'Got %d uniq predictions' % len(allUniqPredictions)
+
+   for _id,current_prediction in allUniqPredictions.items():
+      id = current_prediction['id']
+      id = int(id)
+
+      if not id >= 1000000300000:
+         is_spliced = True
+      else:
+         is_spliced = False
+
+      is_covered = False
+
+      if is_spliced and with_coverage:
+         try:
+            current_coverage_nr = coverage_map[id]
+            is_covered = True
+         except:
+            is_covered = False
+
+
+      if is_spliced:
+         spliced_ctr += 1
+      else: 
+         unspliced_ctr += 1
+   
+      try:
+         #current_ground_truth = all_filtered_reads[id]
+         current_ground_truth = qparser.fetch_read(id)
+      except:
+         skipped_ctr += 1
+         continue
+
+      start_pos = current_prediction['start_pos']
+      chr = current_prediction['chr']
+      strand = current_prediction['strand']
+
+      #score = current_prediction['DPScores'].flatten().tolist()[0][0]
+      #pdb.set_trace()
+
+      predExons = current_prediction['predExons'] #:newExons, 'dna':dna, 'est':est
+      predExons = [e+start_pos for e in predExons]
+
+      spliced_flag = False
+
+      if len(predExons) == 4:
+         spliced_flag = True
+         predExons[1] -= 1
+         predExons[3] -= 1
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         e_stop = current_ground_truth['exon_stop']
+         e_start = current_ground_truth['exon_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
+
+         if p_start == predExons[0] and e_stop == predExons[1] and\
+         e_start == predExons[2] and p_stop == predExons[3]:
+            pos_correct = True
+         else:
+            pos_correct = False
+
+      elif len(predExons) == 2:
+         spliced_flag = False
+         predExons[1] -= 1
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
+
+         if math.fabs(p_start - predExons[0]) <= 0:# and math.fabs(p_stop - predExons[1]) <= 2:
+            pos_correct = True
+         else:
+            pos_correct = False
+
+      else:
+         pos_correct = False
+
+      if is_spliced and not spliced_flag:
+         unspliced_spliced_reads_ctr += 1
+
+      if is_spliced and not pos_correct and len(predExons) == 4 and predExons[1]!=-1:
+          wrong_spliced_reads_ctr += 1
+
+      if not is_spliced and spliced_flag:
+         wrong_unspliced_reads_ctr += 1
+
+      if not is_spliced and not pos_correct:
+         wrong_aligned_unspliced_reads_ctr += 1
+
+      if pos_correct:
+         pos_correct_ctr += 1
+
+         if is_spliced:
+               correct_spliced_ctr += 1
+               all_labels_fh.write('%d correct\n'%id)
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  correct_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               correct_unspliced_ctr += 1
+
+      else:
+         pos_incorrect_ctr += 1
+
+         if is_spliced:
+               incorrect_spliced_ctr += 1
+               all_labels_fh.write('%d wrong\n'%id)
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  incorrect_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               incorrect_unspliced_ctr += 1
+
+      if with_coverage:
+         if not is_covered:
+            current_coverage_nr=0 
+
+         if pos_correct:
+            new_line = "%s\tcorrect\t%i" %( current_prediction['id'], current_coverage_nr)
+         else:
+            new_line = "%s\twrong\t%i" %( current_prediction['id'], current_coverage_nr)
+
+         coverage_labels_fh.write(new_line+'\n')
+
+      total_ctr += 1
+
+   coverage_labels_fh.close()
+
+   numPredictions = len(allUniqPredictions)
+
+   result = []
+
+   # now that we have evaluated all instances put out all counters and sizes
+   result.append(('numPredictions',numPredictions))
+   result.append(('spliced_ctr',spliced_ctr))
+   result.append(('unspliced_ctr',unspliced_ctr))
+
+   result.append(('correct_spliced_ctr',correct_spliced_ctr))
+   result.append(('incorrect_spliced_ctr',incorrect_spliced_ctr))
+
+   result.append(('correct_unspliced_ctr',correct_unspliced_ctr))
+   result.append(('incorrect_unspliced_ctr',incorrect_unspliced_ctr))
+
+   result.append(('correct_covered_splice_ctr',correct_covered_splice_ctr))
+   result.append(('incorrect_covered_splice_ctr',incorrect_covered_splice_ctr))
+
+   result.append(('pos_correct_ctr',pos_correct_ctr))
+   result.append(('pos_incorrect_ctr',pos_incorrect_ctr))
+
+   result.append(('unspliced_spliced_reads_ctr',unspliced_spliced_reads_ctr))
+   result.append(('wrong_spliced_reads_ctr',wrong_spliced_reads_ctr))
+
+   result.append(('wrong_unspliced_reads_ctr',wrong_unspliced_reads_ctr))
+   result.append(('wrong_aligned_unspliced_reads_ctr',wrong_aligned_unspliced_reads_ctr))
+
+   result.append(('total_ctr',total_ctr))
+
+   result.append(('skipped_ctr',skipped_ctr))
+   result.append(('min_coverage',min_coverage))
+
+   return result
+
+
+
+def print_result(result):
+   # now that we have evaluated all instances put out all counters and sizes
+   for name,ctr in result:
+      print name,ctr
+
+
+def load_chunks(current_dir):
+   chunks_fn = []
+   for fn in os.listdir(current_dir):
+      if fn.startswith('chunk'):
+         chunks_fn.append(fn)
+
+   allPredictions = []
+
+   for c_fn in chunks_fn:
+      full_fn = os.path.join(current_dir,c_fn)
+      print full_fn
+      current_chunk = cPickle.load(open(full_fn))
+      allPredictions.extend(current_chunk)
+
+   return allPredictions
+
+
+def predict_on_all_chunks(current_dir,training_keys_fn):
+   """
+   We load all chunks from the current_dir belonging to one run.
+   Then we load the saved keys of the training set to restore the training and
+   testing sets.
+   Once we have done that we separately evaluate both sets.
+
+   """
+   
+   allPredictions = load_chunks(current_dir)
+
+   allPredictionsDict = {}
+   for elem in allPredictions:
+      id = elem['id']
+
+      if allPredictionsDict.has_key(id):
+         old_entry = allPredictionsDict[id]
+         old_entry.append(elem)
+         allPredictionsDict[id] = old_entry
+      else:
+         allPredictionsDict[id] = [elem]
+
+   training_keys = cPickle.load(open(training_keys_fn))
+
+   training_set = {}
+   for key in training_keys:
+      # we have the try construct because some of the reads used for training
+      # may not be found using vmatch at all
+      try:
+         training_set[key] = allPredictionsDict[key]
+         del allPredictionsDict[key]
+      except:
+         pass
+
+   test_set = allPredictionsDict
+
+   #test_set = {}
+   #for k in allPredictionsDict.keys()[:100]:
+   #   test_set[k] = allPredictionsDict[k]
+   #result_train = predict_on(training_set,all_filtered_reads,False,coverage_fn)
+   #pdb.set_trace()
+
+   # this is the heuristic.parsed_spliced_reads.txt file from the vmatch remapping step
+   coverage_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/all_coverages' 
+
+   all_filtered_reads = '/fml/ag-raetsch/share/projects/qpalma/solexa/new_run/allReads.full'
+
+   coverage_labels_fn = 'COVERAGE_LABELS'
+
+   result_test = predict_on(test_set,all_filtered_reads,'all_prediction_labels.txt',True,coverage_fn,coverage_labels_fn)
+   #print_result(result_train)
+
+   return result_test
+   
+
+if __name__ == '__main__':
+   pass