keep type variable arguments of datatype constructors in bookkeeping
authorhaftmann
Fri, 26 Nov 2010 22:33:21 +0100
changeset 40726 16dcfedc4eb7
parent 40717 88f2955a111e
child 40727 29885c9be6ae
child 40757 b469a373df31
keep type variable arguments of datatype constructors in bookkeeping
src/HOL/Tools/code_evaluation.ML
src/Pure/Isar/code.ML
src/Tools/Code/code_runtime.ML
src/Tools/Code/code_thingol.ML
--- a/src/HOL/Tools/code_evaluation.ML	Fri Nov 26 18:07:00 2010 +0100
+++ b/src/HOL/Tools/code_evaluation.ML	Fri Nov 26 22:33:21 2010 +0100
@@ -54,7 +54,7 @@
 
 (* code equations for datatypes *)
 
-fun mk_term_of_eq thy ty (c, tys) =
+fun mk_term_of_eq thy ty (c, (_, tys)) =
   let
     val t = list_comb (Const (c, tys ---> ty),
       map Free (Name.names Name.context "a" tys));
@@ -74,7 +74,7 @@
     val vs = map (fn (v, sort) =>
       (v, curry (Sorts.inter_sort algebra) @{sort typerep} sort)) raw_vs;
     val ty = Type (tyco, map TFree vs);
-    val cs = (map o apsnd o map o map_atyps)
+    val cs = (map o apsnd o apsnd o map o map_atyps)
       (fn TFree (v, _) => TFree (v, (the o AList.lookup (op =) vs) v)) raw_cs;
     val const = AxClass.param_of_inst thy (@{const_name term_of}, tyco);
     val eqs = map (mk_term_of_eq thy ty) cs;
@@ -121,7 +121,7 @@
     |> Code.add_eqn eq
   end;
 
-fun ensure_abs_term_of_code (tyco, (raw_vs, ((abs, ty), (proj, _)))) thy =
+fun ensure_abs_term_of_code (tyco, (raw_vs, ((abs, (_, ty)), (proj, _)))) thy =
   let
     val has_inst = can (Sorts.mg_domain (Sign.classes_of thy) tyco) @{sort term_of};
   in if has_inst then add_abs_term_of_code tyco raw_vs abs ty proj thy else thy end;
--- a/src/Pure/Isar/code.ML	Fri Nov 26 18:07:00 2010 +0100
+++ b/src/Pure/Isar/code.ML	Fri Nov 26 22:33:21 2010 +0100
@@ -21,7 +21,7 @@
 
   (*constructor sets*)
   val constrset_of_consts: theory -> (string * typ) list
-    -> string * ((string * sort) list * (string * typ list) list)
+    -> string * ((string * sort) list * (string * ((string * sort) list * typ list)) list)
 
   (*code equations and certificates*)
   val mk_eqn: theory -> thm * bool -> thm * bool
@@ -48,11 +48,11 @@
   val add_datatype: (string * typ) list -> theory -> theory
   val add_datatype_cmd: string list -> theory -> theory
   val datatype_interpretation:
