# HG changeset patch # User wenzelm # Date 1377287969 -7200 # Node ID b881bee69d3acd6047b55d17f3c23c3a21b37002 # Parent a5805fe4e91c32e0377548e72299afe76e7f5469# Parent 31e24d6ff1eaf8930eb55c5e9c236d699eb0be62 merged diff -r 31e24d6ff1ea -r b881bee69d3a src/HOL/TPTP/mash_export.ML --- a/src/HOL/TPTP/mash_export.ML Fri Aug 23 20:53:00 2013 +0200 +++ b/src/HOL/TPTP/mash_export.ML Fri Aug 23 21:59:29 2013 +0200 @@ -189,7 +189,7 @@ |> rev |> weight_facts_steeply |> map extra_features_of - |> rpair goal_feats |-> fold (union (op = o pairself fst)) + |> rpair goal_feats |-> fold (union (eq_fst (op =))) in "? " ^ string_of_int max_suggs ^ " # " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^ diff -r 31e24d6ff1ea -r b881bee69d3a src/HOL/Tools/Sledgehammer/MaSh/src/mash.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Fri Aug 23 20:53:00 2013 +0200 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Fri Aug 23 21:59:29 2013 +0200 @@ -103,30 +103,27 @@ received = communicate(data,args.host,args.port) logger.info(received) - if args.inputFile == None: - return - logger.debug('Using the following settings: %s',args) - # IO Streams - OS = open(args.predictions,'w') - IS = open(args.inputFile,'r') - lineCount = 0 - for line in IS: - lineCount += 1 - if lineCount % 100 == 0: - logger.info('On line %s', lineCount) - #if lineCount == 50: ### - # break - received = communicate(line,args.host,args.port) - if not received == '': - OS.write('%s\n' % received) - OS.close() - IS.close() + if not args.inputFile == None: + logger.debug('Using the following settings: %s',args) + # IO Streams + OS = open(args.predictions,'w') + IS = open(args.inputFile,'r') + lineCount = 0 + for line in IS: + lineCount += 1 + if lineCount % 100 == 0: + logger.info('On line %s', lineCount) + received = communicate(line,args.host,args.port) + if not received == '': + OS.write('%s\n' % received) + OS.close() + IS.close() # Statistics if args.statistics: received = communicate('avgStats',args.host,args.port) logger.info(received) - elif args.saveModels: + if args.saveModels: communicate('save',args.host,args.port) diff -r 31e24d6ff1ea -r b881bee69d3a src/HOL/Tools/Sledgehammer/MaSh/src/mashOld.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/mashOld.py Fri Aug 23 20:53:00 2013 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,329 +0,0 @@ -#!/usr/bin/python -# Title: HOL/Tools/Sledgehammer/MaSh/src/mash.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012 -# -# Entry point for MaSh (Machine Learning for Sledgehammer). - -''' -MaSh - Machine Learning for Sledgehammer - -MaSh allows to use different machine learning algorithms to predict relevant fact for Sledgehammer. - -Created on July 12, 2012 - -@author: Daniel Kuehlwein -''' - -import logging,datetime,string,os,sys -from argparse import ArgumentParser,RawDescriptionHelpFormatter -from time import time -from stats import Statistics -from theoryStats import TheoryStatistics -from theoryModels import TheoryModels -from dictionaries import Dictionaries -#from fullNaiveBayes import NBClassifier -from sparseNaiveBayes import sparseNBClassifier -from snow import SNoW -from predefined import Predefined - -# Set up command-line parser -parser = ArgumentParser(description='MaSh - Machine Learning for Sledgehammer. \n\n\ -MaSh allows to use different machine learning algorithms to predict relevant facts for Sledgehammer.\n\n\ ---------------- Example Usage ---------------\n\ -First initialize:\n./mash.py -l test.log -o ../tmp/ --init --inputDir ../data/Jinja/ \n\ -Then create predictions:\n./mash.py -i ../data/Jinja/mash_commands -p ../data/Jinja/mash_suggestions -l test.log -o ../tmp/ --statistics\n\ -\n\n\ -Author: Daniel Kuehlwein, July 2012',formatter_class=RawDescriptionHelpFormatter) -parser.add_argument('-i','--inputFile',help='File containing all problems to be solved.') -parser.add_argument('-o','--outputDir', default='../tmp/',help='Directory where all created files are stored. Default=../tmp/.') -parser.add_argument('-p','--predictions',default='../tmp/%s.predictions' % datetime.datetime.now(), - help='File where the predictions stored. Default=../tmp/dateTime.predictions.') -parser.add_argument('--numberOfPredictions',default=200,help="Number of premises to write in the output. Default=200.",type=int) - -parser.add_argument('--init',default=False,action='store_true',help="Initialize Mash. Requires --inputDir to be defined. Default=False.") -parser.add_argument('--inputDir',default='../data/20121212/Jinja/',\ - help='Directory containing all the input data. MaSh expects the following files: mash_features,mash_dependencies,mash_accessibility') -parser.add_argument('--depFile', default='mash_dependencies', - help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies') -parser.add_argument('--saveModel',default=False,action='store_true',help="Stores the learned Model at the end of a prediction run. Default=False.") - -parser.add_argument('--learnTheories',default=False,action='store_true',help="Uses a two-lvl prediction mode. First the theories, then the premises. Default=False.") -# Theory Parameters -parser.add_argument('--theoryDefValPos',default=-7.5,help="Default value for positive unknown features. Default=-7.5.",type=float) -parser.add_argument('--theoryDefValNeg',default=-10.0,help="Default value for negative unknown features. Default=-15.0.",type=float) -parser.add_argument('--theoryPosWeight',default=2.0,help="Weight value for positive features. Default=2.0.",type=float) - -parser.add_argument('--nb',default=False,action='store_true',help="Use Naive Bayes for learning. This is the default learning method.") -# NB Parameters -parser.add_argument('--NBDefaultPriorWeight',default=20.0,help="Initializes classifiers with value * p |- p. Default=20.0.",type=float) -parser.add_argument('--NBDefVal',default=-15.0,help="Default value for unknown features. Default=-15.0.",type=float) -parser.add_argument('--NBPosWeight',default=10.0,help="Weight value for positive features. Default=10.0.",type=float) -# TODO: Rename to sineFeatures -parser.add_argument('--sineFeatures',default=False,action='store_true',help="Uses a SInE like prior for premise lvl predictions. Default=False.") -parser.add_argument('--sineWeight',default=0.5,help="How much the SInE prior is weighted. Default=0.5.",type=float) - -parser.add_argument('--snow',default=False,action='store_true',help="Use SNoW's naive bayes instead of Naive Bayes for learning.") -parser.add_argument('--predef',help="Use predefined predictions. Used only for comparison with the actual learning. Argument is the filename of the predictions.") -parser.add_argument('--statistics',default=False,action='store_true',help="Create and show statistics for the top CUTOFF predictions.\ - WARNING: This will make the program a lot slower! Default=False.") -parser.add_argument('--saveStats',default=None,help="If defined, stores the statistics in the filename provided.") -parser.add_argument('--cutOff',default=500,help="Option for statistics. Only consider the first cutOff predictions. Default=500.",type=int) -parser.add_argument('-l','--log', default='../tmp/%s.log' % datetime.datetime.now(), help='Log file name. Default=../tmp/dateTime.log') -parser.add_argument('-q','--quiet',default=False,action='store_true',help="If enabled, only print warnings. Default=False.") -parser.add_argument('--modelFile', default='../tmp/model.pickle', help='Model file name. Default=../tmp/model.pickle') -parser.add_argument('--dictsFile', default='../tmp/dict.pickle', help='Dict file name. Default=../tmp/dict.pickle') -parser.add_argument('--theoryFile', default='../tmp/theory.pickle', help='Model file name. Default=../tmp/theory.pickle') - -def mash(argv = sys.argv[1:]): - # Initializing command-line arguments - args = parser.parse_args(argv) - - # Set up logging - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', - datefmt='%d-%m %H:%M:%S', - filename=args.log, - filemode='w') - logger = logging.getLogger('main.py') - - """ - # remove old handler for tester - # TODO: Comment out for Jasmins version. This crashes python 2.6.1 - logger.root.handlers[0].stream.close() - logger.root.removeHandler(logger.root.handlers[0]) - file_handler = logging.FileHandler(args.log) - file_handler.setLevel(logging.DEBUG) - formatter = logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s') - file_handler.setFormatter(formatter) - logger.root.addHandler(file_handler) - #""" - if args.quiet: - logger.setLevel(logging.WARNING) - #console.setLevel(logging.WARNING) - else: - console = logging.StreamHandler(sys.stdout) - console.setLevel(logging.INFO) - formatter = logging.Formatter('# %(message)s') - console.setFormatter(formatter) - logging.getLogger('').addHandler(console) - - if not os.path.exists(args.outputDir): - os.makedirs(args.outputDir) - - logger.info('Using the following settings: %s',args) - # Pick algorithm - if args.nb: - logger.info('Using sparse Naive Bayes for learning.') - model = sparseNBClassifier(args.NBDefaultPriorWeight,args.NBPosWeight,args.NBDefVal) - elif args.snow: - logger.info('Using naive bayes (SNoW) for learning.') - model = SNoW() - elif args.predef: - logger.info('Using predefined predictions.') - model = Predefined(args.predef) - else: - logger.info('No algorithm specified. Using sparse Naive Bayes.') - model = sparseNBClassifier(args.NBDefaultPriorWeight,args.NBPosWeight,args.NBDefVal) - - # Initializing model - if args.init: - logger.info('Initializing Model.') - startTime = time() - - # Load all data - dicts = Dictionaries() - dicts.init_all(args) - - # Create Model - trainData = dicts.featureDict.keys() - model.initializeModel(trainData,dicts) - - if args.learnTheories: - depFile = os.path.join(args.inputDir,args.depFile) - theoryModels = TheoryModels(args.theoryDefValPos,args.theoryDefValNeg,args.theoryPosWeight) - theoryModels.init(depFile,dicts) - theoryModels.save(args.theoryFile) - - model.save(args.modelFile) - dicts.save(args.dictsFile) - - logger.info('All Done. %s seconds needed.',round(time()-startTime,2)) - return 0 - # Create predictions and/or update model - else: - lineCounter = 1 - statementCounter = 1 - computeStats = False - dicts = Dictionaries() - theoryModels = TheoryModels(args.theoryDefValPos,args.theoryDefValNeg,args.theoryPosWeight) - # Load Files - if os.path.isfile(args.dictsFile): - #logger.info('Loading Dictionaries') - #startTime = time() - dicts.load(args.dictsFile) - #logger.info('Done %s',time()-startTime) - if os.path.isfile(args.modelFile): - #logger.info('Loading Model') - #startTime = time() - model.load(args.modelFile) - #logger.info('Done %s',time()-startTime) - if os.path.isfile(args.theoryFile) and args.learnTheories: - #logger.info('Loading Theory Models') - #startTime = time() - theoryModels.load(args.theoryFile) - #logger.info('Done %s',time()-startTime) - logger.info('All loading completed') - - # IO Streams - OS = open(args.predictions,'w') - IS = open(args.inputFile,'r') - - # Statistics - if args.statistics: - stats = Statistics(args.cutOff) - if args.learnTheories: - theoryStats = TheoryStatistics() - - predictions = None - predictedTheories = None - #Reading Input File - for line in IS: -# try: - if True: - if line.startswith('!'): - problemId = dicts.parse_fact(line) - # Statistics - if args.statistics and computeStats: - computeStats = False - # Assume '!' comes after '?' - if args.predef: - predictions = model.predict(problemId) - if args.learnTheories: - tmp = [dicts.idNameDict[x] for x in dicts.dependenciesDict[problemId]] - usedTheories = set([x.split('.')[0] for x in tmp]) - theoryStats.update((dicts.idNameDict[problemId]).split('.')[0],predictedTheories,usedTheories,len(theoryModels.accessibleTheories)) - stats.update(predictions,dicts.dependenciesDict[problemId],statementCounter) - if not stats.badPreds == []: - bp = string.join([str(dicts.idNameDict[x]) for x in stats.badPreds], ',') - logger.debug('Bad predictions: %s',bp) - - statementCounter += 1 - # Update Dependencies, p proves p - dicts.dependenciesDict[problemId] = [problemId]+dicts.dependenciesDict[problemId] - if args.learnTheories: - theoryModels.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId],dicts) - if args.snow: - model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId],dicts) - else: - model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId]) - elif line.startswith('p'): - # Overwrite old proof. - problemId,newDependencies = dicts.parse_overwrite(line) - newDependencies = [problemId]+newDependencies - model.overwrite(problemId,newDependencies,dicts) - if args.learnTheories: - theoryModels.overwrite(problemId,newDependencies,dicts) - dicts.dependenciesDict[problemId] = newDependencies - elif line.startswith('?'): - startTime = time() - computeStats = True - if args.predef: - continue - name,features,accessibles,hints = dicts.parse_problem(line) - - # Create predictions - logger.info('Starting computation for problem on line %s',lineCounter) - # Update Models with hints - if not hints == []: - if args.learnTheories: - accessibleTheories = set([(dicts.idNameDict[x]).split('.')[0] for x in accessibles]) - theoryModels.update_with_acc('hints',features,hints,dicts,accessibleTheories) - if args.snow: - pass - else: - model.update('hints',features,hints) - - # Predict premises - if args.learnTheories: - predictedTheories,accessibles = theoryModels.predict(features,accessibles,dicts) - - # Add additional features on premise lvl if sine is enabled - if args.sineFeatures: - origFeatures = [f for f,_w in features] - secondaryFeatures = [] - for f in origFeatures: - if dicts.featureCountDict[f] == 1: - continue - triggeredFormulas = dicts.featureTriggeredFormulasDict[f] - for formula in triggeredFormulas: - tFeatures = dicts.triggerFeaturesDict[formula] - #tFeatures = [ff for ff,_fw in dicts.featureDict[formula]] - newFeatures = set(tFeatures).difference(secondaryFeatures+origFeatures) - for fNew in newFeatures: - secondaryFeatures.append((fNew,args.sineWeight)) - predictionsFeatures = features+secondaryFeatures - else: - predictionsFeatures = features - predictions,predictionValues = model.predict(predictionsFeatures,accessibles,dicts) - assert len(predictions) == len(predictionValues) - - # Delete hints - if not hints == []: - if args.learnTheories: - theoryModels.delete('hints',features,hints,dicts) - if args.snow: - pass - else: - model.delete('hints',features,hints) - - logger.info('Done. %s seconds needed.',round(time()-startTime,2)) - # Output - predictionNames = [str(dicts.idNameDict[p]) for p in predictions[:args.numberOfPredictions]] - predictionValues = [str(x) for x in predictionValues[:args.numberOfPredictions]] - predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))] - predictionsString = string.join(predictionsStringList,' ') - outString = '%s: %s' % (name,predictionsString) - OS.write('%s\n' % outString) - else: - logger.warning('Unspecified input format: \n%s',line) - sys.exit(-1) - lineCounter += 1 - """ - except: - logger.warning('An error occurred on line %s .',line) - lineCounter += 1 - continue - """ - OS.close() - IS.close() - - # Statistics - if args.statistics: - if args.learnTheories: - theoryStats.printAvg() - stats.printAvg() - - # Save - if args.saveModel: - model.save(args.modelFile) - if args.learnTheories: - theoryModels.save(args.theoryFile) - dicts.save(args.dictsFile) - if not args.saveStats == None: - if args.learnTheories: - theoryStatsFile = os.path.join(args.outputDir,'theoryStats') - theoryStats.save(theoryStatsFile) - statsFile = os.path.join(args.outputDir,args.saveStats) - stats.save(statsFile) - return 0 - -if __name__ == '__main__': - # Cezary Auth - args = ['--statistics', '--init', '--inputDir', '../data/20130118/Jinja', '--log', '../tmp/auth.log', '--theoryFile', '../tmp/t0', '--modelFile', '../tmp/m0', '--dictsFile', '../tmp/d0','--NBDefaultPriorWeight', '20.0', '--NBDefVal', '-15.0', '--NBPosWeight', '10.0'] - mash(args) - args = ['-i', '../data/20130118/Jinja/mash_commands', '-p', '../tmp/auth.pred0', '--statistics', '--cutOff', '500', '--log', '../tmp/auth.log','--modelFile', '../tmp/m0', '--dictsFile', '../tmp/d0'] - mash(args) - - #sys.exit(mash(args)) - sys.exit(mash()) diff -r 31e24d6ff1ea -r b881bee69d3a src/HOL/Tools/Sledgehammer/MaSh/src/stats.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/stats.py Fri Aug 23 20:53:00 2013 +0200 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/stats.py Fri Aug 23 21:59:29 2013 +0200 @@ -123,6 +123,8 @@ # HACK FOR PAPER assert len(self.aucData) == len(self.recall100Median) nrDataPoints = len(self.aucData) + if nrDataPoints == 0: + return "No data points" if nrDataPoints % 2 == 1: medianAUC = sorted(self.aucData)[nrDataPoints/2 + 1] else: diff -r 31e24d6ff1ea -r b881bee69d3a src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri Aug 23 20:53:00 2013 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri Aug 23 21:59:29 2013 +0200 @@ -151,7 +151,7 @@ xs |> chunk_list 500 |> List.app (File.append path o implode o map f)) handle IO.Io _ => () -fun run_mash_tool ctxt overlord extra_args write_cmds read_suggs = +fun run_mash_tool ctxt overlord extra_args background write_cmds read_suggs = let val (temp_dir, serial) = if overlord then (getenv "ISABELLE_HOME_USER", "") @@ -172,7 +172,8 @@ " --dictsFile=" ^ model_dir ^ "/dict.pickle" ^ " --log " ^ log_file ^ " " ^ core ^ (if extra_args = [] then "" else " " ^ space_implode " " extra_args) ^ - " >& " ^ err_file + " >& " ^ err_file ^ + (if background then " &" else "") fun run_on () = (Isabelle_System.bash command |> tap (fn _ => trace_msg ctxt (fn () => @@ -254,7 +255,12 @@ struct fun shutdown ctxt overlord = - run_mash_tool ctxt overlord [shutdown_server_arg] ([], K "") (K ()) + (trace_msg ctxt (K "MaSh shutdown"); + run_mash_tool ctxt overlord [shutdown_server_arg] true ([], K "") (K ())) + +fun save ctxt overlord = + (trace_msg ctxt (K "MaSh save"); + run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ())) fun unlearn ctxt overlord = let val path = mash_model_dir () in @@ -270,19 +276,19 @@ | learn ctxt overlord learns = (trace_msg ctxt (fn () => "MaSh learn " ^ elide_string 1000 (space_implode " " (map #1 learns))); - run_mash_tool ctxt overlord [save_models_arg] (learns, str_of_learn) + run_mash_tool ctxt overlord [] false (learns, str_of_learn) (K ())) fun relearn _ _ [] = () | relearn ctxt overlord relearns = (trace_msg ctxt (fn () => "MaSh relearn " ^ elide_string 1000 (space_implode " " (map #1 relearns))); - run_mash_tool ctxt overlord [save_models_arg] (relearns, str_of_relearn) + run_mash_tool ctxt overlord [] false (relearns, str_of_relearn) (K ())) fun query ctxt overlord max_suggs (query as (_, _, _, feats)) = (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats); - run_mash_tool ctxt overlord [] ([query], str_of_query max_suggs) + run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs) (fn suggs => case suggs () of [] => [] @@ -359,8 +365,8 @@ | _ => NONE) | _ => NONE -fun load _ _ (state as (true, _)) = state - | load ctxt overlord _ = +fun load_state _ _ (state as (true, _)) = state + | load_state ctxt overlord _ = let val path = mash_state_file () in (true, case try File.read_lines path of @@ -394,8 +400,8 @@ | _ => empty_state) end -fun save _ (state as {dirty = SOME [], ...}) = state - | save ctxt {access_G, num_known_facts, dirty} = +fun save_state _ (state as {dirty = SOME [], ...}) = state + | save_state ctxt {access_G, num_known_facts, dirty} = let fun str_of_entry (name, parents, kind) = str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^ @@ -424,12 +430,13 @@ in fun map_state ctxt overlord f = - Synchronized.change global_state (load ctxt overlord ##> (f #> save ctxt)) + Synchronized.change global_state + (load_state ctxt overlord ##> (f #> save_state ctxt)) handle FILE_VERSION_TOO_NEW () => () fun peek_state ctxt overlord f = Synchronized.change_result global_state - (perhaps (try (load ctxt overlord)) #> `snd #>> f) + (perhaps (try (load_state ctxt overlord)) #> `snd #>> f) fun clear_state ctxt overlord = Synchronized.change global_state (fn _ => @@ -513,14 +520,11 @@ |> map snd |> take max_facts end +fun free_feature_of s = ("f" ^ s, 40.0 (* FUDGE *)) fun thy_feature_of s = ("y" ^ s, 8.0 (* FUDGE *)) -fun free_feature_of s = ("f" ^ s, 40.0 (* FUDGE *)) fun type_feature_of s = ("t" ^ s, 4.0 (* FUDGE *)) fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *)) -fun status_feature_of status = (string_of_status status, 2.0 (* FUDGE *)) -val local_feature = ("local", 8.0 (* FUDGE *)) -val lams_feature = ("lams", 2.0 (* FUDGE *)) -val skos_feature = ("skos", 2.0 (* FUDGE *)) +val local_feature = ("local", 16.0 (* FUDGE *)) fun crude_theory_ord p = if Theory.subthy p then @@ -609,7 +613,7 @@ #> swap #> op :: #> subtract (op =) @{sort type} #> map massage_long_name #> map class_feature_of - #> union (op = o pairself fst)) S + #> union (eq_fst (op =))) S fun pattify_type 0 _ = [] | pattify_type _ (Type (s, [])) = @@ -625,13 +629,13 @@ | pattify_type _ (TVar (_, S)) = maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S) fun add_type_pat depth T = - union (op = o pairself fst) (map type_feature_of (pattify_type depth T)) + union (eq_fst (op =)) (map type_feature_of (pattify_type depth T)) fun add_type_pats 0 _ = I | add_type_pats depth t = add_type_pat depth t #> add_type_pats (depth - 1) t fun add_type T = add_type_pats type_max_depth T - #> fold_atyps_sorts (fn (_, S) => add_classes S) T + #> fold_atyps_sorts (add_classes o snd) T fun add_subtypes (T as Type (_, Ts)) = add_type T #> fold add_subtypes Ts | add_subtypes T = add_type T @@ -672,8 +676,7 @@ | (q, qw) => (p ^ "(" ^ q ^ ")", pw + qw)) ps qs end | pattify_term _ _ _ _ = [] - fun add_term_pat Ts depth = - union (op = o pairself fst) o pattify_term Ts [] depth + fun add_term_pat Ts = union (eq_fst (op =)) oo pattify_term Ts [] fun add_term_pats _ 0 _ = I | add_term_pats Ts depth t = add_term_pat Ts depth t #> add_term_pats Ts (depth - 1) t @@ -695,21 +698,16 @@ #> fold (add_subterms Ts) args in [] |> fold (add_subterms []) ts end -fun is_exists (s, _) = (s = @{const_name Ex} orelse s = @{const_name Ex1}) - val term_max_depth = 2 -val type_max_depth = 2 +val type_max_depth = 1 (* TODO: Generate type classes for types? *) -fun features_of ctxt prover thy num_facts const_tab (scope, status) ts = +fun features_of ctxt prover thy num_facts const_tab (scope, _) ts = let val thy_name = Context.theory_name thy in thy_feature_of thy_name :: term_features_of ctxt prover thy_name num_facts const_tab term_max_depth type_max_depth ts - |> status <> General ? cons (status_feature_of status) |> scope <> Global ? cons local_feature - |> exists (not o is_lambda_free) ts ? cons lams_feature - |> exists (exists_Const is_exists) ts ? cons skos_feature end (* Too many dependencies is a sign that a decision procedure is at work. There @@ -906,7 +904,7 @@ val chained_feature_factor = 0.5 val extra_feature_factor = 0.1 -val num_extra_feature_facts = 0 (* FUDGE *) +val num_extra_feature_facts = 10 (* FUDGE *) (* FUDGE *) fun weight_of_proximity_fact rank = @@ -927,7 +925,7 @@ fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown) | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown = let - val inter_fact = inter (Thm.eq_thm_prop o pairself snd) + val inter_fact = inter (eq_snd Thm.eq_thm_prop) val raw_mash = find_suggested_facts ctxt facts suggs val proximate = take max_proximity_facts facts val unknown_chained = inter_fact raw_unknown chained @@ -938,9 +936,9 @@ (0.1 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown))] val unknown = raw_unknown - |> fold (subtract (Thm.eq_thm_prop o pairself snd)) + |> fold (subtract (eq_snd Thm.eq_thm_prop)) [unknown_chained, unknown_proximate] - in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end + in (mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess, unknown) end fun add_const_counts t = fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) @@ -972,15 +970,15 @@ chained |> map (rpair 1.0) |> map (chained_or_extra_features_of chained_feature_factor) - |> rpair [] |-> fold (union (op = o pairself fst)) + |> rpair [] |-> fold (union (eq_fst (op =))) val extra_feats = facts |> take (Int.max (0, num_extra_feature_facts - length chained)) |> weight_facts_steeply |> map (chained_or_extra_features_of extra_feature_factor) - |> rpair [] |-> fold (union (op = o pairself fst)) + |> rpair [] |-> fold (union (eq_fst (op =))) val feats = - fold (union (op = o pairself fst)) [chained_feats, extra_feats] + fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats val hints = chained |> filter (is_fact_in_graph access_G o snd) @@ -1044,7 +1042,8 @@ used_ths |> filter (is_fact_in_graph access_G) |> map nickname_of_thm in - MaSh.learn ctxt overlord [(name, parents, feats, deps)] + MaSh.learn ctxt overlord [(name, parents, feats, deps)]; + MaSh.save ctxt overlord end); (true, "") end) @@ -1056,7 +1055,7 @@ (* The timeout is understood in a very relaxed fashion. *) fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover - auto_level run_prover learn_timeout facts = + save auto_level run_prover learn_timeout facts = let val timer = Timer.startRealTimer () fun next_commit_time () = @@ -1107,6 +1106,7 @@ in MaSh.learn ctxt overlord (rev learns); MaSh.relearn ctxt overlord relearns; + if save then MaSh.save ctxt overlord else (); {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty} end @@ -1228,7 +1228,7 @@ val num_facts = length facts val prover = hd provers fun learn auto_level run_prover = - mash_learn_facts ctxt params prover auto_level run_prover NONE facts + mash_learn_facts ctxt params prover true auto_level run_prover NONE facts |> Output.urgent_message in if run_prover then @@ -1261,11 +1261,14 @@ val mepo_weight = 0.5 val mash_weight = 0.5 +val max_facts_to_learn_before_query = 100 + (* The threshold should be large enough so that MaSh doesn't kick in for Auto Sledgehammer and Try. *) val min_secs_for_learning = 15 -fun relevant_facts ctxt (params as {overlord, learn, fact_filter, timeout, ...}) +fun relevant_facts ctxt + (params as {overlord, blocking, learn, fact_filter, timeout, ...}) prover max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts = if not (subset (op =) (the_list fact_filter, fact_filters)) then @@ -1278,28 +1281,46 @@ [("", [])] else let - fun maybe_learn () = - if learn andalso not (Async_Manager.has_running_threads MaShN) andalso + fun maybe_launch_thread () = + if not blocking andalso + not (Async_Manager.has_running_threads MaShN) andalso (timeout = NONE orelse Time.toSeconds (the timeout) >= min_secs_for_learning) then let val timeout = Option.map (time_mult learn_timeout_slack) timeout in launch_thread (timeout |> the_default one_day) - (fn () => (true, mash_learn_facts ctxt params prover 2 false - timeout facts)) + (fn () => (true, mash_learn_facts ctxt params prover true 2 + false timeout facts)) end else () - val effective_fact_filter = + fun maybe_learn () = + if learn then + let + val {access_G, num_known_facts, ...} = peek_state ctxt overlord I + val is_in_access_G = is_fact_in_graph access_G o snd + in + if length facts - num_known_facts <= max_facts_to_learn_before_query + andalso length (filter_out is_in_access_G facts) + <= max_facts_to_learn_before_query then + (mash_learn_facts ctxt params prover false 2 false timeout facts + |> (fn "" => () | s => Output.urgent_message (MaShN ^ ": " ^ s)); + true) + else + (maybe_launch_thread (); false) + end + else + false + val (save, effective_fact_filter) = case fact_filter of - SOME ff => (() |> ff <> mepoN ? maybe_learn; ff) + SOME ff => (ff <> mepoN andalso maybe_learn (), ff) | NONE => if is_mash_enabled () then - (maybe_learn (); + (maybe_learn (), if mash_can_suggest_facts ctxt overlord then meshN else mepoN) else - mepoN + (false, mepoN) val add_ths = Attrib.eval_thms ctxt add fun in_add (_, th) = member Thm.eq_thm_prop add_ths th fun add_and_take accepts = @@ -1327,9 +1348,10 @@ else I) val mesh = - mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess + mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take in + if save then MaSh.save ctxt overlord else (); case (fact_filter, mess) of (NONE, [(_, (mepo, _)), (_, (mash, _))]) => [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take), diff -r 31e24d6ff1ea -r b881bee69d3a src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML Fri Aug 23 20:53:00 2013 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML Fri Aug 23 21:59:29 2013 +0200 @@ -311,10 +311,10 @@ s = pseudo_abs_name orelse String.isPrefix pseudo_skolem_prefix s orelse String.isSuffix theory_const_suffix s -fun fact_weight fudge stature const_tab relevant_consts chained_consts - fact_consts = - case fact_consts |> List.partition (pconst_hyper_mem I relevant_consts) - ||> filter_out (pconst_hyper_mem swap relevant_consts) of +fun fact_weight fudge stature const_tab rel_const_tab rel_const_iter_tab + chained_const_tab fact_consts = + case fact_consts |> List.partition (pconst_hyper_mem I rel_const_tab) + ||> filter_out (pconst_hyper_mem swap rel_const_tab) of ([], _) => 0.0 | (rel, irrel) => if forall (forall (is_odd_const_name o fst)) [rel, irrel] then @@ -327,7 +327,8 @@ val irrel_weight = ~ (stature_bonus fudge stature) |> fold (curry (op +) - o irrel_pconst_weight fudge const_tab chained_consts) irrel + o irrel_pconst_weight fudge const_tab chained_const_tab) + irrel val res = rel_weight / (rel_weight + irrel_weight) in if Real.isFinite res then res else 0.0 end @@ -400,30 +401,36 @@ | _ => NONE) val chained_const_tab = Symtab.empty |> fold (add_pconsts true) chained_ts val goal_const_tab = - Symtab.empty |> fold (add_pconsts true) hyp_ts - |> add_pconsts false concl_t + Symtab.empty + |> fold (add_pconsts true) hyp_ts + |> add_pconsts false concl_t |> (fn tab => if Symtab.is_empty tab then chained_const_tab else tab) |> fold (if_empty_replace_with_scope thy is_built_in_const facts) [Chained, Assum, Local] - fun iter j remaining_max thres rel_const_tab hopeless hopeful = + val goal_const_iter_tab = goal_const_tab |> Symtab.map (K (K ~1)) + fun iter j remaining_max thres rel_const_tab rel_const_iter_tab hopeless + hopeful = let fun relevant [] _ [] = (* Nothing has been added this iteration. *) if j = 0 andalso thres >= ridiculous_threshold then (* First iteration? Try again. *) iter 0 max_facts (thres / threshold_divisor) rel_const_tab - hopeless hopeful + rel_const_iter_tab hopeless hopeful else [] | relevant candidates rejects [] = let val (accepts, more_rejects) = take_most_relevant ctxt max_facts remaining_max fudge candidates + val sps = maps (snd o fst) accepts; val rel_const_tab' = - rel_const_tab - |> fold (add_pconst_to_table false) (maps (snd o fst) accepts) - fun is_dirty (c, _) = - Symtab.lookup rel_const_tab' c <> Symtab.lookup rel_const_tab c + rel_const_tab |> fold (add_pconst_to_table false) sps + val rel_const_iter_tab' = + rel_const_iter_tab + |> fold (fn (s, _) => Symtab.default (s, j)) sps + fun is_dirty (s, _) = + Symtab.lookup rel_const_tab' s <> Symtab.lookup rel_const_tab s val (hopeful_rejects, hopeless_rejects) = (rejects @ hopeless, ([], [])) |-> fold (fn (ax as (_, consts), old_weight) => @@ -441,7 +448,8 @@ val remaining_max = remaining_max - length accepts in trace_msg ctxt (fn () => "New or updated constants: " ^ - commas (rel_const_tab' |> Symtab.dest + commas (rel_const_tab' + |> Symtab.dest |> subtract (op =) (rel_const_tab |> Symtab.dest) |> map string_of_hyper_pconst)); map (fst o fst) accepts @ @@ -449,7 +457,7 @@ [] else iter (j + 1) remaining_max thres rel_const_tab' - hopeless_rejects hopeful_rejects) + rel_const_iter_tab' hopeless_rejects hopeful_rejects) end | relevant candidates rejects (((ax as (((_, stature), _), fact_consts)), cached_weight) @@ -458,8 +466,9 @@ val weight = case cached_weight of SOME w => w - | NONE => fact_weight fudge stature const_tab rel_const_tab - chained_const_tab fact_consts + | NONE => + fact_weight fudge stature const_tab rel_const_tab + rel_const_iter_tab chained_const_tab fact_consts in if weight >= thres then relevant ((ax, weight) :: candidates) rejects hopeful @@ -470,7 +479,8 @@ trace_msg ctxt (fn () => "ITERATION " ^ string_of_int j ^ ": current threshold: " ^ Real.toString thres ^ ", constants: " ^ - commas (rel_const_tab |> Symtab.dest + commas (rel_const_tab + |> Symtab.dest |> filter (curry (op <>) [] o snd) |> map string_of_hyper_pconst)); relevant [] [] hopeful @@ -499,7 +509,7 @@ |> insert_into_facts accepts in facts |> map_filter (pair_consts_fact thy is_built_in_const fudge) - |> iter 0 max_facts thres0 goal_const_tab [] + |> iter 0 max_facts thres0 goal_const_tab goal_const_iter_tab [] |> insert_special_facts |> tap (fn accepts => trace_msg ctxt (fn () => "Total relevant: " ^ string_of_int (length accepts))) diff -r 31e24d6ff1ea -r b881bee69d3a src/HOL/Tools/Sledgehammer/sledgehammer_run.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_run.ML Fri Aug 23 20:53:00 2013 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_run.ML Fri Aug 23 21:59:29 2013 +0200 @@ -167,8 +167,6 @@ val auto_try_max_facts_divisor = 2 (* FUDGE *) -fun eq_facts p = eq_list (op = o pairself fst) p - fun string_of_facts facts = "Including " ^ string_of_int (length facts) ^ " relevant fact" ^ plural_s (length facts) ^ ":\n" ^