--- a/src/Pure/Isar/class.ML Thu Sep 20 16:37:28 2007 +0200
+++ b/src/Pure/Isar/class.ML Thu Sep 20 16:37:29 2007 +0200
@@ -44,6 +44,14 @@
val inst_const: theory -> string * string -> string
val param_const: theory -> string -> (string * string) option
val params_of_sort: theory -> sort -> (string * (string * typ)) list
+
+ (*experimental*)
+ val init_ref: (class -> Proof.context -> (theory -> theory) * Proof.context) ref
+ val init: class -> Proof.context -> (theory -> theory) * Proof.context;
+ val init_default: class -> Proof.context -> (theory -> theory) * Proof.context;
+ val remove_constraints: class -> theory -> (string * typ) list * theory
+ val class_term_check: theory -> class -> term list -> Proof.context -> term list * Proof.context
+ val local_param: theory -> class -> string -> (term * (class * int)) option
end;
structure Class : CLASS =
@@ -91,6 +99,15 @@
| NONE => thm;
in strip end;
+fun get_remove_contraint c thy =
+ let
+ val ty = Sign.the_const_constraint thy c;
+ in
+ thy
+ |> Sign.add_const_constraint_i (c, NONE)
+ |> pair (c, Logic.unvarifyT ty)
+ end;
+
(** axclass command **)
@@ -277,14 +294,6 @@
in fold_map read defs cs end;
val (defs, other_cs) = read_defs raw_defs cs
(fold Sign.primitive_arity arities (Theory.copy theory));
- fun get_remove_contraint c thy =
- let
- val ty = Sign.the_const_constraint thy c;
- in
- thy
- |> Sign.add_const_constraint_i (c, NONE)
- |> pair (c, Logic.unvarifyT ty)
- end;
fun after_qed' cs defs =
fold Sign.add_const_constraint_i (map (apsnd SOME) cs)
#> after_qed defs;
@@ -320,30 +329,45 @@
datatype class_data = ClassData of {
locale: string,
consts: (string * string) list
- (*locale parameter ~> toplevel theory constant*),
- v: string option,
+ (*locale parameter ~> theory constant name*),
+ v: string,
inst: typ Symtab.table * term Symtab.table
(*canonical interpretation*),
- intro: thm
-} * thm list (*derived defs*);
+ intro: thm,
+ defs: thm list,
+ localized: (string * (term * (class * int))) list
+ (*theory constant name ~> (locale parameter, (class, instantiaton index of class typ))*)
+};
-fun rep_classdata (ClassData c) = c;
+fun rep_class_data (ClassData d) = d;
+fun mk_class_data ((locale, consts, v, inst, intro), (defs, localized)) =
+ ClassData { locale = locale, consts = consts, v = v, inst = inst, intro = intro,
+ defs = defs, localized = localized };
+fun map_class_data f (ClassData { locale, consts, v, inst, intro, defs, localized }) =
+ mk_class_data (f ((locale, consts, v, inst, intro), (defs, localized)))
+fun merge_class_data _ (ClassData { locale = locale, consts = consts, v = v, inst = inst,
+ intro = intro, defs = defs1, localized = localized1 },
+ ClassData { locale = _, consts = _, v = _, inst = _, intro = _,
+ defs = defs2, localized = localized2 }) =
+ mk_class_data ((locale, consts, v, inst, intro),
+ (Thm.merge_thms (defs1, defs2), AList.merge (op =) (K true) (localized1, localized2)));
fun merge_pair f1 f2 ((x1, y1), (x2, y2)) = (f1 (x1, x2), f2 (y1, y2));
structure ClassData = TheoryDataFun
(
- type T = class_data Graph.T * class Symtab.table (*locale name ~> class name*);
+ type T = class_data Graph.T * class Symtab.table
+ (*locale name ~> class name*);
val empty = (Graph.empty, Symtab.empty);
val copy = I;
val extend = I;
- fun merge _ = merge_pair (Graph.merge (K true)) (Symtab.merge (K true));
+ fun merge _ = merge_pair (Graph.join merge_class_data) (Symtab.merge (K true));
);
(* queries *)
-val lookup_class_data = Option.map rep_classdata oo try o Graph.get_node
+val lookup_class_data = Option.map rep_class_data oo try o Graph.get_node
o fst o ClassData.get;
fun class_of_locale thy = Symtab.lookup ((snd o ClassData.get) thy);
@@ -358,18 +382,23 @@
fun params class =
let
val const_typs = (#params o AxClass.get_definition thy) class;
- val const_names = (#consts o fst o the_class_data thy) class;
+ val const_names = (#consts o the_class_data thy) class;
in
(map o apsnd) (fn c => (c, (the o AList.lookup (op =) const_typs) c)) const_names
end;
in maps params o ancestry thy end;
-fun these_defs thy = maps (these o Option.map snd o lookup_class_data thy) o ancestry thy;
+fun these_defs thy = maps (these o Option.map #defs o lookup_class_data thy) o ancestry thy;
fun these_intros thy =
- Graph.fold (fn (_, (data, _)) => insert Thm.eq_thm ((#intro o fst o rep_classdata) data))
+ Graph.fold (fn (_, (data, _)) => insert Thm.eq_thm ((#intro o rep_class_data) data))
((fst o ClassData.get) thy) [];
+fun these_localized thy class =
+ maps (#localized o the_class_data thy) (ancestry thy [class]);
+
+fun local_param thy = AList.lookup (op =) o these_localized thy;
+
fun print_classes thy =
let
val algebra = Sign.classes_of thy;
@@ -389,7 +418,7 @@
(SOME o Pretty.str) ("class " ^ class ^ ":"),
(SOME o Pretty.block) [Pretty.str "supersort: ",
(Sign.pretty_sort thy o Sign.certify_sort thy o Sign.super_classes thy) class],
- Option.map (Pretty.str o prefix "locale: " o #locale o fst) (lookup_class_data thy class),
+ Option.map (Pretty.str o prefix "locale: " o #locale) (lookup_class_data thy class),
((fn [] => NONE | ps => (SOME o Pretty.block o Pretty.fbreaks) (Pretty.str "parameters:" :: ps)) o map mk_param
o these o Option.map #params o try (AxClass.get_definition thy)) class,
(SOME o Pretty.block o Pretty.breaks) [
@@ -408,15 +437,16 @@
fun add_class_data ((class, superclasses), (locale, consts, v, inst, intro)) =
ClassData.map (fn (gr, tab) => (
gr
- |> Graph.new_node (class, ClassData ({ locale = locale, consts = consts,
- v = v, inst = inst, intro = intro }, []))
+ |> Graph.new_node (class, mk_class_data ((locale, (map o apfst) fst consts, v, inst, intro),
+ ([], map (apsnd (rpair (class, 0) o Free) o swap) consts)))
|> fold (curry Graph.add_edge class) superclasses,
tab
|> Symtab.update (locale, class)
));
-fun add_class_const_thm (class, thm) = (ClassData.map o apfst o Graph.map_node class)
- (fn ClassData (data, thms) => ClassData (data, thm :: thms));
+fun add_class_const_def (class, (entry, def)) =
+ (ClassData.map o apfst o Graph.map_node class o map_class_data o apsnd)
+ (fn (defs, localized) => (def :: defs, (apsnd o apsnd) (pair class) entry :: localized));
(** rule calculation, tactics and methods **)
@@ -452,7 +482,7 @@
fun class_interpretation class facts defs thy =
let
- val ({ locale, inst, ... }, _) = the_class_data thy class;
+ val { locale, inst, ... } = the_class_data thy class;
val tac = (ALLGOALS o ProofContext.fact_tac) facts;
val prfx = Logic.const_of_class (NameSpace.base class);
in
@@ -464,7 +494,7 @@
let
fun mk_axioms class =
let
- val ({ locale, inst = (_, insttab), ... }, _) = the_class_data thy class;
+ val { locale, inst = (_, insttab), ... } = the_class_data thy class;
in
Locale.global_asms_of thy locale
|> maps snd
@@ -546,7 +576,7 @@
val sups = filter (is_some o lookup_class_data thy) supclasses
|> Sign.certify_sort thy;
val supsort = Sign.certify_sort thy supclasses;
- val suplocales = map (Locale.Locale o #locale o fst o the_class_data thy) sups;
+ val suplocales = map (Locale.Locale o #locale o the_class_data thy) sups;
val supexpr = Locale.Merge (suplocales @ includes);
val supparams = (map fst o Locale.parameters_of_expr thy)
(Locale.Merge suplocales);
@@ -563,10 +593,10 @@
let
val params = Locale.parameters_of thy name_locale;
val v = case (maps typ_tfrees o map (snd o fst)) params
- of (v, _) :: _ => SOME v
- | _ => NONE;
+ of (v, _) :: _ => v
+ | [] => AxClass.param_tyvarname;
in
- (v, (map (fst o fst) params, params
+ (v, (map fst params, params
|> (map o apfst o apsnd o Term.map_type_tfree) mk_tyvar
|> (map o apsnd) (fork_mixfix true NONE #> fst)
|> chop (length supconsts)
@@ -578,7 +608,6 @@
fun subst (Free (c, ty)) =
Const ((fst o the o AList.lookup (op =) consts) c, ty)
| subst t = t;
- val super_defs = these_defs thy sups;
fun prep_asm ((name, atts), ts) =
((NameSpace.base name, map (Attrib.attribute thy) atts),
(map o map_aterms) subst ts);
@@ -595,15 +624,15 @@
|> add_locale (SOME "") bname supexpr ((*elems_constrains @*) elems)
|-> (fn name_locale => ProofContext.theory_result (
`(fn thy => extract_params thy name_locale)
- #-> (fn (v, (param_names, params)) =>
+ #-> (fn (v, (globals, params)) =>
AxClass.define_class_params (bname, supsort) params
(extract_assumes name_locale params) other_consts
#-> (fn (name_axclass, (consts, axioms)) =>
`(fn thy => class_intro thy name_locale name_axclass sups)
#-> (fn class_intro =>
add_class_data ((name_axclass, sups),
- (name_locale, map (fst o fst) params ~~ map fst consts, v,
- (mk_instT name_axclass, mk_inst name_axclass param_names
+ (name_locale, map fst params ~~ map fst consts, v,
+ (mk_instT name_axclass, mk_inst name_axclass (map fst globals)
(map snd supconsts @ consts)), class_intro))
#> note_intro name_axclass class_intro
#> class_interpretation name_axclass axioms []
@@ -619,52 +648,62 @@
end; (*local*)
+(* class target context *)
+
+fun remove_constraints class thy =
+ thy |> fold_map (get_remove_contraint o fst) (these_localized thy class);
+
+
(* definition in class target *)
fun export_fixes thy class =
let
- val v = (#v o fst o the_class_data thy) class;
- val constrain_sort = curry (Sorts.inter_sort (Sign.classes_of thy)) [class];
- val subst_typ = Term.map_type_tfree (fn var as (w, sort) =>
- if SOME w = v then TFree (w, constrain_sort sort) else TFree var);
val consts = params_of_sort thy [class];
fun subst_aterm (t as Free (v, ty)) = (case AList.lookup (op =) consts v
of SOME (c, _) => Const (c, ty)
| NONE => t)
| subst_aterm t = t;
- in map_types subst_typ #> Term.map_aterms subst_aterm end;
+ in Term.map_aterms subst_aterm end;
fun add_const_in_class class ((c, rhs), syn) thy =
let
val prfx = (Logic.const_of_class o NameSpace.base) class;
- fun mk_name inject c =
+ fun mk_name c =
let
val n1 = Sign.full_name thy c;
val n2 = NameSpace.qualifier n1;
val n3 = NameSpace.base n1;
- in NameSpace.implode (n2 :: inject @ [n3]) end;
- val abbr' = mk_name [prfx, prfx] c;
+ in NameSpace.implode [n2, prfx, n3] end;
+ val v = (#v o the_class_data thy) class;
+ val constrain_sort = curry (Sorts.inter_sort (Sign.classes_of thy)) [class];
+ val subst_typ = Term.map_type_tfree (fn var as (w, sort) =>
+ if w = v then TFree (w, constrain_sort sort) else TFree var);
val rhs' = export_fixes thy class rhs;
val ty' = Term.fastype_of rhs';
- val def = (c, Logic.mk_equals (Const (mk_name [prfx] c, ty'), rhs'));
+ val ty'' = subst_typ ty';
+ val c' = mk_name c;
+ val def = (c, Logic.mk_equals (Const (c', ty'), rhs'));
val (syn', _) = fork_mixfix true NONE syn;
- fun interpret def =
+ fun interpret def thy =
let
val def' = symmetric def;
val def_eq = Thm.prop_of def';
+ val typargs = Sign.const_typargs thy (c', fastype_of rhs);
+ val typidx = find_index (fn TFree (w, _) => v = w | _ => false) typargs;
in
- class_interpretation class [def'] [def_eq]
- #> add_class_const_thm (class, def')
+ thy
+ |> class_interpretation class [def'] [def_eq]
+ |> add_class_const_def (class, ((c', (rhs, typidx)), def'))
end;
in
thy
- |> Sign.hide_consts_i true [abbr']
|> Sign.add_path prfx
|> Sign.add_consts_authentic [(c, ty', syn')]
|> Sign.parent_path
|> Sign.sticky_prefix prfx
|> PureThy.add_defs_i false [(def, [])]
|-> (fn [def] => interpret def)
+ |> Sign.add_const_constraint_i (c', SOME ty'')
|> Sign.restore_naming thy
end;
@@ -677,8 +716,8 @@
let
val class = prep_class theory raw_class;
val superclass = prep_class theory raw_superclass;
- val loc_name = (#locale o fst o the_class_data theory) class;
- val loc_expr = (Locale.Locale o #locale o fst o the_class_data theory) superclass;
+ val loc_name = (#locale o the_class_data theory) class;
+ val loc_expr = (Locale.Locale o #locale o the_class_data theory) superclass;
fun prove_classrel (class, superclass) thy =
let
val classes = (Graph.all_succs o #classes o Sorts.rep_algebra
@@ -717,4 +756,52 @@
end; (*local*)
+(*experimental*)
+fun class_term_check thy class =
+ let
+ val algebra = Sign.classes_of thy;
+ val { v, ... } = the_class_data thy class;
+ fun add_constrain_classtyp sort' (ty as TFree (v, _)) =
+ AList.map_default (op =) (v, []) (curry (Sorts.inter_sort algebra) sort')
+ | add_constrain_classtyp sort' (Type (tyco, tys)) = case Sorts.mg_domain algebra tyco sort'
+ of sorts => fold2 add_constrain_classtyp sorts tys;
+ fun class_arg c idx ty =
+ let
+ val typargs = Sign.const_typargs thy (c, ty);
+ fun classtyp (t as TFree (w, _)) = if w = v then NONE else SOME t
+ | classtyp t = SOME t;
+ in classtyp (nth typargs idx) end;
+ fun add_inst (c, ty) (terminsts, typinsts) = case local_param thy class c
+ of NONE => (terminsts, typinsts)
+ | SOME (t, (class', idx)) => (case class_arg c idx ty
+ of NONE => (((c, ty), t) :: terminsts, typinsts)
+ | SOME ty => (terminsts, add_constrain_classtyp [class'] ty typinsts));
+ in pair o (fn ts => let
+ val cs = (fold o fold_aterms) (fn Const c_ty => insert (op =) c_ty | _ => I) ts [];
+ val (terminsts, typinsts) = fold add_inst cs ([], []);
+ in
+ ts
+ |> (map o map_aterms) (fn t as Const c_ty => the_default t (AList.lookup (op =) terminsts c_ty)
+ | t => t)
+ |> (map o map_types o map_atyps) (fn t as TFree (v, sort) =>
+ case AList.lookup (op =) typinsts v
+ of SOME sort' => TFree (v, Sorts.inter_sort algebra (sort, sort'))
+ | NONE => t)
+ end) end;
+
+val init_ref = ref (K (pair I) : class -> Proof.context -> (theory -> theory) * Proof.context);
+fun init class = ! init_ref class;
+
+fun init_default class ctxt =
+ let
+ val thy = ProofContext.theory_of ctxt;
+ val term_check = class_term_check thy class;
+ in
+ ctxt
+ (*|> ProofContext.theory_result (remove_constraints class)*)
+ |> Context.proof_map (Syntax.add_term_check term_check)
+ (*|>> fold (fn (c, ty) => Sign.add_const_constraint_i (c, SOME ty))*)
+ |> pair I
+ end;
+
end;