speed up MaSh duplicate check
authorblanchet
Sun, 04 Oct 2015 17:48:34 +0200
changeset 61322 44f4ffe2b210
parent 61321 c982a4cc8dc4
child 61323 99b3a17a7eab
child 61326 3ad2b2055ffc
speed up MaSh duplicate check
src/HOL/TPTP/mash_export.ML
src/HOL/Tools/Sledgehammer/sledgehammer_fact.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/TPTP/mash_export.ML	Sun Oct 04 17:41:52 2015 +0200
+++ b/src/HOL/TPTP/mash_export.ML	Sun Oct 04 17:48:34 2015 +0200
@@ -314,7 +314,7 @@
         val mess =
           [(mepo_weight, (mepo_suggs, [])),
            (mash_weight, (mash_suggs, []))]
-        val mesh_suggs = mesh_facts (op =) max_suggs mess
+        val mesh_suggs = mesh_facts I (op =) max_suggs mess
         val mesh_line = encode_str name ^ ": " ^ encode_strs mesh_suggs ^ "\n"
       in File.append mesh_path mesh_line end
 
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_fact.ML	Sun Oct 04 17:41:52 2015 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_fact.ML	Sun Oct 04 17:48:34 2015 +0200
@@ -28,6 +28,7 @@
   val clasimpset_rule_table_of : Proof.context -> status Termtab.table
   val build_name_tables : (thm -> string) -> ('a * thm) list ->
     string Symtab.table * string Symtab.table
+  val fact_distinct : (term * term -> bool) -> ('a * thm) list -> ('a * thm) list
   val maybe_instantiate_inducts : Proof.context -> term list -> term ->
     (((unit -> string) * 'a) * thm) list -> (((unit -> string) * 'a) * thm) list
   val fact_of_raw_fact : raw_fact -> fact
@@ -375,11 +376,13 @@
   end
 
 fun fact_distinct eq facts =
-  fold (fn fact as (_, th) =>
-      Net.insert_term_safe (eq o apply2 (normalize_eq o Thm.prop_of o snd))
-        (normalize_eq (Thm.prop_of th), fact))
-    facts Net.empty
+  fold (fn (i, fact as (_, th)) =>
+      Net.insert_term_safe (eq o apply2 (normalize_eq o Thm.prop_of o snd o snd))
+        (normalize_eq (Thm.prop_of th), (i, fact)))
+    (tag_list 0 facts) Net.empty
   |> Net.entries
+  |> sort (int_ord o apply2 fst)
+  |> map snd
 
 fun struct_induct_rule_on th =
   (case Logic.strip_horn (Thm.prop_of th) of
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Sun Oct 04 17:41:52 2015 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Sun Oct 04 17:48:34 2015 +0200
@@ -44,7 +44,8 @@
   val the_mash_algorithm : unit -> mash_algorithm
   val str_of_mash_algorithm : mash_algorithm -> string
 
-  val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list
+  val mesh_facts : ('a list -> 'a list) -> ('a * 'a -> bool) -> int ->
+    (real * (('a * real) list * 'a list)) list -> 'a list
   val nickname_of_thm : thm -> string
   val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
   val crude_thm_ord : Proof.context -> thm * thm -> order
@@ -169,9 +170,10 @@
   | normalize_scores max_facts xs =
     map (apsnd (curry Real.* (1.0 / avg (map snd (take max_facts xs))))) xs
 
-fun mesh_facts fact_eq max_facts [(_, (sels, unks))] =
-    distinct fact_eq (map fst (take max_facts sels) @ take (max_facts - length sels) unks)
-  | mesh_facts fact_eq max_facts mess =
+fun mesh_facts maybe_distinct _ max_facts [(_, (sels, unks))] =
+    map fst (take max_facts sels) @ take (max_facts - length sels) unks
+    |> maybe_distinct
+  | mesh_facts _ fact_eq max_facts mess =
     let
       val mess = mess |> map (apsnd (apfst (normalize_scores max_facts)))
 
@@ -503,7 +505,7 @@
        MaSh_NB => nb ()
      | MaSh_kNN => knn ()
      | MaSh_NB_kNN =>
-       mesh_facts (op =) max_suggs
+       mesh_facts I (op =) max_suggs
          [(0.5 (* FUDGE *), (weight_facts_steeply (nb ()), [])),
           (0.5 (* FUDGE *), (weight_facts_steeply (knn ()), []))])
      |> map (curry Vector.sub fact_names))
@@ -1169,7 +1171,7 @@
     val unknown = raw_unknown
       |> fold (subtract (eq_snd Thm.eq_thm_prop)) [unknown_chained, unknown_proximate]
   in
-    (mesh_facts (eq_snd (gen_eq_thm ctxt)) max_facts mess, unknown)
+    (mesh_facts (fact_distinct (op aconv)) (eq_snd (gen_eq_thm ctxt)) max_facts mess, unknown)
   end
 
 fun mash_suggested_facts ctxt thy_name ({debug, ...} : params) max_suggs hyp_ts concl_t facts =
@@ -1603,8 +1605,8 @@
          |> weight_facts_steeply, [])
 
       fun mash () =
-        mash_suggested_facts ctxt thy_name params
-          (generous_max_suggestions max_facts) hyp_ts concl_t facts
+        mash_suggested_facts ctxt thy_name params (generous_max_suggestions max_facts) hyp_ts
+          concl_t facts
         |>> weight_facts_steeply
 
       val mess =
@@ -1612,7 +1614,9 @@
         [] |> effective_fact_filter <> mepoN ? cons (mash_weight, mash)
            |> effective_fact_filter <> mashN ? cons (mepo_weight, mepo)
            |> Par_List.map (apsnd (fn f => f ()))
-      val mesh = mesh_facts (eq_snd (gen_eq_thm ctxt)) max_facts mess |> add_and_take
+      val mesh =
+        mesh_facts (fact_distinct (op aconv)) (eq_snd (gen_eq_thm ctxt)) max_facts mess
+        |> add_and_take
     in
       (case (fact_filter, mess) of
         (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>