new version of MaSh, with theory-level reasoning
authorblanchet
Thu Dec 27 10:01:40 2012 +0100 (2012-12-27)
changeset 50619b958a94cf811
parent 50617 9df2f825422b
child 50620 07e08250a880
new version of MaSh, with theory-level reasoning
src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py
src/HOL/Tools/Sledgehammer/MaSh/src/mash.py
src/HOL/Tools/Sledgehammer/MaSh/src/predefined.py
src/HOL/Tools/Sledgehammer/MaSh/src/singleNaiveBayes.py
src/HOL/Tools/Sledgehammer/MaSh/src/snow.py
src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
src/HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py
src/HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py
     1.1 --- a/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Wed Dec 26 11:06:21 2012 +0100
     1.2 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Thu Dec 27 10:01:40 2012 +0100
     1.3 @@ -35,7 +35,7 @@
     1.4          self.changed = True
     1.5  
     1.6      """
     1.7 -    Init functions. Side Effect: nameIdDict, idNameDict, featureIdDict get filled!
     1.8 +    Init functions. Side Effect: nameIdDict, idNameDict, featureIdDict, articleDict get filled!
     1.9      """
    1.10      def init_featureDict(self,featureFile):
    1.11          self.featureDict,self.maxNameId,self.maxFeatureId = create_feature_dict(self.nameIdDict,self.idNameDict,self.maxNameId,self.featureIdDict,\
    1.12 @@ -175,12 +175,8 @@
    1.13                  self.expandedAccessibles[accId] = self.expand_accessibles(accIdAcc)
    1.14                  self.changed = True
    1.15          accessibles = self.expand_accessibles(unExpAcc)
    1.16 -#        # Feature Ids
    1.17 -#        featureNames = [f.strip() for f in line[1].split()]
    1.18 -#        for fn in featureNames:
    1.19 -#            self.add_feature(fn)
    1.20 -#        features = [self.featureIdDict[fn] for fn in featureNames]
    1.21          features = self.get_features(line)
    1.22 +
    1.23          return name,features,accessibles
    1.24  
    1.25      def save(self,fileName):
     2.1 --- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Wed Dec 26 11:06:21 2012 +0100
     2.2 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Thu Dec 27 10:01:40 2012 +0100
     2.3 @@ -19,10 +19,11 @@
     2.4  from argparse import ArgumentParser,RawDescriptionHelpFormatter
     2.5  from time import time
     2.6  from stats import Statistics
     2.7 +from theoryStats import TheoryStatistics
     2.8 +from theoryModels import TheoryModels
     2.9  from dictionaries import Dictionaries
    2.10  #from fullNaiveBayes import NBClassifier
    2.11  from sparseNaiveBayes import sparseNBClassifier
    2.12 -#from naiveBayes import sparseNBClassifier
    2.13  from snow import SNoW
    2.14  from predefined import Predefined
    2.15  
    2.16 @@ -41,11 +42,13 @@
    2.17  parser.add_argument('--numberOfPredictions',default=200,help="Number of premises to write in the output. Default=200.",type=int)
    2.18  
    2.19  parser.add_argument('--init',default=False,action='store_true',help="Initialize Mash. Requires --inputDir to be defined. Default=False.")
    2.20 -parser.add_argument('--inputDir',default='../data/Jinja/',\
    2.21 +parser.add_argument('--inputDir',default='../data/20121212/Jinja/',\
    2.22                      help='Directory containing all the input data. MaSh expects the following files: mash_features,mash_dependencies,mash_accessibility')
    2.23  parser.add_argument('--depFile', default='mash_dependencies',
    2.24                      help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies')
    2.25  parser.add_argument('--saveModel',default=False,action='store_true',help="Stores the learned Model at the end of a prediction run. Default=False.")
    2.26 +parser.add_argument('--learnTheories',default=False,action='store_true',help="Uses a two-lvl prediction mode. First the theories, then the premises. Default=False.")
    2.27 +
    2.28  
    2.29  parser.add_argument('--nb',default=False,action='store_true',help="Use Naive Bayes for learning. This is the default learning method.")
    2.30  parser.add_argument('--snow',default=False,action='store_true',help="Use SNoW's naive bayes instead of Naive Bayes for learning.")
    2.31 @@ -101,6 +104,7 @@
    2.32          model = sparseNBClassifier()
    2.33          modelFile = os.path.join(args.outputDir,'NB.pickle')
    2.34      dictsFile = os.path.join(args.outputDir,'dicts.pickle')
    2.35 +    theoryFile = os.path.join(args.outputDir,'theory.pickle')
    2.36  
    2.37      # Initializing model
    2.38      if args.init:
    2.39 @@ -110,14 +114,17 @@
    2.40          # Load all data
    2.41          dicts = Dictionaries()
    2.42          dicts.init_all(args.inputDir,depFileName=args.depFile)
    2.43 -
    2.44 +        
    2.45          # Create Model
    2.46          trainData = dicts.featureDict.keys()
    2.47 -        if args.predef:
    2.48 -            model.initializeModel(trainData,dicts)
    2.49 -        else:
    2.50 -            model.initializeModel(trainData,dicts)
    2.51 +        model.initializeModel(trainData,dicts)
    2.52  
    2.53 +        if args.learnTheories:
    2.54 +            depFile = os.path.join(args.inputDir,args.depFile)
    2.55 +            theoryModels = TheoryModels()
    2.56 +            theoryModels.init(depFile,dicts)
    2.57 +            theoryModels.save(theoryFile)
    2.58 +            
    2.59          model.save(modelFile)
    2.60          dicts.save(dictsFile)
    2.61  
    2.62 @@ -129,11 +136,14 @@
    2.63          statementCounter = 1
    2.64          computeStats = False
    2.65          dicts = Dictionaries()
    2.66 +        theoryModels = TheoryModels()
    2.67          # Load Files
    2.68          if os.path.isfile(dictsFile):
    2.69              dicts.load(dictsFile)
    2.70          if os.path.isfile(modelFile):
    2.71              model.load(modelFile)
    2.72 +        if os.path.isfile(theoryFile) and args.learnTheories:
    2.73 +            theoryModels.load(theoryFile)
    2.74  
    2.75          # IO Streams
    2.76          OS = open(args.predictions,'w')
    2.77 @@ -142,32 +152,37 @@
    2.78          # Statistics
    2.79          if args.statistics:
    2.80              stats = Statistics(args.cutOff)
    2.81 +            if args.learnTheories:
    2.82 +                theoryStats = TheoryStatistics()
    2.83  
    2.84          predictions = None
    2.85 +        predictedTheories = None
    2.86          #Reading Input File
    2.87          for line in IS:
    2.88  #           try:
    2.89              if True:
    2.90                  if line.startswith('!'):
    2.91 -                    problemId = dicts.parse_fact(line)                    
    2.92 +                    problemId = dicts.parse_fact(line)                        
    2.93                      # Statistics
    2.94                      if args.statistics and computeStats:
    2.95                          computeStats = False
    2.96 -                        acc = dicts.accessibleDict[problemId]
    2.97 +                        # Assume '!' comes after '?'
    2.98                          if args.predef:
    2.99                              predictions = model.predict(problemId)
   2.100 -                        else:
   2.101 -                            if args.snow:
   2.102 -                                predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc),dicts)
   2.103 -                            else:
   2.104 -                                predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc))                        
   2.105 +                        if args.learnTheories:
   2.106 +                            tmp = [dicts.idNameDict[x] for x in dicts.dependenciesDict[problemId]]
   2.107 +                            usedTheories = set([x.split('.')[0] for x in tmp]) 
   2.108 +                            theoryStats.update((dicts.idNameDict[problemId]).split('.')[0],predictedTheories,usedTheories)                        
   2.109                          stats.update(predictions,dicts.dependenciesDict[problemId],statementCounter)
   2.110                          if not stats.badPreds == []:
   2.111                              bp = string.join([str(dicts.idNameDict[x]) for x in stats.badPreds], ',')
   2.112                              logger.debug('Bad predictions: %s',bp)
   2.113 +
   2.114                      statementCounter += 1
   2.115                      # Update Dependencies, p proves p
   2.116                      dicts.dependenciesDict[problemId] = [problemId]+dicts.dependenciesDict[problemId]
   2.117 +                    if args.learnTheories:
   2.118 +                        theoryModels.update(problemId,dicts)
   2.119                      if args.snow:
   2.120                          model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId],dicts)
   2.121                      else:
   2.122 @@ -177,15 +192,19 @@
   2.123                      problemId,newDependencies = dicts.parse_overwrite(line)
   2.124                      newDependencies = [problemId]+newDependencies
   2.125                      model.overwrite(problemId,newDependencies,dicts)
   2.126 +                    if args.learnTheories:
   2.127 +                        theoryModels.overwrite(problemId,newDependencies,dicts)
   2.128                      dicts.dependenciesDict[problemId] = newDependencies
   2.129 -                elif line.startswith('?'):                    
   2.130 +                elif line.startswith('?'):               
   2.131                      startTime = time()
   2.132                      computeStats = True
   2.133                      if args.predef:
   2.134                          continue
   2.135 -                    name,features,accessibles = dicts.parse_problem(line)
   2.136 +                    name,features,accessibles = dicts.parse_problem(line)    
   2.137                      # Create predictions
   2.138                      logger.info('Starting computation for problem on line %s',lineCounter)
   2.139 +                    if args.learnTheories:
   2.140 +                        predictedTheories,accessibles = theoryModels.predict(features,accessibles,dicts)
   2.141                      if args.snow:
   2.142                          predictions,predictionValues = model.predict(features,accessibles,dicts)
   2.143                      else:
   2.144 @@ -214,13 +233,20 @@
   2.145  
   2.146          # Statistics
   2.147          if args.statistics:
   2.148 +            if args.learnTheories:
   2.149 +                theoryStats.printAvg()
   2.150              stats.printAvg()
   2.151  
   2.152          # Save
   2.153          if args.saveModel:
   2.154              model.save(modelFile)
   2.155 +            if args.learnTheories:
   2.156 +                theoryModels.save(theoryFile)
   2.157          dicts.save(dictsFile)
   2.158          if not args.saveStats == None:
   2.159 +            if args.learnTheories:
   2.160 +                theoryStatsFile = os.path.join(args.outputDir,'theoryStats')
   2.161 +                theoryStats.save(theoryStatsFile)
   2.162              statsFile = os.path.join(args.outputDir,args.saveStats)
   2.163              stats.save(statsFile)
   2.164      return 0
   2.165 @@ -228,28 +254,37 @@
   2.166  if __name__ == '__main__':
   2.167      # Example:
   2.168      # Jinja
   2.169 -    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef']
   2.170 -    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/natATPMP.stats']
   2.171 -    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/']
   2.172 -    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/natATPNB.stats','--cutOff','500']
   2.173 -    # List
   2.174 -    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/List/','--isabelle']
   2.175 -    #args = ['-i', '../data/List/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--isabelle','-o','../tmp/','--statistics']
   2.176 -    # Huffmann
   2.177 -    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/','--depFile','mash_atp_dependencies']
   2.178 -    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/']
   2.179 -    #args = ['-i', '../data/Huffman/mash_commands','-p','../tmp/testNB.pred','-l','testNB.log','--nb','-o','../tmp/','--statistics']
   2.180 -    # Jinja
   2.181 -    # ISAR
   2.182 -    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/']    
   2.183 -    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500']
   2.184 -    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef']
   2.185 -    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats']
   2.186 +    # ISAR Theories
   2.187 +    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121221/Jinja/','--learnTheories']    
   2.188 +    #args = ['-i', '../data/20121221/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--learnTheories']
   2.189 +    # ISAR NB
   2.190 +    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121221/Jinja/']    
   2.191 +    #args = ['-i', '../data/20121221/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500']
   2.192 +    # ISAR MePo
   2.193 +    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121212/Jinja/','--predef']
   2.194 +    #args = ['-i', '../data/20121212/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats']
   2.195 +    # ISAR NB ATP
   2.196 +    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121212/Jinja/','--depFile','mash_atp_dependencies']    
   2.197 +    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies']
   2.198 +    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef','--depFile','mash_atp_dependencies']
   2.199 +    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats','--depFile','mash_atp_dependencies']
   2.200      #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--depFile','mash_atp_dependencies','--snow']    
   2.201 -    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies']
   2.202 +    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--snow','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies']
   2.203 +    # ISAR Snow
   2.204 +    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121212/Jinja/','--snow']    
   2.205 +    #args = ['-i', '../data/20121212/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--snow','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500']
   2.206 + 
   2.207 +
   2.208  
   2.209 -    # ATP
   2.210 -    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--depFile','mash_atp_dependencies']    
   2.211 +    # Probability
   2.212 +    # ISAR NB
   2.213 +    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121213/Probability/']    
   2.214 +    #args = ['-i', '../data/20121213/Probability/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/ProbIsarNB.stats','--cutOff','500']
   2.215 +    # ISAR MePo
   2.216 +    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121213/Probability/','--predef']
   2.217 +    #args = ['-i', '../data/20121213/Probability/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats']
   2.218 +    # ISAR NB ATP
   2.219 +    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121212/Jinja/','--depFile','mash_atp_dependencies']    
   2.220      #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies']
   2.221      #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef','--depFile','mash_atp_dependencies']
   2.222      #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats','--depFile','mash_atp_dependencies']
     3.1 --- a/src/HOL/Tools/Sledgehammer/MaSh/src/predefined.py	Wed Dec 26 11:06:21 2012 +0100
     3.2 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/predefined.py	Thu Dec 27 10:01:40 2012 +0100
     3.3 @@ -37,8 +37,7 @@
     3.4              line = line[1].split()
     3.5              preds = [dicts.get_name_id(x.strip())for x in line]
     3.6              self.predictions[predId] = preds
     3.7 -        IS.close()
     3.8 -        return dicts
     3.9 +        IS.close()        
    3.10  
    3.11      def update(self,dataPoint,features,dependencies):
    3.12          """
     4.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     4.2 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/singleNaiveBayes.py	Thu Dec 27 10:01:40 2012 +0100
     4.3 @@ -0,0 +1,173 @@
     4.4 +#     Title:      HOL/Tools/Sledgehammer/MaSh/src/singleNaiveBayes.py
     4.5 +#     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
     4.6 +#     Copyright   2012
     4.7 +#
     4.8 +# An updatable sparse naive Bayes classifier.
     4.9 +
    4.10 +'''
    4.11 +Created on Jul 11, 2012
    4.12 +
    4.13 +@author: Daniel Kuehlwein
    4.14 +'''
    4.15 +
    4.16 +from cPickle import dump,load
    4.17 +from math import log,exp
    4.18 +
    4.19 +
    4.20 +class singleNBClassifier(object):
    4.21 +    '''
    4.22 +    An updateable naive Bayes classifier.
    4.23 +    '''
    4.24 +
    4.25 +    def __init__(self):
    4.26 +        '''
    4.27 +        Constructor
    4.28 +        '''
    4.29 +        self.neg = 0.0
    4.30 +        self.pos = 0.0
    4.31 +        self.counts = {} # Counts is the tuple poscounts,negcounts
    4.32 +    
    4.33 +    def update(self,features,label):
    4.34 +        """
    4.35 +        Updates the Model.
    4.36 +        
    4.37 +        @param label: True or False, True if the features belong to a positive label, false else.
    4.38 +        """
    4.39 +        #print label,self.pos,self.neg,self.counts
    4.40 +        if label:
    4.41 +            self.pos += 1
    4.42 +        else:
    4.43 +            self.neg += 1
    4.44 +        
    4.45 +        for f,_w in features:
    4.46 +            if not self.counts.has_key(f):
    4.47 +                fPosCount = 0.0
    4.48 +                fNegCount = 0.0
    4.49 +                self.counts[f] = [fPosCount,fNegCount]
    4.50 +            posCount,negCount = self.counts[f]
    4.51 +            if label:
    4.52 +                posCount += 1
    4.53 +            else:
    4.54 +                negCount += 1
    4.55 +            self.counts[f] = [posCount,negCount]
    4.56 +        #print label,self.pos,self.neg,self.counts
    4.57 +                
    4.58 + 
    4.59 +    def delete(self,features,label):
    4.60 +        """
    4.61 +        Deletes a single datapoint from the model.
    4.62 +        """
    4.63 +        if label:
    4.64 +            self.pos -= 1
    4.65 +        else:
    4.66 +            self.neg -= 1
    4.67 +        for f in features:
    4.68 +            posCount,negCount = self.counts[f]
    4.69 +            if label:
    4.70 +                posCount -= 1
    4.71 +            else:
    4.72 +                negCount -= 1
    4.73 +            self.counts[f] = [posCount,negCount]
    4.74 +
    4.75 +            
    4.76 +    def overwrite(self,features,label):
    4.77 +        """
    4.78 +        Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly.
    4.79 +        """
    4.80 +        self.delete(features,label)
    4.81 +        self.update(features,label)
    4.82 +    
    4.83 +    def predict_sparse(self,features):
    4.84 +        """
    4.85 +        Returns 1 if the probability is greater than 50%.
    4.86 +        """
    4.87 +        if self.neg == 0:
    4.88 +            return 1
    4.89 +        elif self.pos ==0:
    4.90 +            return 0
    4.91 +        defValPos = -7.5       
    4.92 +        defValNeg = -15.0
    4.93 +        posWeight = 10.0
    4.94 +        
    4.95 +        logneg = log(self.neg)
    4.96 +        logpos = log(self.pos)
    4.97 +        prob = logpos - logneg
    4.98 +        
    4.99 +        for f,_w in features:
   4.100 +            if self.counts.has_key(f):
   4.101 +                posCount,negCount = self.counts[f]
   4.102 +                if posCount > 0:
   4.103 +                    prob += (log(posWeight * posCount) - logpos)
   4.104 +                else:
   4.105 +                    prob += defValPos
   4.106 +                if negCount > 0:
   4.107 +                    prob -= (log(negCount) - logneg)
   4.108 +                else:
   4.109 +                    prob -= defValNeg 
   4.110 +        if prob >= 0 : 
   4.111 +            return 1
   4.112 +        else:
   4.113 +            return 0
   4.114 +    
   4.115 +    def predict(self,features):    
   4.116 +        """
   4.117 +        Returns 1 if the probability is greater than 50%.
   4.118 +        """
   4.119 +        if self.neg == 0:
   4.120 +            return 1
   4.121 +        elif self.pos ==0:
   4.122 +            return 0
   4.123 +        defVal = -15.0       
   4.124 +        expDefVal = exp(defVal)
   4.125 +        
   4.126 +        logneg = log(self.neg)
   4.127 +        logpos = log(self.pos)
   4.128 +        prob = logpos - logneg
   4.129 +        
   4.130 +        for f in self.counts.keys():
   4.131 +            posCount,negCount = self.counts[f]
   4.132 +            if f in features:
   4.133 +                if posCount == 0:
   4.134 +                    prob += defVal
   4.135 +                else:
   4.136 +                    prob += log(float(posCount)/self.pos)
   4.137 +                if negCount == 0:
   4.138 +                    prob -= defVal
   4.139 +                else:
   4.140 +                    prob -= log(float(negCount)/self.neg)
   4.141 +            else:
   4.142 +                if posCount == self.pos:
   4.143 +                    prob += log(1-expDefVal)
   4.144 +                else:
   4.145 +                    prob += log(1-float(posCount)/self.pos)
   4.146 +                if negCount == self.neg:
   4.147 +                    prob -= log(1-expDefVal)
   4.148 +                else:
   4.149 +                    prob -= log(1-float(negCount)/self.neg)
   4.150 +
   4.151 +        if prob >= 0 : 
   4.152 +            return 1
   4.153 +        else:
   4.154 +            return 0        
   4.155 +        
   4.156 +    def save(self,fileName):
   4.157 +        OStream = open(fileName, 'wb')
   4.158 +        dump(self.counts,OStream)        
   4.159 +        OStream.close()
   4.160 +        
   4.161 +    def load(self,fileName):
   4.162 +        OStream = open(fileName, 'rb')
   4.163 +        self.counts = load(OStream)      
   4.164 +        OStream.close()
   4.165 +
   4.166 +if __name__ == '__main__':
   4.167 +    x = singleNBClassifier()
   4.168 +    x.update([0], True)
   4.169 +    assert x.predict([0]) == 1
   4.170 +    x = singleNBClassifier()
   4.171 +    x.update([0], False)
   4.172 +    assert x.predict([0]) == 0    
   4.173 +    
   4.174 +    x.update([0], True)
   4.175 +    x.update([1], True)
   4.176 +    print x.pos,x.neg,x.predict([0,1])
   4.177 \ No newline at end of file
     5.1 --- a/src/HOL/Tools/Sledgehammer/MaSh/src/snow.py	Wed Dec 26 11:06:21 2012 +0100
     5.2 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/snow.py	Thu Dec 27 10:01:40 2012 +0100
     5.3 @@ -5,16 +5,14 @@
     5.4  # Wrapper for SNoW.
     5.5  
     5.6  '''
     5.7 -THIS FILE IS NOT UP TO DATE!
     5.8 -NEEDS SOME FIXING BEFORE IT WILL WORK WITH THE MAIN ALGORITHM
     5.9  
    5.10  Created on Jul 12, 2012
    5.11  
    5.12  @author: daniel
    5.13  '''
    5.14  
    5.15 -import logging,shlex,subprocess,string
    5.16 -from cPickle import load,dump
    5.17 +import logging,shlex,subprocess,string,shutil
    5.18 +#from cPickle import load,dump
    5.19  
    5.20  class SNoW(object):
    5.21      '''
    5.22 @@ -29,6 +27,7 @@
    5.23          self.SNoWTrainFile = '../tmp/snow.train'
    5.24          self.SNoWTestFile = '../snow.test'
    5.25          self.SNoWNetFile = '../tmp/snow.net'
    5.26 +        self.defMaxNameId = 20000
    5.27  
    5.28      def initializeModel(self,trainData,dicts):
    5.29          """
    5.30 @@ -38,7 +37,8 @@
    5.31          self.logger.debug('Creating IO Files')
    5.32          OS = open(self.SNoWTrainFile,'w')
    5.33          for nameId in trainData:
    5.34 -            features = [f+dicts.maxNameId for f in dicts.featureDict[nameId]]
    5.35 +            features = [f+dicts.maxNameId for f,_w in dicts.featureDict[nameId]]
    5.36 +            #features = [f+self.defMaxNameId for f,_w in dicts.featureDict[nameId]]
    5.37              features = map(str,features)
    5.38              featureString = string.join(features,',')
    5.39              dependencies = dicts.dependenciesDict[nameId]
    5.40 @@ -51,25 +51,51 @@
    5.41          # Build Model
    5.42          self.logger.debug('Building Model START.')
    5.43          snowTrainCommand = '../bin/snow -train -M+ -I %s -F %s -g- -B :0-%s' % (self.SNoWTrainFile,self.SNoWNetFile,dicts.maxNameId-1)
    5.44 +        #snowTrainCommand = '../bin/snow -train -M+ -I %s -F %s -g- -B :0-%s' % (self.SNoWTrainFile,self.SNoWNetFile,self.defMaxNameId-1)
    5.45          args = shlex.split(snowTrainCommand)
    5.46          p = subprocess.Popen(args,stdout=subprocess.PIPE,stderr=subprocess.STDOUT)
    5.47          p.wait()
    5.48          self.logger.debug('Building Model END.')
    5.49  
    5.50 -
    5.51      def update(self,dataPoint,features,dependencies,dicts):
    5.52          """
    5.53          Updates the Model.
    5.54 -        THIS IS NOT WORKING ATM< BUT WE DONT CARE
    5.55 +        """
    5.56          """
    5.57          self.logger.debug('Updating Model START')
    5.58 -        trainData = dicts.featureDict.keys()
    5.59 -        self.initializeModel(trainData,dicts)
    5.60 +        # Ignore Feature weights        
    5.61 +        features = [f+self.defMaxNameId for f,_w in features]
    5.62 +        
    5.63 +        OS = open(self.SNoWTestFile,'w')
    5.64 +        features = map(str,features)
    5.65 +        featureString = string.join(features, ',')
    5.66 +        dependencies = map(str,dependencies)
    5.67 +        dependenciesString = string.join(dependencies,',')
    5.68 +        snowString = string.join([featureString,dependenciesString],',')+':\n'
    5.69 +        OS.write(snowString)
    5.70 +        OS.close()
    5.71 +        snowTestCommand = '../bin/snow -test -I %s -F %s -o allboth -i+' % (self.SNoWTestFile,self.SNoWNetFile) 
    5.72 +        args = shlex.split(snowTestCommand)
    5.73 +        p = subprocess.Popen(args,stdout=subprocess.PIPE,stderr=subprocess.STDOUT)
    5.74 +        (_lines, _stderrdata) = p.communicate()
    5.75 +        # Move new net file        
    5.76 +        src = self.SNoWNetFile+'.new'
    5.77 +        dst = self.SNoWNetFile
    5.78 +        shutil.move(src, dst)        
    5.79          self.logger.debug('Updating Model END')
    5.80 +        """
    5.81 +        # Do nothing, only update at evaluation. Is a lot faster.
    5.82 +        pass
    5.83  
    5.84  
    5.85      def predict(self,features,accessibles,dicts):
    5.86 -        logger = logging.getLogger('predict_SNoW')
    5.87 +        trainData = dicts.featureDict.keys()
    5.88 +        self.initializeModel(trainData, dicts)        
    5.89 +        
    5.90 +        logger = logging.getLogger('predict_SNoW')        
    5.91 +        # Ignore Feature weights
    5.92 +        #features = [f+self.defMaxNameId for f,_w in features]
    5.93 +        features = [f+dicts.maxNameId for f,_w in features]
    5.94  
    5.95          OS = open(self.SNoWTestFile,'w')
    5.96          features = map(str,features)
    5.97 @@ -87,17 +113,22 @@
    5.98          assert lines[9].startswith('Example ')
    5.99          assert lines[-4] == ''
   5.100          predictionsCon = []
   5.101 +        predictionsValues = []
   5.102          for line in lines[10:-4]:
   5.103              premiseId = int(line.split()[0][:-1])
   5.104              predictionsCon.append(premiseId)
   5.105 -        return predictionsCon
   5.106 +            val = line.split()[4]
   5.107 +            if val.endswith('*'):
   5.108 +                val = float(val[:-1])
   5.109 +            else:
   5.110 +                val = float(val)
   5.111 +            predictionsValues.append(val)
   5.112 +        return predictionsCon,predictionsValues
   5.113  
   5.114      def save(self,fileName):
   5.115 -        OStream = open(fileName, 'wb')
   5.116 -        dump(self.counts,OStream)
   5.117 -        OStream.close()
   5.118 -
   5.119 +        # Nothing to do since we don't update
   5.120 +        pass
   5.121 +    
   5.122      def load(self,fileName):
   5.123 -        OStream = open(fileName, 'rb')
   5.124 -        self.counts = load(OStream)
   5.125 -        OStream.close()
   5.126 +        # Nothing to do since we don't update
   5.127 +        pass
     6.1 --- a/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py	Wed Dec 26 11:06:21 2012 +0100
     6.2 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py	Thu Dec 27 10:01:40 2012 +0100
     6.3 @@ -2,7 +2,7 @@
     6.4  #     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
     6.5  #     Copyright   2012
     6.6  #
     6.7 -# An updatable naive Bayes classifier.
     6.8 +# An updatable sparse naive Bayes classifier.
     6.9  
    6.10  '''
    6.11  Created on Jul 11, 2012
    6.12 @@ -37,7 +37,6 @@
    6.13          for key in dicts.dependenciesDict.keys():
    6.14              # Add p proves p
    6.15              keyDeps = [key]+dicts.dependenciesDict[key]
    6.16 -
    6.17              for dep in keyDeps:
    6.18                  self.counts[dep][0] += 1
    6.19                  depFeatures = dicts.featureDict[key]
    6.20 @@ -89,6 +88,8 @@
    6.21          For each accessible, predicts the probability of it being useful given the features.
    6.22          Returns a ranking of the accessibles.
    6.23          """
    6.24 +        posWeight = 20.0
    6.25 +        defVal = 15
    6.26          predictions = []
    6.27          for a in accessibles:
    6.28              posA = self.counts[a][0]
    6.29 @@ -96,14 +97,16 @@
    6.30              fWeightsA = self.counts[a][1]
    6.31              resultA = log(posA)
    6.32              for f,w in features:
    6.33 +                # DEBUG
    6.34 +                #w = 1
    6.35                  if f in fA:
    6.36                      if fWeightsA[f] == 0:
    6.37 -                        resultA -= w*15
    6.38 +                        resultA -= w*defVal
    6.39                      else:
    6.40                          assert fWeightsA[f] <= posA
    6.41 -                        resultA += w*log(float(fWeightsA[f])/posA)
    6.42 +                        resultA += w*log(float(posWeight*fWeightsA[f])/posA)
    6.43                  else:
    6.44 -                    resultA -= w*15
    6.45 +                    resultA -= w*defVal
    6.46              predictions.append(resultA)
    6.47          #expPredictions = array([exp(x) for x in predictions])
    6.48          predictions = array(predictions)
     7.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     7.2 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py	Thu Dec 27 10:01:40 2012 +0100
     7.3 @@ -0,0 +1,136 @@
     7.4 +#     Title:      HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py
     7.5 +#     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
     7.6 +#     Copyright   2012
     7.7 +#
     7.8 +# An updatable sparse naive Bayes classifier.
     7.9 +
    7.10 +'''
    7.11 +Created on Dec 26, 2012
    7.12 +
    7.13 +@author: Daniel Kuehlwein
    7.14 +'''
    7.15 +
    7.16 +from singleNaiveBayes import singleNBClassifier
    7.17 +from cPickle import load,dump
    7.18 +import sys,logging
    7.19 +
    7.20 +class TheoryModels(object):
    7.21 +    '''
    7.22 +    MetaClass for all the theory models.
    7.23 +    '''
    7.24 +
    7.25 +
    7.26 +    def __init__(self):
    7.27 +        '''
    7.28 +        Constructor
    7.29 +        '''
    7.30 +        self.theoryModels = {}
    7.31 +        self.theoryDict = {}
    7.32 +        self.accessibleTheories = []
    7.33 +        self.currentTheory = None
    7.34 +  
    7.35 +    def init(self,depFile,dicts):      
    7.36 +        logger = logging.getLogger('TheoryModels')
    7.37 +        IS = open(depFile,'r')
    7.38 +        for line in IS:
    7.39 +            line = line.split(':')
    7.40 +            name = line[0]
    7.41 +            theory = name.split('.')[0]
    7.42 +            # Name Id
    7.43 +            if not dicts.nameIdDict.has_key(name):
    7.44 +                logger.warning('%s is missing in nameIdDict. Aborting.',name)
    7.45 +                sys.exit(-1)
    7.46 +    
    7.47 +            nameId = dicts.nameIdDict[name]
    7.48 +            features = dicts.featureDict[nameId]
    7.49 +            if not self.theoryDict.has_key(theory):
    7.50 +                assert not theory == self.currentTheory
    7.51 +                if not self.currentTheory == None:
    7.52 +                    self.accessibleTheories.append(self.currentTheory)
    7.53 +                self.currentTheory = theory
    7.54 +                self.theoryDict[theory] = set([nameId])
    7.55 +                theoryModel = singleNBClassifier()
    7.56 +                self.theoryModels[theory] = theoryModel 
    7.57 +            else:
    7.58 +                self.theoryDict[theory] = self.theoryDict[theory].union([nameId])               
    7.59 +            
    7.60 +            # Find the actually used theories
    7.61 +            usedtheories = []    
    7.62 +            dependencies = line[1].split()
    7.63 +            if len(dependencies) == 0:
    7.64 +                continue
    7.65 +            for dep in dependencies:
    7.66 +                depId = dicts.nameIdDict[dep.strip()]
    7.67 +                deptheory = dep.split('.')[0]
    7.68 +                usedtheories.append(deptheory)
    7.69 +                if not self.theoryDict.has_key(deptheory):
    7.70 +                    self.theoryDict[deptheory] = set([depId])
    7.71 +                else:
    7.72 +                    self.theoryDict[deptheory] = self.theoryDict[deptheory].union([depId])                   
    7.73 +                        
    7.74 +            # Update theoryModels
    7.75 +            self.theoryModels[self.currentTheory].update(features,self.currentTheory in usedtheories)
    7.76 +            for a in self.accessibleTheories:                
    7.77 +                self.theoryModels[a].update(dicts.featureDict[nameId],a in usedtheories)
    7.78 +        IS.close()
    7.79 +    
    7.80 +    def overwrite(self,problemId,newDependencies,dicts):
    7.81 +        pass
    7.82 +    
    7.83 +    def delete(self):
    7.84 +        pass
    7.85 +    
    7.86 +    def update(self,problemId,dicts):        
    7.87 +        features = dicts.featureDict[problemId]
    7.88 +        
    7.89 +        # Find the actually used theories
    7.90 +        tmp = [dicts.idNameDict[x] for x in dicts.dependenciesDict[problemId]]
    7.91 +        usedTheories = set([x.split('.')[0] for x in tmp]) 
    7.92 +        currentTheory = (dicts.idNameDict[problemId]).split('.')[0]       
    7.93 +        # Create new theory model, if there is a new theory 
    7.94 +        if not self.theoryDict.has_key(currentTheory):
    7.95 +            assert not currentTheory == self.currentTheory
    7.96 +            if not currentTheory == None:
    7.97 +                self.theoryDict[currentTheory] = []
    7.98 +                self.currentTheory = currentTheory
    7.99 +                theoryModel = singleNBClassifier()
   7.100 +                self.theoryModels[currentTheory] = theoryModel          
   7.101 +        if not len(usedTheories) == 0:
   7.102 +            for a in self.accessibleTheories:                
   7.103 +                self.theoryModels[a].update(features,a in usedTheories)   
   7.104 +    
   7.105 +    def predict(self,features,accessibles,dicts):
   7.106 +        """
   7.107 +        Predicts the relevant theories. Returns the predicted theories and a list of all accessible premises in these theories.
   7.108 +        """         
   7.109 +        # TODO: This can be made a lot faster!    
   7.110 +        self.accessibleTheories = []
   7.111 +        for x in accessibles:
   7.112 +            xArt = (dicts.idNameDict[x]).split('.')[0]
   7.113 +            self.accessibleTheories.append(xArt)
   7.114 +        self.accessibleTheories = set(self.accessibleTheories)
   7.115 +        
   7.116 +        # Predict Theories
   7.117 +        predictedTheories = [self.currentTheory]
   7.118 +        for a in self.accessibleTheories:
   7.119 +            if self.theoryModels[a].predict_sparse(features):
   7.120 +            #if theoryModels[a].predict(dicts.featureDict[nameId]):
   7.121 +                predictedTheories.append(a)
   7.122 +        predictedTheories = set(predictedTheories)
   7.123 +
   7.124 +        # Delete accessibles in unpredicted theories
   7.125 +        newAcc = []
   7.126 +        for x in accessibles:
   7.127 +            xArt = (dicts.idNameDict[x]).split('.')[0]
   7.128 +            if xArt in predictedTheories:
   7.129 +                newAcc.append(x)
   7.130 +        return predictedTheories,newAcc
   7.131 +        
   7.132 +    def save(self,fileName):
   7.133 +        outStream = open(fileName, 'wb')
   7.134 +        dump((self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict),outStream)
   7.135 +        outStream.close()
   7.136 +    def load(self,fileName):
   7.137 +        inStream = open(fileName, 'rb')
   7.138 +        self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict = load(inStream)
   7.139 +        inStream.close()
   7.140 \ No newline at end of file
     8.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     8.2 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py	Thu Dec 27 10:01:40 2012 +0100
     8.3 @@ -0,0 +1,63 @@
     8.4 +#     Title:      HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py
     8.5 +#     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
     8.6 +#     Copyright   2012
     8.7 +#
     8.8 +# An updatable sparse naive Bayes classifier.
     8.9 +
    8.10 +'''
    8.11 +Created on Dec 26, 2012
    8.12 +
    8.13 +@author: Daniel Kuehlwein
    8.14 +'''
    8.15 +
    8.16 +from cPickle import load,dump
    8.17 +import logging,string
    8.18 +
    8.19 +class TheoryStatistics(object):
    8.20 +    '''
    8.21 +    Stores statistics for theory lvl predictions
    8.22 +    '''
    8.23 +
    8.24 +
    8.25 +    def __init__(self):
    8.26 +        '''
    8.27 +        Constructor
    8.28 +        '''
    8.29 +        self.logger = logging.getLogger('TheoryStatistics')
    8.30 +        self.count = 0
    8.31 +        self.precision = 0.0
    8.32 +        self.recall100 = 0
    8.33 +        self.recall = 0.0
    8.34 +        self.predicted = 0.0
    8.35 +    
    8.36 +    def update(self,currentTheory,predictedTheories,usedTheories):
    8.37 +        self.count += 1
    8.38 +        allPredTheories = predictedTheories.union([currentTheory])
    8.39 +        if set(usedTheories).issubset(allPredTheories):
    8.40 +            self.recall100 += 1
    8.41 +        localPredicted = len(allPredTheories)
    8.42 +        self.predicted += localPredicted 
    8.43 +        localPrec = float(len(set(usedTheories).intersection(allPredTheories))) / localPredicted
    8.44 +        self.precision += localPrec
    8.45 +        localRecall = float(len(set(usedTheories).intersection(allPredTheories))) / len(set(usedTheories))
    8.46 +        self.recall += localRecall
    8.47 +        self.logger.info('Theory prediction results:')
    8.48 +        self.logger.info('Problem: %s \t Recall100: %s \t Precision: %s \t Recall: %s \t PredictedTeories: %s',\
    8.49 +                         self.count,self.recall100,round(localPrec,2),round(localRecall,2),localPredicted)
    8.50 +        
    8.51 +    def printAvg(self):
    8.52 +        self.logger.info('Average theory results:')
    8.53 +        self.logger.info('avgPrecision: %s \t avgRecall100: %s \t avgRecall: %s \t avgPredicted:%s', \
    8.54 +                         round(self.precision/self.count,2),\
    8.55 +                         round(float(self.recall100)/self.count,2),\
    8.56 +                         round(self.recall/self.count,2),\
    8.57 +                         round(self.predicted /self.count,2))
    8.58 +        
    8.59 +    def save(self,fileName):
    8.60 +        oStream = open(fileName, 'wb')
    8.61 +        dump((self.count,self.precision,self.recall100,self.recall,self.predicted),oStream)
    8.62 +        oStream.close()
    8.63 +    def load(self,fileName):
    8.64 +        iStream = open(fileName, 'rb')
    8.65 +        self.count,self.precision,self.recall100,self.recall,self.predicted = load(iStream)
    8.66 +        iStream.close()
    8.67 \ No newline at end of file