reactive "sledgehammer_filter": statically correct, but untested (no proof_file);
authorwenzelm
Sat, 15 May 2021 17:40:36 +0200
changeset 73696 03e134d5f867
parent 73695 b6d444194280
child 73697 0e7a5c7a14c8
reactive "sledgehammer_filter": statically correct, but untested (no proof_file);
src/HOL/Mirabelle.thy
src/HOL/Tools/Mirabelle/mirabelle_sledgehammer_filter.ML
src/Pure/Concurrent/unsynchronized.ML
--- a/src/HOL/Mirabelle.thy	Sat May 15 17:38:49 2021 +0200
+++ b/src/HOL/Mirabelle.thy	Sat May 15 17:40:36 2021 +0200
@@ -1,5 +1,6 @@
 (*  Title:      HOL/Mirabelle.thy
     Author:     Jasmin Blanchette and Sascha Boehme, TU Munich
+    Author:     Makarius
 *)
 
 theory Mirabelle
@@ -12,8 +13,8 @@
 ML_file \<open>Tools/Mirabelle/mirabelle_quickcheck.ML\<close>
 (*
 ML_file \<open>Tools/Mirabelle/mirabelle_sledgehammer.ML\<close>
+*)
 ML_file \<open>Tools/Mirabelle/mirabelle_sledgehammer_filter.ML\<close>
-*)
 ML_file \<open>Tools/Mirabelle/mirabelle_try0.ML\<close>
 
 end
--- a/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer_filter.ML	Sat May 15 17:38:49 2021 +0200
+++ b/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer_filter.ML	Sat May 15 17:40:36 2021 +0200
@@ -1,10 +1,11 @@
 (*  Title:      HOL/Mirabelle/Tools/mirabelle_sledgehammer_filter.ML
     Author:     Jasmin Blanchette, TU Munich
+    Author:     Makarius
 
 Mirabelle action: "sledgehammer_filter".
 *)
 
-structure Mirabelle_Sledgehammer_Filter : MIRABELLE_ACTION =
+structure Mirabelle_Sledgehammer_Filter: sig end =
 struct
 
 fun get args name default_value =
@@ -39,142 +40,144 @@
 structure Prooftab =
   Table(type key = int * int val ord = prod_ord int_ord int_ord)
 
-val proof_table = Unsynchronized.ref (Prooftab.empty: string list list Prooftab.table)
-
-val num_successes = Unsynchronized.ref ([] : (int * int) list)
-val num_failures = Unsynchronized.ref ([] : (int * int) list)
-val num_found_proofs = Unsynchronized.ref ([] : (int * int) list)
-val num_lost_proofs = Unsynchronized.ref ([] : (int * int) list)
-val num_found_facts = Unsynchronized.ref ([] : (int * int) list)
-val num_lost_facts = Unsynchronized.ref ([] : (int * int) list)
-
-fun get id c = the_default 0 (AList.lookup (op =) (!c) id)
-fun add id c n =
-  c := (case AList.lookup (op =) (!c) id of
-         SOME m => AList.update (op =) (id, m + n) (!c)
-       | NONE => (id, n) :: !c)
+fun print_int x = Value.print_int (! x)
 
-fun init proof_file _ thy =
-  let
-    fun do_line line =
-      (case line |> space_explode ":" of
-        [line_num, offset, proof] =>
-        SOME (apply2 (the o Int.fromString) (line_num, offset),
-              proof |> space_explode " " |> filter_out (curry (op =) ""))
-       | _ => NONE)
-    val proofs = File.read (Path.explode proof_file)
-    val proof_tab =
-      proofs |> space_explode "\n"
-             |> map_filter do_line
-             |> AList.coalesce (op =)
-             |> Prooftab.make
-  in proof_table := proof_tab; thy end
-
-fun percentage a b = if b = 0 then "N/A" else string_of_int (a * 100 div b)
+fun percentage a b = if b = 0 then "N/A" else Value.print_int (a * 100 div b)
 fun percentage_alt a b = percentage a (a + b)
 
