author | blanchet |
Sat, 08 Dec 2012 00:48:50 +0100 | |
changeset 50434 | 960a3429615c |
parent 50399 | 52d9720f7a48 |
child 50441 | 1e71f9d3cd57 |
permissions | -rwxr-xr-x |
50220 | 1 |
#!/usr/bin/python |
50222 | 2 |
# Title: HOL/Tools/Sledgehammer/MaSh/src/mash.py |
3 |
# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen |
|
4 |
# Copyright 2012 |
|
5 |
# |
|
6 |
# Entry point for MaSh (Machine Learning for Sledgehammer). |
|
7 |
||
50220 | 8 |
''' |
9 |
MaSh - Machine Learning for Sledgehammer |
|
10 |
||
11 |
MaSh allows to use different machine learning algorithms to predict relevant fact for Sledgehammer. |
|
12 |
||
13 |
Created on July 12, 2012 |
|
14 |
||
15 |
@author: Daniel Kuehlwein |
|
16 |
''' |
|
17 |
||
18 |
import logging,datetime,string,os,sys |
|
19 |
from argparse import ArgumentParser,RawDescriptionHelpFormatter |
|
20 |
from time import time |
|
21 |
from stats import Statistics |
|
22 |
from dictionaries import Dictionaries |
|
50399 | 23 |
#from fullNaiveBayes import NBClassifier |
50220 | 24 |
from naiveBayes import NBClassifier |
25 |
from snow import SNoW |
|
26 |
from predefined import Predefined |
|
27 |
||
28 |
# Set up command-line parser |
|
29 |
parser = ArgumentParser(description='MaSh - Machine Learning for Sledgehammer. \n\n\ |
|
30 |
MaSh allows to use different machine learning algorithms to predict relevant facts for Sledgehammer.\n\n\ |
|
31 |
--------------- Example Usage ---------------\n\ |
|
50434
960a3429615c
more MaSh tweaking -- in particular, export the same facts in "MaSh_Export" as are later tried in "MaSh_Eval"
blanchet
parents:
50399
diff
changeset
|
32 |
First initialize:\n./mash.py -l test.log -o ../tmp/ --init --inputDir ../data/Jinja/ \n\ |
960a3429615c
more MaSh tweaking -- in particular, export the same facts in "MaSh_Export" as are later tried in "MaSh_Eval"
blanchet
parents:
50399
diff
changeset
|
33 |
Then create predictions:\n./mash.py -i ../data/Jinja/mash_commands -p ../data/Jinja/mash_suggestions -l test.log -o ../tmp/ --statistics\n\ |
50220 | 34 |
\n\n\ |
35 |
Author: Daniel Kuehlwein, July 2012',formatter_class=RawDescriptionHelpFormatter) |
|
36 |
parser.add_argument('-i','--inputFile',help='File containing all problems to be solved.') |
|
37 |
parser.add_argument('-o','--outputDir', default='../tmp/',help='Directory where all created files are stored. Default=../tmp/.') |
|
50388 | 38 |
parser.add_argument('-p','--predictions',default='../tmp/%s.predictions' % datetime.datetime.now(), |
50220 | 39 |
help='File where the predictions stored. Default=../tmp/dateTime.predictions.') |
40 |
parser.add_argument('--numberOfPredictions',default=200,help="Number of premises to write in the output. Default=200.",type=int) |
|
41 |
||
42 |
parser.add_argument('--init',default=False,action='store_true',help="Initialize Mash. Requires --inputDir to be defined. Default=False.") |
|
50434
960a3429615c
more MaSh tweaking -- in particular, export the same facts in "MaSh_Export" as are later tried in "MaSh_Eval"
blanchet
parents:
50399
diff
changeset
|
43 |
parser.add_argument('--inputDir',default='../data/Jinja/',\ |
50220 | 44 |
help='Directory containing all the input data. MaSh expects the following files: mash_features,mash_dependencies,mash_accessibility') |
45 |
parser.add_argument('--depFile', default='mash_dependencies', |
|
46 |
help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies') |
|
47 |
parser.add_argument('--saveModel',default=False,action='store_true',help="Stores the learned Model at the end of a prediction run. Default=False.") |
|
48 |
||
49 |
parser.add_argument('--nb',default=False,action='store_true',help="Use Naive Bayes for learning. This is the default learning method.") |
|
50 |
parser.add_argument('--snow',default=False,action='store_true',help="Use SNoW's naive bayes instead of Naive Bayes for learning.") |
|
51 |
parser.add_argument('--predef',default=False,action='store_true',\ |
|
50399 | 52 |
help="Use predefined predictions. Used only for comparison with the actual learning. Expects mash_mepo_suggestions in inputDir.") |
50220 | 53 |
parser.add_argument('--statistics',default=False,action='store_true',help="Create and show statistics for the top CUTOFF predictions.\ |
54 |
WARNING: This will make the program a lot slower! Default=False.") |
|
55 |
parser.add_argument('--saveStats',default=None,help="If defined, stores the statistics in the filename provided.") |
|
56 |
parser.add_argument('--cutOff',default=500,help="Option for statistics. Only consider the first cutOff predictions. Default=500.",type=int) |
|
57 |
parser.add_argument('-l','--log', default='../tmp/%s.log' % datetime.datetime.now(), help='Log file name. Default=../tmp/dateTime.log') |
|
58 |
parser.add_argument('-q','--quiet',default=False,action='store_true',help="If enabled, only print warnings. Default=False.") |
|
59 |
||
50388 | 60 |
def main(argv = sys.argv[1:]): |
50220 | 61 |
# Initializing command-line arguments |
62 |
args = parser.parse_args(argv) |
|
63 |
||
50388 | 64 |
# Set up logging |
50220 | 65 |
logging.basicConfig(level=logging.DEBUG, |
66 |
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', |
|
67 |
datefmt='%d-%m %H:%M:%S', |
|
68 |
filename=args.log, |
|
69 |
filemode='w') |
|
70 |
console = logging.StreamHandler(sys.stdout) |
|
71 |
console.setLevel(logging.INFO) |
|
72 |
formatter = logging.Formatter('# %(message)s') |
|
73 |
console.setFormatter(formatter) |
|
74 |
logging.getLogger('').addHandler(console) |
|
75 |
logger = logging.getLogger('main.py') |
|
76 |
if args.quiet: |
|
77 |
logger.setLevel(logging.WARNING) |
|
78 |
console.setLevel(logging.WARNING) |
|
79 |
if not os.path.exists(args.outputDir): |
|
80 |
os.makedirs(args.outputDir) |
|
81 |
||
82 |
logger.info('Using the following settings: %s',args) |
|
83 |
# Pick algorithm |
|
84 |
if args.nb: |
|
50388 | 85 |
logger.info('Using Naive Bayes for learning.') |
86 |
model = NBClassifier() |
|
50220 | 87 |
modelFile = os.path.join(args.outputDir,'NB.pickle') |
88 |
elif args.snow: |
|
89 |
logger.info('Using naive bayes (SNoW) for learning.') |
|
90 |
model = SNoW() |
|
91 |
modelFile = os.path.join(args.outputDir,'SNoW.pickle') |
|
92 |
elif args.predef: |
|
93 |
logger.info('Using predefined predictions.') |
|
50399 | 94 |
#predictionFile = os.path.join(args.inputDir,'mash_meng_paulson_suggestions') |
95 |
predictionFile = os.path.join(args.inputDir,'mash_mepo_suggestions') |
|
50220 | 96 |
model = Predefined(predictionFile) |
50399 | 97 |
modelFile = os.path.join(args.outputDir,'mepo.pickle') |
50220 | 98 |
else: |
50388 | 99 |
logger.info('No algorithm specified. Using Naive Bayes.') |
100 |
model = NBClassifier() |
|
101 |
modelFile = os.path.join(args.outputDir,'NB.pickle') |
|
102 |
dictsFile = os.path.join(args.outputDir,'dicts.pickle') |
|
103 |
||
50220 | 104 |
# Initializing model |
50388 | 105 |
if args.init: |
50220 | 106 |
logger.info('Initializing Model.') |
107 |
startTime = time() |
|
50388 | 108 |
|
109 |
# Load all data |
|
50220 | 110 |
dicts = Dictionaries() |
111 |
dicts.init_all(args.inputDir,depFileName=args.depFile) |
|
50388 | 112 |
|
50220 | 113 |
# Create Model |
114 |
trainData = dicts.featureDict.keys() |
|
115 |
if args.predef: |
|
116 |
dicts = model.initializeModel(trainData,dicts) |
|
117 |
else: |
|
118 |
model.initializeModel(trainData,dicts) |
|
50388 | 119 |
|
50220 | 120 |
model.save(modelFile) |
121 |
dicts.save(dictsFile) |
|
122 |
||
123 |
logger.info('All Done. %s seconds needed.',round(time()-startTime,2)) |
|
124 |
return 0 |
|
50388 | 125 |
# Create predictions and/or update model |
50220 | 126 |
else: |
50399 | 127 |
lineCounter = 1 |
128 |
statementCounter = 1 |
|
129 |
computeStats = False |
|
50220 | 130 |
dicts = Dictionaries() |
131 |
# Load Files |
|
132 |
if os.path.isfile(dictsFile): |
|
133 |
dicts.load(dictsFile) |
|
134 |
if os.path.isfile(modelFile): |
|
135 |
model.load(modelFile) |
|
50388 | 136 |
|
50220 | 137 |
# IO Streams |
138 |
OS = open(args.predictions,'a') |
|
139 |
IS = open(args.inputFile,'r') |
|
50388 | 140 |
|
50220 | 141 |
# Statistics |
142 |
if args.statistics: |
|
143 |
stats = Statistics(args.cutOff) |
|
50388 | 144 |
|
50220 | 145 |
predictions = None |
146 |
#Reading Input File |
|
147 |
for line in IS: |
|
50399 | 148 |
# try: |
50220 | 149 |
if True: |
150 |
if line.startswith('!'): |
|
151 |
problemId = dicts.parse_fact(line) |
|
152 |
# Statistics |
|
50399 | 153 |
if args.statistics and computeStats: |
154 |
computeStats = False |
|
50220 | 155 |
acc = dicts.accessibleDict[problemId] |
156 |
if args.predef: |
|
50399 | 157 |
predictions = model.predict(problemId) |
50220 | 158 |
else: |
50388 | 159 |
predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc)) |
50399 | 160 |
stats.update(predictions,dicts.dependenciesDict[problemId],statementCounter) |
50220 | 161 |
if not stats.badPreds == []: |
162 |
bp = string.join([str(dicts.idNameDict[x]) for x in stats.badPreds], ',') |
|
50388 | 163 |
logger.debug('Bad predictions: %s',bp) |
50399 | 164 |
statementCounter += 1 |
50220 | 165 |
# Update Dependencies, p proves p |
166 |
dicts.dependenciesDict[problemId] = [problemId]+dicts.dependenciesDict[problemId] |
|
167 |
model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId]) |
|
168 |
elif line.startswith('p'): |
|
169 |
# Overwrite old proof. |
|
170 |
problemId,newDependencies = dicts.parse_overwrite(line) |
|
171 |
newDependencies = [problemId]+newDependencies |
|
172 |
model.overwrite(problemId,newDependencies,dicts) |
|
173 |
dicts.dependenciesDict[problemId] = newDependencies |
|
50399 | 174 |
elif line.startswith('?'): |
50220 | 175 |
startTime = time() |
50399 | 176 |
computeStats = True |
50220 | 177 |
if args.predef: |
178 |
continue |
|
179 |
name,features,accessibles = dicts.parse_problem(line) |
|
180 |
# Create predictions |
|
50388 | 181 |
logger.info('Starting computation for problem on line %s',lineCounter) |
182 |
predictions,predictionValues = model.predict(features,accessibles) |
|
50220 | 183 |
assert len(predictions) == len(predictionValues) |
184 |
logger.info('Done. %s seconds needed.',round(time()-startTime,2)) |
|
50399 | 185 |
# Output |
50220 | 186 |
predictionNames = [str(dicts.idNameDict[p]) for p in predictions[:args.numberOfPredictions]] |
50388 | 187 |
predictionValues = [str(x) for x in predictionValues[:args.numberOfPredictions]] |
188 |
predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))] |
|
50220 | 189 |
predictionsString = string.join(predictionsStringList,' ') |
190 |
outString = '%s: %s' % (name,predictionsString) |
|
191 |
OS.write('%s\n' % outString) |
|
192 |
else: |
|
193 |
logger.warning('Unspecified input format: \n%s',line) |
|
194 |
sys.exit(-1) |
|
50399 | 195 |
lineCounter += 1 |
50220 | 196 |
""" |
197 |
except: |
|
198 |
logger.warning('An error occurred on line %s .',line) |
|
199 |
lineCounter += 1 |
|
200 |
continue |
|
50388 | 201 |
""" |
50220 | 202 |
OS.close() |
203 |
IS.close() |
|
50388 | 204 |
|
50220 | 205 |
# Statistics |
206 |
if args.statistics: |
|
207 |
stats.printAvg() |
|
50388 | 208 |
|
50220 | 209 |
# Save |
210 |
if args.saveModel: |
|
211 |
model.save(modelFile) |
|
212 |
dicts.save(dictsFile) |
|
213 |
if not args.saveStats == None: |
|
214 |
statsFile = os.path.join(args.outputDir,args.saveStats) |
|
215 |
stats.save(statsFile) |
|
216 |
return 0 |
|
217 |
||
218 |
if __name__ == '__main__': |
|
219 |
# Example: |
|
50434
960a3429615c
more MaSh tweaking -- in particular, export the same facts in "MaSh_Export" as are later tried in "MaSh_Eval"
blanchet
parents:
50399
diff
changeset
|
220 |
# Jinja |
960a3429615c
more MaSh tweaking -- in particular, export the same facts in "MaSh_Export" as are later tried in "MaSh_Eval"
blanchet
parents:
50399
diff
changeset
|
221 |
#args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef'] |
960a3429615c
more MaSh tweaking -- in particular, export the same facts in "MaSh_Export" as are later tried in "MaSh_Eval"
blanchet
parents:
50399
diff
changeset
|
222 |
#args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/natATPMP.stats'] |
960a3429615c
more MaSh tweaking -- in particular, export the same facts in "MaSh_Export" as are later tried in "MaSh_Eval"
blanchet
parents:
50399
diff
changeset
|
223 |
#args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/'] |
960a3429615c
more MaSh tweaking -- in particular, export the same facts in "MaSh_Export" as are later tried in "MaSh_Eval"
blanchet
parents:
50399
diff
changeset
|
224 |
#args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/natATPNB.stats','--cutOff','500'] |
50399 | 225 |
# List |
50220 | 226 |
#args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/List/','--isabelle'] |
227 |
#args = ['-i', '../data/List/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--isabelle','-o','../tmp/','--statistics'] |
|
50399 | 228 |
# Huffmann |
229 |
#args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/','--depFile','mash_atp_dependencies'] |
|
230 |
#args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/'] |
|
231 |
#args = ['-i', '../data/Huffman/mash_commands','-p','../tmp/testNB.pred','-l','testNB.log','--nb','-o','../tmp/','--statistics'] |
|
232 |
# Arrow |
|
233 |
#args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Arrow_Order/'] |
|
234 |
#args = ['-i', '../data/Arrow_Order/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/arrowIsarNB.stats','--cutOff','500'] |
|
235 |
# Fundamental_Theorem_Algebra |
|
236 |
#args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Fundamental_Theorem_Algebra/'] |
|
237 |
#args = ['-i', '../data/Fundamental_Theorem_Algebra/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/Fundamental_Theorem_AlgebraIsarNB.stats','--cutOff','500'] |
|
238 |
#args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Fundamental_Theorem_Algebra/','--predef'] |
|
239 |
#args = ['-i', '../data/Fundamental_Theorem_Algebra/mash_commands','-p','../tmp/Fundamental_Theorem_AlgebraMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/Fundamental_Theorem_AlgebraMePo.stats'] |
|
240 |
# Jinja |
|
241 |
#args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/'] |
|
242 |
#args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500'] |
|
243 |
||
244 |
||
50220 | 245 |
#startTime = time() |
246 |
#sys.exit(main(args)) |
|
50388 | 247 |
#print 'New ' + str(round(time()-startTime,2)) |
50220 | 248 |
sys.exit(main()) |