generalized datatype code generation code so that it works with old-style and new-style (co)datatypes (as long as they are not local)
authorblanchet
Mon, 02 Dec 2013 20:31:54 +0100
changeset 54615 62fb5af93fe2
parent 54614 689398f0953f
child 54616 a21a2223c02b
generalized datatype code generation code so that it works with old-style and new-style (co)datatypes (as long as they are not local)
src/HOL/Ctr_Sugar.thy
src/HOL/Inductive.thy
src/HOL/Tools/Datatype/datatype_codegen.ML
src/HOL/Tools/ctr_sugar.ML
src/HOL/Tools/ctr_sugar_code.ML
--- a/src/HOL/Ctr_Sugar.thy	Mon Dec 02 20:31:54 2013 +0100
+++ b/src/HOL/Ctr_Sugar.thy	Mon Dec 02 20:31:54 2013 +0100
@@ -40,6 +40,7 @@
 
 ML_file "Tools/ctr_sugar_util.ML"
 ML_file "Tools/ctr_sugar_tactics.ML"
+ML_file "Tools/ctr_sugar_code.ML"
 ML_file "Tools/ctr_sugar.ML"
 
 end
--- a/src/HOL/Inductive.thy	Mon Dec 02 20:31:54 2013 +0100
+++ b/src/HOL/Inductive.thy	Mon Dec 02 20:31:54 2013 +0100
@@ -274,7 +274,7 @@
 ML_file "Tools/Datatype/datatype_prop.ML"
 ML_file "Tools/Datatype/datatype_data.ML" setup Datatype_Data.setup
 ML_file "Tools/Datatype/rep_datatype.ML"
-ML_file "Tools/Datatype/datatype_codegen.ML" setup Datatype_Codegen.setup
+ML_file "Tools/Datatype/datatype_codegen.ML"
 ML_file "Tools/Datatype/primrec.ML"
 
 text{* Lambda-abstractions with pattern matching: *}
--- a/src/HOL/Tools/Datatype/datatype_codegen.ML	Mon Dec 02 20:31:54 2013 +0100
+++ b/src/HOL/Tools/Datatype/datatype_codegen.ML	Mon Dec 02 20:31:54 2013 +0100
@@ -6,152 +6,24 @@
 
 signature DATATYPE_CODEGEN =
 sig
-  val setup: theory -> theory
 end;
 
 structure Datatype_Codegen : DATATYPE_CODEGEN =
 struct
 
