--- a/src/Pure/thm.ML Mon Apr 10 14:59:40 2023 +0200
+++ b/src/Pure/thm.ML Mon Apr 10 18:08:23 2023 +0200
@@ -385,20 +385,22 @@
type constraint = {theory: theory, typ: typ, sort: sort};
-local
+structure Constraints =
+ Set(
+ type key = constraint;
+ val ord =
+ Context.theory_id_ord o apply2 (Context.theory_id o #theory)
+ ||| Term_Ord.typ_ord o apply2 #typ
+ ||| Term_Ord.sort_ord o apply2 #sort;
+ );
-val constraint_ord : constraint ord =
- Context.theory_id_ord o apply2 (Context.theory_id o #theory)
- ||| Term_Ord.typ_ord o apply2 #typ
- ||| Term_Ord.sort_ord o apply2 #sort;
+local
val smash_atyps =
map_atyps (fn TVar (_, S) => Term.aT S | TFree (_, S) => Term.aT S | T => T);
in
-val union_constraints = Ord_List.union constraint_ord;
-
fun insert_constraints thy (T, S) =
let
val ignored =
@@ -409,7 +411,7 @@
| _ => false);
in
if ignored then I
- else Ord_List.insert constraint_ord {theory = thy, typ = smash_atyps T, sort = S}
+ else Constraints.insert {theory = thy, typ = smash_atyps T, sort = S}
end;
fun insert_constraints_env thy env =
@@ -429,7 +431,7 @@
{cert: Context.certificate, (*background theory certificate*)
tags: Properties.T, (*additional annotations/comments*)
maxidx: int, (*maximum index of any Var or TVar*)
- constraints: constraint Ord_List.T, (*implicit proof obligations for sort constraints*)
+ constraints: Constraints.T, (*implicit proof obligations for sort constraints*)
shyps: Sortset.T, (*sort hypotheses*)
hyps: term Ord_List.T, (*hypotheses*)
tpairs: (term * term) list, (*flex-flex pairs*)
@@ -496,6 +498,7 @@
val maxidx_of = #maxidx o rep_thm;
fun maxidx_thm th i = Int.max (maxidx_of th, i);
+val constraints_of = #constraints o rep_thm;
val shyps_of = #shyps o rep_thm;
val hyps_of = #hyps o rep_thm;
val prop_of = #prop o rep_thm;
@@ -566,17 +569,17 @@
t = t, T = T, maxidx = maxidx, sorts = sorts});
fun trim_context_thm th =
- (case th of
- Thm (_, {constraints = _ :: _, ...}) =>
- raise THM ("trim_context: pending sort constraints", 0, [th])
- | Thm (_, {cert = Context.Certificate_Id _, ...}) => th
- | Thm (der,
- {cert = Context.Certificate thy, tags, maxidx, constraints = [], shyps, hyps,
- tpairs, prop}) =>
- Thm (der,
- {cert = Context.Certificate_Id (Context.theory_id thy),
- tags = tags, maxidx = maxidx, constraints = [], shyps = shyps, hyps = hyps,
- tpairs = tpairs, prop = prop}));
+ if Constraints.is_empty (constraints_of th) then
+ (case th of
+ Thm (_, {cert = Context.Certificate_Id _, ...}) => th
+ | Thm (der,
+ {cert = Context.Certificate thy, tags, maxidx, constraints = _, shyps, hyps,
+ tpairs, prop}) =>
+ Thm (der,
+ {cert = Context.Certificate_Id (Context.theory_id thy),
+ tags = tags, maxidx = maxidx, constraints = Constraints.empty,
+ shyps = shyps, hyps = hyps, tpairs = tpairs, prop = prop}))
+ else raise THM ("trim_context: pending sort constraints", 0, [th]);
fun transfer_ctyp thy' cT =
let
@@ -804,7 +807,7 @@
val _ = Context.eq_certificate (cert, orig_cert) orelse err "bad theory";
val _ = prop aconv orig_prop orelse err "bad prop";
- val _ = null constraints orelse err "bad sort constraints";
+ val _ = Constraints.is_empty constraints orelse err "bad sort constraints";
val _ = null tpairs orelse err "bad flex-flex constraints";
val _ = null hyps orelse err "bad hyps";
val _ = Sortset.subset (shyps, orig_shyps) orelse err "bad shyps";
@@ -827,7 +830,7 @@
{cert = cert,
tags = [],
maxidx = maxidx,
- constraints = [],
+ constraints = Constraints.empty,
shyps = sorts,
hyps = [],
tpairs = [],
@@ -849,7 +852,7 @@
in
Thm (der,
{cert = cert, tags = [], maxidx = maxidx,
- constraints = [], shyps = shyps, hyps = [], tpairs = [], prop = prop})
+ constraints = Constraints.empty, shyps = shyps, hyps = [], tpairs = [], prop = prop})
end
| NONE => raise THEORY ("No axiom " ^ quote name, [thy]));
@@ -981,30 +984,31 @@
in
-fun solve_constraints (thm as Thm (_, {constraints = [], ...})) = thm
- | solve_constraints (thm as Thm (der, args)) =
- let
- val {cert, tags, maxidx, constraints, shyps, hyps, tpairs, prop} = args;
+fun solve_constraints (thm as Thm (der, args)) =
+ if Constraints.is_empty (constraints_of thm) then thm
+ else
+ let
+ val {cert, tags, maxidx, constraints, shyps, hyps, tpairs, prop} = args;
- val thy = Context.certificate_theory cert;
- val bad_thys =
- constraints |> map_filter (fn {theory = thy', ...} =>
- if Context.eq_thy (thy, thy') then NONE else SOME thy');
- val () =
- if null bad_thys then ()
- else
- raise THEORY ("solve_constraints: bad theories for theorem\n" ^
- Syntax.string_of_term_global thy (prop_of thm), thy :: bad_thys);
+ val thy = Context.certificate_theory cert;
+ val bad_thys =
+ Constraints.fold (fn {theory = thy', ...} =>
+ if Context.eq_thy (thy, thy') then I else cons thy') constraints [];
+ val () =
+ if null bad_thys then ()
+ else
+ raise THEORY ("solve_constraints: bad theories for theorem\n" ^
+ Syntax.string_of_term_global thy (prop_of thm), thy :: bad_thys);
- val Deriv {promises, body = PBody {oracles, thms, proof}} = der;
- val (oracles', thms') = (oracles, thms)
- |> fold (fold union_digest o constraint_digest) constraints;
- val body' = PBody {oracles = oracles', thms = thms', proof = proof};
- in
- Thm (Deriv {promises = promises, body = body'},
- {constraints = [], cert = cert, tags = tags, maxidx = maxidx,
- shyps = shyps, hyps = hyps, tpairs = tpairs, prop = prop})
- end;
+ val Deriv {promises, body = PBody {oracles, thms, proof}} = der;
+ val (oracles', thms') = (oracles, thms)
+ |> Constraints.fold (fold union_digest o constraint_digest) constraints;
+ val body' = PBody {oracles = oracles', thms = thms', proof = proof};
+ in
+ Thm (Deriv {promises = promises, body = body'},
+ {constraints = Constraints.empty, cert = cert, tags = tags, maxidx = maxidx,
+ shyps = shyps, hyps = hyps, tpairs = tpairs, prop = prop})
+ end;
end;
@@ -1137,7 +1141,7 @@
{cert = Context.join_certificate (Context.Certificate thy', cert2),
tags = [],
maxidx = maxidx,
- constraints = [],
+ constraints = Constraints.empty,
shyps = sorts,
hyps = [],
tpairs = [],
@@ -1174,7 +1178,7 @@
{cert = cert,
tags = [],
maxidx = ~1,
- constraints = [],
+ constraints = Constraints.empty,
shyps = sorts,
hyps = [prop],
tpairs = [],
@@ -1225,7 +1229,7 @@
{cert = join_certificate2 (thAB, thA),
tags = [],
maxidx = Int.max (maxidx1, maxidx2),
- constraints = union_constraints constraintsA constraints,
+ constraints = Constraints.merge (constraintsA, constraints),
shyps = Sortset.merge (shypsA, shyps),
hyps = union_hyps hypsA hyps,
tpairs = union_tpairs tpairsA tpairs,
@@ -1301,7 +1305,7 @@
{cert = cert,
tags = [],
maxidx = maxidx,
- constraints = [],
+ constraints = Constraints.empty,
shyps = sorts,
hyps = [],
tpairs = [],
@@ -1347,7 +1351,7 @@
{cert = join_certificate2 (th1, th2),
tags = [],
maxidx = Int.max (maxidx1, maxidx2),
- constraints = union_constraints constraints1 constraints2,
+ constraints = Constraints.merge (constraints1, constraints2),
shyps = Sortset.merge (shyps1, shyps2),
hyps = union_hyps hyps1 hyps2,
tpairs = union_tpairs tpairs1 tpairs2,
@@ -1370,7 +1374,7 @@
{cert = cert,
tags = [],
maxidx = maxidx,
- constraints = [],
+ constraints = Constraints.empty,
shyps = sorts,
hyps = [],
tpairs = [],
@@ -1382,7 +1386,7 @@
{cert = cert,
tags = [],
maxidx = maxidx,
- constraints = [],
+ constraints = Constraints.empty,
shyps = sorts,
hyps = [],
tpairs = [],
@@ -1393,7 +1397,7 @@
{cert = cert,
tags = [],
maxidx = maxidx,
- constraints = [],
+ constraints = Constraints.empty,
shyps = sorts,
hyps = [],
tpairs = [],
@@ -1460,7 +1464,7 @@
{cert = join_certificate2 (th1, th2),
tags = [],
maxidx = Int.max (maxidx1, maxidx2),
- constraints = union_constraints constraints1 constraints2,
+ constraints = Constraints.merge (constraints1, constraints2),
shyps = Sortset.merge (shyps1, shyps2),
hyps = union_hyps hyps1 hyps2,
tpairs = union_tpairs tpairs1 tpairs2,
@@ -1488,7 +1492,7 @@
{cert = join_certificate2 (th1, th2),
tags = [],
maxidx = Int.max (maxidx1, maxidx2),
- constraints = union_constraints constraints1 constraints2,
+ constraints = Constraints.merge (constraints1, constraints2),
shyps = Sortset.merge (shyps1, shyps2),
hyps = union_hyps hyps1 hyps2,
tpairs = union_tpairs tpairs1 tpairs2,
@@ -1517,7 +1521,7 @@
{cert = join_certificate2 (th1, th2),
tags = [],
maxidx = Int.max (maxidx1, maxidx2),
- constraints = union_constraints constraints1 constraints2,
+ constraints = Constraints.merge (constraints1, constraints2),
shyps = Sortset.merge (shyps1, shyps2),
hyps = union_hyps hyps1 hyps2,
tpairs = union_tpairs tpairs1 tpairs2,
@@ -1741,7 +1745,7 @@
{cert = cert,
tags = [],
maxidx = maxidx,
- constraints = [],
+ constraints = Constraints.empty,
shyps = sorts,
hyps = [],
tpairs = [],
@@ -1765,7 +1769,7 @@
{cert = cert,
tags = [],
maxidx = maxidx,
- constraints = insert_constraints thy (T, [c]) [],
+ constraints = Constraints.build (insert_constraints thy (T, [c])),
shyps = sorts,
hyps = [],
tpairs = [],
@@ -1798,7 +1802,7 @@
{cert = cert,
tags = [],
maxidx = maxidx_of_term prop',
- constraints = [],
+ constraints = Constraints.empty,
shyps = Sortset.make [[]], (*potentially redundant*)
hyps = [],
tpairs = [],
@@ -2166,7 +2170,7 @@
(ntps, (map normt (Bs @ As), normt C))
end
val constraints' =
- union_constraints constraints1 constraints2
+ Constraints.merge (constraints1, constraints2)
|> insert_constraints_env (Context.certificate_theory cert) env;
fun bicompose_proof prf1 prf2 =
Proofterm.bicompose_proof flatten (map normt Bs) (map normt As) A oldAs n (nlift+1)