-fun done id ({log, ...} : Mirabelle.done_args) =
-  if get id num_successes + get id num_failures > 0 then
-    (log "";
-     log ("Number of overall successes: " ^ string_of_int (get id num_successes));
-     log ("Number of overall failures: " ^ string_of_int (get id num_failures));
-     log ("Overall success rate: " ^
-          percentage_alt (get id num_successes) (get id num_failures) ^ "%");
-     log ("Number of found proofs: " ^ string_of_int (get id num_found_proofs));
-     log ("Number of lost proofs: " ^ string_of_int (get id num_lost_proofs));
-     log ("Proof found rate: " ^
-          percentage_alt (get id num_found_proofs) (get id num_lost_proofs) ^ "%");
-     log ("Number of found facts: " ^ string_of_int (get id num_found_facts));
-     log ("Number of lost facts: " ^ string_of_int (get id num_lost_facts));
-     log ("Fact found rate: " ^
-          percentage_alt (get id num_found_facts) (get id num_lost_facts) ^ "%"))
-  else
-    ()
-
 val default_prover = ATP_Proof.eN (* arbitrary ATP *)
 
-fun with_index (i, s) = s ^ "@" ^ string_of_int i
-
-fun action args id ({pre, pos, log, ...} : Mirabelle.run_args) =
-  case (Position.line_of pos, Position.offset_of pos) of
-    (SOME line_num, SOME offset) =>
-    (case Prooftab.lookup (!proof_table) (line_num, offset) of
-       SOME proofs =>
-       let
-         val thy = Proof.theory_of pre
-         val {context = ctxt, facts = chained_ths, goal} = Proof.goal pre
-         val prover = AList.lookup (op =) args "prover" |> the_default default_prover
-         val params as {max_facts, ...} = Sledgehammer_Commands.default_params thy args
-         val default_max_facts =
-           Sledgehammer_Prover_Minimize.default_max_facts_of_prover ctxt prover
-         val relevance_fudge =
-           extract_relevance_fudge args Sledgehammer_MePo.default_relevance_fudge
-         val subgoal = 1
-         val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal subgoal ctxt
-         val ho_atp = Sledgehammer_Prover_ATP.is_ho_atp ctxt prover
-         val keywords = Thy_Header.get_keywords' ctxt
-         val css_table = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
-         val facts =
-           Sledgehammer_Fact.nearly_all_facts ctxt ho_atp
-               Sledgehammer_Fact.no_fact_override keywords css_table chained_ths
-               hyp_ts concl_t
-           |> Sledgehammer_Fact.drop_duplicate_facts
-           |> Sledgehammer_MePo.mepo_suggested_facts ctxt params
-                  (the_default default_max_facts max_facts) (SOME relevance_fudge) hyp_ts concl_t
-            |> map (fst o fst)
-         val (found_facts, lost_facts) =
-           flat proofs |> sort_distinct string_ord
-           |> map (fn fact => (find_index (curry (op =) fact) facts, fact))
-           |> List.partition (curry (op <=) 0 o fst)
-           |>> sort (prod_ord int_ord string_ord) ||> map snd
-         val found_proofs = filter (forall (member (op =) facts)) proofs
-         val n = length found_proofs
-         val _ =
-           if n = 0 then
-             (add id num_failures 1; log "Failure")
-           else
-             (add id num_successes 1;
-              add id num_found_proofs n;
-              log ("Success (" ^ string_of_int n ^ " of " ^
-                   string_of_int (length proofs) ^ " proofs)"))
-         val _ = add id num_lost_proofs (length proofs - n)
-         val _ = add id num_found_facts (length found_facts)
-         val _ = add id num_lost_facts (length lost_facts)
-         val _ =
-           if null found_facts then
-             ()
-           else
-             let
-               val found_weight =
-                 Real.fromInt (fold (fn (n, _) => Integer.add (n * n)) found_facts 0)
-                   / Real.fromInt (length found_facts)
-                 |> Math.sqrt |> Real.ceil
-             in
-               log ("Found facts (among " ^ string_of_int (length facts) ^
-                    ", weight " ^ string_of_int found_weight ^ "): " ^
-                    commas (map with_index found_facts))
-             end
-         val _ = if null lost_facts then
-                   ()
-                 else
-                   log ("Lost facts (among " ^ string_of_int (length facts) ^
-                        "): " ^ commas lost_facts)
-       in () end
-     | NONE => log "No known proof")
-  | _ => ()
+fun with_index (i, s) = s ^ "@" ^ Value.print_int i
 
 val proof_fileK = "proof_file"
 
