src/HOL/Tools/Sledgehammer/MaSh/src/fullNaiveBayes.py
changeset 53100 1133b9e83f09
equal deleted inserted replaced
53099:5c7780d21d24 53100:1133b9e83f09
       
     1 '''
       
     2 Created on Jul 11, 2012
       
     3 
       
     4 @author: Daniel Kuehlwein
       
     5 '''
       
     6 
       
     7 from cPickle import dump,load
       
     8 from numpy import array,exp
       
     9 from math import log
       
    10 
       
    11 class NBClassifier(object):
       
    12     '''
       
    13     An updateable naive Bayes classifier.
       
    14     '''
       
    15 
       
    16     def __init__(self):
       
    17         '''
       
    18         Constructor
       
    19         '''
       
    20         self.counts = {}
       
    21         self.negCounts = {}
       
    22     
       
    23     def initializeModel(self,trainData,dicts):
       
    24         """
       
    25         Build basic model from training data.
       
    26         """        
       
    27         for d in trainData:
       
    28             self.counts[d] = [0,{}]
       
    29             self.negCounts[d] = [0,{}]
       
    30             dAccUnExp = dicts.accessibleDict[d]
       
    31             if dicts.expandedAccessibles.has_key(d):
       
    32                 dAcc = dicts.expandedAccessibles(d)
       
    33             else:
       
    34                 if len(dicts.expandedAccessibles.keys()) >= 100:
       
    35                     dicts.expandedAccessibles = {}
       
    36                 dAcc = dicts.expand_accessibles(dAccUnExp)
       
    37                 dicts.expandedAccessibles[d] = dAcc 
       
    38             dDeps = set(dicts.dependenciesDict[d])
       
    39             dFeatures = dicts.featureDict[d]
       
    40             # d proves d
       
    41             self.counts[d][0] += 1
       
    42             for f in dFeatures:
       
    43                 if self.counts[d][1].has_key(f):
       
    44                     self.counts[d][1][f] += 1
       
    45                 else:
       
    46                     self.counts[d][1][f] = 1
       
    47             for acc in dAcc:
       
    48                 if not self.counts.has_key(acc):
       
    49                     self.counts[acc] = [0,{}]
       
    50                 if not self.negCounts.has_key(acc):
       
    51                     self.negCounts[acc] = [0,{}]        
       
    52                 if acc in dDeps:
       
    53                     self.counts[acc][0] += 1
       
    54                     for f in dFeatures:
       
    55                         if self.counts[acc][1].has_key(f):
       
    56                             self.counts[acc][1][f] += 1
       
    57                         else:
       
    58                             self.counts[acc][1][f] = 1
       
    59                 else:
       
    60                     self.negCounts[acc][0] += 1
       
    61                     for f in dFeatures:
       
    62                         if self.negCounts[acc][1].has_key(f):
       
    63                             self.negCounts[acc][1][f] += 1
       
    64                         else:
       
    65                             self.negCounts[acc][1][f] = 1
       
    66     
       
    67     def update(self,dataPoint,features,dependencies,dicts):
       
    68         """
       
    69         Updates the Model.
       
    70         """
       
    71         if not self.counts.has_key(dataPoint):
       
    72             self.counts[dataPoint] = [0,{}]
       
    73         if not self.negCounts.has_key(dataPoint):
       
    74             self.negCounts[dataPoint] = [0,{}]
       
    75         if dicts.expandedAccessibles.has_key(dataPoint):
       
    76             dAcc = dicts.expandedAccessibles(dataPoint)
       
    77         else:
       
    78             if len(dicts.expandedAccessibles.keys()) >= 100:
       
    79                 dicts.expandedAccessibles = {}
       
    80             dAccUnExp = dicts.accessibleDict[dataPoint]
       
    81             dAcc = dicts.expand_accessibles(dAccUnExp)
       
    82             dicts.expandedAccessibles[dataPoint] = dAcc 
       
    83         dDeps = set(dicts.dependenciesDict[dataPoint])
       
    84         dFeatures = dicts.featureDict[dataPoint]
       
    85         # d proves d
       
    86         self.counts[dataPoint][0] += 1
       
    87         for f in dFeatures:
       
    88             if self.counts[dataPoint][1].has_key(f):
       
    89                 self.counts[dataPoint][1][f] += 1
       
    90             else:
       
    91                 self.counts[dataPoint][1][f] = 1
       
    92 
       
    93         for acc in dAcc:
       
    94             if acc in dDeps:
       
    95                 self.counts[acc][0] += 1
       
    96                 for f in dFeatures:
       
    97                     if self.counts[acc][1].has_key(f):
       
    98                         self.counts[acc][1][f] += 1
       
    99                     else:
       
   100                         self.counts[acc][1][f] = 1
       
   101             else:
       
   102                 self.negCounts[acc][0] += 1
       
   103                 for f in dFeatures:
       
   104                     if self.negCounts[acc][1].has_key(f):
       
   105                         self.negCounts[acc][1][f] += 1
       
   106                     else:
       
   107                         self.negCounts[acc][1][f] = 1
       
   108 
       
   109     def delete(self,dataPoint,features,dependencies):
       
   110         """
       
   111         Deletes a single datapoint from the model.
       
   112         """
       
   113         for dep in dependencies:
       
   114             self.counts[dep][0] -= 1
       
   115             for f in features:
       
   116                 self.counts[dep][1][f] -= 1
       
   117 
       
   118             
       
   119     def overwrite(self,problemId,newDependencies,dicts):
       
   120         """
       
   121         Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly.
       
   122         """
       
   123         assert self.counts.has_key(problemId)
       
   124         oldDeps = dicts.dependenciesDict[problemId]
       
   125         features = dicts.featureDict[problemId]
       
   126         self.delete(problemId,features,oldDeps)
       
   127         self.update(problemId,features,newDependencies)
       
   128     
       
   129     def predict(self,features,accessibles):
       
   130         """
       
   131         For each accessible, predicts the probability of it being useful given the features.
       
   132         Returns a ranking of the accessibles.
       
   133         """
       
   134         predictions = []
       
   135         for a in accessibles:            
       
   136             posA = self.counts[a][0]
       
   137             negA = self.negCounts[a][0]
       
   138             fPosA = set(self.counts[a][1].keys())
       
   139             fNegA = set(self.negCounts[a][1].keys())
       
   140             fPosWeightsA = self.counts[a][1]
       
   141             fNegWeightsA = self.negCounts[a][1]
       
   142             if negA == 0:
       
   143                 resultA = 0 
       
   144             elif posA == 0:
       
   145                 print a
       
   146                 print 'xx'
       
   147                 import sys
       
   148                 sys.exit(-1)
       
   149             else:
       
   150                 resultA = log(posA) - log(negA) 
       
   151                 for f in features:
       
   152                     if f in fPosA:
       
   153                         # P(f | a)
       
   154                         if fPosWeightsA[f] == 0:
       
   155                             resultA -= 15
       
   156                         else:
       
   157                             assert fPosWeightsA[f] <= posA
       
   158                             resultA += log(float(fPosWeightsA[f])/posA)
       
   159                     else:
       
   160                         resultA -= 15
       
   161                     # P(f | not a)
       
   162                     if f in fNegA:
       
   163                         if fNegWeightsA[f] == 0:
       
   164                             resultA += 15
       
   165                         else:
       
   166                             assert fNegWeightsA[f] <= negA
       
   167                             resultA -= log(float(fNegWeightsA[f])/negA)
       
   168                     else: 
       
   169                         resultA += 15
       
   170 
       
   171             predictions.append(resultA)
       
   172         #expPredictions = array([exp(x) for x in predictions])
       
   173         predictions = array(predictions)
       
   174         perm = (-predictions).argsort()        
       
   175         #return array(accessibles)[perm],expPredictions[perm] 
       
   176         return array(accessibles)[perm],predictions[perm]
       
   177     
       
   178     def save(self,fileName):
       
   179         OStream = open(fileName, 'wb')
       
   180         dump((self.counts,self.negCounts),OStream)        
       
   181         OStream.close()
       
   182         
       
   183     def load(self,fileName):
       
   184         OStream = open(fileName, 'rb')
       
   185         self.counts,self.negCounts = load(OStream)      
       
   186         OStream.close()
       
   187 
       
   188     
       
   189 if __name__ == '__main__':
       
   190     featureDict = {0:[0,1,2],1:[3,2,1]}
       
   191     dependenciesDict = {0:[0],1:[0,1]}
       
   192     libDicts = (featureDict,dependenciesDict,{})
       
   193     c = NBClassifier()
       
   194     c.initializeModel([0,1],libDicts)
       
   195     c.update(2,[14,1,3],[0,2])
       
   196     print c.counts
       
   197     print c.predict([0,14],[0,1,2])
       
   198     c.storeModel('x')
       
   199     d = NBClassifier()
       
   200     d.loadModel('x')
       
   201     print c.counts
       
   202     print d.counts
       
   203     print 'Done'