--- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Fri Aug 23 00:12:20 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Fri Aug 23 13:30:25 2013 +0200
@@ -103,30 +103,27 @@
received = communicate(data,args.host,args.port)
logger.info(received)
- if args.inputFile == None:
- return
- logger.debug('Using the following settings: %s',args)
- # IO Streams
- OS = open(args.predictions,'w')
- IS = open(args.inputFile,'r')
- lineCount = 0
- for line in IS:
- lineCount += 1
- if lineCount % 100 == 0:
- logger.info('On line %s', lineCount)
- #if lineCount == 50: ###
- # break
- received = communicate(line,args.host,args.port)
- if not received == '':
- OS.write('%s\n' % received)
- OS.close()
- IS.close()
+ if not args.inputFile == None:
+ logger.debug('Using the following settings: %s',args)
+ # IO Streams
+ OS = open(args.predictions,'w')
+ IS = open(args.inputFile,'r')
+ lineCount = 0
+ for line in IS:
+ lineCount += 1
+ if lineCount % 100 == 0:
+ logger.info('On line %s', lineCount)
+ received = communicate(line,args.host,args.port)
+ if not received == '':
+ OS.write('%s\n' % received)
+ OS.close()
+ IS.close()
# Statistics
if args.statistics:
received = communicate('avgStats',args.host,args.port)
logger.info(received)
- elif args.saveModels:
+ if args.saveModels:
communicate('save',args.host,args.port)
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri Aug 23 00:12:20 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri Aug 23 13:30:25 2013 +0200
@@ -151,7 +151,7 @@
xs |> chunk_list 500 |> List.app (File.append path o implode o map f))
handle IO.Io _ => ()
-fun run_mash_tool ctxt overlord extra_args write_cmds read_suggs =
+fun run_mash_tool ctxt overlord extra_args background write_cmds read_suggs =
let
val (temp_dir, serial) =
if overlord then (getenv "ISABELLE_HOME_USER", "")
@@ -172,7 +172,8 @@
" --dictsFile=" ^ model_dir ^ "/dict.pickle" ^
" --log " ^ log_file ^ " " ^ core ^
(if extra_args = [] then "" else " " ^ space_implode " " extra_args) ^
- " >& " ^ err_file
+ " >& " ^ err_file ^
+ (if background then " &" else "")
fun run_on () =
(Isabelle_System.bash command
|> tap (fn _ => trace_msg ctxt (fn () =>
@@ -254,7 +255,10 @@
struct
fun shutdown ctxt overlord =
- run_mash_tool ctxt overlord [shutdown_server_arg] ([], K "") (K ())
+ run_mash_tool ctxt overlord [shutdown_server_arg] true ([], K "") (K ())
+
+fun save ctxt overlord =
+ run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ())
fun unlearn ctxt overlord =
let val path = mash_model_dir () in
@@ -270,19 +274,19 @@
| learn ctxt overlord learns =
(trace_msg ctxt (fn () => "MaSh learn " ^
elide_string 1000 (space_implode " " (map #1 learns)));
- run_mash_tool ctxt overlord [save_models_arg] (learns, str_of_learn)
+ run_mash_tool ctxt overlord [] false (learns, str_of_learn)
(K ()))
fun relearn _ _ [] = ()
| relearn ctxt overlord relearns =
(trace_msg ctxt (fn () => "MaSh relearn " ^
elide_string 1000 (space_implode " " (map #1 relearns)));
- run_mash_tool ctxt overlord [save_models_arg] (relearns, str_of_relearn)
+ run_mash_tool ctxt overlord [] false (relearns, str_of_relearn)
(K ()))
fun query ctxt overlord max_suggs (query as (_, _, _, feats)) =
(trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats);
- run_mash_tool ctxt overlord [] ([query], str_of_query max_suggs)
+ run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs)
(fn suggs =>
case suggs () of
[] => []
@@ -359,8 +363,8 @@
| _ => NONE)
| _ => NONE
-fun load _ _ (state as (true, _)) = state
- | load ctxt overlord _ =
+fun load_state _ _ (state as (true, _)) = state
+ | load_state ctxt overlord _ =
let val path = mash_state_file () in
(true,
case try File.read_lines path of
@@ -394,8 +398,8 @@
| _ => empty_state)
end
-fun save _ (state as {dirty = SOME [], ...}) = state
- | save ctxt {access_G, num_known_facts, dirty} =
+fun save_state _ (state as {dirty = SOME [], ...}) = state
+ | save_state ctxt {access_G, num_known_facts, dirty} =
let
fun str_of_entry (name, parents, kind) =
str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^
@@ -424,12 +428,13 @@
in
fun map_state ctxt overlord f =
- Synchronized.change global_state (load ctxt overlord ##> (f #> save ctxt))
+ Synchronized.change global_state
+ (load_state ctxt overlord ##> (f #> save_state ctxt))
handle FILE_VERSION_TOO_NEW () => ()
fun peek_state ctxt overlord f =
Synchronized.change_result global_state
- (perhaps (try (load ctxt overlord)) #> `snd #>> f)
+ (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
fun clear_state ctxt overlord =
Synchronized.change global_state (fn _ =>
@@ -1044,7 +1049,8 @@
used_ths |> filter (is_fact_in_graph access_G)
|> map nickname_of_thm
in
- MaSh.learn ctxt overlord [(name, parents, feats, deps)]
+ MaSh.learn ctxt overlord [(name, parents, feats, deps)];
+ MaSh.save ctxt overlord
end);
(true, "")
end)
@@ -1056,7 +1062,7 @@
(* The timeout is understood in a very relaxed fashion. *)
fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover
- auto_level run_prover learn_timeout facts =
+ save auto_level run_prover learn_timeout facts =
let
val timer = Timer.startRealTimer ()
fun next_commit_time () =
@@ -1107,6 +1113,7 @@
in
MaSh.learn ctxt overlord (rev learns);
MaSh.relearn ctxt overlord relearns;
+ if save then MaSh.save ctxt overlord else ();
{access_G = access_G, num_known_facts = num_known_facts,
dirty = dirty}
end
@@ -1228,7 +1235,7 @@
val num_facts = length facts
val prover = hd provers
fun learn auto_level run_prover =
- mash_learn_facts ctxt params prover auto_level run_prover NONE facts
+ mash_learn_facts ctxt params prover true auto_level run_prover NONE facts
|> Output.urgent_message
in
if run_prover then
@@ -1261,6 +1268,8 @@
val mepo_weight = 0.5
val mash_weight = 0.5
+val max_facts_to_learn_before_query = 100
+
(* The threshold should be large enough so that MaSh doesn't kick in for Auto
Sledgehammer and Try. *)
val min_secs_for_learning = 15
@@ -1278,28 +1287,45 @@
[("", [])]
else
let
- fun maybe_learn () =
- if learn andalso not (Async_Manager.has_running_threads MaShN) andalso
+ fun maybe_launch_thread () =
+ if not (Async_Manager.has_running_threads MaShN) andalso
(timeout = NONE orelse
Time.toSeconds (the timeout) >= min_secs_for_learning) then
let
val timeout = Option.map (time_mult learn_timeout_slack) timeout
in
launch_thread (timeout |> the_default one_day)
- (fn () => (true, mash_learn_facts ctxt params prover 2 false
- timeout facts))
+ (fn () => (true, mash_learn_facts ctxt params prover true 2
+ false timeout facts))
end
else
()
- val effective_fact_filter =
+ fun maybe_learn () =
+ if learn then
+ let
+ val {access_G, num_known_facts, ...} = peek_state ctxt overlord I
+ val is_in_access_G = is_fact_in_graph access_G o snd
+ in
+ if length facts - num_known_facts <= max_facts_to_learn_before_query
+ andalso length (filter_out is_in_access_G facts)
+ <= max_facts_to_learn_before_query then
+ (mash_learn_facts ctxt params prover false 2 false timeout facts
+ |> (fn "" => () | s => Output.urgent_message s);
+ true)
+ else
+ (maybe_launch_thread (); false)
+ end
+ else
+ false
+ val (save, effective_fact_filter) =
case fact_filter of
- SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
+ SOME ff => (ff <> mepoN andalso maybe_learn (), ff)
| NONE =>
if is_mash_enabled () then
- (maybe_learn ();
+ (maybe_learn (),
if mash_can_suggest_facts ctxt overlord then meshN else mepoN)
else
- mepoN
+ (false, mepoN)
val add_ths = Attrib.eval_thms ctxt add
fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
fun add_and_take accepts =
@@ -1330,6 +1356,7 @@
mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess
|> add_and_take
in
+ if save then MaSh.save ctxt overlord else ();
case (fact_filter, mess) of
(NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
[(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),