src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
author blanchet
Fri, 18 Oct 2013 13:30:09 +0200
changeset 54149 70456a8f5e6e
parent 53957 ce12e547e6bb
child 54432 68f8bd1641da
permissions -rw-r--r--
repair invariant in MaSh when learning new proofs
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     1
#     Title:      HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     2
#     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     3
#     Copyright   2012
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     4
#
50619
b958a94cf811 new version of MaSh, with theory-level reasoning
blanchet
parents: 50482
diff changeset
     5
# An updatable sparse naive Bayes classifier.
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     6
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     7
'''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     8
Created on Jul 11, 2012
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     9
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    10
@author: Daniel Kuehlwein
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    11
'''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    12
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    13
from cPickle import dump,load
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    14
from numpy import array
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    15
from math import log
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    16
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    17
class sparseNBClassifier(object):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    18
    '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    19
    An updateable naive Bayes classifier.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    20
    '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    21
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
    22
    def __init__(self,defaultPriorWeight = 20.0,posWeight = 20.0,defVal = -15.0):
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    23
        '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    24
        Constructor
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    25
        '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    26
        self.counts = {}
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    27
        self.defaultPriorWeight = defaultPriorWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    28
        self.posWeight = posWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    29
        self.defVal = defVal
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    30
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    31
    def initializeModel(self,trainData,dicts):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    32
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    33
        Build basic model from training data.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    34
        """
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    35
        for d in trainData:            
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    36
            dFeatureCounts = {}
50997
31f9ba85dc2e compute proper weight for "p proves p" in MaSh
blanchet
parents: 50951
diff changeset
    37
            # Add p proves p with weight self.defaultPriorWeight
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    38
            if not self.defaultPriorWeight == 0:            
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
    39
                for f in dicts.featureDict[d].iterkeys():
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    40
                    dFeatureCounts[f] = self.defaultPriorWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    41
            self.counts[d] = [self.defaultPriorWeight,dFeatureCounts]
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    42
50997
31f9ba85dc2e compute proper weight for "p proves p" in MaSh
blanchet
parents: 50951
diff changeset
    43
        for key,keyDeps in dicts.dependenciesDict.iteritems():
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    44
            for dep in keyDeps:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    45
                self.counts[dep][0] += 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    46
                depFeatures = dicts.featureDict[key]
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
    47
                for f in depFeatures.iterkeys():
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    48
                    if self.counts[dep][1].has_key(f):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    49
                        self.counts[dep][1][f] += 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    50
                    else:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    51
                        self.counts[dep][1][f] = 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    52
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    53
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    54
    def update(self,dataPoint,features,dependencies):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    55
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    56
        Updates the Model.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    57
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    58
        if not self.counts.has_key(dataPoint):
50827
aba769dc82e9 updated MaSh Python component
blanchet
parents: 50619
diff changeset
    59
            dFeatureCounts = {}            
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    60
            # Give p |- p a higher weight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    61
            if not self.defaultPriorWeight == 0:               
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
    62
                for f in features.iterkeys():
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    63
                    dFeatureCounts[f] = self.defaultPriorWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    64
            self.counts[dataPoint] = [self.defaultPriorWeight,dFeatureCounts]            
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    65
        for dep in dependencies:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    66
            self.counts[dep][0] += 1
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
    67
            for f in features.iterkeys():
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    68
                if self.counts[dep][1].has_key(f):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    69
                    self.counts[dep][1][f] += 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    70
                else:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    71
                    self.counts[dep][1][f] = 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    72
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    73
    def delete(self,dataPoint,features,dependencies):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    74
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    75
        Deletes a single datapoint from the model.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    76
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    77
        for dep in dependencies:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    78
            self.counts[dep][0] -= 1
53957
ce12e547e6bb fixed one line that would never have compiled in a typed language + release the lock in case of exceptions
blanchet
parents: 53555
diff changeset
    79
            for f,_w in features.items():
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    80
                self.counts[dep][1][f] -= 1
54149
70456a8f5e6e repair invariant in MaSh when learning new proofs
blanchet
parents: 53957
diff changeset
    81
                if self.counts[dep][1][f] == 0:
70456a8f5e6e repair invariant in MaSh when learning new proofs
blanchet
parents: 53957
diff changeset
    82
                    del self.counts[dep][1][f]
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    83
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    84
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    85
    def overwrite(self,problemId,newDependencies,dicts):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    86
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    87
        Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    88
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    89
        assert self.counts.has_key(problemId)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    90
        oldDeps = dicts.dependenciesDict[problemId]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    91
        features = dicts.featureDict[problemId]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    92
        self.delete(problemId,features,oldDeps)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    93
        self.update(problemId,features,newDependencies)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    94
50827
aba769dc82e9 updated MaSh Python component
blanchet
parents: 50619
diff changeset
    95
    def predict(self,features,accessibles,dicts):
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    96
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    97
        For each accessible, predicts the probability of it being useful given the features.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    98
        Returns a ranking of the accessibles.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    99
        """
