performance tuning: replace Ord_List by Table();
authorwenzelm
Mon, 10 Apr 2023 19:37:15 +0200
changeset 77806 b6aa5eac0a1a
parent 77805 66779a752f10
child 77807 15d39d6bb258
performance tuning: replace Ord_List by Table();
src/Pure/thm.ML
--- a/src/Pure/thm.ML	Mon Apr 10 18:16:33 2023 +0200
+++ b/src/Pure/thm.ML	Mon Apr 10 19:37:15 2023 +0200
@@ -437,9 +437,17 @@
   tpairs: (term * term) list,   (*flex-flex pairs*)
   prop: term}                   (*conclusion*)
 and deriv = Deriv of
- {promises: (serial * thm future) Ord_List.T,
+ {promises: thm future Inttab'.table,
   body: Proofterm.proof_body};
 
+type promises = thm future Inttab'.table;
+val null_promises : promises -> bool = Inttab'.is_empty;
+val empty_promises : promises = Inttab'.empty;
+val merge_promises : promises * promises -> promises = Inttab'.merge (K true);
+val make_promises : (serial * thm future) list -> promises = Inttab'.make;
+val dest_promises : promises -> (serial * thm future) list = Inttab'.dest;
+fun forall_promises f : promises -> bool = Inttab'.forall (f o fst);
+
 type conv = cterm -> thm;
 
 (*errors involving theorems*)
@@ -734,21 +742,19 @@
 fun make_deriv promises oracles thms proof =
   Deriv {promises = promises, body = PBody {oracles = oracles, thms = thms, proof = proof}};
 
-val empty_deriv = make_deriv [] [] [] MinProof;
+val empty_deriv = make_deriv empty_promises [] [] MinProof;
 
 
 (* inference rules *)
 
-val promise_ord: (serial * thm future) ord = fn ((i, _), (j, _)) => int_ord (j, i);
-
 fun bad_proofs i =
   error ("Illegal level of detail for proof objects: " ^ string_of_int i);
 
 fun deriv_rule2 f
-    (Deriv {promises = ps1, body = PBody {oracles = oracles1, thms = thms1, proof = prf1}})
-    (Deriv {promises = ps2, body = PBody {oracles = oracles2, thms = thms2, proof = prf2}}) =
+    (Deriv {promises = promises1, body = PBody {oracles = oracles1, thms = thms1, proof = prf1}})
+    (Deriv {promises = promises2, body = PBody {oracles = oracles2, thms = thms2, proof = prf2}}) =
   let
-    val ps = Ord_List.union promise_ord ps1 ps2;
+    val ps = merge_promises (promises1, promises2);
     val oracles = Proofterm.unions_oracles [oracles1, oracles2];
     val thms = Proofterm.unions_thms [thms1, thms2];
     val prf =
@@ -763,7 +769,7 @@
 
 fun deriv_rule0 make_prf =
   if ! Proofterm.proofs <= 1 then empty_deriv
-  else deriv_rule1 I (make_deriv [] [] [] (make_prf ()));
+  else deriv_rule1 I (make_deriv empty_promises [] [] (make_prf ()));
 
 fun deriv_rule_unconditional f (Deriv {promises, body = PBody {oracles, thms, proof}}) =
   make_deriv promises oracles thms (f proof);
@@ -771,15 +777,18 @@
 
 (* fulfilled proofs *)
 
-fun raw_promises_of (Thm (Deriv {promises, ...}, _)) = promises;
+fun merge_promises_of (Thm (Deriv {promises, ...}, _)) ps = merge_promises (ps, promises);
 
 fun join_promises [] = ()
   | join_promises promises = join_promises_of (Future.joins (map snd promises))
-and join_promises_of thms = join_promises (Ord_List.make promise_ord (maps raw_promises_of thms));
+and join_promises_of thms =
+  join_promises (dest_promises (fold merge_promises_of thms empty_promises));
 
 fun fulfill_body (th as Thm (Deriv {promises, body}, _)) =
-  let val fulfilled_promises = map #1 promises ~~ map fulfill_body (Future.joins (map #2 promises))
-  in Proofterm.fulfill_norm_proof (theory_of_thm th) fulfilled_promises body end;
+  let
+    val pending = dest_promises promises;
+    val fulfilled = map #1 pending ~~ map fulfill_body (Future.joins (map #2 pending));
+  in Proofterm.fulfill_norm_proof (theory_of_thm th) fulfilled body end;
 
 fun proof_bodies_of thms = (join_promises_of thms; map fulfill_body thms);
 val proof_body_of = singleton proof_bodies_of;
@@ -811,8 +820,8 @@
     val _ = null tpairs orelse err "bad flex-flex constraints";
     val _ = null hyps orelse err "bad hyps";
     val _ = Sortset.subset (shyps, orig_shyps) orelse err "bad shyps";
-    val _ = forall (fn (j, _) => i <> j) promises orelse err "bad dependencies";
-    val _ = join_promises promises;
+    val _ = forall_promises (fn j => i <> j) promises orelse err "bad dependencies";
+    val _ = join_promises (dest_promises promises);
   in thm end;
 
 fun future future_thm ct =
@@ -826,7 +835,7 @@
     val i = serial ();
     val future = future_thm |> Future.map (future_result i cert sorts prop);
   in
-    Thm (make_deriv [(i, future)] [] [] MinProof,
+    Thm (make_deriv (make_promises [(i, future)]) [] [] MinProof,
      {cert = cert,
       tags = [],
       maxidx = maxidx,
@@ -1085,12 +1094,13 @@
   Proofterm.get_id shyps hyps prop (proof_of thm);
 
 (*dependencies of PThm node*)
-fun thm_deps (thm as Thm (Deriv {promises = [], body = PBody {thms, ...}, ...}, _)) =
+fun thm_deps (thm as Thm (Deriv {promises, body = PBody {thms, ...}, ...}, _)) =
+  if null_promises promises then
       (case (derivation_id thm, thms) of
         (SOME {serial = i, ...}, [(j, thm_node)]) =>
           if i = j then Proofterm.thm_node_thms thm_node else thms
       | _ => thms)
-  | thm_deps thm = raise THM ("thm_deps: bad promises", 0, [thm]);
+  else raise THM ("thm_deps: bad promises", 0, [thm]);
 
 fun name_derivation name_pos =
   strip_shyps #> (fn thm as Thm (der, args) =>
@@ -1102,11 +1112,11 @@
 
       val _ = null tpairs orelse raise THM ("name_derivation: bad flex-flex constraints", 0, [thm]);
 
-      val ps = map (apsnd (Future.map fulfill_body)) promises;
+      val ps = map (apsnd (Future.map fulfill_body)) (dest_promises promises);
       val (pthm, proof) =
         Proofterm.thm_proof thy (classrel_proof thy) (arity_proof thy)
           name_pos shyps hyps prop ps body;
-      val der' = make_deriv [] [] [pthm] proof;
+      val der' = make_deriv empty_promises [] [pthm] proof;
     in Thm (der', args) end);
 
 fun close_derivation pos =
@@ -1137,7 +1147,7 @@
               | 0 => (((name, Position.none), NONE), MinProof)
               | i => bad_proofs i);
           in
-            Thm (make_deriv [] [oracle] [] prf,
+            Thm (make_deriv empty_promises [oracle] [] prf,
              {cert = Context.join_certificate (Context.Certificate thy', cert2),
               tags = [],
               maxidx = maxidx,
@@ -1791,11 +1801,11 @@
       val tfrees = build_rev (Term.add_tfree_names prop);
       val _ = null tfrees orelse err ("illegal free type variables " ^ commas_quote tfrees);
 
-      val ps = map (apsnd (Future.map fulfill_body)) promises;
+      val ps = map (apsnd (Future.map fulfill_body)) (dest_promises promises);
       val (pthm, proof) =
         Proofterm.unconstrain_thm_proof thy (classrel_proof thy) (arity_proof thy)
           shyps prop ps body;
-      val der' = make_deriv [] [] [pthm] proof;
+      val der' = make_deriv empty_promises [] [pthm] proof;
       val prop' = Proofterm.thm_node_prop (#2 pthm);
     in
       Thm (der',