reactive "sledgehammer";
authorwenzelm
Sat, 15 May 2021 22:06:05 +0200
changeset 73697 0e7a5c7a14c8
parent 73696 03e134d5f867
child 73698 3d0952893db8
reactive "sledgehammer";
src/HOL/Mirabelle.thy
src/HOL/Tools/Mirabelle/mirabelle.ML
src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML
--- a/src/HOL/Mirabelle.thy	Sat May 15 17:40:36 2021 +0200
+++ b/src/HOL/Mirabelle.thy	Sat May 15 22:06:05 2021 +0200
@@ -11,9 +11,7 @@
 ML_file \<open>Tools/Mirabelle/mirabelle_arith.ML\<close>
 ML_file \<open>Tools/Mirabelle/mirabelle_metis.ML\<close>
 ML_file \<open>Tools/Mirabelle/mirabelle_quickcheck.ML\<close>
-(*
 ML_file \<open>Tools/Mirabelle/mirabelle_sledgehammer.ML\<close>
-*)
 ML_file \<open>Tools/Mirabelle/mirabelle_sledgehammer_filter.ML\<close>
 ML_file \<open>Tools/Mirabelle/mirabelle_try0.ML\<close>
 
--- a/src/HOL/Tools/Mirabelle/mirabelle.ML	Sat May 15 17:40:36 2021 +0200
+++ b/src/HOL/Tools/Mirabelle/mirabelle.ML	Sat May 15 22:06:05 2021 +0200
@@ -8,7 +8,8 @@
   (*core*)
   val print_name: string -> string
   val print_properties: Properties.T -> string
-  type context = {tag: string, arguments: Properties.T, timeout: Time.time, theory: theory}
+  type context =
+    {index: int, tag: string, arguments: Properties.T, timeout: Time.time, theory: theory}
   type command = {name: string, pos: Position.T, pre: Proof.state, post: Toplevel.state}
   val theory_action: binding -> (context -> command list -> XML.body) -> theory -> theory
   val log_command: command -> XML.body -> XML.tree
@@ -48,7 +49,8 @@
 (* actions *)
 
 type command = {name: string, pos: Position.T, pre: Proof.state, post: Toplevel.state};
