performance tuning: replace Ord_List by Table();
authorwenzelm
Tue, 11 Apr 2023 20:32:04 +0200
changeset 77825 61f652dd955a
parent 77824 e3fe192fa4a8
child 77826 e3db27e3b0c6
performance tuning: replace Ord_List by Table();
src/HOL/Tools/Mirabelle/mirabelle.ML
src/HOL/Tools/Sledgehammer/sledgehammer_util.ML
src/Pure/Proof/extraction.ML
src/Pure/proofterm.ML
src/Pure/thm.ML
src/Pure/thm_deps.ML
--- a/src/HOL/Tools/Mirabelle/mirabelle.ML	Tue Apr 11 15:03:02 2023 +0200
+++ b/src/HOL/Tools/Mirabelle/mirabelle.ML	Tue Apr 11 20:32:04 2023 +0200
@@ -302,7 +302,7 @@
 
 fun fold_body_thms f =
   let
-    fun app n (PBody {thms, ...}) = thms |> fold (fn (i, thm_node) =>
+    fun app n (PBody {thms, ...}) = thms |> PThms.fold (fn (i, thm_node) =>
       fn (x, seen) =>
         if Inttab.defined seen i then (x, seen)
         else
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_util.ML	Tue Apr 11 15:03:02 2023 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_util.ML	Tue Apr 11 20:32:04 2023 +0200
@@ -104,7 +104,7 @@
             else (num_thms + 1, name' :: names)
           | NONE => accum)
       end
-    and app_body map_name (PBody {thms, ...}) = fold (app_thm map_name) thms
+    and app_body map_name (PBody {thms, ...}) = PThms.fold (app_thm map_name) thms
   in
     snd (app_body map_plain_name body (0, []))
   end
--- a/src/Pure/Proof/extraction.ML	Tue Apr 11 15:03:02 2023 +0200
+++ b/src/Pure/Proof/extraction.ML	Tue Apr 11 20:32:04 2023 +0200
@@ -180,11 +180,7 @@
         (fn Oracle (name, prop, _) => apfst (cons ((name, Position.none), SOME prop))
           | PThm (header, thm_body) => apsnd (cons (Proofterm.make_thm header thm_body))
           | _ => I);
-    val body =
-      PBody
-       {oracles = Oracles.make oracles,
-        thms = Ord_List.make Proofterm.thm_ord thms,
-        proof = prf};
+    val body = PBody {oracles = Oracles.make oracles, thms = PThms.make thms, proof = prf};
   in Proofterm.thm_body body end;
 
 
--- a/src/Pure/proofterm.ML	Tue Apr 11 15:03:02 2023 +0200
+++ b/src/Pure/proofterm.ML	Tue Apr 11 20:32:04 2023 +0200
@@ -11,6 +11,11 @@
   val ord = prod_ord (prod_ord fast_string_ord Position.ord) (option_ord Term_Ord.fast_term_ord)
 );
 
+structure PThms = Table(
+  type key = serial
+  val ord = rev_order o int_ord
+);
+
 signature PROOFTERM =
 sig
   type thm_header =
@@ -32,9 +37,12 @@
    | PThm of thm_header * thm_body
   and proof_body = PBody of
     {oracles: Oracles.T,
-     thms: (serial * thm_node) Ord_List.T,
+     thms: thm_node PThms.table,
      proof: proof}
   type thm = serial * thm_node
+  type thms = thm_node PThms.table
+  val union_thms: thms * thms -> thms
+  val unions_thms: thms list -> thms
   exception MIN_PROOF of unit
   val proof_of: proof_body -> proof
   val join_proof: proof_body future -> proof
@@ -48,15 +56,13 @@
   val thm_node_name: thm_node -> string
   val thm_node_prop: thm_node -> term
   val thm_node_body: thm_node -> proof_body future
-  val thm_node_thms: thm_node -> thm list
-  val join_thms: thm list -> proof_body list
+  val thm_node_thms: thm_node -> thms
+  val join_thms: thms -> proof_body list
   val make_thm: thm_header -> thm_body -> thm
   val fold_proof_atoms: bool -> (proof -> 'a -> 'a) -> proof list -> 'a -> 'a
   val fold_body_thms:
     ({serial: serial, name: string, prop: term, body: proof_body} -> 'a -> 'a) ->
     proof_body list -> 'a -> 'a
-  val thm_ord: thm ord
-  val unions_thms: thm Ord_List.T list -> thm Ord_List.T
   val no_proof_body: proof -> proof_body
   val no_thm_names: proof -> proof
   val no_thm_proofs: proof -> proof
@@ -220,7 +226,7 @@
  | PThm of thm_header * thm_body
 and proof_body = PBody of
   {oracles: Oracles.T,
-   thms: (serial * thm_node) Ord_List.T,
+   thms: thm_node PThms.table,
    proof: proof}
 and thm_body =
   Thm_Body of {open_proof: proof -> proof, body: proof_body future}
@@ -229,8 +235,10 @@
     body: proof_body future, export: unit lazy, consolidate: unit lazy};
 
 type thm = serial * thm_node;
