src/HOL/Tools/datatype_codegen.ML
changeset 20177 0af885e3dabf
parent 20105 454f4be984b7
child 20182 79c9ff40d760
--- a/src/HOL/Tools/datatype_codegen.ML	Fri Jul 21 14:46:27 2006 +0200
+++ b/src/HOL/Tools/datatype_codegen.ML	Fri Jul 21 14:47:22 2006 +0200
@@ -11,9 +11,17 @@
     -> (((string * sort) list * (string * typ list) list) * tactic) option
   val get_all_datatype_cons: theory -> (string * string) list
   val dest_case_expr: theory -> term
-    -> ((string * typ) list * ((term * typ) * (term * term) list)) option;
+    -> ((string * typ) list * ((term * typ) * (term * term) list)) option
   val add_datatype_case_const: string -> theory -> theory
   val add_datatype_case_defs: string -> theory -> theory
+  val datatypes_dependency: theory -> string list list
+  val get_datatype_mut_specs: theory -> string list
+    -> ((string * sort) list * (string * (string * typ list) list) list)
+  val get_datatype_arities: theory -> string list -> sort
+    -> (string * (((string * sort list) * sort)  * term list)) list option
+  val datatype_prove_arities : tactic -> string list -> sort
+    -> ((string * term list) list
+    -> ((bstring * attribute list) * term) list) -> theory -> theory
   val setup: theory -> theory
 end;
 
@@ -306,13 +314,88 @@
 
 (** code 2nd generation **)
 
+fun datatypes_dependency thy =
+  let
+    val dtnames = DatatypePackage.get_datatypes thy;
+    fun add_node (dtname, _) =
+      let
+        fun add_tycos (Type (tyco, tys)) = insert (op =) tyco #> fold add_tycos tys
+          | add_tycos _ = I;
+        val deps = (filter (Symtab.defined dtnames) o maps (fn ty =>
+          add_tycos ty [])
+            o maps snd o snd o the o DatatypePackage.get_datatype_spec thy) dtname
+      in
+        Graph.default_node (dtname, ())
+        #> fold (fn dtname' =>
+             Graph.default_node (dtname', ())
+             #> Graph.add_edge (dtname', dtname)
+           ) deps
+      end
+  in
+    Graph.empty
+    |> Symtab.fold add_node dtnames
+    |> Graph.strong_conn
+  end;
+
+fun get_datatype_mut_specs thy (tycos as tyco :: _) =
+  let
+    val tycos' = (map (#1 o snd) o #descr o DatatypePackage.the_datatype thy) tyco;
+    val _ = if gen_subset (op =) (tycos, tycos') then () else
+      error ("datatype constructors are not mutually recursive: " ^ (commas o map quote) tycos);
+    val (vs::_, css) = split_list (map (the o DatatypePackage.get_datatype_spec thy) tycos);
+  in (vs, tycos ~~ css) end;
+
+fun get_datatype_arities thy tycos sort =
+  let
+    val algebra = Sign.classes_of thy;
+    val (vs_proto, css_proto) = get_datatype_mut_specs thy tycos;
+    val vs = map (fn (v, vsort) => (v, Sorts.inter_sort algebra (vsort, sort))) vs_proto;
+    fun inst_type tyco (c, tys) =
+      let
+        val tys' = (map o map_atyps)
+          (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) vs v))) tys
+      in (c, tys') end;
+    val css = map (fn (tyco, cs) => (tyco, (map (inst_type tyco) cs))) css_proto;
+    fun mk_arity tyco =
+      ((tyco, map snd vs), sort);
+    fun typ_of_sort ty =
+      let
+        val arities = map (fn (tyco, _) => ((tyco, map snd vs), sort)) css;
+      in ClassPackage.assume_arities_of_sort thy arities (ty, sort) end;
+    fun mk_cons tyco (c, tys) =
+      let
+        val ts = Name.give_names Name.context "a" tys;
+        val ty = tys ---> Type (tyco, map TFree vs);
+      in list_comb (Const (c, ty), map Free ts) end;
+  in if forall (fn (_, cs) => forall (fn (_, tys) => forall typ_of_sort tys) cs) css
+    then SOME (
+      map (fn (tyco, cs) => (tyco, (mk_arity tyco, map (mk_cons tyco) cs))) css
+    ) else NONE
+  end;
+
+fun datatype_prove_arities tac tycos sort f thy =
+  case get_datatype_arities thy tycos sort
+   of NONE => thy
+    | SOME insts => let
+        fun proven ((tyco, asorts), sort) =
+          Sorts.of_sort (Sign.classes_of thy)
+            (Type (tyco, map TFree (Name.give_names Name.context "'a" asorts)), sort);
+        val (arities, css) = (split_list o map_filter
+          (fn (tyco, (arity, cs)) => if proven arity
+            then SOME (arity, (tyco, cs)) else NONE)) insts;
+      in
+        thy
+        |> ClassPackage.prove_instance_arity tac
+             arities ("", []) (f css)
+      end;
+
 fun dtyp_of_case_const thy c =
   get_first (fn (dtco, { case_name, ... }) => if case_name = c then SOME dtco else NONE)
     ((Symtab.dest o DatatypePackage.get_datatypes) thy);
 
 fun dest_case_app cs ts tys =
   let
-    val abs = CodegenThingol.give_names [] (Library.drop (length ts, tys));
+    val abs = Name.give_names Name.context "a" (Library.drop (length ts, tys));
     val (ts', t) = split_last (ts @ map Free abs);
     val (tys', sty) = split_last tys;
     fun freenames_of t = fold_aterms
@@ -382,7 +465,7 @@
     get_datatype_spec_thms #>
   CodegenPackage.set_get_all_datatype_cons
     get_all_datatype_cons #>
-  DatatypeHooks.add add_datatype_case_const #>
-  DatatypeHooks.add add_datatype_case_defs
+  DatatypeHooks.add (fold add_datatype_case_const) #>
+  DatatypeHooks.add (fold add_datatype_case_defs)
 
 end;