src/HOL/Tools/Datatype/datatype.ML
changeset 45822 843dc212f69e
parent 45821 c2f6c50e3d42
child 45839 43a5b86bc102
--- a/src/HOL/Tools/Datatype/datatype.ML	Mon Dec 12 20:55:57 2011 +0100
+++ b/src/HOL/Tools/Datatype/datatype.ML	Mon Dec 12 23:05:21 2011 +0100
@@ -61,7 +61,7 @@
 (** proof of characteristic theorems **)
 
 fun representation_proofs (config : Datatype_Aux.config) (dt_info : Datatype_Aux.info Symtab.table)
-    descr sorts types_syntax constr_syntax case_names_induct thy =
+    descr types_syntax constr_syntax case_names_induct thy =
   let
     val descr' = flat descr;
     val new_type_names = map (Binding.name_of o fst) types_syntax;
@@ -74,17 +74,16 @@
     val rep_set_names = map (Sign.full_bname thy1) rep_set_names';
 
     val tyvars = map (fn (_, (_, Ts, _)) => map Datatype_Aux.dest_DtTFree Ts) (hd descr);
-    val leafTs' = Datatype_Aux.get_nonrec_types descr' sorts;
-    val branchTs = Datatype_Aux.get_branching_types descr' sorts;
+    val leafTs' = Datatype_Aux.get_nonrec_types descr';
+    val branchTs = Datatype_Aux.get_branching_types descr';
     val branchT =
       if null branchTs then HOLogic.unitT
       else Balanced_Tree.make (fn (T, U) => Type (@{type_name Sum_Type.sum}, [T, U])) branchTs;
     val arities = remove (op =) 0 (Datatype_Aux.get_arities descr');
     val unneeded_vars =
-      subtract (op =) (fold Term.add_tfree_namesT (leafTs' @ branchTs) []) (hd tyvars);
-    val leafTs =
-      leafTs' @ map (fn n => TFree (n, (the o AList.lookup (op =) sorts) n)) unneeded_vars;
-    val recTs = Datatype_Aux.get_rec_types descr' sorts;
+      subtract (op =) (fold Term.add_tfreesT (leafTs' @ branchTs) []) (hd tyvars);
+    val leafTs = leafTs' @ map TFree unneeded_vars;
+    val recTs = Datatype_Aux.get_rec_types descr';
     val (newTs, oldTs) = chop (length (hd descr)) recTs;
     val sumT =
       if null leafTs then HOLogic.unitT
@@ -156,7 +155,7 @@
           (case Datatype_Aux.strip_dtyp dt of
             (dts, Datatype_Aux.DtRec k) =>
               let
-                val Ts = map (Datatype_Aux.typ_of_dtyp descr' sorts) dts;
+                val Ts = map (Datatype_Aux.typ_of_dtyp descr') dts;
                 val free_t =
                   Datatype_Aux.app_bnds (Datatype_Aux.mk_Free "x" (Ts ---> Univ_elT) j) (length Ts)
               in
@@ -166,7 +165,7 @@
                 mk_lim free_t Ts :: ts)
               end
           | _ =>
-              let val T = Datatype_Aux.typ_of_dtyp descr' sorts dt
+              let val T = Datatype_Aux.typ_of_dtyp descr' dt
               in (j + 1, prems, (Leaf $ mk_inj T (Datatype_Aux.mk_Free "x" T j)) :: ts) end);
 
         val (_, prems, ts) = fold_rev mk_prem cargs (1, [], []);
@@ -194,8 +193,7 @@
       |> Sign.parent_path
       |> fold_map
         (fn (((name, mx), tvs), c) =>
-          Typedef.add_typedef_global false NONE
-            (name, map (rpair dummyS) tvs, mx)
+          Typedef.add_typedef_global false NONE (name, tvs, mx)
             (Collect $ Const (c, UnivT')) NONE
             (rtac exI 1 THEN rtac CollectI 1 THEN
               QUIET_BREADTH_FIRST (has_fewer_prems 1)
@@ -223,7 +221,7 @@
       let
         fun constr_arg dt (j, l_args, r_args) =
           let
-            val T = Datatype_Aux.typ_of_dtyp descr' sorts dt;
+            val T = Datatype_Aux.typ_of_dtyp descr' dt;
             val free_t = Datatype_Aux.mk_Free "x" T j;
           in
             (case (Datatype_Aux.strip_dtyp dt, strip_type T) of
@@ -235,7 +233,7 @@
           end;
 
         val (_, l_args, r_args) = fold_rev constr_arg cargs (1, [], []);
-        val constrT = (map (Datatype_Aux.typ_of_dtyp descr' sorts) cargs) ---> T;
+        val constrT = map (Datatype_Aux.typ_of_dtyp descr') cargs ---> T;
         val abs_name = Sign.intern_const thy ("Abs_" ^ tname);
         val rep_name = Sign.intern_const thy ("Rep_" ^ tname);
         val lhs = list_comb (Const (cname, constrT), l_args);
@@ -305,7 +303,7 @@
 
     fun make_iso_def k ks n (cname, cargs) (fs, eqns, i) =
       let
-        val argTs = map (Datatype_Aux.typ_of_dtyp descr' sorts) cargs;
+        val argTs = map (Datatype_Aux.typ_of_dtyp descr') cargs;
         val T = nth recTs k;
         val rep_name = nth all_rep_names k;
         val rep_const = Const (rep_name, T --> Univ_elT);
@@ -313,7 +311,7 @@
 
         fun process_arg ks' dt (i2, i2', ts, Ts) =
           let
-            val T' = Datatype_Aux.typ_of_dtyp descr' sorts dt;
+            val T' = Datatype_Aux.typ_of_dtyp descr' dt;
             val (Us, U) = strip_type T'
           in
             (case Datatype_Aux.strip_dtyp dt of
@@ -556,7 +554,7 @@
       in prove ts end;
 
     val distinct_thms =
-      map2 (prove_distinct_thms) dist_rewrites (Datatype_Prop.make_distincts descr sorts);
+      map2 (prove_distinct_thms) dist_rewrites (Datatype_Prop.make_distincts descr);
 
     (* prove injectivity of constructors *)
 
@@ -582,7 +580,7 @@
 
     val constr_inject =
       map (fn (ts, thms) => map (prove_constr_inj_thm thms) ts)
-        ((Datatype_Prop.make_injs descr sorts) ~~ constr_rep_thms);
+        (Datatype_Prop.make_injs descr ~~ constr_rep_thms);
 
     val ((constr_inject', distinct_thms'), thy6) =
       thy5
@@ -642,7 +640,7 @@
       else map (Free o apfst fst o dest_Var) Ps;
     val indrule_lemma' = cterm_instantiate (map cert Ps ~~ map cert frees) indrule_lemma;
 
-    val dt_induct_prop = Datatype_Prop.make_ind descr sorts;
+    val dt_induct_prop = Datatype_Prop.make_ind descr;
     val dt_induct =
       Skip_Proof.prove_global thy6 []
       (Logic.strip_imp_prems dt_induct_prop)
@@ -698,7 +696,7 @@
     val _ =
       (case duplicates (op =) (map fst new_dts) of
         [] => ()
-      | dups => error ("Duplicate datatypes: " ^ commas dups));
+      | dups => error ("Duplicate datatypes: " ^ commas_quote dups));
 
     fun prep_dt_spec (tvs, tname, mx, constrs) (dts', constr_syntax, sorts, i) =
       let
@@ -721,21 +719,28 @@
       in
         (case duplicates (op =) (map fst constrs') of
           [] =>
-            (dts' @ [(i, (Sign.full_name tmp_thy tname, map Datatype_Aux.DtTFree tvs, constrs'))],
+            (dts' @ [(i, (Sign.full_name tmp_thy tname, tvs, constrs'))],
               constr_syntax @ [constr_syntax'], sorts', i + 1)
         | dups =>
-            error ("Duplicate constructors " ^ commas dups ^ " in datatype " ^ Binding.print tname))
+            error ("Duplicate constructors " ^ commas_quote dups ^
+              " in datatype " ^ Binding.print tname))
       end;
 
-    val (dts', constr_syntax, sorts', i) = fold prep_dt_spec dts ([], [], [], 0);
-    val sorts =
-      sorts' @ map (rpair (Sign.defaultS tmp_thy)) (subtract (op =) (map fst sorts') tyvars);
+    val (dts0, constr_syntax, sorts', i) = fold prep_dt_spec dts ([], [], [], 0);
+    val tmp_ctxt' = tmp_ctxt |> fold (Variable.declare_typ o TFree) sorts';
+
+    val dts' = dts0 |> map (fn (i, (name, tvs, cs)) =>
+      let
+        val args = tvs |>
+          map (fn a => Datatype_Aux.DtTFree (a, Proof_Context.default_sort tmp_ctxt' (a, ~1)));
+      in (i, (name, args, cs)) end);
+
     val dt_info = Datatype_Data.get_all thy;
-    val (descr, _) = Datatype_Aux.unfold_datatypes tmp_ctxt dts' sorts dt_info dts' i;
+    val (descr, _) = Datatype_Aux.unfold_datatypes tmp_ctxt dts' dt_info dts' i;
     val _ =
       Datatype_Aux.check_nonempty descr
         handle (exn as Datatype_Aux.Datatype_Empty s) =>
-          if #strict config then error ("Nonemptiness check failed for datatype " ^ s)
+          if #strict config then error ("Nonemptiness check failed for datatype " ^ quote s)
           else reraise exn;
 
     val _ =
@@ -743,10 +748,10 @@
         ("Constructing datatype(s) " ^ commas_quote (map (Binding.name_of o #2) dts));
   in
     thy
-    |> representation_proofs config dt_info descr sorts types_syntax constr_syntax
+    |> representation_proofs config dt_info descr types_syntax constr_syntax
       (Datatype_Data.mk_case_names_induct (flat descr))
     |-> (fn (inject, distinct, induct) =>
-      Datatype_Data.derive_datatype_props config dt_names descr sorts induct inject distinct)
+      Datatype_Data.derive_datatype_props config dt_names descr induct inject distinct)
   end;
 
 val add_datatype = gen_add_datatype Datatype_Data.cert_typ;