src/HOL/Tools/Sledgehammer/sledgehammer_proof_methods.ML
changeset 81254 d3c0734059ee
parent 80910 406a85a25189
child 82202 a1f85f579a07
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_proof_methods.ML	Thu Oct 24 22:05:57 2024 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_proof_methods.ML	Fri Oct 25 15:31:58 2024 +0200
@@ -14,7 +14,7 @@
     SMT_Verit of string
 
   datatype proof_method =
-    Metis_Method of string option * string option |
+    Metis_Method of string option * string option * string list |
     Meson_Method |
     SMT_Method of SMT_backend |
     SATx_Method |
@@ -36,9 +36,10 @@
     Play_Failed
 
   type one_line_params =
-    ((string * stature) list * (proof_method * play_outcome)) * string * int * int
+    ((Pretty.T * stature) list * (proof_method * play_outcome)) * string * int * int
 
   val is_proof_method_direct : proof_method -> bool
+  val pretty_proof_method : string -> string -> Pretty.T list -> proof_method -> Pretty.T
   val string_of_proof_method : string list -> proof_method -> string
   val tac_of_proof_method : Proof.context -> thm list * thm list -> proof_method -> int -> tactic
   val string_of_play_outcome : play_outcome -> string
@@ -52,13 +53,14 @@
 open ATP_Util
 open ATP_Problem_Generate
 open ATP_Proof_Reconstruct
+open Sledgehammer_Util
 
 datatype SMT_backend =
   SMT_Z3 |
   SMT_Verit of string
 
 datatype proof_method =
-  Metis_Method of string option * string option |
+  Metis_Method of string option * string option * string list |
   Meson_Method |
   SMT_Method of SMT_backend |
   SATx_Method |
@@ -80,7 +82,7 @@
   Play_Failed
 
 type one_line_params =
-  ((string * stature) list * (proof_method * play_outcome)) * string * int * int
+  ((Pretty.T * stature) list * (proof_method * play_outcome)) * string * int * int
 
 fun is_proof_method_direct (Metis_Method _) = true
   | is_proof_method_direct Meson_Method = true
@@ -91,7 +93,13 @@
 fun is_proof_method_multi_goal Auto_Method = true
   | is_proof_method_multi_goal _ = false
 
-fun maybe_paren s = s |> not (Symbol_Pos.is_identifier s) ? enclose "(" ")"
+fun pretty_paren prefix suffix = Pretty.enclose (prefix ^ "(") (")" ^ suffix)
+fun pretty_maybe_paren prefix suffix [pretty] =
+    if Symbol_Pos.is_identifier (content_of_pretty pretty) then
+      Pretty.block [Pretty.str prefix, pretty, Pretty.str suffix]
+    else
+      pretty_paren prefix suffix [pretty]
+  | pretty_maybe_paren prefix suffix pretties = pretty_paren prefix suffix pretties
 
 (*
 Combine indexed fact names for pretty-printing.
@@ -99,35 +107,36 @@
 Combines only adjacent same names.
 Input should not have same name with and without index.
 *)
-fun merge_indexed_facts (ss: string list) :string list =
+fun merge_indexed_facts (facts: Pretty.T list) : Pretty.T list =
   let
 
-    fun split (s: string) : string * string =
-      if String.isPrefix "\<open>" s then (s,"")
-      else
-        case first_field "(" s of
-          NONE => (s,"")
-        | SOME (name,isp) => (name, String.substring (isp, 0, size isp - 1))
+    fun split (p: Pretty.T) : (string * string) option =
+      try (unsuffix ")" o content_of_pretty) p
+      |> Option.mapPartial (first_field "(")
+
+    fun add_pretty (name,is) = (SOME (name,is),Pretty.str (name ^ "(" ^ is ^ ")"))
 
-    fun merge ((name1,is1) :: (name2,is2) :: zs) =
+    fun merge ((SOME (name1,is1),p1) :: (y as (SOME (name2,is2),_)) :: zs) =
         if name1 = name2
-        then merge ((name1,is1 ^ "," ^ is2) :: zs)
-        else (name1,is1) :: merge ((name2,is2) :: zs)
-      | merge xs = xs;
-
-    fun parents is = if is = "" then "" else "(" ^ is ^ ")"
+        then merge (add_pretty (name1,is1 ^ "," ^ is2) :: zs)
+        else p1 :: merge (y :: zs)
+      | merge ((_,p) :: ys) = p :: merge ys
+      | merge [] = []
 
   in
