src/HOL/Tools/record.ML
changeset 36151 b89a2a05a3ce
parent 36137 0be811a98d3a
child 36153 1ac501e16a6a
--- a/src/HOL/Tools/record.ML	Thu Apr 15 16:55:49 2010 +0200
+++ b/src/HOL/Tools/record.ML	Thu Apr 15 16:58:12 2010 +0200
@@ -54,9 +54,9 @@
   val print_records: theory -> unit
   val read_typ: Proof.context -> string -> (string * sort) list -> typ * (string * sort) list
   val cert_typ: Proof.context -> typ -> (string * sort) list -> typ * (string * sort) list
-  val add_record: bool -> string list * binding -> (typ list * string) option ->
+  val add_record: bool -> (string * sort) list * binding -> (typ list * string) option ->
     (binding * typ * mixfix) list -> theory -> theory
-  val add_record_cmd: bool -> string list * binding -> string option ->
+  val add_record_cmd: bool -> (string * string option) list * binding -> string option ->
     (binding * string * mixfix) list -> theory -> theory
   val setup: theory -> theory
 end;
@@ -64,7 +64,8 @@
 
 signature ISO_TUPLE_SUPPORT =
 sig
-  val add_iso_tuple_type: bstring * string list -> typ * typ -> theory -> (term * term) * theory
+  val add_iso_tuple_type: bstring * (string * sort) list ->
+    typ * typ -> theory -> (term * term) * theory
   val mk_cons_tuple: term * term -> term
   val dest_cons_tuple: term -> term * term
   val iso_tuple_intros_tac: int -> tactic
@@ -900,10 +901,9 @@
     val midx = maxidx_of_typ T;
     val varifyT = varifyT midx;
 
-    fun mk_type_abbr subst name alphas =
-      let val abbrT = Type (name, map (fn a => varifyT (TFree (a, HOLogic.typeS))) alphas) in
-        Syntax.term_of_typ (! Syntax.show_sorts) (Envir.norm_type subst abbrT)
-      end;
+    fun mk_type_abbr subst name args =
+      let val abbrT = Type (name, map (varifyT o TFree) args)
+      in Syntax.term_of_typ (! Syntax.show_sorts) (Envir.norm_type subst abbrT) end;
 
     fun match rT T = Sign.typ_match thy (varifyT rT, T) Vartab.empty;
   in
@@ -912,7 +912,7 @@
         SOME (name, _) =>
           if name = last_ext then
             let val subst = match schemeT T in
-              if HOLogic.is_unitT (Envir.norm_type subst (varifyT (TFree (zeta, HOLogic.typeS))))
+              if HOLogic.is_unitT (Envir.norm_type subst (varifyT (TFree zeta)))
               then mk_type_abbr subst abbr alphas
               else mk_type_abbr subst (suffix schemeN abbr) (alphas @ [zeta])
             end handle Type.TYPE_MATCH => record_type_tr' ctxt tm
@@ -1639,11 +1639,10 @@
     val fields_moreTs = fieldTs @ [moreT];
 
     val alphas_zeta = alphas @ [zeta];
-    val alphas_zetaTs = map (fn a => TFree (a, HOLogic.typeS)) alphas_zeta;
 
     val ext_binding = Binding.name (suffix extN base_name);
     val ext_name = suffix extN name;
-    val extT = Type (suffix ext_typeN name, alphas_zetaTs);
+    val extT = Type (suffix ext_typeN name, map TFree alphas_zeta);
     val ext_type = fields_moreTs ---> extT;
 
 
@@ -1846,10 +1845,8 @@
 
 (* record_definition *)
 
-fun record_definition (args, binding) parent (parents: parent_info list) raw_fields thy =
+fun record_definition (alphas, binding) parent (parents: parent_info list) raw_fields thy =
   let
-    val alphas = map fst args;
-
     val name = Sign.full_name thy binding;
     val full = Sign.full_name_path thy (Binding.name_of binding); (* FIXME Binding.qualified (!?) *)
 
@@ -1869,7 +1866,7 @@
     val fields = map (apfst full) bfields;
     val names = map fst fields;
     val types = map snd fields;