-(** generic code generator **)
-
-(* liberal addition of code data for datatypes *)
-
-fun mk_constr_consts thy vs tyco cos =
-  let
-    val cs = map (fn (c, tys) => (c, tys ---> Type (tyco, map TFree vs))) cos;
-    val cs' = map (fn c_ty as (_, ty) => (Axclass.unoverload_const thy c_ty, ty)) cs;
-  in
-    if is_some (try (Code.constrset_of_consts thy) cs')
-    then SOME cs
-    else NONE
-  end;
-
-
-(* case certificates *)
-
-fun mk_case_cert thy tyco =
+fun add_code_for_datatype fcT_name thy =
   let
-    val raw_thms = #case_rewrites (Datatype_Data.the_info thy tyco);
-    val thms as hd_thm :: _ = raw_thms
-      |> Conjunction.intr_balanced
-      |> Thm.unvarify_global
-      |> Conjunction.elim_balanced (length raw_thms)
-      |> map Simpdata.mk_meta_eq
-      |> map Drule.zero_var_indexes;
-    val params = fold_aterms (fn (Free (v, _)) => insert (op =) v | _ => I) (Thm.prop_of hd_thm) [];
-    val rhs = hd_thm
-      |> Thm.prop_of
-      |> Logic.dest_equals
-      |> fst
-      |> Term.strip_comb
-      |> apsnd (fst o split_last)
-      |> list_comb;
-    val lhs = Free (singleton (Name.variant_list params) "case", Term.fastype_of rhs);
-    val asm = Thm.cterm_of thy (Logic.mk_equals (lhs, rhs));
+    val (As', ctr_specs) = Datatype_Data.the_spec thy fcT_name;
+    val {inject = inject_thms, distinct = distinct_thms, case_rewrites = case_thms, ...} =
+      Datatype_Data.the_info thy fcT_name;
+
+    val As = map TFree As';
+    val fcT = Type (fcT_name, As);
+    val ctrs = map (fn (c, arg_Ts) => (c, arg_Ts ---> fcT)) ctr_specs;
   in
-    thms
-    |> Conjunction.intr_balanced
-    |> rewrite_rule [Thm.symmetric (Thm.assume asm)]
-    |> Thm.implies_intr asm
-    |> Thm.generalize ([], params) 0
-    |> Axclass.unoverload thy
-    |> Thm.varifyT_global
+    Ctr_Sugar_Code.add_ctr_code fcT_name As ctrs inject_thms distinct_thms case_thms thy
   end;
 
-
-(* equality *)
-
-fun mk_eq_eqns thy tyco =
-  let
-    val (vs, cos) = Datatype_Data.the_spec thy tyco;
-    val {descr, index, inject = inject_thms, distinct = distinct_thms, ...} =
-      Datatype_Data.the_info thy tyco;
-    val ty = Type (tyco, map TFree vs);
-    fun mk_eq (t1, t2) = Const (@{const_name HOL.equal}, ty --> ty --> HOLogic.boolT) $ t1 $ t2;
-    fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, @{term True});
-    fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, @{term False});
-    val triv_injects =
-      map_filter
-        (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty))))
-          | _ => NONE) cos;
-    fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) =
-      trueprop $ (equiv $ mk_eq (t1, t2) $ rhs);
-    val injects = map prep_inject (nth (Datatype_Prop.make_injs [descr]) index);
-    fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) =
-      [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)];
-    val distincts = maps prep_distinct (nth (Datatype_Prop.make_distincts [descr]) index);
-    val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty)));
-    val simp_ctxt =
-      Simplifier.global_context thy HOL_basic_ss
-        addsimps (map Simpdata.mk_eq (@{thms equal eq_True} @ inject_thms @ distinct_thms));
-    fun prove prop =
-      Goal.prove_sorry_global thy [] [] prop (K (ALLGOALS (simp_tac simp_ctxt)))
-      |> Simpdata.mk_eq;
-  in (map prove (triv_injects @ injects @ distincts), prove refl) end;
-
-fun add_equality vs tycos thy =
-  let
-    fun add_def tyco lthy =
-      let
-        val ty = Type (tyco, map TFree vs);
-        fun mk_side const_name =
-          Const (const_name, ty --> ty --> HOLogic.boolT) $ Free ("x", ty) $ Free ("y", ty);
-        val def =
-          HOLogic.mk_Trueprop (HOLogic.mk_eq
-            (mk_side @{const_name HOL.equal}, mk_side @{const_name HOL.eq}));
-        val def' = Syntax.check_term lthy def;
-        val ((_, (_, thm)), lthy') =
-          Specification.definition (NONE, (Attrib.empty_binding, def')) lthy;
-        val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy);
-        val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm;
-      in (thm', lthy') end;
-    fun tac thms = Class.intro_classes_tac [] THEN ALLGOALS (Proof_Context.fact_tac thms);
-    fun prefix tyco =
-      Binding.qualify true (Long_Name.base_name tyco) o Binding.qualify true "eq" o Binding.name;
-    fun add_eq_thms tyco =
-      `(fn thy => mk_eq_eqns thy tyco)
-      #-> (fn (thms, thm) =>
-        Global_Theory.note_thmss Thm.lemmaK
-          [((prefix tyco "refl", [Code.add_nbe_default_eqn_attribute]), [([thm], [])]),
-            ((prefix tyco "simps", [Code.add_default_eqn_attribute]), [(rev thms, [])])])
-      #> snd;
-  in
-    thy
-    |> Class.instantiation (tycos, vs, [HOLogic.class_equal])
-    |> fold_map add_def tycos
-    |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm)
-         (fn _ => fn def_thms => tac def_thms) def_thms)
-    |-> (fn def_thms => fold Code.del_eqn def_thms)
-    |> fold add_eq_thms tycos
-  end;
-
-
-(* register a datatype etc. *)
-
-fun add_all_code config tycos thy =
-  let
-    val (vs :: _, coss) = split_list (map (Datatype_Data.the_spec thy) tycos);
-    val any_css = map2 (mk_constr_consts thy vs) tycos coss;
-    val css = if exists is_none any_css then [] else map_filter I any_css;
-    val case_rewrites = maps (#case_rewrites o Datatype_Data.the_info thy) tycos;
-    val certs = map (mk_case_cert thy) tycos;
-    val tycos_eq =
-      filter_out
-        (fn tyco => Sorts.has_instance (Sign.classes_of thy) tyco [HOLogic.class_equal]) tycos;
-  in
-    if null css then thy
-    else
-      thy
-      |> tap (fn _ => Datatype_Aux.message config "Registering datatype for code generator ...")
-      |> fold Code.add_datatype css
-      |> fold_rev Code.add_default_eqn case_rewrites
-      |> fold Code.add_case certs
-      |> not (null tycos_eq) ? add_equality vs tycos_eq
-   end;
-
-
-(** theory setup **)
-
-val setup = Datatype_Data.interpretation add_all_code;
+val _ = Theory.setup (Datatype_Data.interpretation (K (fold add_code_for_datatype)));
 
 end;
--- a/src/HOL/Tools/ctr_sugar.ML	Mon Dec 02 20:31:54 2013 +0100
+++ b/src/HOL/Tools/ctr_sugar.ML	Mon Dec 02 20:31:54 2013 +0100
@@ -66,6 +66,7 @@
 
 open Ctr_Sugar_Util
 open Ctr_Sugar_Tactics
+open Ctr_Sugar_Code
 
 type ctr_sugar =
   {ctrs: term list,
@@ -926,11 +927,13 @@
         (ctr_sugar,
          lthy
          |> not rep_compat ?
-            (Local_Theory.declaration {syntax = false, pervasive = true}
-               (fn phi => Case_Translation.register
-                  (Morphism.term phi casex) (map (Morphism.term phi) ctrs)))
+            Local_Theory.declaration {syntax = false, pervasive = true}
+              (fn phi => Case_Translation.register
+                 (Morphism.term phi casex) (map (Morphism.term phi) ctrs))
          |> Local_Theory.notes (anonymous_notes @ notes) |> snd
-         |> register_ctr_sugar fcT_name ctr_sugar)
+         |> register_ctr_sugar fcT_name ctr_sugar
+         |> Local_Theory.background_theory
+           (add_ctr_code fcT_name As (map dest_Const ctrs) inject_thms distinct_thms case_thms))
       end;
   in
     (goalss, after_qed, lthy')
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/ctr_sugar_code.ML	Mon Dec 02 20:31:54 2013 +0100
@@ -0,0 +1,129 @@
+(*  Title:      HOL/Tools/ctr_sugar_code.ML
+    Author:     Jasmin Blanchette, TU Muenchen
+    Author:     Dmitriy Traytel, TU Muenchen
+    Author:     Stefan Berghofer, TU Muenchen
+    Author:     Florian Haftmann, TU Muenchen
+    Copyright   2001-2013
+
+Code generation for freely generated types.
+*)
+
+signature CTR_SUGAR_CODE =
+sig
+  val add_ctr_code: string -> typ list -> (string * typ) list -> thm list -> thm list -> thm list ->
+    theory -> theory
+end;
+
+structure Ctr_Sugar_Code : CTR_SUGAR_CODE =
+struct
+
+open Ctr_Sugar_Util
+
+val eqN = "eq"
+val reflN = "refl"
+val simpsN = "simps"
+
+fun mk_case_certificate thy raw_thms =
+  let
+    val thms as thm1 :: _ = raw_thms
+      |> Conjunction.intr_balanced
+      |> Thm.unvarify_global
+      |> Conjunction.elim_balanced (length raw_thms)
+      |> map Simpdata.mk_meta_eq
+      |> map Drule.zero_var_indexes;
+    val params = Term.add_free_names (Thm.prop_of thm1) [];
+    val rhs = thm1
+      |> Thm.prop_of |> Logic.dest_equals |> fst |> Term.strip_comb
+      ||> fst o split_last |> list_comb;
+    val lhs = Free (singleton (Name.variant_list params) "case", Term.fastype_of rhs);
+    val assum = Thm.cterm_of thy (Logic.mk_equals (lhs, rhs));
+  in
+    thms
+    |> Conjunction.intr_balanced
+    |> rewrite_rule [Thm.symmetric (Thm.assume assum)]
+    |> Thm.implies_intr assum
+    |> Thm.generalize ([], params) 0
+    |> Axclass.unoverload thy
+    |> Thm.varifyT_global
+  end;
+
+fun mk_free_ctr_equations fcT ctrs inject_thms distinct_thms thy =
+  let
+    fun mk_fcT_eq (t, u) = Const (@{const_name HOL.equal}, fcT --> fcT --> HOLogic.boolT) $ t $ u;
+    fun true_eq tu = HOLogic.mk_eq (mk_fcT_eq tu, @{term True});
+    fun false_eq tu = HOLogic.mk_eq (mk_fcT_eq tu, @{term False});
+
+    val monomorphic_prop_of = prop_of o Thm.unvarify_global o Drule.zero_var_indexes;
+
+    fun massage_inject (tp $ (eqv $ (_ $ t $ u) $ rhs)) = tp $ (eqv $ mk_fcT_eq (t, u) $ rhs);
+    fun massage_distinct (tp $ (_ $ (_ $ t $ u))) = [tp $ false_eq (t, u), tp $ false_eq (u, t)];
+
+    val triv_inject_goals =
+      map_filter (fn c as (_, T) =>
+          if T = fcT then SOME (HOLogic.mk_Trueprop (true_eq (Const c, Const c))) else NONE)
+        ctrs;
+    val inject_goals = map (massage_inject o monomorphic_prop_of) inject_thms;
+    val distinct_goals = maps (massage_distinct o monomorphic_prop_of) distinct_thms;
+    val refl_goal = HOLogic.mk_Trueprop (true_eq (Free ("x", fcT), Free ("x", fcT)));
+
+    val simp_ctxt =
+      Simplifier.global_context thy HOL_basic_ss
+        addsimps (map Simpdata.mk_eq (@{thms equal eq_True} @ inject_thms @ distinct_thms));
+
+    fun prove goal =
+      Goal.prove_sorry_global thy [] [] goal (K (ALLGOALS (simp_tac simp_ctxt)))
+      |> Simpdata.mk_eq;
+  in
+    (map prove (triv_inject_goals @ inject_goals @ distinct_goals), prove refl_goal)
+  end;
+
+fun add_equality fcT fcT_name As ctrs inject_thms distinct_thms =
+  let
+    fun add_def lthy =
+      let
+        fun mk_side const_name =
+          Const (const_name, fcT --> fcT --> HOLogic.boolT) $ Free ("x", fcT) $ Free ("y", fcT);
+        val spec =
+          mk_Trueprop_eq (mk_side @{const_name HOL.equal}, mk_side @{const_name HOL.eq})
+          |> Syntax.check_term lthy;
+        val ((_, (_, raw_def)), lthy') =
+          Specification.definition (NONE, (Attrib.empty_binding, spec)) lthy;
+        val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy); (* FIXME? *)
+        val def = singleton (Proof_Context.export lthy' ctxt_thy) raw_def;
+      in
+        (def, lthy')
+      end;
+
+    fun tac thms = Class.intro_classes_tac [] THEN ALLGOALS (Proof_Context.fact_tac thms);
+
+    val qualify =
+      Binding.qualify true (Long_Name.base_name fcT_name) o Binding.qualify true eqN o Binding.name;
+  in
+    Class.instantiation ([fcT_name], map dest_TFree As, [HOLogic.class_equal])
+    #> add_def
+    #-> Class.prove_instantiation_exit_result (map o Morphism.thm) (K tac) o single
+    #-> fold Code.del_eqn
+    #> `(mk_free_ctr_equations fcT ctrs inject_thms distinct_thms)
+    #-> (fn (thms, thm) => Global_Theory.note_thmss Thm.lemmaK
+      [((qualify reflN, [Code.add_nbe_default_eqn_attribute]), [([thm], [])]),
+        ((qualify simpsN, [Code.add_default_eqn_attribute]), [(rev thms, [])])])
+    #> snd
+  end;
+
+fun add_ctr_code fcT_name As ctrs inject_thms distinct_thms case_thms thy =
+  let
+    val fcT = Type (fcT_name, As);
+    val unover_ctrs = map (fn ctr as (_, fcT) => (Axclass.unoverload_const thy ctr, fcT)) ctrs;
+  in
+    if can (Code.constrset_of_consts thy) unover_ctrs then
+      thy
+      |> Code.add_datatype ctrs
+      |> fold_rev Code.add_default_eqn case_thms
+      |> Code.add_case (mk_case_certificate thy case_thms)
+      |> not (Sorts.has_instance (Sign.classes_of thy) fcT_name [HOLogic.class_equal])
+        ? add_equality fcT fcT_name As ctrs inject_thms distinct_thms
+    else
+      thy
+  end;
+
+end;