'datatype' specifications allow explicit sort constraints;
authorwenzelm
Tue, 13 Dec 2011 23:23:51 +0100
changeset 45839 43a5b86bc102
parent 45838 653c84d5c6c9
child 45842 3fd2cd187299
'datatype' specifications allow explicit sort constraints; tuned signatures;
NEWS
doc-src/IsarRef/Thy/HOL_Specific.thy
doc-src/IsarRef/Thy/document/HOL_Specific.tex
src/HOL/Nominal/nominal_atoms.ML
src/HOL/Nominal/nominal_datatype.ML
src/HOL/SPARK/Tools/spark_vcs.ML
src/HOL/Tools/Datatype/datatype.ML
src/HOL/Tools/Datatype/datatype_aux.ML
src/HOL/Tools/Datatype/datatype_data.ML
src/HOL/Tools/inductive_realizer.ML
--- a/NEWS	Tue Dec 13 20:29:59 2011 +0100
+++ b/NEWS	Tue Dec 13 23:23:51 2011 +0100
@@ -53,6 +53,8 @@
 
 *** HOL ***
 
+* 'datatype' specifications allow explicit sort constraints.
+
 * Theory HOL/Library/Diagonalize has been removed. INCOMPATIBILITY, use
 theory HOL/Library/Nat_Bijection instead.
 
--- a/doc-src/IsarRef/Thy/HOL_Specific.thy	Tue Dec 13 20:29:59 2011 +0100
+++ b/doc-src/IsarRef/Thy/HOL_Specific.thy	Tue Dec 13 23:23:51 2011 +0100
@@ -693,7 +693,7 @@
     @@{command (HOL) rep_datatype} ('(' (@{syntax name} +) ')')? (@{syntax term} +)
     ;
 
-    spec: @{syntax typespec} @{syntax mixfix}? '=' (cons + '|')
+    spec: @{syntax typespec_sorts} @{syntax mixfix}? '=' (cons + '|')
     ;
     cons: @{syntax name} (@{syntax type} * ) @{syntax mixfix}?
   "}
--- a/doc-src/IsarRef/Thy/document/HOL_Specific.tex	Tue Dec 13 20:29:59 2011 +0100
+++ b/doc-src/IsarRef/Thy/document/HOL_Specific.tex	Tue Dec 13 23:23:51 2011 +0100
@@ -1036,7 +1036,7 @@
 \rail@endplus
 \rail@end
 \rail@begin{2}{\isa{spec}}
-\rail@nont{\hyperlink{syntax.typespec}{\mbox{\isa{typespec}}}}[]
+\rail@nont{\hyperlink{syntax.typespec-sorts}{\mbox{\isa{typespec{\isaliteral{5F}{\isacharunderscore}}sorts}}}}[]
 \rail@bar
 \rail@nextbar{1}
 \rail@nont{\hyperlink{syntax.mixfix}{\mbox{\isa{mixfix}}}}[]
