--- a/src/Pure/zterm.ML Wed Jan 03 12:40:10 2024 +0100
+++ b/src/Pure/zterm.ML Thu Jan 04 15:16:10 2024 +0100
@@ -805,18 +805,27 @@
local
-fun close_prop hyps concl =
- fold_rev (fn A => fn B => ZApp (ZApp (ZConst0 "Pure.imp", A), B)) hyps concl;
+fun close_prop prems prop =
+ fold_rev (fn A => fn B => ZApp (ZApp (ZConst0 "Pure.imp", A), B)) prems prop;
-fun close_proof hyps prf =
+fun close_proof of_class prems prf =
let
- val m = length hyps - 1;
- val bounds = ZTerms.build (hyps |> fold_index (fn (i, h) => ZTerms.update (h, m - i)));
+ fun err msg t = raise ZTERM (msg, [], [t], [prf]);
+
+ val m = length prems - 1;
+ val bounds = ZTerms.build (prems |> fold_index (fn (i, h) => ZTerms.update (h, m - i)));
+ fun get_bound lev t = ZTerms.lookup bounds t |> Option.map (fn i => ZBoundp (lev + i));
fun proof lev (ZHyp t) =
- (case ZTerms.lookup bounds t of
- SOME i => ZBoundp (lev + i)
- | NONE => raise ZTERM ("Unbound proof hypothesis", [], [t], []))
+ (case get_bound lev t of
+ SOME p => p
+ | NONE => err "Loose bound in proof term" t)
+ | proof lev (ZClassp C) =
+ (case get_bound lev (ZClass C) of
+ SOME p => p
+ | NONE =>
+ if of_class C then raise Same.SAME
+ else err "Invalid class constraint " (ZClass C))
| proof lev (ZAbst (x, T, p)) = ZAbst (x, T, proof lev p)
| proof lev (ZAbsp (x, t, p)) = ZAbsp (x, t, proof (lev + 1) p)
| proof lev (ZAppt (p, t)) = ZAppt (proof lev p, t)
@@ -824,21 +833,50 @@
(ZAppp (proof lev p, Same.commit (proof lev) q)
handle Same.SAME => ZAppp (p, proof lev q))
| proof _ _ = raise Same.SAME;
- in ZAbsps hyps (Same.commit (proof 0) prf) end;
+ in ZAbsps prems (Same.commit (proof 0) prf) end;
fun box_proof zproof_name thy hyps concl prf =
let
- val {zterm, ...} = zterm_cache thy;
- val hyps' = map zterm hyps;
- val concl' = zterm concl;
+ val {zterm, typ, ...} = norm_cache thy;
+
+ val algebra = Sign.classes_of thy;
+ fun of_class (T, c) = Sorts.of_sort algebra (typ T, [c]);
+
+ val present_set = Types.build (fold Types.add_atyps hyps #> Types.add_atyps concl);
+ val ucontext = Logic.unconstrain_context [] present_set;
+
+ val outer_constraints = map (apfst ztyp_of) (#outer_constraints ucontext);
+ val constraints = map (zterm o #2) (#constraints ucontext);
+
+ val typ_operation = #typ_operation ucontext {strip_sorts = true};
+ val unconstrain_typ = Same.commit typ_operation;
+ val unconstrain_ztyp =
+ ZTypes.unsynchronized_cache (Same.function_eq (op =) (typ_of #> unconstrain_typ #> ztyp_of));
+ val unconstrain_zterm = zterm o Term.map_types typ_operation;
+ val unconstrain_proof = Same.commit (map_proof_types {hyps = true} unconstrain_ztyp);
- val prop' = beta_norm_term (close_prop hyps' concl');
- val prf' = beta_norm_prooft (close_proof hyps' prf);
+ val constrain_instT =
+ ZTVars.build (present_set |> Types.fold (fn (T, _) =>
+ let
+ val ZTVar v = ztyp_of (unconstrain_typ T);
+ val U = ztyp_of T;
+ in ZTVars.add (v, U) end));
+ val constrain_proof =
+ map_proof_types {hyps = true} (subst_type_same (fn v =>
+ let
+ val T = ZTVar v;
+ val T' = the_default T (ZTVars.lookup constrain_instT v);
+ in if T = T' then raise Same.SAME else T' end));
+
+ val args = map ZClassp outer_constraints @ map (ZHyp o zterm) hyps;
+ val prems = constraints @ map unconstrain_zterm hyps;
+ val prop' = beta_norm_term (close_prop prems (unconstrain_zterm concl));
+ val prf' = beta_norm_prooft (close_proof of_class prems (unconstrain_proof prf));
val i = serial ();
val zbox: zbox = (i, (prop', prf'));
- val zbox_prf = ZAppts (ZConstp (zproof_const (zproof_name i) prop'), hyps');
- in (zbox, zbox_prf) end;
+ val const = constrain_proof (ZConstp (zproof_const (zproof_name i) prop'));
+ in (zbox, ZAppps (const, args)) end;
in