hooks now take string list as arguments (mutual datatypes); some nice combinators in datatype_codegen
--- a/src/HOL/Tools/datatype_codegen.ML Fri Jul 21 14:46:27 2006 +0200
+++ b/src/HOL/Tools/datatype_codegen.ML Fri Jul 21 14:47:22 2006 +0200
@@ -11,9 +11,17 @@
-> (((string * sort) list * (string * typ list) list) * tactic) option
val get_all_datatype_cons: theory -> (string * string) list
val dest_case_expr: theory -> term
- -> ((string * typ) list * ((term * typ) * (term * term) list)) option;
+ -> ((string * typ) list * ((term * typ) * (term * term) list)) option
val add_datatype_case_const: string -> theory -> theory
val add_datatype_case_defs: string -> theory -> theory
+ val datatypes_dependency: theory -> string list list
+ val get_datatype_mut_specs: theory -> string list
+ -> ((string * sort) list * (string * (string * typ list) list) list)
+ val get_datatype_arities: theory -> string list -> sort
+ -> (string * (((string * sort list) * sort) * term list)) list option
+ val datatype_prove_arities : tactic -> string list -> sort
+ -> ((string * term list) list
+ -> ((bstring * attribute list) * term) list) -> theory -> theory
val setup: theory -> theory
end;
@@ -306,13 +314,88 @@
(** code 2nd generation **)
+fun datatypes_dependency thy =
+ let
+ val dtnames = DatatypePackage.get_datatypes thy;
+ fun add_node (dtname, _) =
+ let
+ fun add_tycos (Type (tyco, tys)) = insert (op =) tyco #> fold add_tycos tys
+ | add_tycos _ = I;
+ val deps = (filter (Symtab.defined dtnames) o maps (fn ty =>
+ add_tycos ty [])
+ o maps snd o snd o the o DatatypePackage.get_datatype_spec thy) dtname
+ in
+ Graph.default_node (dtname, ())
+ #> fold (fn dtname' =>
+ Graph.default_node (dtname', ())
+ #> Graph.add_edge (dtname', dtname)
+ ) deps
+ end
+ in
+ Graph.empty
+ |> Symtab.fold add_node dtnames
+ |> Graph.strong_conn
+ end;
+
+fun get_datatype_mut_specs thy (tycos as tyco :: _) =
+ let
+ val tycos' = (map (#1 o snd) o #descr o DatatypePackage.the_datatype thy) tyco;
+ val _ = if gen_subset (op =) (tycos, tycos') then () else
+ error ("datatype constructors are not mutually recursive: " ^ (commas o map quote) tycos);
+ val (vs::_, css) = split_list (map (the o DatatypePackage.get_datatype_spec thy) tycos);
+ in (vs, tycos ~~ css) end;
+
+fun get_datatype_arities thy tycos sort =
+ let
+ val algebra = Sign.classes_of thy;
+ val (vs_proto, css_proto) = get_datatype_mut_specs thy tycos;
+ val vs = map (fn (v, vsort) => (v, Sorts.inter_sort algebra (vsort, sort))) vs_proto;
+ fun inst_type tyco (c, tys) =
+ let
+ val tys' = (map o map_atyps)
+ (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) vs v))) tys
+ in (c, tys') end;
+ val css = map (fn (tyco, cs) => (tyco, (map (inst_type tyco) cs))) css_proto;
+ fun mk_arity tyco =
+ ((tyco, map snd vs), sort);
+ fun typ_of_sort ty =
+ let
+ val arities = map (fn (tyco, _) => ((tyco, map snd vs), sort)) css;
+ in ClassPackage.assume_arities_of_sort thy arities (ty, sort) end;
+ fun mk_cons tyco (c, tys) =
+ let
+ val ts = Name.give_names Name.context "a" tys;
+ val ty = tys ---> Type (tyco, map TFree vs);
+ in list_comb (Const (c, ty), map Free ts) end;
+ in if forall (fn (_, cs) => forall (fn (_, tys) => forall typ_of_sort tys) cs) css
+ then SOME (
+ map (fn (tyco, cs) => (tyco, (mk_arity tyco, map (mk_cons tyco) cs))) css
+ ) else NONE
+ end;
+
+fun datatype_prove_arities tac tycos sort f thy =
+ case get_datatype_arities thy tycos sort
+ of NONE => thy
+ | SOME insts => let
+ fun proven ((tyco, asorts), sort) =
+ Sorts.of_sort (Sign.classes_of thy)
+ (Type (tyco, map TFree (Name.give_names Name.context "'a" asorts)), sort);
+ val (arities, css) = (split_list o map_filter
+ (fn (tyco, (arity, cs)) => if proven arity
+ then SOME (arity, (tyco, cs)) else NONE)) insts;
+ in
+ thy
+ |> ClassPackage.prove_instance_arity tac
+ arities ("", []) (f css)
+ end;
+
fun dtyp_of_case_const thy c =
get_first (fn (dtco, { case_name, ... }) => if case_name = c then SOME dtco else NONE)
((Symtab.dest o DatatypePackage.get_datatypes) thy);
fun dest_case_app cs ts tys =
let
- val abs = CodegenThingol.give_names [] (Library.drop (length ts, tys));
+ val abs = Name.give_names Name.context "a" (Library.drop (length ts, tys));
val (ts', t) = split_last (ts @ map Free abs);
val (tys', sty) = split_last tys;
fun freenames_of t = fold_aterms
@@ -382,7 +465,7 @@
get_datatype_spec_thms #>
CodegenPackage.set_get_all_datatype_cons
get_all_datatype_cons #>
- DatatypeHooks.add add_datatype_case_const #>
- DatatypeHooks.add add_datatype_case_defs
+ DatatypeHooks.add (fold add_datatype_case_const) #>
+ DatatypeHooks.add (fold add_datatype_case_defs)
end;
--- a/src/HOL/Tools/datatype_hooks.ML Fri Jul 21 14:46:27 2006 +0200
+++ b/src/HOL/Tools/datatype_hooks.ML Fri Jul 21 14:47:22 2006 +0200
@@ -7,7 +7,7 @@
signature DATATYPE_HOOKS =
sig
- type hook = string -> theory -> theory;
+ type hook = string list -> theory -> theory;
val add: hook -> theory -> theory;
val invoke: hook;
val setup: theory -> theory;
@@ -19,7 +19,7 @@
(* theory data *)
-type hook = string -> theory -> theory;
+type hook = string list -> theory -> theory;
datatype T = T of (serial * hook) list;
fun map_T f (T hooks) = T (f hooks);
@@ -43,8 +43,8 @@
fun add hook =
DatatypeHooksData.map (map_T (cons (serial (), hook)));
-fun invoke dtco thy =
- fold (fn (_, f) => f dtco) ((fn T hooks => hooks) (DatatypeHooksData.get thy)) thy;
+fun invoke dtcos thy =
+ fold (fn (_, f) => f dtcos) ((fn T hooks => hooks) (DatatypeHooksData.get thy)) thy;
(* theory setup *)
--- a/src/HOL/Tools/datatype_package.ML Fri Jul 21 14:46:27 2006 +0200
+++ b/src/HOL/Tools/datatype_package.ML Fri Jul 21 14:47:22 2006 +0200
@@ -723,7 +723,7 @@
|> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)
|> snd
|> DatatypeRealizer.add_dt_realizers sorts (map snd dt_infos)
- |> fold (DatatypeHooks.invoke o fst) dt_infos;
+ |> DatatypeHooks.invoke (map fst dt_infos);
in
({distinct = distinct,
inject = inject,
@@ -783,7 +783,7 @@
|> Theory.parent_path
|> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd
|> DatatypeRealizer.add_dt_realizers sorts (map snd dt_infos)
- |> fold (DatatypeHooks.invoke o fst) dt_infos;
+ |> DatatypeHooks.invoke (map fst dt_infos);
in
({distinct = distinct,
inject = inject,
@@ -893,7 +893,7 @@
|> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)
|> snd
|> DatatypeRealizer.add_dt_realizers sorts (map snd dt_infos)
- |> fold (DatatypeHooks.invoke o fst) dt_infos;
+ |> DatatypeHooks.invoke (map fst dt_infos);
in
({distinct = distinct,
inject = inject,