explicit type variable arguments for constructors
authorhaftmann
Thu, 17 Jun 2010 15:59:46 +0200
changeset 37448 3bd4b3809bee
parent 37447 ad3e04f289b6
child 37449 034ebe92f090
explicit type variable arguments for constructors
src/Pure/Isar/code.ML
src/Tools/Code/code_eval.ML
src/Tools/Code/code_thingol.ML
--- a/src/Pure/Isar/code.ML	Thu Jun 17 11:33:04 2010 +0200
+++ b/src/Pure/Isar/code.ML	Thu Jun 17 15:59:46 2010 +0200
@@ -66,7 +66,7 @@
   val del_eqns: string -> theory -> theory
   val add_case: thm -> theory -> theory
   val add_undefined: string -> theory -> theory
-  val get_type: theory -> string -> ((string * sort) list * (string * typ list) list)
+  val get_type: theory -> string -> ((string * sort) list * ((string * string list) * typ list) list)
   val get_type_of_constr_or_abstr: theory -> string -> (string * bool) option
   val is_constr: theory -> string -> bool
   val is_abstr: theory -> string -> bool
@@ -429,7 +429,13 @@
  of SOME (vs, Abstractor spec) => (vs, spec)
   | _ => error ("Not an abstract type: " ^ tyco);
  
-fun get_type thy = fst o get_type_spec thy;
+fun get_type thy tyco =
+  let
+    val ((vs, cos), _) = get_type_spec thy tyco;
+    fun args_of c tys = map (fst o dest_TFree)
+      (Sign.const_typargs thy (c, tys ---> Type (tyco, map TFree vs)));
+    fun add_typargs (c, tys) = ((c, args_of c tys), tys);
+  in (vs, map add_typargs cos) end;
 
 fun get_type_of_constr_or_abstr thy c =
   case (snd o strip_type o const_typ thy) c
@@ -1115,7 +1121,7 @@
     val ([x, y], ctxt) = Name.variants ["A", "A'"] Name.context;
     val (zs, _) = Name.variants (replicate (num_args - 1) "") ctxt;
     val (ws, vs) = chop pos zs;
-    val T = Logic.unvarifyT_global (const_typ thy case_const);
+    val T = Logic.unvarifyT_global (Sign.the_const_type thy case_const);
     val Ts = (fst o strip_type) T;
     val T_cong = nth Ts pos;
     fun mk_prem z = Free (z, T_cong);
@@ -1177,7 +1183,7 @@
   Interpretation(type T = string * serial val eq = eq_snd (op =) : T * T -> bool);
 
 fun datatype_interpretation f = Datatype_Interpretation.interpretation
-  (fn (tyco, _) => fn thy => f (tyco, get_type thy tyco) thy);
+  (fn (tyco, _) => fn thy => f (tyco, fst (get_type_spec thy tyco)) thy);
 
 fun add_datatype proto_constrs thy =
   let
--- a/src/Tools/Code/code_eval.ML	Thu Jun 17 11:33:04 2010 +0200
+++ b/src/Tools/Code/code_eval.ML	Thu Jun 17 15:59:46 2010 +0200
@@ -122,7 +122,7 @@
 
 fun check_datatype thy tyco consts =
   let
-    val constrs = (map fst o snd o Code.get_type thy) tyco;
+    val constrs = (map (fst o fst) o snd o Code.get_type thy) tyco;
     val missing_constrs = subtract (op =) consts constrs;
     val _ = if null missing_constrs then []
       else error ("Missing constructor(s) " ^ commas (map quote missing_constrs)
--- a/src/Tools/Code/code_thingol.ML	Thu Jun 17 11:33:04 2010 +0200
+++ b/src/Tools/Code/code_thingol.ML	Thu Jun 17 15:59:46 2010 +0200
@@ -67,14 +67,16 @@
   datatype stmt =
       NoStmt
     | Fun of string * ((typscheme * ((iterm list * iterm) * (thm option * bool)) list) * thm option)
-    | Datatype of string * ((vname * sort) list * (string * itype list) list)
+    | Datatype of string * ((vname * sort) list *
+        ((string * vname list (*type argument wrt. canonical order*)) * itype list) list)
     | Datatypecons of string * string
     | Class of class * (vname * ((class * string) list * (string * itype) list))
     | Classrel of class * class
     | Classparam of string * class
-    | Classinst of (class * (string * (vname * sort) list) (*class and arity*) )
+    | Classinst of (class * (string * (vname * sort) list) (*class and arity*))
           * ((class * (string * (string * dict list list))) list (*super instances*)
-        * ((string * const) * (thm * bool)) list (*class parameter instances*))
+        * (((string * const) * (thm * bool)) list (*class parameter instances*)
+          * ((string * const) * (thm * bool)) list (*super class parameter instances*)))
   type program = stmt Graph.T
   val empty_funs: program -> string list
   val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm
@@ -403,14 +405,16 @@
 datatype stmt =
     NoStmt
   | Fun of string * ((typscheme * ((iterm list * iterm) * (thm option * bool)) list) * thm option)
-  | Datatype of string * ((vname * sort) list * (string * itype list) list)
+  | Datatype of string * ((vname * sort) list * ((string * vname list) * itype list) list)
   | Datatypecons of string * string
   | Class of class * (vname * ((class * string) list * (string * itype) list))
   | Classrel of class * class
   | Classparam of string * class
   | Classinst of (class * (string * (vname * sort) list))
         * ((class * (string * (string * dict list list))) list
-      * ((string * const) * (thm * bool)) list) (*see also signature*);
+      * (((string * const) * (thm * bool)) list
+        * ((string * const) * (thm * bool)) list))
+      (*see also signature*);
 
 type program = stmt Graph.T;
 
