performance tuning: replace Ord_List by Set();
authorwenzelm
Mon, 10 Apr 2023 18:08:23 +0200
changeset 77803 f34d11942ac1
parent 77802 25c114e2528e
child 77804 849c996f052b
performance tuning: replace Ord_List by Set();
src/Pure/thm.ML
--- a/src/Pure/thm.ML	Mon Apr 10 14:59:40 2023 +0200
+++ b/src/Pure/thm.ML	Mon Apr 10 18:08:23 2023 +0200
@@ -385,20 +385,22 @@
 
 type constraint = {theory: theory, typ: typ, sort: sort};
 
-local
+structure Constraints =
+  Set(
+    type key = constraint;
+    val ord =
+      Context.theory_id_ord o apply2 (Context.theory_id o #theory)
+      ||| Term_Ord.typ_ord o apply2 #typ
+      ||| Term_Ord.sort_ord o apply2 #sort;
+  );
 
-val constraint_ord : constraint ord =
-  Context.theory_id_ord o apply2 (Context.theory_id o #theory)
-  ||| Term_Ord.typ_ord o apply2 #typ
-  ||| Term_Ord.sort_ord o apply2 #sort;
+local
 
 val smash_atyps =
   map_atyps (fn TVar (_, S) => Term.aT S | TFree (_, S) => Term.aT S | T => T);
 
 in
 
-val union_constraints = Ord_List.union constraint_ord;
-
 fun insert_constraints thy (T, S) =
   let
     val ignored =
@@ -409,7 +411,7 @@
         | _ => false);
   in
     if ignored then I
-    else Ord_List.insert constraint_ord {theory = thy, typ = smash_atyps T, sort = S}
+    else Constraints.insert {theory = thy, typ = smash_atyps T, sort = S}
   end;
 
 fun insert_constraints_env thy env =
@@ -429,7 +431,7 @@
  {cert: Context.certificate,    (*background theory certificate*)
   tags: Properties.T,           (*additional annotations/comments*)
   maxidx: int,                  (*maximum index of any Var or TVar*)
-  constraints: constraint Ord_List.T,  (*implicit proof obligations for sort constraints*)
+  constraints: Constraints.T,   (*implicit proof obligations for sort constraints*)
   shyps: Sortset.T,             (*sort hypotheses*)
   hyps: term Ord_List.T,        (*hypotheses*)
   tpairs: (term * term) list,   (*flex-flex pairs*)
@@ -496,6 +498,7 @@
 
 val maxidx_of = #maxidx o rep_thm;
 fun maxidx_thm th i = Int.max (maxidx_of th, i);
+val constraints_of = #constraints o rep_thm;
 val shyps_of = #shyps o rep_thm;
 val hyps_of = #hyps o rep_thm;
 val prop_of = #prop o rep_thm;
@@ -566,17 +569,17 @@
         t = t, T = T, maxidx = maxidx, sorts = sorts});
 
 fun trim_context_thm th =
-  (case th of
-    Thm (_, {constraints = _ :: _, ...}) =>
-      raise THM ("trim_context: pending sort constraints", 0, [th])
-  | Thm (_, {cert = Context.Certificate_Id _, ...}) => th
-  | Thm (der,
-      {cert = Context.Certificate thy, tags, maxidx, constraints = [], shyps, hyps,
-        tpairs, prop}) =>
-      Thm (der,
-       {cert = Context.Certificate_Id (Context.theory_id thy),
-        tags = tags, maxidx = maxidx, constraints = [], shyps = shyps, hyps = hyps,
-        tpairs = tpairs, prop = prop}));
+  if Constraints.is_empty (constraints_of th) then
+    (case th of
+      Thm (_, {cert = Context.Certificate_Id _, ...}) => th
+    | Thm (der,
+        {cert = Context.Certificate thy, tags, maxidx, constraints = _, shyps, hyps,
+          tpairs, prop}) =>
+        Thm (der,
+         {cert = Context.Certificate_Id (Context.theory_id thy),
+          tags = tags, maxidx = maxidx, constraints = Constraints.empty,
+          shyps = shyps, hyps = hyps, tpairs = tpairs, prop = prop}))
+  else raise THM ("trim_context: pending sort constraints", 0, [th]);
 
 fun transfer_ctyp thy' cT =
   let
@@ -804,7 +807,7 @@
 
     val _ = Context.eq_certificate (cert, orig_cert) orelse err "bad theory";
     val _ = prop aconv orig_prop orelse err "bad prop";
