+ added framework code for training modus
[qpalma.git] / scripts / Evaluation.py
index b02ee9c..952eb4c 100644 (file)
 
 import cPickle
 import sys
-import pydb
 import pdb
 import os
 import os.path
 import math
 
+from qpalma.parsers import *
 
-def evaluatePositions(eBegin,eEnd):
-   eBegin_pos = [elem for elem in eBegin if elem == 0]
-   eBegin_neg = [elem for elem in eBegin if elem != 0]
-   eEnd_pos = [elem for elem in eEnd if elem == 0]
-   eEnd_neg = [elem for elem in eEnd if elem != 0]
 
-   mean_eBegin_neg = 0
-   for idx in range(len(eBegin_neg)):
-      mean_eBegin_neg += eBegin_neg[idx]
-      
-   try:
-      mean_eBegin_neg /= 1.0*len(eBegin_neg)
-   except:
-      mean_eBegin_neg = -1
+data = None
 
-   mean_eEnd_neg = 0
-   for idx in range(len(eEnd_neg)):
-      mean_eEnd_neg += eEnd_neg[idx]
 
-   try:
-      mean_eEnd_neg /= 1.0*len(eEnd_neg)
-   except:
-      mean_eEnd_neg = -1
+def result_statistic():
+   """
+
+   """   
+
+   num_unaligned_reads = 0
+   num_incorrectly_aligned_reads = 0
+   pass
+
+   
+def createErrorVSCutPlot(results):
+   """
+   This function takes the results of the evaluation and creates a tex table.
+   """
+
+   fh = open('error_rates_table.tex','w+')
+   lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
+   'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
+   'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
+
+   #for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
+   for pos,key in enumerate(['+++']):
+      res = results[key]
+      for i in range(37):
+         ctr = 0
+         try:
+            ctr = res[1][i]
+         except:
+            ctr = 0
+
+         lines.append( '%d\n' % ctr)
+
+      if pos % 2 == 1:
+         lines.append('\hline')
+
+   lines.append('\end{tabular}')
+
+   lines = [l+'\n' for l in lines]
+   for l in lines:
+      fh.write(l)
+   fh.close()
+
+
+def createTable(results):
+   """
+   This function takes the results of the evaluation and creates a tex table.
+   """
+
+   fh = open('result_table.tex','w+')
+   lines = ['\\begin{tabular}{|c|c|c|r|}', '\hline',\
+   'Quality & Splice & Intron & \multicolumn{1}{c|}{Error on Positions} & \multicolumn{1}{c|}{Error on Scores} & \\',\
+   'information & site pred. & length & \multicolumn{1}{c|}{rate}\\', '\hline']
+
+   for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
+      res = [e*100 for e in results[key]]
+   
+      lines.append( '%s & %s & %s & %2.2f & %2.2f \\%%\\\\' % ( key[0], key[1], key[2], res[0], res[1] ) )
+      if pos % 2 == 1:
+         lines.append('\hline')
+
+   for pos,key in enumerate(['---','+--','-+-','++-','--+','+-+','-++','+++']):
+      res = [e*100 for e in results[key]]
+   
+      lines.append( '%s & %s & %s & %2.2f & x \\%%\\\\' % ( key[0], key[1], key[2], res[2] ) )
+      if pos % 2 == 1:
+         lines.append('\hline')
+
+   lines.append('\end{tabular}')
+
+   lines = [l+'\n' for l in lines]
+   for l in lines:
+      fh.write(l)
+   fh.close()
+
+
+def compare_scores_and_labels(scores,labels):
+   """
+   Iterate through all predictions. If we find a correct prediction check
+   whether this correct prediction scores higher than the incorrect
+   predictions for this example.
+   """
+
+   for currentPos,currentElem in enumerate(scores):
+      if labels[currentPos] == True:
+         for otherPos,otherElem in enumerate(scores):
+            if otherPos == currentPos:
+               continue
+
+            if labels[otherPos] == False and otherElem >= currentElem:
+               return False
+
+   return True
+
+
+def compare_exons(predExons,trueExons):
+   e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
+
+   if len(predExons) == 4:
+      e1_begin,e1_end = predExons[0],predExons[1]
+      e2_begin,e2_end = predExons[2],predExons[3]
+   else:
+      return False
+
+   e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
+   e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
+
+   e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
+   e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
+
+   if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
+   and e2_e_off == 0:
+      return True
+
+   return False
+
+
+def evaluate_unmapped_example(current_prediction):
+   predExons = current_prediction['predExons']
+   trueExons = current_prediction['trueExons']
+
+   result = compare_exons(predExons,trueExons)
+   return result
+
+
+def evaluate_example(current_prediction):
+   label = False
+   label = current_prediction['label'] 
+
+   pred_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+
+   # if the read was mapped by vmatch at an incorrect position we only have to
+   # compare the score
+   if label == False:
+      return label,False,pred_score
+
+   predExons = current_prediction['predExons']
+   trueExons = current_prediction['trueExons']
+
+   predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
+   truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
+
+   pos_comparison = (predPositions == truePositions)
+
+   return label,pos_comparison,pred_score
+
+
+def prediction_on(filename):
+
+   allPredictions = cPickle.load(open(filename))
+
+   gt_correct_ctr = 0
+   gt_incorrect_ctr = 0
+   incorrect_gt_cuts       = {}
+
+   pos_correct_ctr   = 0
+   pos_incorrect_ctr = 0
+   incorrect_vmatch_cuts   = {}
+
+   score_correct_ctr = 0
+   score_incorrect_ctr = 0
+
+   total_gt_examples = 0
+   total_vmatch_instances_ctr = 0
+
+   true_vmatch_instances_ctr = 0
+
+   allUniquePredictions = [False]*len(allPredictions)
+
+   for pos,current_example_pred in enumerate(allPredictions):
+      for elem_nr,new_prediction in enumerate(current_example_pred[1:]):
+
+         if allUniquePredictions[pos] != False:
+            current_prediction = allUniquePredictions[pos]
+
+            current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+            new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
+
+            if current_a_score < new_score :
+               allUniquePredictions[id] = new_prediction
+
+         else:
+            allUniquePredictions[pos] = new_prediction
+
+   for current_pred in allUniquePredictions:
+      if current_pred == False:
+         continue
+
+   #for current_example_pred in allPredictions:
+      #gt_example = current_example_pred[0]
+      #gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
+      #gt_correct = evaluate_unmapped_example(gt_example)
+
+      #exampleIdx = gt_example['exampleIdx']
+
+      #cut_pos = gt_example['true_cut']
+
+      #if gt_correct:
+      #   gt_correct_ctr += 1
+      #else:
+      #   gt_incorrect_ctr += 1
+
+      #   try:
+      #      incorrect_gt_cuts[cut_pos] += 1
+      #   except:
+      #      incorrect_gt_cuts[cut_pos] = 1
+
+      #total_gt_examples += 1
+     
+      #current_scores = []
+      #current_labels = []
+      #for elem_nr,current_pred in enumerate(current_example_pred[1:]):
+
+      current_label,comparison_result,current_score = evaluate_example(current_pred)
 
