further unification of datatype and rep_datatype
authorhaftmann
Mon, 28 Sep 2009 10:51:12 +0200
changeset 32728 2c55fc50f670
parent 32727 9072201cd69d
child 32729 1441cf4ddc1a
further unification of datatype and rep_datatype
src/HOL/Tools/Datatype/datatype.ML
--- a/src/HOL/Tools/Datatype/datatype.ML	Mon Sep 28 10:20:21 2009 +0200
+++ b/src/HOL/Tools/Datatype/datatype.ML	Mon Sep 28 10:51:12 2009 +0200
@@ -347,41 +347,41 @@
     |> PureThy.add_thmss [((Binding.name "inducts", inducts), [])] |> snd
   end;
 
-fun derive_datatype_props config dt_names alt_names descr sorts induct inject distinct thy1 =
+fun derive_datatype_props config dt_names alt_names descr sorts
+    induct inject (distinct_rules, distinct_rewrites, distinct_entry) thy1 =
   let
     val thy2 = thy1 |> Theory.checkpoint
+    val flat_descr = flat descr;
     val new_type_names = map Long_Name.base_name (the_default dt_names alt_names);
-    val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names);
-    val (case_names_induct, case_names_exhausts) =
-      (mk_case_names_induct descr, mk_case_names_exhausts descr dt_names);
+    val _ = message config ("Deriving properties for datatype(s) " ^ commas_quote new_type_names);
     val inducts = Project_Rule.projections (ProofContext.init thy2) induct;
 
     val (casedist_thms, thy3) = thy2 |>
-      DatatypeAbsProofs.prove_casedist_thms config new_type_names [descr] sorts induct
-        case_names_exhausts;
+      DatatypeAbsProofs.prove_casedist_thms config new_type_names descr sorts induct
+        (mk_case_names_exhausts flat_descr dt_names);
     val ((rec_names, rec_rewrites), thy4) = DatatypeAbsProofs.prove_primrec_thms
-      config new_type_names [descr] sorts (get_all thy3) inject distinct
+      config new_type_names descr sorts (get_all thy3) inject distinct_rewrites
       (Simplifier.theory_context thy3 dist_ss) induct thy3;
     val ((case_rewrites, case_names), thy5) = DatatypeAbsProofs.prove_case_thms
-      config new_type_names [descr] sorts rec_names rec_rewrites thy4;
+      config new_type_names descr sorts rec_names rec_rewrites thy4;
     val (split_thms, thy6) = DatatypeAbsProofs.prove_split_thms
-      config new_type_names [descr] sorts inject distinct casedist_thms case_rewrites thy5;
+      config new_type_names descr sorts inject distinct_rewrites casedist_thms case_rewrites thy5;
     val (nchotomys, thy7) = DatatypeAbsProofs.prove_nchotomys config new_type_names
-      [descr] sorts casedist_thms thy6;
+      descr sorts casedist_thms thy6;
     val (case_congs, thy8) = DatatypeAbsProofs.prove_case_congs new_type_names
-      [descr] sorts nchotomys case_rewrites thy7;
+      descr sorts nchotomys case_rewrites thy7;
     val (weak_case_congs, thy9) = DatatypeAbsProofs.prove_weak_case_congs new_type_names
-      [descr] sorts thy8;
+      descr sorts thy8;
 
-    val simps = flat (distinct @ inject @ case_rewrites) @ rec_rewrites;
-    val dt_infos = map (make_dt_info alt_names descr sorts induct inducts rec_names rec_rewrites)
-      ((0 upto length descr - 1) ~~ descr ~~ case_names ~~ case_rewrites ~~ casedist_thms ~~
-        map FewConstrs distinct ~~ inject ~~ split_thms ~~ nchotomys ~~ case_congs ~~ weak_case_congs);
+    val simps = flat (distinct_rules @ inject @ case_rewrites) @ rec_rewrites;
+    val dt_infos = map (make_dt_info alt_names flat_descr sorts induct inducts rec_names rec_rewrites)
+      ((0 upto length descr - 1) ~~ flat_descr ~~ case_names ~~ case_rewrites ~~ casedist_thms ~~
+        distinct_entry ~~ inject ~~ split_thms ~~ nchotomys ~~ case_congs ~~ weak_case_congs);
     val dt_names = map fst dt_infos;
   in
     thy9
     |> add_case_tr' case_names
-    |> add_rules simps case_rewrites rec_rewrites inject distinct weak_case_congs (Simplifier.attrib (op addcongs))
+    |> add_rules simps case_rewrites rec_rewrites inject distinct_rules weak_case_congs (Simplifier.attrib (op addcongs))
     |> register dt_infos
     |> add_cases_induct dt_infos inducts
     |> Sign.parent_path
