more changes to MaSh Python program (by Daniel K.)
authorblanchet
Sat, 08 Dec 2012 13:55:26 +0100
changeset 50441 1e71f9d3cd57
parent 50440 ca99c269ca3a
child 50442 4f6a4d32522c
more changes to MaSh Python program (by Daniel K.)
src/HOL/Tools/Sledgehammer/MaSh/src/mash.py
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Sat Dec 08 00:48:51 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Sat Dec 08 13:55:26 2012 +0100
@@ -82,7 +82,7 @@
     logger.info('Using the following settings: %s',args)
     # Pick algorithm
     if args.nb:
-        logger.info('Using Naive Bayes for learning.')
+        logger.info('Using sparse Naive Bayes for learning.')
         model = NBClassifier()
         modelFile = os.path.join(args.outputDir,'NB.pickle')
     elif args.snow:
@@ -96,7 +96,7 @@
         model = Predefined(predictionFile)
         modelFile = os.path.join(args.outputDir,'mepo.pickle')        
     else:
-        logger.info('No algorithm specified. Using Naive Bayes.')
+        logger.info('No algorithm specified. Using sparse Naive Bayes.')
         model = NBClassifier()
         modelFile = os.path.join(args.outputDir,'NB.pickle')
     dictsFile = os.path.join(args.outputDir,'dicts.pickle')
@@ -113,7 +113,7 @@
         # Create Model
         trainData = dicts.featureDict.keys()
         if args.predef:
-            dicts = model.initializeModel(trainData,dicts)
+            model.initializeModel(trainData,dicts)
         else:
             model.initializeModel(trainData,dicts)
 
@@ -135,7 +135,7 @@
             model.load(modelFile)
 
         # IO Streams
-        OS = open(args.predictions,'a')
+        OS = open(args.predictions,'w')
         IS = open(args.inputFile,'r')
 
         # Statistics
@@ -148,7 +148,7 @@
 #           try:
             if True:
                 if line.startswith('!'):
-                    problemId = dicts.parse_fact(line)
+                    problemId = dicts.parse_fact(line)                    
                     # Statistics
                     if args.statistics and computeStats:
                         computeStats = False
@@ -156,7 +156,10 @@
                         if args.predef:
                             predictions = model.predict(problemId)
                         else:
-                            predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc))
+                            if args.snow:
+                                predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc),dicts)
+                            else:
+                                predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc))                        
                         stats.update(predictions,dicts.dependenciesDict[problemId],statementCounter)
                         if not stats.badPreds == []:
                             bp = string.join([str(dicts.idNameDict[x]) for x in stats.badPreds], ',')
@@ -164,7 +167,10 @@
                     statementCounter += 1
                     # Update Dependencies, p proves p
                     dicts.dependenciesDict[problemId] = [problemId]+dicts.dependenciesDict[problemId]
-                    model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId])
+                    if args.snow:
+                        model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId],dicts)
+                    else:
+                        model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId])
                 elif line.startswith('p'):
                     # Overwrite old proof.
                     problemId,newDependencies = dicts.parse_overwrite(line)
@@ -179,7 +185,10 @@
                     name,features,accessibles = dicts.parse_problem(line)
                     # Create predictions
                     logger.info('Starting computation for problem on line %s',lineCounter)
-                    predictions,predictionValues = model.predict(features,accessibles)
+                    if args.snow:
+                        predictions,predictionValues = model.predict(features,accessibles,dicts)
+                    else:
+                        predictions,predictionValues = model.predict(features,accessibles)
                     assert len(predictions) == len(predictionValues)
                     logger.info('Done. %s seconds needed.',round(time()-startTime,2))
                     # Output        
@@ -229,17 +238,23 @@
     #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/','--depFile','mash_atp_dependencies']
     #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/']
     #args = ['-i', '../data/Huffman/mash_commands','-p','../tmp/testNB.pred','-l','testNB.log','--nb','-o','../tmp/','--statistics']
-    # Arrow
-    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Arrow_Order/']    
-    #args = ['-i', '../data/Arrow_Order/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/arrowIsarNB.stats','--cutOff','500']
-    # Fundamental_Theorem_Algebra
-    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Fundamental_Theorem_Algebra/']    
-    #args = ['-i', '../data/Fundamental_Theorem_Algebra/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/Fundamental_Theorem_AlgebraIsarNB.stats','--cutOff','500']
-    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Fundamental_Theorem_Algebra/','--predef']
-    #args = ['-i', '../data/Fundamental_Theorem_Algebra/mash_commands','-p','../tmp/Fundamental_Theorem_AlgebraMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/Fundamental_Theorem_AlgebraMePo.stats']
     # Jinja
+    # ISAR
     #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/']    
     #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500']
+    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef']
+    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats']
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--depFile','mash_atp_dependencies','--snow']    
+    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies']
+
+    # ATP
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--depFile','mash_atp_dependencies']    
+    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies']
+    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef','--depFile','mash_atp_dependencies']
+    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats','--depFile','mash_atp_dependencies']
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--depFile','mash_atp_dependencies','--snow']    
+    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--snow','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies']
+
 
     
     #startTime = time()