src/HOL/Tools/datatype_codegen.ML
changeset 25534 d0b74fdd6067
parent 25505 4d531475129a
child 25569 c597835d5de4
--- a/src/HOL/Tools/datatype_codegen.ML	Wed Dec 05 14:15:39 2007 +0100
+++ b/src/HOL/Tools/datatype_codegen.ML	Wed Dec 05 14:15:45 2007 +0100
@@ -2,32 +2,21 @@
     ID:         $Id$
     Author:     Stefan Berghofer & Florian Haftmann, TU Muenchen
 
-Code generator for inductive datatypes.
+Code generator facilities for inductive datatypes.
 *)
 
 signature DATATYPE_CODEGEN =
 sig
   val get_eq: theory -> string -> thm list
-  val get_eq_datatype: theory -> string -> thm list
   val get_case_cert: theory -> string -> thm
-
-  type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
-    -> theory -> theory
-  val add_codetypes_hook: hook -> theory -> theory
-  val get_codetypes_arities: theory -> (string * bool) list -> sort
-    -> (string * (arity * term list)) list
-  val prove_codetypes_arities: tactic -> (string * bool) list -> sort
-    -> (arity list -> (string * term list) list -> theory
-      -> ((bstring * Attrib.src list) * term) list * theory)
-    -> (arity list -> (string * term list) list -> thm list -> theory -> theory)
-    -> theory -> theory
-
   val setup: theory -> theory
 end;
 
 structure DatatypeCodegen : DATATYPE_CODEGEN =
 struct
 
+(** SML code generator **)
+
 open Codegen;
 
 fun mk_tuple [p] = p
@@ -310,66 +299,21 @@
   | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
 
 
-(** datatypes for code 2nd generation **)
-
-local
-
-val not_sym = thm "HOL.not_sym";
-val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
-val refl = thm "refl";
-val eqTrueI = thm "eqTrueI";
+(** generic code generator **)
 