-   return eBegin_pos,eBegin_neg,eEnd_pos,eEnd_neg,mean_eBegin_neg,mean_eEnd_neg
+      # if vmatch found the right read pos we check for right exons
+      # boundaries
+      #if current_label:
+      if comparison_result:
+         pos_correct_ctr += 1
+      else:
+         pos_incorrect_ctr += 1
+
+            #try:
+            #   incorrect_vmatch_cuts[cut_pos] += 1
+            #except:
+            #   incorrect_vmatch_cuts[cut_pos] = 1
+
+         true_vmatch_instances_ctr += 1
+         
+      #current_scores.append(current_score)
+      #current_labels.append(current_label)
+
+      total_vmatch_instances_ctr += 1
+
+      # check whether the correct predictions score higher than the incorrect
+      # ones
+      #cmp_res = compare_scores_and_labels(current_scores,current_labels)
+      #if cmp_res:
+      #   score_correct_ctr += 1
+      #else:
+      #   score_incorrect_ctr += 1
+
+   # now that we have evaluated all instances put out all counters and sizes
+   print 'Total num. of examples: %d' % len(allPredictions)
+   print 'Number of correct ground truth examples: %d' % gt_correct_ctr
+   print 'Total num. of true vmatch instances %d' % true_vmatch_instances_ctr
+   print 'Correct pos: %d, incorrect pos: %d' % (pos_correct_ctr,pos_incorrect_ctr)
+   print 'Total num. of vmatch instances %d' % total_vmatch_instances_ctr
+   print 'Correct scores: %d, incorrect scores: %d' %\
+   (score_correct_ctr,score_incorrect_ctr)
+
+   pos_error   = 1.0 * pos_incorrect_ctr / total_vmatch_instances_ctr
+   score_error = 1.0 * score_incorrect_ctr / total_vmatch_instances_ctr
+   gt_error    = 1.0 * gt_incorrect_ctr / total_gt_examples
+
+   return (pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts)
+
+
+def collect_prediction(current_dir,run_name):
+   """
+   Given the toplevel directory this function takes care that for each distinct
+   experiment the training and test predictions are evaluated.
+
+   """
+   idx = 5
+
+   train_suffix   = '_%d_allPredictions_TRAIN' % (idx)
+   test_suffix    = '_%d_allPredictions_TEST' % (idx)
+
+   jp = os.path.join
+   b2s = ['-','+']
+
+   currentRun = cPickle.load(open(jp(current_dir,'run_object_%d.pickle'%(idx))))
+   QFlag    = currentRun['enable_quality_scores']
+   SSFlag   = currentRun['enable_splice_signals']
+   ILFlag   = currentRun['enable_intron_length']
+   currentRunId = '%s%s%s' % (b2s[QFlag],b2s[SSFlag],b2s[ILFlag])
+   
+   #filename =  jp(current_dir,run_name)+train_suffix
+   #print 'Prediction on: %s' % filename
+   #train_result = prediction_on(filename)
+   train_result = []
+
+   filename =  jp(current_dir,run_name)+test_suffix
+   print 'Prediction on: %s' % filename
+   test_result = prediction_on(filename)
+
+   return train_result,test_result,currentRunId
 
 
 def perform_prediction(current_dir,run_name):