-    (string * ((string * sort) list * (string * typ list) list)
+    (string * ((string * sort) list * (string * ((string * sort) list * typ list)) list)
       -> theory -> theory) -> theory -> theory
   val add_abstype: thm -> theory -> theory
   val abstype_interpretation:
-    (string * ((string * sort) list * ((string * typ) * (string * thm)))
+    (string * ((string * sort) list * ((string * ((string * sort) list * typ)) * (string * thm)))
       -> theory -> theory) -> theory -> theory
   val add_eqn: thm -> theory -> theory
   val add_nbe_eqn: thm -> theory -> theory
@@ -66,7 +66,8 @@
   val del_eqns: string -> theory -> theory
   val add_case: thm -> theory -> theory
   val add_undefined: string -> theory -> theory
-  val get_type: theory -> string -> ((string * sort) list * ((string * string list) * typ list) list)
+  val get_type: theory -> string
+    -> ((string * sort) list * (string * ((string * sort) list * typ list)) list) * bool
   val get_type_of_constr_or_abstr: theory -> string -> (string * bool) option
   val is_constr: theory -> string -> bool
   val is_abstr: theory -> string -> bool
@@ -147,11 +148,11 @@
 
 (* datatypes *)
 
-datatype typ_spec = Constructors of (string * typ list) list
-  | Abstractor of (string * typ) * (string * thm);
+datatype typ_spec = Constructors of (string * ((string * sort) list * typ list)) list
+  | Abstractor of (string * ((string * sort) list * typ)) * (string * thm);
 
 fun constructors_of (Constructors cos) = (cos, false)
-  | constructors_of (Abstractor ((co, ty), _)) = ([(co, [ty])], true);
+  | constructors_of (Abstractor ((co, (vs, ty)), _)) = ([(co, (vs, [ty]))], true);
 
 
 (* functions *)
@@ -412,7 +413,8 @@
       let
         val the_v = the o AList.lookup (op =) (vs ~~ vs');
         val ty' = map_atyps (fn TFree (v, _) => TFree (the_v v)) ty;
-      in (c, (fst o strip_type) ty') end;
+        val vs'' = map dest_TFree (Sign.const_typargs thy (c, ty'));
+      in (c, (vs'', (fst o strip_type) ty')) end;
     val c' :: cs' = map (ty_sorts thy) cs;
     val ((tyco, sorts), cs'') = fold add cs' (apsnd single c');
     val vs = Name.names Name.context Name.aT sorts;
@@ -423,7 +425,7 @@
  of (_, entry) :: _ => SOME entry
   | _ => NONE;
 
-fun get_type_spec thy tyco = case get_type_entry thy tyco
+fun get_type thy tyco = case get_type_entry thy tyco
  of SOME (vs, spec) => apfst (pair vs) (constructors_of spec)
   | NONE => arity_number thy tyco
       |> Name.invents Name.context Name.aT
@@ -435,17 +437,9 @@
  of SOME (vs, Abstractor spec) => (vs, spec)
   | _ => error ("Not an abstract type: " ^ tyco);
  
-fun get_type thy tyco =
-  let
-    val ((vs, cos), _) = get_type_spec thy tyco;
-    fun args_of c tys = map (fst o dest_TFree)
-      (Sign.const_typargs thy (c, tys ---> Type (tyco, map TFree vs)));
-    fun add_typargs (c, tys) = ((c, args_of c tys), tys);
-  in (vs, map add_typargs cos) end;
-
 fun get_type_of_constr_or_abstr thy c =
   case (snd o strip_type o const_typ thy) c
-   of Type (tyco, _) => let val ((vs, cos), abstract) = get_type_spec thy tyco
+   of Type (tyco, _) => let val ((vs, cos), abstract) = get_type thy tyco
         in if member (op =) (map fst cos) c then SOME (tyco, abstract) else NONE end
     | _ => NONE;
 
@@ -683,8 +677,9 @@
     val _ = if param = rhs then () else bad "Not an abstype certificate";
     val ((tyco, sorts), (abs, (vs, ty'))) = ty_sorts thy (abs, Logic.unvarifyT_global raw_ty);
     val ty = domain_type ty';
+    val vs' = map dest_TFree (Sign.const_typargs thy (abs, ty'));
     val ty_abs = range_type ty';
-  in (tyco, (vs ~~ sorts, ((abs, ty), (rep, thm)))) end;
+  in (tyco, (vs ~~ sorts, ((abs, (vs', ty)), (rep, thm)))) end;
 
 
 (* code equation certificates *)
@@ -784,7 +779,7 @@
 
 fun cert_of_proj thy c tyco =
   let
-    val (vs, ((abs, ty), (rep, cert))) = get_abstype_spec thy tyco;
+    val (vs, ((abs, (_, ty)), (rep, cert))) = get_abstype_spec thy tyco;
     val _ = if c = rep then () else
       error ("Wrong head of projection,\nexpected constant " ^ string_of_const thy rep);
   in Projection (mk_proj tyco vs ty abs rep, tyco) end;
@@ -979,8 +974,8 @@
         pretty_typ typ
         :: Pretty.str "="
         :: (if abstract then [Pretty.str "(abstract)"] else [])
-        @ separate (Pretty.str "|") (map (fn (c, []) => Pretty.str (string_of_const thy c)
-             | (c, tys) =>
+        @ separate (Pretty.str "|") (map (fn (c, (_, [])) => Pretty.str (string_of_const thy c)
+             | (c, (_, tys)) =>
                  (Pretty.block o Pretty.breaks)
                     (Pretty.str (string_of_const thy c)
                       :: Pretty.str "of"
@@ -1202,7 +1197,7 @@
   Interpretation(type T = string * serial val eq = eq_snd (op =) : T * T -> bool);
 
 fun datatype_interpretation f = Datatype_Interpretation.interpretation
-  (fn (tyco, _) => fn thy => f (tyco, fst (get_type_spec thy tyco)) thy);
+  (fn (tyco, _) => fn thy => f (tyco, fst (get_type thy tyco)) thy);
 
 fun add_datatype proto_constrs thy =
   let
@@ -1226,7 +1221,7 @@
 
 fun add_abstype proto_thm thy =
   let
-    val (tyco, (vs, (abs_ty as (abs, ty), (rep, cert)))) =
+    val (tyco, (vs, (abs_ty as (abs, (_, ty)), (rep, cert)))) =
       error_thm (check_abstype_cert thy) proto_thm;
   in
     thy
--- a/src/Tools/Code/code_runtime.ML	Fri Nov 26 18:07:00 2010 +0100
+++ b/src/Tools/Code/code_runtime.ML	Fri Nov 26 22:33:21 2010 +0100
@@ -258,7 +258,7 @@
 
 fun check_datatype thy tyco some_consts =
   let
-    val constrs = (map (fst o fst) o snd o Code.get_type thy) tyco;
+    val constrs = (map fst o snd o fst o Code.get_type thy) tyco;
     val _ = case some_consts
      of SOME consts =>
           let
--- a/src/Tools/Code/code_thingol.ML	Fri Nov 26 18:07:00 2010 +0100
+++ b/src/Tools/Code/code_thingol.ML	Fri Nov 26 22:33:21 2010 +0100
@@ -573,12 +573,12 @@
 
 fun ensure_tyco thy algbr eqngr permissive tyco =
   let
-    val (vs, cos) = Code.get_type thy tyco;
+    val ((vs, cos), _) = Code.get_type thy tyco;
     val stmt_datatype =
       fold_map (translate_tyvar_sort thy algbr eqngr permissive) vs
-      ##>> fold_map (fn ((c, vs), tys) =>
+      ##>> fold_map (fn (c, (vs, tys)) =>
         ensure_const thy algbr eqngr permissive c
-        ##>> pair (map (unprefix "'") vs)
+        ##>> pair (map (unprefix "'" o fst) vs)
         ##>> fold_map (translate_typ thy algbr eqngr permissive) tys) cos
       #>> (fn info => Datatype (tyco, info));
   in ensure_stmt lookup_tyco (declare_tyco thy) stmt_datatype tyco end