|
1 ''' |
|
2 Created on Jul 9, 2012 |
|
3 |
|
4 @author: Daniel Kuehlwein |
|
5 ''' |
|
6 |
|
7 import logging,string |
|
8 from cPickle import load,dump |
|
9 |
|
10 class Statistics(object): |
|
11 ''' |
|
12 Class for all the statistics |
|
13 ''' |
|
14 |
|
15 def __init__(self,cutOff=500): |
|
16 ''' |
|
17 Constructor |
|
18 ''' |
|
19 self.logger = logging.getLogger('Statistics') |
|
20 self.avgAUC = 0.0 |
|
21 self.avgRecall100 = 0.0 |
|
22 self.avgAvailable = 0.0 |
|
23 self.avgDepNr = 0.0 |
|
24 self.problems = 0.0 |
|
25 self.cutOff = cutOff |
|
26 self.recallData = [0]*cutOff |
|
27 self.recall100Data = [0]*cutOff |
|
28 self.aucData = [] |
|
29 |
|
30 def update(self,predictions,dependencies): |
|
31 """ |
|
32 Evaluates AUC, dependencies, recall100 and number of available premises of a prediction. |
|
33 """ |
|
34 |
|
35 available = len(predictions) |
|
36 predictions = predictions[:self.cutOff] |
|
37 dependencies = set(dependencies) |
|
38 depNr = len(dependencies) |
|
39 aucSum = 0. |
|
40 posResults = 0. |
|
41 positives, negatives = 0, 0 |
|
42 recall100 = 0.0 |
|
43 badPreds = [] |
|
44 depsFound = [] |
|
45 for index,pId in enumerate(predictions): |
|
46 if pId in dependencies: #positive |
|
47 posResults+=1 |
|
48 positives+=1 |
|
49 recall100 = index+1 |
|
50 depsFound.append(pId) |
|
51 if index > 200: |
|
52 badPreds.append(pId) |
|
53 else: |
|
54 aucSum += posResults |
|
55 negatives+=1 |
|
56 # Update Recall and Recall100 stats |
|
57 if depNr == positives: |
|
58 self.recall100Data[index] += 1 |
|
59 if depNr == 0: |
|
60 self.recallData[index] += 1 |
|
61 else: |
|
62 self.recallData[index] += float(positives)/depNr |
|
63 |
|
64 if not depNr == positives: |
|
65 depsFound = set(depsFound) |
|
66 missing = [] |
|
67 for dep in dependencies: |
|
68 if not dep in depsFound: |
|
69 missing.append(dep) |
|
70 badPreds.append(dep) |
|
71 recall100 = len(predictions)+1 |
|
72 positives+=1 |
|
73 self.logger.debug('Dependencies missing for %s in accessibles! Estimating Statistics.',\ |
|
74 string.join([str(dep) for dep in missing],',')) |
|
75 |
|
76 if positives == 0 or negatives == 0: |
|
77 auc = 1.0 |
|
78 else: |
|
79 auc = aucSum/(negatives*positives) |
|
80 |
|
81 self.aucData.append(auc) |
|
82 self.avgAUC += auc |
|
83 self.avgRecall100 += recall100 |
|
84 self.problems += 1 |
|
85 self.badPreds = badPreds |
|
86 self.avgAvailable += available |
|
87 self.avgDepNr += depNr |
|
88 self.logger.info('AUC: %s \t Needed: %s \t Recall100: %s \t Available: %s \t cutOff:%s',\ |
|
89 round(100*auc,2),depNr,recall100,available,self.cutOff) |
|
90 |
|
91 def printAvg(self): |
|
92 self.logger.info('Average results:') |
|
93 self.logger.info('avgAUC: %s \t avgDepNr: %s \t avgRecall100: %s \t cutOff:%s', \ |
|
94 round(100*self.avgAUC/self.problems,2),round(self.avgDepNr/self.problems,2),self.avgRecall100/self.problems,self.cutOff) |
|
95 |
|
96 try: |
|
97 from matplotlib.pyplot import plot,figure,show,xlabel,ylabel,axis,hist |
|
98 avgRecall = [float(x)/self.problems for x in self.recallData] |
|
99 figure('Recall') |
|
100 plot(range(self.cutOff),avgRecall) |
|
101 ylabel('Average Recall') |
|
102 xlabel('Highest ranked premises') |
|
103 axis([0,self.cutOff,0.0,1.0]) |
|
104 figure('100%Recall') |
|
105 plot(range(self.cutOff),self.recall100Data) |
|
106 ylabel('100%Recall') |
|
107 xlabel('Highest ranked premises') |
|
108 axis([0,self.cutOff,0,self.problems]) |
|
109 figure('AUC Histogram') |
|
110 hist(self.aucData,bins=100) |
|
111 ylabel('Problems') |
|
112 xlabel('AUC') |
|
113 show() |
|
114 except: |
|
115 self.logger.warning('Matplotlib module missing. Skipping graphs.') |
|
116 |
|
117 def save(self,fileName): |
|
118 oStream = open(fileName, 'wb') |
|
119 dump((self.avgAUC,self.avgRecall100,self.avgAvailable,self.avgDepNr,self.problems,self.cutOff,self.recallData,self.recall100Data,self.aucData),oStream) |
|
120 oStream.close() |
|
121 def load(self,fileName): |
|
122 iStream = open(fileName, 'rb') |
|
123 self.avgAUC,self.avgRecall100,self.avgAvailable,self.avgDepNr,self.problems,self.cutOff,self.recallData,self.recall100Data,self.aucData = load(iStream) |
|
124 iStream.close() |