-    val alphas_fields = fold Term.add_tfree_namesT types [];
+    val alphas_fields = fold Term.add_tfreesT types [];
     val alphas_ext = inter (op =) alphas_fields alphas;
     val len = length fields;
     val variants =
@@ -1885,9 +1882,8 @@
     val all_vars = parent_vars @ vars;
     val all_named_vars = (parent_names ~~ parent_vars) @ named_vars;
 
-
-    val zeta = Name.variant alphas "'z";
-    val moreT = TFree (zeta, HOLogic.typeS);
+    val zeta = (Name.variant (map #1 alphas) "'z", HOLogic.typeS);
+    val moreT = TFree zeta;
     val more = Free (moreN, moreT);
     val full_moreN = full (Binding.name moreN);
     val bfields_more = bfields @ [(Binding.name moreN, moreT)];
@@ -1978,8 +1974,8 @@
 
     (*record (scheme) type abbreviation*)
     val recordT_specs =
-      [(Binding.suffix_name schemeN binding, alphas @ [zeta], rec_schemeT0, NoSyn),
-        (binding, alphas, recT0, NoSyn)];
+      [(Binding.suffix_name schemeN binding, map #1 (alphas @ [zeta]), rec_schemeT0, NoSyn),
+        (binding, map #1 alphas, recT0, NoSyn)];
 
     val ext_defs = ext_def :: map #ext_def parents;
 
@@ -2349,7 +2345,7 @@
            ((Binding.name "iffs", iffs), [iff_add])];
 
     val info =
-      make_record_info args parent fields extension
+      make_record_info alphas parent fields extension
         ext_induct ext_inject ext_surjective ext_split ext_def
         sel_convs' upd_convs' sel_defs' upd_defs' fold_congs' unfold_congs' splits' derived_defs'
         surjective' equality' induct_scheme' induct' cases_scheme' cases' simps' iffs';
@@ -2371,10 +2367,25 @@
 
 (* add_record *)
 
-(*We do all preparations and error checks here, deferring the real
-  work to record_definition.*)
-fun gen_add_record prep_typ prep_raw_parent quiet_mode
-    (params, binding) raw_parent raw_fields thy =
+local
+
+fun read_parent NONE ctxt = (NONE, ctxt)
+  | read_parent (SOME raw_T) ctxt =
+      (case ProofContext.read_typ_abbrev ctxt raw_T of
+        Type (name, Ts) => (SOME (Ts, name), fold Variable.declare_typ Ts ctxt)
+      | T => error ("Bad parent record specification: " ^ Syntax.string_of_typ ctxt T));
+
+fun prep_field prep (x, T, mx) = (x, prep T, mx)
+  handle ERROR msg =>
+    cat_error msg ("The error(s) above occurred in record field " ^ quote (Binding.str_of x));
+
+fun read_field raw_field ctxt =
+  let val field as (_, T, _) = prep_field (Syntax.read_typ ctxt) raw_field
+  in (field, Variable.declare_typ T ctxt) end;
+
+in
+
+fun add_record quiet_mode (params, binding) raw_parent raw_fields thy =
   let
     val _ = Theory.requires thy "Record" "record definitions";
     val _ =
@@ -2382,40 +2393,19 @@
       else writeln ("Defining record " ^ quote (Binding.str_of binding) ^ " ...");
 
     val ctxt = ProofContext.init thy;
-
-
-    (* parents *)
-
-    fun prep_inst T = fst (cert_typ ctxt T []);
-
-    val parent = Option.map (apfst (map prep_inst) o prep_raw_parent ctxt) raw_parent
-      handle ERROR msg => cat_error msg ("The error(s) above in parent record specification");
+    fun cert_typ T = Type.no_tvars (ProofContext.cert_typ ctxt T)
+      handle TYPE (msg, _, _) => error msg;
+
+
+    (* specification *)
+
+    val parent = Option.map (apfst (map cert_typ)) raw_parent
+      handle ERROR msg =>
+        cat_error msg ("The error(s) above occurred in parent record specification");
+    val parent_args = (case parent of SOME (Ts, _) => Ts | NONE => []);
     val parents = add_parents thy parent [];
 
-    val init_env =
-      (case parent of
-        NONE => []
-      | SOME (types, _) => fold Term.add_tfreesT types []);
-
-
-    (* fields *)
-
-    fun prep_field (x, raw_T, mx) env =
-      let
-        val (T, env') =
-          prep_typ ctxt raw_T env handle ERROR msg =>
-            cat_error msg ("The error(s) above occured in record field " ^ quote (Binding.str_of x));
-      in ((x, T, mx), env') end;
-
-    val (bfields, envir) = fold_map prep_field raw_fields init_env;
-    val envir_names = map fst envir;
-
-
-    (* args *)
-
-    val defaultS = Sign.defaultS thy;
-    val args = map (fn x => (x, AList.lookup (op =) envir x |> the_default defaultS)) params;
-
+    val bfields = map (prep_field cert_typ) raw_fields;
 
     (* errors *)
 
@@ -2424,15 +2414,12 @@
       if is_none (get_record thy name) then []
       else ["Duplicate definition of record " ^ quote name];
 
-    val err_dup_parms =
-      (case duplicates (op =) params of
+    val spec_frees = fold Term.add_tfreesT (parent_args @ map #2 bfields) [];
+    val err_extra_frees =
+      (case subtract (op =) params spec_frees of
         [] => []
-      | dups => ["Duplicate parameter(s) " ^ commas dups]);
-
-    val err_extra_frees =
-      (case subtract (op =) params envir_names of
-        [] => []
-      | extras => ["Extra free type variable(s) " ^ commas extras]);
+      | extras => ["Extra free type variable(s) " ^
+          commas (map (Syntax.string_of_typ ctxt o TFree) extras)]);
 
     val err_no_fields = if null bfields then ["No fields present"] else [];
 
@@ -2445,23 +2432,26 @@
       if forall (not_equal moreN o Binding.name_of o #1) bfields then []
       else ["Illegal field name " ^ quote moreN];
 
-    val err_dup_sorts =
-      (case duplicates (op =) envir_names of
-        [] => []
-      | dups => ["Inconsistent sort constraints for " ^ commas dups]);
-
     val errs =
-      err_dup_record @ err_dup_parms @ err_extra_frees @ err_no_fields @
-      err_dup_fields @ err_bad_fields @ err_dup_sorts;
-
+      err_dup_record @ err_extra_frees @ err_no_fields @ err_dup_fields @ err_bad_fields;
     val _ = if null errs then () else error (cat_lines errs);
   in
-    thy |> record_definition (args, binding) parent parents bfields
+    thy |> record_definition (params, binding) parent parents bfields
   end
   handle ERROR msg => cat_error msg ("Failed to define record " ^ quote (Binding.str_of binding));
 
-val add_record = gen_add_record cert_typ (K I);
-val add_record_cmd = gen_add_record read_typ read_raw_parent;
+fun add_record_cmd quiet_mode (raw_params, binding) raw_parent raw_fields thy0 =
+  let
+    val thy = Theory.checkpoint thy0;
+    val lthy = Theory_Target.init NONE thy;
+    val params = map (apsnd (Typedecl.read_constraint lthy)) raw_params;
+    val (_, lthy1) = Typedecl.predeclare_constraints (binding, params, NoSyn) lthy;
+    val (parent, lthy2) = read_parent raw_parent lthy1;
+    val (fields, lthy3) = fold_map read_field raw_fields lthy2;
+    val params' = map (fn (a, _) => (a, ProofContext.default_sort lthy3 (a, ~1))) params;
+  in thy |> add_record quiet_mode (params', binding) parent fields end;
+
+end;
 
 
 (* setup theory *)
@@ -2479,7 +2469,7 @@
 
 val _ =
   OuterSyntax.command "record" "define extensible record" K.thy_decl
-    (P.type_args -- P.binding --
+    (P.type_args_constrained -- P.binding --
       (P.$$$ "=" |-- Scan.option (P.typ --| P.$$$ "+") -- Scan.repeat1 P.const_binding)
     >> (fn (x, (y, z)) => Toplevel.theory (add_record_cmd false x y z)));