-fun mk_distinct cos =
-  let
-    fun sym_product [] = []
-      | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
-    fun mk_co_args (co, tys) ctxt =
-      let
-        val names = Name.invents ctxt "a" (length tys);
-        val ctxt' = fold Name.declare names ctxt;
-        val vs = map2 (curry Free) names tys;
-      in (vs, ctxt') end;
-    fun mk_dist ((co1, tys1), (co2, tys2)) =
-      let
-        val ((xs1, xs2), _) = Name.context
-          |> mk_co_args (co1, tys1)
-          ||>> mk_co_args (co2, tys2);
-        val prem = HOLogic.mk_eq
-          (list_comb (co1, xs1), list_comb (co2, xs2));
-        val t = HOLogic.mk_not prem;
-      in HOLogic.mk_Trueprop t end;
-  in map mk_dist (sym_product cos) end;
+(* specification *)
 
-in
-
-fun get_eq_datatype thy dtco =
+fun add_datatype_spec vs dtco cos thy =
   let
-    val SOME (vs, cs) = DatatypePackage.get_datatype_spec thy dtco;
-    fun mk_triv_inject co =
-      let
-        val ct' = Thm.cterm_of thy
-          (Const (co, Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vs)))
-        val cty' = Thm.ctyp_of_term ct';
-        val SOME (ct, cty) = fold_aterms (fn Var (v, ty) =>
-          (K o SOME) (Thm.cterm_of thy (Var (v, Thm.typ_of cty')), Thm.ctyp_of thy ty) | _ => I)
-          (Thm.prop_of refl) NONE;
-      in eqTrueI OF [Thm.instantiate ([(cty, cty')], [(ct, ct')]) refl] end;
-    val inject1 = map_filter (fn (co, []) => SOME (mk_triv_inject co) | _ => NONE) cs
-    val inject2 = (#inject o DatatypePackage.the_datatype thy) dtco;
-    val ctxt = ProofContext.init thy;
-    val simpset = Simplifier.context ctxt
-      (MetaSimplifier.empty_ss addsimprocs [distinct_simproc]);
-    val cos = map (fn (co, tys) =>
-        (Const (co, tys ---> Type (dtco, map TFree vs)), tys)) cs;
-    val tac = ALLGOALS (simp_tac simpset)
-      THEN ALLGOALS (ProofContext.fact_tac [not_false_true, TrueI]);
-    val distinct =
-      mk_distinct cos
-      |> map (fn t => Goal.prove_global thy [] [] t (K tac))
-      |> (fn thms => thms @ map (fn thm => not_sym OF [thm]) thms)
-  in inject1 @ inject2 @ distinct end;
+    val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
+  in
+    thy
+    |> try (Code.add_datatype cs)
+    |> the_default thy
+  end;
 
-end;
+
+(* case certificates *)
 
 fun get_case_cert thy tyco =
   let
@@ -402,170 +346,116 @@
     |> Thm.varifyT
   end;
 
-
-
-(** codetypes for code 2nd generation **)
-
-(* abstraction over datatypes vs. type copies *)
-
-fun get_typecopy_spec thy tyco =
+fun add_datatype_cases dtco thy =
   let
-    val SOME { vs, constr, typ, ... } = TypecopyPackage.get_info thy tyco
-  in (vs, [(constr, [typ])]) end;
-
-
-fun get_spec thy (dtco, true) =
-      (the o DatatypePackage.get_datatype_spec thy) dtco
-  | get_spec thy (tyco, false) =
-      get_typecopy_spec thy tyco;
-
-local
-  fun get_eq_thms thy tyco = case DatatypePackage.get_datatype thy tyco
-   of SOME _ => get_eq_datatype thy tyco
-    | NONE => [TypecopyPackage.get_eq thy tyco];
-  fun constrain_op_eq_thms thy thms =
-    let
-      fun add_eq (Const ("op =", ty)) =
-            fold (insert (eq_fst (op =))) (Term.add_tvarsT ty [])
-        | add_eq _ =
-            I
-      val eqs = fold (fold_aterms add_eq o Thm.prop_of) thms [];
-      val instT = map (fn (v_i, sort) =>
-        (Thm.ctyp_of thy (TVar (v_i, sort)),
-           Thm.ctyp_of thy (TVar (v_i, Sorts.inter_sort (Sign.classes_of thy)
-             (sort, [HOLogic.class_eq]))))) eqs;
-    in
-      thms
-      |> map (Thm.instantiate (instT, []))
-    end;
-in
-  fun get_eq thy tyco =
-    get_eq_thms thy tyco
-    |> maps ((#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy)
-    |> constrain_op_eq_thms thy
-end;
-
-type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
-  -> theory -> theory;
-
-fun add_codetypes_hook hook thy =
-  let
-    fun add_spec thy (tyco, is_dt) =
-      (tyco, (is_dt, get_spec thy (tyco, is_dt)));
-    fun datatype_hook dtcos thy =
-      hook (map (add_spec thy) (map (rpair true) dtcos)) thy;
-    fun typecopy_hook tyco thy =
-      hook ([(tyco, (false, get_typecopy_spec thy tyco))]) thy;
+    val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
+    val certs = get_case_cert thy dtco;
   in
     thy
-    |> DatatypePackage.interpretation datatype_hook
-    |> TypecopyPackage.interpretation typecopy_hook
+    |> Code.add_case certs
+    |> fold_rev Code.add_default_func case_rewrites
   end;
 
-fun the_codetypes_mut_specs thy ([(tyco, is_dt)]) =
-      let
-        val (vs, cs) = get_spec thy (tyco, is_dt)
-      in (vs, [(tyco, (is_dt, cs))]) end
-  | the_codetypes_mut_specs thy (tycos' as (tyco, true) :: _) =
-      let
-        val tycos = map fst tycos';
-        val tycos'' = (map (#1 o snd) o #descr o DatatypePackage.the_datatype thy) tyco;
-        val _ = if gen_subset (op =) (tycos, tycos'') then () else
-          error ("type constructors are not mutually recursive: " ^ (commas o map quote) tycos);
-        val (vs::_, css) = split_list (map (the o DatatypePackage.get_datatype_spec thy) tycos);
-      in (vs, map2 (fn (tyco, is_dt) => fn cs => (tyco, (is_dt, cs))) tycos' css) end;
+
+(* equality *)
+
+local
 
-
-(* instrumentalizing the sort algebra *)
+val not_sym = thm "HOL.not_sym";
+val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
+val refl = thm "refl";
+val eqTrueI = thm "eqTrueI";
 
-fun get_codetypes_arities thy tycos sort =
+fun mk_distinct cos =
   let
-    val pp = Sign.pp thy;
-    val algebra = Sign.classes_of thy;
-    val (vs_proto, css_proto) = the_codetypes_mut_specs thy tycos;
-    val vs = map (fn (v, vsort) => (v, Sorts.inter_sort algebra (vsort, sort))) vs_proto;
-    val css = map (fn (tyco, (_, cs)) => (tyco, cs)) css_proto;
-    val algebra' = algebra
-      |> fold (fn (tyco, _) =>
-           Sorts.add_arities pp (tyco, map (fn class => (class, map snd vs)) sort)) css;
-    fun typ_sort_inst ty = CodeUnit.typ_sort_inst algebra' (Logic.varifyT ty, sort);
-    val venv = Vartab.empty
-      |> fold (fn (v, sort) => Vartab.update_new ((v, 0), sort)) vs
-      |> fold (fn (_, cs) => fold (fn (_, tys) => fold typ_sort_inst tys) cs) css;
-    fun inst (v, _) = (v, (the o Vartab.lookup venv) (v, 0));
-    val vs' = map inst vs;
-    fun mk_arity tyco = (tyco, map snd vs', sort);
-    fun mk_cons tyco (c, tys) =
+    fun sym_product [] = []
+      | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
+    fun mk_co_args (co, tys) ctxt =
+      let
+        val names = Name.invents ctxt "a" (length tys);
+        val ctxt' = fold Name.declare names ctxt;
+        val vs = map2 (curry Free) names tys;
+      in (vs, ctxt') end;
+    fun mk_dist ((co1, tys1), (co2, tys2)) =
       let
-        val tys' = (map o Term.map_type_tfree) (TFree o inst) tys;
-        val ts = Name.names Name.context "a" tys';
-        val ty = (tys' ---> Type (tyco, map TFree vs'));
-      in list_comb (Const (c, ty), map Free ts) end;
-  in
-    map (fn (tyco, cs) => (tyco, (mk_arity tyco, map (mk_cons tyco) cs))) css
-  end;
+        val ((xs1, xs2), _) = Name.context
+          |> mk_co_args (co1, tys1)
+          ||>> mk_co_args (co2, tys2);
+        val prem = HOLogic.mk_eq
+          (list_comb (co1, xs1), list_comb (co2, xs2));
+        val t = HOLogic.mk_not prem;
+      in HOLogic.mk_Trueprop t end;
+  in map mk_dist (sym_product cos) end;
+
+in
 
-fun prove_codetypes_arities tac tycos sort f after_qed thy =
-  case try (get_codetypes_arities thy tycos) sort
-   of NONE => thy
-    | SOME insts => let
-        fun proven (tyco, asorts, sort) =
-          Sorts.of_sort (Sign.classes_of thy)
-            (Type (tyco, map TFree (Name.names Name.context "'a" asorts)), sort);
-        val (arities, css) = (split_list o map_filter
-          (fn (tyco, (arity, cs)) => if proven arity
-            then NONE else SOME (arity, (tyco, cs)))) insts;
-      in
-        thy
-        |> not (null arities) ? (
-            f arities css
-            #-> (fn defs =>
-              Instance.prove_instance tac arities defs
-            #-> (fn defs =>
-              after_qed arities css defs)))
-      end;
+fun get_eq thy dtco =
+  let
+    val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco;
+    fun mk_triv_inject co =
+      let
+        val ct' = Thm.cterm_of thy
+          (Const (co, Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vs)))
+        val cty' = Thm.ctyp_of_term ct';
+        val SOME (ct, cty) = fold_aterms (fn Var (v, ty) =>
+          (K o SOME) (Thm.cterm_of thy (Var (v, Thm.typ_of cty')), Thm.ctyp_of thy ty) | _ => I)
+          (Thm.prop_of refl) NONE;
+      in eqTrueI OF [Thm.instantiate ([(cty, cty')], [(ct, ct')]) refl] end;
+    val inject1 = map_filter (fn (co, []) => SOME (mk_triv_inject co) | _ => NONE) cs
+    val inject2 = (#inject o DatatypePackage.the_datatype thy) dtco;
+    val ctxt = ProofContext.init thy;
+    val simpset = Simplifier.context ctxt
+      (MetaSimplifier.empty_ss addsimprocs [distinct_simproc]);
+    val cos = map (fn (co, tys) =>
+        (Const (co, tys ---> Type (dtco, map TFree vs)), tys)) cs;
+    val tac = ALLGOALS (simp_tac simpset)
+      THEN ALLGOALS (ProofContext.fact_tac [not_false_true, TrueI]);
+    val distinct =
+      mk_distinct cos
+      |> map (fn t => Goal.prove_global thy [] [] t (K tac))
+      |> (fn thms => thms @ map (fn thm => not_sym OF [thm]) thms)
+  in inject1 @ inject2 @ distinct end;
 
-
-(* operational equality *)
+end;
 
-fun eq_hook specs =
+fun add_datatypes_equality vs dtcos thy =
   let
-    fun add_eq_thms (dtco, (_, (vs, cs))) thy =
+    fun get_eq' thy dtco = get_eq thy dtco
+      |> map (CodeUnit.constrain_thm [HOLogic.class_eq])
+      |> maps ((#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy);
+    fun add_eq_thms dtco thy =
       let
         val thy_ref = Theory.check_thy thy;
         val const = Class.param_of_inst thy ("op =", dtco);
-        val get_thms = (fn () => get_eq (Theory.deref thy_ref) dtco |> rev);
+        val get_thms = (fn () => get_eq' (Theory.deref thy_ref) dtco |> rev);
       in
         Code.add_funcl (const, Susp.delay get_thms) thy
       end;
+    val sorts_eq =
+      map (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq] o snd) vs;
   in
-    prove_codetypes_arities (Class.intro_classes_tac [])
-      (map (fn (tyco, (is_dt, _)) => (tyco, is_dt)) specs)
-      [HOLogic.class_eq] ((K o K o pair) []) ((K o K o K) (fold add_eq_thms specs))
+    thy
+    |> Instance.instantiate (dtcos, sorts_eq, [HOLogic.class_eq]) (pair ())
+         ((K o K) (Class.intro_classes_tac []))
+    |> fold add_eq_thms dtcos
   end;
 
 
-
 (** theory setup **)
 
-fun add_datatype_spec dtco thy =
+fun add_datatype_code dtcos thy =
   let
-    val SOME (vs, cos) = DatatypePackage.get_datatype_spec thy dtco;
-    val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
-    val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
-    val certs = get_case_cert thy dtco;
+    val (vs :: _, coss) = (split_list o map (DatatypePackage.the_datatype_spec thy)) dtcos;
   in
     thy
-    |> try (Code.add_datatype cs)
-    |> the_default thy
-    |> Code.add_case certs
-    |> fold_rev Code.add_default_func case_rewrites
+    |> fold2 (add_datatype_spec vs) dtcos coss
+    |> fold add_datatype_cases dtcos
+    |> add_datatypes_equality vs dtcos
   end;
 
 val setup = 
   add_codegen "datatype" datatype_codegen
   #> add_tycodegen "datatype" datatype_tycodegen
-  #> DatatypePackage.interpretation (fold add_datatype_spec)
-  #> add_codetypes_hook eq_hook
+  #> DatatypePackage.interpretation add_datatype_code
 
 end;