src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
author blanchet
Thu, 31 Jan 2013 11:20:12 +0100
changeset 50997 31f9ba85dc2e
parent 50951 e1cbaa7d5536
child 53100 1133b9e83f09
permissions -rw-r--r--
compute proper weight for "p proves p" in MaSh
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     1
#     Title:      HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     2
#     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     3
#     Copyright   2012
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     4
#
50619
b958a94cf811 new version of MaSh, with theory-level reasoning
blanchet
parents: 50482
diff changeset
     5
# An updatable sparse naive Bayes classifier.
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     6
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     7
'''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     8
Created on Jul 11, 2012
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     9
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    10
@author: Daniel Kuehlwein
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    11
'''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    12
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    13
from cPickle import dump,load
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    14
from numpy import array
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    15
from math import log
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    16
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    17
class sparseNBClassifier(object):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    18
    '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    19
    An updateable naive Bayes classifier.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    20
    '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    21
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
    22
    def __init__(self,defaultPriorWeight = 20.0,posWeight = 20.0,defVal = -15.0):
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    23
        '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    24
        Constructor
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    25
        '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    26
        self.counts = {}
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    27
        self.defaultPriorWeight = defaultPriorWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    28
        self.posWeight = posWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    29
        self.defVal = defVal
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    30
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    31
    def initializeModel(self,trainData,dicts):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    32
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    33
        Build basic model from training data.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    34
        """
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    35
        for d in trainData:            
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    36
            dFeatureCounts = {}
50997
31f9ba85dc2e compute proper weight for "p proves p" in MaSh
blanchet
parents: 50951
diff changeset
    37
            # Add p proves p with weight self.defaultPriorWeight
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    38
            if not self.defaultPriorWeight == 0:            
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    39
                for f,_w in dicts.featureDict[d]:
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    40
                    dFeatureCounts[f] = self.defaultPriorWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    41
            self.counts[d] = [self.defaultPriorWeight,dFeatureCounts]
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    42
50997
31f9ba85dc2e compute proper weight for "p proves p" in MaSh
blanchet
parents: 50951
diff changeset
    43
        for key,keyDeps in dicts.dependenciesDict.iteritems():
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    44
            for dep in keyDeps:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    45
                self.counts[dep][0] += 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    46
                depFeatures = dicts.featureDict[key]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    47
                for f,_w in depFeatures:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    48
                    if self.counts[dep][1].has_key(f):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    49
                        self.counts[dep][1][f] += 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    50
                    else:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    51
                        self.counts[dep][1][f] = 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    52
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    53
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    54
    def update(self,dataPoint,features,dependencies):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    55
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    56
        Updates the Model.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    57
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    58
        if not self.counts.has_key(dataPoint):
50827
aba769dc82e9 updated MaSh Python component
blanchet
parents: 50619
diff changeset
    59
            dFeatureCounts = {}            
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    60
            # Give p |- p a higher weight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    61
            if not self.defaultPriorWeight == 0:               
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    62
                for f,_w in features:
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    63
                    dFeatureCounts[f] = self.defaultPriorWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    64
            self.counts[dataPoint] = [self.defaultPriorWeight,dFeatureCounts]            
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    65
        for dep in dependencies:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    66
            self.counts[dep][0] += 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    67
            for f,_w in features:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    68
                if self.counts[dep][1].has_key(f):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    69
                    self.counts[dep][1][f] += 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    70
                else:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    71
                    self.counts[dep][1][f] = 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    72
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    73
    def delete(self,dataPoint,features,dependencies):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    74
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    75
        Deletes a single datapoint from the model.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    76
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    77
        for dep in dependencies:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    78
            self.counts[dep][0] -= 1
50827
aba769dc82e9 updated MaSh Python component
blanchet
parents: 50619
diff changeset
    79
            for f,_w in features:
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    80
                self.counts[dep][1][f] -= 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    81
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    82
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    83
    def overwrite(self,problemId,newDependencies,dicts):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    84
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    85
        Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    86
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    87
        assert self.counts.has_key(problemId)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    88
        oldDeps = dicts.dependenciesDict[problemId]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    89
        features = dicts.featureDict[problemId]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    90
        self.delete(problemId,features,oldDeps)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    91
        self.update(problemId,features,newDependencies)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    92
50827
aba769dc82e9 updated MaSh Python component
blanchet
parents: 50619
diff changeset
    93
    def predict(self,features,accessibles,dicts):
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    94
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    95
        For each accessible, predicts the probability of it being useful given the features.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    96
        Returns a ranking of the accessibles.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    97
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    98
        predictions = []
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    99
        for a in accessibles:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   100
            posA = self.counts[a][0]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   101
            fA = set(self.counts[a][1].keys())
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   102
            fWeightsA = self.counts[a][1]
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
   103
            resultA = log(posA)
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   104
            for f,w in features:
50619
b958a94cf811 new version of MaSh, with theory-level reasoning
blanchet
parents: 50482
diff changeset
   105
                # DEBUG
50997
31f9ba85dc2e compute proper weight for "p proves p" in MaSh
blanchet
parents: 50951
diff changeset
   106
                #w = 1.0
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   107
                if f in fA:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   108
                    if fWeightsA[f] == 0:
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
   109
                        resultA += w*self.defVal
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   110
                    else:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   111
                        assert fWeightsA[f] <= posA
50827
aba769dc82e9 updated MaSh Python component
blanchet
parents: 50619
diff changeset
   112
                        resultA += w*log(float(self.posWeight*fWeightsA[f])/posA)
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   113
                else:
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
   114
                    resultA += w*self.defVal
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   115
            predictions.append(resultA)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   116
        predictions = array(predictions)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   117
        perm = (-predictions).argsort()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   118
        return array(accessibles)[perm],predictions[perm]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   119
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   120
    def save(self,fileName):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   121
        OStream = open(fileName, 'wb')
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
   122
        dump((self.counts,self.defaultPriorWeight,self.posWeight,self.defVal),OStream)
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   123
        OStream.close()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   124
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   125
    def load(self,fileName):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   126
        OStream = open(fileName, 'rb')
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
   127
        self.counts,self.defaultPriorWeight,self.posWeight,self.defVal = load(OStream)
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   128
        OStream.close()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   129
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   130
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   131
if __name__ == '__main__':
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   132
    featureDict = {0:[0,1,2],1:[3,2,1]}
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   133
    dependenciesDict = {0:[0],1:[0,1]}
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   134
    libDicts = (featureDict,dependenciesDict,{})
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   135
    c = sparseNBClassifier()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   136
    c.initializeModel([0,1],libDicts)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   137
    c.update(2,[14,1,3],[0,2])
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   138
    print c.counts
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   139
    print c.predict([0,14],[0,1,2])
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   140
    c.storeModel('x')
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   141
    d = sparseNBClassifier()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   142
    d.loadModel('x')
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   143
    print c.counts
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   144
    print d.counts
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   145
    print 'Done'