-type context = {tag: string, arguments: Properties.T, timeout: Time.time, theory: theory};
+type context =
+  {index: int, tag: string, arguments: Properties.T, timeout: Time.time, theory: theory};
 
 structure Data = Theory_Data
 (
@@ -82,7 +84,7 @@
   let
     val action = #2 (Name_Space.check (Context.Theory thy) (Data.get thy) (name, Position.none));
     val tag = "#" ^ Value.print_int index ^ " " ^ name ^ ": ";
-    val context = {tag = tag, arguments = arguments, timeout = timeout, theory = thy};
+    val context = {index = index, tag = tag, arguments = arguments, timeout = timeout, theory = thy};
     val export_body = action context commands;
     val export_head = log_action name arguments;
     val export_name = Path.binding0 (Path.basic "mirabelle" + Path.basic (Value.print_int index));
--- a/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Sat May 15 17:40:36 2021 +0200
+++ b/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Sat May 15 22:06:05 2021 +0200
@@ -1,10 +1,11 @@
 (*  Title:      HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML
     Author:     Jasmin Blanchette and Sascha Boehme and Tobias Nipkow, TU Munich
+    Author:     Makarius
 
 Mirabelle action: "sledgehammer".
 *)
 
-structure Mirabelle_Sledgehammer : MIRABELLE_ACTION =
+structure Mirabelle_Sledgehammer: sig end =
 struct
 
 (*To facilitate synching the description of Mirabelle Sledgehammer parameters
@@ -40,9 +41,6 @@
 val type_encK = "type_enc" (*=STRING: type encoding scheme*)
 val uncurried_aliasesK = "uncurried_aliases" (*=SMART_BOOL: use fresh function names to alias curried applications*)
 
-fun sh_tag id = "#" ^ string_of_int id ^ " sledgehammer: "
-fun proof_method_tag meth id = "#" ^ string_of_int id ^ " " ^ (!meth) ^ " (sledgehammer): "
-
 val separator = "-----"
 
 (*FIXME sensible to have Mirabelle-level Sledgehammer defaults?*)
@@ -121,6 +119,8 @@
   re_u: re_data (* proof method with unminimized set of lemmas *)
   }
 
+type change_data = (data -> data) -> unit
+
 fun make_data (sh, re_u) = Data {sh=sh, re_u=re_u}
 
 val empty_data = make_data (empty_sh_data, empty_re_data)
@@ -216,89 +216,98 @@
 val str = string_of_int
 val str3 = Real.fmt (StringCvt.FIX (SOME 3))
 fun percentage a b = string_of_int (a * 100 div b)
-fun time t = Real.fromInt t / 1000.0
+fun ms t = Real.fromInt t / 1000.0
 fun avg_time t n =
   if n > 0 then (Real.fromInt t / 1000.0) / Real.fromInt n else 0.0
 
-fun log_sh_data log
-    (calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail) =
- (log ("Total number of sledgehammer calls: " ^ str calls);
-  log ("Number of successful sledgehammer calls: " ^ str success);
-  log ("Number of sledgehammer lemmas: " ^ str lemmas);
-  log ("Max number of sledgehammer lemmas: " ^ str max_lems);
-  log ("Success rate: " ^ percentage success calls ^ "%");
-  log ("Total number of nontrivial sledgehammer calls: " ^ str nontriv_calls);
-  log ("Number of successful nontrivial sledgehammer calls: " ^ str nontriv_success);
-  log ("Total time for sledgehammer calls (Isabelle): " ^ str3 (time time_isa));
-  log ("Total time for successful sledgehammer calls (ATP): " ^ str3 (time time_prover));
-  log ("Total time for failed sledgehammer calls (ATP): " ^ str3 (time time_prover_fail));
-  log ("Average time for sledgehammer calls (Isabelle): " ^
-    str3 (avg_time time_isa calls));
-  log ("Average time for successful sledgehammer calls (ATP): " ^
-    str3 (avg_time time_prover success));
-  log ("Average time for failed sledgehammer calls (ATP): " ^
-    str3 (avg_time time_prover_fail (calls - success)))
-  )
+fun log_sh_data (ShData
+    {calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail}) =
+  let
+    val props =
+     [("sh_calls", str calls),
+      ("sh_success", str success),
+      ("sh_nontriv_calls", str nontriv_calls),
+      ("sh_nontriv_success", str nontriv_success),
+      ("sh_lemmas", str lemmas),
+      ("sh_max_lems", str max_lems),
+      ("sh_time_isa", str3 (ms time_isa)),
+      ("sh_time_prover", str3 (ms time_prover)),
+      ("sh_time_prover_fail", str3 (ms time_prover_fail))]
+    val text =
+      "\nTotal number of sledgehammer calls: " ^ str calls ^
+      "\nNumber of successful sledgehammer calls: " ^ str success ^
+      "\nNumber of sledgehammer lemmas: " ^ str lemmas ^
+      "\nMax number of sledgehammer lemmas: " ^ str max_lems ^
+      "\nSuccess rate: " ^ percentage success calls ^ "%" ^
+      "\nTotal number of nontrivial sledgehammer calls: " ^ str nontriv_calls ^
+      "\nNumber of successful nontrivial sledgehammer calls: " ^ str nontriv_success ^
+      "\nTotal time for sledgehammer calls (Isabelle): " ^ str3 (ms time_isa) ^
+      "\nTotal time for successful sledgehammer calls (ATP): " ^ str3 (ms time_prover) ^
+      "\nTotal time for failed sledgehammer calls (ATP): " ^ str3 (ms time_prover_fail) ^
+      "\nAverage time for sledgehammer calls (Isabelle): " ^
+        str3 (avg_time time_isa calls) ^
+      "\nAverage time for successful sledgehammer calls (ATP): " ^
+        str3 (avg_time time_prover success) ^
+      "\nAverage time for failed sledgehammer calls (ATP): " ^
+        str3 (avg_time time_prover_fail (calls - success))
+  in (props, text) end
 
-fun str_of_pos (pos, triv) =
-  str0 (Position.line_of pos) ^ ":" ^ str0 (Position.offset_of pos) ^
-  (if triv then "[T]" else "")
-
-fun log_re_data log tag sh_calls (re_calls, re_success, re_nontriv_calls,
-     re_nontriv_success, re_proofs, re_time, re_timeout,
-    (lemmas, lems_sos, lems_max), re_posns) =
- (log ("Total number of " ^ tag ^ "proof method calls: " ^ str re_calls);
-  log ("Number of successful " ^ tag ^ "proof method calls: " ^ str re_success ^
-    " (proof: " ^ str re_proofs ^ ")");
-  log ("Number of " ^ tag ^ "proof method timeouts: " ^ str re_timeout);
-  log ("Success rate: " ^ percentage re_success sh_calls ^ "%");
-  log ("Total number of nontrivial " ^ tag ^ "proof method calls: " ^ str re_nontriv_calls);
-  log ("Number of successful nontrivial " ^ tag ^ "proof method calls: " ^ str re_nontriv_success ^
-    " (proof: " ^ str re_proofs ^ ")");
-  log ("Number of successful " ^ tag ^ "proof method lemmas: " ^ str lemmas);
-  log ("SOS of successful " ^ tag ^ "proof method lemmas: " ^ str lems_sos);
-  log ("Max number of successful " ^ tag ^ "proof method lemmas: " ^ str lems_max);
-  log ("Total time for successful " ^ tag ^ "proof method calls: " ^ str3 (time re_time));
-  log ("Average time for successful " ^ tag ^ "proof method calls: " ^
-    str3 (avg_time re_time re_success));
-  if tag=""
-  then log ("Proved: " ^ space_implode " " (map str_of_pos re_posns))
-  else ()
- )
+fun log_re_data sh_calls (ReData {calls, success, nontriv_calls,
+     nontriv_success, proofs, time, timeout, lemmas = (lemmas, lems_sos, lems_max), posns}) =
+  let
+    val proved =
+      posns |> map (fn (pos, triv) =>
+        str0 (Position.line_of pos) ^ ":" ^ str0 (Position.offset_of pos) ^
+        (if triv then "[T]" else ""))
+    val props =
+     [("re_calls", str calls),
+      ("re_success", str success),
+      ("re_nontriv_calls", str nontriv_calls),
+      ("re_nontriv_success", str nontriv_success),
+      ("re_proofs", str proofs),
+      ("re_time", str3 (ms time)),
+      ("re_timeout", str timeout),
+      ("re_lemmas", str lemmas),
+      ("re_lems_sos", str lems_sos),
+      ("re_lems_max", str lems_max),
+      ("re_proved", space_implode "," proved)]
+    val text =
+      "\nTotal number of proof method calls: " ^ str calls ^
+      "\nNumber of successful proof method calls: " ^ str success ^
+        " (proof: " ^ str proofs ^ ")" ^
+      "\nNumber of proof method timeouts: " ^ str timeout ^
+      "\nSuccess rate: " ^ percentage success sh_calls ^ "%" ^
+      "\nTotal number of nontrivial proof method calls: " ^ str nontriv_calls ^
+      "\nNumber of successful nontrivial proof method calls: " ^ str nontriv_success ^
+        " (proof: " ^ str proofs ^ ")" ^
+      "\nNumber of successful proof method lemmas: " ^ str lemmas ^
+      "\nSOS of successful proof method lemmas: " ^ str lems_sos ^
+      "\nMax number of successful proof method lemmas: " ^ str lems_max ^
+      "\nTotal time for successful proof method calls: " ^ str3 (ms time) ^
+      "\nAverage time for successful proof method calls: " ^ str3 (avg_time time success) ^
+      "\nProved: " ^ space_implode " " proved
+  in (props, text) end
 
 in
 
-fun log_data id log (Data {sh, re_u}) =
+fun log_data index (Data {sh, re_u}) =
   let
     val ShData {calls=sh_calls, ...} = sh
-
-    fun app_if (ReData {calls, ...}) f = if calls > 0 then f () else ()
-    fun log_re tag m =
-      log_re_data log tag sh_calls (tuple_of_re_data m)
-    fun log_proof_method (tag1, m1) = app_if m1 (fn () => (log_re tag1 m1; log ""))
+    val ReData {calls=re_calls, ...} = re_u
   in
-    if sh_calls > 0
-    then
-     (log ("\n\n\nReport #" ^ string_of_int id ^ ":\n");
-      log_sh_data log (tuple_of_sh_data sh);
-      log "";
-      log_proof_method ("", re_u))
-    else ()
+    if sh_calls > 0 then
+      let
+        val (props1, text1) = log_sh_data sh
+        val (props2, text2) = log_re_data sh_calls re_u
+        val text =
+          "\n\nReport #" ^ string_of_int index ^ ":\n" ^
+          (if re_calls > 0 then text1 ^ "\n" ^ text2 else text1)
+      in [Mirabelle.log_report (props1 @ props2) [XML.Text text]] end
+    else []
   end
 
 end
 
-(* Warning: we implicitly assume single-threaded execution here *)
-val data = Unsynchronized.ref ([] : (int * data) list)
-
-fun init id thy = (Unsynchronized.change data (cons (id, empty_data)); thy)
-fun done id ({log, ...}: Mirabelle.done_args) =
-  AList.lookup (op =) (!data) id
-  |> Option.map (log_data id log)
-  |> K ()
-
-fun change_data id f = (Unsynchronized.change data (AList.map_entry (op =) id f); ())
-
 fun get_prover_name thy args =
   let
     fun default_prover_name () =
@@ -450,13 +459,12 @@
 
 in
 
-fun run_sledgehammer trivial args proof_method named_thms id
-      ({pre=st, log, pos, ...}: Mirabelle.run_args) =
+fun run_sledgehammer change_data trivial args proof_method named_thms pos st =
   let
     val thy = Proof.theory_of st
     val triv_str = if trivial then "[T] " else ""
-    val _ = change_data id inc_sh_calls
-    val _ = if trivial then () else change_data id inc_sh_nontriv_calls
+    val _ = change_data inc_sh_calls
+    val _ = if trivial then () else change_data inc_sh_nontriv_calls
     val prover_name = get_prover_name thy args
     val fact_filter = AList.lookup (op =) args fact_filterK |> the_default fact_filter_default
     val type_enc = AList.lookup (op =) args type_encK |> the_default type_enc_default
@@ -476,7 +484,7 @@
     val force_sos = AList.lookup (op =) args force_sosK
       |> Option.map (curry (op <>) "false")
     val dir = AList.lookup (op =) args keepK
-    val timeout = Mirabelle.get_int_setting args (prover_timeoutK, 30)
+    val timeout = Mirabelle.get_int_argument args (prover_timeoutK, 30)
     (* always use a hard timeout, but give some slack so that the automatic
        minimizer has a chance to do its magic *)
     val preplay_timeout = AList.lookup (op =) args preplay_timeoutK
@@ -502,23 +510,23 @@
               name
             |> Option.map (pair (name, stature))
         in
-          change_data id inc_sh_success;
-          if trivial then () else change_data id inc_sh_nontriv_success;
-          change_data id (inc_sh_lemmas (length names));
-          change_data id (inc_sh_max_lems (length names));
-          change_data id (inc_sh_time_isa time_isa);
-          change_data id (inc_sh_time_prover time_prover);
+          change_data inc_sh_success;
+          if trivial then () else change_data inc_sh_nontriv_success;
+          change_data (inc_sh_lemmas (length names));
+          change_data (inc_sh_max_lems (length names));
+          change_data (inc_sh_time_isa time_isa);
+          change_data (inc_sh_time_prover time_prover);
           proof_method := proof_method_from_msg args msg;
           named_thms := SOME (map_filter get_thms names);
-          log (sh_tag id ^ triv_str ^ "succeeded (" ^ string_of_int time_isa ^ "+" ^
-            string_of_int time_prover ^ ") [" ^ prover_name ^ "]:\n" ^ msg)
+          triv_str ^ "succeeded (" ^ string_of_int time_isa ^ "+" ^
+            string_of_int time_prover ^ ") [" ^ prover_name ^ "]:\n" ^ msg
         end
     | SH_FAIL (time_isa, time_prover) =>
         let
-          val _ = change_data id (inc_sh_time_isa time_isa)
-          val _ = change_data id (inc_sh_time_prover_fail time_prover)
-        in log (sh_tag id ^ triv_str ^ "failed: " ^ msg) end
-    | SH_ERROR => log (sh_tag id ^ "failed: " ^ msg))
+          val _ = change_data (inc_sh_time_isa time_isa)
+          val _ = change_data (inc_sh_time_prover_fail time_prover)
+        in triv_str ^ "failed: " ^ msg end
+    | SH_ERROR => "failed: " ^ msg)
   end
 
 end
