src/HOL/Tools/Sledgehammer/MaSh/src/mash.py
changeset 50399 52d9720f7a48
parent 50388 a5b666e0c3c2
child 50434 960a3429615c
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Thu Dec 06 11:25:10 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Thu Dec 06 11:27:44 2012 +0100
@@ -20,6 +20,7 @@
 from time import time
 from stats import Statistics
 from dictionaries import Dictionaries
+#from fullNaiveBayes import NBClassifier
 from naiveBayes import NBClassifier
 from snow import SNoW
 from predefined import Predefined
@@ -48,7 +49,7 @@
 parser.add_argument('--nb',default=False,action='store_true',help="Use Naive Bayes for learning. This is the default learning method.")
 parser.add_argument('--snow',default=False,action='store_true',help="Use SNoW's naive bayes instead of Naive Bayes for learning.")
 parser.add_argument('--predef',default=False,action='store_true',\
-                    help="Use predefined predictions. Used only for comparison with the actual learning. Expects mash_meng_paulson_suggestions in inputDir.")
+                    help="Use predefined predictions. Used only for comparison with the actual learning. Expects mash_mepo_suggestions in inputDir.")
 parser.add_argument('--statistics',default=False,action='store_true',help="Create and show statistics for the top CUTOFF predictions.\
                     WARNING: This will make the program a lot slower! Default=False.")
 parser.add_argument('--saveStats',default=None,help="If defined, stores the statistics in the filename provided.")
@@ -90,9 +91,10 @@
         modelFile = os.path.join(args.outputDir,'SNoW.pickle')
     elif args.predef:
         logger.info('Using predefined predictions.')
-        predictionFile = os.path.join(args.inputDir,'mash_meng_paulson_suggestions')
+        #predictionFile = os.path.join(args.inputDir,'mash_meng_paulson_suggestions') 
+        predictionFile = os.path.join(args.inputDir,'mash_mepo_suggestions')
         model = Predefined(predictionFile)
-        modelFile = os.path.join(args.outputDir,'isabelle.pickle')
+        modelFile = os.path.join(args.outputDir,'mepo.pickle')        
     else:
         logger.info('No algorithm specified. Using Naive Bayes.')
         model = NBClassifier()
@@ -122,7 +124,9 @@
         return 0
     # Create predictions and/or update model
     else:
-        lineCounter = 0
+        lineCounter = 1
+        statementCounter = 1
+        computeStats = False
         dicts = Dictionaries()
         # Load Files
         if os.path.isfile(dictsFile):
@@ -141,21 +145,23 @@
         predictions = None
         #Reading Input File
         for line in IS:
- #           try:
+#           try:
             if True:
                 if line.startswith('!'):
                     problemId = dicts.parse_fact(line)
                     # Statistics
-                    if args.statistics:
+                    if args.statistics and computeStats:
+                        computeStats = False
                         acc = dicts.accessibleDict[problemId]
                         if args.predef:
-                            predictions = model.predict[problemId]
+                            predictions = model.predict(problemId)
                         else:
                             predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc))
-                        stats.update(predictions,dicts.dependenciesDict[problemId])
+                        stats.update(predictions,dicts.dependenciesDict[problemId],statementCounter)
                         if not stats.badPreds == []:
                             bp = string.join([str(dicts.idNameDict[x]) for x in stats.badPreds], ',')
                             logger.debug('Bad predictions: %s',bp)
+                    statementCounter += 1
                     # Update Dependencies, p proves p
                     dicts.dependenciesDict[problemId] = [problemId]+dicts.dependenciesDict[problemId]
                     model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId])
@@ -165,8 +171,9 @@
                     newDependencies = [problemId]+newDependencies
                     model.overwrite(problemId,newDependencies,dicts)
                     dicts.dependenciesDict[problemId] = newDependencies
-                elif line.startswith('?'):
+                elif line.startswith('?'):                    
                     startTime = time()
+                    computeStats = True
                     if args.predef:
                         continue
                     name,features,accessibles = dicts.parse_problem(line)
@@ -175,18 +182,17 @@
                     predictions,predictionValues = model.predict(features,accessibles)
                     assert len(predictions) == len(predictionValues)
                     logger.info('Done. %s seconds needed.',round(time()-startTime,2))
-
-                    # Output
+                    # Output        
                     predictionNames = [str(dicts.idNameDict[p]) for p in predictions[:args.numberOfPredictions]]
                     predictionValues = [str(x) for x in predictionValues[:args.numberOfPredictions]]
                     predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))]
                     predictionsString = string.join(predictionsStringList,' ')
                     outString = '%s: %s' % (name,predictionsString)
                     OS.write('%s\n' % outString)
-                    lineCounter += 1
                 else:
                     logger.warning('Unspecified input format: \n%s',line)
                     sys.exit(-1)
+                lineCounter += 1
             """
             except:
                 logger.warning('An error occurred on line %s .',line)
@@ -216,11 +222,26 @@
     #args = ['-i', '../data/Nat/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/natATPMP.stats']
     #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Nat/']
     #args = ['-i', '../data/Nat/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/natATPNB.stats','--cutOff','500']
-    # BUG
+    # List
     #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/List/','--isabelle']
     #args = ['-i', '../data/List/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--isabelle','-o','../tmp/','--statistics']
-    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../bug/init','--init']
-    #args = ['-i', '../bug/adds/mash_commands','-p','../tmp/testNB.pred','-l','testNB.log','--nb','-o','../tmp/']
+    # Huffmann
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/','--depFile','mash_atp_dependencies']
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/']
+    #args = ['-i', '../data/Huffman/mash_commands','-p','../tmp/testNB.pred','-l','testNB.log','--nb','-o','../tmp/','--statistics']
+    # Arrow
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Arrow_Order/']    
+    #args = ['-i', '../data/Arrow_Order/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/arrowIsarNB.stats','--cutOff','500']
+    # Fundamental_Theorem_Algebra
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Fundamental_Theorem_Algebra/']    
+    #args = ['-i', '../data/Fundamental_Theorem_Algebra/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/Fundamental_Theorem_AlgebraIsarNB.stats','--cutOff','500']
+    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Fundamental_Theorem_Algebra/','--predef']
+    #args = ['-i', '../data/Fundamental_Theorem_Algebra/mash_commands','-p','../tmp/Fundamental_Theorem_AlgebraMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/Fundamental_Theorem_AlgebraMePo.stats']
+    # Jinja
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/']    
+    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500']
+
+    
     #startTime = time()
     #sys.exit(main(args))
     #print 'New ' + str(round(time()-startTime,2))