-   cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s | qsub -l h_vmem=12.0G -cwd -j y -N \"%s.log\"'%(current_dir,run_name)
-   #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
-   #print cmd
-   os.system(cmd)
+   """
+   This function takes care of starting the jobs needed for the prediction phase
+   of qpalma
+   """
+
+   #for i in range(1,6):
+   for i in range(1,2):
+      cmd = 'echo /fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/doPrediction.sh %s %d |\
+      qsub -l h_vmem=12.0G -cwd -j y -N \"%s_%d.log\"'%(current_dir,i,run_name,i)
+
+      #cmd = './doPrediction.sh %s 1>%s.out 2>%s.err' %(current_dir,run_name,run_name)
+      #print cmd
+      os.system(cmd)
+
+
 
 def forall_experiments(current_func,tl_dir):
    """
    Given the toplevel directoy this function calls for each subdir the
-   function given as first argument
+   function given as first argument. Which are at the moment:
+
+   - perform_prediction, and
+   - collect_prediction.
+   
    """
 
    dir_entries = os.listdir(tl_dir)
    dir_entries = [os.path.join(tl_dir,de) for de in dir_entries]
    run_dirs = [de for de in dir_entries if os.path.isdir(de)]
 
+   all_results = {}
+   all_error_rates = {}
+
    for current_dir in run_dirs:
       run_name = current_dir.split('/')[-1]
-      current_func(current_dir,run_name)
 
+      pdb.set_trace()
 
-def collect_prediction(current_dir,run_name):
+      if current_func.__name__ == 'perform_prediction':
+         current_func(current_dir,run_name)
+
+      if current_func.__name__ == 'collect_prediction':
+         train_result,test_result,currentRunId = current_func(current_dir,run_name)
+         all_results[currentRunId] = test_result
+         pos_error,score_error,gt_error,incorrect_gt_cuts,incorrect_vmatch_cuts = test_result
+         all_error_rates[currentRunId] = (incorrect_gt_cuts,incorrect_vmatch_cuts)
+
+   if current_func.__name__ == 'collect_prediction':
+      #createErrorVSCutPlot(all_error_rates)
+      createTable(all_results)
+
+def _predict_on(filename,filtered_reads,with_coverage):
    """
-   Given the toplevel directoy this function calls for each subdir the
+   This function evaluates the predictions  made by QPalma.
+   It needs a pickled file containing the predictions themselves and the
+   ascii file with original reads.
+
+   Optionally one can specifiy a coverage file containing for each read the
+   coverage number estimated by a remapping step.
+
    """
 
-   train_suffix   = '_allPredictions_TRAIN'
-   test_suffix    = '_allPredictions_TEST'
-   jp = os.path.join
-   
-   filename =  jp(current_dir,run_name)+train_suffix
-   print 'Prediction on: %s' % filename
-   prediction_on(filename)
+   coverage_map = {}
 
-   filename =  jp(current_dir,run_name)+test_suffix
-   print 'Prediction on: %s' % filename
-   prediction_on(filename)
+   if with_coverage:
+      for line in open('/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/coverage_results/ALL_COVERAGES'):
+         id,coverage_nr = line.strip().split()
+         coverage_map[int(id)] = int(coverage_nr)
 
+   print 'parsing filtered reads..'
+   all_filtered_reads = parse_filtered_reads(filtered_reads)
+   print 'found %d filtered reads' % len(all_filtered_reads)
+
+   out_fh = open('predicted_positions.txt','w+')
 
-def prediction_on(filename):
    allPredictions = cPickle.load(open(filename))
 
-   exon1Begin     = []
-   exon1End       = []
-   exon2Begin     = []
-   exon2End       = []
-   allWrongExons  = []
-   allDoubleScores = []
+   spliced_ctr = 0
+   unspliced_ctr = 0
 
    pos_correct_ctr   = 0
    pos_incorrect_ctr = 0
-   score_correct_ctr = 0
-   score_incorrect_ctr = 0
 