@@ -531,8 +539,7 @@
    ("slice", "false"),
    ("timeout", timeout |> Time.toSeconds |> string_of_int)]
 
-fun run_proof_method trivial full name meth named_thms id
-    ({pre=st, timeout, log, pos, ...}: Mirabelle.run_args) =
+fun run_proof_method change_data trivial full name meth named_thms timeout pos st =
   let
     fun do_method named_thms ctxt =
       let
@@ -582,79 +589,96 @@
       Mirabelle.can_apply timeout (do_method named_thms) st
 
     fun with_time (false, t) = "failed (" ^ string_of_int t ^ ")"
-      | with_time (true, t) = (change_data id inc_proof_method_success;
+      | with_time (true, t) = (change_data inc_proof_method_success;
           if trivial then ()
-          else change_data id inc_proof_method_nontriv_success;
-          change_data id (inc_proof_method_lemmas (length named_thms));
-          change_data id (inc_proof_method_time t);
-          change_data id (inc_proof_method_posns (pos, trivial));
-          if name = "proof" then change_data id inc_proof_method_proofs else ();
+          else change_data inc_proof_method_nontriv_success;
+          change_data (inc_proof_method_lemmas (length named_thms));
+          change_data (inc_proof_method_time t);
+          change_data (inc_proof_method_posns (pos, trivial));
+          if name = "proof" then change_data inc_proof_method_proofs else ();
           "succeeded (" ^ string_of_int t ^ ")")
     fun timed_method named_thms =
-      (with_time (Mirabelle.cpu_time apply_method named_thms), true)
-      handle Timeout.TIMEOUT _ => (change_data id inc_proof_method_timeout; ("timeout", false))
-           | ERROR msg => ("error: " ^ msg, false)
+      with_time (Mirabelle.cpu_time apply_method named_thms)
+        handle Timeout.TIMEOUT _ => (change_data inc_proof_method_timeout; "timeout")
+          | ERROR msg => ("error: " ^ msg)
 
-    val _ = log separator
-    val _ = change_data id inc_proof_method_calls
-    val _ = if trivial then () else change_data id inc_proof_method_nontriv_calls
-  in
-    named_thms
-    |> timed_method
-    |>> log o prefix (proof_method_tag meth id)
-    |> snd
-  end
+    val _ = change_data inc_proof_method_calls
+    val _ = if trivial then () else change_data inc_proof_method_nontriv_calls
+  in timed_method named_thms end
 
 val try_timeout = seconds 5.0
 
+fun catch e =
+  e () handle exn =>
+    if Exn.is_interrupt exn then Exn.reraise exn
+    else Mirabelle.print_exn exn
+
 (* crude hack *)
 val num_sledgehammer_calls = Unsynchronized.ref 0
 val remaining_stride = Unsynchronized.ref stride_default
 
-fun sledgehammer_action args =
-  let
-    val stride = Mirabelle.get_int_setting args (strideK, stride_default)
-    val max_calls = Mirabelle.get_int_setting args (max_callsK, max_calls_default)
-    val check_trivial = Mirabelle.get_bool_setting args (check_trivialK, check_trivial_default)
-  in
-    fn id => fn (st as {pre, name, log, ...}: Mirabelle.run_args) =>
-      let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in
-        if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then
-          ()
-        else if !remaining_stride > 1 then
-          (* We still have some steps to do *)
-          (remaining_stride := !remaining_stride - 1;
-          log "Skipping because of stride")
-        else
-          (* This was the last step, now run the action *)
-          let
-            val _ = remaining_stride := stride
-            val _ = num_sledgehammer_calls := !num_sledgehammer_calls + 1
-          in
-            if !num_sledgehammer_calls > max_calls then
-              log "Skipping because max number of calls reached"
-            else
-              let
-                val meth = Unsynchronized.ref ""
-                val named_thms =
-                  Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
-                val trivial =
-                  if check_trivial then
-                    Try0.try0 (SOME try_timeout) ([], [], [], []) pre
-                    handle Timeout.TIMEOUT _ => false
-                  else false
-                fun apply_method () =
-                  (Mirabelle.catch_result (proof_method_tag meth) false
-                    (run_proof_method trivial false name meth (these (!named_thms))) id st; ())
-              in
-                Mirabelle.catch sh_tag (run_sledgehammer trivial args meth named_thms) id st;
-                if is_some (!named_thms) then apply_method () else ()
-              end
-          end
-      end
-  end
+val _ =
+  Theory.setup (Mirabelle.theory_action \<^binding>\<open>sledgehammer\<close>
+    (fn context => fn commands =>
+      let
+        val {index, tag = sh_tag, arguments = args, timeout, ...} = context
+        fun proof_method_tag meth = "#" ^ string_of_int index ^ " " ^ meth ^ " (sledgehammer): "
+
+        val data = Unsynchronized.ref empty_data
+        val change_data = Unsynchronized.change data
+
+        val stride = Mirabelle.get_int_argument args (strideK, stride_default)
+        val max_calls = Mirabelle.get_int_argument args (max_callsK, max_calls_default)
+        val check_trivial = Mirabelle.get_bool_argument args (check_trivialK, check_trivial_default)
 
-fun invoke args =
-  Mirabelle.register (init, sledgehammer_action args, done)
+        val results =
+          commands |> maps (fn command =>
+            let
+              val {name, pos, pre = st, ...} = command
+              val goal = Thm.major_prem_of (#goal (Proof.goal st))
+              val log =
+                if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then []
+                else if !remaining_stride > 1 then
+                  (* We still have some steps to do *)
+                  (Unsynchronized.dec remaining_stride; ["Skipping because of stride"])
+                else
+                  (* This was the last step, now run the action *)
+                  let
+                    val _ = remaining_stride := stride
+                    val _ = Unsynchronized.inc num_sledgehammer_calls
+                  in
+                    if !num_sledgehammer_calls > max_calls then
+                      ["Skipping because max number of calls reached"]
+                    else
+                      let
+                        val meth = Unsynchronized.ref ""
+                        val named_thms =
+                          Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
+                        val trivial =
+                          if check_trivial then
+                            Try0.try0 (SOME try_timeout) ([], [], [], []) st
+                              handle Timeout.TIMEOUT _ => false
+                          else false
+                        val log1 =
+                          sh_tag ^ catch (fn () =>
+                            run_sledgehammer change_data trivial args meth named_thms pos st)
+                        val log2 =
+                          (case ! named_thms of
+                            SOME thms =>
+                              [separator,
+                               proof_method_tag (!meth) ^
+                               catch (fn () =>
+                                  run_proof_method change_data trivial false name meth thms
+                                    timeout pos st)]
+                          | NONE => [])
+                      in log1 :: log2 end
+                  end
+            in
+              if null log then []
+              else [Mirabelle.log_command command [XML.Text (cat_lines log)]]
+            end)
+
+        val report = log_data index (! data)
+      in results @ report end))
 
 end