refactored try0's internals
authordesharna
Thu, 27 Mar 2025 10:06:43 +0100
changeset 82356 79a86a1ecb3d
parent 82355 4ace4f6f7101
child 82357 a3f30dc05920
refactored try0's internals
src/HOL/Tools/try0.ML
--- a/src/HOL/Tools/try0.ML	Wed Mar 26 09:51:26 2025 +0100
+++ b/src/HOL/Tools/try0.ML	Thu Mar 27 10:06:43 2025 +0100
@@ -67,6 +67,149 @@
   "" |> fold (add_attr_text tagged) tags
 
 type result = {name: string, command: string, time: Time.time, state: Proof.state}
+type proof_method = Time.time option -> tagged_xref list -> Proof.state -> result option
+type proof_method_options = {run_if_auto_try: bool}
+
+val noop_proof_method : proof_method = fn _ => fn _ => fn _ => NONE
+
+local
+  val proof_methods =
+    Synchronized.var "Try0.proof_methods" (Symtab.empty : proof_method Symtab.table);
+  val auto_try_proof_methods_names =
+    Synchronized.var "Try0.auto_try_proof_methods" (Symset.empty : Symset.T);
+in
+
+fun register_proof_method name ({run_if_auto_try} : proof_method_options) proof_method =
+  let
+    val () = if name = "" then error "Registering unnamed proof method" else ()
+    val () = Synchronized.change proof_methods (Symtab.update_new (name, proof_method))
+    val () =
+      if run_if_auto_try then
+        Synchronized.change auto_try_proof_methods_names (Symset.insert name)
+      else
+        ()
+  in () end
+val _ = Symset.dest
+
+fun get_proof_method name = Symtab.lookup (Synchronized.value proof_methods) name;
+
+fun get_all_proof_methods () =
+  Symtab.fold (fn x => fn xs => x :: xs) (Synchronized.value proof_methods) []
+
+fun get_all_proof_method_names () =
+  Symtab.fold (fn (name, _) => fn names => name :: names) (Synchronized.value proof_methods) []
+
+fun get_all_auto_try_proof_method_names () : string list =
+  Symset.dest (Synchronized.value auto_try_proof_methods_names)
+
+fun should_auto_try_proof_method (name : string) : bool =
+  Symset.member (Synchronized.value auto_try_proof_methods_names) name
+
+end
+
+fun get_proof_method_or_noop name =
+  (case get_proof_method name of
+    NONE => noop_proof_method
+  | SOME proof_method => proof_method)
+
+val apply_proof_method = get_proof_method_or_noop
+
+fun maybe_apply_proof_method name mode : proof_method =
+  if mode <> Auto_Try orelse should_auto_try_proof_method name then
+    get_proof_method_or_noop name
+  else
+    noop_proof_method
+
+fun time_string time = string_of_int (Time.toMilliseconds time) ^ " ms"
+fun tool_time_string (s, time) = s ^ ": " ^ time_string time
+
+(* Makes reconstructor tools as silent as possible. The "set_visible" calls suppresses "Unification
+   bound exceeded" warnings and the like. *)
+fun silence_methods debug =
+  Config.put Metis_Tactic.verbose debug
+  #> not debug ? (fn ctxt =>
+      ctxt
+      |> Simplifier_Trace.disable
+      |> Context_Position.set_visible false
+      |> Config.put Unify.unify_trace false
+      |> Config.put Argo_Tactic.trace "none")
+
+fun generic_try0 mode timeout_opt (tagged : tagged_xref list) st =
+  let
+    val st = Proof.map_contexts (silence_methods false) st
+    fun try_method method = method mode timeout_opt tagged st
+    fun get_message {command, time, ...} = "Found proof: " ^ Active.sendback_markup_command
+      command ^ " (" ^ time_string time ^ ")"
+    val print_step = Option.map (tap (writeln o get_message))
+    fun get_results methods : result list =
+      if mode = Normal then
+        methods
+        |> Par_List.map (try_method #> print_step)
+        |> map_filter I
+        |> sort (Time.compare o apply2 #time)
+      else
+        methods
+        |> Par_List.get_some try_method
+        |> the_list
+    val proof_method_names = get_all_proof_method_names ()
+    val maybe_apply_methods = map maybe_apply_proof_method proof_method_names
+  in
+    if mode = Normal then
+      let val names = map quote proof_method_names in
+        writeln ("Trying " ^ implode_space (Try.serial_commas "and" names) ^ "...")
+      end
+    else
+      ();
+    (case get_results maybe_apply_methods of
+      [] => (if mode = Normal then writeln "No proof found" else (); ((false, (noneN, [])), []))
+    | results as {name, command, ...} :: _ =>
+      let
+        val method_times =
+          results
+          |> map (fn {name, time, ...} => (time, name))
+          |> AList.coalesce (op =)
+          |> map (swap o apsnd commas)
+        val message =
+          (case mode of
+             Auto_Try => "Auto Try0 found a proof"
+           | Try => "Try0 found a proof"
+           | Normal => "Try this") ^ ": " ^
+          Active.sendback_markup_command command ^
+          (case method_times of
+            [(_, ms)] => " (" ^ time_string ms ^ ")"
+          | method_times => "\n(" ^ space_implode "; " (map tool_time_string method_times) ^ ")")
+      in
+        ((true, (name, if mode = Auto_Try then [message] else (writeln message; []))), results)
+      end)
+  end
+
+fun try0 timeout_opt = snd oo generic_try0 Normal timeout_opt
+
+fun try0_trans tagged =
+  Toplevel.keep_proof
+    (ignore o generic_try0 Normal (SOME default_timeout) tagged o Toplevel.proof_of)
+
+val parse_fact_refs = Scan.repeat1 (Scan.unless (Parse.name -- Args.colon) Parse.thm)
+
+val parse_attr =
+  Args.$$$ "simp" |-- Args.colon |-- parse_fact_refs >> (map (rpair [Simp]))
+  || Args.$$$ "intro" |-- Args.colon |-- parse_fact_refs >> (map (rpair [Intro]))
+  || Args.$$$ "elim" |-- Args.colon |-- parse_fact_refs >> (map (rpair [Elim]))
+  || Args.$$$ "dest" |-- Args.colon |-- parse_fact_refs >> (map (rpair [Dest]))
+
+fun parse_attrs x =
+  (Args.parens parse_attrs
+   || Scan.repeat parse_attr >> (fn tagged => fold (curry (op @)) tagged [])) x
+
+val _ =
+  Outer_Syntax.command \<^command_keyword>\<open>try0\<close> "try a combination of proof methods"
+    (Scan.optional parse_attrs [] #>> try0_trans)
+
+val _ =
+  Try.tool_setup
+   {name = "try0", weight = 30, auto_option = \<^system_option>\<open>auto_methods\<close>,
+    body = fn auto => fst o generic_try0 (if auto then Auto_Try else Try) NONE []}
+
 
 local
 
@@ -133,116 +276,17 @@
     else NONE
   end
 
-in
-
-val named_methods = map fst raw_named_methods
-
-fun apply_proof_method name timeout_opt tagged st :
-  result option =
+fun apply_proof_method name timeout_opt (tagged : tagged_xref list) st : result option =
   (case AList.lookup (op =) raw_named_methods name of
     NONE => NONE
   | SOME raw_method => apply_raw_named_method (name, raw_method) timeout_opt tagged st)
 
-fun maybe_apply_proof_method name mode timeout_opt tagged st :
-  result option =
-  (case AList.lookup (op =) raw_named_methods name of
-    NONE => NONE
-  | SOME (raw_method as ((_, run_if_auto_try), _)) =>
-    if mode <> Auto_Try orelse run_if_auto_try then
-      apply_raw_named_method (name, raw_method) timeout_opt tagged st
-    else
-      NONE)
+in
+
+val () = List.app (fn (name, ((_, run_if_auto_try), _)) =>
+  register_proof_method name {run_if_auto_try = run_if_auto_try} (apply_proof_method name)
+  handle Symtab.DUP _ => ()) raw_named_methods
 
 end
 
-val maybe_apply_methods = map maybe_apply_proof_method named_methods
-
-fun time_string time = string_of_int (Time.toMilliseconds time) ^ " ms"
-fun tool_time_string (s, time) = s ^ ": " ^ time_string time
-
-(* Makes reconstructor tools as silent as possible. The "set_visible" calls suppresses "Unification
-   bound exceeded" warnings and the like. *)
-fun silence_methods debug =
-  Config.put Metis_Tactic.verbose debug
-  #> not debug ? (fn ctxt =>
-      ctxt
-      |> Simplifier_Trace.disable
-      |> Context_Position.set_visible false
-      |> Config.put Unify.unify_trace false
-      |> Config.put Argo_Tactic.trace "none")
-
-fun generic_try0 mode timeout_opt (tagged : tagged_xref list) st =
-  let
-    val st = Proof.map_contexts (silence_methods false) st
-    fun try_method method = method mode timeout_opt tagged st
-    fun get_message {command, time, ...} = "Found proof: " ^ Active.sendback_markup_command
-      command ^ " (" ^ time_string time ^ ")"
-    val print_step = Option.map (tap (writeln o get_message))
-    fun get_results methods : result list =
-      if mode = Normal then
-        methods
-        |> Par_List.map (try_method #> print_step)
-        |> map_filter I
-        |> sort (Time.compare o apply2 #time)
-      else
-        methods
-        |> Par_List.get_some try_method
-        |> the_list
-  in
-    if mode = Normal then
-      "Trying " ^ implode_space (Try.serial_commas "and" (map quote named_methods)) ^
-      "..."
-      |> writeln
-    else
-      ();
-    (case get_results maybe_apply_methods of
-      [] => (if mode = Normal then writeln "No proof found" else (); ((false, (noneN, [])), []))
-    | results as {name, command, ...} :: _ =>
-      let
-        val method_times =
-          results
-          |> map (fn {name, time, ...} => (time, name))
-          |> AList.coalesce (op =)
-          |> map (swap o apsnd commas)
-        val message =
-          (case mode of
-             Auto_Try => "Auto Try0 found a proof"
-           | Try => "Try0 found a proof"
-           | Normal => "Try this") ^ ": " ^
-          Active.sendback_markup_command command ^
-          (case method_times of
-            [(_, ms)] => " (" ^ time_string ms ^ ")"
-          | method_times => "\n(" ^ space_implode "; " (map tool_time_string method_times) ^ ")")
-      in
-        ((true, (name, if mode = Auto_Try then [message] else (writeln message; []))), results)
-      end)
-  end
-
-fun try0 timeout_opt = snd oo generic_try0 Normal timeout_opt
-
-fun try0_trans tagged =
-  Toplevel.keep_proof
-    (ignore o generic_try0 Normal (SOME default_timeout) tagged o Toplevel.proof_of)
-
-val parse_fact_refs = Scan.repeat1 (Scan.unless (Parse.name -- Args.colon) Parse.thm)
-
-val parse_attr =
-  Args.$$$ "simp" |-- Args.colon |-- parse_fact_refs >> (map (rpair [Simp]))
-  || Args.$$$ "intro" |-- Args.colon |-- parse_fact_refs >> (map (rpair [Intro]))
-  || Args.$$$ "elim" |-- Args.colon |-- parse_fact_refs >> (map (rpair [Elim]))
-  || Args.$$$ "dest" |-- Args.colon |-- parse_fact_refs >> (map (rpair [Dest]))
-
-fun parse_attrs x =
-  (Args.parens parse_attrs
-   || Scan.repeat parse_attr >> (fn tagged => fold (curry (op @)) tagged [])) x
-
-val _ =
-  Outer_Syntax.command \<^command_keyword>\<open>try0\<close> "try a combination of proof methods"
-    (Scan.optional parse_attrs [] #>> try0_trans)
-
-val _ =
-  Try.tool_setup
-   {name = "try0", weight = 30, auto_option = \<^system_option>\<open>auto_methods\<close>,
-    body = fn auto => fst o generic_try0 (if auto then Auto_Try else Try) NONE []}
-
 end