author | blanchet |
Fri, 11 Jan 2013 16:30:56 +0100 | |
changeset 50827 | aba769dc82e9 |
parent 50619 | b958a94cf811 |
child 50840 | a5cc092156da |
permissions | -rw-r--r-- |
50482 | 1 |
# Title: HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py |
2 |
# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen |
|
3 |
# Copyright 2012 |
|
4 |
# |
|
50619
b958a94cf811
new version of MaSh, with theory-level reasoning
blanchet
parents:
50482
diff
changeset
|
5 |
# An updatable sparse naive Bayes classifier. |
50482 | 6 |
|
7 |
''' |
|
8 |
Created on Jul 11, 2012 |
|
9 |
||
10 |
@author: Daniel Kuehlwein |
|
11 |
''' |
|
12 |
||
13 |
from cPickle import dump,load |
|
14 |
from numpy import array,exp |
|
15 |
from math import log |
|
16 |
||
17 |
class sparseNBClassifier(object): |
|
18 |
''' |
|
19 |
An updateable naive Bayes classifier. |
|
20 |
''' |
|
21 |
||
50827 | 22 |
def __init__(self,useSinePrior = False): |
50482 | 23 |
''' |
24 |
Constructor |
|
25 |
''' |
|
26 |
self.counts = {} |
|
50827 | 27 |
self.sinePrior = useSinePrior |
28 |
self.defaultPriorWeight = 20 |
|
29 |
self.posWeight = 20 |
|
30 |
self.defVal = 15 |
|
50482 | 31 |
|
32 |
def initializeModel(self,trainData,dicts): |
|
33 |
""" |
|
34 |
Build basic model from training data. |
|
35 |
""" |
|
36 |
for d in trainData: |
|
37 |
dPosCount = 0 |
|
38 |
dFeatureCounts = {} |
|
50827 | 39 |
# DEBUG: give p |- p a higher weight |
40 |
dPosCount = self.defaultPriorWeight |
|
41 |
for f,_w in dicts.featureDict[d]: |
|
42 |
dFeatureCounts[f] = self.defaultPriorWeight |
|
50482 | 43 |
self.counts[d] = [dPosCount,dFeatureCounts] |
44 |
||
45 |
for key in dicts.dependenciesDict.keys(): |
|
46 |
# Add p proves p |
|
47 |
keyDeps = [key]+dicts.dependenciesDict[key] |
|
48 |
for dep in keyDeps: |
|
49 |
self.counts[dep][0] += 1 |
|
50 |
depFeatures = dicts.featureDict[key] |
|
51 |
for f,_w in depFeatures: |
|
52 |
if self.counts[dep][1].has_key(f): |
|
53 |
self.counts[dep][1][f] += 1 |
|
54 |
else: |
|
55 |
self.counts[dep][1][f] = 1 |
|
56 |
||
57 |
||
58 |
def update(self,dataPoint,features,dependencies): |
|
59 |
""" |
|
60 |
Updates the Model. |
|
61 |
""" |
|
62 |
if not self.counts.has_key(dataPoint): |
|
63 |
dPosCount = 0 |
|
50827 | 64 |
dFeatureCounts = {} |
65 |
# DEBUG: give p |- p a higher weight |
|
66 |
dPosCount = self.defaultPriorWeight |
|
67 |
for f,_w in features: |
|
68 |
dFeatureCounts[f] = self.defaultPriorWeight |
|
69 |
self.counts[dataPoint] = [dPosCount,dFeatureCounts] |
|
50482 | 70 |
for dep in dependencies: |
71 |
self.counts[dep][0] += 1 |
|
72 |
for f,_w in features: |
|
73 |
if self.counts[dep][1].has_key(f): |
|
74 |
self.counts[dep][1][f] += 1 |
|
75 |
else: |
|
76 |
self.counts[dep][1][f] = 1 |
|
77 |
||
78 |
def delete(self,dataPoint,features,dependencies): |
|
79 |
""" |
|
80 |
Deletes a single datapoint from the model. |
|
81 |
""" |
|
82 |
for dep in dependencies: |
|
83 |
self.counts[dep][0] -= 1 |
|
50827 | 84 |
for f,_w in features: |
50482 | 85 |
self.counts[dep][1][f] -= 1 |
86 |
||
87 |
||
88 |
def overwrite(self,problemId,newDependencies,dicts): |
|
89 |
""" |
|
90 |
Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly. |
|
91 |
""" |
|
92 |
assert self.counts.has_key(problemId) |
|
93 |
oldDeps = dicts.dependenciesDict[problemId] |
|
94 |
features = dicts.featureDict[problemId] |
|
95 |
self.delete(problemId,features,oldDeps) |
|
96 |
self.update(problemId,features,newDependencies) |
|
97 |
||
50827 | 98 |
def predict(self,features,accessibles,dicts): |
50482 | 99 |
""" |
100 |
For each accessible, predicts the probability of it being useful given the features. |
|
101 |
Returns a ranking of the accessibles. |
|
102 |
""" |
|
103 |
predictions = [] |
|
104 |
for a in accessibles: |
|
105 |
posA = self.counts[a][0] |
|
106 |
fA = set(self.counts[a][1].keys()) |
|
107 |
fWeightsA = self.counts[a][1] |
|
108 |
resultA = log(posA) |
|
109 |
for f,w in features: |
|
50619
b958a94cf811
new version of MaSh, with theory-level reasoning
blanchet
parents:
50482
diff
changeset
|
110 |
# DEBUG |
b958a94cf811
new version of MaSh, with theory-level reasoning
blanchet
parents:
50482
diff
changeset
|
111 |
#w = 1 |
50482 | 112 |
if f in fA: |
113 |
if fWeightsA[f] == 0: |
|
50827 | 114 |
resultA -= w*self.defVal |
50482 | 115 |
else: |
116 |
assert fWeightsA[f] <= posA |
|
50827 | 117 |
resultA += w*log(float(self.posWeight*fWeightsA[f])/posA) |
50482 | 118 |
else: |
50827 | 119 |
resultA -= w*self.defVal |
50482 | 120 |
predictions.append(resultA) |
121 |
#expPredictions = array([exp(x) for x in predictions]) |
|
122 |
predictions = array(predictions) |
|
123 |
perm = (-predictions).argsort() |
|
124 |
#return array(accessibles)[perm],expPredictions[perm] |
|
125 |
return array(accessibles)[perm],predictions[perm] |
|
126 |
||
127 |
def save(self,fileName): |
|
128 |
OStream = open(fileName, 'wb') |
|
129 |
dump(self.counts,OStream) |
|
130 |
OStream.close() |
|
131 |
||
132 |
def load(self,fileName): |
|
133 |
OStream = open(fileName, 'rb') |
|
134 |
self.counts = load(OStream) |
|
135 |
OStream.close() |
|
136 |
||
137 |
||
138 |
if __name__ == '__main__': |
|
139 |
featureDict = {0:[0,1,2],1:[3,2,1]} |
|
140 |
dependenciesDict = {0:[0],1:[0,1]} |
|
141 |
libDicts = (featureDict,dependenciesDict,{}) |
|
142 |
c = sparseNBClassifier() |
|
143 |
c.initializeModel([0,1],libDicts) |
|
144 |
c.update(2,[14,1,3],[0,2]) |
|
145 |
print c.counts |
|
146 |
print c.predict([0,14],[0,1,2]) |
|
147 |
c.storeModel('x') |
|
148 |
d = sparseNBClassifier() |
|
149 |
d.loadModel('x') |
|
150 |
print c.counts |
|
151 |
print d.counts |
|
152 |
print 'Done' |