--- a/src/Pure/Tools/class_package.ML Tue Jan 17 10:26:50 2006 +0100
+++ b/src/Pure/Tools/class_package.ML Tue Jan 17 16:36:57 2006 +0100
@@ -18,23 +18,21 @@
-> ((bstring * term) * theory attribute list) list
-> theory -> Proof.state
val add_classentry: class -> xstring list -> xstring list -> theory -> theory
- val the_consts: theory -> class -> string list
- val the_tycos: theory -> class -> (string * string) list
- val print_classes: theory -> unit
val syntactic_sort_of: theory -> sort -> sort
- val get_arities: theory -> sort -> string -> sort list
- val get_superclasses: theory -> class -> class list
- val get_const_sign: theory -> string -> string -> typ
- val get_inst_consts_sign: theory -> string * class -> (string * typ) list
+ val the_superclasses: theory -> class -> class list
+ val the_consts_sign: theory -> class -> string * (string * typ) list
val lookup_const_class: theory -> string -> class option
+ val the_instances: theory -> class -> (string * string) list
+ val the_inst_sign: theory -> class * string -> (string * sort) list * (string * typ) list
val get_classtab: theory -> (string list * (string * string) list) Symtab.table
+ val print_classes: theory -> unit
type sortcontext = (string * sort) list
datatype sortlookup = Instance of (class * string) * sortlookup list list
| Lookup of class list * (string * int)
val extract_sortctxt: theory -> typ -> sortcontext
- val extract_sortlookup: theory -> typ * typ -> sortlookup list list
+ val extract_sortlookup: theory -> string * typ -> sortlookup list list
end;
structure ClassPackage: CLASS_PACKAGE =
@@ -126,21 +124,19 @@
insts = insts @ [inst]
});
-val the_consts = map fst o #consts oo get_class_data;
-val the_tycos = #insts oo get_class_data;
-
(* classes and instances *)
+fun subst_clsvar v ty_subst =
+ map_type_tfree (fn u as (w, _) =>
+ if w = v then ty_subst else TFree u);
+
local
open Element
fun gen_add_class add_locale bname raw_import raw_body thy =
let
- fun subst_clsvar v ty_subst =
- map_type_tfree (fn u as (w, _) =>
- if w = v then ty_subst else TFree u);
fun extract_assumes c_adds elems =
let
fun subst_free ts =
@@ -240,7 +236,9 @@
fun get_c_given thy = map (fst o dest_def o snd o tap_def thy o fst) raw_defs;
fun check_defs c_given c_req thy =
let
- fun eq_c ((c1, ty1), (c2, ty2)) = c1 = c2 andalso Sign.typ_instance thy (ty1, ty2)
+ fun eq_c ((c1, ty1), (c2, ty2)) = c1 = c2
+ andalso Sign.typ_instance thy (ty1, ty2)
+ andalso Sign.typ_instance thy (ty2, ty1)
val _ = case fold (remove eq_c) c_given c_req
of [] => ()
| cs => error ("no definition(s) given for"
@@ -263,9 +261,12 @@
val add_instance_arity_i = fn x => gen_instance_arity (AxClass.cert_arity) IsarThy.add_defs_i (K I) x;
-(* class queries *)
+(* queries *)
-fun is_class thy cls = lookup_class_data thy cls |> Option.map (not o null o #consts) |> the_default false;
+fun is_class thy cls =
+ lookup_class_data thy cls
+ |> Option.map (not o null o #consts)
+ |> the_default false;
fun syntactic_sort_of thy sort =
let
@@ -280,11 +281,7 @@
|> Sorts.norm_sort classes
end;
-fun get_arities thy sort tycon =
- Sorts.mg_domain (Sign.classes_arities_of thy) tycon (syntactic_sort_of thy sort)
- |> map (syntactic_sort_of thy);
-
-fun get_superclasses thy class =
+fun the_superclasses thy class =
if is_class thy class
then
Sorts.superclasses (Sign.classes_of thy) class
@@ -292,49 +289,43 @@
else
error ("no syntactic class: " ^ class);
-
-(* instance queries *)
-
-fun mk_const_sign thy class tvar ty =
+fun the_consts_sign thy class =
let
- val (ty', thaw) = Type.freeze_thaw_type ty;
- val tvars_used = Term.add_tfreesT ty' [];
- val tvar_rename = hd (Term.invent_names (map fst tvars_used) tvar 1);
- in
- ty'
- |> map_type_tfree (fn (tvar', sort) =>
- if Sorts.sort_eq (Sign.classes_of thy) ([class], sort)
- then TFree (tvar, [])
- else if tvar' = tvar
- then TVar ((tvar_rename, 0), sort)
- else TFree (tvar', sort))
- |> thaw
- end;
+ val data = (the oo Symtab.lookup) ((fst o ClassData.get) thy) class
+ in (#var data, #consts data) end;
+
+fun lookup_const_class thy =
+ Symtab.lookup ((snd o ClassData.get) thy);
+
+fun the_instances thy class =
+ (#insts o the o Symtab.lookup ((fst o ClassData.get) thy)) class;
-fun get_const_sign thy tvar const =
- let
- val class = (the o lookup_const_class thy) const;
- val ty = Sign.the_const_constraint thy const;
- in mk_const_sign thy class tvar ty end;
-
-fun get_inst_consts_sign thy (tyco, class) =
+fun the_inst_sign thy (class, tyco) =
let
- val consts = the_consts thy class;
- val arities = get_arities thy [class] tyco;
- val const_signs = map (get_const_sign thy "'a") consts;
- val vars_used = fold (fn ty => curry (gen_union (op =))
- (map fst (typ_tfrees ty) |> remove (op =) "'a")) const_signs [];
- val vars_new = Term.invent_names vars_used "'a" (length arities);
- val typ_arity = Type (tyco, map2 (curry TFree) vars_new arities);
- val instmem_signs =
- map (typ_subst_TVars [(("'a", 0), typ_arity)]) const_signs;
- in consts ~~ instmem_signs end;
+ val _ = if is_class thy class then () else error ("no syntactic class: " ^ class);
+ val arity =
+ Sorts.mg_domain (Sign.classes_arities_of thy) tyco [class]
+ |> map (syntactic_sort_of thy);
+ val clsvar = (#var o the o Symtab.lookup ((fst o ClassData.get) thy)) class;
+ val const_sign = (snd o the_consts_sign thy) class;
+ fun add_var sort used =
+ let
+ val v = hd (Term.invent_names used "'a" 1)
+ in ((v, sort), v::used) end;
+ val (vsorts, _) =
+ []
+ |> fold (fn (_, ty) => curry (gen_union (op =))
+ ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) const_sign
+ |> fold_map add_var arity;
+ val ty_inst = Type (tyco, map (fn (v, sort) => TVar ((v, 0), sort)) vsorts);
+ val inst_signs = map (apsnd (subst_clsvar clsvar ty_inst)) const_sign;
+ in (vsorts, inst_signs) end;
fun get_classtab thy =
Symtab.fold
(fn (class, { consts = consts, insts = insts, ... }) =>
Symtab.update_new (class, (map fst consts, insts)))
- (fst (ClassData.get thy)) Symtab.empty;
+ ((fst o ClassData.get) thy) Symtab.empty;
(* extracting dictionary obligations from types *)
@@ -342,15 +333,16 @@
type sortcontext = (string * sort) list;
fun extract_sortctxt thy ty =
- (typ_tfrees o Type.no_tvars) ty
+ (typ_tfrees o fst o Type.freeze_thaw_type) ty
|> map (apsnd (syntactic_sort_of thy))
|> filter (not o null o snd);
datatype sortlookup = Instance of (class * string) * sortlookup list list
| Lookup of class list * (string * int)
-fun extract_sortlookup thy (raw_typ_def, raw_typ_use) =
+fun extract_sortlookup thy (c, raw_typ_use) =
let
+ val raw_typ_def = Sign.the_const_constraint thy c;
val typ_def = Type.varifyT raw_typ_def;
val typ_use = Type.varifyT raw_typ_use;
val match_tab = Sign.typ_match thy (typ_def, typ_use) Vartab.empty;
@@ -374,8 +366,22 @@
let val (deriv, classindex) = mk_class_deriv thy (syntactic_sort_of thy sort_use) class
in Lookup (deriv, (vname, classindex)) end;
in map mk_look sort_def end;
+ fun reorder_sortctxt ctxt =
+ case lookup_const_class thy c
+ of NONE => ctxt
+ | SOME class =>
+ let
+ val data = (the o Symtab.lookup ((fst o ClassData.get) thy)) class;
+ val sign = (Type.varifyT o the o AList.lookup (op =) (#consts data)) c;
+ val match_tab = Sign.typ_match thy (sign, typ_def) Vartab.empty;
+ val v : string = case Vartab.lookup match_tab (#var data, 0)
+ of SOME (_, TVar ((v, _), _)) => v;
+ in
+ (v, (the o AList.lookup (op =) ctxt) v) :: AList.delete (op =) v ctxt
+ end;
in
extract_sortctxt thy ((fst o Type.freeze_thaw_type) raw_typ_def)
+ |> reorder_sortctxt
|> map (tab_lookup o fst)
|> map (apfst (syntactic_sort_of thy))
|> filter (not o null o fst)
@@ -388,11 +394,26 @@
fun add_classentry raw_class raw_cs raw_insts thy =
let
val class = Sign.intern_class thy raw_class;
- val cs = raw_cs |> map (Sign.intern_const thy);
+ val cs_proto =
+ raw_cs
+ |> map (Sign.intern_const thy)
+ |> map (fn c => (c, Sign.the_const_constraint thy c));
+ val used =
+ []
+ |> fold (fn (_, ty) => curry (gen_union (op =))
+ ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) cs_proto
+ val v = hd (Term.invent_names used "'a" 1);
+ val cs =
+ cs_proto
+ |> map (fn (c, ty) => (c, map_type_tvar (fn var as ((tvar', _), sort) =>
+ if Sorts.sort_eq (Sign.classes_of thy) ([class], sort)
+ then TFree (v, [])
+ else TVar var
+ ) ty));
val insts = map (rpair (Context.theory_name thy) o Sign.intern_type thy) raw_insts;
in
thy
- |> add_class_data (class, ([], "", class, "", map (rpair dummyT) cs))
+ |> add_class_data (class, ([], "", class, v, cs))
|> fold (curry add_inst_data class) insts
end;