-    val _ = null constraints orelse err "bad sort constraints";
+    val _ = Constraints.is_empty constraints orelse err "bad sort constraints";
     val _ = null tpairs orelse err "bad flex-flex constraints";
     val _ = null hyps orelse err "bad hyps";
     val _ = Sortset.subset (shyps, orig_shyps) orelse err "bad shyps";
@@ -827,7 +830,7 @@
      {cert = cert,
       tags = [],
       maxidx = maxidx,
-      constraints = [],
+      constraints = Constraints.empty,
       shyps = sorts,
       hyps = [],
       tpairs = [],
@@ -849,7 +852,7 @@
       in
         Thm (der,
           {cert = cert, tags = [], maxidx = maxidx,
-            constraints = [], shyps = shyps, hyps = [], tpairs = [], prop = prop})
+            constraints = Constraints.empty, shyps = shyps, hyps = [], tpairs = [], prop = prop})
       end
   | NONE => raise THEORY ("No axiom " ^ quote name, [thy]));
 
@@ -981,30 +984,31 @@
 
 in
 
-fun solve_constraints (thm as Thm (_, {constraints = [], ...})) = thm
-  | solve_constraints (thm as Thm (der, args)) =
-      let
-        val {cert, tags, maxidx, constraints, shyps, hyps, tpairs, prop} = args;
+fun solve_constraints (thm as Thm (der, args)) =
+  if Constraints.is_empty (constraints_of thm) then thm
+  else
+    let
+      val {cert, tags, maxidx, constraints, shyps, hyps, tpairs, prop} = args;
 
