# HG changeset patch # User blanchet # Date 1354971326 -3600 # Node ID 1e71f9d3cd57d2dbebb75e8ea54f8e147e7f5c24 # Parent ca99c269ca3aadc71090449f49266e3eb8d220bf more changes to MaSh Python program (by Daniel K.) diff -r ca99c269ca3a -r 1e71f9d3cd57 src/HOL/Tools/Sledgehammer/MaSh/src/mash.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Sat Dec 08 00:48:51 2012 +0100 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Sat Dec 08 13:55:26 2012 +0100 @@ -82,7 +82,7 @@ logger.info('Using the following settings: %s',args) # Pick algorithm if args.nb: - logger.info('Using Naive Bayes for learning.') + logger.info('Using sparse Naive Bayes for learning.') model = NBClassifier() modelFile = os.path.join(args.outputDir,'NB.pickle') elif args.snow: @@ -96,7 +96,7 @@ model = Predefined(predictionFile) modelFile = os.path.join(args.outputDir,'mepo.pickle') else: - logger.info('No algorithm specified. Using Naive Bayes.') + logger.info('No algorithm specified. Using sparse Naive Bayes.') model = NBClassifier() modelFile = os.path.join(args.outputDir,'NB.pickle') dictsFile = os.path.join(args.outputDir,'dicts.pickle') @@ -113,7 +113,7 @@ # Create Model trainData = dicts.featureDict.keys() if args.predef: - dicts = model.initializeModel(trainData,dicts) + model.initializeModel(trainData,dicts) else: model.initializeModel(trainData,dicts) @@ -135,7 +135,7 @@ model.load(modelFile) # IO Streams - OS = open(args.predictions,'a') + OS = open(args.predictions,'w') IS = open(args.inputFile,'r') # Statistics @@ -148,7 +148,7 @@ # try: if True: if line.startswith('!'): - problemId = dicts.parse_fact(line) + problemId = dicts.parse_fact(line) # Statistics if args.statistics and computeStats: computeStats = False @@ -156,7 +156,10 @@ if args.predef: predictions = model.predict(problemId) else: - predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc)) + if args.snow: + predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc),dicts) + else: + predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc)) stats.update(predictions,dicts.dependenciesDict[problemId],statementCounter) if not stats.badPreds == []: bp = string.join([str(dicts.idNameDict[x]) for x in stats.badPreds], ',') @@ -164,7 +167,10 @@ statementCounter += 1 # Update Dependencies, p proves p dicts.dependenciesDict[problemId] = [problemId]+dicts.dependenciesDict[problemId] - model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId]) + if args.snow: + model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId],dicts) + else: + model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId]) elif line.startswith('p'): # Overwrite old proof. problemId,newDependencies = dicts.parse_overwrite(line) @@ -179,7 +185,10 @@ name,features,accessibles = dicts.parse_problem(line) # Create predictions logger.info('Starting computation for problem on line %s',lineCounter) - predictions,predictionValues = model.predict(features,accessibles) + if args.snow: + predictions,predictionValues = model.predict(features,accessibles,dicts) + else: + predictions,predictionValues = model.predict(features,accessibles) assert len(predictions) == len(predictionValues) logger.info('Done. %s seconds needed.',round(time()-startTime,2)) # Output @@ -229,17 +238,23 @@ #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/','--depFile','mash_atp_dependencies'] #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/'] #args = ['-i', '../data/Huffman/mash_commands','-p','../tmp/testNB.pred','-l','testNB.log','--nb','-o','../tmp/','--statistics'] - # Arrow - #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Arrow_Order/'] - #args = ['-i', '../data/Arrow_Order/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/arrowIsarNB.stats','--cutOff','500'] - # Fundamental_Theorem_Algebra - #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Fundamental_Theorem_Algebra/'] - #args = ['-i', '../data/Fundamental_Theorem_Algebra/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/Fundamental_Theorem_AlgebraIsarNB.stats','--cutOff','500'] - #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Fundamental_Theorem_Algebra/','--predef'] - #args = ['-i', '../data/Fundamental_Theorem_Algebra/mash_commands','-p','../tmp/Fundamental_Theorem_AlgebraMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/Fundamental_Theorem_AlgebraMePo.stats'] # Jinja + # ISAR #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/'] #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500'] + #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef'] + #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats'] + #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--depFile','mash_atp_dependencies','--snow'] + #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies'] + + # ATP + #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--depFile','mash_atp_dependencies'] + #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies'] + #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef','--depFile','mash_atp_dependencies'] + #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats','--depFile','mash_atp_dependencies'] + #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--depFile','mash_atp_dependencies','--snow'] + #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--snow','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies'] + #startTime = time()