|
1 ''' |
|
2 Created on Jul 11, 2012 |
|
3 |
|
4 @author: Daniel Kuehlwein |
|
5 ''' |
|
6 |
|
7 from cPickle import dump,load |
|
8 from numpy import array,exp |
|
9 from math import log |
|
10 |
|
11 class NBClassifier(object): |
|
12 ''' |
|
13 An updateable naive Bayes classifier. |
|
14 ''' |
|
15 |
|
16 def __init__(self): |
|
17 ''' |
|
18 Constructor |
|
19 ''' |
|
20 self.counts = {} |
|
21 self.negCounts = {} |
|
22 |
|
23 def initializeModel(self,trainData,dicts): |
|
24 """ |
|
25 Build basic model from training data. |
|
26 """ |
|
27 for d in trainData: |
|
28 self.counts[d] = [0,{}] |
|
29 self.negCounts[d] = [0,{}] |
|
30 dAccUnExp = dicts.accessibleDict[d] |
|
31 if dicts.expandedAccessibles.has_key(d): |
|
32 dAcc = dicts.expandedAccessibles(d) |
|
33 else: |
|
34 if len(dicts.expandedAccessibles.keys()) >= 100: |
|
35 dicts.expandedAccessibles = {} |
|
36 dAcc = dicts.expand_accessibles(dAccUnExp) |
|
37 dicts.expandedAccessibles[d] = dAcc |
|
38 dDeps = set(dicts.dependenciesDict[d]) |
|
39 dFeatures = dicts.featureDict[d] |
|
40 # d proves d |
|
41 self.counts[d][0] += 1 |
|
42 for f in dFeatures: |
|
43 if self.counts[d][1].has_key(f): |
|
44 self.counts[d][1][f] += 1 |
|
45 else: |
|
46 self.counts[d][1][f] = 1 |
|
47 for acc in dAcc: |
|
48 if not self.counts.has_key(acc): |
|
49 self.counts[acc] = [0,{}] |
|
50 if not self.negCounts.has_key(acc): |
|
51 self.negCounts[acc] = [0,{}] |
|
52 if acc in dDeps: |
|
53 self.counts[acc][0] += 1 |
|
54 for f in dFeatures: |
|
55 if self.counts[acc][1].has_key(f): |
|
56 self.counts[acc][1][f] += 1 |
|
57 else: |
|
58 self.counts[acc][1][f] = 1 |
|
59 else: |
|
60 self.negCounts[acc][0] += 1 |
|
61 for f in dFeatures: |
|
62 if self.negCounts[acc][1].has_key(f): |
|
63 self.negCounts[acc][1][f] += 1 |
|
64 else: |
|
65 self.negCounts[acc][1][f] = 1 |
|
66 |
|
67 def update(self,dataPoint,features,dependencies,dicts): |
|
68 """ |
|
69 Updates the Model. |
|
70 """ |
|
71 if not self.counts.has_key(dataPoint): |
|
72 self.counts[dataPoint] = [0,{}] |
|
73 if not self.negCounts.has_key(dataPoint): |
|
74 self.negCounts[dataPoint] = [0,{}] |
|
75 if dicts.expandedAccessibles.has_key(dataPoint): |
|
76 dAcc = dicts.expandedAccessibles(dataPoint) |
|
77 else: |
|
78 if len(dicts.expandedAccessibles.keys()) >= 100: |
|
79 dicts.expandedAccessibles = {} |
|
80 dAccUnExp = dicts.accessibleDict[dataPoint] |
|
81 dAcc = dicts.expand_accessibles(dAccUnExp) |
|
82 dicts.expandedAccessibles[dataPoint] = dAcc |
|
83 dDeps = set(dicts.dependenciesDict[dataPoint]) |
|
84 dFeatures = dicts.featureDict[dataPoint] |
|
85 # d proves d |
|
86 self.counts[dataPoint][0] += 1 |
|
87 for f in dFeatures: |
|
88 if self.counts[dataPoint][1].has_key(f): |
|
89 self.counts[dataPoint][1][f] += 1 |
|
90 else: |
|
91 self.counts[dataPoint][1][f] = 1 |
|
92 |
|
93 for acc in dAcc: |
|
94 if acc in dDeps: |
|
95 self.counts[acc][0] += 1 |
|
96 for f in dFeatures: |
|
97 if self.counts[acc][1].has_key(f): |
|
98 self.counts[acc][1][f] += 1 |
|
99 else: |
|
100 self.counts[acc][1][f] = 1 |
|
101 else: |
|
102 self.negCounts[acc][0] += 1 |
|
103 for f in dFeatures: |
|
104 if self.negCounts[acc][1].has_key(f): |
|
105 self.negCounts[acc][1][f] += 1 |
|
106 else: |
|
107 self.negCounts[acc][1][f] = 1 |
|
108 |
|
109 def delete(self,dataPoint,features,dependencies): |
|
110 """ |
|
111 Deletes a single datapoint from the model. |
|
112 """ |
|
113 for dep in dependencies: |
|
114 self.counts[dep][0] -= 1 |
|
115 for f in features: |
|
116 self.counts[dep][1][f] -= 1 |
|
117 |
|
118 |
|
119 def overwrite(self,problemId,newDependencies,dicts): |
|
120 """ |
|
121 Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly. |
|
122 """ |
|
123 assert self.counts.has_key(problemId) |
|
124 oldDeps = dicts.dependenciesDict[problemId] |
|
125 features = dicts.featureDict[problemId] |
|
126 self.delete(problemId,features,oldDeps) |
|
127 self.update(problemId,features,newDependencies) |
|
128 |
|
129 def predict(self,features,accessibles): |
|
130 """ |
|
131 For each accessible, predicts the probability of it being useful given the features. |
|
132 Returns a ranking of the accessibles. |
|
133 """ |
|
134 predictions = [] |
|
135 for a in accessibles: |
|
136 posA = self.counts[a][0] |
|
137 negA = self.negCounts[a][0] |
|
138 fPosA = set(self.counts[a][1].keys()) |
|
139 fNegA = set(self.negCounts[a][1].keys()) |
|
140 fPosWeightsA = self.counts[a][1] |
|
141 fNegWeightsA = self.negCounts[a][1] |
|
142 if negA == 0: |
|
143 resultA = 0 |
|
144 elif posA == 0: |
|
145 print a |
|
146 print 'xx' |
|
147 import sys |
|
148 sys.exit(-1) |
|
149 else: |
|
150 resultA = log(posA) - log(negA) |
|
151 for f in features: |
|
152 if f in fPosA: |
|
153 # P(f | a) |
|
154 if fPosWeightsA[f] == 0: |
|
155 resultA -= 15 |
|
156 else: |
|
157 assert fPosWeightsA[f] <= posA |
|
158 resultA += log(float(fPosWeightsA[f])/posA) |
|
159 else: |
|
160 resultA -= 15 |
|
161 # P(f | not a) |
|
162 if f in fNegA: |
|
163 if fNegWeightsA[f] == 0: |
|
164 resultA += 15 |
|
165 else: |
|
166 assert fNegWeightsA[f] <= negA |
|
167 resultA -= log(float(fNegWeightsA[f])/negA) |
|
168 else: |
|
169 resultA += 15 |
|
170 |
|
171 predictions.append(resultA) |
|
172 #expPredictions = array([exp(x) for x in predictions]) |
|
173 predictions = array(predictions) |
|
174 perm = (-predictions).argsort() |
|
175 #return array(accessibles)[perm],expPredictions[perm] |
|
176 return array(accessibles)[perm],predictions[perm] |
|
177 |
|
178 def save(self,fileName): |
|
179 OStream = open(fileName, 'wb') |
|
180 dump((self.counts,self.negCounts),OStream) |
|
181 OStream.close() |
|
182 |
|
183 def load(self,fileName): |
|
184 OStream = open(fileName, 'rb') |
|
185 self.counts,self.negCounts = load(OStream) |
|
186 OStream.close() |
|
187 |
|
188 |
|
189 if __name__ == '__main__': |
|
190 featureDict = {0:[0,1,2],1:[3,2,1]} |
|
191 dependenciesDict = {0:[0],1:[0,1]} |
|
192 libDicts = (featureDict,dependenciesDict,{}) |
|
193 c = NBClassifier() |
|
194 c.initializeModel([0,1],libDicts) |
|
195 c.update(2,[14,1,3],[0,2]) |
|
196 print c.counts |
|
197 print c.predict([0,14],[0,1,2]) |
|
198 c.storeModel('x') |
|
199 d = NBClassifier() |
|
200 d.loadModel('x') |
|
201 print c.counts |
|
202 print d.counts |
|
203 print 'Done' |