explicit type schemes for functions
authorhaftmann
Fri, 23 May 2008 16:05:07 +0200
changeset 26970 bc28e7bcb765
parent 26969 cf3f998d0631
child 26971 160117247294
explicit type schemes for functions
src/Pure/Isar/code.ML
src/Pure/Isar/code_unit.ML
src/Tools/nbe.ML
--- a/src/Pure/Isar/code.ML	Fri May 23 16:05:04 2008 +0200
+++ b/src/Pure/Isar/code.ML	Fri May 23 16:05:07 2008 +0200
@@ -38,7 +38,7 @@
   val get_datatype_of_constr: theory -> string -> string option
   val get_case_data: theory -> string -> (int * string list) option
   val is_undefined: theory -> string -> bool
-  val default_typ: theory -> string -> typ
+  val default_typ: theory -> string -> (string * sort) list * typ
 
   val preprocess_conv: cterm -> thm
   val preprocess_term: theory -> term -> term
@@ -89,7 +89,7 @@
   val empty = [];
   val copy = I;
   val extend = I;
-  fun merge _ = AList.merge (op =) (K true);
+  fun merge _ = AList.merge (op = : string * string -> bool) (K true);
 );
 
 fun add_attribute (attr as (name, _)) =
@@ -510,9 +510,14 @@
 
 (** theorem transformation and certification **)
 
+fun const_of thy = dest_Const o fst o strip_comb o fst o Logic.dest_equals
+  o ObjectLogic.drop_judgment thy o Thm.plain_prop_of;
+
+fun const_of_func thy = AxClass.unoverload_const thy o const_of thy;
+
 fun common_typ_funcs [] = []
   | common_typ_funcs [thm] = [thm]
-  | common_typ_funcs (thms as thm :: _) =
+  | common_typ_funcs (thms as thm :: _) = (*FIXME is too general*)
       let
         val thy = Thm.theory_of_thm thm;
         fun incr_thm thm max =