@@ -398,17 +398,16 @@
   let
     val raw_distinct = (map o maps) (fn thm => [thm, thm RS not_sym]) half_distinct;
     val new_type_names = map Long_Name.base_name (the_default dt_names alt_names);
-    val (case_names_induct, case_names_exhausts) =
-      (mk_case_names_induct descr, mk_case_names_exhausts descr dt_names);
     val (((inject, distinct), [induct]), thy2) =
       thy1
       |> store_thmss "inject" new_type_names raw_inject
       ||>> store_thmss "distinct" new_type_names raw_distinct
       ||> Sign.add_path (space_implode "_" new_type_names)
-      ||>> PureThy.add_thms [((Binding.name "induct", raw_induct), [case_names_induct])];
+      ||>> PureThy.add_thms [((Binding.name "induct", raw_induct), [mk_case_names_induct descr])];
   in
     thy2
-    |> derive_datatype_props config dt_names alt_names descr sorts induct inject distinct
+    |> derive_datatype_props config dt_names alt_names [descr] sorts
+         induct inject (distinct, distinct, map FewConstrs distinct)
  end;
 
 fun gen_rep_datatype prep_term config after_qed alt_names raw_ts thy =
@@ -482,18 +481,17 @@
 
 (** definitional introduction of datatypes **)
 
-fun add_datatype_def config new_type_names descr sorts types_syntax constr_syntax dt_info
-    case_names_induct case_names_exhausts thy =
+fun add_datatype_def config dt_names new_type_names descr sorts types_syntax constr_syntax dt_info thy =
   let
-    val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names);
-
+    val _ = message config ("Constructing datatype(s) " ^ commas_quote new_type_names);
+    val flat_descr = flat descr;
     val ((inject, distinct, dist_rewrites, simproc_dists, induct), thy2) = thy |>
       DatatypeRepProofs.representation_proofs config dt_info new_type_names descr sorts
-        types_syntax constr_syntax case_names_induct;
-    val inducts = Project_Rule.projections (ProofContext.init thy2) induct;
+        types_syntax constr_syntax (mk_case_names_induct flat_descr);
 
+    val inducts = Project_Rule.projections (ProofContext.init thy2) induct;
     val (casedist_thms, thy3) = DatatypeAbsProofs.prove_casedist_thms config new_type_names descr
-      sorts induct case_names_exhausts thy2;
+      sorts induct (mk_case_names_exhausts flat_descr dt_names) thy2;
     val ((rec_names, rec_rewrites), thy4) = DatatypeAbsProofs.prove_primrec_thms
       config new_type_names descr sorts dt_info inject dist_rewrites
       (Simplifier.theory_context thy3 dist_ss) induct thy3;
@@ -509,7 +507,7 @@
       descr sorts thy9;
 
     val dt_infos = map
-      (make_dt_info (SOME new_type_names) (flat descr) sorts induct inducts rec_names rec_rewrites)
+      (make_dt_info (SOME new_type_names) flat_descr sorts induct inducts rec_names rec_rewrites)
       ((0 upto length (hd descr) - 1) ~~ hd descr ~~ case_names ~~ case_rewrites ~~
         casedist_thms ~~ simproc_dists ~~ inject ~~ split_thms ~~ nchotomys ~~ case_congs ~~ weak_case_congs);
 
@@ -534,7 +532,6 @@
     val _ = Theory.requires thy "Datatype" "datatype definitions";
 
     (* this theory is used just for parsing *)
-
     val tmp_thy = thy |>
       Theory.copy |>
       Sign.add_types (map (fn (tvs, tname, mx, _) =>
@@ -549,6 +546,7 @@
           | dups => error ("Duplicate parameter(s) for datatype " ^ quote (Binding.str_of tname) ^
               " : " ^ commas dups))
       end) dts);
+    val dt_names = map fst new_dts;
 
     val _ = (case duplicates (op =) (map fst new_dts) @ duplicates (op =) new_type_names of
       [] => () | dups => error ("Duplicate datatypes: " ^ commas dups));
@@ -590,14 +588,8 @@
     val _ = check_nonempty descr handle (exn as Datatype_Empty s) =>
       if #strict config then error ("Nonemptiness check failed for datatype " ^ s)
       else raise exn;
-
-    val descr' = flat descr;
-    val case_names_induct = mk_case_names_induct descr';
-    val case_names_exhausts = mk_case_names_exhausts descr' (map #1 new_dts);
   in
-    add_datatype_def
-      config new_type_names descr sorts types_syntax constr_syntax dt_info
-      case_names_induct case_names_exhausts thy
+    add_datatype_def config dt_names new_type_names descr sorts types_syntax constr_syntax dt_info thy
   end;
 
 val add_datatype = gen_add_datatype cert_typ;