src/HOL/Tools/Mirabelle/mirabelle_sledgehammer_filter.ML
changeset 73847 58f6b41efe88
parent 73696 03e134d5f867
child 74948 15ce207f69c8
--- a/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer_filter.ML	Sun Jun 06 21:39:26 2021 +0200
+++ b/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer_filter.ML	Thu Jun 10 11:21:57 2021 +0200
@@ -1,11 +1,12 @@
 (*  Title:      HOL/Mirabelle/Tools/mirabelle_sledgehammer_filter.ML
     Author:     Jasmin Blanchette, TU Munich
     Author:     Makarius
+    Author:     Martin Desharnais, UniBw Munich
 
 Mirabelle action: "sledgehammer_filter".
 *)
 
-structure Mirabelle_Sledgehammer_Filter: sig end =
+structure Mirabelle_Sledgehammer_Filter: MIRABELLE_ACTION =
 struct
 
 fun get args name default_value =
@@ -40,7 +41,7 @@
 structure Prooftab =
   Table(type key = int * int val ord = prod_ord int_ord int_ord)
 
-fun print_int x = Value.print_int (! x)
+fun print_int x = Value.print_int (Synchronized.value x)
 
 fun percentage a b = if b = 0 then "N/A" else Value.print_int (a * 100 div b)
 fun percentage_alt a b = percentage a (a + b)
@@ -51,133 +52,135 @@
 
 val proof_fileK = "proof_file"
 
