src/HOL/Tools/Mirabelle/mirabelle.ML
changeset 73847 58f6b41efe88
parent 73822 1192c68ebe1c
child 73848 77306bf4e1ee
--- a/src/HOL/Tools/Mirabelle/mirabelle.ML	Sun Jun 06 21:39:26 2021 +0200
+++ b/src/HOL/Tools/Mirabelle/mirabelle.ML	Thu Jun 10 11:21:57 2021 +0200
@@ -1,23 +1,21 @@
 (*  Title:      HOL/Mirabelle/Tools/mirabelle.ML
-    Author:     Jasmin Blanchette and Sascha Boehme, TU Munich
+    Author:     Jasmin Blanchette, TU Munich
+    Author:     Sascha Boehme, TU Munich
     Author:     Makarius
+    Author:     Martin Desharnais, UniBw Munich
 *)
 
 signature MIRABELLE =
 sig
   (*core*)
-  val print_name: string -> string
-  val print_properties: Properties.T -> string
-  type context =
-    {index: int, tag: string, arguments: Properties.T, timeout: Time.time, theory: theory}
-  type command = {name: string, pos: Position.T, pre: Proof.state, post: Toplevel.state}
-  val theory_action: binding -> (context -> command list -> XML.body) -> theory -> theory
-  val log_command: command -> XML.body -> XML.tree
-  val log_report: Properties.T -> XML.body -> XML.tree
-  val print_exn: exn -> string
-  val command_action: binding -> (context -> command -> string) -> theory -> theory
+  type action_context = {index: int, name: string, arguments: Properties.T, timeout: Time.time}
+  type command =
+    {theory_index: int, name: string, pos: Position.T, pre: Proof.state, post: Toplevel.state}
+  type action = {run_action: command -> string, finalize: unit -> string}
+  val register_action: string -> (action_context -> action) -> unit
 
   (*utility functions*)
+  val print_exn: exn -> string
   val can_apply : Time.time -> (Proof.context -> int -> tactic) ->
     Proof.state -> bool
   val theorems_in_proof_term : theory -> thm -> thm list
@@ -37,9 +35,6 @@
 
 val keywords = Keyword.no_command_keywords (Thy_Header.get_keywords \<^theory>);
 
-val print_name = Token.print_name keywords;
-val print_properties = Token.print_properties keywords;
-
 fun read_actions str =
   Token.read_body keywords
     (Parse.enum ";" (Parse.name -- Sledgehammer_Commands.parse_params))
@@ -48,68 +43,69 @@
 
 (* actions *)
 
-type command = {name: string, pos: Position.T, pre: Proof.state, post: Toplevel.state};
-type context =
-  {index: int, tag: string, arguments: Properties.T, timeout: Time.time, theory: theory};
-
-structure Data = Theory_Data
-(
-  type T = (context -> command list -> XML.body) Name_Space.table;
-  val empty = Name_Space.empty_table "mirabelle_action";
-  val extend = I;
-  val merge = Name_Space.merge_tables;
-);
+type command =
+  {theory_index: int, name: string, pos: Position.T, pre: Proof.state, post: Toplevel.state};
+type action_context = {index: int, name: string, arguments: Properties.T, timeout: Time.time};
+type action = {run_action: command -> string, finalize: unit -> string};
 
-fun theory_action binding action thy =
-  let val context = Context.Theory thy |> Name_Space.map_naming (K Name_Space.global_naming);
-  in thy |> Data.map (#2 o Name_Space.define context true (binding, action)) end;
-
-
-(* log content *)
+local
+  val actions = Synchronized.var "Mirabelle.actions"
+    (Symtab.empty : (action_context -> action) Symtab.table);
+in
 
-fun log_action name arguments =
-  XML.Elem (("action", (Markup.nameN, name) :: arguments),
-    [XML.Text (print_name name ^ (if null arguments then "" else " " ^ print_properties arguments))]);
+val register_action = Synchronized.change actions oo curry Symtab.update;
 
-fun log_command ({name, pos, ...}: command) body =
-  XML.Elem (("command", (Markup.nameN, name) :: Position.properties_of pos), body);
+fun get_action name = Symtab.lookup (Synchronized.value actions) name;
 
-fun log_report props body =
-  XML.Elem (("report", props), body);
+end
 
 
 (* apply actions *)
 
-fun apply_action index name arguments timeout commands thy =
-  let
-    val action = #2 (Name_Space.check (Context.Theory thy) (Data.get thy) (name, Position.none));
-    val tag = "#" ^ Value.print_int index ^ " " ^ name ^ ": ";
-    val context = {index = index, tag = tag, arguments = arguments, timeout = timeout, theory = thy};
-    val export_body = action context commands;
-    val export_head = log_action name arguments;
-    val export_name = Path.binding0 (Path.basic "mirabelle" + Path.basic (Value.print_int index));
-  in
-    if null export_body then ()
-    else Export.export thy export_name (export_head :: export_body)
-  end;
-
 fun print_exn exn =
   (case exn of
     Timeout.TIMEOUT _ => "timeout"
   | ERROR msg => "error: " ^ msg
-  | exn => "exception:\n" ^ General.exnMessage exn);
+  | exn => "exception: " ^ General.exnMessage exn);
 
-fun command_action binding action =
+fun run_action_function f =
+  f () handle exn =>
+    if Exn.is_interrupt exn then Exn.reraise exn
+    else print_exn exn;
+
+fun make_action_path (context as {index, name, ...} : action_context) =
+  Path.basic (string_of_int index ^ "." ^ name);
+
+fun finalize_action ({finalize, ...} : action) context =
   let