@@ -428,6 +432,9 @@
       (ICase (((map_terms_bottom_up f t, ty), (map o pairself)
         (map_terms_bottom_up f) ps), map_terms_bottom_up f t0));
 
+fun map_classparam_instances_as_term f =
+  (map o apfst o apsnd) (fn const => case f (IConst const) of IConst const' => const')
+
 fun map_terms_stmt f NoStmt = NoStmt
   | map_terms_stmt f (Fun (c, ((tysm, eqs), case_cong))) = Fun (c, ((tysm, (map o apfst)
       (fn (ts, t) => (map f ts, f t)) eqs), case_cong))
@@ -436,9 +443,8 @@
   | map_terms_stmt f (stmt as Class _) = stmt
   | map_terms_stmt f (stmt as Classrel _) = stmt
   | map_terms_stmt f (stmt as Classparam _) = stmt
-  | map_terms_stmt f (Classinst (arity, (super_instances, classparams))) =
-      Classinst (arity, (super_instances, (map o apfst o apsnd) (fn const =>
-        case f (IConst const) of IConst const' => const') classparams));
+  | map_terms_stmt f (Classinst (arity, (super_instances, classparam_instances))) =
+      Classinst (arity, (super_instances, (pairself o map_classparam_instances_as_term) f classparam_instances));
 
 fun is_cons program name = case Graph.get_node program name
  of Datatypecons _ => true
@@ -557,8 +563,9 @@
     val (vs, cos) = Code.get_type thy tyco;
     val stmt_datatype =
       fold_map (translate_tyvar_sort thy algbr eqngr permissive) vs
-      ##>> fold_map (fn (c, tys) =>
+      ##>> fold_map (fn ((c, vs), tys) =>
         ensure_const thy algbr eqngr permissive c
+        ##>> pair (map (unprefix "'") vs)
         ##>> fold_map (translate_typ thy algbr eqngr permissive) tys) cos
       #>> (fn info => Datatype (tyco, info));
   in ensure_stmt lookup_tyco (declare_tyco thy) stmt_datatype tyco end
@@ -607,7 +614,10 @@
 and ensure_inst thy (algbr as (_, algebra)) eqngr permissive (class, tyco) =
   let
     val super_classes = (Sorts.minimize_sort algebra o Sorts.super_classes algebra) class;
-    val classparams = these (try (#params o AxClass.get_info thy) class);
+    val these_classparams = these o try (#params o AxClass.get_info thy);
+    val classparams = these_classparams class;
+    val further_classparams = maps these_classparams
+      ((Sorts.complete_sort algebra o Sorts.super_classes algebra) class);
     val vs = Name.names Name.context "'a" (Sorts.mg_domain algebra tyco [class]);
     val sorts' = Sorts.mg_domain (Sign.classes_of thy) tyco [class];
     val vs' = map2 (fn (v, sort1) => fn sort2 => (v,
@@ -637,8 +647,11 @@
       ##>> fold_map (translate_tyvar_sort thy algbr eqngr permissive) vs
       ##>> fold_map translate_super_instance super_classes
       ##>> fold_map translate_classparam_instance classparams
-      #>> (fn ((((class, tyco), arity_args), super_instances), classparam_instances) =>
-             Classinst ((class, (tyco, arity_args)), (super_instances, classparam_instances)));
+      ##>> fold_map translate_classparam_instance further_classparams
+      #>> (fn (((((class, tyco), arity_args), super_instances),
+        classparam_instances), further_classparam_instances) =>
+          Classinst ((class, (tyco, arity_args)), (super_instances,
+            (classparam_instances, further_classparam_instances))));
   in ensure_stmt lookup_instance (declare_instance thy) stmt_inst (class, tyco) end
 and translate_typ thy algbr eqngr permissive (TFree (v, _)) =
       pair (ITyVar (unprefix "'" v))
@@ -682,15 +695,15 @@
         then translation_error thy permissive some_thm
           "Abstraction violation" ("constant " ^ Code.string_of_const thy c)
       else ()
-    val tys = Sign.const_typargs thy (c, ty);
+    val arg_typs = Sign.const_typargs thy (c, ty);
     val sorts = Code_Preproc.sortargs eqngr c;
-    val tys_args = (fst o Term.strip_type) ty;
+    val function_typs = (fst o Term.strip_type) ty;
   in
     ensure_const thy algbr eqngr permissive c
-    ##>> fold_map (translate_typ thy algbr eqngr permissive) tys
-    ##>> fold_map (translate_dicts thy algbr eqngr permissive some_thm) (tys ~~ sorts)
-    ##>> fold_map (translate_typ thy algbr eqngr permissive) tys_args
-    #>> (fn (((c, tys), iss), tys_args) => IConst (c, ((tys, iss), tys_args)))
+    ##>> fold_map (translate_typ thy algbr eqngr permissive) arg_typs
+    ##>> fold_map (translate_dicts thy algbr eqngr permissive some_thm) (arg_typs ~~ sorts)
+    ##>> fold_map (translate_typ thy algbr eqngr permissive) function_typs
+    #>> (fn (((c, arg_typs), dss), function_typs) => IConst (c, ((arg_typs, dss), function_typs)))
   end
 and translate_app_const thy algbr eqngr permissive some_thm ((c_ty, ts), some_abs) =
   translate_const thy algbr eqngr permissive some_thm (c_ty, some_abs)