|
1 # Title: HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py |
|
2 # Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen |
|
3 # Copyright 2012 |
|
4 # |
|
5 # An updatable sparse naive Bayes classifier. |
|
6 |
|
7 ''' |
|
8 Created on Dec 26, 2012 |
|
9 |
|
10 @author: Daniel Kuehlwein |
|
11 ''' |
|
12 |
|
13 from cPickle import load,dump |
|
14 import logging,string |
|
15 |
|
16 class TheoryStatistics(object): |
|
17 ''' |
|
18 Stores statistics for theory lvl predictions |
|
19 ''' |
|
20 |
|
21 |
|
22 def __init__(self): |
|
23 ''' |
|
24 Constructor |
|
25 ''' |
|
26 self.logger = logging.getLogger('TheoryStatistics') |
|
27 self.count = 0 |
|
28 self.precision = 0.0 |
|
29 self.recall100 = 0 |
|
30 self.recall = 0.0 |
|
31 self.predicted = 0.0 |
|
32 |
|
33 def update(self,currentTheory,predictedTheories,usedTheories): |
|
34 self.count += 1 |
|
35 allPredTheories = predictedTheories.union([currentTheory]) |
|
36 if set(usedTheories).issubset(allPredTheories): |
|
37 self.recall100 += 1 |
|
38 localPredicted = len(allPredTheories) |
|
39 self.predicted += localPredicted |
|
40 localPrec = float(len(set(usedTheories).intersection(allPredTheories))) / localPredicted |
|
41 self.precision += localPrec |
|
42 localRecall = float(len(set(usedTheories).intersection(allPredTheories))) / len(set(usedTheories)) |
|
43 self.recall += localRecall |
|
44 self.logger.info('Theory prediction results:') |
|
45 self.logger.info('Problem: %s \t Recall100: %s \t Precision: %s \t Recall: %s \t PredictedTeories: %s',\ |
|
46 self.count,self.recall100,round(localPrec,2),round(localRecall,2),localPredicted) |
|
47 |
|
48 def printAvg(self): |
|
49 self.logger.info('Average theory results:') |
|
50 self.logger.info('avgPrecision: %s \t avgRecall100: %s \t avgRecall: %s \t avgPredicted:%s', \ |
|
51 round(self.precision/self.count,2),\ |
|
52 round(float(self.recall100)/self.count,2),\ |
|
53 round(self.recall/self.count,2),\ |
|
54 round(self.predicted /self.count,2)) |
|
55 |
|
56 def save(self,fileName): |
|
57 oStream = open(fileName, 'wb') |
|
58 dump((self.count,self.precision,self.recall100,self.recall,self.predicted),oStream) |
|
59 oStream.close() |
|
60 def load(self,fileName): |
|
61 iStream = open(fileName, 'rb') |
|
62 self.count,self.precision,self.recall100,self.recall,self.predicted = load(iStream) |
|
63 iStream.close() |