(* Title: Pure/Tools/class_package.ML
ID: $Id$
Author: Florian Haftmann, TU Muenchen
Haskell98-like operational view on type classes.
*)
signature CLASS_PACKAGE =
sig
val add_class: bstring -> Locale.expr -> Element.context list -> theory
-> ProofContext.context * theory
val add_class_i: bstring -> Locale.expr -> Element.context_i list -> theory
-> ProofContext.context * theory
val add_instance_arity: (xstring * string list) * string
-> ((bstring * string) * Attrib.src list) list
-> theory -> Proof.state
val add_instance_arity_i: (string * sort list) * sort
-> ((bstring * term) * theory attribute list) list
-> theory -> Proof.state
val add_classentry: class -> xstring list -> xstring list -> theory -> theory
val the_consts: theory -> class -> string list
val the_tycos: theory -> class -> (string * string) list
val print_classes: theory -> unit
val syntactic_sort_of: theory -> sort -> sort
val get_arities: theory -> sort -> string -> sort list
val get_superclasses: theory -> class -> class list
val get_const_sign: theory -> string -> string -> typ
val get_inst_consts_sign: theory -> string * class -> (string * typ) list
val lookup_const_class: theory -> string -> class option
val get_classtab: theory -> (string list * (string * string) list) Symtab.table
type sortcontext = (string * sort) list
datatype sortlookup = Instance of (class * string) * sortlookup list list
| Lookup of class list * (string * int)
val extract_sortctxt: theory -> typ -> sortcontext
val extract_sortlookup: theory -> typ * typ -> sortlookup list list
end;
structure ClassPackage: CLASS_PACKAGE =
struct
(* data kind 'Pure/classes' *)
type class_data = {
superclasses: class list,
name_locale: string,
name_axclass: string,
var: string,
consts: (string * typ) list,
insts: (string * string) list
};
structure ClassData = TheoryDataFun (
struct
val name = "Pure/classes";
type T = class_data Symtab.table * class Symtab.table;
val empty = (Symtab.empty, Symtab.empty);
val copy = I;
val extend = I;
fun merge _ ((t1, r1), (t2, r2))=
(Symtab.merge (op =) (t1, t2),
Symtab.merge (op =) (r1, r2));
fun print thy (tab, _) =
let
fun pretty_class (name, {superclasses, name_locale, name_axclass, var, consts, insts}) =
(Pretty.block o Pretty.fbreaks) [
Pretty.str ("class " ^ name ^ ":"),
(Pretty.block o Pretty.fbreaks) (
Pretty.str "superclasses: "
:: map Pretty.str superclasses
),
Pretty.str ("locale: " ^ name_locale),
Pretty.str ("axclass: " ^ name_axclass),
Pretty.str ("class variable: " ^ var),
(Pretty.block o Pretty.fbreaks) (
Pretty.str "constants: "
:: map (fn (c, ty) => Pretty.str (c ^ " :: " ^ Sign.string_of_typ thy ty)) consts
),
(Pretty.block o Pretty.fbreaks) (
Pretty.str "instances: "
:: map (fn (tyco, thyname) => Pretty.str (tyco ^ ", in theory " ^ thyname)) insts
)
]
in
(Pretty.writeln o Pretty.chunks o map pretty_class o Symtab.dest) tab
end;
end
);
val print_classes = ClassData.print;
val lookup_class_data = Symtab.lookup o fst o ClassData.get;
val lookup_const_class = Symtab.lookup o snd o ClassData.get;
fun get_class_data thy class =
case lookup_class_data thy class
of NONE => error ("undeclared class " ^ quote class)
| SOME data => data;
fun add_class_data (class, (superclasses, name_locale, name_axclass, classvar, consts)) =
ClassData.map (fn (classtab, consttab) => (
classtab
|> Symtab.update (class, {
superclasses = superclasses,
name_locale = name_locale,
name_axclass = name_axclass,
var = classvar,
consts = consts,
insts = []
}),
consttab
|> fold (fn (c, _) => Symtab.update (c, class)) consts
));
fun add_inst_data (class, inst) =
(ClassData.map o apfst o Symtab.map_entry class)
(fn {superclasses, name_locale, name_axclass, var, consts, insts}
=> {
superclasses = superclasses,
name_locale = name_locale,
name_axclass = name_axclass,
var = var,
consts = consts,
insts = insts @ [inst]
});
val the_consts = map fst o #consts oo get_class_data;
val the_tycos = #insts oo get_class_data;
(* classes and instances *)
local
open Element
fun gen_add_class add_locale bname raw_import raw_body thy =
let
fun subst_clsvar v ty_subst =
map_type_tfree (fn u as (w, _) =>
if w = v then ty_subst else TFree u);
fun extract_assumes c_adds elems =
let
fun subst_free ts =
let
val get_ty = the o AList.lookup (op =) (fold Term.add_frees ts []);
val subst_map = map (fn (c, (c', _)) =>
(Free (c, get_ty c), Const (c', get_ty c))) c_adds;
in map (subst_atomic subst_map) ts end;
in
elems
|> (map o List.mapPartial)
(fn (Assumes asms) => (SOME o map (map fst o snd)) asms
| _ => NONE)
|> Library.flat o Library.flat o Library.flat
|> subst_free
end;
fun extract_tyvar_name thy tys =
fold (curry add_typ_tfrees) tys []
|> (fn [(v, sort)] =>
if Sorts.sort_eq (Sign.classes_of thy) (Sign.defaultS thy, sort)
then v
else error ("illegal sort constraint on class type variable: " ^ Sign.string_of_sort thy sort)
| [] => error ("no class type variable")
| vs => error ("more than one type variable: " ^ (commas o map (Sign.string_of_typ thy o TFree)) vs))
fun extract_tyvar_consts thy elems =
elems
|> Library.flat
|> List.mapPartial
(fn (Fixes consts) => SOME consts
| _ => NONE)
|> Library.flat
|> map (fn (c, ty, syn) =>
((c, the ty), (Syntax.unlocalize_mixfix o Syntax.fix_mixfix c o the) syn))
|> `(fn consts => extract_tyvar_name thy (map (snd o fst) consts))
|-> (fn v => map ((apfst o apsnd) (subst_clsvar v (TFree (v, []))))
#> pair v);
fun add_global_const v ((c, ty), syn) thy =
thy
|> Sign.add_consts_i [(c, ty |> subst_clsvar v (TFree (v, Sign.defaultS thy)), syn)]
|> `(fn thy => (c, (Sign.intern_const thy c, ty)))
fun add_global_constraint v class (_, (c, ty)) thy =
thy
|> Sign.add_const_constraint_i (c, subst_clsvar v (TVar ((v, 0), [class])) ty);
fun print_ctxt ctxt elem =
map Pretty.writeln (Element.pretty_ctxt ctxt elem)
in
thy
|> add_locale bname raw_import raw_body
|-> (fn ((import_elems, body_elems), ctxt) =>
`(fn thy => Locale.intern thy bname)
#-> (fn name_locale =>
`(fn thy => extract_tyvar_consts thy body_elems)
#-> (fn (v, c_defs) =>
fold_map (add_global_const v) c_defs
#-> (fn c_adds =>
AxClass.add_axclass_i (bname, Sign.defaultS thy)
(map (Thm.no_attributes o pair "") (extract_assumes c_adds (import_elems @ body_elems)))
#-> (fn _ =>
`(fn thy => Sign.intern_class thy bname)
#-> (fn name_axclass =>
fold (add_global_constraint v name_axclass) c_adds
#> add_class_data (name_locale, ([], name_locale, name_axclass, v, map snd c_adds))
#> tap (fn _ => (map o map) (print_ctxt ctxt) import_elems)
#> tap (fn _ => (map o map) (print_ctxt ctxt) body_elems)
#> pair ctxt
))))))
end;
in
val add_class = gen_add_class (Locale.add_locale_context true);
val add_class_i = gen_add_class (Locale.add_locale_context_i true);
end; (* local *)
fun gen_instance_arity prep_arity add_defs tap_def raw_arity raw_defs thy =
let
val dest_def = Theory.dest_def (Sign.pp thy) handle TERM (msg, _) => error msg;
val arity as (tyco, asorts, sort) = prep_arity thy ((fn ((x, y), z) => (x, y, z)) raw_arity);
val ty_inst = Type (tyco, map2 (curry TVar o rpair 0) (Term.invent_names [] "'a" (length asorts)) asorts)
fun get_c_req class =
let
val data = get_class_data thy class;
val subst_ty = map_type_tfree (fn (var as (v, _)) =>
if #var data = v then ty_inst else TFree var)
in (map (apsnd subst_ty) o #consts) data end;
val c_req = (Library.flat o map get_c_req) sort;
fun get_remove_contraint c thy =
let
val ty1 = Sign.the_const_constraint thy c;
val ty2 = Sign.the_const_type thy c;
in
thy
|> Sign.add_const_constraint_i (c, ty2)
|> pair (c, ty1)
end;
fun get_c_given thy = map (fst o dest_def o snd o tap_def thy o fst) raw_defs;
fun check_defs c_given c_req thy =
let
fun eq_c ((c1, ty1), (c2, ty2)) = c1 = c2 andalso Sign.typ_instance thy (ty1, ty2)
val _ = case fold (remove eq_c) c_given c_req
of [] => ()
| cs => error ("no definition(s) given for"
^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs);
val _ = case fold (remove eq_c) c_req c_given
of [] => ()
| cs => error ("superfluous definition(s) given for"
^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs);
in thy end;
in
thy
|> fold_map get_remove_contraint (map fst c_req)
||> tap (fn thy => check_defs (get_c_given thy) c_req)
||> add_defs (true, raw_defs)
|-> (fn cs => fold Sign.add_const_constraint_i cs)
|> AxClass.instance_arity_i arity
end;
val add_instance_arity = fn x => gen_instance_arity (AxClass.read_arity) IsarThy.add_defs read_axm x;
val add_instance_arity_i = fn x => gen_instance_arity (AxClass.cert_arity) IsarThy.add_defs_i (K I) x;
(* class queries *)
fun is_class thy cls = lookup_class_data thy cls |> Option.map (not o null o #consts) |> the_default false;
fun syntactic_sort_of thy sort =
let
val classes = Sign.classes_of thy;
fun get_sort cls =
if is_class thy cls
then [cls]
else syntactic_sort_of thy (Sorts.superclasses classes cls);
in
map get_sort sort
|> Library.flat
|> Sorts.norm_sort classes
end;
fun get_arities thy sort tycon =
Sorts.mg_domain (Sign.classes_arities_of thy) tycon (syntactic_sort_of thy sort)
|> map (syntactic_sort_of thy);
fun get_superclasses thy class =
if is_class thy class
then
Sorts.superclasses (Sign.classes_of thy) class
|> syntactic_sort_of thy
else
error ("no syntactic class: " ^ class);
(* instance queries *)
fun mk_const_sign thy class tvar ty =
let
val (ty', thaw) = Type.freeze_thaw_type ty;
val tvars_used = Term.add_tfreesT ty' [];
val tvar_rename = hd (Term.invent_names (map fst tvars_used) tvar 1);
in
ty'
|> map_type_tfree (fn (tvar', sort) =>
if Sorts.sort_eq (Sign.classes_of thy) ([class], sort)
then TFree (tvar, [])
else if tvar' = tvar
then TVar ((tvar_rename, 0), sort)
else TFree (tvar', sort))
|> thaw
end;
fun get_const_sign thy tvar const =
let
val class = (the o lookup_const_class thy) const;
val ty = Sign.the_const_constraint thy const;
in mk_const_sign thy class tvar ty end;
fun get_inst_consts_sign thy (tyco, class) =
let
val consts = the_consts thy class;
val arities = get_arities thy [class] tyco;
val const_signs = map (get_const_sign thy "'a") consts;
val vars_used = fold (fn ty => curry (gen_union (op =))
(map fst (typ_tfrees ty) |> remove (op =) "'a")) const_signs [];
val vars_new = Term.invent_names vars_used "'a" (length arities);
val typ_arity = Type (tyco, map2 (curry TFree) vars_new arities);
val instmem_signs =
map (typ_subst_TVars [(("'a", 0), typ_arity)]) const_signs;
in consts ~~ instmem_signs end;
fun get_classtab thy =
Symtab.fold
(fn (class, { consts = consts, insts = insts, ... }) =>
Symtab.update_new (class, (map fst consts, insts)))
(fst (ClassData.get thy)) Symtab.empty;
(* extracting dictionary obligations from types *)
type sortcontext = (string * sort) list;
fun extract_sortctxt thy ty =
(typ_tfrees o Type.no_tvars) ty
|> map (apsnd (syntactic_sort_of thy))
|> filter (not o null o snd);
datatype sortlookup = Instance of (class * string) * sortlookup list list
| Lookup of class list * (string * int)
fun extract_sortlookup thy (raw_typ_def, raw_typ_use) =
let
val typ_def = Type.varifyT raw_typ_def;
val typ_use = Type.varifyT raw_typ_use;
val match_tab = Sign.typ_match thy (typ_def, typ_use) Vartab.empty;
fun tab_lookup vname = (the o Vartab.lookup match_tab) (vname, 0);
fun get_superclass_derivation (subclasses, superclass) =
(the oo get_first) (fn subclass =>
Sorts.class_le_path (Sign.classes_of thy) (subclass, superclass)
) subclasses;
fun mk_class_deriv thy subclasses superclass =
case get_superclass_derivation (subclasses, superclass)
of (subclass::deriv) =>
((rev o filter (is_class thy)) deriv, find_index_eq subclass subclasses);
fun mk_lookup (sort_def, (Type (tycon, tys))) =
let
val arity_lookup = map2 (curry mk_lookup)
(map (syntactic_sort_of thy) (Sorts.mg_domain (Sign.classes_arities_of thy) tycon sort_def)) tys
in map (fn class => Instance ((class, tycon), arity_lookup)) sort_def end
| mk_lookup (sort_def, TVar ((vname, _), sort_use)) =
let
fun mk_look class =
let val (deriv, classindex) = mk_class_deriv thy (syntactic_sort_of thy sort_use) class
in Lookup (deriv, (vname, classindex)) end;
in map mk_look sort_def end;
in
extract_sortctxt thy ((fst o Type.freeze_thaw_type) raw_typ_def)
|> map (tab_lookup o fst)
|> map (apfst (syntactic_sort_of thy))
|> filter (not o null o fst)
|> map mk_lookup
end;
(* intermediate auxiliary *)
fun add_classentry raw_class raw_cs raw_insts thy =
let
val class = Sign.intern_class thy raw_class;
val cs = raw_cs |> map (Sign.intern_const thy);
val insts = map (rpair (Context.theory_name thy) o Sign.intern_type thy) raw_insts;
in
thy
|> add_class_data (class, ([], "", class, "", map (rpair dummyT) cs))
|> fold (curry add_inst_data class) insts
end;
(* toplevel interface *)
local
structure P = OuterParse
and K = OuterKeyword
in
val (classK, instanceK) = ("class", "class_instance")
val locale_val =
(P.locale_expr --
Scan.optional (P.$$$ "+" |-- P.!!! (Scan.repeat1 P.context_element)) [] ||
Scan.repeat1 P.context_element >> pair Locale.empty);
val classP =
OuterSyntax.command classK "operational type classes" K.thy_decl
(P.name -- Scan.optional (P.$$$ "=" |-- P.!!! locale_val) (Locale.empty, [])
>> (Toplevel.theory_context
o (fn f => swap o f) o (fn (bname, (expr, elems)) => add_class bname expr elems)));
val instanceP =
OuterSyntax.command instanceK "" K.thy_goal
(P.xname -- (P.$$$ "::" |-- P.!!! P.arity)
-- Scan.repeat1 P.spec_name
>> (Toplevel.theory_to_proof
o (fn ((tyco, (asorts, sort)), defs) => add_instance_arity ((tyco, asorts), sort) defs)));
val _ = OuterSyntax.add_parsers [classP, instanceP];
end; (* local *)
(* setup *)
val _ = Context.add_setup [ClassData.init];
end; (* struct *)