-val _ =
-  Theory.setup (Mirabelle.theory_action \<^binding>\<open>sledgehammer_filter\<close>
-    (fn context => fn commands =>
+fun make_action ({arguments, ...} : Mirabelle.action_context) =
+  let
+    val (proof_table, args) =
       let
-        val (proof_table, args) =
-          let
-            val (pf_args, other_args) =
-              #arguments context |> List.partition (curry (op =) proof_fileK o fst)
-            val proof_file =
-              (case pf_args of
-                [] => error "No \"proof_file\" specified"
-              | (_, s) :: _ => s)
-            fun do_line line =
-              (case line |> space_explode ":" of
-                [line_num, offset, proof] =>
-                  SOME (apply2 (the o Int.fromString) (line_num, offset),
-                    proof |> space_explode " " |> filter_out (curry (op =) ""))
-              | _ => NONE)
-            val proof_table =
-              File.read (Path.explode proof_file)
-              |> space_explode "\n"
-              |> map_filter do_line
-              |> AList.coalesce (op =)
-              |> Prooftab.make
-          in (proof_table, other_args) end
+        val (pf_args, other_args) =
+          List.partition (curry (op =) proof_fileK o fst) arguments
+        val proof_file =
+          (case pf_args of
+            [] => error "No \"proof_file\" specified"
+          | (_, s) :: _ => s)
+        fun do_line line =
+          (case space_explode ":" line of
+            [line_num, offset, proof] =>
+              SOME (apply2 (the o Int.fromString) (line_num, offset),
+                proof |> space_explode " " |> filter_out (curry (op =) ""))
+          | _ => NONE)
+        val proof_table =
+          File.read (Path.explode proof_file)
+          |> space_explode "\n"
+          |> map_filter do_line
+          |> AList.coalesce (op =)
+          |> Prooftab.make
+      in (proof_table, other_args) end
 
-        val num_successes = Unsynchronized.ref 0
-        val num_failures = Unsynchronized.ref 0
-        val num_found_proofs = Unsynchronized.ref 0
-        val num_lost_proofs = Unsynchronized.ref 0
-        val num_found_facts = Unsynchronized.ref 0
-        val num_lost_facts = Unsynchronized.ref 0
+    val num_successes = Synchronized.var "num_successes" 0
+    val num_failures = Synchronized.var "num_failures" 0
+    val num_found_proofs = Synchronized.var "num_found_proofs" 0
+    val num_lost_proofs = Synchronized.var "num_lost_proofs" 0
+    val num_found_facts = Synchronized.var "num_found_facts" 0
+    val num_lost_facts = Synchronized.var "num_lost_facts" 0
 
+    fun run_action ({pos, pre, ...} : Mirabelle.command) =
+      let
         val results =
-          commands |> maps (fn {pos, pre, ...} =>
-            (case (Position.line_of pos, Position.offset_of pos) of
-              (SOME line_num, SOME offset) =>
-                (case Prooftab.lookup proof_table (line_num, offset) of
-                  SOME proofs =>
-                    let
-                      val thy = Proof.theory_of pre
-                      val {context = ctxt, facts = chained_ths, goal} = Proof.goal pre
-                      val prover = AList.lookup (op =) args "prover" |> the_default default_prover
-                      val params as {max_facts, ...} = Sledgehammer_Commands.default_params thy args
-                      val default_max_facts =
-                        Sledgehammer_Prover_Minimize.default_max_facts_of_prover ctxt prover
-                      val relevance_fudge =
-                        extract_relevance_fudge args Sledgehammer_MePo.default_relevance_fudge
-                      val subgoal = 1
-                      val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal subgoal ctxt
-                      val ho_atp = Sledgehammer_Prover_ATP.is_ho_atp ctxt prover
-                      val keywords = Thy_Header.get_keywords' ctxt
-                      val css_table = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
-                      val facts =
-                        Sledgehammer_Fact.nearly_all_facts ctxt ho_atp
-                          Sledgehammer_Fact.no_fact_override keywords css_table chained_ths
-                          hyp_ts concl_t
-                        |> Sledgehammer_Fact.drop_duplicate_facts
-                        |> Sledgehammer_MePo.mepo_suggested_facts ctxt params
-                            (the_default default_max_facts max_facts)
-                            (SOME relevance_fudge) hyp_ts concl_t
-                        |> map (fst o fst)
-                      val (found_facts, lost_facts) =
-                        flat proofs |> sort_distinct string_ord
-                        |> map (fn fact => (find_index (curry (op =) fact) facts, fact))
-                        |> List.partition (curry (op <=) 0 o fst)
-                        |>> sort (prod_ord int_ord string_ord) ||> map snd
-                      val found_proofs = filter (forall (member (op =) facts)) proofs
-                      val n = length found_proofs
-                      val log1 =
-                        if n = 0 then
-                          (Unsynchronized.inc num_failures; "Failure")
-                        else
-                          (Unsynchronized.inc num_successes;
-                           Unsynchronized.add num_found_proofs n;
-                           "Success (" ^ Value.print_int n ^ " of " ^
-                             Value.print_int (length proofs) ^ " proofs)")
-                      val _ = Unsynchronized.add num_lost_proofs (length proofs - n)
-                      val _ = Unsynchronized.add num_found_facts (length found_facts)
-                      val _ = Unsynchronized.add num_lost_facts (length lost_facts)
-                      val log2 =
-                        if null found_facts then []
-                        else
-                          let
-                            val found_weight =
-                              Real.fromInt (fold (fn (n, _) => Integer.add (n * n)) found_facts 0)
-                                / Real.fromInt (length found_facts)
-                              |> Math.sqrt |> Real.ceil
-                          in
-                            ["Found facts (among " ^ Value.print_int (length facts) ^
-                             ", weight " ^ Value.print_int found_weight ^ "): " ^
-                             commas (map with_index found_facts)]
-                          end
-                      val log3 =
-                        if null lost_facts then []
-                        else
-                          ["Lost facts (among " ^ Value.print_int (length facts) ^ "): " ^
-                           commas lost_facts]
-                    in [XML.Text (cat_lines (log1 :: log2 @ log3))] end
-                | NONE => [XML.Text "No known proof"])
-            | _ => []))
+          (case (Position.line_of pos, Position.offset_of pos) of
+            (SOME line_num, SOME offset) =>
+              (case Prooftab.lookup proof_table (line_num, offset) of
+                SOME proofs =>
+                  let
+                    val thy = Proof.theory_of pre
+                    val {context = ctxt, facts = chained_ths, goal} = Proof.goal pre
+                    val prover = AList.lookup (op =) args "prover" |> the_default default_prover
+                    val params as {max_facts, ...} = Sledgehammer_Commands.default_params thy args
+                    val default_max_facts =
+                      Sledgehammer_Prover_Minimize.default_max_facts_of_prover ctxt prover
+                    val relevance_fudge =
+                      extract_relevance_fudge args Sledgehammer_MePo.default_relevance_fudge
+                    val subgoal = 1
+                    val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal subgoal ctxt
+                    val ho_atp = Sledgehammer_Prover_ATP.is_ho_atp ctxt prover
+                    val keywords = Thy_Header.get_keywords' ctxt
+                    val css_table = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
+                    val facts =
+                      Sledgehammer_Fact.nearly_all_facts ctxt ho_atp
+                        Sledgehammer_Fact.no_fact_override keywords css_table chained_ths
+                        hyp_ts concl_t
+                      |> Sledgehammer_Fact.drop_duplicate_facts
+                      |> Sledgehammer_MePo.mepo_suggested_facts ctxt params
+                          (the_default default_max_facts max_facts)
+                          (SOME relevance_fudge) hyp_ts concl_t
+                      |> map (fst o fst)
+                    val (found_facts, lost_facts) =
+                      flat proofs |> sort_distinct string_ord
+                      |> map (fn fact => (find_index (curry (op =) fact) facts, fact))
+                      |> List.partition (curry (op <=) 0 o fst)
+                      |>> sort (prod_ord int_ord string_ord) ||> map snd
+                    val found_proofs = filter (forall (member (op =) facts)) proofs
+                    val n = length found_proofs
+                    val _ = Int.div
+                    val _ = Synchronized.change num_failures (curry op+ 1)
+                    val log1 =
+                      if n = 0 then
+                        (Synchronized.change num_failures (curry op+ 1); "Failure")
+                      else
+                        (Synchronized.change num_successes (curry op+ 1);
+                         Synchronized.change num_found_proofs (curry op+ n);
+                         "Success (" ^ Value.print_int n ^ " of " ^
+                           Value.print_int (length proofs) ^ " proofs)")
+                    val _ = Synchronized.change num_lost_proofs (curry op+ (length proofs - n))
+                    val _ = Synchronized.change num_found_facts (curry op+ (length found_facts))
+                    val _ = Synchronized.change num_lost_facts (curry op+ (length lost_facts))
+                    val log2 =
+                      if null found_facts then
+                        ""
+                      else
+                        let
+                          val found_weight =
+                            Real.fromInt (fold (fn (n, _) => Integer.add (n * n)) found_facts 0)
+                              / Real.fromInt (length found_facts)
+                            |> Math.sqrt |> Real.ceil
+                        in
+                          "Found facts (among " ^ Value.print_int (length facts) ^
+                          ", weight " ^ Value.print_int found_weight ^ "): " ^
+                          commas (map with_index found_facts)
+                        end
+                    val log3 =
+                      if null lost_facts then
+                        ""
+                      else
+                        "Lost facts (among " ^ Value.print_int (length facts) ^ "): " ^
+                         commas lost_facts
+                  in cat_lines [log1, log2, log3] end
+              | NONE => "No known proof")
+          | _ => "")
+      in
+        results
+      end
 
-        val report =
-          if ! num_successes + ! num_failures > 0 then
-            let
-              val props =
-                [("num_successes", print_int num_successes),
-                 ("num_failures", print_int num_failures),
-                 ("num_found_proofs", print_int num_found_proofs),
-                 ("num_lost_proofs", print_int num_lost_proofs),
-                 ("num_found_facts", print_int num_found_facts),
-                 ("num_lost_facts", print_int num_lost_facts)]
-              val text =
-                "\nNumber of overall successes: " ^ print_int num_successes ^
-                "\nNumber of overall failures: " ^ print_int num_failures ^
-                "\nOverall success rate: " ^
-                    percentage_alt (! num_successes) (! num_failures) ^ "%" ^
-                "\nNumber of found proofs: " ^ print_int num_found_proofs ^
-                "\nNumber of lost proofs: " ^ print_int num_lost_proofs ^
-                "\nProof found rate: " ^
-                    percentage_alt (! num_found_proofs) (! num_lost_proofs) ^ "%" ^
-                "\nNumber of found facts: " ^ print_int num_found_facts ^
-                "\nNumber of lost facts: " ^ print_int num_lost_facts ^
-                "\nFact found rate: " ^
-                    percentage_alt (! num_found_facts) (! num_lost_facts) ^ "%"
-            in [Mirabelle.log_report props [XML.Text text]] end
-          else []
-      in results @ report end))
+    fun finalize () =
+      if Synchronized.value num_successes + Synchronized.value num_failures > 0 then
+        "\nNumber of overall successes: " ^ print_int num_successes ^
+        "\nNumber of overall failures: " ^ print_int num_failures ^
+        "\nOverall success rate: " ^
+          percentage_alt (Synchronized.value num_successes)
+            (Synchronized.value num_failures) ^ "%" ^
+        "\nNumber of found proofs: " ^ print_int num_found_proofs ^
+        "\nNumber of lost proofs: " ^ print_int num_lost_proofs ^
+        "\nProof found rate: " ^
+          percentage_alt (Synchronized.value num_found_proofs)
+            (Synchronized.value num_lost_proofs) ^ "%" ^
+        "\nNumber of found facts: " ^ print_int num_found_facts ^
+        "\nNumber of lost facts: " ^ print_int num_lost_facts ^
+        "\nFact found rate: " ^
+          percentage_alt (Synchronized.value num_found_facts)
+            (Synchronized.value num_lost_facts) ^ "%"
+      else
+        ""
+  in {run_action = run_action, finalize = finalize} end
+
+val () = Mirabelle.register_action "sledgehammer_filter" make_action
 
 end