-    fun apply context command =
-      let val s =
-        action context command handle exn =>
-          if Exn.is_interrupt exn then Exn.reraise exn
-          else #tag context ^ print_exn exn;
-      in
-        if s = "" then NONE
-        else SOME (log_command command [XML.Text s]) end;
-  in theory_action binding (map_filter o apply) end;
+    val s = run_action_function finalize;
+    val action_path = make_action_path context;
+    val export_name =
+      Path.binding0 (Path.basic "mirabelle" + action_path + Path.basic "finalize");
+  in
+    if s <> "" then
+      Export.export \<^theory> export_name [XML.Text s]
+    else
+      ()
+  end
+
+fun apply_action ({run_action, ...} : action) context (command as {pos, pre, ...} : command) =
+  let
+    val thy = Proof.theory_of pre;
+    val action_path = make_action_path context;
+    val goal_name_path = Path.basic (#name command)
+    val line_path = Path.basic (string_of_int (the (Position.line_of pos)));
+    val offset_path = Path.basic (string_of_int (the (Position.offset_of pos)));
+    val export_name =
+      Path.binding0 (Path.basic "mirabelle" + action_path + Path.basic "goal" + goal_name_path +
+        line_path + offset_path);
+    val s = run_action_function (fn () => run_action command);
+  in
+    if s <> "" then
+      Export.export thy export_name [XML.Text s]
+    else
+      ()
+  end;
 
 
 (* theory line range *)
@@ -147,9 +143,6 @@
     fun check_pos range = check_line range o Position.line_of;
   in check_pos o get_theory end;
 
-fun check_session qualifier thy_name (_: Position.T) =
-  Resources.theory_qualifier thy_name = qualifier;
-
 
 (* presentation hook *)
 
@@ -160,6 +153,7 @@
     let
       val mirabelle_timeout = Options.default_seconds \<^system_option>\<open>mirabelle_timeout\<close>;
       val mirabelle_stride = Options.default_int \<^system_option>\<open>mirabelle_stride\<close>;
+      val mirabelle_max_calls = Options.default_int \<^system_option>\<open>mirabelle_max_calls\<close>;
       val mirabelle_actions = Options.default_string \<^system_option>\<open>mirabelle_actions\<close>;
       val mirabelle_theories = Options.default_string \<^system_option>\<open>mirabelle_theories\<close>;
 
@@ -167,35 +161,62 @@
         (case read_actions mirabelle_actions of
           SOME actions => actions
         | NONE => error ("Failed to parse mirabelle_actions: " ^ quote mirabelle_actions));
-      val check =
-        if mirabelle_theories = "" then check_session qualifier
-        else check_theories (space_explode "," mirabelle_theories);
+    in
+      if null actions then
+        ()
+      else
+        let
+          val check_theory = check_theories (space_explode "," mirabelle_theories);
 
-      fun theory_commands (thy, segments) =
-        let
-          val commands = segments
-            |> map_index (fn (n, {command = tr, prev_state = st, state = st', ...}) =>
-              if n mod mirabelle_stride = 0 then
+          fun make_commands (thy_index, (thy, segments)) =
+            let
+              val thy_long_name = Context.theory_long_name thy;
+              val check_thy = check_theory thy_long_name;
+              fun make_command {command = tr, prev_state = st, state = st', ...} =
                 let
                   val name = Toplevel.name_of tr;
                   val pos = Toplevel.pos_of tr;
                 in
                   if can (Proof.assert_backward o Toplevel.proof_of) st andalso
-                    member (op =) whitelist name andalso
-                    check (Context.theory_long_name thy) pos
-                  then SOME {name = name, pos = pos, pre = Toplevel.proof_of st, post = st'}
+                    member (op =) whitelist name andalso check_thy pos
+                  then SOME {theory_index = thy_index, name = name, pos = pos,
+                    pre = Toplevel.proof_of st, post = st'}
                   else NONE
-                end
-              else NONE)
-            |> map_filter I;
-        in if null commands then NONE else SOME (thy, commands) end;
+                end;
+            in
+              if Resources.theory_qualifier thy_long_name = qualifier then
+                map_filter make_command segments
+              else
+                []
+            end;
 
-      fun app_actions (thy, commands) =
-        (actions, ()) |-> fold_index (fn (index, (name, arguments)) => fn () =>
-          apply_action (index + 1) name arguments mirabelle_timeout commands thy);
-    in
-      if null actions then ()
-      else List.app app_actions (map_filter theory_commands loaded_theories)
+          (* initialize actions *)
+          val contexts = actions |> map_index (fn (n, (name, args)) =>
+            let
+              val make_action = the (get_action name);
+              val context = {index = n, name = name, arguments = args, timeout = mirabelle_timeout};
+            in
+              (make_action context, context)
+            end);
+        in
+          (* run actions on all relevant goals *)
+          loaded_theories
+          |> map_index I
+          |> maps make_commands
+          |> map_index I
+          |> maps (fn (n, command) =>
+            let val (m, k) = Integer.div_mod (n + 1) mirabelle_stride in
+              if k = 0 andalso m <= mirabelle_max_calls then
+                map (fn context => (context, command)) contexts
+              else
+                []
+            end)
+          |> Par_List.map (fn ((action, context), command) => apply_action action context command)
+          |> ignore;
+
+          (* finalize actions *)
+          List.app (uncurry finalize_action) contexts
+          end
     end);