src/HOL/Tools/Sledgehammer/MaSh/src/stats.py
changeset 50220 90280d85cd03
child 50222 40e3c3be6bca
equal deleted inserted replaced
50219:f6b95f0bba78 50220:90280d85cd03
       
     1 '''
       
     2 Created on Jul 9, 2012
       
     3 
       
     4 @author: Daniel Kuehlwein
       
     5 '''
       
     6 
       
     7 import logging,string
       
     8 from cPickle import load,dump
       
     9 
       
    10 class Statistics(object):
       
    11     '''
       
    12     Class for all the statistics
       
    13     '''
       
    14 
       
    15     def __init__(self,cutOff=500):
       
    16         '''
       
    17         Constructor
       
    18         '''
       
    19         self.logger = logging.getLogger('Statistics')
       
    20         self.avgAUC = 0.0
       
    21         self.avgRecall100 = 0.0
       
    22         self.avgAvailable = 0.0
       
    23         self.avgDepNr = 0.0
       
    24         self.problems = 0.0
       
    25         self.cutOff = cutOff
       
    26         self.recallData = [0]*cutOff
       
    27         self.recall100Data = [0]*cutOff
       
    28         self.aucData = []
       
    29         
       
    30     def update(self,predictions,dependencies):
       
    31         """
       
    32         Evaluates AUC, dependencies, recall100 and number of available premises of a prediction.
       
    33         """
       
    34 
       
    35         available = len(predictions)
       
    36         predictions = predictions[:self.cutOff]
       
    37         dependencies = set(dependencies)
       
    38         depNr = len(dependencies)
       
    39         aucSum = 0.    
       
    40         posResults = 0.        
       
    41         positives, negatives = 0, 0
       
    42         recall100 = 0.0
       
    43         badPreds = []
       
    44         depsFound = []
       
    45         for index,pId in enumerate(predictions):
       
    46             if pId in dependencies:        #positive
       
    47                 posResults+=1
       
    48                 positives+=1
       
    49                 recall100 = index+1
       
    50                 depsFound.append(pId)
       
    51                 if index > 200:
       
    52                     badPreds.append(pId)
       
    53             else:            
       
    54                 aucSum += posResults
       
    55                 negatives+=1
       
    56             # Update Recall and Recall100 stats
       
    57             if depNr == positives:
       
    58                 self.recall100Data[index] += 1
       
    59             if depNr == 0:
       
    60                 self.recallData[index] += 1
       
    61             else:
       
    62                 self.recallData[index] += float(positives)/depNr
       
    63     
       
    64         if not depNr == positives:
       
    65             depsFound = set(depsFound)
       
    66             missing = []
       
    67             for dep in dependencies:
       
    68                 if not dep in depsFound:
       
    69                     missing.append(dep)
       
    70                     badPreds.append(dep)
       
    71                     recall100 = len(predictions)+1
       
    72                     positives+=1
       
    73             self.logger.debug('Dependencies missing for %s in accessibles! Estimating Statistics.',\
       
    74                               string.join([str(dep) for dep in missing],','))
       
    75     
       
    76         if positives == 0 or negatives == 0:
       
    77             auc = 1.0
       
    78         else:            
       
    79             auc = aucSum/(negatives*positives)
       
    80         
       
    81         self.aucData.append(auc)
       
    82         self.avgAUC += auc
       
    83         self.avgRecall100 += recall100
       
    84         self.problems += 1
       
    85         self.badPreds = badPreds
       
    86         self.avgAvailable += available 
       
    87         self.avgDepNr += depNr
       
    88         self.logger.info('AUC: %s \t Needed: %s \t Recall100: %s \t Available: %s \t cutOff:%s',\
       
    89                           round(100*auc,2),depNr,recall100,available,self.cutOff)        
       
    90         
       
    91     def printAvg(self):
       
    92         self.logger.info('Average results:')
       
    93         self.logger.info('avgAUC: %s \t avgDepNr: %s \t avgRecall100: %s \t cutOff:%s', \
       
    94                          round(100*self.avgAUC/self.problems,2),round(self.avgDepNr/self.problems,2),self.avgRecall100/self.problems,self.cutOff)
       
    95 
       
    96         try:
       
    97             from matplotlib.pyplot import plot,figure,show,xlabel,ylabel,axis,hist
       
    98             avgRecall = [float(x)/self.problems for x in self.recallData]
       
    99             figure('Recall')
       
   100             plot(range(self.cutOff),avgRecall)
       
   101             ylabel('Average Recall')
       
   102             xlabel('Highest ranked premises')
       
   103             axis([0,self.cutOff,0.0,1.0])
       
   104             figure('100%Recall')
       
   105             plot(range(self.cutOff),self.recall100Data)
       
   106             ylabel('100%Recall')
       
   107             xlabel('Highest ranked premises')
       
   108             axis([0,self.cutOff,0,self.problems])
       
   109             figure('AUC Histogram')
       
   110             hist(self.aucData,bins=100)
       
   111             ylabel('Problems')
       
   112             xlabel('AUC')
       
   113             show()
       
   114         except:
       
   115             self.logger.warning('Matplotlib module missing. Skipping graphs.')
       
   116     
       
   117     def save(self,fileName):       
       
   118         oStream = open(fileName, 'wb')
       
   119         dump((self.avgAUC,self.avgRecall100,self.avgAvailable,self.avgDepNr,self.problems,self.cutOff,self.recallData,self.recall100Data,self.aucData),oStream)
       
   120         oStream.close()
       
   121     def load(self,fileName):
       
   122         iStream = open(fileName, 'rb')        
       
   123         self.avgAUC,self.avgRecall100,self.avgAvailable,self.avgDepNr,self.problems,self.cutOff,self.recallData,self.recall100Data,self.aucData = load(iStream)
       
   124         iStream.close()