--- a/src/HOL/Nominal/nominal_atoms.ML	Tue Dec 13 20:29:59 2011 +0100
+++ b/src/HOL/Nominal/nominal_atoms.ML	Tue Dec 13 23:23:51 2011 +0100
@@ -99,7 +99,7 @@
     
     val (_,thy1) = 
     fold_map (fn ak => fn thy => 
-          let val dt = ([], Binding.name ak, NoSyn, [(Binding.name ak, [@{typ nat}], NoSyn)])
+          let val dt = ((Binding.name ak, [], NoSyn), [(Binding.name ak, [@{typ nat}], NoSyn)])
               val (dt_names, thy1) = Datatype.add_datatype Datatype.default_config [dt] thy;
             
               val injects = maps (#inject o Datatype.the_info thy1) dt_names;
--- a/src/HOL/Nominal/nominal_datatype.ML	Tue Dec 13 20:29:59 2011 +0100
+++ b/src/HOL/Nominal/nominal_datatype.ML	Tue Dec 13 23:23:51 2011 +0100
@@ -6,9 +6,7 @@
 
 signature NOMINAL_DATATYPE =
 sig
-  val add_nominal_datatype : Datatype.config ->
-    (string list * binding * mixfix * (binding * string list * mixfix) list) list ->
-    theory -> theory
+  val add_nominal_datatype : Datatype.config -> Datatype.spec_cmd list -> theory -> theory
   type descr
   type nominal_datatype_info
   val get_nominal_datatypes : theory -> nominal_datatype_info Symtab.table
@@ -187,30 +185,16 @@
 fun fresh_star_const T U =
   Const ("Nominal.fresh_star", HOLogic.mk_setT T --> U --> HOLogic.boolT);
 
-fun gen_add_nominal_datatype prep_typ config dts thy =
+fun gen_add_nominal_datatype prep_specs config dts thy =
   let
-    val new_type_names = map (Binding.name_of o #2) dts;
-
+    val new_type_names = map (fn ((tname, _, _), _) => Binding.name_of tname) dts;
 
-    (* this theory is used just for parsing *)
-
-    val tmp_thy = thy |>
-      Theory.copy |>
-      Sign.add_types_global (map (fn (tvs, tname, mx, _) => (tname, length tvs, mx)) dts);
+    val (dts', _) = prep_specs dts thy;
 
     val atoms = atoms_of thy;
 
-    fun prep_constr (cname, cargs, mx) (constrs, sorts) =
-      let val (cargs', sorts') = fold_map (prep_typ tmp_thy) cargs sorts
-      in (constrs @ [(cname, cargs', mx)], sorts') end
-
-    fun prep_dt_spec (tvs, tname, mx, constrs) (dts, sorts) =
-      let val (constrs', sorts') = fold prep_constr constrs ([], sorts)
-      in (dts @ [(tvs, tname, mx, constrs')], sorts') end
-
-    val (dts', sorts) = fold prep_dt_spec dts ([], []);
-    val tyvars = map (map (fn s =>
-      (s, the (AList.lookup (op =) sorts s))) o #1) dts';
+    val tyvars = map (fn ((_, tvs, _), _) => tvs) dts';
+    val sorts = flat tyvars;
 
     fun inter_sort thy S S' = Sign.inter_sort thy (S, S');
     fun augment_sort_typ thy S =
@@ -220,12 +204,12 @@
       end;
     fun augment_sort thy S = map_types (augment_sort_typ thy S);
 
-    val types_syntax = map (fn (tvs, tname, mx, constrs) => (tname, mx)) dts';
-    val constr_syntax = map (fn (tvs, tname, mx, constrs) =>
+    val types_syntax = map (fn ((tname, tvs, mx), constrs) => (tname, mx)) dts';
+    val constr_syntax = map (fn (_, constrs) =>
       map (fn (cname, cargs, mx) => (cname, mx)) constrs) dts';
 
-    val ps = map (fn (_, n, _, _) =>
-      (Sign.full_name tmp_thy n, Sign.full_name tmp_thy (Binding.suffix_name "_Rep" n))) dts;
+    val ps = map (fn ((n, _, _), _) =>
+      (Sign.full_name thy n, Sign.full_name thy (Binding.suffix_name "_Rep" n))) dts;
     val rps = map Library.swap ps;
 
     fun replace_types (Type ("Nominal.ABS", [T, U])) =
@@ -234,8 +218,8 @@
           Type (the_default s (AList.lookup op = ps s), map replace_types Ts)
       | replace_types T = T;
 
-    val dts'' = map (fn (tvs, tname, mx, constrs) =>
-      (tvs, Binding.suffix_name "_Rep" tname, NoSyn,
+    val dts'' = map (fn ((tname, tvs, mx), constrs) =>
+      ((Binding.suffix_name "_Rep" tname, tvs, NoSyn),
         map (fn (cname, cargs, mx) => (Binding.suffix_name "_Rep" cname,
           map replace_types cargs, NoSyn)) constrs)) dts';
 
@@ -2081,11 +2065,11 @@
     thy13
   end;
 
-val add_nominal_datatype = gen_add_nominal_datatype Datatype.read_typ;
+val add_nominal_datatype = gen_add_nominal_datatype Datatype.read_specs;
 
 val _ =
   Outer_Syntax.command "nominal_datatype" "define inductive datatypes" Keyword.thy_decl
-    (Parse.and_list1 Datatype.parse_decl
+    (Parse.and_list1 Datatype.spec_cmd
       >> (Toplevel.theory o add_nominal_datatype Datatype.default_config));
 
 end
--- a/src/HOL/SPARK/Tools/spark_vcs.ML	Tue Dec 13 20:29:59 2011 +0100
+++ b/src/HOL/SPARK/Tools/spark_vcs.ML	Tue Dec 13 23:23:51 2011 +0100
@@ -306,8 +306,7 @@
               in
                 (thy |>
                  Datatype.add_datatype {strict = true, quiet = true}
-                   [([], tyb, NoSyn,
-                     map (fn s => (Binding.name s, [], NoSyn)) els)] |> snd |>
+                   [((tyb, [], NoSyn), map (fn s => (Binding.name s, [], NoSyn)) els)] |> snd |>
                  add_enum_type s tyname,
                  tyname)
               end
--- a/src/HOL/Tools/Datatype/datatype.ML	Tue Dec 13 20:29:59 2011 +0100
+++ b/src/HOL/Tools/Datatype/datatype.ML	Tue Dec 13 23:23:51 2011 +0100
@@ -10,13 +10,17 @@
 signature DATATYPE =
 sig
   include DATATYPE_DATA
-  val add_datatype: config ->
-    (string list * binding * mixfix * (binding * typ list * mixfix) list) list ->
-    theory -> string list * theory
-  val add_datatype_cmd:
-    (string list * binding * mixfix * (binding * string list * mixfix) list) list ->
-    theory -> theory
-  val parse_decl: (string list * binding * mixfix * (binding * string list * mixfix) list) parser
+  type spec =
+    (binding * (string * sort) list * mixfix) *
+    (binding * typ list * mixfix) list
+  type spec_cmd =
+    (binding * (string * string option) list * mixfix) *
+    (binding * string list * mixfix) list
+  val read_specs: spec_cmd list -> theory -> spec list * Proof.context
+  val check_specs: spec list -> theory -> spec list * Proof.context
+  val add_datatype: config -> spec list -> theory -> string list * theory
+  val add_datatype_cmd: spec_cmd list -> theory -> theory
+  val spec_cmd: spec_cmd parser
 end;
 
 structure Datatype : DATATYPE =
@@ -670,27 +674,74 @@
 
 (** datatype definition **)
 
-fun gen_add_datatype prep_typ config dts thy =
+(* specifications *)
+
+type spec = (binding * (string * sort) list * mixfix) * (binding * typ list * mixfix) list;
+
+type spec_cmd =
+  (binding * (string * string option) list * mixfix) * (binding * string list * mixfix) list;
+
+local
+
+fun parse_spec ctxt ((b, args, mx), constrs) =
+  ((b, map (apsnd (Typedecl.read_constraint ctxt)) args, mx),
+    constrs |> map (fn (c, Ts, mx') => (c, map (Syntax.parse_typ ctxt) Ts, mx')));
+
+fun check_specs ctxt (specs: spec list) =
+  let
+    fun prep_spec ((tname, args, mx), constrs) tys =
+      let
+        val (args', tys1) = chop (length args) tys;
+        val (constrs', tys3) = (constrs, tys1) |-> fold_map (fn (cname, cargs, mx') => fn tys2 =>
+          let val (cargs', tys3) = chop (length cargs) tys2;
+          in ((cname, cargs', mx'), tys3) end);
+      in (((tname, map dest_TFree args', mx), constrs'), tys3) end;
+
+    val all_tys =
+      specs |> maps (fn ((_, args, _), cs) => map TFree args @ maps #2 cs)
+      |> Syntax.check_typs ctxt;
+
+  in #1 (fold_map prep_spec specs all_tys) end;
+
+fun prep_specs parse raw_specs thy =
+  let
+    val ctxt = thy
+      |> Theory.copy
+      |> Sign.add_types_global (map (fn ((b, args, mx), _) => (b, length args, mx)) raw_specs)
+      |> Proof_Context.init_global
+      |> fold (fn ((_, args, _), _) => fold (fn (a, _) =>
+          Variable.declare_typ (TFree (a, dummyS))) args) raw_specs;
+    val specs = check_specs ctxt (map (parse ctxt) raw_specs);
+  in (specs, ctxt) end;
+
+in
+
+val read_specs = prep_specs parse_spec;
+val check_specs = prep_specs (K I);
+
+end;
+
+
+(* main commands *)
+
+fun gen_add_datatype prep_specs config raw_specs thy =
   let
     val _ = Theory.requires thy "Datatype" "datatype definitions";
 
-    (* this theory is used just for parsing *)
-    val tmp_thy = thy
-      |> Theory.copy
-      |> Sign.add_types_global (map (fn (tvs, tname, mx, _) => (tname, length tvs, mx)) dts);
-    val tmp_ctxt = Proof_Context.init_global tmp_thy;
+    val (dts, spec_ctxt) = prep_specs raw_specs thy;
+    val ((_, tyvars, _), _) :: _ = dts;
+    val string_of_tyvar = Syntax.string_of_typ spec_ctxt o TFree;
 
-    val (tyvars, _, _, _) ::_ = dts;
-    val (new_dts, types_syntax) = ListPair.unzip (map (fn (tvs, tname, mx, _) =>
-      let val full_tname = Sign.full_name tmp_thy tname in
+    val (new_dts, types_syntax) = dts |> map (fn ((tname, tvs, mx), _) =>
+      let val full_tname = Sign.full_name thy tname in
         (case duplicates (op =) tvs of
           [] =>
             if eq_set (op =) (tyvars, tvs) then ((full_tname, tvs), (tname, mx))
-            else error ("Mutually recursive datatypes must have same type parameters")
+            else error "Mutually recursive datatypes must have same type parameters"
         | dups =>
             error ("Duplicate parameter(s) for datatype " ^ Binding.print tname ^
-              " : " ^ commas dups))
-      end) dts);
+              " : " ^ commas (map string_of_tyvar dups)))
+      end) |> split_list;
     val dt_names = map fst new_dts;
 
     val _ =
@@ -698,45 +749,37 @@
         [] => ()
       | dups => error ("Duplicate datatypes: " ^ commas_quote dups));
 
-    fun prep_dt_spec (tvs, tname, mx, constrs) (dts', constr_syntax, sorts, i) =
+    fun prep_dt_spec ((tname, tvs, mx), constrs) (dts', constr_syntax, i) =
       let
-        fun prep_constr (cname, cargs, mx') (constrs, constr_syntax', sorts') =
+        fun prep_constr (cname, cargs, mx') (constrs, constr_syntax') =
           let
-            val (cargs', sorts'') = fold_map (prep_typ tmp_thy) cargs sorts';
             val _ =
-              (case subtract (op =) tvs (fold Term.add_tfree_namesT cargs' []) of
+              (case subtract (op =) tvs (fold Term.add_tfreesT cargs []) of
                 [] => ()
-              | vs => error ("Extra type variables on rhs: " ^ commas vs));
-            val c = Sign.full_name_path tmp_thy (Binding.name_of tname) cname;
+              | vs => error ("Extra type variables on rhs: " ^ commas (map string_of_tyvar vs)));
+            val c = Sign.full_name_path thy (Binding.name_of tname) cname;
           in
-            (constrs @ [(c, map (Datatype_Aux.dtyp_of_typ new_dts) cargs')],
-              constr_syntax' @ [(cname, mx')], sorts'')
+            (constrs @ [(c, map (Datatype_Aux.dtyp_of_typ new_dts) cargs)],
+              constr_syntax' @ [(cname, mx')])
           end handle ERROR msg =>
             cat_error msg ("The error above occurred in constructor " ^ Binding.print cname ^
               " of datatype " ^ Binding.print tname);
 
-        val (constrs', constr_syntax', sorts') = fold prep_constr constrs ([], [], sorts);
+        val (constrs', constr_syntax') = fold prep_constr constrs ([], []);
       in
         (case duplicates (op =) (map fst constrs') of
           [] =>
-            (dts' @ [(i, (Sign.full_name tmp_thy tname, tvs, constrs'))],
-              constr_syntax @ [constr_syntax'], sorts', i + 1)
+            (dts' @ [(i, (Sign.full_name thy tname, map Datatype_Aux.DtTFree tvs, constrs'))],
+              constr_syntax @ [constr_syntax'], i + 1)
         | dups =>
             error ("Duplicate constructors " ^ commas_quote dups ^
               " in datatype " ^ Binding.print tname))
       end;
 
-    val (dts0, constr_syntax, sorts', i) = fold prep_dt_spec dts ([], [], [], 0);
-    val tmp_ctxt' = tmp_ctxt |> fold (Variable.declare_typ o TFree) sorts';
-
-    val dts' = dts0 |> map (fn (i, (name, tvs, cs)) =>
-      let
-        val args = tvs |>
-          map (fn a => Datatype_Aux.DtTFree (a, Proof_Context.default_sort tmp_ctxt' (a, ~1)));
-      in (i, (name, args, cs)) end);
+    val (dts', constr_syntax, i) = fold prep_dt_spec dts ([], [], 0);
 
     val dt_info = Datatype_Data.get_all thy;
-    val (descr, _) = Datatype_Aux.unfold_datatypes tmp_ctxt dts' dt_info dts' i;
+    val (descr, _) = Datatype_Aux.unfold_datatypes spec_ctxt dts' dt_info dts' i;
     val _ =
       Datatype_Aux.check_nonempty descr
         handle (exn as Datatype_Aux.Datatype_Empty s) =>
@@ -745,7 +788,7 @@
 
     val _ =
       Datatype_Aux.message config
-        ("Constructing datatype(s) " ^ commas_quote (map (Binding.name_of o #2) dts));
+        ("Constructing datatype(s) " ^ commas_quote (map (Binding.name_of o #1 o #1) dts));
   in
     thy
     |> representation_proofs config dt_info descr types_syntax constr_syntax
@@ -754,20 +797,20 @@
       Datatype_Data.derive_datatype_props config dt_names descr induct inject distinct)
   end;
 
-val add_datatype = gen_add_datatype Datatype_Data.cert_typ;
-val add_datatype_cmd = snd oo gen_add_datatype Datatype_Data.read_typ Datatype_Aux.default_config;
+val add_datatype = gen_add_datatype check_specs;
+val add_datatype_cmd = snd oo gen_add_datatype read_specs Datatype_Aux.default_config;
 
 
-(* concrete syntax *)
+(* outer syntax *)
 
-val parse_decl =
-  Parse.type_args -- Parse.binding -- Parse.opt_mixfix --
+val spec_cmd =
+  Parse.type_args_constrained -- Parse.binding -- Parse.opt_mixfix --
   (Parse.$$$ "=" |-- Parse.enum1 "|" (Parse.binding -- Scan.repeat Parse.typ -- Parse.opt_mixfix))
-  >> (fn (((vs, t), mx), cons) => (vs, t, mx, map Parse.triple1 cons));
+  >> (fn (((vs, t), mx), cons) => ((t, vs, mx), map Parse.triple1 cons));
 
 val _ =
   Outer_Syntax.command "datatype" "define inductive datatypes" Keyword.thy_decl
-    (Parse.and_list1 parse_decl >> (Toplevel.theory o add_datatype_cmd));
+    (Parse.and_list1 spec_cmd >> (Toplevel.theory o add_datatype_cmd));
 
 
 open Datatype_Data;
--- a/src/HOL/Tools/Datatype/datatype_aux.ML	Tue Dec 13 20:29:59 2011 +0100
+++ b/src/HOL/Tools/Datatype/datatype_aux.ML	Tue Dec 13 23:23:51 2011 +0100
@@ -57,7 +57,7 @@
   exception Datatype
   exception Datatype_Empty of string
   val name_of_typ : typ -> string
-  val dtyp_of_typ : (string * string list) list -> typ -> dtyp
+  val dtyp_of_typ : (string * (string * sort) list) list -> typ -> dtyp
   val mk_Free : string -> typ -> int -> term
   val is_rec_type : dtyp -> bool
   val typ_of_dtyp : descr -> dtyp -> typ
@@ -242,7 +242,7 @@
       (case AList.lookup (op =) new_dts tname of
         NONE => DtType (tname, map (dtyp_of_typ new_dts) Ts)
       | SOME vs =>
-          if map (try (fst o dest_TFree)) Ts = map SOME vs then
+          if map (try dest_TFree) Ts = map SOME vs then
             DtRec (find_index (curry op = tname o fst) new_dts)
           else error ("Illegal occurrence of recursive type " ^ quote tname));
 
--- a/src/HOL/Tools/Datatype/datatype_data.ML	Tue Dec 13 20:29:59 2011 +0100
+++ b/src/HOL/Tools/Datatype/datatype_data.ML	Tue Dec 13 23:23:51 2011 +0100
@@ -28,8 +28,6 @@
   val make_case :  Proof.context -> Datatype_Case.config -> string list -> term ->
     (term * term) list -> term
   val strip_case : Proof.context -> bool -> term -> (term * (term * term) list) option
-  val read_typ: theory -> string -> (string * sort) list -> typ * (string * sort) list
-  val cert_typ: theory -> typ -> (string * sort) list -> typ * (string * sort) list
   val mk_case_names_induct: descr -> attribute
   val setup: theory -> theory
 end;
@@ -171,27 +169,6 @@
 
 (** various auxiliary **)
 
-(* prepare datatype specifications *)
-
-fun read_typ thy str sorts =
-  let
-    val ctxt = Proof_Context.init_global thy
-      |> fold (Variable.declare_typ o TFree) sorts;
-    val T = Syntax.read_typ ctxt str;
-  in (T, Term.add_tfreesT T sorts) end;
-
-fun cert_typ sign raw_T sorts =
-  let
-    val T = Type.no_tvars (Sign.certify_typ sign raw_T)
-      handle TYPE (msg, _, _) => error msg;
-    val sorts' = Term.add_tfreesT T sorts;
-    val _ =
-      (case duplicates (op =) (map fst sorts') of
-        [] => ()
-      | dups => error ("Inconsistent sort constraints for " ^ commas dups));
-  in (T, sorts') end;
-
-
 (* case names *)
 
 local
@@ -427,8 +404,7 @@
         (TFree o (the o AList.lookup (op =) (map fst raw_vs ~~ vs)) o fst o dest_TFree) T);
 
     val cs = map (apsnd (map norm_constr)) raw_cs;
-    val dtyps_of_typ =
-      map (Datatype_Aux.dtyp_of_typ (map (rpair (map fst vs) o fst) cs)) o binder_types;
+    val dtyps_of_typ = map (Datatype_Aux.dtyp_of_typ (map (rpair vs o fst) cs)) o binder_types;
     val dt_names = map fst cs;
 
     fun mk_spec (i, (tyco, constr)) =
--- a/src/HOL/Tools/inductive_realizer.ML	Tue Dec 13 20:29:59 2011 +0100
+++ b/src/HOL/Tools/inductive_realizer.ML	Tue Dec 13 23:23:51 2011 +0100
@@ -69,8 +69,9 @@
         filter_out (equal Extraction.nullT) (map
           (Logic.unvarifyT_global o Extraction.etype_of thy vs []) (prems_of intr)),
             NoSyn);
-  in (map (fn a => "'" ^ a) vs @ map (fst o fst) iTs, tname, NoSyn,
-    map constr_of_intr intrs)
+  in
+    ((tname, map (rpair dummyS) (map (fn a => "'" ^ a) vs @ map (fst o fst) iTs), NoSyn),
+      map constr_of_intr intrs)
   end;
 
 fun mk_rlz T = Const ("realizes", [T, HOLogic.boolT] ---> HOLogic.boolT);
@@ -233,8 +234,9 @@
       end) concls rec_names)
   end;
 
-fun add_dummy name dname (x as (_, (vs, s, mfx, cs))) =
-  if Binding.eq_name (name, s) then (true, (vs, s, mfx, (dname, [HOLogic.unitT], NoSyn) :: cs))
+fun add_dummy name dname (x as (_, ((s, vs, mx), cs))) =
+  if Binding.eq_name (name, s)
+  then (true, ((s, vs, mx), (dname, [HOLogic.unitT], NoSyn) :: cs))
   else x;
 
 fun add_dummies f [] _ thy =