50222
|
1 |
# Title: HOL/Tools/Sledgehammer/MaSh/src/stats.py
|
|
2 |
# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
|
|
3 |
# Copyright 2012
|
|
4 |
#
|
|
5 |
# Statistics collector.
|
|
6 |
|
50220
|
7 |
'''
|
|
8 |
Created on Jul 9, 2012
|
|
9 |
|
|
10 |
@author: Daniel Kuehlwein
|
|
11 |
'''
|
|
12 |
|
|
13 |
import logging,string
|
|
14 |
from cPickle import load,dump
|
|
15 |
|
|
16 |
class Statistics(object):
|
|
17 |
'''
|
|
18 |
Class for all the statistics
|
|
19 |
'''
|
|
20 |
|
|
21 |
def __init__(self,cutOff=500):
|
|
22 |
'''
|
|
23 |
Constructor
|
|
24 |
'''
|
|
25 |
self.logger = logging.getLogger('Statistics')
|
|
26 |
self.avgAUC = 0.0
|
|
27 |
self.avgRecall100 = 0.0
|
|
28 |
self.avgAvailable = 0.0
|
|
29 |
self.avgDepNr = 0.0
|
|
30 |
self.problems = 0.0
|
|
31 |
self.cutOff = cutOff
|
|
32 |
self.recallData = [0]*cutOff
|
53100
|
33 |
self.recall100Median = []
|
50220
|
34 |
self.recall100Data = [0]*cutOff
|
|
35 |
self.aucData = []
|
50388
|
36 |
self.premiseOccurenceCounter = {}
|
|
37 |
self.firstDepAppearance = {}
|
|
38 |
self.depAppearances = []
|
|
39 |
|
|
40 |
def update(self,predictions,dependencies,statementCounter):
|
50220
|
41 |
"""
|
|
42 |
Evaluates AUC, dependencies, recall100 and number of available premises of a prediction.
|
|
43 |
"""
|
|
44 |
available = len(predictions)
|
|
45 |
predictions = predictions[:self.cutOff]
|
|
46 |
dependencies = set(dependencies)
|
50388
|
47 |
# No Stats for if no dependencies
|
|
48 |
if len(dependencies) == 0:
|
|
49 |
self.logger.debug('No Dependencies for statement %s' % statementCounter )
|
|
50 |
self.badPreds = []
|
|
51 |
return
|
|
52 |
if len(predictions) < self.cutOff:
|
|
53 |
for i in range(len(predictions),self.cutOff):
|
|
54 |
self.recall100Data[i] += 1
|
|
55 |
self.recallData[i] += 1
|
|
56 |
for d in dependencies:
|
|
57 |
if self.premiseOccurenceCounter.has_key(d):
|
|
58 |
self.premiseOccurenceCounter[d] += 1
|
|
59 |
else:
|
|
60 |
self.premiseOccurenceCounter[d] = 1
|
|
61 |
if self.firstDepAppearance.has_key(d):
|
|
62 |
self.depAppearances.append(statementCounter-self.firstDepAppearance[d])
|
|
63 |
else:
|
|
64 |
self.firstDepAppearance[d] = statementCounter
|
50220
|
65 |
depNr = len(dependencies)
|
50388
|
66 |
aucSum = 0.
|
|
67 |
posResults = 0.
|
50220
|
68 |
positives, negatives = 0, 0
|
|
69 |
recall100 = 0.0
|
|
70 |
badPreds = []
|
|
71 |
depsFound = []
|
|
72 |
for index,pId in enumerate(predictions):
|
|
73 |
if pId in dependencies: #positive
|
|
74 |
posResults+=1
|
|
75 |
positives+=1
|
|
76 |
recall100 = index+1
|
|
77 |
depsFound.append(pId)
|
|
78 |
if index > 200:
|
|
79 |
badPreds.append(pId)
|
50388
|
80 |
else:
|
50220
|
81 |
aucSum += posResults
|
|
82 |
negatives+=1
|
|
83 |
# Update Recall and Recall100 stats
|
|
84 |
if depNr == positives:
|
|
85 |
self.recall100Data[index] += 1
|
|
86 |
if depNr == 0:
|
|
87 |
self.recallData[index] += 1
|
|
88 |
else:
|
|
89 |
self.recallData[index] += float(positives)/depNr
|
50388
|
90 |
|
50220
|
91 |
if not depNr == positives:
|
|
92 |
depsFound = set(depsFound)
|
|
93 |
missing = []
|
|
94 |
for dep in dependencies:
|
|
95 |
if not dep in depsFound:
|
|
96 |
missing.append(dep)
|
|
97 |
badPreds.append(dep)
|
|
98 |
recall100 = len(predictions)+1
|
|
99 |
positives+=1
|
53555
|
100 |
self.logger.debug('Dependencies missing for %s in cutoff predictions! Estimating Statistics.',\
|
50220
|
101 |
string.join([str(dep) for dep in missing],','))
|
50388
|
102 |
|
50220
|
103 |
if positives == 0 or negatives == 0:
|
|
104 |
auc = 1.0
|
50388
|
105 |
else:
|
50220
|
106 |
auc = aucSum/(negatives*positives)
|
50388
|
107 |
|
50220
|
108 |
self.aucData.append(auc)
|
|
109 |
self.avgAUC += auc
|
|
110 |
self.avgRecall100 += recall100
|
53100
|
111 |
self.recall100Median.append(recall100)
|
50220
|
112 |
self.problems += 1
|
|
113 |
self.badPreds = badPreds
|
50388
|
114 |
self.avgAvailable += available
|
50220
|
115 |
self.avgDepNr += depNr
|
53555
|
116 |
self.logger.info('Statement: %s: AUC: %s \t Needed: %s \t Recall100: %s \t Available: %s \t cutOff: %s',\
|
50388
|
117 |
statementCounter,round(100*auc,2),depNr,recall100,available,self.cutOff)
|
|
118 |
|
50220
|
119 |
def printAvg(self):
|
|
120 |
self.logger.info('Average results:')
|
53100
|
121 |
#self.logger.info('avgAUC: %s \t avgDepNr: %s \t avgRecall100: %s \t cutOff:%s', \
|
|
122 |
# round(100*self.avgAUC/self.problems,2),round(self.avgDepNr/self.problems,2),round(self.avgRecall100/self.problems,2),self.cutOff)
|
|
123 |
# HACK FOR PAPER
|
|
124 |
assert len(self.aucData) == len(self.recall100Median)
|
|
125 |
nrDataPoints = len(self.aucData)
|
53156
|
126 |
if nrDataPoints == 0:
|
|
127 |
return "No data points"
|
53100
|
128 |
if nrDataPoints % 2 == 1:
|
|
129 |
medianAUC = sorted(self.aucData)[nrDataPoints/2 + 1]
|
|
130 |
else:
|
|
131 |
medianAUC = float(sorted(self.aucData)[nrDataPoints/2] + sorted(self.aucData)[nrDataPoints/2 + 1])/2
|
|
132 |
#nrDataPoints = len(self.recall100Median)
|
|
133 |
if nrDataPoints % 2 == 1:
|
|
134 |
medianrecall100 = sorted(self.recall100Median)[nrDataPoints/2 + 1]
|
|
135 |
else:
|
|
136 |
medianrecall100 = float(sorted(self.recall100Median)[nrDataPoints/2] + sorted(self.recall100Median)[nrDataPoints/2 + 1])/2
|
|
137 |
|
53555
|
138 |
returnString = 'avgAUC: %s \t medianAUC: %s \t avgDepNr: %s \t avgRecall100: %s \t medianRecall100: %s \t cutOff: %s' %\
|
53100
|
139 |
(round(100*self.avgAUC/self.problems,2),round(100*medianAUC,2),round(self.avgDepNr/self.problems,2),round(self.avgRecall100/self.problems,2),round(medianrecall100,2),self.cutOff)
|
|
140 |
self.logger.info(returnString)
|
|
141 |
return returnString
|
|
142 |
|
|
143 |
"""
|
|
144 |
self.logger.info('avgAUC: %s \t medianAUC: %s \t avgDepNr: %s \t avgRecall100: %s \t medianRecall100: %s \t cutOff:%s', \
|
|
145 |
round(100*self.avgAUC/self.problems,2),round(100*medianAUC,2),round(self.avgDepNr/self.problems,2),round(self.avgRecall100/self.problems,2),round(medianrecall100,2),self.cutOff)
|
|
146 |
"""
|
50388
|
147 |
|
|
148 |
def save(self,fileName):
|
50220
|
149 |
oStream = open(fileName, 'wb')
|
50388
|
150 |
dump((self.avgAUC,self.avgRecall100,self.avgAvailable,self.avgDepNr,self.problems,self.cutOff,self.recallData,self.recall100Data,self.aucData,self.premiseOccurenceCounter),oStream)
|
50220
|
151 |
oStream.close()
|
|
152 |
def load(self,fileName):
|
50388
|
153 |
iStream = open(fileName, 'rb')
|
|
154 |
self.avgAUC,self.avgRecall100,self.avgAvailable,self.avgDepNr,self.problems,self.cutOff,self.recallData,self.recall100Data,self.aucData,self.premiseOccurenceCounter = load(iStream)
|
|
155 |
iStream.close()
|