53100
|
1 |
#!/usr/bin/python
|
|
2 |
# Title: HOL/Tools/Sledgehammer/MaSh/src/mash.py
|
|
3 |
# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
|
|
4 |
# Copyright 2012
|
|
5 |
#
|
|
6 |
# Entry point for MaSh (Machine Learning for Sledgehammer).
|
|
7 |
|
|
8 |
'''
|
|
9 |
MaSh - Machine Learning for Sledgehammer
|
|
10 |
|
|
11 |
MaSh allows to use different machine learning algorithms to predict relevant fact for Sledgehammer.
|
|
12 |
|
|
13 |
Created on July 12, 2012
|
|
14 |
|
|
15 |
@author: Daniel Kuehlwein
|
|
16 |
'''
|
|
17 |
|
|
18 |
import logging,datetime,string,os,sys
|
|
19 |
from argparse import ArgumentParser,RawDescriptionHelpFormatter
|
|
20 |
from time import time
|
|
21 |
from stats import Statistics
|
|
22 |
from theoryStats import TheoryStatistics
|
|
23 |
from theoryModels import TheoryModels
|
|
24 |
from dictionaries import Dictionaries
|
|
25 |
#from fullNaiveBayes import NBClassifier
|
|
26 |
from sparseNaiveBayes import sparseNBClassifier
|
|
27 |
from snow import SNoW
|
|
28 |
from predefined import Predefined
|
|
29 |
|
|
30 |
# Set up command-line parser
|
|
31 |
parser = ArgumentParser(description='MaSh - Machine Learning for Sledgehammer. \n\n\
|
|
32 |
MaSh allows to use different machine learning algorithms to predict relevant facts for Sledgehammer.\n\n\
|
|
33 |
--------------- Example Usage ---------------\n\
|
|
34 |
First initialize:\n./mash.py -l test.log -o ../tmp/ --init --inputDir ../data/Jinja/ \n\
|
|
35 |
Then create predictions:\n./mash.py -i ../data/Jinja/mash_commands -p ../data/Jinja/mash_suggestions -l test.log -o ../tmp/ --statistics\n\
|
|
36 |
\n\n\
|
|
37 |
Author: Daniel Kuehlwein, July 2012',formatter_class=RawDescriptionHelpFormatter)
|
|
38 |
parser.add_argument('-i','--inputFile',help='File containing all problems to be solved.')
|
|
39 |
parser.add_argument('-o','--outputDir', default='../tmp/',help='Directory where all created files are stored. Default=../tmp/.')
|
|
40 |
parser.add_argument('-p','--predictions',default='../tmp/%s.predictions' % datetime.datetime.now(),
|
|
41 |
help='File where the predictions stored. Default=../tmp/dateTime.predictions.')
|
|
42 |
parser.add_argument('--numberOfPredictions',default=200,help="Number of premises to write in the output. Default=200.",type=int)
|
|
43 |
|
|
44 |
parser.add_argument('--init',default=False,action='store_true',help="Initialize Mash. Requires --inputDir to be defined. Default=False.")
|
|
45 |
parser.add_argument('--inputDir',default='../data/20121212/Jinja/',\
|
|
46 |
help='Directory containing all the input data. MaSh expects the following files: mash_features,mash_dependencies,mash_accessibility')
|
|
47 |
parser.add_argument('--depFile', default='mash_dependencies',
|
|
48 |
help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies')
|
|
49 |
parser.add_argument('--saveModel',default=False,action='store_true',help="Stores the learned Model at the end of a prediction run. Default=False.")
|
|
50 |
|
|
51 |
parser.add_argument('--learnTheories',default=False,action='store_true',help="Uses a two-lvl prediction mode. First the theories, then the premises. Default=False.")
|
|
52 |
# Theory Parameters
|
|
53 |
parser.add_argument('--theoryDefValPos',default=-7.5,help="Default value for positive unknown features. Default=-7.5.",type=float)
|
|
54 |
parser.add_argument('--theoryDefValNeg',default=-10.0,help="Default value for negative unknown features. Default=-15.0.",type=float)
|
|
55 |
parser.add_argument('--theoryPosWeight',default=2.0,help="Weight value for positive features. Default=2.0.",type=float)
|
|
56 |
|
|
57 |
parser.add_argument('--nb',default=False,action='store_true',help="Use Naive Bayes for learning. This is the default learning method.")
|
|
58 |
# NB Parameters
|
|
59 |
parser.add_argument('--NBDefaultPriorWeight',default=20.0,help="Initializes classifiers with value * p |- p. Default=20.0.",type=float)
|
|
60 |
parser.add_argument('--NBDefVal',default=-15.0,help="Default value for unknown features. Default=-15.0.",type=float)
|
|
61 |
parser.add_argument('--NBPosWeight',default=10.0,help="Weight value for positive features. Default=10.0.",type=float)
|
|
62 |
# TODO: Rename to sineFeatures
|
|
63 |
parser.add_argument('--sineFeatures',default=False,action='store_true',help="Uses a SInE like prior for premise lvl predictions. Default=False.")
|
|
64 |
parser.add_argument('--sineWeight',default=0.5,help="How much the SInE prior is weighted. Default=0.5.",type=float)
|
|
65 |
|
|
66 |
parser.add_argument('--snow',default=False,action='store_true',help="Use SNoW's naive bayes instead of Naive Bayes for learning.")
|
|
67 |
parser.add_argument('--predef',help="Use predefined predictions. Used only for comparison with the actual learning. Argument is the filename of the predictions.")
|
|
68 |
parser.add_argument('--statistics',default=False,action='store_true',help="Create and show statistics for the top CUTOFF predictions.\
|
|
69 |
WARNING: This will make the program a lot slower! Default=False.")
|
|
70 |
parser.add_argument('--saveStats',default=None,help="If defined, stores the statistics in the filename provided.")
|
|
71 |
parser.add_argument('--cutOff',default=500,help="Option for statistics. Only consider the first cutOff predictions. Default=500.",type=int)
|
|
72 |
parser.add_argument('-l','--log', default='../tmp/%s.log' % datetime.datetime.now(), help='Log file name. Default=../tmp/dateTime.log')
|
|
73 |
parser.add_argument('-q','--quiet',default=False,action='store_true',help="If enabled, only print warnings. Default=False.")
|
|
74 |
parser.add_argument('--modelFile', default='../tmp/model.pickle', help='Model file name. Default=../tmp/model.pickle')
|
|
75 |
parser.add_argument('--dictsFile', default='../tmp/dict.pickle', help='Dict file name. Default=../tmp/dict.pickle')
|
|
76 |
parser.add_argument('--theoryFile', default='../tmp/theory.pickle', help='Model file name. Default=../tmp/theory.pickle')
|
|
77 |
|
|
78 |
def mash(argv = sys.argv[1:]):
|
|
79 |
# Initializing command-line arguments
|
|
80 |
args = parser.parse_args(argv)
|
|
81 |
|
|
82 |
# Set up logging
|
|
83 |
logging.basicConfig(level=logging.DEBUG,
|
|
84 |
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
|
|
85 |
datefmt='%d-%m %H:%M:%S',
|
|
86 |
filename=args.log,
|
|
87 |
filemode='w')
|
|
88 |
logger = logging.getLogger('main.py')
|
|
89 |
|
|
90 |
"""
|
|
91 |
# remove old handler for tester
|
|
92 |
# TODO: Comment out for Jasmins version. This crashes python 2.6.1
|
|
93 |
logger.root.handlers[0].stream.close()
|
|
94 |
logger.root.removeHandler(logger.root.handlers[0])
|
|
95 |
file_handler = logging.FileHandler(args.log)
|
|
96 |
file_handler.setLevel(logging.DEBUG)
|
|
97 |
formatter = logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
|
|
98 |
file_handler.setFormatter(formatter)
|
|
99 |
logger.root.addHandler(file_handler)
|
|
100 |
#"""
|
|
101 |
if args.quiet:
|
|
102 |
logger.setLevel(logging.WARNING)
|
|
103 |
#console.setLevel(logging.WARNING)
|
|
104 |
else:
|
|
105 |
console = logging.StreamHandler(sys.stdout)
|
|
106 |
console.setLevel(logging.INFO)
|
|
107 |
formatter = logging.Formatter('# %(message)s')
|
|
108 |
console.setFormatter(formatter)
|
|
109 |
logging.getLogger('').addHandler(console)
|
|
110 |
|
|
111 |
if not os.path.exists(args.outputDir):
|
|
112 |
os.makedirs(args.outputDir)
|
|
113 |
|
|
114 |
logger.info('Using the following settings: %s',args)
|
|
115 |
# Pick algorithm
|
|
116 |
if args.nb:
|
|
117 |
logger.info('Using sparse Naive Bayes for learning.')
|
|
118 |
model = sparseNBClassifier(args.NBDefaultPriorWeight,args.NBPosWeight,args.NBDefVal)
|
|
119 |
elif args.snow:
|
|
120 |
logger.info('Using naive bayes (SNoW) for learning.')
|
|
121 |
model = SNoW()
|
|
122 |
elif args.predef:
|
|
123 |
logger.info('Using predefined predictions.')
|
|
124 |
model = Predefined(args.predef)
|
|
125 |
else:
|
|
126 |
logger.info('No algorithm specified. Using sparse Naive Bayes.')
|
|
127 |
model = sparseNBClassifier(args.NBDefaultPriorWeight,args.NBPosWeight,args.NBDefVal)
|
|
128 |
|
|
129 |
# Initializing model
|
|
130 |
if args.init:
|
|
131 |
logger.info('Initializing Model.')
|
|
132 |
startTime = time()
|
|
133 |
|
|
134 |
# Load all data
|
|
135 |
dicts = Dictionaries()
|
|
136 |
dicts.init_all(args)
|
|
137 |
|
|
138 |
# Create Model
|
|
139 |
trainData = dicts.featureDict.keys()
|
|
140 |
model.initializeModel(trainData,dicts)
|
|
141 |
|
|
142 |
if args.learnTheories:
|
|
143 |
depFile = os.path.join(args.inputDir,args.depFile)
|
|
144 |
theoryModels = TheoryModels(args.theoryDefValPos,args.theoryDefValNeg,args.theoryPosWeight)
|
|
145 |
theoryModels.init(depFile,dicts)
|
|
146 |
theoryModels.save(args.theoryFile)
|
|
147 |
|
|
148 |
model.save(args.modelFile)
|
|
149 |
dicts.save(args.dictsFile)
|
|
150 |
|
|
151 |
logger.info('All Done. %s seconds needed.',round(time()-startTime,2))
|
|
152 |
return 0
|
|
153 |
# Create predictions and/or update model
|
|
154 |
else:
|
|
155 |
lineCounter = 1
|
|
156 |
statementCounter = 1
|
|
157 |
computeStats = False
|
|
158 |
dicts = Dictionaries()
|
|
159 |
theoryModels = TheoryModels(args.theoryDefValPos,args.theoryDefValNeg,args.theoryPosWeight)
|
|
160 |
# Load Files
|
|
161 |
if os.path.isfile(args.dictsFile):
|
|
162 |
#logger.info('Loading Dictionaries')
|
|
163 |
#startTime = time()
|
|
164 |
dicts.load(args.dictsFile)
|
|
165 |
#logger.info('Done %s',time()-startTime)
|
|
166 |
if os.path.isfile(args.modelFile):
|
|
167 |
#logger.info('Loading Model')
|
|
168 |
#startTime = time()
|
|
169 |
model.load(args.modelFile)
|
|
170 |
#logger.info('Done %s',time()-startTime)
|
|
171 |
if os.path.isfile(args.theoryFile) and args.learnTheories:
|
|
172 |
#logger.info('Loading Theory Models')
|
|
173 |
#startTime = time()
|
|
174 |
theoryModels.load(args.theoryFile)
|
|
175 |
#logger.info('Done %s',time()-startTime)
|
|
176 |
logger.info('All loading completed')
|
|
177 |
|
|
178 |
# IO Streams
|
|
179 |
OS = open(args.predictions,'w')
|
|
180 |
IS = open(args.inputFile,'r')
|
|
181 |
|
|
182 |
# Statistics
|
|
183 |
if args.statistics:
|
|
184 |
stats = Statistics(args.cutOff)
|
|
185 |
if args.learnTheories:
|
|
186 |
theoryStats = TheoryStatistics()
|
|
187 |
|
|
188 |
predictions = None
|
|
189 |
predictedTheories = None
|
|
190 |
#Reading Input File
|
|
191 |
for line in IS:
|
|
192 |
# try:
|
|
193 |
if True:
|
|
194 |
if line.startswith('!'):
|
|
195 |
problemId = dicts.parse_fact(line)
|
|
196 |
# Statistics
|
|
197 |
if args.statistics and computeStats:
|
|
198 |
computeStats = False
|
|
199 |
# Assume '!' comes after '?'
|
|
200 |
if args.predef:
|
|
201 |
predictions = model.predict(problemId)
|
|
202 |
if args.learnTheories:
|
|
203 |
tmp = [dicts.idNameDict[x] for x in dicts.dependenciesDict[problemId]]
|
|
204 |
usedTheories = set([x.split('.')[0] for x in tmp])
|
|
205 |
theoryStats.update((dicts.idNameDict[problemId]).split('.')[0],predictedTheories,usedTheories,len(theoryModels.accessibleTheories))
|
|
206 |
stats.update(predictions,dicts.dependenciesDict[problemId],statementCounter)
|
|
207 |
if not stats.badPreds == []:
|
|
208 |
bp = string.join([str(dicts.idNameDict[x]) for x in stats.badPreds], ',')
|
|
209 |
logger.debug('Bad predictions: %s',bp)
|
|
210 |
|
|
211 |
statementCounter += 1
|
|
212 |
# Update Dependencies, p proves p
|
|
213 |
dicts.dependenciesDict[problemId] = [problemId]+dicts.dependenciesDict[problemId]
|
|
214 |
if args.learnTheories:
|
|
215 |
theoryModels.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId],dicts)
|
|
216 |
if args.snow:
|
|
217 |
model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId],dicts)
|
|
218 |
else:
|
|
219 |
model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId])
|
|
220 |
elif line.startswith('p'):
|
|
221 |
# Overwrite old proof.
|
|
222 |
problemId,newDependencies = dicts.parse_overwrite(line)
|
|
223 |
newDependencies = [problemId]+newDependencies
|
|
224 |
model.overwrite(problemId,newDependencies,dicts)
|
|
225 |
if args.learnTheories:
|
|
226 |
theoryModels.overwrite(problemId,newDependencies,dicts)
|
|
227 |
dicts.dependenciesDict[problemId] = newDependencies
|
|
228 |
elif line.startswith('?'):
|
|
229 |
startTime = time()
|
|
230 |
computeStats = True
|
|
231 |
if args.predef:
|
|
232 |
continue
|
|
233 |
name,features,accessibles,hints = dicts.parse_problem(line)
|
|
234 |
|
|
235 |
# Create predictions
|
|
236 |
logger.info('Starting computation for problem on line %s',lineCounter)
|
|
237 |
# Update Models with hints
|
|
238 |
if not hints == []:
|
|
239 |
if args.learnTheories:
|
|
240 |
accessibleTheories = set([(dicts.idNameDict[x]).split('.')[0] for x in accessibles])
|
|
241 |
theoryModels.update_with_acc('hints',features,hints,dicts,accessibleTheories)
|
|
242 |
if args.snow:
|
|
243 |
pass
|
|
244 |
else:
|
|
245 |
model.update('hints',features,hints)
|
|
246 |
|
|
247 |
# Predict premises
|
|
248 |
if args.learnTheories:
|
|
249 |
predictedTheories,accessibles = theoryModels.predict(features,accessibles,dicts)
|
|
250 |
|
|
251 |
# Add additional features on premise lvl if sine is enabled
|
|
252 |
if args.sineFeatures:
|
|
253 |
origFeatures = [f for f,_w in features]
|
|
254 |
secondaryFeatures = []
|
|
255 |
for f in origFeatures:
|
|
256 |
if dicts.featureCountDict[f] == 1:
|
|
257 |
continue
|
|
258 |
triggeredFormulas = dicts.featureTriggeredFormulasDict[f]
|
|
259 |
for formula in triggeredFormulas:
|
|
260 |
tFeatures = dicts.triggerFeaturesDict[formula]
|
|
261 |
#tFeatures = [ff for ff,_fw in dicts.featureDict[formula]]
|
|
262 |
newFeatures = set(tFeatures).difference(secondaryFeatures+origFeatures)
|
|
263 |
for fNew in newFeatures:
|
|
264 |
secondaryFeatures.append((fNew,args.sineWeight))
|
|
265 |
predictionsFeatures = features+secondaryFeatures
|
|
266 |
else:
|
|
267 |
predictionsFeatures = features
|
|
268 |
predictions,predictionValues = model.predict(predictionsFeatures,accessibles,dicts)
|
|
269 |
assert len(predictions) == len(predictionValues)
|
|
270 |
|
|
271 |
# Delete hints
|
|
272 |
if not hints == []:
|
|
273 |
if args.learnTheories:
|
|
274 |
theoryModels.delete('hints',features,hints,dicts)
|
|
275 |
if args.snow:
|
|
276 |
pass
|
|
277 |
else:
|
|
278 |
model.delete('hints',features,hints)
|
|
279 |
|
|
280 |
logger.info('Done. %s seconds needed.',round(time()-startTime,2))
|
|
281 |
# Output
|
|
282 |
predictionNames = [str(dicts.idNameDict[p]) for p in predictions[:args.numberOfPredictions]]
|
|
283 |
predictionValues = [str(x) for x in predictionValues[:args.numberOfPredictions]]
|
|
284 |
predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))]
|
|
285 |
predictionsString = string.join(predictionsStringList,' ')
|
|
286 |
outString = '%s: %s' % (name,predictionsString)
|
|
287 |
OS.write('%s\n' % outString)
|
|
288 |
else:
|
|
289 |
logger.warning('Unspecified input format: \n%s',line)
|
|
290 |
sys.exit(-1)
|
|
291 |
lineCounter += 1
|
|
292 |
"""
|
|
293 |
except:
|
|
294 |
logger.warning('An error occurred on line %s .',line)
|
|
295 |
lineCounter += 1
|
|
296 |
continue
|
|
297 |
"""
|
|
298 |
OS.close()
|
|
299 |
IS.close()
|
|
300 |
|
|
301 |
# Statistics
|
|
302 |
if args.statistics:
|
|
303 |
if args.learnTheories:
|
|
304 |
theoryStats.printAvg()
|
|
305 |
stats.printAvg()
|
|
306 |
|
|
307 |
# Save
|
|
308 |
if args.saveModel:
|
|
309 |
model.save(args.modelFile)
|
|
310 |
if args.learnTheories:
|
|
311 |
theoryModels.save(args.theoryFile)
|
|
312 |
dicts.save(args.dictsFile)
|
|
313 |
if not args.saveStats == None:
|
|
314 |
if args.learnTheories:
|
|
315 |
theoryStatsFile = os.path.join(args.outputDir,'theoryStats')
|
|
316 |
theoryStats.save(theoryStatsFile)
|
|
317 |
statsFile = os.path.join(args.outputDir,args.saveStats)
|
|
318 |
stats.save(statsFile)
|
|
319 |
return 0
|
|
320 |
|
|
321 |
if __name__ == '__main__':
|
|
322 |
# Cezary Auth
|
|
323 |
args = ['--statistics', '--init', '--inputDir', '../data/20130118/Jinja', '--log', '../tmp/auth.log', '--theoryFile', '../tmp/t0', '--modelFile', '../tmp/m0', '--dictsFile', '../tmp/d0','--NBDefaultPriorWeight', '20.0', '--NBDefVal', '-15.0', '--NBPosWeight', '10.0']
|
|
324 |
mash(args)
|
|
325 |
args = ['-i', '../data/20130118/Jinja/mash_commands', '-p', '../tmp/auth.pred0', '--statistics', '--cutOff', '500', '--log', '../tmp/auth.log','--modelFile', '../tmp/m0', '--dictsFile', '../tmp/d0']
|
|
326 |
mash(args)
|
|
327 |
|
|
328 |
#sys.exit(mash(args))
|
|
329 |
sys.exit(mash())
|