# HG changeset patch # User wenzelm # Date 1621109165 -7200 # Node ID 0e7a5c7a14c854ed7bf50b51821c731dd184de1f # Parent 03e134d5f867465fa521b802779a64b5fbbb8e6a reactive "sledgehammer"; diff -r 03e134d5f867 -r 0e7a5c7a14c8 src/HOL/Mirabelle.thy --- a/src/HOL/Mirabelle.thy Sat May 15 17:40:36 2021 +0200 +++ b/src/HOL/Mirabelle.thy Sat May 15 22:06:05 2021 +0200 @@ -11,9 +11,7 @@ ML_file \Tools/Mirabelle/mirabelle_arith.ML\ ML_file \Tools/Mirabelle/mirabelle_metis.ML\ ML_file \Tools/Mirabelle/mirabelle_quickcheck.ML\ -(* ML_file \Tools/Mirabelle/mirabelle_sledgehammer.ML\ -*) ML_file \Tools/Mirabelle/mirabelle_sledgehammer_filter.ML\ ML_file \Tools/Mirabelle/mirabelle_try0.ML\ diff -r 03e134d5f867 -r 0e7a5c7a14c8 src/HOL/Tools/Mirabelle/mirabelle.ML --- a/src/HOL/Tools/Mirabelle/mirabelle.ML Sat May 15 17:40:36 2021 +0200 +++ b/src/HOL/Tools/Mirabelle/mirabelle.ML Sat May 15 22:06:05 2021 +0200 @@ -8,7 +8,8 @@ (*core*) val print_name: string -> string val print_properties: Properties.T -> string - type context = {tag: string, arguments: Properties.T, timeout: Time.time, theory: theory} + type context = + {index: int, tag: string, arguments: Properties.T, timeout: Time.time, theory: theory} type command = {name: string, pos: Position.T, pre: Proof.state, post: Toplevel.state} val theory_action: binding -> (context -> command list -> XML.body) -> theory -> theory val log_command: command -> XML.body -> XML.tree @@ -48,7 +49,8 @@ (* actions *) type command = {name: string, pos: Position.T, pre: Proof.state, post: Toplevel.state}; -type context = {tag: string, arguments: Properties.T, timeout: Time.time, theory: theory}; +type context = + {index: int, tag: string, arguments: Properties.T, timeout: Time.time, theory: theory}; structure Data = Theory_Data ( @@ -82,7 +84,7 @@ let val action = #2 (Name_Space.check (Context.Theory thy) (Data.get thy) (name, Position.none)); val tag = "#" ^ Value.print_int index ^ " " ^ name ^ ": "; - val context = {tag = tag, arguments = arguments, timeout = timeout, theory = thy}; + val context = {index = index, tag = tag, arguments = arguments, timeout = timeout, theory = thy}; val export_body = action context commands; val export_head = log_action name arguments; val export_name = Path.binding0 (Path.basic "mirabelle" + Path.basic (Value.print_int index)); diff -r 03e134d5f867 -r 0e7a5c7a14c8 src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML --- a/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML Sat May 15 17:40:36 2021 +0200 +++ b/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML Sat May 15 22:06:05 2021 +0200 @@ -1,10 +1,11 @@ (* Title: HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML Author: Jasmin Blanchette and Sascha Boehme and Tobias Nipkow, TU Munich + Author: Makarius Mirabelle action: "sledgehammer". *) -structure Mirabelle_Sledgehammer : MIRABELLE_ACTION = +structure Mirabelle_Sledgehammer: sig end = struct (*To facilitate synching the description of Mirabelle Sledgehammer parameters @@ -40,9 +41,6 @@ val type_encK = "type_enc" (*=STRING: type encoding scheme*) val uncurried_aliasesK = "uncurried_aliases" (*=SMART_BOOL: use fresh function names to alias curried applications*) -fun sh_tag id = "#" ^ string_of_int id ^ " sledgehammer: " -fun proof_method_tag meth id = "#" ^ string_of_int id ^ " " ^ (!meth) ^ " (sledgehammer): " - val separator = "-----" (*FIXME sensible to have Mirabelle-level Sledgehammer defaults?*) @@ -121,6 +119,8 @@ re_u: re_data (* proof method with unminimized set of lemmas *) } +type change_data = (data -> data) -> unit + fun make_data (sh, re_u) = Data {sh=sh, re_u=re_u} val empty_data = make_data (empty_sh_data, empty_re_data) @@ -216,89 +216,98 @@ val str = string_of_int val str3 = Real.fmt (StringCvt.FIX (SOME 3)) fun percentage a b = string_of_int (a * 100 div b) -fun time t = Real.fromInt t / 1000.0 +fun ms t = Real.fromInt t / 1000.0 fun avg_time t n = if n > 0 then (Real.fromInt t / 1000.0) / Real.fromInt n else 0.0 -fun log_sh_data log - (calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail) = - (log ("Total number of sledgehammer calls: " ^ str calls); - log ("Number of successful sledgehammer calls: " ^ str success); - log ("Number of sledgehammer lemmas: " ^ str lemmas); - log ("Max number of sledgehammer lemmas: " ^ str max_lems); - log ("Success rate: " ^ percentage success calls ^ "%"); - log ("Total number of nontrivial sledgehammer calls: " ^ str nontriv_calls); - log ("Number of successful nontrivial sledgehammer calls: " ^ str nontriv_success); - log ("Total time for sledgehammer calls (Isabelle): " ^ str3 (time time_isa)); - log ("Total time for successful sledgehammer calls (ATP): " ^ str3 (time time_prover)); - log ("Total time for failed sledgehammer calls (ATP): " ^ str3 (time time_prover_fail)); - log ("Average time for sledgehammer calls (Isabelle): " ^ - str3 (avg_time time_isa calls)); - log ("Average time for successful sledgehammer calls (ATP): " ^ - str3 (avg_time time_prover success)); - log ("Average time for failed sledgehammer calls (ATP): " ^ - str3 (avg_time time_prover_fail (calls - success))) - ) +fun log_sh_data (ShData + {calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail}) = + let + val props = + [("sh_calls", str calls), + ("sh_success", str success), + ("sh_nontriv_calls", str nontriv_calls), + ("sh_nontriv_success", str nontriv_success), + ("sh_lemmas", str lemmas), + ("sh_max_lems", str max_lems), + ("sh_time_isa", str3 (ms time_isa)), + ("sh_time_prover", str3 (ms time_prover)), + ("sh_time_prover_fail", str3 (ms time_prover_fail))] + val text = + "\nTotal number of sledgehammer calls: " ^ str calls ^ + "\nNumber of successful sledgehammer calls: " ^ str success ^ + "\nNumber of sledgehammer lemmas: " ^ str lemmas ^ + "\nMax number of sledgehammer lemmas: " ^ str max_lems ^ + "\nSuccess rate: " ^ percentage success calls ^ "%" ^ + "\nTotal number of nontrivial sledgehammer calls: " ^ str nontriv_calls ^ + "\nNumber of successful nontrivial sledgehammer calls: " ^ str nontriv_success ^ + "\nTotal time for sledgehammer calls (Isabelle): " ^ str3 (ms time_isa) ^ + "\nTotal time for successful sledgehammer calls (ATP): " ^ str3 (ms time_prover) ^ + "\nTotal time for failed sledgehammer calls (ATP): " ^ str3 (ms time_prover_fail) ^ + "\nAverage time for sledgehammer calls (Isabelle): " ^ + str3 (avg_time time_isa calls) ^ + "\nAverage time for successful sledgehammer calls (ATP): " ^ + str3 (avg_time time_prover success) ^ + "\nAverage time for failed sledgehammer calls (ATP): " ^ + str3 (avg_time time_prover_fail (calls - success)) + in (props, text) end -fun str_of_pos (pos, triv) = - str0 (Position.line_of pos) ^ ":" ^ str0 (Position.offset_of pos) ^ - (if triv then "[T]" else "") - -fun log_re_data log tag sh_calls (re_calls, re_success, re_nontriv_calls, - re_nontriv_success, re_proofs, re_time, re_timeout, - (lemmas, lems_sos, lems_max), re_posns) = - (log ("Total number of " ^ tag ^ "proof method calls: " ^ str re_calls); - log ("Number of successful " ^ tag ^ "proof method calls: " ^ str re_success ^ - " (proof: " ^ str re_proofs ^ ")"); - log ("Number of " ^ tag ^ "proof method timeouts: " ^ str re_timeout); - log ("Success rate: " ^ percentage re_success sh_calls ^ "%"); - log ("Total number of nontrivial " ^ tag ^ "proof method calls: " ^ str re_nontriv_calls); - log ("Number of successful nontrivial " ^ tag ^ "proof method calls: " ^ str re_nontriv_success ^ - " (proof: " ^ str re_proofs ^ ")"); - log ("Number of successful " ^ tag ^ "proof method lemmas: " ^ str lemmas); - log ("SOS of successful " ^ tag ^ "proof method lemmas: " ^ str lems_sos); - log ("Max number of successful " ^ tag ^ "proof method lemmas: " ^ str lems_max); - log ("Total time for successful " ^ tag ^ "proof method calls: " ^ str3 (time re_time)); - log ("Average time for successful " ^ tag ^ "proof method calls: " ^ - str3 (avg_time re_time re_success)); - if tag="" - then log ("Proved: " ^ space_implode " " (map str_of_pos re_posns)) - else () - ) +fun log_re_data sh_calls (ReData {calls, success, nontriv_calls, + nontriv_success, proofs, time, timeout, lemmas = (lemmas, lems_sos, lems_max), posns}) = + let + val proved = + posns |> map (fn (pos, triv) => + str0 (Position.line_of pos) ^ ":" ^ str0 (Position.offset_of pos) ^ + (if triv then "[T]" else "")) + val props = + [("re_calls", str calls), + ("re_success", str success), + ("re_nontriv_calls", str nontriv_calls), + ("re_nontriv_success", str nontriv_success), + ("re_proofs", str proofs), + ("re_time", str3 (ms time)), + ("re_timeout", str timeout), + ("re_lemmas", str lemmas), + ("re_lems_sos", str lems_sos), + ("re_lems_max", str lems_max), + ("re_proved", space_implode "," proved)] + val text = + "\nTotal number of proof method calls: " ^ str calls ^ + "\nNumber of successful proof method calls: " ^ str success ^ + " (proof: " ^ str proofs ^ ")" ^ + "\nNumber of proof method timeouts: " ^ str timeout ^ + "\nSuccess rate: " ^ percentage success sh_calls ^ "%" ^ + "\nTotal number of nontrivial proof method calls: " ^ str nontriv_calls ^ + "\nNumber of successful nontrivial proof method calls: " ^ str nontriv_success ^ + " (proof: " ^ str proofs ^ ")" ^ + "\nNumber of successful proof method lemmas: " ^ str lemmas ^ + "\nSOS of successful proof method lemmas: " ^ str lems_sos ^ + "\nMax number of successful proof method lemmas: " ^ str lems_max ^ + "\nTotal time for successful proof method calls: " ^ str3 (ms time) ^ + "\nAverage time for successful proof method calls: " ^ str3 (avg_time time success) ^ + "\nProved: " ^ space_implode " " proved + in (props, text) end in -fun log_data id log (Data {sh, re_u}) = +fun log_data index (Data {sh, re_u}) = let val ShData {calls=sh_calls, ...} = sh - - fun app_if (ReData {calls, ...}) f = if calls > 0 then f () else () - fun log_re tag m = - log_re_data log tag sh_calls (tuple_of_re_data m) - fun log_proof_method (tag1, m1) = app_if m1 (fn () => (log_re tag1 m1; log "")) + val ReData {calls=re_calls, ...} = re_u in - if sh_calls > 0 - then - (log ("\n\n\nReport #" ^ string_of_int id ^ ":\n"); - log_sh_data log (tuple_of_sh_data sh); - log ""; - log_proof_method ("", re_u)) - else () + if sh_calls > 0 then + let + val (props1, text1) = log_sh_data sh + val (props2, text2) = log_re_data sh_calls re_u + val text = + "\n\nReport #" ^ string_of_int index ^ ":\n" ^ + (if re_calls > 0 then text1 ^ "\n" ^ text2 else text1) + in [Mirabelle.log_report (props1 @ props2) [XML.Text text]] end + else [] end end -(* Warning: we implicitly assume single-threaded execution here *) -val data = Unsynchronized.ref ([] : (int * data) list) - -fun init id thy = (Unsynchronized.change data (cons (id, empty_data)); thy) -fun done id ({log, ...}: Mirabelle.done_args) = - AList.lookup (op =) (!data) id - |> Option.map (log_data id log) - |> K () - -fun change_data id f = (Unsynchronized.change data (AList.map_entry (op =) id f); ()) - fun get_prover_name thy args = let fun default_prover_name () = @@ -450,13 +459,12 @@ in -fun run_sledgehammer trivial args proof_method named_thms id - ({pre=st, log, pos, ...}: Mirabelle.run_args) = +fun run_sledgehammer change_data trivial args proof_method named_thms pos st = let val thy = Proof.theory_of st val triv_str = if trivial then "[T] " else "" - val _ = change_data id inc_sh_calls - val _ = if trivial then () else change_data id inc_sh_nontriv_calls + val _ = change_data inc_sh_calls + val _ = if trivial then () else change_data inc_sh_nontriv_calls val prover_name = get_prover_name thy args val fact_filter = AList.lookup (op =) args fact_filterK |> the_default fact_filter_default val type_enc = AList.lookup (op =) args type_encK |> the_default type_enc_default @@ -476,7 +484,7 @@ val force_sos = AList.lookup (op =) args force_sosK |> Option.map (curry (op <>) "false") val dir = AList.lookup (op =) args keepK - val timeout = Mirabelle.get_int_setting args (prover_timeoutK, 30) + val timeout = Mirabelle.get_int_argument args (prover_timeoutK, 30) (* always use a hard timeout, but give some slack so that the automatic minimizer has a chance to do its magic *) val preplay_timeout = AList.lookup (op =) args preplay_timeoutK @@ -502,23 +510,23 @@ name |> Option.map (pair (name, stature)) in - change_data id inc_sh_success; - if trivial then () else change_data id inc_sh_nontriv_success; - change_data id (inc_sh_lemmas (length names)); - change_data id (inc_sh_max_lems (length names)); - change_data id (inc_sh_time_isa time_isa); - change_data id (inc_sh_time_prover time_prover); + change_data inc_sh_success; + if trivial then () else change_data inc_sh_nontriv_success; + change_data (inc_sh_lemmas (length names)); + change_data (inc_sh_max_lems (length names)); + change_data (inc_sh_time_isa time_isa); + change_data (inc_sh_time_prover time_prover); proof_method := proof_method_from_msg args msg; named_thms := SOME (map_filter get_thms names); - log (sh_tag id ^ triv_str ^ "succeeded (" ^ string_of_int time_isa ^ "+" ^ - string_of_int time_prover ^ ") [" ^ prover_name ^ "]:\n" ^ msg) + triv_str ^ "succeeded (" ^ string_of_int time_isa ^ "+" ^ + string_of_int time_prover ^ ") [" ^ prover_name ^ "]:\n" ^ msg end | SH_FAIL (time_isa, time_prover) => let - val _ = change_data id (inc_sh_time_isa time_isa) - val _ = change_data id (inc_sh_time_prover_fail time_prover) - in log (sh_tag id ^ triv_str ^ "failed: " ^ msg) end - | SH_ERROR => log (sh_tag id ^ "failed: " ^ msg)) + val _ = change_data (inc_sh_time_isa time_isa) + val _ = change_data (inc_sh_time_prover_fail time_prover) + in triv_str ^ "failed: " ^ msg end + | SH_ERROR => "failed: " ^ msg) end end @@ -531,8 +539,7 @@ ("slice", "false"), ("timeout", timeout |> Time.toSeconds |> string_of_int)] -fun run_proof_method trivial full name meth named_thms id - ({pre=st, timeout, log, pos, ...}: Mirabelle.run_args) = +fun run_proof_method change_data trivial full name meth named_thms timeout pos st = let fun do_method named_thms ctxt = let @@ -582,79 +589,96 @@ Mirabelle.can_apply timeout (do_method named_thms) st fun with_time (false, t) = "failed (" ^ string_of_int t ^ ")" - | with_time (true, t) = (change_data id inc_proof_method_success; + | with_time (true, t) = (change_data inc_proof_method_success; if trivial then () - else change_data id inc_proof_method_nontriv_success; - change_data id (inc_proof_method_lemmas (length named_thms)); - change_data id (inc_proof_method_time t); - change_data id (inc_proof_method_posns (pos, trivial)); - if name = "proof" then change_data id inc_proof_method_proofs else (); + else change_data inc_proof_method_nontriv_success; + change_data (inc_proof_method_lemmas (length named_thms)); + change_data (inc_proof_method_time t); + change_data (inc_proof_method_posns (pos, trivial)); + if name = "proof" then change_data inc_proof_method_proofs else (); "succeeded (" ^ string_of_int t ^ ")") fun timed_method named_thms = - (with_time (Mirabelle.cpu_time apply_method named_thms), true) - handle Timeout.TIMEOUT _ => (change_data id inc_proof_method_timeout; ("timeout", false)) - | ERROR msg => ("error: " ^ msg, false) + with_time (Mirabelle.cpu_time apply_method named_thms) + handle Timeout.TIMEOUT _ => (change_data inc_proof_method_timeout; "timeout") + | ERROR msg => ("error: " ^ msg) - val _ = log separator - val _ = change_data id inc_proof_method_calls - val _ = if trivial then () else change_data id inc_proof_method_nontriv_calls - in - named_thms - |> timed_method - |>> log o prefix (proof_method_tag meth id) - |> snd - end + val _ = change_data inc_proof_method_calls + val _ = if trivial then () else change_data inc_proof_method_nontriv_calls + in timed_method named_thms end val try_timeout = seconds 5.0 +fun catch e = + e () handle exn => + if Exn.is_interrupt exn then Exn.reraise exn + else Mirabelle.print_exn exn + (* crude hack *) val num_sledgehammer_calls = Unsynchronized.ref 0 val remaining_stride = Unsynchronized.ref stride_default -fun sledgehammer_action args = - let - val stride = Mirabelle.get_int_setting args (strideK, stride_default) - val max_calls = Mirabelle.get_int_setting args (max_callsK, max_calls_default) - val check_trivial = Mirabelle.get_bool_setting args (check_trivialK, check_trivial_default) - in - fn id => fn (st as {pre, name, log, ...}: Mirabelle.run_args) => - let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in - if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then - () - else if !remaining_stride > 1 then - (* We still have some steps to do *) - (remaining_stride := !remaining_stride - 1; - log "Skipping because of stride") - else - (* This was the last step, now run the action *) - let - val _ = remaining_stride := stride - val _ = num_sledgehammer_calls := !num_sledgehammer_calls + 1 - in - if !num_sledgehammer_calls > max_calls then - log "Skipping because max number of calls reached" - else - let - val meth = Unsynchronized.ref "" - val named_thms = - Unsynchronized.ref (NONE : ((string * stature) * thm list) list option) - val trivial = - if check_trivial then - Try0.try0 (SOME try_timeout) ([], [], [], []) pre - handle Timeout.TIMEOUT _ => false - else false - fun apply_method () = - (Mirabelle.catch_result (proof_method_tag meth) false - (run_proof_method trivial false name meth (these (!named_thms))) id st; ()) - in - Mirabelle.catch sh_tag (run_sledgehammer trivial args meth named_thms) id st; - if is_some (!named_thms) then apply_method () else () - end - end - end - end +val _ = + Theory.setup (Mirabelle.theory_action \<^binding>\sledgehammer\ + (fn context => fn commands => + let + val {index, tag = sh_tag, arguments = args, timeout, ...} = context + fun proof_method_tag meth = "#" ^ string_of_int index ^ " " ^ meth ^ " (sledgehammer): " + + val data = Unsynchronized.ref empty_data + val change_data = Unsynchronized.change data + + val stride = Mirabelle.get_int_argument args (strideK, stride_default) + val max_calls = Mirabelle.get_int_argument args (max_callsK, max_calls_default) + val check_trivial = Mirabelle.get_bool_argument args (check_trivialK, check_trivial_default) -fun invoke args = - Mirabelle.register (init, sledgehammer_action args, done) + val results = + commands |> maps (fn command => + let + val {name, pos, pre = st, ...} = command + val goal = Thm.major_prem_of (#goal (Proof.goal st)) + val log = + if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then [] + else if !remaining_stride > 1 then + (* We still have some steps to do *) + (Unsynchronized.dec remaining_stride; ["Skipping because of stride"]) + else + (* This was the last step, now run the action *) + let + val _ = remaining_stride := stride + val _ = Unsynchronized.inc num_sledgehammer_calls + in + if !num_sledgehammer_calls > max_calls then + ["Skipping because max number of calls reached"] + else + let + val meth = Unsynchronized.ref "" + val named_thms = + Unsynchronized.ref (NONE : ((string * stature) * thm list) list option) + val trivial = + if check_trivial then + Try0.try0 (SOME try_timeout) ([], [], [], []) st + handle Timeout.TIMEOUT _ => false + else false + val log1 = + sh_tag ^ catch (fn () => + run_sledgehammer change_data trivial args meth named_thms pos st) + val log2 = + (case ! named_thms of + SOME thms => + [separator, + proof_method_tag (!meth) ^ + catch (fn () => + run_proof_method change_data trivial false name meth thms + timeout pos st)] + | NONE => []) + in log1 :: log2 end + end + in + if null log then [] + else [Mirabelle.log_command command [XML.Text (cat_lines log)]] + end) + + val report = log_data index (! data) + in results @ report end)) end