# HG changeset patch # User blanchet # Date 1377257425 -7200 # Node ID cbd3c7c48d2c58b78227391be0e07b1428bbffd2 # Parent fbf4d50dec91a48dea97a7d6dc710186e347eea5 learn new facts on query if there aren't too many of them in MaSh diff -r fbf4d50dec91 -r cbd3c7c48d2c src/HOL/Tools/Sledgehammer/MaSh/src/mash.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Fri Aug 23 00:12:20 2013 +0200 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Fri Aug 23 13:30:25 2013 +0200 @@ -103,30 +103,27 @@ received = communicate(data,args.host,args.port) logger.info(received) - if args.inputFile == None: - return - logger.debug('Using the following settings: %s',args) - # IO Streams - OS = open(args.predictions,'w') - IS = open(args.inputFile,'r') - lineCount = 0 - for line in IS: - lineCount += 1 - if lineCount % 100 == 0: - logger.info('On line %s', lineCount) - #if lineCount == 50: ### - # break - received = communicate(line,args.host,args.port) - if not received == '': - OS.write('%s\n' % received) - OS.close() - IS.close() + if not args.inputFile == None: + logger.debug('Using the following settings: %s',args) + # IO Streams + OS = open(args.predictions,'w') + IS = open(args.inputFile,'r') + lineCount = 0 + for line in IS: + lineCount += 1 + if lineCount % 100 == 0: + logger.info('On line %s', lineCount) + received = communicate(line,args.host,args.port) + if not received == '': + OS.write('%s\n' % received) + OS.close() + IS.close() # Statistics if args.statistics: received = communicate('avgStats',args.host,args.port) logger.info(received) - elif args.saveModels: + if args.saveModels: communicate('save',args.host,args.port) diff -r fbf4d50dec91 -r cbd3c7c48d2c src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri Aug 23 00:12:20 2013 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri Aug 23 13:30:25 2013 +0200 @@ -151,7 +151,7 @@ xs |> chunk_list 500 |> List.app (File.append path o implode o map f)) handle IO.Io _ => () -fun run_mash_tool ctxt overlord extra_args write_cmds read_suggs = +fun run_mash_tool ctxt overlord extra_args background write_cmds read_suggs = let val (temp_dir, serial) = if overlord then (getenv "ISABELLE_HOME_USER", "") @@ -172,7 +172,8 @@ " --dictsFile=" ^ model_dir ^ "/dict.pickle" ^ " --log " ^ log_file ^ " " ^ core ^ (if extra_args = [] then "" else " " ^ space_implode " " extra_args) ^ - " >& " ^ err_file + " >& " ^ err_file ^ + (if background then " &" else "") fun run_on () = (Isabelle_System.bash command |> tap (fn _ => trace_msg ctxt (fn () => @@ -254,7 +255,10 @@ struct fun shutdown ctxt overlord = - run_mash_tool ctxt overlord [shutdown_server_arg] ([], K "") (K ()) + run_mash_tool ctxt overlord [shutdown_server_arg] true ([], K "") (K ()) + +fun save ctxt overlord = + run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ()) fun unlearn ctxt overlord = let val path = mash_model_dir () in @@ -270,19 +274,19 @@ | learn ctxt overlord learns = (trace_msg ctxt (fn () => "MaSh learn " ^ elide_string 1000 (space_implode " " (map #1 learns))); - run_mash_tool ctxt overlord [save_models_arg] (learns, str_of_learn) + run_mash_tool ctxt overlord [] false (learns, str_of_learn) (K ())) fun relearn _ _ [] = () | relearn ctxt overlord relearns = (trace_msg ctxt (fn () => "MaSh relearn " ^ elide_string 1000 (space_implode " " (map #1 relearns))); - run_mash_tool ctxt overlord [save_models_arg] (relearns, str_of_relearn) + run_mash_tool ctxt overlord [] false (relearns, str_of_relearn) (K ())) fun query ctxt overlord max_suggs (query as (_, _, _, feats)) = (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats); - run_mash_tool ctxt overlord [] ([query], str_of_query max_suggs) + run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs) (fn suggs => case suggs () of [] => [] @@ -359,8 +363,8 @@ | _ => NONE) | _ => NONE -fun load _ _ (state as (true, _)) = state - | load ctxt overlord _ = +fun load_state _ _ (state as (true, _)) = state + | load_state ctxt overlord _ = let val path = mash_state_file () in (true, case try File.read_lines path of @@ -394,8 +398,8 @@ | _ => empty_state) end -fun save _ (state as {dirty = SOME [], ...}) = state - | save ctxt {access_G, num_known_facts, dirty} = +fun save_state _ (state as {dirty = SOME [], ...}) = state + | save_state ctxt {access_G, num_known_facts, dirty} = let fun str_of_entry (name, parents, kind) = str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^ @@ -424,12 +428,13 @@ in fun map_state ctxt overlord f = - Synchronized.change global_state (load ctxt overlord ##> (f #> save ctxt)) + Synchronized.change global_state + (load_state ctxt overlord ##> (f #> save_state ctxt)) handle FILE_VERSION_TOO_NEW () => () fun peek_state ctxt overlord f = Synchronized.change_result global_state - (perhaps (try (load ctxt overlord)) #> `snd #>> f) + (perhaps (try (load_state ctxt overlord)) #> `snd #>> f) fun clear_state ctxt overlord = Synchronized.change global_state (fn _ => @@ -1044,7 +1049,8 @@ used_ths |> filter (is_fact_in_graph access_G) |> map nickname_of_thm in - MaSh.learn ctxt overlord [(name, parents, feats, deps)] + MaSh.learn ctxt overlord [(name, parents, feats, deps)]; + MaSh.save ctxt overlord end); (true, "") end) @@ -1056,7 +1062,7 @@ (* The timeout is understood in a very relaxed fashion. *) fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover - auto_level run_prover learn_timeout facts = + save auto_level run_prover learn_timeout facts = let val timer = Timer.startRealTimer () fun next_commit_time () = @@ -1107,6 +1113,7 @@ in MaSh.learn ctxt overlord (rev learns); MaSh.relearn ctxt overlord relearns; + if save then MaSh.save ctxt overlord else (); {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty} end @@ -1228,7 +1235,7 @@ val num_facts = length facts val prover = hd provers fun learn auto_level run_prover = - mash_learn_facts ctxt params prover auto_level run_prover NONE facts + mash_learn_facts ctxt params prover true auto_level run_prover NONE facts |> Output.urgent_message in if run_prover then @@ -1261,6 +1268,8 @@ val mepo_weight = 0.5 val mash_weight = 0.5 +val max_facts_to_learn_before_query = 100 + (* The threshold should be large enough so that MaSh doesn't kick in for Auto Sledgehammer and Try. *) val min_secs_for_learning = 15 @@ -1278,28 +1287,45 @@ [("", [])] else let - fun maybe_learn () = - if learn andalso not (Async_Manager.has_running_threads MaShN) andalso + fun maybe_launch_thread () = + if not (Async_Manager.has_running_threads MaShN) andalso (timeout = NONE orelse Time.toSeconds (the timeout) >= min_secs_for_learning) then let val timeout = Option.map (time_mult learn_timeout_slack) timeout in launch_thread (timeout |> the_default one_day) - (fn () => (true, mash_learn_facts ctxt params prover 2 false - timeout facts)) + (fn () => (true, mash_learn_facts ctxt params prover true 2 + false timeout facts)) end else () - val effective_fact_filter = + fun maybe_learn () = + if learn then + let + val {access_G, num_known_facts, ...} = peek_state ctxt overlord I + val is_in_access_G = is_fact_in_graph access_G o snd + in + if length facts - num_known_facts <= max_facts_to_learn_before_query + andalso length (filter_out is_in_access_G facts) + <= max_facts_to_learn_before_query then + (mash_learn_facts ctxt params prover false 2 false timeout facts + |> (fn "" => () | s => Output.urgent_message s); + true) + else + (maybe_launch_thread (); false) + end + else + false + val (save, effective_fact_filter) = case fact_filter of - SOME ff => (() |> ff <> mepoN ? maybe_learn; ff) + SOME ff => (ff <> mepoN andalso maybe_learn (), ff) | NONE => if is_mash_enabled () then - (maybe_learn (); + (maybe_learn (), if mash_can_suggest_facts ctxt overlord then meshN else mepoN) else - mepoN + (false, mepoN) val add_ths = Attrib.eval_thms ctxt add fun in_add (_, th) = member Thm.eq_thm_prop add_ths th fun add_and_take accepts = @@ -1330,6 +1356,7 @@ mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess |> add_and_take in + if save then MaSh.save ctxt overlord else (); case (fact_filter, mess) of (NONE, [(_, (mepo, _)), (_, (mash, _))]) => [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),