53135
f08f66b55cb5 minor tweaks to MaSh tool
blanchet
parents: 53100
diff changeset
   100
        tau = 0.05 # Jasmin, change value here
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   101
        predictions = []
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
   102
        observedFeatures = features.keys()
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   103
        for a in accessibles:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   104
            posA = self.counts[a][0]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   105
            fA = set(self.counts[a][1].keys())
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   106
            fWeightsA = self.counts[a][1]
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
   107
            resultA = log(posA)
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
   108
            for f,w in features.iteritems():
50619
b958a94cf811 new version of MaSh, with theory-level reasoning
blanchet
parents: 50482
diff changeset
   109
                # DEBUG
50997
31f9ba85dc2e compute proper weight for "p proves p" in MaSh
blanchet
parents: 50951
diff changeset
   110
                #w = 1.0
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   111
                if f in fA:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   112
                    if fWeightsA[f] == 0:
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
   113
                        resultA += w*self.defVal
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   114
                    else:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   115
                        assert fWeightsA[f] <= posA
50827
aba769dc82e9 updated MaSh Python component
blanchet
parents: 50619
diff changeset
   116
                        resultA += w*log(float(self.posWeight*fWeightsA[f])/posA)
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   117
                else:
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
   118
                    resultA += w*self.defVal
53100
1133b9e83f09 new version of MaSh tool -- experimental server
blanchet
parents: 50997
diff changeset
   119
            if not tau == 0.0:
1133b9e83f09 new version of MaSh tool -- experimental server
blanchet
parents: 50997
diff changeset
   120
                missingFeatures = list(fA.difference(observedFeatures))
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
   121
                #sumOfWeights = sum([log(float(fWeightsA[x])/posA) for x in missingFeatures])  # slower
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
   122
                sumOfWeights = sum([log(float(fWeightsA[x])) for x in missingFeatures]) - log(posA) * len(missingFeatures) #DEFAULT
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
   123
                #sumOfWeights = sum([log(float(fWeightsA[x])/self.totalFeatureCounts[x]) for x in missingFeatures]) - log(posA) * len(missingFeatures)
53100
1133b9e83f09 new version of MaSh tool -- experimental server
blanchet
parents: 50997
diff changeset
   124
                resultA -= tau * sumOfWeights
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   125
            predictions.append(resultA)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   126
        predictions = array(predictions)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   127
        perm = (-predictions).argsort()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   128
        return array(accessibles)[perm],predictions[perm]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   129
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   130
    def save(self,fileName):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   131
        OStream = open(fileName, 'wb')
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
   132
        dump((self.counts,self.defaultPriorWeight,self.posWeight,self.defVal),OStream)
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   133
        OStream.close()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   134
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   135
    def load(self,fileName):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   136
        OStream = open(fileName, 'rb')
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
   137
        self.counts,self.defaultPriorWeight,self.posWeight,self.defVal = load(OStream)
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   138
        OStream.close()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   139
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   140
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   141
if __name__ == '__main__':
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   142
    featureDict = {0:[0,1,2],1:[3,2,1]}
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   143
    dependenciesDict = {0:[0],1:[0,1]}
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   144
    libDicts = (featureDict,dependenciesDict,{})
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   145
    c = sparseNBClassifier()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   146
    c.initializeModel([0,1],libDicts)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   147
    c.update(2,[14,1,3],[0,2])
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   148
    print c.counts
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   149
    print c.predict([0,14],[0,1,2])
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   150
    c.storeModel('x')
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   151
    d = sparseNBClassifier()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   152
    d.loadModel('x')
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   153
    print c.counts
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   154
    print d.counts
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   155
    print 'Done'