-    map (fn (name,is) => name ^ parents is) (merge (map split ss))
+    merge (map (`split) facts)
   end
 
-fun string_of_proof_method ss meth =
+fun pretty_proof_method prefix suffix facts meth =
   let
     val meth_s =
       (case meth of
-        Metis_Method (NONE, NONE) => "metis"
-      | Metis_Method (type_enc_opt, lam_trans_opt) =>
-        "metis (" ^ commas (map_filter I [type_enc_opt, lam_trans_opt]) ^ ")"
+        Metis_Method (NONE, NONE, additional_fact_names) =>
+        implode_space ("metis" :: additional_fact_names)
+      | Metis_Method (type_enc_opt, lam_trans_opt, additional_fact_names) =>
+        implode_space ("metis" ::
+          "(" ^ commas (map_filter I [type_enc_opt, lam_trans_opt]) ^ ")" ::
+          additional_fact_names)
       | Meson_Method => "meson"
       | SMT_Method SMT_Z3 => "smt (z3)"
       | SMT_Method (SMT_Verit strategy) =>
@@ -135,7 +144,7 @@
       | SATx_Method => "satx"
       | Argo_Method => "argo"
       | Blast_Method => "blast"
-      | Simp_Method => if null ss then "simp" else "simp add:"
+      | Simp_Method => if null facts then "simp" else "simp add:"
       | Auto_Method => "auto"
       | Fastforce_Method => "fastforce"
       | Force_Method => "force"
@@ -145,19 +154,25 @@
       | Algebra_Method => "algebra"
       | Order_Method => "order")
   in
-    maybe_paren (implode_space (meth_s :: merge_indexed_facts ss))
+    pretty_maybe_paren prefix suffix
+      (Pretty.str meth_s :: merge_indexed_facts facts |> Pretty.breaks)
   end
 
+fun string_of_proof_method ss =
+  pretty_proof_method "" "" (map Pretty.str ss)
+  #> content_of_pretty
+
 fun tac_of_proof_method ctxt (local_facts, global_facts) meth =
   let
-    fun tac_of_metis (type_enc_opt, lam_trans_opt) =
+    fun tac_of_metis (type_enc_opt, lam_trans_opt, additional_fact_names) =
       let
+        val additional_facts = maps (thms_of_name ctxt) additional_fact_names
         val ctxt = ctxt
           |> Config.put Metis_Tactic.verbose false
           |> Config.put Metis_Tactic.trace false
       in
         SELECT_GOAL (Metis_Tactic.metis_method ((Option.map single type_enc_opt, lam_trans_opt),
-          global_facts) ctxt local_facts)
+          additional_facts @ global_facts) ctxt local_facts)
       end
 
     fun tac_of_smt SMT_Z3 = SMT_Solver.smt_tac
@@ -207,27 +222,36 @@
 (* FIXME *)
 fun proof_method_command meth i n used_chaineds _(*num_chained*) extras =
   let
-    val (indirect_ss, direct_ss) =
+    val (indirect_facts, direct_facts) =
       if is_proof_method_direct meth then ([], extras) else (extras, [])
+    val suffix =
+      if is_proof_method_multi_goal meth andalso n <> 1 then "[1]" else ""
   in
-    (if null indirect_ss then ""
-     else "using " ^ implode_space (merge_indexed_facts indirect_ss) ^ " ") ^
-    apply_on_subgoal i n ^ string_of_proof_method direct_ss meth ^
-    (if is_proof_method_multi_goal meth andalso n <> 1 then "[1]" else "")
+    (if null indirect_facts then []
+     else Pretty.str "using" :: merge_indexed_facts indirect_facts) @
+    [pretty_proof_method (apply_on_subgoal i n) suffix direct_facts meth]
+    |> Pretty.block o Pretty.breaks
+    |> Pretty.symbolic_string_of (* markup string *)
   end
 
 fun try_command_line banner play command =
   let val s = string_of_play_outcome play in
-    banner ^ Active.sendback_markup_command command ^ (s |> s <> "" ? enclose " (" ")")
+    (* Add optional markup break (command may need multiple lines) *)
+    banner ^ Markup.markup (Markup.break {width = 1, indent = 2}) " " ^
+    Active.sendback_markup_command command ^ (s |> s <> "" ? enclose " (" ")")
   end
 
+val failed_command_line =
+  prefix ("One-line proof reconstruction failed:" ^
+    (* Add optional markup break (command may need multiple lines) *)
+    Markup.markup (Markup.break {width = 1, indent = 2}) " ")
+
 fun one_line_proof_text _ num_chained
     ((used_facts, (meth, play)), banner, subgoal, subgoal_count) =
   let val (chained, extra) = List.partition (fn (_, (sc, _)) => sc = Chained) used_facts in
     map fst extra
     |> proof_method_command meth subgoal subgoal_count (map fst chained) num_chained
-    |> (if play = Play_Failed then prefix "One-line proof reconstruction failed: "
-        else try_command_line banner play)
+    |> (if play = Play_Failed then failed_command_line else try_command_line banner play)
   end
 
 end;