src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
author blanchet
Thu, 14 Nov 2013 15:57:48 +0100
changeset 54432 68f8bd1641da
parent 54149 70456a8f5e6e
child 54692 5ce1b9613705
permissions -rw-r--r--
have MaSh support nameless facts (i.e. proofs) and use that support
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     1
#     Title:      HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     2
#     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     3
#     Copyright   2012
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     4
#
50619
b958a94cf811 new version of MaSh, with theory-level reasoning
blanchet
parents: 50482
diff changeset
     5
# An updatable sparse naive Bayes classifier.
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     6
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     7
'''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     8
Created on Jul 11, 2012
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
     9
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    10
@author: Daniel Kuehlwein
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    11
'''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    12
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    13
from cPickle import dump,load
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    14
from numpy import array
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    15
from math import log
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    16
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    17
class sparseNBClassifier(object):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    18
    '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    19
    An updateable naive Bayes classifier.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    20
    '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    21
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
    22
    def __init__(self,defaultPriorWeight = 20.0,posWeight = 20.0,defVal = -15.0):
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    23
        '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    24
        Constructor
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    25
        '''
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    26
        self.counts = {}
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    27
        self.defaultPriorWeight = defaultPriorWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    28
        self.posWeight = posWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    29
        self.defVal = defVal
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    30
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    31
    def initializeModel(self,trainData,dicts):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    32
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    33
        Build basic model from training data.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    34
        """
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    35
        for d in trainData:            
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    36
            dFeatureCounts = {}
50997
31f9ba85dc2e compute proper weight for "p proves p" in MaSh
blanchet
parents: 50951
diff changeset
    37
            # Add p proves p with weight self.defaultPriorWeight
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    38
            if not self.defaultPriorWeight == 0:            
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
    39
                for f in dicts.featureDict[d].iterkeys():
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    40
                    dFeatureCounts[f] = self.defaultPriorWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    41
            self.counts[d] = [self.defaultPriorWeight,dFeatureCounts]
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    42
50997
31f9ba85dc2e compute proper weight for "p proves p" in MaSh
blanchet
parents: 50951
diff changeset
    43
        for key,keyDeps in dicts.dependenciesDict.iteritems():
54432
68f8bd1641da have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents: 54149
diff changeset
    44
            keyFeatures = dicts.featureDict[key]
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    45
            for dep in keyDeps:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    46
                self.counts[dep][0] += 1
54432
68f8bd1641da have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents: 54149
diff changeset
    47
                #depFeatures = dicts.featureDict[key]
68f8bd1641da have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents: 54149
diff changeset
    48
                for f in keyFeatures.iterkeys():
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    49
                    if self.counts[dep][1].has_key(f):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    50
                        self.counts[dep][1][f] += 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    51
                    else:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    52
                        self.counts[dep][1][f] = 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    53
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    54
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    55
    def update(self,dataPoint,features,dependencies):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    56
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    57
        Updates the Model.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    58
        """
54432
68f8bd1641da have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents: 54149
diff changeset
    59
        if (not self.counts.has_key(dataPoint)) and (not dataPoint == 0):
50827
aba769dc82e9 updated MaSh Python component
blanchet
parents: 50619
diff changeset
    60
            dFeatureCounts = {}            
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    61
            # Give p |- p a higher weight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    62
            if not self.defaultPriorWeight == 0:               
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
    63
                for f in features.iterkeys():
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    64
                    dFeatureCounts[f] = self.defaultPriorWeight
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
    65
            self.counts[dataPoint] = [self.defaultPriorWeight,dFeatureCounts]            
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    66
        for dep in dependencies:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    67
            self.counts[dep][0] += 1
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
    68
            for f in features.iterkeys():
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    69
                if self.counts[dep][1].has_key(f):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    70
                    self.counts[dep][1][f] += 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    71
                else:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    72
                    self.counts[dep][1][f] = 1
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    73
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    74
    def delete(self,dataPoint,features,dependencies):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    75
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    76
        Deletes a single datapoint from the model.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    77
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    78
        for dep in dependencies:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    79
            self.counts[dep][0] -= 1
53957
ce12e547e6bb fixed one line that would never have compiled in a typed language + release the lock in case of exceptions
blanchet
parents: 53555
diff changeset
    80
            for f,_w in features.items():
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    81
                self.counts[dep][1][f] -= 1
54149
70456a8f5e6e repair invariant in MaSh when learning new proofs
blanchet
parents: 53957
diff changeset
    82
                if self.counts[dep][1][f] == 0:
70456a8f5e6e repair invariant in MaSh when learning new proofs
blanchet
parents: 53957
diff changeset
    83
                    del self.counts[dep][1][f]
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    84
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    85
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    86
    def overwrite(self,problemId,newDependencies,dicts):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    87
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    88
        Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    89
        """
54432
68f8bd1641da have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents: 54149
diff changeset
    90
        try:
68f8bd1641da have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents: 54149
diff changeset
    91
            assert self.counts.has_key(problemId)
68f8bd1641da have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents: 54149
diff changeset
    92
        except:
68f8bd1641da have MaSh support nameless facts (i.e. proofs) and use that support
blanchet
parents: 54149
diff changeset
    93
            raise LookupError('Trying to overwrite dependencies for unknown fact: %s. Facts need to be introduced before overwriting them.' % dicts.idNameDict[problemId])
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    94
        oldDeps = dicts.dependenciesDict[problemId]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    95
        features = dicts.featureDict[problemId]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    96
        self.delete(problemId,features,oldDeps)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    97
        self.update(problemId,features,newDependencies)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
    98
50827
aba769dc82e9 updated MaSh Python component
blanchet
parents: 50619
diff changeset
    99
    def predict(self,features,accessibles,dicts):
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   100
        """
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   101
        For each accessible, predicts the probability of it being useful given the features.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   102
        Returns a ranking of the accessibles.
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   103
        """
53135
f08f66b55cb5 minor tweaks to MaSh tool
blanchet
parents: 53100
diff changeset
   104
        tau = 0.05 # Jasmin, change value here
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   105
        predictions = []
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
   106
        observedFeatures = features.keys()
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   107
        for a in accessibles:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   108
            posA = self.counts[a][0]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   109
            fA = set(self.counts[a][1].keys())
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   110
            fWeightsA = self.counts[a][1]
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
   111
            resultA = log(posA)
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
   112
            for f,w in features.iteritems():
50619
b958a94cf811 new version of MaSh, with theory-level reasoning
blanchet
parents: 50482
diff changeset
   113
                # DEBUG
50997
31f9ba85dc2e compute proper weight for "p proves p" in MaSh
blanchet
parents: 50951
diff changeset
   114
                #w = 1.0
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   115
                if f in fA:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   116
                    if fWeightsA[f] == 0:
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
   117
                        resultA += w*self.defVal
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   118
                    else:
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   119
                        assert fWeightsA[f] <= posA
50827
aba769dc82e9 updated MaSh Python component
blanchet
parents: 50619
diff changeset
   120
                        resultA += w*log(float(self.posWeight*fWeightsA[f])/posA)
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   121
                else:
50840
a5cc092156da new version of MaSh Python component
blanchet
parents: 50827
diff changeset
   122
                    resultA += w*self.defVal
53100
1133b9e83f09 new version of MaSh tool -- experimental server
blanchet
parents: 50997
diff changeset
   123
            if not tau == 0.0:
1133b9e83f09 new version of MaSh tool -- experimental server
blanchet
parents: 50997
diff changeset
   124
                missingFeatures = list(fA.difference(observedFeatures))
53555
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
   125
                #sumOfWeights = sum([log(float(fWeightsA[x])/posA) for x in missingFeatures])  # slower
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
   126
                sumOfWeights = sum([log(float(fWeightsA[x])) for x in missingFeatures]) - log(posA) * len(missingFeatures) #DEFAULT
12251bc889f1 new version of MaSh
blanchet
parents: 53135
diff changeset
   127
                #sumOfWeights = sum([log(float(fWeightsA[x])/self.totalFeatureCounts[x]) for x in missingFeatures]) - log(posA) * len(missingFeatures)
53100
1133b9e83f09 new version of MaSh tool -- experimental server
blanchet
parents: 50997
diff changeset
   128
                resultA -= tau * sumOfWeights
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   129
            predictions.append(resultA)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   130
        predictions = array(predictions)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   131
        perm = (-predictions).argsort()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   132
        return array(accessibles)[perm],predictions[perm]
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   133
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   134
    def save(self,fileName):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   135
        OStream = open(fileName, 'wb')
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
   136
        dump((self.counts,self.defaultPriorWeight,self.posWeight,self.defVal),OStream)
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   137
        OStream.close()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   138
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   139
    def load(self,fileName):
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   140
        OStream = open(fileName, 'rb')
50951
e1cbaa7d5536 updated MaSh
blanchet
parents: 50840
diff changeset
   141
        self.counts,self.defaultPriorWeight,self.posWeight,self.defVal = load(OStream)
50482
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   142
        OStream.close()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   143
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   144
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   145
if __name__ == '__main__':
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   146
    featureDict = {0:[0,1,2],1:[3,2,1]}
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   147
    dependenciesDict = {0:[0],1:[0,1]}
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   148
    libDicts = (featureDict,dependenciesDict,{})
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   149
    c = sparseNBClassifier()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   150
    c.initializeModel([0,1],libDicts)
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   151
    c.update(2,[14,1,3],[0,2])
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   152
    print c.counts
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   153
    print c.predict([0,14],[0,1,2])
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   154
    c.storeModel('x')
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   155
    d = sparseNBClassifier()
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   156
    d.loadModel('x')
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   157
    print c.counts
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   158
    print d.counts
d7be7ccf428b updated version of MaSh learner engine
blanchet
parents:
diff changeset
   159
    print 'Done'