-        val thy = Context.certificate_theory cert;
-        val bad_thys =
-          constraints |> map_filter (fn {theory = thy', ...} =>
-            if Context.eq_thy (thy, thy') then NONE else SOME thy');
-        val () =
-          if null bad_thys then ()
-          else
-            raise THEORY ("solve_constraints: bad theories for theorem\n" ^
-              Syntax.string_of_term_global thy (prop_of thm), thy :: bad_thys);
+      val thy = Context.certificate_theory cert;
+      val bad_thys =
+        Constraints.fold (fn {theory = thy', ...} =>
+          if Context.eq_thy (thy, thy') then I else cons thy') constraints [];
+      val () =
+        if null bad_thys then ()
+        else
+          raise THEORY ("solve_constraints: bad theories for theorem\n" ^
+            Syntax.string_of_term_global thy (prop_of thm), thy :: bad_thys);
 
-        val Deriv {promises, body = PBody {oracles, thms, proof}} = der;
-        val (oracles', thms') = (oracles, thms)
-          |> fold (fold union_digest o constraint_digest) constraints;
-        val body' = PBody {oracles = oracles', thms = thms', proof = proof};
-      in
-        Thm (Deriv {promises = promises, body = body'},
-          {constraints = [], cert = cert, tags = tags, maxidx = maxidx,
-            shyps = shyps, hyps = hyps, tpairs = tpairs, prop = prop})
-      end;
+      val Deriv {promises, body = PBody {oracles, thms, proof}} = der;
+      val (oracles', thms') = (oracles, thms)
+        |> Constraints.fold (fold union_digest o constraint_digest) constraints;
+      val body' = PBody {oracles = oracles', thms = thms', proof = proof};
+    in
+      Thm (Deriv {promises = promises, body = body'},
+        {constraints = Constraints.empty, cert = cert, tags = tags, maxidx = maxidx,
+          shyps = shyps, hyps = hyps, tpairs = tpairs, prop = prop})
+    end;
 
 end;
 
@@ -1137,7 +1141,7 @@
              {cert = Context.join_certificate (Context.Certificate thy', cert2),
               tags = [],
               maxidx = maxidx,
-              constraints = [],
+              constraints = Constraints.empty,
               shyps = sorts,
               hyps = [],
               tpairs = [],
@@ -1174,7 +1178,7 @@
      {cert = cert,
       tags = [],
       maxidx = ~1,
-      constraints = [],
+      constraints = Constraints.empty,
       shyps = sorts,
       hyps = [prop],
       tpairs = [],
@@ -1225,7 +1229,7 @@
            {cert = join_certificate2 (thAB, thA),
             tags = [],
             maxidx = Int.max (maxidx1, maxidx2),
-            constraints = union_constraints constraintsA constraints,
+            constraints = Constraints.merge (constraintsA, constraints),
             shyps = Sortset.merge (shypsA, shyps),
             hyps = union_hyps hypsA hyps,
             tpairs = union_tpairs tpairsA tpairs,
@@ -1301,7 +1305,7 @@
    {cert = cert,
     tags = [],
     maxidx = maxidx,
-    constraints = [],
+    constraints = Constraints.empty,
     shyps = sorts,
     hyps = [],
     tpairs = [],
@@ -1347,7 +1351,7 @@
            {cert = join_certificate2 (th1, th2),
             tags = [],
             maxidx = Int.max (maxidx1, maxidx2),
-            constraints = union_constraints constraints1 constraints2,
+            constraints = Constraints.merge (constraints1, constraints2),
             shyps = Sortset.merge (shyps1, shyps2),
             hyps = union_hyps hyps1 hyps2,
             tpairs = union_tpairs tpairs1 tpairs2,
@@ -1370,7 +1374,7 @@
      {cert = cert,
       tags = [],
       maxidx = maxidx,
-      constraints = [],
+      constraints = Constraints.empty,
       shyps = sorts,
       hyps = [],
       tpairs = [],
@@ -1382,7 +1386,7 @@
    {cert = cert,
     tags = [],
     maxidx = maxidx,
-    constraints = [],
+    constraints = Constraints.empty,
     shyps = sorts,
     hyps = [],
     tpairs = [],
@@ -1393,7 +1397,7 @@
    {cert = cert,
     tags = [],
     maxidx = maxidx,
-    constraints = [],
+    constraints = Constraints.empty,
     shyps = sorts,
     hyps = [],
     tpairs = [],
@@ -1460,7 +1464,7 @@
            {cert = join_certificate2 (th1, th2),
             tags = [],
             maxidx = Int.max (maxidx1, maxidx2),
-            constraints = union_constraints constraints1 constraints2,
+            constraints = Constraints.merge (constraints1, constraints2),
             shyps = Sortset.merge (shyps1, shyps2),
             hyps = union_hyps hyps1 hyps2,
             tpairs = union_tpairs tpairs1 tpairs2,
@@ -1488,7 +1492,7 @@
            {cert = join_certificate2 (th1, th2),
             tags = [],
             maxidx = Int.max (maxidx1, maxidx2),
-            constraints = union_constraints constraints1 constraints2,
+            constraints = Constraints.merge (constraints1, constraints2),
             shyps = Sortset.merge (shyps1, shyps2),
             hyps = union_hyps hyps1 hyps2,
             tpairs = union_tpairs tpairs1 tpairs2,
@@ -1517,7 +1521,7 @@
            {cert = join_certificate2 (th1, th2),
             tags = [],
             maxidx = Int.max (maxidx1, maxidx2),
-            constraints = union_constraints constraints1 constraints2,
+            constraints = Constraints.merge (constraints1, constraints2),
             shyps = Sortset.merge (shyps1, shyps2),
             hyps = union_hyps hyps1 hyps2,
             tpairs = union_tpairs tpairs1 tpairs2,
@@ -1741,7 +1745,7 @@
      {cert = cert,
       tags = [],
       maxidx = maxidx,
-      constraints = [],
+      constraints = Constraints.empty,
       shyps = sorts,
       hyps = [],
       tpairs = [],
@@ -1765,7 +1769,7 @@
        {cert = cert,
         tags = [],
         maxidx = maxidx,
-        constraints = insert_constraints thy (T, [c]) [],
+        constraints = Constraints.build (insert_constraints thy (T, [c])),
         shyps = sorts,
         hyps = [],
         tpairs = [],
@@ -1798,7 +1802,7 @@
        {cert = cert,
         tags = [],
         maxidx = maxidx_of_term prop',
-        constraints = [],
+        constraints = Constraints.empty,
         shyps = Sortset.make [[]],  (*potentially redundant*)
         hyps = [],
         tpairs = [],
@@ -2166,7 +2170,7 @@
                   (ntps, (map normt (Bs @ As), normt C))
              end
            val constraints' =
-             union_constraints constraints1 constraints2
+             Constraints.merge (constraints1, constraints2)
              |> insert_constraints_env (Context.certificate_theory cert) env;
            fun bicompose_proof prf1 prf2 =
              Proofterm.bicompose_proof flatten (map normt Bs) (map normt As) A oldAs n (nlift+1)