author | blanchet |
Fri, 18 Oct 2013 13:30:09 +0200 | |
changeset 54149 | 70456a8f5e6e |
parent 53957 | ce12e547e6bb |
child 54432 | 68f8bd1641da |
permissions | -rw-r--r-- |
50482 | 1 |
# Title: HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py |
2 |
# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen |
|
3 |
# Copyright 2012 |
|
4 |
# |
|
50619
b958a94cf811
new version of MaSh, with theory-level reasoning
blanchet
parents:
50482
diff
changeset
|
5 |
# An updatable sparse naive Bayes classifier. |
50482 | 6 |
|
7 |
''' |
|
8 |
Created on Jul 11, 2012 |
|
9 |
||
10 |
@author: Daniel Kuehlwein |
|
11 |
''' |
|
12 |
||
13 |
from cPickle import dump,load |
|
50840 | 14 |
from numpy import array |
50482 | 15 |
from math import log |
16 |
||
17 |
class sparseNBClassifier(object): |
|
18 |
''' |
|
19 |
An updateable naive Bayes classifier. |
|
20 |
''' |
|
21 |
||
50951 | 22 |
def __init__(self,defaultPriorWeight = 20.0,posWeight = 20.0,defVal = -15.0): |
50482 | 23 |
''' |
24 |
Constructor |
|
25 |
''' |
|
26 |
self.counts = {} |
|
50840 | 27 |
self.defaultPriorWeight = defaultPriorWeight |
28 |
self.posWeight = posWeight |
|
29 |
self.defVal = defVal |
|
50482 | 30 |
|
31 |
def initializeModel(self,trainData,dicts): |
|
32 |
""" |
|
33 |
Build basic model from training data. |
|
34 |
""" |
|
50840 | 35 |
for d in trainData: |
50482 | 36 |
dFeatureCounts = {} |
50997 | 37 |
# Add p proves p with weight self.defaultPriorWeight |
50840 | 38 |
if not self.defaultPriorWeight == 0: |
53555 | 39 |
for f in dicts.featureDict[d].iterkeys(): |
50840 | 40 |
dFeatureCounts[f] = self.defaultPriorWeight |
41 |
self.counts[d] = [self.defaultPriorWeight,dFeatureCounts] |
|
50482 | 42 |
|
50997 | 43 |
for key,keyDeps in dicts.dependenciesDict.iteritems(): |
50482 | 44 |
for dep in keyDeps: |
45 |
self.counts[dep][0] += 1 |
|
46 |
depFeatures = dicts.featureDict[key] |
|
53555 | 47 |
for f in depFeatures.iterkeys(): |
50482 | 48 |
if self.counts[dep][1].has_key(f): |
49 |
self.counts[dep][1][f] += 1 |
|
50 |
else: |
|
51 |
self.counts[dep][1][f] = 1 |
|
52 |
||
53 |
||
54 |
def update(self,dataPoint,features,dependencies): |
|
55 |
""" |
|
56 |
Updates the Model. |
|
57 |
""" |
|
58 |
if not self.counts.has_key(dataPoint): |
|
50827 | 59 |
dFeatureCounts = {} |
50840 | 60 |
# Give p |- p a higher weight |
61 |
if not self.defaultPriorWeight == 0: |
|
53555 | 62 |
for f in features.iterkeys(): |
50840 | 63 |
dFeatureCounts[f] = self.defaultPriorWeight |
64 |
self.counts[dataPoint] = [self.defaultPriorWeight,dFeatureCounts] |
|
50482 | 65 |
for dep in dependencies: |
66 |
self.counts[dep][0] += 1 |
|
53555 | 67 |
for f in features.iterkeys(): |
50482 | 68 |
if self.counts[dep][1].has_key(f): |
69 |
self.counts[dep][1][f] += 1 |
|
70 |
else: |
|
71 |
self.counts[dep][1][f] = 1 |
|
72 |
||
73 |
def delete(self,dataPoint,features,dependencies): |
|
74 |
""" |
|
75 |
Deletes a single datapoint from the model. |
|
76 |
""" |
|
77 |
for dep in dependencies: |
|
78 |
self.counts[dep][0] -= 1 |
|
53957
ce12e547e6bb
fixed one line that would never have compiled in a typed language + release the lock in case of exceptions
blanchet
parents:
53555
diff
changeset
|
79 |
for f,_w in features.items(): |
50482 | 80 |
self.counts[dep][1][f] -= 1 |
54149
70456a8f5e6e
repair invariant in MaSh when learning new proofs
blanchet
parents:
53957
diff
changeset
|
81 |
if self.counts[dep][1][f] == 0: |
70456a8f5e6e
repair invariant in MaSh when learning new proofs
blanchet
parents:
53957
diff
changeset
|
82 |
del self.counts[dep][1][f] |
50482 | 83 |
|
84 |
||
85 |
def overwrite(self,problemId,newDependencies,dicts): |
|
86 |
""" |
|
87 |
Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly. |
|
88 |
""" |
|
89 |
assert self.counts.has_key(problemId) |
|
90 |
oldDeps = dicts.dependenciesDict[problemId] |
|
91 |
features = dicts.featureDict[problemId] |
|
92 |
self.delete(problemId,features,oldDeps) |
|
93 |
self.update(problemId,features,newDependencies) |
|
94 |
||
50827 | 95 |
def predict(self,features,accessibles,dicts): |
50482 | 96 |
""" |
97 |
For each accessible, predicts the probability of it being useful given the features. |
|
98 |
Returns a ranking of the accessibles. |
|
99 |
""" |
|
53135 | 100 |
tau = 0.05 # Jasmin, change value here |
50482 | 101 |
predictions = [] |
53555 | 102 |
observedFeatures = features.keys() |
50482 | 103 |
for a in accessibles: |
104 |
posA = self.counts[a][0] |
|
105 |
fA = set(self.counts[a][1].keys()) |
|
106 |
fWeightsA = self.counts[a][1] |
|
50951 | 107 |
resultA = log(posA) |
53555 | 108 |
for f,w in features.iteritems(): |
50619
b958a94cf811
new version of MaSh, with theory-level reasoning
blanchet
parents:
50482
diff
changeset
|
109 |
# DEBUG |
50997 | 110 |
#w = 1.0 |
50482 | 111 |
if f in fA: |
112 |
if fWeightsA[f] == 0: |
|
50840 | 113 |
resultA += w*self.defVal |
50482 | 114 |
else: |
115 |
assert fWeightsA[f] <= posA |
|
50827 | 116 |
resultA += w*log(float(self.posWeight*fWeightsA[f])/posA) |
50482 | 117 |
else: |
50840 | 118 |
resultA += w*self.defVal |
53100 | 119 |
if not tau == 0.0: |
120 |
missingFeatures = list(fA.difference(observedFeatures)) |
|
53555 | 121 |
#sumOfWeights = sum([log(float(fWeightsA[x])/posA) for x in missingFeatures]) # slower |
122 |
sumOfWeights = sum([log(float(fWeightsA[x])) for x in missingFeatures]) - log(posA) * len(missingFeatures) #DEFAULT |
|
123 |
#sumOfWeights = sum([log(float(fWeightsA[x])/self.totalFeatureCounts[x]) for x in missingFeatures]) - log(posA) * len(missingFeatures) |
|
53100 | 124 |
resultA -= tau * sumOfWeights |
50482 | 125 |
predictions.append(resultA) |
126 |
predictions = array(predictions) |
|
127 |
perm = (-predictions).argsort() |
|
128 |
return array(accessibles)[perm],predictions[perm] |
|
129 |
||
130 |
def save(self,fileName): |
|
131 |
OStream = open(fileName, 'wb') |
|
50951 | 132 |
dump((self.counts,self.defaultPriorWeight,self.posWeight,self.defVal),OStream) |
50482 | 133 |
OStream.close() |
134 |
||
135 |
def load(self,fileName): |
|
136 |
OStream = open(fileName, 'rb') |
|
50951 | 137 |
self.counts,self.defaultPriorWeight,self.posWeight,self.defVal = load(OStream) |
50482 | 138 |
OStream.close() |
139 |
||
140 |
||
141 |
if __name__ == '__main__': |
|
142 |
featureDict = {0:[0,1,2],1:[3,2,1]} |
|
143 |
dependenciesDict = {0:[0],1:[0,1]} |
|
144 |
libDicts = (featureDict,dependenciesDict,{}) |
|
145 |
c = sparseNBClassifier() |
|
146 |
c.initializeModel([0,1],libDicts) |
|
147 |
c.update(2,[14,1,3],[0,2]) |
|
148 |
print c.counts |
|
149 |
print c.predict([0,14],[0,1,2]) |
|
150 |
c.storeModel('x') |
|
151 |
d = sparseNBClassifier() |
|
152 |
d.loadModel('x') |
|
153 |
print c.counts |
|
154 |
print d.counts |
|
155 |
print 'Done' |