--- a/src/Pure/Isar/class.ML Fri Mar 07 13:53:06 2008 +0100
+++ b/src/Pure/Isar/class.ML Fri Mar 07 13:53:07 2008 +0100
@@ -70,15 +70,6 @@
(Method.Basic (K (Method.SIMPLE_METHOD tac), Position.none), NONE)
#> ProofContext.theory_of;
-fun get_remove_global_constraint c thy =
- let
- val ty = Sign.the_const_constraint thy c;
- in
- thy
- |> Sign.add_const_constraint (c, NONE)
- |> pair (c, Logic.unvarifyT ty)
- end;
-
(** primitive axclass and instance commands **)
@@ -345,15 +336,20 @@
fun class_interpretation class facts defs thy =
let
val params = these_params thy [class];
+ val consts = map (fst o snd) params;
+ val constraints = map (fn c => map_atyps (K (TFree (Name.aT,
+ [the (AxClass.class_of_param thy c)]))) (Sign.the_const_type thy c)) consts;
+ val no_constraints = map (map_atyps (K (TFree (Name.aT, [])))) constraints;
+ fun add_constraint c T = Sign.add_const_constraint (c, SOME T);
val inst = (#inst o the_class_data thy) class;
val tac = ALLGOALS (ProofContext.fact_tac facts);
val prfx = class_prefix class;
in
thy
- |> fold_map (get_remove_global_constraint o fst o snd) params
- ||> prove_interpretation tac ((false, prfx), []) (Locale.Locale class)
+ |> fold2 add_constraint consts no_constraints
+ |> prove_interpretation tac ((false, prfx), []) (Locale.Locale class)
(inst, map (fn def => (("", []), def)) defs)
- |-> (fn cs => fold (Sign.add_const_constraint o apsnd SOME) cs)
+ |> fold2 add_constraint consts constraints
end;
fun prove_subclass (sub, sup) thm thy =
@@ -398,26 +394,7 @@
(* class context syntax *)
-structure ClassSyntax = ProofDataFun(
- type T = {
- local_constraints: (string * typ) list,
- global_constraints: (string * typ) list,
- base_sort: sort,
- operations: (string * (typ * term)) list,
- unchecks: (term * term) list,
- passed: bool
- };
- fun init _ = {
- local_constraints = [],
- global_constraints = [],
- base_sort = [],
- operations = [],
- unchecks = [],
- passed = true
- };;
-);
-
-fun synchronize_syntax sups base_sort ctxt =
+fun synchronize_class_syntax sups base_sort ctxt =
let
val thy = ProofContext.theory_of ctxt;
fun subst_class_typ sort = map_atyps
@@ -430,80 +407,44 @@
fun declare_const (c, _) =
let val b = Sign.base_name c
in Sign.intern_const thy b = c ? Variable.declare_const (b, c) end;
+ fun improve (c, ty) = (case AList.lookup (op =) local_constraints c
+ of SOME ty' => (case try (Type.raw_match (ty', ty)) Vartab.empty
+ of SOME tyenv => (case Vartab.lookup tyenv (Name.aT, 0)
+ of SOME (_, ty' as TVar (tvar as (vi, _))) =>
+ if TypeInfer.is_param vi
+ then SOME (ty', TFree (Name.aT, base_sort))
+ else NONE
+ | _ => NONE)
+ | NONE => NONE)
+ | NONE => NONE)
+ fun subst (c, ty) = Option.map snd (AList.lookup (op =) operations c);
val unchecks = map (fn (c, (_, (ty, t))) => (t, Const (c, ty))) operations;
in
ctxt
|> fold declare_const local_constraints
|> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints
- |> ClassSyntax.put {
+ |> Overloading.map_improvable_syntax (K {
local_constraints = local_constraints,
global_constraints = global_constraints,
- base_sort = base_sort,
- operations = (map o apsnd) snd operations,
+ improve = improve,
+ subst = subst,
unchecks = unchecks,
passed = false
- }
+ })
end;
fun refresh_syntax class ctxt =
let
val thy = ProofContext.theory_of ctxt;
val base_sort = (#base_sort o the_class_data thy) class;
- in synchronize_syntax [class] base_sort ctxt end;
-
-val mark_passed = ClassSyntax.map
- (fn { local_constraints, global_constraints, base_sort, operations, unchecks, passed } =>
- { local_constraints = local_constraints, global_constraints = global_constraints,
- base_sort = base_sort, operations = operations, unchecks = unchecks, passed = true });
-
-fun sort_term_check ts ctxt =
- let
- val { local_constraints, global_constraints, base_sort, operations, passed, ... } =
- ClassSyntax.get ctxt;
- fun check_improve (Const (c, ty)) = (case AList.lookup (op =) local_constraints c
- of SOME ty0 => (case try (Type.raw_match (ty0, ty)) Vartab.empty
- of SOME tyenv => (case Vartab.lookup tyenv (Name.aT, 0)
- of SOME (_, TVar (tvar as (vi, _))) =>
- if TypeInfer.is_param vi then cons tvar else I
- | _ => I)
- | NONE => I)
- | NONE => I)
- | check_improve _ = I;
- val improvements = (fold o fold_aterms) check_improve ts [];
- val ts' = (map o map_types o map_atyps) (fn ty as TVar tvar =>
- if member (op =) improvements tvar
- then TFree (Name.aT, base_sort) else ty | ty => ty) ts;
- fun check t0 = Envir.expand_term (fn Const (c, ty) => (case AList.lookup (op =) operations c
- of SOME (ty0, t) =>
- if Type.typ_instance (ProofContext.tsig_of ctxt) (ty, ty0)
- then SOME (ty0, check t) else NONE
- | NONE => NONE)
- | _ => NONE) t0;
- val ts'' = map check ts';
- in if eq_list (op aconv) (ts, ts'') andalso passed then NONE
- else
- ctxt
- |> fold (ProofContext.add_const_constraint o apsnd SOME) global_constraints
- |> mark_passed
- |> pair ts''
- |> SOME
- end;
-
-fun sort_term_uncheck ts ctxt =
- let
- val thy = ProofContext.theory_of ctxt;
- val unchecks = (#unchecks o ClassSyntax.get) ctxt;
- val ts' = map (Pattern.rewrite_term thy unchecks []) ts;
- in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
+ in synchronize_class_syntax [class] base_sort ctxt end;
fun init_ctxt sups base_sort ctxt =
ctxt
|> Variable.declare_term
(Logic.mk_type (TFree (Name.aT, base_sort)))
- |> synchronize_syntax sups base_sort
- |> Context.proof_map (
- Syntax.add_term_check 0 "class" sort_term_check
- #> Syntax.add_term_uncheck 0 "class" sort_term_uncheck)
+ |> synchronize_class_syntax sups base_sort
+ |> Overloading.add_improvable_syntax;
fun init class thy =
thy
@@ -710,46 +651,52 @@
|> find_first (fn (_, (v', _)) => v = v')
|> Option.map (fst o fst);
-fun confirm_declaration c = (map_instantiation o apsnd)
- (filter_out (fn (_, (c', _)) => c' = c));
-
(* syntax *)
-fun inst_term_check ts lthy =
+fun synchronize_inst_syntax ctxt =
let
- val params = instantiation_params lthy;
- val tsig = ProofContext.tsig_of lthy;
- val thy = ProofContext.theory_of lthy;
-
- fun check_improve (Const (c, ty)) = (case AxClass.inst_tyco_of thy (c, ty)
- of SOME tyco => (case AList.lookup (op =) params (c, tyco)
- of SOME (_, ty') => perhaps (try (Type.typ_match tsig (ty, ty')))
- | NONE => I)
- | NONE => I)
- | check_improve _ = I;
- val subst_param = map_aterms (fn t as Const (c, ty) =>
- (case AxClass.inst_tyco_of thy (c, ty)
+ val Instantiation { arities = (_, _, sorts), params = params } = Instantiation.get ctxt;
+ val thy = ProofContext.theory_of ctxt;
+ val operations = these_operations thy sorts;
+ fun subst_class_typ sort = map_atyps
+ (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty');
+ val local_constraints =
+ (map o apsnd) (subst_class_typ [] o fst o snd) operations;
+ val global_constraints = map_filter (fn (c, (class, (ty, _))) =>
+ if exists (fn ((c', _), _) => c = c') params
+ then SOME (c, subst_class_typ [class] ty)
+ else NONE) operations;
+ fun improve (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
of SOME tyco => (case AList.lookup (op =) params (c, tyco)
- of SOME v_ty => Free v_ty
- | NONE => t)
- | NONE => t)
- | t => t);
-
- val improvement = (fold o fold_aterms) check_improve ts Vartab.empty;
- val ts' = (map o map_types) (Envir.typ_subst_TVars improvement) ts;
- val ts'' = map subst_param ts';
- in if eq_list (op aconv) (ts, ts'') then NONE else SOME (ts'', lthy) end;
-
-fun inst_term_uncheck ts lthy =
- let
- val params = instantiation_params lthy;
- val ts' = (map o map_aterms) (fn t as Free (v, ty) =>
- (case get_first (fn ((c, _), (v', _)) => if v = v' then SOME c else NONE) params
- of SOME c => Const (c, ty)
- | NONE => t)
- | t => t) ts;
- in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
+ of SOME (_, ty') => SOME (ty, ty')
+ | NONE => NONE)
+ | NONE => NONE;
+ (*| NONE => (case map_filter
+ (fn ((c', _), (_, ty')) => if c' = c then SOME ty' else NONE) params
+ of [ty'] => (case Sign.const_typargs thy (c, ty)
+ of [TVar (vi, _)] => if TypeInfer.is_param vi then SOME (ty, ty') else NONE
+ | _ => NONE)
+ | _ => NONE*);
+ fun subst (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
+ of SOME tyco => (case AList.lookup (op =) params (c, tyco)
+ of SOME (v_ty as (_, ty)) => SOME (ty, Free v_ty)
+ | NONE => NONE)
+ | NONE => NONE;
+ val unchecks =
+ map (fn ((c, _), v_ty as (_, ty)) => (Free v_ty, Const (c, ty))) params;
+ in
+ ctxt
+ |> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints
+ |> Overloading.map_improvable_syntax (K {
+ local_constraints = local_constraints,
+ global_constraints = global_constraints,
+ improve = improve,
+ subst = subst,
+ unchecks = unchecks,
+ passed = false
+ })
+ end;
(* target *)
@@ -786,12 +733,14 @@
|> Instantiation.put (mk_instantiation ((tycos, vs, sort), params))
|> fold (Variable.declare_term o Logic.mk_type o TFree) vs
|> fold (Variable.declare_names o Free o snd) params
- |> fold (fn tyco => ProofContext.add_arity (tyco, map snd vs, sort)) tycos
- |> Context.proof_map (
- Syntax.add_term_check 0 "instance" inst_term_check
- #> Syntax.add_term_uncheck 0 "instance" inst_term_uncheck)
+ |> synchronize_inst_syntax
+ |> Overloading.add_improvable_syntax
end;
+fun confirm_declaration c = (map_instantiation o apsnd)
+ (filter_out (fn (_, (c', _)) => c' = c))
+ #> LocalTheory.target synchronize_inst_syntax
+
fun gen_instantiation_instance do_proof after_qed lthy =
let
val (tycos, vs, sort) = (#arities o the_instantiation) lthy;
--- a/src/Pure/Isar/overloading.ML Fri Mar 07 13:53:06 2008 +0100
+++ b/src/Pure/Isar/overloading.ML Fri Mar 07 13:53:07 2008 +0100
@@ -14,6 +14,11 @@
val define: bool -> string -> string * term -> theory -> thm * theory
val operation: Proof.context -> string -> (string * bool) option
val pretty: Proof.context -> Pretty.T
+
+ type improvable_syntax
+ val add_improvable_syntax: Proof.context -> Proof.context
+ val map_improvable_syntax: (improvable_syntax -> improvable_syntax)
+ -> Proof.context -> Proof.context
end;
structure Overloading: OVERLOADING =
@@ -44,45 +49,101 @@
Thm.add_def (not checked) true (name, Logic.mk_equals (Const (c, Term.fastype_of t), t));
-(* syntax *)
+(* generic check/uncheck combinators for improvable constants *)
+
+type improvable_syntax = {
+ local_constraints: (string * typ) list,
+ global_constraints: (string * typ) list,
+ improve: string * typ -> (typ * typ) option,
+ subst: string * typ -> (typ * term) option,
+ unchecks: (term * term) list,
+ passed: bool
+};
-fun subst_operation overloading = map_aterms (fn t as Const (c, ty) =>
- (case AList.lookup (op =) overloading (c, ty)
- of SOME (v, _) => Free (v, ty)
- | NONE => t)
- | t => t);
+structure ImprovableSyntax = ProofDataFun(
+ type T = improvable_syntax;
+ fun init _ = {
+ local_constraints = [],
+ global_constraints = [],
+ improve = K NONE,
+ subst = K NONE,
+ unchecks = [],
+ passed = true
+ };
+);
-fun term_check ts lthy =
+val map_improvable_syntax = ImprovableSyntax.map;
+
+val mark_passed = map_improvable_syntax
+ (fn { local_constraints, global_constraints, improve, subst, unchecks, passed } =>
+ { local_constraints = local_constraints, global_constraints = global_constraints,
+ improve = improve, subst = subst, unchecks = unchecks, passed = true });
+
+fun improve_term_check ts ctxt =
let
- val overloading = get_overloading lthy;
- val ts' = map (subst_operation overloading) ts;
- in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
+ val { local_constraints, global_constraints, improve, subst, passed, ... } =
+ ImprovableSyntax.get ctxt;
+ val tsig = (Sign.tsig_of o ProofContext.theory_of) ctxt;
+ fun accumulate_improvements (Const (c, ty)) = (case improve (c, ty)
+ of SOME ty_ty' => (perhaps o try o Type.typ_match tsig) ty_ty'
+ | _ => I)
+ | accumulate_improvements _ = I;
+ val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty;
+ val ts' = (map o map_types) (Envir.typ_subst_TVars improvements) ts;
+ fun apply_subst t = Envir.expand_term (fn Const (c, ty) => (case subst (c, ty)
+ of SOME (ty', t') =>
+ if Type.typ_instance tsig (ty, ty')
+ then SOME (ty', apply_subst t') else NONE
+ | NONE => NONE)
+ | _ => NONE) t;
+ val ts'' = map apply_subst ts';
+ in if eq_list (op aconv) (ts, ts'') andalso passed then NONE else
+ if passed then SOME (ts'', ctxt)
+ else SOME (ts'', ctxt
+ |> fold (ProofContext.add_const_constraint o apsnd SOME) global_constraints
+ |> mark_passed)
+ end;
-fun term_uncheck ts lthy =
+fun improve_term_uncheck ts ctxt =
let
- val overloading = get_overloading lthy;
- fun subst (t as Free (v, ty)) = (case get_first (fn ((c, _), (v', _)) =>
- if v = v' then SOME c else NONE) overloading
- of SOME c => Const (c, ty)
- | NONE => t)
- | subst t = t;
- val ts' = (map o map_aterms) subst ts;
- in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
+ val thy = ProofContext.theory_of ctxt;
+ val unchecks = (#unchecks o ImprovableSyntax.get) ctxt;
+ val ts' = map (Pattern.rewrite_term thy unchecks []) ts;
+ in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
+
+fun add_improvable_syntax ctxt = ctxt
+ |> Context.proof_map
+ (Syntax.add_term_check 0 "improvement" improve_term_check
+ #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck)
+ |> fold (ProofContext.add_const_constraint o apsnd SOME)
+ ((#local_constraints o ImprovableSyntax.get) ctxt);
(* target *)
-fun init overloading thy =
+fun init raw_overloading thy =
let
- val _ = if null overloading then error "At least one parameter must be given" else ();
+ val _ = if null raw_overloading then error "At least one parameter must be given" else ();
+ val overloading = map (fn (v, c_ty, checked) => (c_ty, (v, checked))) raw_overloading;
+ fun subst (c, ty) = case AList.lookup (op =) overloading (c, ty)
+ of SOME (v, _) => SOME (ty, Free (v, ty))
+ | NONE => NONE;
+ val unchecks =
+ map (fn (c_ty as (_, ty), (v, _)) => (Free (v, ty), Const c_ty)) overloading;
in
thy
|> ProofContext.init
- |> OverloadingData.put (map (fn (v, c_ty, checked) => (c_ty, (v, checked))) overloading)
- |> fold (fn (v, (_, ty), _) => Variable.declare_term (Free (v, ty))) overloading
- |> Context.proof_map (
- Syntax.add_term_check 0 "overloading" term_check
- #> Syntax.add_term_uncheck 0 "overloading" term_uncheck)
+ |> OverloadingData.put overloading
+ |> fold (fn (v, (_, ty), _) => Variable.declare_term (Free (v, ty))) raw_overloading
+ |> map_improvable_syntax (K {
+ local_constraints = [],
+ global_constraints = [],
+ improve = K NONE,
+ subst = subst,
+ unchecks = unchecks,
+ passed = false
+ })
+ |> add_improvable_syntax
end;
fun conclude lthy =