50220
|
1 |
#!/usr/bin/python
|
50222
|
2 |
# Title: HOL/Tools/Sledgehammer/MaSh/src/mash.py
|
|
3 |
# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
|
|
4 |
# Copyright 2012
|
|
5 |
#
|
|
6 |
# Entry point for MaSh (Machine Learning for Sledgehammer).
|
|
7 |
|
50220
|
8 |
'''
|
|
9 |
MaSh - Machine Learning for Sledgehammer
|
|
10 |
|
|
11 |
MaSh allows to use different machine learning algorithms to predict relevant fact for Sledgehammer.
|
|
12 |
|
|
13 |
Created on July 12, 2012
|
|
14 |
|
|
15 |
@author: Daniel Kuehlwein
|
|
16 |
'''
|
|
17 |
|
|
18 |
import logging,datetime,string,os,sys
|
|
19 |
from argparse import ArgumentParser,RawDescriptionHelpFormatter
|
|
20 |
from time import time
|
|
21 |
from stats import Statistics
|
|
22 |
from dictionaries import Dictionaries
|
|
23 |
from naiveBayes import NBClassifier
|
|
24 |
from snow import SNoW
|
|
25 |
from predefined import Predefined
|
|
26 |
|
|
27 |
# Set up command-line parser
|
|
28 |
parser = ArgumentParser(description='MaSh - Machine Learning for Sledgehammer. \n\n\
|
|
29 |
MaSh allows to use different machine learning algorithms to predict relevant facts for Sledgehammer.\n\n\
|
|
30 |
--------------- Example Usage ---------------\n\
|
|
31 |
First initialize:\n./mash.py -l test.log -o ../tmp/ --init --inputDir ../data/Nat/ \n\
|
|
32 |
Then create predictions:\n./mash.py -i ../data/Nat/mash_commands -p ../tmp/test.pred -l test.log -o ../tmp/ --statistics\n\
|
|
33 |
\n\n\
|
|
34 |
Author: Daniel Kuehlwein, July 2012',formatter_class=RawDescriptionHelpFormatter)
|
|
35 |
parser.add_argument('-i','--inputFile',help='File containing all problems to be solved.')
|
|
36 |
parser.add_argument('-o','--outputDir', default='../tmp/',help='Directory where all created files are stored. Default=../tmp/.')
|
|
37 |
parser.add_argument('-p','--predictions',default='../tmp/%s.predictions' % datetime.datetime.now(),
|
|
38 |
help='File where the predictions stored. Default=../tmp/dateTime.predictions.')
|
|
39 |
parser.add_argument('--numberOfPredictions',default=200,help="Number of premises to write in the output. Default=200.",type=int)
|
|
40 |
|
|
41 |
parser.add_argument('--init',default=False,action='store_true',help="Initialize Mash. Requires --inputDir to be defined. Default=False.")
|
|
42 |
parser.add_argument('--inputDir',default='../data/Nat/',\
|
|
43 |
help='Directory containing all the input data. MaSh expects the following files: mash_features,mash_dependencies,mash_accessibility')
|
|
44 |
parser.add_argument('--depFile', default='mash_dependencies',
|
|
45 |
help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies')
|
|
46 |
parser.add_argument('--saveModel',default=False,action='store_true',help="Stores the learned Model at the end of a prediction run. Default=False.")
|
|
47 |
|
|
48 |
parser.add_argument('--nb',default=False,action='store_true',help="Use Naive Bayes for learning. This is the default learning method.")
|
|
49 |
parser.add_argument('--snow',default=False,action='store_true',help="Use SNoW's naive bayes instead of Naive Bayes for learning.")
|
|
50 |
parser.add_argument('--predef',default=False,action='store_true',\
|
|
51 |
help="Use predefined predictions. Used only for comparison with the actual learning. Expects mash_meng_paulson_suggestions in inputDir.")
|
|
52 |
parser.add_argument('--statistics',default=False,action='store_true',help="Create and show statistics for the top CUTOFF predictions.\
|
|
53 |
WARNING: This will make the program a lot slower! Default=False.")
|
|
54 |
parser.add_argument('--saveStats',default=None,help="If defined, stores the statistics in the filename provided.")
|
|
55 |
parser.add_argument('--cutOff',default=500,help="Option for statistics. Only consider the first cutOff predictions. Default=500.",type=int)
|
|
56 |
parser.add_argument('-l','--log', default='../tmp/%s.log' % datetime.datetime.now(), help='Log file name. Default=../tmp/dateTime.log')
|
|
57 |
parser.add_argument('-q','--quiet',default=False,action='store_true',help="If enabled, only print warnings. Default=False.")
|
|
58 |
|
|
59 |
def main(argv = sys.argv[1:]):
|
|
60 |
# Initializing command-line arguments
|
|
61 |
args = parser.parse_args(argv)
|
|
62 |
|
|
63 |
# Set up logging
|
|
64 |
logging.basicConfig(level=logging.DEBUG,
|
|
65 |
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
|
|
66 |
datefmt='%d-%m %H:%M:%S',
|
|
67 |
filename=args.log,
|
|
68 |
filemode='w')
|
|
69 |
console = logging.StreamHandler(sys.stdout)
|
|
70 |
console.setLevel(logging.INFO)
|
|
71 |
formatter = logging.Formatter('# %(message)s')
|
|
72 |
console.setFormatter(formatter)
|
|
73 |
logging.getLogger('').addHandler(console)
|
|
74 |
logger = logging.getLogger('main.py')
|
|
75 |
if args.quiet:
|
|
76 |
logger.setLevel(logging.WARNING)
|
|
77 |
console.setLevel(logging.WARNING)
|
|
78 |
if not os.path.exists(args.outputDir):
|
|
79 |
os.makedirs(args.outputDir)
|
|
80 |
|
|
81 |
logger.info('Using the following settings: %s',args)
|
|
82 |
# Pick algorithm
|
|
83 |
if args.nb:
|
|
84 |
logger.info('Using Naive Bayes for learning.')
|
|
85 |
model = NBClassifier()
|
|
86 |
modelFile = os.path.join(args.outputDir,'NB.pickle')
|
|
87 |
elif args.snow:
|
|
88 |
logger.info('Using naive bayes (SNoW) for learning.')
|
|
89 |
model = SNoW()
|
|
90 |
modelFile = os.path.join(args.outputDir,'SNoW.pickle')
|
|
91 |
elif args.predef:
|
|
92 |
logger.info('Using predefined predictions.')
|
|
93 |
predictionFile = os.path.join(args.inputDir,'mash_meng_paulson_suggestions')
|
|
94 |
model = Predefined(predictionFile)
|
|
95 |
modelFile = os.path.join(args.outputDir,'isabelle.pickle')
|
|
96 |
else:
|
|
97 |
logger.info('No algorithm specified. Using Naive Bayes.')
|
|
98 |
model = NBClassifier()
|
|
99 |
modelFile = os.path.join(args.outputDir,'NB.pickle')
|
|
100 |
dictsFile = os.path.join(args.outputDir,'dicts.pickle')
|
|
101 |
|
|
102 |
# Initializing model
|
|
103 |
if args.init:
|
|
104 |
logger.info('Initializing Model.')
|
|
105 |
startTime = time()
|
|
106 |
|
|
107 |
# Load all data
|
|
108 |
dicts = Dictionaries()
|
|
109 |
dicts.init_all(args.inputDir,depFileName=args.depFile)
|
|
110 |
|
|
111 |
# Create Model
|
|
112 |
trainData = dicts.featureDict.keys()
|
|
113 |
if args.predef:
|
|
114 |
dicts = model.initializeModel(trainData,dicts)
|
|
115 |
else:
|
|
116 |
model.initializeModel(trainData,dicts)
|
|
117 |
|
|
118 |
model.save(modelFile)
|
|
119 |
dicts.save(dictsFile)
|
|
120 |
|
|
121 |
logger.info('All Done. %s seconds needed.',round(time()-startTime,2))
|
|
122 |
return 0
|
|
123 |
# Create predictions and/or update model
|
|
124 |
else:
|
|
125 |
lineCounter = 0
|
|
126 |
dicts = Dictionaries()
|
|
127 |
# Load Files
|
|
128 |
if os.path.isfile(dictsFile):
|
|
129 |
dicts.load(dictsFile)
|
|
130 |
if os.path.isfile(modelFile):
|
|
131 |
model.load(modelFile)
|
|
132 |
|
|
133 |
# IO Streams
|
|
134 |
OS = open(args.predictions,'a')
|
|
135 |
IS = open(args.inputFile,'r')
|
|
136 |
|
|
137 |
# Statistics
|
|
138 |
if args.statistics:
|
|
139 |
stats = Statistics(args.cutOff)
|
|
140 |
|
|
141 |
predictions = None
|
|
142 |
#Reading Input File
|
|
143 |
for line in IS:
|
|
144 |
# try:
|
|
145 |
if True:
|
|
146 |
if line.startswith('!'):
|
|
147 |
problemId = dicts.parse_fact(line)
|
|
148 |
# Statistics
|
|
149 |
if args.statistics:
|
|
150 |
acc = dicts.accessibleDict[problemId]
|
|
151 |
if args.predef:
|
|
152 |
predictions = model.predict[problemId]
|
|
153 |
else:
|
|
154 |
predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc))
|
|
155 |
stats.update(predictions,dicts.dependenciesDict[problemId])
|
|
156 |
if not stats.badPreds == []:
|
|
157 |
bp = string.join([str(dicts.idNameDict[x]) for x in stats.badPreds], ',')
|
|
158 |
logger.debug('Bad predictions: %s',bp)
|
|
159 |
# Update Dependencies, p proves p
|
|
160 |
dicts.dependenciesDict[problemId] = [problemId]+dicts.dependenciesDict[problemId]
|
|
161 |
model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId])
|
|
162 |
elif line.startswith('p'):
|
|
163 |
# Overwrite old proof.
|
|
164 |
problemId,newDependencies = dicts.parse_overwrite(line)
|
|
165 |
newDependencies = [problemId]+newDependencies
|
|
166 |
model.overwrite(problemId,newDependencies,dicts)
|
|
167 |
dicts.dependenciesDict[problemId] = newDependencies
|
|
168 |
elif line.startswith('?'):
|
|
169 |
startTime = time()
|
|
170 |
if args.predef:
|
|
171 |
continue
|
|
172 |
name,features,accessibles = dicts.parse_problem(line)
|
|
173 |
# Create predictions
|
|
174 |
logger.info('Starting computation for problem on line %s',lineCounter)
|
|
175 |
predictions,predictionValues = model.predict(features,accessibles)
|
|
176 |
assert len(predictions) == len(predictionValues)
|
|
177 |
logger.info('Done. %s seconds needed.',round(time()-startTime,2))
|
|
178 |
|
|
179 |
# Output
|
|
180 |
predictionNames = [str(dicts.idNameDict[p]) for p in predictions[:args.numberOfPredictions]]
|
|
181 |
predictionValues = [str(x) for x in predictionValues[:args.numberOfPredictions]]
|
|
182 |
predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))]
|
|
183 |
predictionsString = string.join(predictionsStringList,' ')
|
|
184 |
outString = '%s: %s' % (name,predictionsString)
|
|
185 |
OS.write('%s\n' % outString)
|
|
186 |
lineCounter += 1
|
|
187 |
else:
|
|
188 |
logger.warning('Unspecified input format: \n%s',line)
|
|
189 |
sys.exit(-1)
|
|
190 |
"""
|
|
191 |
except:
|
|
192 |
logger.warning('An error occurred on line %s .',line)
|
|
193 |
lineCounter += 1
|
|
194 |
continue
|
|
195 |
"""
|
|
196 |
OS.close()
|
|
197 |
IS.close()
|
|
198 |
|
|
199 |
# Statistics
|
|
200 |
if args.statistics:
|
|
201 |
stats.printAvg()
|
|
202 |
|
|
203 |
# Save
|
|
204 |
if args.saveModel:
|
|
205 |
model.save(modelFile)
|
|
206 |
dicts.save(dictsFile)
|
|
207 |
if not args.saveStats == None:
|
|
208 |
statsFile = os.path.join(args.outputDir,args.saveStats)
|
|
209 |
stats.save(statsFile)
|
|
210 |
return 0
|
|
211 |
|
|
212 |
if __name__ == '__main__':
|
|
213 |
# Example:
|
|
214 |
# Nat
|
|
215 |
#args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Nat/','--predef']
|
|
216 |
#args = ['-i', '../data/Nat/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/natATPMP.stats']
|
|
217 |
#args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Nat/']
|
|
218 |
#args = ['-i', '../data/Nat/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/natATPNB.stats','--cutOff','500']
|
|
219 |
# BUG
|
|
220 |
#args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/List/','--isabelle']
|
|
221 |
#args = ['-i', '../data/List/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--isabelle','-o','../tmp/','--statistics']
|
|
222 |
#args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../bug/init','--init']
|
|
223 |
#args = ['-i', '../bug/adds/mash_commands','-p','../tmp/testNB.pred','-l','testNB.log','--nb','-o','../tmp/']
|
|
224 |
#startTime = time()
|
|
225 |
#sys.exit(main(args))
|
|
226 |
#print 'New ' + str(round(time()-startTime,2))
|
|
227 |
sys.exit(main())
|
|
228 |
|