diff -r 98e6a0a011f3 -r 7dc7dcd63224 src/Pure/Tools/class_package.ML --- a/src/Pure/Tools/class_package.ML Tue Jan 17 10:26:50 2006 +0100 +++ b/src/Pure/Tools/class_package.ML Tue Jan 17 16:36:57 2006 +0100 @@ -18,23 +18,21 @@ -> ((bstring * term) * theory attribute list) list -> theory -> Proof.state val add_classentry: class -> xstring list -> xstring list -> theory -> theory - val the_consts: theory -> class -> string list - val the_tycos: theory -> class -> (string * string) list - val print_classes: theory -> unit val syntactic_sort_of: theory -> sort -> sort - val get_arities: theory -> sort -> string -> sort list - val get_superclasses: theory -> class -> class list - val get_const_sign: theory -> string -> string -> typ - val get_inst_consts_sign: theory -> string * class -> (string * typ) list + val the_superclasses: theory -> class -> class list + val the_consts_sign: theory -> class -> string * (string * typ) list val lookup_const_class: theory -> string -> class option + val the_instances: theory -> class -> (string * string) list + val the_inst_sign: theory -> class * string -> (string * sort) list * (string * typ) list val get_classtab: theory -> (string list * (string * string) list) Symtab.table + val print_classes: theory -> unit type sortcontext = (string * sort) list datatype sortlookup = Instance of (class * string) * sortlookup list list | Lookup of class list * (string * int) val extract_sortctxt: theory -> typ -> sortcontext - val extract_sortlookup: theory -> typ * typ -> sortlookup list list + val extract_sortlookup: theory -> string * typ -> sortlookup list list end; structure ClassPackage: CLASS_PACKAGE = @@ -126,21 +124,19 @@ insts = insts @ [inst] }); -val the_consts = map fst o #consts oo get_class_data; -val the_tycos = #insts oo get_class_data; - (* classes and instances *) +fun subst_clsvar v ty_subst = + map_type_tfree (fn u as (w, _) => + if w = v then ty_subst else TFree u); + local open Element fun gen_add_class add_locale bname raw_import raw_body thy = let - fun subst_clsvar v ty_subst = - map_type_tfree (fn u as (w, _) => - if w = v then ty_subst else TFree u); fun extract_assumes c_adds elems = let fun subst_free ts = @@ -240,7 +236,9 @@ fun get_c_given thy = map (fst o dest_def o snd o tap_def thy o fst) raw_defs; fun check_defs c_given c_req thy = let - fun eq_c ((c1, ty1), (c2, ty2)) = c1 = c2 andalso Sign.typ_instance thy (ty1, ty2) + fun eq_c ((c1, ty1), (c2, ty2)) = c1 = c2 + andalso Sign.typ_instance thy (ty1, ty2) + andalso Sign.typ_instance thy (ty2, ty1) val _ = case fold (remove eq_c) c_given c_req of [] => () | cs => error ("no definition(s) given for" @@ -263,9 +261,12 @@ val add_instance_arity_i = fn x => gen_instance_arity (AxClass.cert_arity) IsarThy.add_defs_i (K I) x; -(* class queries *) +(* queries *) -fun is_class thy cls = lookup_class_data thy cls |> Option.map (not o null o #consts) |> the_default false; +fun is_class thy cls = + lookup_class_data thy cls + |> Option.map (not o null o #consts) + |> the_default false; fun syntactic_sort_of thy sort = let @@ -280,11 +281,7 @@ |> Sorts.norm_sort classes end; -fun get_arities thy sort tycon = - Sorts.mg_domain (Sign.classes_arities_of thy) tycon (syntactic_sort_of thy sort) - |> map (syntactic_sort_of thy); - -fun get_superclasses thy class = +fun the_superclasses thy class = if is_class thy class then Sorts.superclasses (Sign.classes_of thy) class @@ -292,49 +289,43 @@ else error ("no syntactic class: " ^ class); - -(* instance queries *) - -fun mk_const_sign thy class tvar ty = +fun the_consts_sign thy class = let - val (ty', thaw) = Type.freeze_thaw_type ty; - val tvars_used = Term.add_tfreesT ty' []; - val tvar_rename = hd (Term.invent_names (map fst tvars_used) tvar 1); - in - ty' - |> map_type_tfree (fn (tvar', sort) => - if Sorts.sort_eq (Sign.classes_of thy) ([class], sort) - then TFree (tvar, []) - else if tvar' = tvar - then TVar ((tvar_rename, 0), sort) - else TFree (tvar', sort)) - |> thaw - end; + val data = (the oo Symtab.lookup) ((fst o ClassData.get) thy) class + in (#var data, #consts data) end; + +fun lookup_const_class thy = + Symtab.lookup ((snd o ClassData.get) thy); + +fun the_instances thy class = + (#insts o the o Symtab.lookup ((fst o ClassData.get) thy)) class; -fun get_const_sign thy tvar const = - let - val class = (the o lookup_const_class thy) const; - val ty = Sign.the_const_constraint thy const; - in mk_const_sign thy class tvar ty end; - -fun get_inst_consts_sign thy (tyco, class) = +fun the_inst_sign thy (class, tyco) = let - val consts = the_consts thy class; - val arities = get_arities thy [class] tyco; - val const_signs = map (get_const_sign thy "'a") consts; - val vars_used = fold (fn ty => curry (gen_union (op =)) - (map fst (typ_tfrees ty) |> remove (op =) "'a")) const_signs []; - val vars_new = Term.invent_names vars_used "'a" (length arities); - val typ_arity = Type (tyco, map2 (curry TFree) vars_new arities); - val instmem_signs = - map (typ_subst_TVars [(("'a", 0), typ_arity)]) const_signs; - in consts ~~ instmem_signs end; + val _ = if is_class thy class then () else error ("no syntactic class: " ^ class); + val arity = + Sorts.mg_domain (Sign.classes_arities_of thy) tyco [class] + |> map (syntactic_sort_of thy); + val clsvar = (#var o the o Symtab.lookup ((fst o ClassData.get) thy)) class; + val const_sign = (snd o the_consts_sign thy) class; + fun add_var sort used = + let + val v = hd (Term.invent_names used "'a" 1) + in ((v, sort), v::used) end; + val (vsorts, _) = + [] + |> fold (fn (_, ty) => curry (gen_union (op =)) + ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) const_sign + |> fold_map add_var arity; + val ty_inst = Type (tyco, map (fn (v, sort) => TVar ((v, 0), sort)) vsorts); + val inst_signs = map (apsnd (subst_clsvar clsvar ty_inst)) const_sign; + in (vsorts, inst_signs) end; fun get_classtab thy = Symtab.fold (fn (class, { consts = consts, insts = insts, ... }) => Symtab.update_new (class, (map fst consts, insts))) - (fst (ClassData.get thy)) Symtab.empty; + ((fst o ClassData.get) thy) Symtab.empty; (* extracting dictionary obligations from types *) @@ -342,15 +333,16 @@ type sortcontext = (string * sort) list; fun extract_sortctxt thy ty = - (typ_tfrees o Type.no_tvars) ty + (typ_tfrees o fst o Type.freeze_thaw_type) ty |> map (apsnd (syntactic_sort_of thy)) |> filter (not o null o snd); datatype sortlookup = Instance of (class * string) * sortlookup list list | Lookup of class list * (string * int) -fun extract_sortlookup thy (raw_typ_def, raw_typ_use) = +fun extract_sortlookup thy (c, raw_typ_use) = let + val raw_typ_def = Sign.the_const_constraint thy c; val typ_def = Type.varifyT raw_typ_def; val typ_use = Type.varifyT raw_typ_use; val match_tab = Sign.typ_match thy (typ_def, typ_use) Vartab.empty; @@ -374,8 +366,22 @@ let val (deriv, classindex) = mk_class_deriv thy (syntactic_sort_of thy sort_use) class in Lookup (deriv, (vname, classindex)) end; in map mk_look sort_def end; + fun reorder_sortctxt ctxt = + case lookup_const_class thy c + of NONE => ctxt + | SOME class => + let + val data = (the o Symtab.lookup ((fst o ClassData.get) thy)) class; + val sign = (Type.varifyT o the o AList.lookup (op =) (#consts data)) c; + val match_tab = Sign.typ_match thy (sign, typ_def) Vartab.empty; + val v : string = case Vartab.lookup match_tab (#var data, 0) + of SOME (_, TVar ((v, _), _)) => v; + in + (v, (the o AList.lookup (op =) ctxt) v) :: AList.delete (op =) v ctxt + end; in extract_sortctxt thy ((fst o Type.freeze_thaw_type) raw_typ_def) + |> reorder_sortctxt |> map (tab_lookup o fst) |> map (apfst (syntactic_sort_of thy)) |> filter (not o null o fst) @@ -388,11 +394,26 @@ fun add_classentry raw_class raw_cs raw_insts thy = let val class = Sign.intern_class thy raw_class; - val cs = raw_cs |> map (Sign.intern_const thy); + val cs_proto = + raw_cs + |> map (Sign.intern_const thy) + |> map (fn c => (c, Sign.the_const_constraint thy c)); + val used = + [] + |> fold (fn (_, ty) => curry (gen_union (op =)) + ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) cs_proto + val v = hd (Term.invent_names used "'a" 1); + val cs = + cs_proto + |> map (fn (c, ty) => (c, map_type_tvar (fn var as ((tvar', _), sort) => + if Sorts.sort_eq (Sign.classes_of thy) ([class], sort) + then TFree (v, []) + else TVar var + ) ty)); val insts = map (rpair (Context.theory_name thy) o Sign.intern_type thy) raw_insts; in thy - |> add_class_data (class, ([], "", class, "", map (rpair dummyT) cs)) + |> add_class_data (class, ([], "", class, v, cs)) |> fold (curry add_inst_data class) insts end;