@@ -521,7 +526,7 @@
             val max' = Thm.maxidx_of thm' + 1;
           in (thm', max') end;
         val (thms', maxidx) = fold_map incr_thm thms 0;
-        val ty1 :: tys = map (snd o CodeUnit.head_func) thms';
+        val ty1 :: tys = map (snd o const_of thy) thms';
         fun unify ty env = Sign.typ_unify thy (ty1, ty) env
           handle Type.TUNIFY =>
             error ("Type unificaton failed, while unifying defining equations\n"
@@ -533,8 +538,6 @@
           cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
       in map (Thm.instantiate (instT, [])) thms' end;
 
-fun const_of_func thy = AxClass.unoverload_const thy o CodeUnit.head_func;
-
 fun certify_const thy const thms =
   let
     fun cert thm = if const = const_of_func thy thm
@@ -569,7 +572,7 @@
       |> map (Thm.transfer thy)
     fun sorts_of [Type (_, tys)] = map (snd o dest_TVar) tys
       | sorts_of tys = map (snd o dest_TVar) tys;
-    val sorts = map (sorts_of o Sign.const_typargs thy o CodeUnit.head_func) funcs;
+    val sorts = map (sorts_of o Sign.const_typargs thy o const_of thy) funcs;
   in sorts end;
 
 fun weakest_constraints thy algebra (class, tyco) =
@@ -640,14 +643,14 @@
     fun check_typ_classparam tyco (c, thm) =
           let
             val SOME class = AxClass.class_of_param thy c;
-            val (_, ty) = CodeUnit.head_func thm;
+            val (_, ty) = const_of thy thm;
             val ty_decl = classparam_weakest_typ thy class (c, tyco);
             val ty_strongest = classparam_strongest_typ thy class (c, tyco);
             fun constrain thm = 
               let
                 val max = Thm.maxidx_of thm + 1;
                 val ty_decl' = Logic.incr_tvar max ty_decl;
-                val (_, ty') = CodeUnit.head_func thm;
+                val (_, ty') = const_of thy thm;
                 val (env, _) = Sign.typ_unify thy (ty_decl', ty') (Vartab.empty, max);
                 val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
                   cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
@@ -669,7 +672,7 @@
           end;
     fun check_typ_fun (c, thm) =
       let
-        val (_, ty) = CodeUnit.head_func thm;
+        val (_, ty) = const_of thy thm;
         val ty_decl = Sign.the_const_type thy c;
       in if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
         then thm
@@ -926,8 +929,9 @@
   |> map (CodeUnit.rewrite_func ((#inlines o the_thmproc o the_exec) thy))
   |> fold (fn (_, (_, f)) => apply_inline_proc thy f) ((#inline_procs o the_thmproc o the_exec) thy)
 (*FIXME - must check: rewrite rule, defining equation, proper constant |> map (snd o check_func false thy) *)
-  |> common_typ_funcs
-  |> map (AxClass.unoverload thy);
+  |> map (AxClass.unoverload thy)
+  |> common_typ_funcs;
+
 
 fun preprocess_conv ct =
   let
@@ -984,10 +988,10 @@
   end;
 
 fun default_typ thy c = case default_typ_proto thy c
- of SOME ty => ty
+ of SOME ty => CodeUnit.typscheme thy (c, ty)
   | NONE => (case get_funcs thy c
      of thm :: _ => snd (CodeUnit.head_func (AxClass.unoverload thy thm))
-      | [] => Sign.the_const_type thy c);
+      | [] => CodeUnit.typscheme thy (c, Sign.the_const_type thy c));
 
 end; (*local*)
 
--- a/src/Pure/Isar/code_unit.ML	Fri May 23 16:05:04 2008 +0200
+++ b/src/Pure/Isar/code_unit.ML	Fri May 23 16:05:07 2008 +0200
@@ -14,6 +14,7 @@
   val try_thm: (thm -> thm) -> thm -> thm option
 
   (*typ instantiations*)
+  val typscheme: theory -> string * typ -> (string * sort) list * typ
   val inst_thm: sort Vartab.table -> thm -> thm
   val constrain_thm: sort -> thm -> thm
 
@@ -39,7 +40,7 @@
   val assert_rew: thm -> thm
   val mk_rew: thm -> thm
   val mk_func: thm -> thm
-  val head_func: thm -> string * typ
+  val head_func: thm -> string * ((string * sort) list * typ)
   val expand_eta: int -> thm -> thm
   val rewrite_func: thm list -> thm -> thm
   val norm_args: thm list -> thm list 
@@ -72,6 +73,13 @@
 
 (* utilities *)
 
+fun typscheme thy (c, ty) =
+  let
+    fun dest (TVar ((v, 0), sort)) = (v, sort)
+      | dest ty = error ("Illegal type parameter in type scheme: " ^ Syntax.string_of_typ_global thy ty);
+    val vs = map dest (Sign.const_typargs thy (c, ty));
+  in (vs, ty) end;
+
 fun inst_thm tvars' thm =
   let
     val thy = Thm.theory_of_thm thm;
@@ -297,7 +305,7 @@
     fun ty_sorts (c, ty) =
       let
         val ty_decl = (Logic.unvarifyT o Sign.the_const_type thy) c;
-        val (tyco, vs_decl) = last_typ (c, ty) ty_decl;
+        val (tyco, _) = last_typ (c, ty) ty_decl;
         val (_, vs) = last_typ (c, ty) ty;
       in ((tyco, map snd vs), (c, (map fst vs, ty_decl))) end;
     fun add ((tyco', sorts'), c) ((tyco, sorts), cs) =
@@ -399,8 +407,8 @@
   let
     val thy = Thm.theory_of_thm thm;
     val Const (c, ty) = (fst o strip_comb o fst o Logic.dest_equals
-      o ObjectLogic.drop_judgment thy o Thm.plain_prop_of) thm;
-  in (c, ty) end;
+      o (*ObjectLogic.drop_judgment thy o *)Thm.plain_prop_of) thm;
+  in (c, typscheme thy (c, ty)) end;
 
 
 (* case cerificates *)
--- a/src/Tools/nbe.ML	Fri May 23 16:05:04 2008 +0200
+++ b/src/Tools/nbe.ML	Fri May 23 16:05:07 2008 +0200
@@ -326,7 +326,7 @@
           let
             val ts' = take_until is_dict ts;
             val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx;
-            val T = Code.default_typ thy c;
+            val (_, T) = Code.default_typ thy c;
             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, [])) T;
             val typidx' = typidx + maxidx_of_typ T' + 1;
           in of_apps bounds (Term.Const (c, T'), ts') typidx' end