-   for current_example_pred in allPredictions:
-      gt_example = current_example_pred[0]
-      gt_score = gt_example['DPScores'].flatten().tolist()[0][0]
-      gt_correct = evaluate_unmapped_example(gt_example)
+   correct_spliced_ctr     = 0
+   correct_unspliced_ctr   = 0
+
+   incorrect_spliced_ctr     = 0
+   incorrect_unspliced_ctr   = 0
+
+   correct_covered_splice_ctr = 0
+   incorrect_covered_splice_ctr = 0
+
+   total_vmatch_instances_ctr = 0
+
+   unspliced_spliced_reads_ctr = 0
+   wrong_spliced_reads_ctr = 0
+
+   wrong_aligned_unspliced_reads_ctr = 0
+   wrong_unspliced_reads_ctr = 0
+
+   cut_pos_ctr = {}
+
+   total_ctr = 0
+   skipped_ctr = 0
+
+   is_spliced = False
+   min_coverage = 3
+
+   allUniqPredictions = {}
+
+   print 'Got %d predictions' % len(allPredictions)
+
+   for new_prediction in allPredictions:
+      id = new_prediction['id']
+      id = int(id)
+
+      if allUniqPredictions.has_key(id):
+         current_prediction = allUniqPredictions[id]
 
-      for elem_nr,current_pred in enumerate(current_example_pred[1:]):
+         current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+         new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
 
-         current_label,comparison_result,current_score = evaluate_example(current_pred)
+         if current_a_score < new_score :
+            allUniqPredictions[id] = new_prediction
 
