author | blanchet |
Thu, 14 Nov 2013 15:57:48 +0100 | |
changeset 54432 | 68f8bd1641da |
parent 54149 | 70456a8f5e6e |
child 54692 | 5ce1b9613705 |
permissions | -rw-r--r-- |
50482 | 1 |
# Title: HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py |
2 |
# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen |
|
3 |
# Copyright 2012 |
|
4 |
# |
|
50619
b958a94cf811
new version of MaSh, with theory-level reasoning
blanchet
parents:
50482
diff
changeset
|
5 |
# An updatable sparse naive Bayes classifier. |
50482 | 6 |
|
7 |
''' |
|
8 |
Created on Jul 11, 2012 |
|
9 |
||
10 |
@author: Daniel Kuehlwein |
|
11 |
''' |
|
12 |
||
13 |
from cPickle import dump,load |
|
50840 | 14 |
from numpy import array |
50482 | 15 |
from math import log |
16 |
||
17 |
class sparseNBClassifier(object): |
|
18 |
''' |
|
19 |
An updateable naive Bayes classifier. |
|
20 |
''' |
|
21 |
||
50951 | 22 |
def __init__(self,defaultPriorWeight = 20.0,posWeight = 20.0,defVal = -15.0): |
50482 | 23 |
''' |
24 |
Constructor |
|
25 |
''' |
|
26 |
self.counts = {} |
|
50840 | 27 |
self.defaultPriorWeight = defaultPriorWeight |
28 |
self.posWeight = posWeight |
|
29 |
self.defVal = defVal |
|
50482 | 30 |
|
31 |
def initializeModel(self,trainData,dicts): |
|
32 |
""" |
|
33 |
Build basic model from training data. |
|
34 |
""" |
|
50840 | 35 |
for d in trainData: |
50482 | 36 |
dFeatureCounts = {} |
50997 | 37 |
# Add p proves p with weight self.defaultPriorWeight |
50840 | 38 |
if not self.defaultPriorWeight == 0: |
53555 | 39 |
for f in dicts.featureDict[d].iterkeys(): |
50840 | 40 |
dFeatureCounts[f] = self.defaultPriorWeight |
41 |
self.counts[d] = [self.defaultPriorWeight,dFeatureCounts] |
|
50482 | 42 |
|
50997 | 43 |
for key,keyDeps in dicts.dependenciesDict.iteritems(): |
54432
68f8bd1641da
have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents:
54149
diff
changeset
|
44 |
keyFeatures = dicts.featureDict[key] |
50482 | 45 |
for dep in keyDeps: |
46 |
self.counts[dep][0] += 1 |
|
54432
68f8bd1641da
have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents:
54149
diff
changeset
|
47 |
#depFeatures = dicts.featureDict[key] |
68f8bd1641da
have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents:
54149
diff
changeset
|
48 |
for f in keyFeatures.iterkeys(): |
50482 | 49 |
if self.counts[dep][1].has_key(f): |
50 |
self.counts[dep][1][f] += 1 |
|
51 |
else: |
|
52 |
self.counts[dep][1][f] = 1 |
|
53 |
||
54 |
||
55 |
def update(self,dataPoint,features,dependencies): |
|
56 |
""" |
|
57 |
Updates the Model. |
|
58 |
""" |
|
54432
68f8bd1641da
have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents:
54149
diff
changeset
|
59 |
if (not self.counts.has_key(dataPoint)) and (not dataPoint == 0): |
50827 | 60 |
dFeatureCounts = {} |
50840 | 61 |
# Give p |- p a higher weight |
62 |
if not self.defaultPriorWeight == 0: |
|
53555 | 63 |
for f in features.iterkeys(): |
50840 | 64 |
dFeatureCounts[f] = self.defaultPriorWeight |
65 |
self.counts[dataPoint] = [self.defaultPriorWeight,dFeatureCounts] |
|
50482 | 66 |
for dep in dependencies: |
67 |
self.counts[dep][0] += 1 |
|
53555 | 68 |
for f in features.iterkeys(): |
50482 | 69 |
if self.counts[dep][1].has_key(f): |
70 |
self.counts[dep][1][f] += 1 |
|
71 |
else: |
|
72 |
self.counts[dep][1][f] = 1 |
|
73 |
||
74 |
def delete(self,dataPoint,features,dependencies): |
|
75 |
""" |
|
76 |
Deletes a single datapoint from the model. |
|
77 |
""" |
|
78 |
for dep in dependencies: |
|
79 |
self.counts[dep][0] -= 1 |
|
53957
ce12e547e6bb
fixed one line that would never have compiled in a typed language + release the lock in case of exceptions
blanchet
parents:
53555
diff
changeset
|
80 |
for f,_w in features.items(): |
50482 | 81 |
self.counts[dep][1][f] -= 1 |
54149
70456a8f5e6e
repair invariant in MaSh when learning new proofs
blanchet
parents:
53957
diff
changeset
|
82 |
if self.counts[dep][1][f] == 0: |
70456a8f5e6e
repair invariant in MaSh when learning new proofs
blanchet
parents:
53957
diff
changeset
|
83 |
del self.counts[dep][1][f] |
50482 | 84 |
|
85 |
||
86 |
def overwrite(self,problemId,newDependencies,dicts): |
|
87 |
""" |
|
88 |
Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly. |
|
89 |
""" |
|
54432
68f8bd1641da
have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents:
54149
diff
changeset
|
90 |
try: |
68f8bd1641da
have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents:
54149
diff
changeset
|
91 |
assert self.counts.has_key(problemId) |
68f8bd1641da
have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents:
54149
diff
changeset
|
92 |
except: |
68f8bd1641da
have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents:
54149
diff
changeset
|
93 |
raise LookupError('Trying to overwrite dependencies for unknown fact: %s. Facts need to be introduced before overwriting them.' % dicts.idNameDict[problemId]) |
50482 | 94 |
oldDeps = dicts.dependenciesDict[problemId] |
95 |
features = dicts.featureDict[problemId] |
|
96 |
self.delete(problemId,features,oldDeps) |
|
97 |
self.update(problemId,features,newDependencies) |
|
98 |
||
50827 | 99 |
def predict(self,features,accessibles,dicts): |
50482 | 100 |
""" |
101 |
For each accessible, predicts the probability of it being useful given the features. |
|
102 |
Returns a ranking of the accessibles. |
|
103 |
""" |
|
53135 | 104 |
tau = 0.05 # Jasmin, change value here |
50482 | 105 |
predictions = [] |
53555 | 106 |
observedFeatures = features.keys() |
50482 | 107 |
for a in accessibles: |
108 |
posA = self.counts[a][0] |
|
109 |
fA = set(self.counts[a][1].keys()) |
|
110 |
fWeightsA = self.counts[a][1] |
|
50951 | 111 |
resultA = log(posA) |
53555 | 112 |
for f,w in features.iteritems(): |
50619
b958a94cf811
new version of MaSh, with theory-level reasoning
blanchet
parents:
50482
diff
changeset
|
113 |
# DEBUG |
50997 | 114 |
#w = 1.0 |
50482 | 115 |
if f in fA: |
116 |
if fWeightsA[f] == 0: |
|
50840 | 117 |
resultA += w*self.defVal |
50482 | 118 |
else: |
119 |
assert fWeightsA[f] <= posA |
|
50827 | 120 |
resultA += w*log(float(self.posWeight*fWeightsA[f])/posA) |
50482 | 121 |
else: |
50840 | 122 |
resultA += w*self.defVal |
53100 | 123 |
if not tau == 0.0: |
124 |
missingFeatures = list(fA.difference(observedFeatures)) |
|
53555 | 125 |
#sumOfWeights = sum([log(float(fWeightsA[x])/posA) for x in missingFeatures]) # slower |
126 |
sumOfWeights = sum([log(float(fWeightsA[x])) for x in missingFeatures]) - log(posA) * len(missingFeatures) #DEFAULT |
|
127 |
#sumOfWeights = sum([log(float(fWeightsA[x])/self.totalFeatureCounts[x]) for x in missingFeatures]) - log(posA) * len(missingFeatures) |
|
53100 | 128 |
resultA -= tau * sumOfWeights |
50482 | 129 |
predictions.append(resultA) |
130 |
predictions = array(predictions) |
|
131 |
perm = (-predictions).argsort() |
|
132 |
return array(accessibles)[perm],predictions[perm] |
|
133 |
||
134 |
def save(self,fileName): |
|
135 |
OStream = open(fileName, 'wb') |
|
50951 | 136 |
dump((self.counts,self.defaultPriorWeight,self.posWeight,self.defVal),OStream) |
50482 | 137 |
OStream.close() |
138 |
||
139 |
def load(self,fileName): |
|
140 |
OStream = open(fileName, 'rb') |
|
50951 | 141 |
self.counts,self.defaultPriorWeight,self.posWeight,self.defVal = load(OStream) |
50482 | 142 |
OStream.close() |
143 |
||
144 |
||
145 |
if __name__ == '__main__': |
|
146 |
featureDict = {0:[0,1,2],1:[3,2,1]} |
|
147 |
dependenciesDict = {0:[0],1:[0,1]} |
|
148 |
libDicts = (featureDict,dependenciesDict,{}) |
|
149 |
c = sparseNBClassifier() |
|
150 |
c.initializeModel([0,1],libDicts) |
|
151 |
c.update(2,[14,1,3],[0,2]) |
|
152 |
print c.counts |
|
153 |
print c.predict([0,14],[0,1,2]) |
|
154 |
c.storeModel('x') |
|
155 |
d = sparseNBClassifier() |
|
156 |
d.loadModel('x') |
|
157 |
print c.counts |
|
158 |
print d.counts |
|
159 |
print 'Done' |