1 #!/usr/bin/python |
|
2 # Title: HOL/Tools/Sledgehammer/MaSh/src/mash.py |
|
3 # Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen |
|
4 # Copyright 2012 |
|
5 # |
|
6 # Entry point for MaSh (Machine Learning for Sledgehammer). |
|
7 |
|
8 ''' |
|
9 MaSh - Machine Learning for Sledgehammer |
|
10 |
|
11 MaSh allows to use different machine learning algorithms to predict relevant fact for Sledgehammer. |
|
12 |
|
13 Created on July 12, 2012 |
|
14 |
|
15 @author: Daniel Kuehlwein |
|
16 ''' |
|
17 |
|
18 import logging,datetime,string,os,sys |
|
19 from argparse import ArgumentParser,RawDescriptionHelpFormatter |
|
20 from time import time |
|
21 from stats import Statistics |
|
22 from theoryStats import TheoryStatistics |
|
23 from theoryModels import TheoryModels |
|
24 from dictionaries import Dictionaries |
|
25 #from fullNaiveBayes import NBClassifier |
|
26 from sparseNaiveBayes import sparseNBClassifier |
|
27 from snow import SNoW |
|
28 from predefined import Predefined |
|
29 |
|
30 # Set up command-line parser |
|
31 parser = ArgumentParser(description='MaSh - Machine Learning for Sledgehammer. \n\n\ |
|
32 MaSh allows to use different machine learning algorithms to predict relevant facts for Sledgehammer.\n\n\ |
|
33 --------------- Example Usage ---------------\n\ |
|
34 First initialize:\n./mash.py -l test.log -o ../tmp/ --init --inputDir ../data/Jinja/ \n\ |
|
35 Then create predictions:\n./mash.py -i ../data/Jinja/mash_commands -p ../data/Jinja/mash_suggestions -l test.log -o ../tmp/ --statistics\n\ |
|
36 \n\n\ |
|
37 Author: Daniel Kuehlwein, July 2012',formatter_class=RawDescriptionHelpFormatter) |
|
38 parser.add_argument('-i','--inputFile',help='File containing all problems to be solved.') |
|
39 parser.add_argument('-o','--outputDir', default='../tmp/',help='Directory where all created files are stored. Default=../tmp/.') |
|
40 parser.add_argument('-p','--predictions',default='../tmp/%s.predictions' % datetime.datetime.now(), |
|
41 help='File where the predictions stored. Default=../tmp/dateTime.predictions.') |
|
42 parser.add_argument('--numberOfPredictions',default=200,help="Number of premises to write in the output. Default=200.",type=int) |
|
43 |
|
44 parser.add_argument('--init',default=False,action='store_true',help="Initialize Mash. Requires --inputDir to be defined. Default=False.") |
|
45 parser.add_argument('--inputDir',default='../data/20121212/Jinja/',\ |
|
46 help='Directory containing all the input data. MaSh expects the following files: mash_features,mash_dependencies,mash_accessibility') |
|
47 parser.add_argument('--depFile', default='mash_dependencies', |
|
48 help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies') |
|
49 parser.add_argument('--saveModel',default=False,action='store_true',help="Stores the learned Model at the end of a prediction run. Default=False.") |
|
50 |
|
51 parser.add_argument('--learnTheories',default=False,action='store_true',help="Uses a two-lvl prediction mode. First the theories, then the premises. Default=False.") |
|
52 # Theory Parameters |
|
53 parser.add_argument('--theoryDefValPos',default=-7.5,help="Default value for positive unknown features. Default=-7.5.",type=float) |
|
54 parser.add_argument('--theoryDefValNeg',default=-10.0,help="Default value for negative unknown features. Default=-15.0.",type=float) |
|
55 parser.add_argument('--theoryPosWeight',default=2.0,help="Weight value for positive features. Default=2.0.",type=float) |
|
56 |
|
57 parser.add_argument('--nb',default=False,action='store_true',help="Use Naive Bayes for learning. This is the default learning method.") |
|
58 # NB Parameters |
|
59 parser.add_argument('--NBDefaultPriorWeight',default=20.0,help="Initializes classifiers with value * p |- p. Default=20.0.",type=float) |
|
60 parser.add_argument('--NBDefVal',default=-15.0,help="Default value for unknown features. Default=-15.0.",type=float) |
|
61 parser.add_argument('--NBPosWeight',default=10.0,help="Weight value for positive features. Default=10.0.",type=float) |
|
62 # TODO: Rename to sineFeatures |
|
63 parser.add_argument('--sineFeatures',default=False,action='store_true',help="Uses a SInE like prior for premise lvl predictions. Default=False.") |
|
64 parser.add_argument('--sineWeight',default=0.5,help="How much the SInE prior is weighted. Default=0.5.",type=float) |
|
65 |
|
66 parser.add_argument('--snow',default=False,action='store_true',help="Use SNoW's naive bayes instead of Naive Bayes for learning.") |
|
67 parser.add_argument('--predef',help="Use predefined predictions. Used only for comparison with the actual learning. Argument is the filename of the predictions.") |
|
68 parser.add_argument('--statistics',default=False,action='store_true',help="Create and show statistics for the top CUTOFF predictions.\ |
|
69 WARNING: This will make the program a lot slower! Default=False.") |
|
70 parser.add_argument('--saveStats',default=None,help="If defined, stores the statistics in the filename provided.") |
|
71 parser.add_argument('--cutOff',default=500,help="Option for statistics. Only consider the first cutOff predictions. Default=500.",type=int) |
|
72 parser.add_argument('-l','--log', default='../tmp/%s.log' % datetime.datetime.now(), help='Log file name. Default=../tmp/dateTime.log') |
|
73 parser.add_argument('-q','--quiet',default=False,action='store_true',help="If enabled, only print warnings. Default=False.") |
|
74 parser.add_argument('--modelFile', default='../tmp/model.pickle', help='Model file name. Default=../tmp/model.pickle') |
|
75 parser.add_argument('--dictsFile', default='../tmp/dict.pickle', help='Dict file name. Default=../tmp/dict.pickle') |
|
76 parser.add_argument('--theoryFile', default='../tmp/theory.pickle', help='Model file name. Default=../tmp/theory.pickle') |
|
77 |
|
78 def mash(argv = sys.argv[1:]): |
|
79 # Initializing command-line arguments |
|
80 args = parser.parse_args(argv) |
|
81 |
|
82 # Set up logging |
|
83 logging.basicConfig(level=logging.DEBUG, |
|
84 format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', |
|
85 datefmt='%d-%m %H:%M:%S', |
|
86 filename=args.log, |
|
87 filemode='w') |
|
88 logger = logging.getLogger('main.py') |
|
89 |
|
90 """ |
|
91 # remove old handler for tester |
|
92 # TODO: Comment out for Jasmins version. This crashes python 2.6.1 |
|
93 logger.root.handlers[0].stream.close() |
|
94 logger.root.removeHandler(logger.root.handlers[0]) |
|
95 file_handler = logging.FileHandler(args.log) |
|
96 file_handler.setLevel(logging.DEBUG) |
|
97 formatter = logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s') |
|
98 file_handler.setFormatter(formatter) |
|
99 logger.root.addHandler(file_handler) |
|
100 #""" |
|
101 if args.quiet: |
|
102 logger.setLevel(logging.WARNING) |
|
103 #console.setLevel(logging.WARNING) |
|
104 else: |
|
105 console = logging.StreamHandler(sys.stdout) |
|
106 console.setLevel(logging.INFO) |
|
107 formatter = logging.Formatter('# %(message)s') |
|
108 console.setFormatter(formatter) |
|
109 logging.getLogger('').addHandler(console) |
|
110 |
|
111 if not os.path.exists(args.outputDir): |
|
112 os.makedirs(args.outputDir) |
|
113 |
|
114 logger.info('Using the following settings: %s',args) |
|
115 # Pick algorithm |
|
116 if args.nb: |
|
117 logger.info('Using sparse Naive Bayes for learning.') |
|
118 model = sparseNBClassifier(args.NBDefaultPriorWeight,args.NBPosWeight,args.NBDefVal) |
|
119 elif args.snow: |
|
120 logger.info('Using naive bayes (SNoW) for learning.') |
|
121 model = SNoW() |
|
122 elif args.predef: |
|
123 logger.info('Using predefined predictions.') |
|
124 model = Predefined(args.predef) |
|
125 else: |
|
126 logger.info('No algorithm specified. Using sparse Naive Bayes.') |
|
127 model = sparseNBClassifier(args.NBDefaultPriorWeight,args.NBPosWeight,args.NBDefVal) |
|
128 |
|
129 # Initializing model |
|
130 if args.init: |
|
131 logger.info('Initializing Model.') |
|
132 startTime = time() |
|
133 |
|
134 # Load all data |
|
135 dicts = Dictionaries() |
|
136 dicts.init_all(args) |
|
137 |
|
138 # Create Model |
|
139 trainData = dicts.featureDict.keys() |
|
140 model.initializeModel(trainData,dicts) |
|
141 |
|
142 if args.learnTheories: |
|
143 depFile = os.path.join(args.inputDir,args.depFile) |
|
144 theoryModels = TheoryModels(args.theoryDefValPos,args.theoryDefValNeg,args.theoryPosWeight) |
|
145 theoryModels.init(depFile,dicts) |
|
146 theoryModels.save(args.theoryFile) |
|
147 |
|
148 model.save(args.modelFile) |
|
149 dicts.save(args.dictsFile) |
|
150 |
|
151 logger.info('All Done. %s seconds needed.',round(time()-startTime,2)) |
|
152 return 0 |
|
153 # Create predictions and/or update model |
|
154 else: |
|
155 lineCounter = 1 |
|
156 statementCounter = 1 |
|
157 computeStats = False |
|
158 dicts = Dictionaries() |
|
159 theoryModels = TheoryModels(args.theoryDefValPos,args.theoryDefValNeg,args.theoryPosWeight) |
|
160 # Load Files |
|
161 if os.path.isfile(args.dictsFile): |
|
162 #logger.info('Loading Dictionaries') |
|
163 #startTime = time() |
|
164 dicts.load(args.dictsFile) |
|
165 #logger.info('Done %s',time()-startTime) |
|
166 if os.path.isfile(args.modelFile): |
|
167 #logger.info('Loading Model') |
|
168 #startTime = time() |
|
169 model.load(args.modelFile) |
|
170 #logger.info('Done %s',time()-startTime) |
|
171 if os.path.isfile(args.theoryFile) and args.learnTheories: |
|
172 #logger.info('Loading Theory Models') |
|
173 #startTime = time() |
|
174 theoryModels.load(args.theoryFile) |
|
175 #logger.info('Done %s',time()-startTime) |
|
176 logger.info('All loading completed') |
|
177 |
|
178 # IO Streams |
|
179 OS = open(args.predictions,'w') |
|
180 IS = open(args.inputFile,'r') |
|
181 |
|
182 # Statistics |
|
183 if args.statistics: |
|
184 stats = Statistics(args.cutOff) |
|
185 if args.learnTheories: |
|
186 theoryStats = TheoryStatistics() |
|
187 |
|
188 predictions = None |
|
189 predictedTheories = None |
|
190 #Reading Input File |
|
191 for line in IS: |
|
192 # try: |
|
193 if True: |
|
194 if line.startswith('!'): |
|
195 problemId = dicts.parse_fact(line) |
|
196 # Statistics |
|
197 if args.statistics and computeStats: |
|
198 computeStats = False |
|
199 # Assume '!' comes after '?' |
|
200 if args.predef: |
|
201 predictions = model.predict(problemId) |
|
202 if args.learnTheories: |
|
203 tmp = [dicts.idNameDict[x] for x in dicts.dependenciesDict[problemId]] |
|
204 usedTheories = set([x.split('.')[0] for x in tmp]) |
|
205 theoryStats.update((dicts.idNameDict[problemId]).split('.')[0],predictedTheories,usedTheories,len(theoryModels.accessibleTheories)) |
|
206 stats.update(predictions,dicts.dependenciesDict[problemId],statementCounter) |
|
207 if not stats.badPreds == []: |
|
208 bp = string.join([str(dicts.idNameDict[x]) for x in stats.badPreds], ',') |
|
209 logger.debug('Bad predictions: %s',bp) |
|
210 |
|
211 statementCounter += 1 |
|
212 # Update Dependencies, p proves p |
|
213 dicts.dependenciesDict[problemId] = [problemId]+dicts.dependenciesDict[problemId] |
|
214 if args.learnTheories: |
|
215 theoryModels.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId],dicts) |
|
216 if args.snow: |
|
217 model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId],dicts) |
|
218 else: |
|
219 model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId]) |
|
220 elif line.startswith('p'): |
|
221 # Overwrite old proof. |
|
222 problemId,newDependencies = dicts.parse_overwrite(line) |
|
223 newDependencies = [problemId]+newDependencies |
|
224 model.overwrite(problemId,newDependencies,dicts) |
|
225 if args.learnTheories: |
|
226 theoryModels.overwrite(problemId,newDependencies,dicts) |
|
227 dicts.dependenciesDict[problemId] = newDependencies |
|
228 elif line.startswith('?'): |
|
229 startTime = time() |
|
230 computeStats = True |
|
231 if args.predef: |
|
232 continue |
|
233 name,features,accessibles,hints = dicts.parse_problem(line) |
|
234 |
|
235 # Create predictions |
|
236 logger.info('Starting computation for problem on line %s',lineCounter) |
|
237 # Update Models with hints |
|
238 if not hints == []: |
|
239 if args.learnTheories: |
|
240 accessibleTheories = set([(dicts.idNameDict[x]).split('.')[0] for x in accessibles]) |
|
241 theoryModels.update_with_acc('hints',features,hints,dicts,accessibleTheories) |
|
242 if args.snow: |
|
243 pass |
|
244 else: |
|
245 model.update('hints',features,hints) |
|
246 |
|
247 # Predict premises |
|
248 if args.learnTheories: |
|
249 predictedTheories,accessibles = theoryModels.predict(features,accessibles,dicts) |
|
250 |
|
251 # Add additional features on premise lvl if sine is enabled |
|
252 if args.sineFeatures: |
|
253 origFeatures = [f for f,_w in features] |
|
254 secondaryFeatures = [] |
|
255 for f in origFeatures: |
|
256 if dicts.featureCountDict[f] == 1: |
|
257 continue |
|
258 triggeredFormulas = dicts.featureTriggeredFormulasDict[f] |
|
259 for formula in triggeredFormulas: |
|
260 tFeatures = dicts.triggerFeaturesDict[formula] |
|
261 #tFeatures = [ff for ff,_fw in dicts.featureDict[formula]] |
|
262 newFeatures = set(tFeatures).difference(secondaryFeatures+origFeatures) |
|
263 for fNew in newFeatures: |
|
264 secondaryFeatures.append((fNew,args.sineWeight)) |
|
265 predictionsFeatures = features+secondaryFeatures |
|
266 else: |
|
267 predictionsFeatures = features |
|
268 predictions,predictionValues = model.predict(predictionsFeatures,accessibles,dicts) |
|
269 assert len(predictions) == len(predictionValues) |
|
270 |
|
271 # Delete hints |
|
272 if not hints == []: |
|
273 if args.learnTheories: |
|
274 theoryModels.delete('hints',features,hints,dicts) |
|
275 if args.snow: |
|
276 pass |
|
277 else: |
|
278 model.delete('hints',features,hints) |
|
279 |
|
280 logger.info('Done. %s seconds needed.',round(time()-startTime,2)) |
|
281 # Output |
|
282 predictionNames = [str(dicts.idNameDict[p]) for p in predictions[:args.numberOfPredictions]] |
|
283 predictionValues = [str(x) for x in predictionValues[:args.numberOfPredictions]] |
|
284 predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))] |
|
285 predictionsString = string.join(predictionsStringList,' ') |
|
286 outString = '%s: %s' % (name,predictionsString) |
|
287 OS.write('%s\n' % outString) |
|
288 else: |
|
289 logger.warning('Unspecified input format: \n%s',line) |
|
290 sys.exit(-1) |
|
291 lineCounter += 1 |
|
292 """ |
|
293 except: |
|
294 logger.warning('An error occurred on line %s .',line) |
|
295 lineCounter += 1 |
|
296 continue |
|
297 """ |
|
298 OS.close() |
|
299 IS.close() |
|
300 |
|
301 # Statistics |
|
302 if args.statistics: |
|
303 if args.learnTheories: |
|
304 theoryStats.printAvg() |
|
305 stats.printAvg() |
|
306 |
|
307 # Save |
|
308 if args.saveModel: |
|
309 model.save(args.modelFile) |
|
310 if args.learnTheories: |
|
311 theoryModels.save(args.theoryFile) |
|
312 dicts.save(args.dictsFile) |
|
313 if not args.saveStats == None: |
|
314 if args.learnTheories: |
|
315 theoryStatsFile = os.path.join(args.outputDir,'theoryStats') |
|
316 theoryStats.save(theoryStatsFile) |
|
317 statsFile = os.path.join(args.outputDir,args.saveStats) |
|
318 stats.save(statsFile) |
|
319 return 0 |
|
320 |
|
321 if __name__ == '__main__': |
|
322 # Cezary Auth |
|
323 args = ['--statistics', '--init', '--inputDir', '../data/20130118/Jinja', '--log', '../tmp/auth.log', '--theoryFile', '../tmp/t0', '--modelFile', '../tmp/m0', '--dictsFile', '../tmp/d0','--NBDefaultPriorWeight', '20.0', '--NBDefVal', '-15.0', '--NBPosWeight', '10.0'] |
|
324 mash(args) |
|
325 args = ['-i', '../data/20130118/Jinja/mash_commands', '-p', '../tmp/auth.pred0', '--statistics', '--cutOff', '500', '--log', '../tmp/auth.log','--modelFile', '../tmp/m0', '--dictsFile', '../tmp/d0'] |
|
326 mash(args) |
|
327 |
|
328 #sys.exit(mash(args)) |
|
329 sys.exit(mash()) |
|