-fun invoke args =
-  let
-    val (pf_args, other_args) = args |> List.partition (curry (op =) proof_fileK o fst)
-    val proof_file =
-      (case pf_args of
-        [] => error "No \"proof_file\" specified"
-      | (_, s) :: _ => s)
-  in Mirabelle.register (init proof_file, action other_args, done) end
+val _ =
+  Theory.setup (Mirabelle.theory_action \<^binding>\<open>sledgehammer_filter\<close>
+    (fn context => fn commands =>
+      let
+        val (proof_table, args) =
+          let
+            val (pf_args, other_args) =
+              #arguments context |> List.partition (curry (op =) proof_fileK o fst)
+            val proof_file =
+              (case pf_args of
+                [] => error "No \"proof_file\" specified"
+              | (_, s) :: _ => s)
+            fun do_line line =
+              (case line |> space_explode ":" of
+                [line_num, offset, proof] =>
+                  SOME (apply2 (the o Int.fromString) (line_num, offset),
+                    proof |> space_explode " " |> filter_out (curry (op =) ""))
+              | _ => NONE)
+            val proof_table =
+              File.read (Path.explode proof_file)
+              |> space_explode "\n"
+              |> map_filter do_line
+              |> AList.coalesce (op =)
+              |> Prooftab.make
+          in (proof_table, other_args) end
+
+        val num_successes = Unsynchronized.ref 0
+        val num_failures = Unsynchronized.ref 0
+        val num_found_proofs = Unsynchronized.ref 0
+        val num_lost_proofs = Unsynchronized.ref 0
+        val num_found_facts = Unsynchronized.ref 0
+        val num_lost_facts = Unsynchronized.ref 0
 
-end;
+        val results =
+          commands |> maps (fn {pos, pre, ...} =>
+            (case (Position.line_of pos, Position.offset_of pos) of
+              (SOME line_num, SOME offset) =>
+                (case Prooftab.lookup proof_table (line_num, offset) of
+                  SOME proofs =>
+                    let
+                      val thy = Proof.theory_of pre
+                      val {context = ctxt, facts = chained_ths, goal} = Proof.goal pre
+                      val prover = AList.lookup (op =) args "prover" |> the_default default_prover
+                      val params as {max_facts, ...} = Sledgehammer_Commands.default_params thy args
+                      val default_max_facts =
+                        Sledgehammer_Prover_Minimize.default_max_facts_of_prover ctxt prover
+                      val relevance_fudge =
+                        extract_relevance_fudge args Sledgehammer_MePo.default_relevance_fudge
+                      val subgoal = 1
+                      val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal subgoal ctxt
+                      val ho_atp = Sledgehammer_Prover_ATP.is_ho_atp ctxt prover
+                      val keywords = Thy_Header.get_keywords' ctxt
+                      val css_table = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
+                      val facts =
+                        Sledgehammer_Fact.nearly_all_facts ctxt ho_atp
+                          Sledgehammer_Fact.no_fact_override keywords css_table chained_ths
+                          hyp_ts concl_t
+                        |> Sledgehammer_Fact.drop_duplicate_facts
+                        |> Sledgehammer_MePo.mepo_suggested_facts ctxt params
+                            (the_default default_max_facts max_facts)
+                            (SOME relevance_fudge) hyp_ts concl_t
+                        |> map (fst o fst)
+                      val (found_facts, lost_facts) =
+                        flat proofs |> sort_distinct string_ord
+                        |> map (fn fact => (find_index (curry (op =) fact) facts, fact))
+                        |> List.partition (curry (op <=) 0 o fst)
+                        |>> sort (prod_ord int_ord string_ord) ||> map snd
+                      val found_proofs = filter (forall (member (op =) facts)) proofs
+                      val n = length found_proofs
+                      val log1 =
+                        if n = 0 then
+                          (Unsynchronized.inc num_failures; "Failure")
+                        else
+                          (Unsynchronized.inc num_successes;
+                           Unsynchronized.add num_found_proofs n;
+                           "Success (" ^ Value.print_int n ^ " of " ^
+                             Value.print_int (length proofs) ^ " proofs)")
+                      val _ = Unsynchronized.add num_lost_proofs (length proofs - n)
+                      val _ = Unsynchronized.add num_found_facts (length found_facts)
+                      val _ = Unsynchronized.add num_lost_facts (length lost_facts)
+                      val log2 =
+                        if null found_facts then []
+                        else
+                          let
+                            val found_weight =
+                              Real.fromInt (fold (fn (n, _) => Integer.add (n * n)) found_facts 0)
+                                / Real.fromInt (length found_facts)
+                              |> Math.sqrt |> Real.ceil
+                          in
+                            ["Found facts (among " ^ Value.print_int (length facts) ^
+                             ", weight " ^ Value.print_int found_weight ^ "): " ^
+                             commas (map with_index found_facts)]
+                          end
+                      val log3 =
+                        if null lost_facts then []
+                        else
+                          ["Lost facts (among " ^ Value.print_int (length facts) ^ "): " ^
+                           commas lost_facts]
+                    in [XML.Text (cat_lines (log1 :: log2 @ log3))] end
+                | NONE => [XML.Text "No known proof"])
+            | _ => []))
 
-(* Workaround to keep the "mirabelle.pl" script happy *)
-structure Mirabelle_Sledgehammer_filter = Mirabelle_Sledgehammer_Filter;
+        val report =
+          if ! num_successes + ! num_failures > 0 then
+            let
+              val props =
+                [("num_successes", print_int num_successes),
+                 ("num_failures", print_int num_failures),
+                 ("num_found_proofs", print_int num_found_proofs),
+                 ("num_lost_proofs", print_int num_lost_proofs),
+                 ("num_found_facts", print_int num_found_facts),
+                 ("num_lost_facts", print_int num_lost_facts)]
+              val text =
+                "\nNumber of overall successes: " ^ print_int num_successes ^
+                "\nNumber of overall failures: " ^ print_int num_failures ^
+                "\nOverall success rate: " ^
+                    percentage_alt (! num_successes) (! num_failures) ^ "%" ^
+                "\nNumber of found proofs: " ^ print_int num_found_proofs ^
+                "\nNumber of lost proofs: " ^ print_int num_lost_proofs ^
+                "\nProof found rate: " ^
+                    percentage_alt (! num_found_proofs) (! num_lost_proofs) ^ "%" ^
+                "\nNumber of found facts: " ^ print_int num_found_facts ^
+                "\nNumber of lost facts: " ^ print_int num_lost_facts ^
+                "\nFact found rate: " ^
+                    percentage_alt (! num_found_facts) (! num_lost_facts) ^ "%"
+            in [Mirabelle.log_report props [XML.Text text]] end
+          else []
+      in results @ report end))
+
+end
--- a/src/Pure/Concurrent/unsynchronized.ML	Sat May 15 17:38:49 2021 +0200
+++ b/src/Pure/Concurrent/unsynchronized.ML	Sat May 15 17:40:36 2021 +0200
@@ -13,6 +13,7 @@
   val change_result: 'a ref -> ('a -> 'b * 'a) -> 'b
   val inc: int ref -> int
   val dec: int ref -> int
+  val add: int ref -> int -> int
   val setmp: 'a ref -> 'a -> ('b -> 'c) -> 'b -> 'c
 end;
 
@@ -29,6 +30,7 @@
 
 fun inc i = (i := ! i + (1: int); ! i);
 fun dec i = (i := ! i - (1: int); ! i);
+fun add i n = (i := ! i + (n: int); ! i);
 
 fun setmp flag value f x =
   Thread_Attributes.uninterruptible (fn restore_attributes => fn () =>