-val thm_ord: thm ord = fn ((i, _), (j, _)) => int_ord (j, i);
 
+type thms = thm_node PThms.table;
+val union_thms: thms * thms -> thms = PThms.merge (K true);
+val unions_thms: thms list -> thms = PThms.merges (K true);
 
 exception MIN_PROOF of unit;
 
@@ -256,11 +264,11 @@
 val thm_node_export = #export o rep_thm_node;
 val thm_node_consolidate = #consolidate o rep_thm_node;
 
-fun join_thms (thms: thm list) =
-  Future.joins (map (thm_node_body o #2) thms);
+fun join_thms (thms: thms) =
+  Future.joins (PThms.fold_rev (cons o thm_node_body o #2) thms []);
 
 val consolidate_bodies =
-  maps (fn PBody {thms, ...} => map (thm_node_consolidate o #2) thms)
+  maps (fn PBody {thms, ...} => PThms.fold_rev (cons o thm_node_consolidate o #2) thms [])
   #> Lazy.consolidate #> map Lazy.force #> ignore;
 
 fun make_thm_node theory_name name prop body export =
@@ -300,7 +308,7 @@
 fun fold_body_thms f =
   let
     fun app (PBody {thms, ...}) =
-      tap join_thms thms |> fold (fn (i, thm_node) => fn (x, seen) =>
+      tap join_thms thms |> PThms.fold (fn (i, thm_node) => fn (x, seen) =>
         if Intset.member seen i then (x, seen)
         else
           let
@@ -314,9 +322,7 @@
 
 (* proof body *)
 
-val unions_thms = Ord_List.unions thm_ord;
-
-fun no_proof_body proof = PBody {oracles = Oracles.empty, thms = [], proof = proof};
+fun no_proof_body proof = PBody {oracles = Oracles.empty, thms = PThms.empty, proof = proof};
 val no_thm_body = thm_body (no_proof_body MinProof);
 
 fun no_thm_names (Abst (x, T, prf)) = Abst (x, T, no_thm_names prf)
@@ -372,7 +378,8 @@
         (map Position.properties_of pos, (prop, (types, map_proof_of open_proof (Future.join body)))))]
 and proof_body consts (PBody {oracles, thms, proof = prf}) =
   triple (list (pair (pair string (properties o Position.properties_of))
-      (option (term consts)))) (list (thm consts)) (proof consts) (Oracles.dest oracles, thms, prf)
+      (option (term consts)))) (list (thm consts)) (proof consts)
+      (Oracles.dest oracles, PThms.dest thms, prf)
 and thm consts (a, thm_node) =
   pair int (pair string (pair string (pair (term consts) (proof_body consts))))
     (a, (thm_node_theory_name thm_node, (thm_node_name thm_node, (thm_node_prop thm_node,
@@ -438,7 +445,7 @@
     val (a, b, c) =
       triple (list (pair (pair string (Position.of_properties o properties))
         (option (term consts)))) (list (thm consts)) (proof consts) x;
-  in PBody {oracles = Oracles.make a, thms = b, proof = c} end
+  in PBody {oracles = Oracles.make a, thms = PThms.make b, proof = c} end
 and thm consts x =
   let
     val (a, (b, (c, (d, e)))) =
@@ -1991,7 +1998,9 @@
         (fold (fn (_, PBody {oracles, ...}) => not (Oracles.is_empty oracles) ? cons oracles)
           ps [oracles0]);
     val thms =
-      unions_thms (fold (fn (_, PBody {thms, ...}) => not (null thms) ? cons thms) ps [thms0]);
+      unions_thms
+        (fold (fn (_, PBody {thms, ...}) => not (PThms.is_empty thms) ? cons thms)
+          ps [thms0]);
     val proof = rew_proof thy proof0;
   in PBody {oracles = oracles, thms = thms, proof = proof} end;
 
@@ -2140,9 +2149,9 @@
       else
         boxes
         |> Inttab.update (i, thm_node_export thm_node)
-        |> fold export_thm (thm_node_thms thm_node);
+        |> PThms.fold export_thm (thm_node_thms thm_node);
 
-    fun export_body (PBody {thms, ...}) = fold export_thm thms;
+    fun export_body (PBody {thms, ...}) = PThms.fold export_thm thms;
 
     val exports = Inttab.build (fold export_body bodies) |> Inttab.dest;
   in List.app (Lazy.force o #2) exports end;
--- a/src/Pure/thm.ML	Tue Apr 11 15:03:02 2023 +0200
+++ b/src/Pure/thm.ML	Tue Apr 11 20:32:04 2023 +0200
@@ -107,7 +107,7 @@
   val expose_proofs: theory -> thm list -> unit
   val expose_proof: theory -> thm -> unit
   val future: thm future -> cterm -> thm
-  val thm_deps: thm -> Proofterm.thm Ord_List.T
+  val thm_deps: thm -> Proofterm.thms
   val extra_shyps: thm -> Sortset.T
   val strip_shyps: thm -> thm
   val derivation_closed: thm -> bool
@@ -739,7 +739,7 @@
 fun make_deriv promises oracles thms proof =
   Deriv {promises = promises, body = PBody {oracles = oracles, thms = thms, proof = proof}};
 
-val empty_deriv = make_deriv empty_promises Oracles.empty [] MinProof;
+val empty_deriv = make_deriv empty_promises Oracles.empty PThms.empty MinProof;
 
 
 (* inference rules *)
@@ -753,7 +753,7 @@
   let
     val ps = merge_promises (promises1, promises2);
     val oracles = Oracles.merge (oracles1, oracles2);
-    val thms = Proofterm.unions_thms [thms1, thms2];
+    val thms = Proofterm.union_thms (thms1, thms2);
     val prf =
       (case ! Proofterm.proofs of
         2 => f prf1 prf2
@@ -766,7 +766,7 @@
 
 fun deriv_rule0 make_prf =
   if ! Proofterm.proofs <= 1 then empty_deriv
-  else deriv_rule1 I (make_deriv empty_promises Oracles.empty [] (make_prf ()));
+  else deriv_rule1 I (make_deriv empty_promises Oracles.empty PThms.empty (make_prf ()));
 
 fun deriv_rule_unconditional f (Deriv {promises, body = PBody {oracles, thms, proof}}) =
   make_deriv promises oracles thms (f proof);
@@ -832,7 +832,7 @@
     val i = serial ();
     val future = future_thm |> Future.map (future_result i cert sorts prop);
   in
-    Thm (make_deriv (make_promises [(i, future)]) Oracles.empty [] MinProof,
+    Thm (make_deriv (make_promises [(i, future)]) Oracles.empty PThms.empty MinProof,
      {cert = cert,
       tags = [],
       maxidx = maxidx,
@@ -974,7 +974,7 @@
 local
 
 fun union_digest (oracles1, thms1) (oracles2, thms2) =
-  (Oracles.merge (oracles1, oracles2), Proofterm.unions_thms [thms1, thms2]);
+  (Oracles.merge (oracles1, oracles2), Proofterm.union_thms (thms1, thms2));
 
 fun thm_digest (Thm (Deriv {body = PBody {oracles, thms, ...}, ...}, _)) =
   (oracles, thms);
@@ -982,11 +982,12 @@
 fun constraint_digest ({theory = thy, typ, sort, ...}: constraint) =
   Sorts.of_sort_derivation (Sign.classes_of thy)
    {class_relation = fn _ => fn _ => fn (digest, c1) => fn c2 =>
-      if c1 = c2 then (Oracles.empty, []) else union_digest digest (thm_digest (the_classrel thy (c1, c2))),
+      if c1 = c2 then (Oracles.empty, PThms.empty)
+      else union_digest digest (thm_digest (the_classrel thy (c1, c2))),
     type_constructor = fn (a, _) => fn dom => fn c =>
       let val arity_digest = thm_digest (the_arity thy (a, (map o map) #2 dom, c))
       in (fold o fold) (union_digest o #1) dom arity_digest end,
-    type_variable = fn T => map (pair (Oracles.empty, [])) (Type.sort_of_atyp T)}
+    type_variable = fn T => map (pair (Oracles.empty, PThms.empty)) (Type.sort_of_atyp T)}
    (typ, sort);
 
 in
@@ -1093,12 +1094,13 @@
 
 (*dependencies of PThm node*)
 fun thm_deps (thm as Thm (Deriv {promises, body = PBody {thms, ...}, ...}, _)) =
-  if null_promises promises then
-      (case (derivation_id thm, thms) of
-        (SOME {serial = i, ...}, [(j, thm_node)]) =>
-          if i = j then Proofterm.thm_node_thms thm_node else thms
-      | _ => thms)
-  else raise THM ("thm_deps: bad promises", 0, [thm]);
+  if not (null_promises promises) then raise THM ("thm_deps: bad promises", 0, [thm])
+  else if PThms.size thms = 1 then
+    (case (derivation_id thm, PThms.dest thms) of
+      (SOME {serial = i, ...}, [(j, thm_node)]) =>
+        if i = j then Proofterm.thm_node_thms thm_node else thms
+    | _ => thms)
+  else thms;
 
 fun name_derivation name_pos =
   strip_shyps #> (fn thm as Thm (der, args) =>
@@ -1114,7 +1116,7 @@
       val (pthm, proof) =
         Proofterm.thm_proof thy (classrel_proof thy) (arity_proof thy)
           name_pos shyps (Termset.dest hyps) prop ps body;
-      val der' = make_deriv empty_promises Oracles.empty [pthm] proof;
+      val der' = make_deriv empty_promises Oracles.empty (PThms.make [pthm]) proof;
     in Thm (der', args) end);
 
 fun close_derivation pos =
@@ -1145,7 +1147,7 @@
               | 0 => (((name, Position.none), NONE), MinProof)
               | i => bad_proofs i);
           in
-            Thm (make_deriv empty_promises (Oracles.make [oracle]) [] prf,
+            Thm (make_deriv empty_promises (Oracles.make [oracle]) PThms.empty prf,
              {cert = Context.join_certificate (Context.Certificate thy', cert2),
               tags = [],
               maxidx = maxidx,
@@ -1805,7 +1807,7 @@
       val (pthm, proof) =
         Proofterm.unconstrain_thm_proof thy (classrel_proof thy) (arity_proof thy)
           shyps prop ps body;
-      val der' = make_deriv empty_promises Oracles.empty [pthm] proof;
+      val der' = make_deriv empty_promises Oracles.empty (PThms.make [pthm]) proof;
       val prop' = Proofterm.thm_node_prop (#2 pthm);
     in
       Thm (der',
--- a/src/Pure/thm_deps.ML	Tue Apr 11 15:03:02 2023 +0200
+++ b/src/Pure/thm_deps.ML	Tue Apr 11 20:32:04 2023 +0200
@@ -24,7 +24,7 @@
   let
     fun collect (PBody {oracles, thms, ...}) =
       (if Oracles.is_empty oracles then I else apfst (cons oracles)) #>
-      (tap Proofterm.join_thms thms |> fold (fn (i, thm_node) => fn (res, seen) =>
+      (tap Proofterm.join_thms thms |> PThms.fold (fn (i, thm_node) => fn (res, seen) =>
         if Intset.member seen i then (res, seen)
         else
           let val body = Future.join (Proofterm.thm_node_body thm_node)
@@ -63,11 +63,11 @@
               Inttab.update (i, SOME (thm_id, thm_name)) res
           | NONE =>
               Inttab.update (i, NONE) res
-              |> fold deps (Proofterm.thm_node_thms thm_node))
+              |> PThms.fold deps (Proofterm.thm_node_thms thm_node))
         end;
   in
     fn thms =>
-      (Inttab.build (fold (fold deps o Thm.thm_deps o Thm.transfer thy) thms), [])
+      (Inttab.build (fold (PThms.fold deps o Thm.thm_deps o Thm.transfer thy) thms), [])
       |-> Inttab.fold_rev (fn (_, SOME entry) => cons entry | _ => I)
   end;