-         if current_label:
-            if comparison_result:
-               pos_correct_ctr += 1
-            else:
-               pos_incorrect_ctr += 1
+      else:
+         allUniqPredictions[id] = new_prediction
+
+   print 'Got %d uniq predictions' % len(allUniqPredictions)
+
+   #for current_prediction in allPredictions:
+   for _id,current_prediction in allUniqPredictions.items():
+      id = current_prediction['id']
+      id = int(id)
+
+      if not id >= 1000000300000:
+         is_spliced = True
+      else:
+         is_spliced = False
+
+      is_covered = False
+
+      if is_spliced and with_coverage:
+         try:
+            current_coverage_nr = coverage_map[id]
+            is_covered = True
+         except:
+            is_covered = False
+
+
+      if is_spliced:
+         spliced_ctr += 1
+      else: 
+         unspliced_ctr += 1
+   
+      try:
+         current_ground_truth = all_filtered_reads[id]
+      except:
+         skipped_ctr += 1
+         continue
+
+      start_pos = current_prediction['start_pos']
+      chr = current_prediction['chr']
+      strand = current_prediction['strand']
+
+      #score = current_prediction['DPScores'].flatten().tolist()[0][0]
+      #pdb.set_trace()
+
+      predExons = current_prediction['predExons'] #:newExons, 'dna':dna, 'est':est
+      predExons = [e+start_pos for e in predExons]
+
+      spliced_flag = False
+
+      if len(predExons) == 4:
+         spliced_flag = True
+         predExons[1] -= 1
+         predExons[3] -= 1
+
+         if predExons[0] == 19504568:
+            pdb.set_trace()
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         e_stop = current_ground_truth['exon_stop']
+         e_start = current_ground_truth['exon_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
+
+         if p_start == predExons[0] and e_stop == predExons[1] and\
+         e_start == predExons[2] and p_stop == predExons[3]:
+            pos_correct = True
          else:
-            if gt_correct and current_label == False and gt_score >= current_score:
-               score_correct_ctr += 1
-            elif gt_correct and current_label == False and not (gt_score >= current_score):
-               score_incorrect_ctr += 1
+            pos_correct = False
 
-   print 'Total num. of examples: %d' % len(allPredictions)
-   print 'Number of correct examples: %d' % pos_correct_ctr
-   print 'Number of incorrect examples: %d' % pos_incorrect_ctr
-   print 'Number of correct scores : %d' % score_correct_ctr
-   print 'Number of incorrect scores: %d' % score_incorrect_ctr
+      elif len(predExons) == 2:
+         spliced_flag = False
+         predExons[1] -= 1
 
-   #print 'Correct positions:\t\t%d\t%d\t%d\t%d' % (len(e1Begin_pos),len(e1End_pos),len(e2Begin_pos),len(e2End_pos))
-   #print 'Incorrect positions:\t\t%d\t%d\t%d\t%d' % (len(e1Begin_neg),len(e1End_neg),len(e2Begin_neg),len(e2End_neg))
-   #print 'Mean of pos. offset:\t\t%.2f\t%.2f\t%.2f\t%.2f' % (mean_e1Begin_neg,mean_e1End_neg,mean_e2Begin_neg,mean_e2End_neg)
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         p_stop = current_ground_truth['p_stop']
 
-def evaluate_unmapped_example(current_prediction):
+         true_cut = current_ground_truth['true_cut']
 
-   predExons = current_prediction['predExons']
-   trueExons = current_prediction['trueExons']
+         if math.fabs(p_start - predExons[0]) <= 0:# and math.fabs(p_stop - predExons[1]) <= 2:
+            pos_correct = True
+         else:
+            pos_correct = False
 
-   return compare_exons(predExons,trueExons)
+      else:
+         pos_correct = False
 
+      if is_spliced and not spliced_flag:
+         unspliced_spliced_reads_ctr += 1
 
-def evaluate_example(current_prediction):
-   label = False
+      if is_spliced and not pos_correct and len(predExons) == 4 and predExons[1]!=-1:
+          wrong_spliced_reads_ctr += 1
 
-   try:
-      label = current_prediction['label'] 
-   except:
-      pass
+      if not is_spliced and spliced_flag:
+         wrong_unspliced_reads_ctr += 1
 
-   predExons = current_prediction['predExons']
-   trueExons = current_prediction['trueExons']
+      if not is_spliced and not pos_correct:
+         wrong_aligned_unspliced_reads_ctr += 1
 
-   predPositions = [elem + current_prediction['alternative_start_pos'] for elem in predExons]
-   truePositions = [elem + current_prediction['start_pos'] for elem in trueExons.flatten().tolist()[0]]
-   #pdb.set_trace()
+      if pos_correct:
+         pos_correct_ctr += 1
+
+         if is_spliced:
+               correct_spliced_ctr += 1
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  correct_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               correct_unspliced_ctr += 1
+
+      else:
+         pos_incorrect_ctr += 1
+
+         if is_spliced:
+               incorrect_spliced_ctr += 1
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  incorrect_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               incorrect_unspliced_ctr += 1
+
+      if with_coverage and spliced_flag:
+          if not is_covered:
+              current_coverage_nr=0 
+          if pos_correct:
+              print "%s\tcorrect\t%i" %( current_prediction['id'], current_coverage_nr)
+          else:
+              print "%s\twrong\t%i" %( current_prediction['id'], current_coverage_nr)
+
+      total_ctr += 1
+
+
+   numPredictions = len(allUniqPredictions)
+
+   # now that we have evaluated all instances put out all counters and sizes
+   print 'Total num. of examples: %d' % numPredictions
+
+   print "spliced/unspliced:  %d,%d " % (spliced_ctr, unspliced_ctr )
+   print "Correct/incorrect spliced: %d,%d " % (correct_spliced_ctr, incorrect_spliced_ctr )
+   print "Correct/incorrect unspliced: %d,%d " % (correct_unspliced_ctr , incorrect_unspliced_ctr )
+   print "Correct/incorrect covered spliced read: %d,%d " %\
+   (correct_covered_splice_ctr,incorrect_covered_splice_ctr)
+
+   print "pos_correct: %d,%d" % (pos_correct_ctr ,  pos_incorrect_ctr )
+
+   print 'unspliced_spliced reads: %d' % unspliced_spliced_reads_ctr 
+   print 'spliced reads at wrong_place: %d' % wrong_spliced_reads_ctr
+
+   print 'spliced_unspliced reads:  %d' % wrong_unspliced_reads_ctr
+   print 'wrong aligned at wrong_pos: %d' % wrong_aligned_unspliced_reads_ctr
+
+   print 'total_ctr: %d' % total_ctr
+
+   print "skipped: %d "  % skipped_ctr
+   print 'min. coverage: %d' % min_coverage
+
+   result_dict = {}
+   result_dict['skipped_ctr'] = skipped_ctr
+   result_dict['min_coverage'] = min_coverage
+
+   return result_dict
+
+
+
+
+
+def predict_on(allPredictions,all_filtered_reads,all_labels_fn,with_coverage,coverage_fn,coverage_labels_fn):
+   """
+   This function evaluates the predictions  made by QPalma.
+   It needs a pickled file containing the predictions themselves and the
+   ascii file with original reads.
+
+   Optionally one can specifiy a coverage file containing for each read the
+   coverage number estimated by a remapping step.
+
+
+   """
+
+   coverage_labels_fh = open(coverage_labels_fn,'w+')
+
+   all_labels_fh = open(all_labels_fn,'w+')
+
+   import qparser
+   qparser.parse_reads(all_filtered_reads)
+
+   coverage_map = {}
+
+
+   if with_coverage:
+      for line in open(coverage_fn):
+         id,coverage_nr = line.strip().split()
+         coverage_map[int(id)] = int(coverage_nr)
+
+   #out_fh = open('predicted_positions.txt','w+')
+
+   spliced_ctr = 0
+   unspliced_ctr = 0
+
+   pos_correct_ctr   = 0
+   pos_incorrect_ctr = 0
+
+   correct_spliced_ctr     = 0
+   correct_unspliced_ctr   = 0
+
+   incorrect_spliced_ctr     = 0
+   incorrect_unspliced_ctr   = 0
+
+   correct_covered_splice_ctr = 0
+   incorrect_covered_splice_ctr = 0
+
+   total_vmatch_instances_ctr = 0
+
+   unspliced_spliced_reads_ctr = 0
+   wrong_spliced_reads_ctr = 0
+
+   wrong_aligned_unspliced_reads_ctr = 0
+   wrong_unspliced_reads_ctr = 0
+
+   cut_pos_ctr = {}
+
+   total_ctr = 0
+   skipped_ctr = 0
+
+   is_spliced = False
+   min_coverage = 3
+
+   allUniqPredictions = {}
+
+   print 'Got %d predictions' % len(allPredictions)
+
+   for k,predictions in allPredictions.items():
+      for new_prediction in predictions:
+         id = new_prediction['id']
+         id = int(id)
+
+         if allUniqPredictions.has_key(id):
+            current_prediction = allUniqPredictions[id]
+
+            current_a_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+            new_score =  new_prediction['DPScores'].flatten().tolist()[0][0]
+
+            if current_a_score < new_score :
+               allUniqPredictions[id] = new_prediction
+
+         else:
+            allUniqPredictions[id] = new_prediction
+
+   print 'Got %d uniq predictions' % len(allUniqPredictions)
+
+   for _id,current_prediction in allUniqPredictions.items():
+      id = current_prediction['id']
+      id = int(id)
+
+      if not id >= 1000000300000:
+         is_spliced = True
+      else:
+         is_spliced = False
+
+      is_covered = False
+
+      if is_spliced and with_coverage:
+         try:
+            current_coverage_nr = coverage_map[id]
+            is_covered = True
+         except:
+            is_covered = False
+
+
+      if is_spliced:
+         spliced_ctr += 1
+      else: 
+         unspliced_ctr += 1
    
-   pred_score = current_prediction['DPScores'].flatten().tolist()[0][0]
+      try:
+         #current_ground_truth = all_filtered_reads[id]
+         current_ground_truth = qparser.fetch_read(id)
+      except:
+         skipped_ctr += 1
+         continue
 
-   altPredPositions = [0]*4
-   if len(predPositions) == 4:
-      altPredPositions[0] = predPositions[0]
-      altPredPositions[1] = predPositions[1]+1
-      altPredPositions[2] = predPositions[2]+1
-      altPredPositions[3] = predPositions[3]
-
-   altPredPositions2 = [0]*4
-   if len(predPositions) == 4:
-      altPredPositions2[0] = predPositions[0]
-      altPredPositions2[1] = predPositions[1]
-      altPredPositions2[2] = predPositions[2]+1
-      altPredPositions2[3] = predPositions[3]
-
-   altPredPositions3 = [0]*4
-   if len(predPositions) == 4:
-      altPredPositions3[0] = predPositions[0]
-      altPredPositions3[1] = predPositions[1]
-      altPredPositions3[2] = predPositions[2]-1
-      altPredPositions3[3] = predPositions[3]
-
-   pos_comparison = (predPositions == truePositions or altPredPositions == truePositions or altPredPositions2 == truePositions or altPredPositions3 == truePositions)
-
-   if label == True and pos_comparison == False:
-      pdb.set_trace()
+      start_pos = current_prediction['start_pos']
+      chr = current_prediction['chr']
+      strand = current_prediction['strand']
 
-   return label,pos_comparison,pred_score
+      #score = current_prediction['DPScores'].flatten().tolist()[0][0]
+      #pdb.set_trace()
 
-def compare_exons(predExons,trueExons):
-   e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
+      predExons = current_prediction['predExons'] #:newExons, 'dna':dna, 'est':est
+      predExons = [e+start_pos for e in predExons]
 
-   if len(predExons) == 4:
-      e1_begin,e1_end = predExons[0],predExons[1]
-      e2_begin,e2_end = predExons[2],predExons[3]
-   else:
-      return False
+      spliced_flag = False
 
-   e1_b_off = int(math.fabs(e1_begin - trueExons[0,0]))
-   e1_e_off = int(math.fabs(e1_end - trueExons[0,1]))
+      if len(predExons) == 4:
+         spliced_flag = True
+         predExons[1] -= 1
+         predExons[3] -= 1
 
-   e2_b_off = int(math.fabs(e2_begin - trueExons[1,0]))
-   e2_e_off = int(math.fabs(e2_end - trueExons[1,1]))
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         e_stop = current_ground_truth['exon_stop']
+         e_start = current_ground_truth['exon_start']
+         p_stop = current_ground_truth['p_stop']
 
-   if e1_b_off == 0 and e1_e_off == 0 and e2_b_off == 0\
-   and e2_e_off == 0:
-      return True
+         true_cut = current_ground_truth['true_cut']
 
-   return False
+         if p_start == predExons[0] and e_stop == predExons[1] and\
+         e_start == predExons[2] and p_stop == predExons[3]:
+            pos_correct = True
+         else:
+            pos_correct = False
+
+      elif len(predExons) == 2:
+         spliced_flag = False
+         predExons[1] -= 1
+
+         cut_pos = current_ground_truth['true_cut']
+         p_start = current_ground_truth['p_start']
+         p_stop = current_ground_truth['p_stop']
+
+         true_cut = current_ground_truth['true_cut']
 
+         if math.fabs(p_start - predExons[0]) <= 0:# and math.fabs(p_stop - predExons[1]) <= 2:
+            pos_correct = True
+         else:
+            pos_correct = False
 
-def evaluateExample(dna,est,exons,SpliceAlign,newEstAlign):
-   newExons = []
-   oldElem = -1
-   SpliceAlign = SpliceAlign.flatten().tolist()[0]
-   SpliceAlign.append(-1)
-   for pos,elem in enumerate(SpliceAlign):
-      if pos == 0:
-         oldElem = -1
       else:
-         oldElem = SpliceAlign[pos-1]
+         pos_correct = False
 
-      if oldElem != 0 and elem == 0: # start of exon
-         newExons.append(pos)
+      if is_spliced and not spliced_flag:
+         unspliced_spliced_reads_ctr += 1
 
-      if oldElem == 0 and elem != 0: # end of exon
-         newExons.append(pos)
+      if is_spliced and not pos_correct and len(predExons) == 4 and predExons[1]!=-1:
+          wrong_spliced_reads_ctr += 1
 
-   e1_b_off,e1_e_off,e2_b_off,e2_e_off = 0,0,0,0
+      if not is_spliced and spliced_flag:
+         wrong_unspliced_reads_ctr += 1
 
-   if len(newExons) == 4:
-      e1_begin,e1_end = newExons[0],newExons[1]
-      e2_begin,e2_end = newExons[2],newExons[3]
-   else:
-      return None,None,None,None,newExons
+      if not is_spliced and not pos_correct:
+         wrong_aligned_unspliced_reads_ctr += 1
 
-   e1_b_off = int(math.fabs(e1_begin - exons[0,0]))
-   e1_e_off = int(math.fabs(e1_end - exons[0,1]))
+      if pos_correct:
+         pos_correct_ctr += 1
 
-   e2_b_off = int(math.fabs(e2_begin - exons[1,0]))
-   e2_e_off = int(math.fabs(e2_end - exons[1,1]))
+         if is_spliced:
+               correct_spliced_ctr += 1
+               all_labels_fh.write('%d correct\n'%id)
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  correct_covered_splice_ctr += 1 
 
-   return e1_b_off,e1_e_off,e2_b_off,e2_e_off,newExons
+         if not is_spliced:
+               correct_unspliced_ctr += 1
 
+      else:
+         pos_incorrect_ctr += 1
 
-if __name__ == '__main__':
-   dir = sys.argv[1]
-   assert os.path.exists(dir), 'Error directory does not exist!'
-
-   forall_experiments(perform_prediction,dir)
-   #forall_experiments(collect_prediction,dir)
-
-"""
-         if elem_nr > 0:
-            #print 'start positions'
-            #print current_pred['start_pos'], current_pred['alternative_start_pos']
-
-            if current_pred['label'] == False or (current_pred['label'] == True
-            and len(current_pred['predExons']) != 4):
-               if 
-               current_example_pred[0]['DPScores'].flatten().tolist()[0][0]:
-                  print current_pred['trueExons'][0,1]-current_pred['trueExons'][0,0],\
-                  current_pred['trueExons'][1,1] - current_pred['trueExons'][1,0],\
-                  current_pred['predExons']
-                  print current_pred['DPScores'].flatten().tolist()[0][0],\
-                  current_example_pred[0]['DPScores'].flatten().tolist()[0][0]
-                  ctr += 1
-                  print ctr
-
-         if e1_b_off != None:
-            exon1Begin.append(e1_b_off)
-            exon1End.append(e1_e_off)
-            exon2Begin.append(e2_b_off)
-            exon2End.append(e2_e_off)
+         if is_spliced:
+               incorrect_spliced_ctr += 1
+               all_labels_fh.write('%d wrong\n'%id)
+               if with_coverage and is_covered and current_coverage_nr >= min_coverage:
+                  incorrect_covered_splice_ctr += 1 
+
+         if not is_spliced:
+               incorrect_unspliced_ctr += 1
+
+      if with_coverage:
+         if not is_covered:
+            current_coverage_nr=0 
+
+         if pos_correct:
+            new_line = "%s\tcorrect\t%i" %( current_prediction['id'], current_coverage_nr)
          else:
-            pass
-            #allWrongExons.append((newExons,exons))
+            new_line = "%s\twrong\t%i" %( current_prediction['id'], current_coverage_nr)
+
+         coverage_labels_fh.write(new_line+'\n')
+
+      total_ctr += 1
+
+   coverage_labels_fh.close()
+
+   numPredictions = len(allUniqPredictions)
+
+   result = []
+
+   # now that we have evaluated all instances put out all counters and sizes
+   result.append(('numPredictions',numPredictions))
+   result.append(('spliced_ctr',spliced_ctr))
+   result.append(('unspliced_ctr',unspliced_ctr))
+
+   result.append(('correct_spliced_ctr',correct_spliced_ctr))
+   result.append(('incorrect_spliced_ctr',incorrect_spliced_ctr))
+
+   result.append(('correct_unspliced_ctr',correct_unspliced_ctr))
+   result.append(('incorrect_unspliced_ctr',incorrect_unspliced_ctr))
+
+   result.append(('correct_covered_splice_ctr',correct_covered_splice_ctr))
+   result.append(('incorrect_covered_splice_ctr',incorrect_covered_splice_ctr))
+
+   result.append(('pos_correct_ctr',pos_correct_ctr))
+   result.append(('pos_incorrect_ctr',pos_incorrect_ctr))
+
+   result.append(('unspliced_spliced_reads_ctr',unspliced_spliced_reads_ctr))
+   result.append(('wrong_spliced_reads_ctr',wrong_spliced_reads_ctr))
+
+   result.append(('wrong_unspliced_reads_ctr',wrong_unspliced_reads_ctr))
+   result.append(('wrong_aligned_unspliced_reads_ctr',wrong_aligned_unspliced_reads_ctr))
+
+   result.append(('total_ctr',total_ctr))
+
+   result.append(('skipped_ctr',skipped_ctr))
+   result.append(('min_coverage',min_coverage))
+
+   return result
+
+
+
+def print_result(result):
+   # now that we have evaluated all instances put out all counters and sizes
+   for name,ctr in result:
+      print name,ctr
 
-         if ambigous_match == True:
-            current_score = current_pred['DPScores'][0]
-            example_scores.append(current_score)
 
-         e1Begin_pos,e1Begin_neg,e1End_pos,e1End_neg,mean_e1Begin_neg,mean_e1End_neg = evaluatePositions(exon1Begin,exon1End)
-         e2Begin_pos,e2Begin_neg,e2End_pos,e2End_neg,mean_e2Begin_neg,mean_e2End_neg = evaluatePositions(exon2Begin,exon2End)
+def load_chunks(current_dir):
+   chunks_fn = []
+   for fn in os.listdir(current_dir):
+      if fn.startswith('chunk'):
+         chunks_fn.append(fn)
 
-      allDoubleScores.append(example_scores)
-"""
+   allPredictions = []
+
+   for c_fn in chunks_fn:
+      full_fn = os.path.join(current_dir,c_fn)
+      print full_fn
+      current_chunk = cPickle.load(open(full_fn))
+      allPredictions.extend(current_chunk)
+
+   return allPredictions
+
+
+def predict_on_all_chunks(current_dir,training_keys_fn):
+   """
+   We load all chunks from the current_dir belonging to one run.
+   Then we load the saved keys of the training set to restore the training and
+   testing sets.
+   Once we have done that we separately evaluate both sets.
+
+   """
+   
+   allPredictions = load_chunks(current_dir)
+
+   allPredictionsDict = {}
+   for elem in allPredictions:
+      id = elem['id']
+
+      if allPredictionsDict.has_key(id):
+         old_entry = allPredictionsDict[id]
+         old_entry.append(elem)
+         allPredictionsDict[id] = old_entry
+      else:
+         allPredictionsDict[id] = [elem]
+
+   training_keys = cPickle.load(open(training_keys_fn))
+
+   training_set = {}
+   for key in training_keys:
+      # we have the try construct because some of the reads used for training
+      # may not be found using vmatch at all
+      try:
+         training_set[key] = allPredictionsDict[key]
+         del allPredictionsDict[key]
+      except:
+         pass
+
+   test_set = allPredictionsDict
+
+   #test_set = {}
+   #for k in allPredictionsDict.keys()[:100]:
+   #   test_set[k] = allPredictionsDict[k]
+   #result_train = predict_on(training_set,all_filtered_reads,False,coverage_fn)
+   #pdb.set_trace()
+
+   # this is the heuristic.parsed_spliced_reads.txt file from the vmatch remapping step
+   coverage_fn = '/fml/ag-raetsch/home/fabio/svn/projects/QPalma/scripts/all_coverages' 
+
+   all_filtered_reads = '/fml/ag-raetsch/share/projects/qpalma/solexa/new_run/allReads.full'
+
+   coverage_labels_fn = 'COVERAGE_LABELS'
+
+   result_test = predict_on(test_set,all_filtered_reads,'all_prediction_labels.txt',True,coverage_fn,coverage_labels_fn)
+   #print_result(result_train)
+
+   return result_test
+   
+
+if __name__ == '__main__':
+   pass