src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
changeset 50619 b958a94cf811
parent 50482 d7be7ccf428b
child 50827 aba769dc82e9
equal deleted inserted replaced
50617:9df2f825422b 50619:b958a94cf811
     1 #     Title:      HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
     1 #     Title:      HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
     2 #     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
     2 #     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
     3 #     Copyright   2012
     3 #     Copyright   2012
     4 #
     4 #
     5 # An updatable naive Bayes classifier.
     5 # An updatable sparse naive Bayes classifier.
     6 
     6 
     7 '''
     7 '''
     8 Created on Jul 11, 2012
     8 Created on Jul 11, 2012
     9 
     9 
    10 @author: Daniel Kuehlwein
    10 @author: Daniel Kuehlwein
    35             self.counts[d] = [dPosCount,dFeatureCounts]
    35             self.counts[d] = [dPosCount,dFeatureCounts]
    36 
    36 
    37         for key in dicts.dependenciesDict.keys():
    37         for key in dicts.dependenciesDict.keys():
    38             # Add p proves p
    38             # Add p proves p
    39             keyDeps = [key]+dicts.dependenciesDict[key]
    39             keyDeps = [key]+dicts.dependenciesDict[key]
    40 
       
    41             for dep in keyDeps:
    40             for dep in keyDeps:
    42                 self.counts[dep][0] += 1
    41                 self.counts[dep][0] += 1
    43                 depFeatures = dicts.featureDict[key]
    42                 depFeatures = dicts.featureDict[key]
    44                 for f,_w in depFeatures:
    43                 for f,_w in depFeatures:
    45                     if self.counts[dep][1].has_key(f):
    44                     if self.counts[dep][1].has_key(f):
    87     def predict(self,features,accessibles):
    86     def predict(self,features,accessibles):
    88         """
    87         """
    89         For each accessible, predicts the probability of it being useful given the features.
    88         For each accessible, predicts the probability of it being useful given the features.
    90         Returns a ranking of the accessibles.
    89         Returns a ranking of the accessibles.
    91         """
    90         """
       
    91         posWeight = 20.0
       
    92         defVal = 15
    92         predictions = []
    93         predictions = []
    93         for a in accessibles:
    94         for a in accessibles:
    94             posA = self.counts[a][0]
    95             posA = self.counts[a][0]
    95             fA = set(self.counts[a][1].keys())
    96             fA = set(self.counts[a][1].keys())
    96             fWeightsA = self.counts[a][1]
    97             fWeightsA = self.counts[a][1]
    97             resultA = log(posA)
    98             resultA = log(posA)
    98             for f,w in features:
    99             for f,w in features:
       
   100                 # DEBUG
       
   101                 #w = 1
    99                 if f in fA:
   102                 if f in fA:
   100                     if fWeightsA[f] == 0:
   103                     if fWeightsA[f] == 0:
   101                         resultA -= w*15
   104                         resultA -= w*defVal
   102                     else:
   105                     else:
   103                         assert fWeightsA[f] <= posA
   106                         assert fWeightsA[f] <= posA
   104                         resultA += w*log(float(fWeightsA[f])/posA)
   107                         resultA += w*log(float(posWeight*fWeightsA[f])/posA)
   105                 else:
   108                 else:
   106                     resultA -= w*15
   109                     resultA -= w*defVal
   107             predictions.append(resultA)
   110             predictions.append(resultA)
   108         #expPredictions = array([exp(x) for x in predictions])
   111         #expPredictions = array([exp(x) for x in predictions])
   109         predictions = array(predictions)
   112         predictions = array(predictions)
   110         perm = (-predictions).argsort()
   113         perm = (-predictions).argsort()
   111         #return array(accessibles)[perm],expPredictions[perm]
   114         #return array(accessibles)[perm],expPredictions[perm]