author | blanchet |
Thu, 31 Jan 2013 11:20:12 +0100 | |
changeset 50997 | 31f9ba85dc2e |
parent 50951 | e1cbaa7d5536 |
child 53100 | 1133b9e83f09 |
permissions | -rw-r--r-- |
50482 | 1 |
# Title: HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py |
2 |
# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen |
|
3 |
# Copyright 2012 |
|
4 |
# |
|
50619
b958a94cf811
new version of MaSh, with theory-level reasoning
blanchet
parents:
50482
diff
changeset
|
5 |
# An updatable sparse naive Bayes classifier. |
50482 | 6 |
|
7 |
''' |
|
8 |
Created on Jul 11, 2012 |
|
9 |
||
10 |
@author: Daniel Kuehlwein |
|
11 |
''' |
|
12 |
||
13 |
from cPickle import dump,load |
|
50840 | 14 |
from numpy import array |
50482 | 15 |
from math import log |
16 |
||
17 |
class sparseNBClassifier(object): |
|
18 |
''' |
|
19 |
An updateable naive Bayes classifier. |
|
20 |
''' |
|
21 |
||
50951 | 22 |
def __init__(self,defaultPriorWeight = 20.0,posWeight = 20.0,defVal = -15.0): |
50482 | 23 |
''' |
24 |
Constructor |
|
25 |
''' |
|
26 |
self.counts = {} |
|
50840 | 27 |
self.defaultPriorWeight = defaultPriorWeight |
28 |
self.posWeight = posWeight |
|
29 |
self.defVal = defVal |
|
50482 | 30 |
|
31 |
def initializeModel(self,trainData,dicts): |
|
32 |
""" |
|
33 |
Build basic model from training data. |
|
34 |
""" |
|
50840 | 35 |
for d in trainData: |
50482 | 36 |
dFeatureCounts = {} |
50997 | 37 |
# Add p proves p with weight self.defaultPriorWeight |
50840 | 38 |
if not self.defaultPriorWeight == 0: |
39 |
for f,_w in dicts.featureDict[d]: |
|
40 |
dFeatureCounts[f] = self.defaultPriorWeight |
|
41 |
self.counts[d] = [self.defaultPriorWeight,dFeatureCounts] |
|
50482 | 42 |
|
50997 | 43 |
for key,keyDeps in dicts.dependenciesDict.iteritems(): |
50482 | 44 |
for dep in keyDeps: |
45 |
self.counts[dep][0] += 1 |
|
46 |
depFeatures = dicts.featureDict[key] |
|
47 |
for f,_w in depFeatures: |
|
48 |
if self.counts[dep][1].has_key(f): |
|
49 |
self.counts[dep][1][f] += 1 |
|
50 |
else: |
|
51 |
self.counts[dep][1][f] = 1 |
|
52 |
||
53 |
||
54 |
def update(self,dataPoint,features,dependencies): |
|
55 |
""" |
|
56 |
Updates the Model. |
|
57 |
""" |
|
58 |
if not self.counts.has_key(dataPoint): |
|
50827 | 59 |
dFeatureCounts = {} |
50840 | 60 |
# Give p |- p a higher weight |
61 |
if not self.defaultPriorWeight == 0: |
|
62 |
for f,_w in features: |
|
63 |
dFeatureCounts[f] = self.defaultPriorWeight |
|
64 |
self.counts[dataPoint] = [self.defaultPriorWeight,dFeatureCounts] |
|
50482 | 65 |
for dep in dependencies: |
66 |
self.counts[dep][0] += 1 |
|
67 |
for f,_w in features: |
|
68 |
if self.counts[dep][1].has_key(f): |
|
69 |
self.counts[dep][1][f] += 1 |
|
70 |
else: |
|
71 |
self.counts[dep][1][f] = 1 |
|
72 |
||
73 |
def delete(self,dataPoint,features,dependencies): |
|
74 |
""" |
|
75 |
Deletes a single datapoint from the model. |
|
76 |
""" |
|
77 |
for dep in dependencies: |
|
78 |
self.counts[dep][0] -= 1 |
|
50827 | 79 |
for f,_w in features: |
50482 | 80 |
self.counts[dep][1][f] -= 1 |
81 |
||
82 |
||
83 |
def overwrite(self,problemId,newDependencies,dicts): |
|
84 |
""" |
|
85 |
Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly. |
|
86 |
""" |
|
87 |
assert self.counts.has_key(problemId) |
|
88 |
oldDeps = dicts.dependenciesDict[problemId] |
|
89 |
features = dicts.featureDict[problemId] |
|
90 |
self.delete(problemId,features,oldDeps) |
|
91 |
self.update(problemId,features,newDependencies) |
|
92 |
||
50827 | 93 |
def predict(self,features,accessibles,dicts): |
50482 | 94 |
""" |
95 |
For each accessible, predicts the probability of it being useful given the features. |
|
96 |
Returns a ranking of the accessibles. |
|
97 |
""" |
|
98 |
predictions = [] |
|
99 |
for a in accessibles: |
|
100 |
posA = self.counts[a][0] |
|
101 |
fA = set(self.counts[a][1].keys()) |
|
102 |
fWeightsA = self.counts[a][1] |
|
50951 | 103 |
resultA = log(posA) |
50482 | 104 |
for f,w in features: |
50619
b958a94cf811
new version of MaSh, with theory-level reasoning
blanchet
parents:
50482
diff
changeset
|
105 |
# DEBUG |
50997 | 106 |
#w = 1.0 |
50482 | 107 |
if f in fA: |
108 |
if fWeightsA[f] == 0: |
|
50840 | 109 |
resultA += w*self.defVal |
50482 | 110 |
else: |
111 |
assert fWeightsA[f] <= posA |
|
50827 | 112 |
resultA += w*log(float(self.posWeight*fWeightsA[f])/posA) |
50482 | 113 |
else: |
50840 | 114 |
resultA += w*self.defVal |
50482 | 115 |
predictions.append(resultA) |
116 |
predictions = array(predictions) |
|
117 |
perm = (-predictions).argsort() |
|
118 |
return array(accessibles)[perm],predictions[perm] |
|
119 |
||
120 |
def save(self,fileName): |
|
121 |
OStream = open(fileName, 'wb') |
|
50951 | 122 |
dump((self.counts,self.defaultPriorWeight,self.posWeight,self.defVal),OStream) |
50482 | 123 |
OStream.close() |
124 |
||
125 |
def load(self,fileName): |
|
126 |
OStream = open(fileName, 'rb') |
|
50951 | 127 |
self.counts,self.defaultPriorWeight,self.posWeight,self.defVal = load(OStream) |
50482 | 128 |
OStream.close() |
129 |
||
130 |
||
131 |
if __name__ == '__main__': |
|
132 |
featureDict = {0:[0,1,2],1:[3,2,1]} |
|
133 |
dependenciesDict = {0:[0],1:[0,1]} |
|
134 |
libDicts = (featureDict,dependenciesDict,{}) |
|
135 |
c = sparseNBClassifier() |
|
136 |
c.initializeModel([0,1],libDicts) |
|
137 |
c.update(2,[14,1,3],[0,2]) |
|
138 |
print c.counts |
|
139 |
print c.predict([0,14],[0,1,2]) |
|
140 |
c.storeModel('x') |
|
141 |
d = sparseNBClassifier() |
|
142 |
d.loadModel('x') |
|
143 |
print c.counts |
|
144 |
print d.counts |
|
145 |
print 'Done' |